use crate::{metrics::METRICS, node_info::NodeAddress, packet::*, Executor};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
net::UdpSocket,
sync::{mpsc, oneshot},
};
use tracing::{debug, error, trace, warn};
pub struct OutboundPacket {
pub node_address: NodeAddress,
pub packet: Packet,
}
pub(crate) struct SendHandler {
send_ipv4: Option<Arc<UdpSocket>>,
send_ipv6: Option<Arc<UdpSocket>>,
handler_recv: mpsc::Receiver<OutboundPacket>,
exit: oneshot::Receiver<()>,
}
enum Error {
Io(std::io::Error),
SocketMismatch,
}
impl SendHandler {
pub(crate) fn spawn(
executor: Box<dyn Executor>,
send_ipv4: Option<Arc<UdpSocket>>,
send_ipv6: Option<Arc<UdpSocket>>,
) -> (mpsc::Sender<OutboundPacket>, oneshot::Sender<()>) {
let (exit_send, exit) = oneshot::channel();
let (handler_send, handler_recv) = mpsc::channel(30);
let mut send_handler = SendHandler {
send_ipv4,
send_ipv6,
handler_recv,
exit,
};
executor.spawn(Box::pin(async move {
debug!("Send handler starting");
send_handler.start().await;
}));
(handler_send, exit_send)
}
async fn start(&mut self) {
loop {
tokio::select! {
Some(packet) = self.handler_recv.recv() => {
let encoded_packet = packet.packet.encode(&packet.node_address.node_id);
if encoded_packet.len() > MAX_PACKET_SIZE {
warn!(
size = encoded_packet.len(),
max = MAX_PACKET_SIZE,
"Sending packet larger than max size"
);
}
let addr = &packet.node_address.socket_addr;
if let Err(e) = self.send(&encoded_packet, addr).await {
match e {
Error::Io(e) => {
trace!(%addr, error = %e, "Could not send packet.");
},
Error::SocketMismatch => {
error!(%addr, "Socket mismatch attempting to send a packet.")
}
}
} else {
METRICS.add_sent_bytes(encoded_packet.len());
}
}
_ = &mut self.exit => {
debug!("Send handler shutdown");
return;
}
}
}
}
async fn send(&self, encoded_packet: &[u8], socket_addr: &SocketAddr) -> Result<usize, Error> {
let socket = match socket_addr {
SocketAddr::V4(_) => {
if let Some(socket) = self.send_ipv4.as_ref() {
socket
} else {
return Err(Error::SocketMismatch);
}
}
SocketAddr::V6(_) => {
if let Some(socket) = self.send_ipv6.as_ref() {
socket
} else {
return Err(Error::SocketMismatch);
}
}
};
socket
.send_to(encoded_packet, socket_addr)
.await
.map_err(Error::Io)
}
}