use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use elara_core::{ElaraError, ElaraResult};
use elara_wire::{Frame, MAX_FRAME_SIZE};
pub struct UdpTransport {
socket: Arc<UdpSocket>,
local_addr: SocketAddr,
}
impl UdpTransport {
pub async fn bind(addr: SocketAddr) -> ElaraResult<Self> {
tracing::info!(
bind_addr = %addr,
"Attempting to bind UDP transport"
);
let socket = UdpSocket::bind(addr)
.await
.map_err(|e| {
tracing::error!(
bind_addr = %addr,
error = %e,
"Failed to bind UDP socket"
);
ElaraError::TransportError(e.to_string())
})?;
let local_addr = socket
.local_addr()
.map_err(|e| {
tracing::error!(
error = %e,
"Failed to get local address"
);
ElaraError::TransportError(e.to_string())
})?;
tracing::info!(
local_addr = %local_addr,
"UDP transport bound successfully"
);
Ok(UdpTransport {
socket: Arc::new(socket),
local_addr,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub async fn send_to(&self, frame: &Frame, dest: SocketAddr) -> ElaraResult<()> {
let bytes = frame.serialize()?;
let size = bytes.len();
tracing::debug!(
dest = %dest,
size = size,
"Sending frame"
);
self.socket
.send_to(&bytes, dest)
.await
.map_err(|e| {
tracing::warn!(
dest = %dest,
size = size,
error = %e,
"Failed to send frame"
);
ElaraError::TransportError(e.to_string())
})?;
Ok(())
}
pub async fn send_bytes_to(&self, bytes: &[u8], dest: SocketAddr) -> ElaraResult<()> {
self.socket
.send_to(bytes, dest)
.await
.map_err(|e| ElaraError::TransportError(e.to_string()))?;
Ok(())
}
pub async fn recv_from(&self) -> ElaraResult<(Frame, SocketAddr)> {
let mut buf = vec![0u8; MAX_FRAME_SIZE];
let (len, addr) = self
.socket
.recv_from(&mut buf)
.await
.map_err(|e| {
tracing::warn!(
error = %e,
"Failed to receive from UDP socket"
);
ElaraError::TransportError(e.to_string())
})?;
tracing::debug!(
source = %addr,
size = len,
"Received frame"
);
let frame = Frame::parse(&buf[..len])?;
Ok((frame, addr))
}
pub async fn recv_bytes_from(&self) -> ElaraResult<(Vec<u8>, SocketAddr)> {
let mut buf = vec![0u8; MAX_FRAME_SIZE];
let (len, addr) = self
.socket
.recv_from(&mut buf)
.await
.map_err(|e| ElaraError::TransportError(e.to_string()))?;
Ok((buf[..len].to_vec(), addr))
}
pub fn socket(&self) -> Arc<UdpSocket> {
Arc::clone(&self.socket)
}
}
pub type PacketReceiver = mpsc::Receiver<(Vec<u8>, SocketAddr)>;
pub type PacketSender = mpsc::Sender<(Vec<u8>, SocketAddr)>;
pub fn start_receive_loop(socket: Arc<UdpSocket>, buffer_size: usize) -> PacketReceiver {
let (tx, rx) = mpsc::channel(buffer_size);
tokio::spawn(async move {
let mut buf = vec![0u8; MAX_FRAME_SIZE];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, addr)) => {
let packet = buf[..len].to_vec();
if tx.send((packet, addr)).await.is_err() {
break; }
}
Err(e) => {
tracing::warn!("UDP receive error: {}", e);
}
}
}
});
rx
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_udp_transport_bind() {
let transport = UdpTransport::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
assert_ne!(transport.local_addr().port(), 0);
}
}