nomad_protocol/transport/
socket.rs1use std::io;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use tokio::net::UdpSocket;
11
12use super::frame::sizes;
13
14pub const DEFAULT_RECV_BUFFER_SIZE: usize = 65535;
16
17#[derive(Debug)]
22pub struct NomadSocket {
23 socket: Arc<UdpSocket>,
25 recv_buffer: Vec<u8>,
27 max_payload_size: usize,
29}
30
31impl NomadSocket {
32 pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
34 let socket = UdpSocket::bind(addr).await?;
35 Ok(Self::from_socket(socket))
36 }
37
38 pub fn from_socket(socket: UdpSocket) -> Self {
40 Self {
41 socket: Arc::new(socket),
42 recv_buffer: vec![0u8; DEFAULT_RECV_BUFFER_SIZE],
43 max_payload_size: sizes::DEFAULT_MAX_PAYLOAD,
44 }
45 }
46
47 pub fn set_max_payload_size(&mut self, size: usize) {
49 self.max_payload_size = size;
50 }
51
52 pub fn max_payload_size(&self) -> usize {
54 self.max_payload_size
55 }
56
57 pub fn local_addr(&self) -> io::Result<SocketAddr> {
59 self.socket.local_addr()
60 }
61
62 pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
67 self.socket.connect(addr).await
68 }
69
70 pub async fn send_to(&self, data: &[u8], addr: SocketAddr) -> io::Result<usize> {
72 self.socket.send_to(data, addr).await
73 }
74
75 pub async fn send(&self, data: &[u8]) -> io::Result<usize> {
77 self.socket.send(data).await
78 }
79
80 pub async fn recv_from(&mut self) -> io::Result<(&[u8], SocketAddr)> {
82 let (len, addr) = self.socket.recv_from(&mut self.recv_buffer).await?;
83 Ok((&self.recv_buffer[..len], addr))
84 }
85
86 pub async fn recv(&mut self) -> io::Result<&[u8]> {
88 let len = self.socket.recv(&mut self.recv_buffer).await?;
89 Ok(&self.recv_buffer[..len])
90 }
91
92 pub fn try_recv_from(&mut self) -> io::Result<Option<(usize, SocketAddr)>> {
96 match self.socket.try_recv_from(&mut self.recv_buffer) {
97 Ok((len, addr)) => Ok(Some((len, addr))),
98 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
99 Err(e) => Err(e),
100 }
101 }
102
103 pub fn recv_data(&self, len: usize) -> &[u8] {
105 &self.recv_buffer[..len]
106 }
107
108 pub fn inner(&self) -> &UdpSocket {
110 &self.socket
111 }
112
113 pub fn socket_arc(&self) -> Arc<UdpSocket> {
115 Arc::clone(&self.socket)
116 }
117
118 pub fn max_frame_size(&self) -> usize {
120 self.max_payload_size + sizes::DATA_FRAME_HEADER_SIZE + sizes::AEAD_TAG_SIZE
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct NomadSocketBuilder {
127 recv_buffer_size: usize,
128 max_payload_size: usize,
129}
130
131impl Default for NomadSocketBuilder {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl NomadSocketBuilder {
138 pub fn new() -> Self {
140 Self {
141 recv_buffer_size: DEFAULT_RECV_BUFFER_SIZE,
142 max_payload_size: sizes::DEFAULT_MAX_PAYLOAD,
143 }
144 }
145
146 pub fn recv_buffer_size(mut self, size: usize) -> Self {
148 self.recv_buffer_size = size;
149 self
150 }
151
152 pub fn max_payload_size(mut self, size: usize) -> Self {
154 self.max_payload_size = size;
155 self
156 }
157
158 pub async fn bind(self, addr: SocketAddr) -> io::Result<NomadSocket> {
160 let socket = UdpSocket::bind(addr).await?;
161 Ok(self.from_socket(socket))
162 }
163
164 pub fn from_socket(self, socket: UdpSocket) -> NomadSocket {
166 NomadSocket {
167 socket: Arc::new(socket),
168 recv_buffer: vec![0u8; self.recv_buffer_size],
169 max_payload_size: self.max_payload_size,
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[tokio::test]
179 async fn test_socket_bind() {
180 let socket = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
181 .await
182 .unwrap();
183 let addr = socket.local_addr().unwrap();
184 assert!(addr.port() != 0);
185 }
186
187 #[tokio::test]
188 async fn test_socket_send_recv() {
189 let mut server = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
190 .await
191 .unwrap();
192 let server_addr = server.local_addr().unwrap();
193
194 let client = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
195 .await
196 .unwrap();
197
198 let data = b"hello NOMAD";
200 client.send_to(data, server_addr).await.unwrap();
201
202 let (received, from) = server.recv_from().await.unwrap();
204 assert_eq!(received, data);
205 assert_eq!(from, client.local_addr().unwrap());
206 }
207
208 #[tokio::test]
209 async fn test_socket_connected() {
210 let mut server = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
211 .await
212 .unwrap();
213 let server_addr = server.local_addr().unwrap();
214
215 let client = NomadSocket::bind("127.0.0.1:0".parse().unwrap())
216 .await
217 .unwrap();
218 client.connect(server_addr).await.unwrap();
219
220 let data = b"connected send";
222 client.send(data).await.unwrap();
223
224 let (received, _) = server.recv_from().await.unwrap();
226 assert_eq!(received, data);
227 }
228
229 #[test]
230 fn test_socket_builder() {
231 let builder = NomadSocketBuilder::new()
232 .recv_buffer_size(4096)
233 .max_payload_size(1400);
234
235 assert_eq!(builder.recv_buffer_size, 4096);
236 assert_eq!(builder.max_payload_size, 1400);
237 }
238
239 #[tokio::test]
240 async fn test_max_frame_size() {
241 let socket = NomadSocketBuilder::new()
242 .max_payload_size(1200)
243 .bind("127.0.0.1:0".parse().unwrap())
244 .await
245 .unwrap();
246
247 let expected = 1200 + sizes::DATA_FRAME_HEADER_SIZE + sizes::AEAD_TAG_SIZE;
249 assert_eq!(socket.max_frame_size(), expected);
250 }
251}