use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use super::{Transport, TransportKind};
use crate::error::{FrameError, SrxError, TransportError};
const MAX_DATAGRAM: usize = 65_507;
pub struct UdpTransport {
socket: Arc<Mutex<Option<UdpSocket>>>,
}
impl UdpTransport {
pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> crate::error::Result<Self> {
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
socket
.connect(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
socket: Arc::new(Mutex::new(Some(socket))),
})
}
pub async fn bind(addr: impl tokio::net::ToSocketAddrs) -> crate::error::Result<Self> {
let socket = UdpSocket::bind(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(Self {
socket: Arc::new(Mutex::new(Some(socket))),
})
}
pub async fn connect_peer(&self, addr: SocketAddr) -> crate::error::Result<()> {
let mut g = self.socket.lock().await;
let s = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
s.connect(addr)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(())
}
#[must_use]
pub fn from_socket(socket: UdpSocket) -> Self {
Self {
socket: Arc::new(Mutex::new(Some(socket))),
}
}
pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
let len = u32::try_from(payload.len()).map_err(|_| {
SrxError::Frame(FrameError::FrameTooLarge {
size: payload.len(),
max: MAX_DATAGRAM,
})
})?;
let total = 4 + payload.len();
if total > MAX_DATAGRAM {
return Err(SrxError::Frame(FrameError::FrameTooLarge {
size: total,
max: MAX_DATAGRAM,
}));
}
let mut framed = Vec::with_capacity(total);
framed.extend_from_slice(&len.to_be_bytes());
framed.extend_from_slice(payload);
self.send(Bytes::from(framed)).await
}
pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
let data = self.recv().await?;
if data.len() < 4 {
return Err(SrxError::Frame(FrameError::Corrupted(
"datagram too short for length prefix".into(),
)));
}
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
let payload_len = len as usize;
if data.len() != 4 + payload_len {
return Err(SrxError::Frame(FrameError::Corrupted(
"datagram size does not match length prefix".into(),
)));
}
Ok(data.slice(4..))
}
}
#[async_trait]
impl Transport for UdpTransport {
fn kind(&self) -> TransportKind {
TransportKind::Udp
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let mut g = self.socket.lock().await;
let s = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
s.send(&data)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
let mut g = self.socket.lock().await;
let s = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let mut buf = vec![0u8; MAX_DATAGRAM];
let n = s
.recv(&mut buf)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
buf.truncate(n);
Ok(Bytes::from(buf))
}
async fn is_healthy(&self) -> bool {
self.socket.lock().await.is_some()
}
async fn close(&self) -> crate::error::Result<()> {
*self.socket.lock().await = None;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn connected_udp_roundtrip() {
let sa = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let sb = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr_a = sa.local_addr().unwrap();
let addr_b = sb.local_addr().unwrap();
sa.connect(addr_b).await.unwrap();
sb.connect(addr_a).await.unwrap();
let a = UdpTransport::from_socket(sa);
let b = UdpTransport::from_socket(sb);
let peer = tokio::spawn(async move {
b.send(Bytes::from_static(b"ping")).await.unwrap();
let r = b.recv().await.unwrap();
assert_eq!(r.as_ref(), b"pong");
});
let first = a.recv().await.unwrap();
assert_eq!(first.as_ref(), b"ping");
a.send(Bytes::from_static(b"pong")).await.unwrap();
peer.await.unwrap();
}
#[tokio::test]
async fn framed_roundtrip_udp() {
let sa = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let sb = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr_a = sa.local_addr().unwrap();
let addr_b = sb.local_addr().unwrap();
sa.connect(addr_b).await.unwrap();
sb.connect(addr_a).await.unwrap();
let a = UdpTransport::from_socket(sa);
let b = UdpTransport::from_socket(sb);
let payload = b"framed-payload-srx";
let peer = tokio::spawn(async move {
let got = b.recv_framed().await.unwrap();
assert_eq!(got.as_ref(), payload);
b.send_framed(b"ack").await.unwrap();
});
a.send_framed(payload).await.unwrap();
let reply = a.recv_framed().await.unwrap();
assert_eq!(reply.as_ref(), b"ack");
peer.await.unwrap();
}
}