nomad_protocol/transport/
socket.rs

1//! Async UDP socket wrapper for NOMAD transport.
2//!
3//! Provides a high-level interface for sending and receiving NOMAD frames
4//! over UDP.
5
6use std::io;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use tokio::net::UdpSocket;
11
12use super::frame::sizes;
13
14/// Default receive buffer size.
15pub const DEFAULT_RECV_BUFFER_SIZE: usize = 65535;
16
17/// Async UDP socket wrapper for NOMAD.
18///
19/// Provides convenient methods for sending/receiving frames with
20/// proper buffer management.
21#[derive(Debug)]
22pub struct NomadSocket {
23    /// The underlying UDP socket.
24    socket: Arc<UdpSocket>,
25    /// Receive buffer.
26    recv_buffer: Vec<u8>,
27    /// Maximum payload size (for MTU considerations).
28    max_payload_size: usize,
29}
30
31impl NomadSocket {
32    /// Create a new NOMAD socket bound to the given address.
33    pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
34        let socket = UdpSocket::bind(addr).await?;
35        Ok(Self::from_socket(socket))
36    }
37
38    /// Create a NOMAD socket from an existing UDP socket.
39    pub fn from_socket(socket: UdpSocket) -> Self {
40        Self {
41            socket: Arc::new(socket),
42            recv_buffer: vec![0u8; DEFAULT_RECV_BUFFER_SIZE],
43            max_payload_size: sizes::DEFAULT_MAX_PAYLOAD,
44        }
45    }
46
47    /// Set the maximum payload size (for MTU considerations).
48    pub fn set_max_payload_size(&mut self, size: usize) {
49        self.max_payload_size = size;
50    }
51
52    /// Get the maximum payload size.
53    pub fn max_payload_size(&self) -> usize {
54        self.max_payload_size
55    }
56
57    /// Get the local address.
58    pub fn local_addr(&self) -> io::Result<SocketAddr> {
59        self.socket.local_addr()
60    }
61
62    /// Connect to a remote address (for client sockets).
63    ///
64    /// After connecting, `send` and `recv` can be used instead of
65    /// `send_to` and `recv_from`.
66    pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
67        self.socket.connect(addr).await
68    }
69
70    /// Send data to a specific address.
71    pub async fn send_to(&self, data: &[u8], addr: SocketAddr) -> io::Result<usize> {
72        self.socket.send_to(data, addr).await
73    }
74
75    /// Send data to the connected address.
76    pub async fn send(&self, data: &[u8]) -> io::Result<usize> {
77        self.socket.send(data).await
78    }
79
80    /// Receive data and return the sender's address.
81    pub async fn recv_from(&mut self) -> io::Result<(&[u8], SocketAddr)> {
82        let (len, addr) = self.socket.recv_from(&mut self.recv_buffer).await?;
83        Ok((&self.recv_buffer[..len], addr))
84    }
85
86    /// Receive data from the connected address.
87    pub async fn recv(&mut self) -> io::Result<&[u8]> {
88        let len = self.socket.recv(&mut self.recv_buffer).await?;
89        Ok(&self.recv_buffer[..len])
90    }
91
92    /// Try to receive data without blocking.
93    ///
94    /// Returns `Ok(None)` if no data is available.
95    pub fn try_recv_from(&mut self) -> io::Result<Option<(usize, SocketAddr)>> {
96        match self.socket.try_recv_from(&mut self.recv_buffer) {
97            Ok((len, addr)) => Ok(Some((len, addr))),
98            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
99            Err(e) => Err(e),
100        }
101    }
102
103    /// Get the received data after a successful `try_recv_from`.
104    pub fn recv_data(&self, len: usize) -> &[u8] {
105        &self.recv_buffer[..len]
106    }
107
108    /// Get a reference to the underlying socket.
109    pub fn inner(&self) -> &UdpSocket {
110        &self.socket
111    }
112
113    /// Get a clone of the Arc-wrapped socket.
114    pub fn socket_arc(&self) -> Arc<UdpSocket> {
115        Arc::clone(&self.socket)
116    }
117
118    /// Calculate maximum frame size considering headers.
119    pub fn max_frame_size(&self) -> usize {
120        self.max_payload_size + sizes::DATA_FRAME_HEADER_SIZE + sizes::AEAD_TAG_SIZE
121    }
122}
123
124/// Builder for creating NOMAD sockets with custom options.
125#[derive(Debug, Clone)]
126pub struct NomadSocketBuilder {
127    recv_buffer_size: usize,
128    max_payload_size: usize,
129}
130
131impl Default for NomadSocketBuilder {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl NomadSocketBuilder {
138    /// Create a new socket builder with default options.
139    pub fn new() -> Self {
140        Self {
141            recv_buffer_size: DEFAULT_RECV_BUFFER_SIZE,
142            max_payload_size: sizes::DEFAULT_MAX_PAYLOAD,
143        }
144    }
145
146    /// Set the receive buffer size.
147    pub fn recv_buffer_size(mut self, size: usize) -> Self {
148        self.recv_buffer_size = size;
149        self
150    }
151
152    /// Set the maximum payload size.
153    pub fn max_payload_size(mut self, size: usize) -> Self {
154        self.max_payload_size = size;
155        self
156    }
157
158    /// Bind to the given address and create a socket.
159    pub async fn bind(self, addr: SocketAddr) -> io::Result<NomadSocket> {
160        let socket = UdpSocket::bind(addr).await?;
161        Ok(self.from_socket(socket))
162    }
163
164    /// Create a socket from an existing UDP socket.
165    pub fn from_socket(self, socket: UdpSocket) -> NomadSocket {
166        NomadSocket {
167            socket: Arc::new(socket),
168            recv_buffer: vec![0u8; self.recv_buffer_size],
169            max_payload_size: self.max_payload_size,
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[tokio::test]
179    async fn test_socket_bind() {
180        let socket = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
181            .await
182            .unwrap();
183        let addr = socket.local_addr().unwrap();
184        assert!(addr.port() != 0);
185    }
186
187    #[tokio::test]
188    async fn test_socket_send_recv() {
189        let mut server = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
190            .await
191            .unwrap();
192        let server_addr = server.local_addr().unwrap();
193
194        let client = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
195            .await
196            .unwrap();
197
198        // Send from client
199        let data = b"hello NOMAD";
200        client.send_to(data, server_addr).await.unwrap();
201
202        // Receive on server
203        let (received, from) = server.recv_from().await.unwrap();
204        assert_eq!(received, data);
205        assert_eq!(from, client.local_addr().unwrap());
206    }
207
208    #[tokio::test]
209    async fn test_socket_connected() {
210        let mut server = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
211            .await
212            .unwrap();
213        let server_addr = server.local_addr().unwrap();
214
215        let client = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
216            .await
217            .unwrap();
218        client.connect(server_addr).await.unwrap();
219
220        // Send using connected interface
221        let data = b"connected send";
222        client.send(data).await.unwrap();
223
224        // Receive on server
225        let (received, _) = server.recv_from().await.unwrap();
226        assert_eq!(received, data);
227    }
228
229    #[test]
230    fn test_socket_builder() {
231        let builder = NomadSocketBuilder::new()
232            .recv_buffer_size(4096)
233            .max_payload_size(1400);
234
235        assert_eq!(builder.recv_buffer_size, 4096);
236        assert_eq!(builder.max_payload_size, 1400);
237    }
238
239    #[tokio::test]
240    async fn test_max_frame_size() {
241        let socket = NomadSocketBuilder::new()
242            .max_payload_size(1200)
243            .bind("127.0.0.1:0".parse().unwrap())
244            .await
245            .unwrap();
246
247        // max_frame_size = payload + header + tag
248        let expected = 1200 + sizes::DATA_FRAME_HEADER_SIZE + sizes::AEAD_TAG_SIZE;
249        assert_eq!(socket.max_frame_size(), expected);
250    }
251}