1use derive_more::Display;
2use std::hash::{Hash, Hasher};
3
4use derive_more::From;
5use smoltcp::socket::*;
6use smoltcp::storage::PacketMetadata;
7use smoltcp::time::Duration;
8use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpVersion};
9
10use crate::{Error, Protocol};
11
12pub const ENV_VAR_TCP_TIMEOUT: &str = "YA_NET_TCP_TIMEOUT_MS";
13pub const ENV_VAR_TCP_KEEP_ALIVE: &str = "YA_NET_TCP_KEEP_ALIVE_MS";
14pub const ENV_VAR_TCP_ACK_DELAY: &str = "YA_NET_TCP_ACK_DELAY_MS";
15pub const ENV_VAR_TCP_NAGLE: &str = "YA_NET_TCP_ACK_DELAY";
16pub const TCP_CONN_TIMEOUT: Duration = Duration::from_secs(45);
17pub const TCP_DISCONN_TIMEOUT: Duration = Duration::from_secs(2);
18const META_STORAGE_SIZE: usize = 1024;
19
20lazy_static::lazy_static! {
21 pub static ref TCP_NAGLE_ENABLED: bool = env_opt(ENV_VAR_TCP_NAGLE, |v| v != 0)
22 .flatten()
23 .unwrap_or(false);
24 pub static ref TCP_TIMEOUT: Option<Duration> = env_opt(ENV_VAR_TCP_TIMEOUT, Duration::from_millis)
25 .unwrap_or(Some(Duration::from_secs(120)));
26 pub static ref TCP_KEEP_ALIVE: Option<Duration> = env_opt(ENV_VAR_TCP_KEEP_ALIVE, Duration::from_millis)
27 .unwrap_or(Some(Duration::from_secs(30)));
28 pub static ref TCP_ACK_DELAY: Option<Duration> = env_opt(ENV_VAR_TCP_ACK_DELAY, Duration::from_millis)
29 .unwrap_or(Some(Duration::from_millis(40)));
30}
31
32fn env_opt<T, F: FnOnce(u64) -> T>(var: &str, f: F) -> Option<Option<T>> {
33 std::env::var(var)
34 .ok()
35 .map(|v| v.parse::<u64>().map(f).ok())
36}
37
38#[derive(Clone, Display, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
40#[display(
41 fmt = "SocketDesc {{ protocol: {}, local: {}, remote: {} }}",
42 protocol,
43 local,
44 remote
45)]
46pub struct SocketDesc {
47 pub protocol: Protocol,
48 pub local: SocketEndpoint,
49 pub remote: SocketEndpoint,
50}
51
52impl SocketDesc {
53 pub fn new(
54 protocol: Protocol,
55 local: impl Into<SocketEndpoint>,
56 remote: impl Into<SocketEndpoint>,
57 ) -> Self {
58 Self {
59 protocol,
60 local: local.into(),
61 remote: remote.into(),
62 }
63 }
64}
65
66#[derive(Clone, Copy, Debug, Eq, PartialEq)]
67pub enum SocketState<T> {
68 Tcp { state: tcp::State, inner: T },
69 Other { inner: T },
70}
71
72impl<T> SocketState<T> {
73 pub fn inner_mut(&mut self) -> &mut T {
74 match self {
75 Self::Tcp { inner, .. } | Self::Other { inner } => inner,
76 }
77 }
78
79 pub fn set_inner(&mut self, value: T) {
80 match self {
81 Self::Tcp { inner, .. } | Self::Other { inner } => *inner = value,
82 }
83 }
84}
85
86impl<T: Default> From<tcp::State> for SocketState<T> {
87 fn from(state: tcp::State) -> Self {
88 Self::Tcp {
89 state,
90 inner: Default::default(),
91 }
92 }
93}
94
95impl<T: Default> Default for SocketState<T> {
96 fn default() -> Self {
97 SocketState::Other {
98 inner: Default::default(),
99 }
100 }
101}
102
103impl<T> ToString for SocketState<T> {
104 fn to_string(&self) -> String {
105 match self {
106 Self::Tcp { state, .. } => format!("{:?}", state),
107 _ => String::default(),
108 }
109 }
110}
111
112#[derive(From, Display, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
114pub enum SocketEndpoint {
115 Ip(IpEndpoint),
116 #[display(fmt = "{:?}", _0)]
117 Icmp(icmp::Endpoint),
118 Other,
119}
120
121impl SocketEndpoint {
122 #[inline]
123 pub fn ip_endpoint(&self) -> Result<IpEndpoint, Error> {
124 match self {
125 SocketEndpoint::Ip(endpoint) => Ok(*endpoint),
126 other => Err(Error::EndpointInvalid(*other)),
127 }
128 }
129
130 #[inline]
131 pub fn is_specified(&self) -> bool {
132 match self {
133 Self::Ip(ip) => {
134 let ip: IpListenEndpoint = IpListenEndpoint::from(*ip);
135 ip.is_specified()
136 }
137 Self::Icmp(icmp) => icmp.is_specified(),
138 Self::Other => false,
139 }
140 }
141
142 pub fn addr_repr(&self) -> String {
143 match self {
144 Self::Ip(ip) => format!("{}", ip.addr),
145 _ => Default::default(),
146 }
147 }
148
149 pub fn port_repr(&self) -> String {
150 match self {
151 Self::Ip(ip) => format!("{}", ip.port),
152 Self::Icmp(icmp) => match icmp {
153 icmp::Endpoint::Unspecified => "*".to_string(),
154 endpoint => format!("{:?}", endpoint),
155 },
156 Self::Other => Default::default(),
157 }
158 }
159}
160
161#[allow(clippy::derived_hash_with_manual_eq)]
162impl Hash for SocketEndpoint {
163 fn hash<H: Hasher>(&self, state: &mut H) {
164 match self {
165 Self::Ip(ip) => {
166 state.write_u8(1);
167 ip.hash(state);
168 }
169 Self::Icmp(icmp) => {
170 state.write_u8(2);
171 match icmp {
172 icmp::Endpoint::Unspecified => state.write_u8(1),
173 icmp::Endpoint::Udp(ip) => {
174 state.write_u8(2);
175 ip.hash(state);
176 }
177 icmp::Endpoint::Ident(id) => {
178 state.write_u8(3);
179 id.hash(state);
180 }
181 }
182 }
183 Self::Other => state.write_u8(3),
184 }
185 }
186}
187
188impl PartialEq<IpEndpoint> for SocketEndpoint {
189 fn eq(&self, other: &IpEndpoint) -> bool {
190 match &self {
191 Self::Ip(endpoint) => endpoint == other,
192 _ => false,
193 }
194 }
195}
196
197impl From<Option<IpEndpoint>> for SocketEndpoint {
198 fn from(opt: Option<IpEndpoint>) -> Self {
199 match opt {
200 Some(endpoint) => Self::Ip(endpoint),
201 None => Self::Other,
202 }
203 }
204}
205
206impl From<u16> for SocketEndpoint {
207 fn from(ident: u16) -> Self {
208 Self::Icmp(icmp::Endpoint::Ident(ident))
209 }
210}
211
212impl<T: Into<IpAddress>> From<(T, u16)> for SocketEndpoint {
213 fn from((t, port): (T, u16)) -> Self {
214 let endpoint: IpEndpoint = (t, port).into();
215 Self::from(endpoint)
216 }
217}
218
219use thiserror::Error;
220
221#[derive(Error, Debug)]
222pub enum RecvError {
223 #[error(transparent)]
224 Tcp(#[from] smoltcp::socket::tcp::RecvError),
225 #[error(transparent)]
226 Udp(#[from] smoltcp::socket::udp::RecvError),
227 #[error(transparent)]
228 Raw(#[from] smoltcp::socket::raw::RecvError),
229 #[error(transparent)]
230 Icmp(#[from] smoltcp::socket::icmp::RecvError),
231 #[error("Dhcpv4 error")]
232 Dhcpv4,
233 #[error("DNS error")]
234 Dns,
235}
236
237pub trait SocketExt {
239 fn protocol(&self) -> Protocol;
240 fn local_endpoint(&self) -> SocketEndpoint;
241 fn remote_endpoint(&self) -> SocketEndpoint;
242
243 fn is_closed(&self) -> bool;
244 fn close(&mut self);
245
246 fn can_recv(&self) -> bool;
247 fn recv(&mut self) -> std::result::Result<Option<(IpEndpoint, Vec<u8>)>, RecvError>;
248
249 fn can_send(&self) -> bool;
250 fn send_capacity(&self) -> usize;
251 fn send_queue(&self) -> usize;
252
253 fn state<T: Default>(&self) -> SocketState<T>;
254 fn desc(&self) -> SocketDesc;
255}
256
257impl<'a> SocketExt for Socket<'a> {
258 fn protocol(&self) -> Protocol {
259 match &self {
260 Self::Tcp(_) => Protocol::Tcp,
261 Self::Udp(_) => Protocol::Udp,
262 Self::Icmp(_) => Protocol::Icmp,
263 Self::Raw(_) => Protocol::Ethernet,
264 Self::Dhcpv4(_) => Protocol::None,
265 Socket::Dns(_) => Protocol::None,
266 }
267 }
268
269 fn local_endpoint(&self) -> SocketEndpoint {
270 match &self {
271 Self::Tcp(s) => s.local_endpoint().into(),
272 Self::Udp(s) => {
273 let Some(addr) = s.endpoint().addr else {
274 return SocketEndpoint::Other
275 };
276 let port = s.endpoint().port;
277 SocketEndpoint::Ip(IpEndpoint { addr, port })
278 }
279 _ => SocketEndpoint::Other,
280 }
281 }
282
283 fn remote_endpoint(&self) -> SocketEndpoint {
284 match &self {
285 Self::Tcp(s) => s.remote_endpoint().into(),
286 _ => SocketEndpoint::Other,
287 }
288 }
289
290 fn is_closed(&self) -> bool {
291 match &self {
292 Self::Tcp(s) => s.state() == tcp::State::Closed,
293 Self::Udp(s) => !s.is_open(),
294 Self::Icmp(s) => !s.is_open(),
295 Self::Raw(_) => false,
296 Self::Dhcpv4(_) => false,
297 Self::Dns(_) => false,
298 }
299 }
300
301 fn close(&mut self) {
302 match self {
303 Self::Tcp(s) => s.close(),
304 Self::Udp(s) => s.close(),
305 _ => (),
306 }
307 }
308
309 fn can_recv(&self) -> bool {
310 match &self {
311 Self::Tcp(s) => s.can_recv(),
312 Self::Udp(s) => s.can_recv(),
313 Self::Icmp(s) => s.can_recv(),
314 Self::Raw(s) => s.can_recv(),
315 Self::Dhcpv4(_) => false,
316 Self::Dns(_) => false,
317 }
318 }
319
320 fn recv(&mut self) -> std::result::Result<Option<(IpEndpoint, Vec<u8>)>, RecvError> {
321 let result = match self {
322 Self::Tcp(tcp) => tcp
323 .recv(|bytes| (bytes.len(), bytes.to_vec()))
324 .map(|vec| (tcp.remote_endpoint(), vec))
325 .map_err(RecvError::from),
326 Self::Udp(udp) => udp
327 .recv()
328 .map(|(bytes, endpoint)| (Some(endpoint.endpoint), bytes.to_vec()))
329 .map_err(RecvError::from),
330 Self::Icmp(icmp) => icmp
331 .recv()
332 .map(|(bytes, address)| (Some((address, 0).into()), bytes.to_vec()))
333 .map_err(RecvError::from),
334 Self::Raw(raw) => raw
335 .recv()
336 .map(|bytes| {
337 let addr = smoltcp::wire::Ipv4Address::UNSPECIFIED.into_address();
338 let port = 0;
339 (Some(IpEndpoint::new(addr, port)), bytes.to_vec())
340 })
341 .map_err(RecvError::from),
342 Self::Dhcpv4(_) => Err(RecvError::Dhcpv4),
343 Self::Dns(_) => Err(RecvError::Dns),
344 };
345
346 match result {
347 Ok((Some(endpoint), bytes)) => Ok(Some((endpoint, bytes))),
348 Ok((None, _)) => Ok(None),
349 Err(RecvError::Udp(smoltcp::socket::udp::RecvError::Exhausted)) => Ok(None),
350 Err(err) => Err(err),
351 }
352 }
353
354 fn can_send(&self) -> bool {
355 match &self {
356 Self::Tcp(s) => s.can_send(),
357 Self::Udp(s) => s.can_send(),
358 Self::Icmp(s) => s.can_send(),
359 Self::Raw(s) => s.can_send(),
360 Self::Dhcpv4(_) => false,
361 Self::Dns(_) => false,
362 }
363 }
364
365 fn send_capacity(&self) -> usize {
366 match &self {
367 Self::Tcp(s) => s.send_capacity(),
368 Self::Udp(s) => s.payload_send_capacity(),
369 Self::Icmp(s) => s.payload_send_capacity(),
370 Self::Raw(s) => s.payload_send_capacity(),
371 Self::Dhcpv4(_) => 0,
372 Self::Dns(_) => 0,
373 }
374 }
375
376 fn send_queue(&self) -> usize {
377 match &self {
378 Self::Tcp(s) => s.send_queue(),
379 _ => {
380 if self.can_send() {
381 self.send_capacity() } else {
383 0
384 }
385 }
386 }
387 }
388
389 fn state<T: Default>(&self) -> SocketState<T> {
390 match &self {
391 Self::Tcp(s) => SocketState::from(s.state()),
392 _ => SocketState::Other {
393 inner: Default::default(),
394 },
395 }
396 }
397
398 fn desc(&self) -> SocketDesc {
399 SocketDesc {
400 protocol: self.protocol(),
401 local: self.local_endpoint(),
402 remote: self.remote_endpoint(),
403 }
404 }
405}
406
407pub trait TcpSocketExt {
408 fn set_defaults(&mut self);
409}
410
411impl<'a> TcpSocketExt for tcp::Socket<'a> {
412 fn set_defaults(&mut self) {
413 self.set_nagle_enabled(*TCP_NAGLE_ENABLED);
414 self.set_timeout(*TCP_TIMEOUT);
415 self.set_keep_alive(*TCP_KEEP_ALIVE);
416 self.set_ack_delay(*TCP_ACK_DELAY);
417 }
418}
419
420#[derive(Clone, Copy, Debug)]
421pub struct SocketMemory {
422 pub tx: Memory,
423 pub rx: Memory,
424}
425
426impl SocketMemory {
427 pub fn default_tcp() -> Self {
428 Self {
429 rx: Memory::default_tcp_rx(),
430 tx: Memory::default_tcp_tx(),
431 }
432 }
433
434 pub fn default_udp() -> Self {
435 Self {
436 rx: Memory::default_udp_rx(),
437 tx: Memory::default_udp_tx(),
438 }
439 }
440
441 pub fn default_icmp() -> Self {
442 Self::default_udp()
443 }
444
445 pub fn default_raw() -> Self {
446 Self::default_tcp()
447 }
448}
449
450#[derive(Clone, Copy, Debug)]
453pub struct Memory {
454 min: usize,
455 default: usize,
456 max: usize,
457}
458
459impl Memory {
460 pub fn new(min: usize, default: usize, max: usize) -> Result<Self, Error> {
461 if default < min || default > max {
462 return Err(Error::Other(format!(
463 "Invalid memory bounds: {min} <= {default} <= {max}",
464 )));
465 }
466 Ok(Self { min, default, max })
467 }
468
469 pub fn set_min(&mut self, min: usize) -> Result<(), Error> {
470 if min > self.default {
471 return Err(Error::Other(format!(
472 "Invalid min memory bound: {min} <= {}",
473 self.default
474 )));
475 }
476
477 self.min = min;
478 Ok(())
479 }
480
481 pub fn set_default(&mut self, default: usize) -> Result<(), Error> {
482 if default < self.min || default > self.max {
483 return Err(Error::Other(format!(
484 "Invalid default memory size: {} <= {default} <= {}",
485 self.min, self.max,
486 )));
487 }
488
489 self.default = default;
490 Ok(())
491 }
492
493 pub fn set_max(&mut self, max: usize) -> Result<(), Error> {
494 if max < self.default {
495 return Err(Error::Other(format!(
496 "Invalid max memory bound: {} <= {max}",
497 self.default
498 )));
499 }
500
501 self.max = max;
502 Ok(())
503 }
504
505 pub fn default_tcp_rx() -> Self {
506 Self::new(4 * 1024, 128 * 1024, 4 * 1024 * 1024).expect("Invalid TCP recv buffer bounds")
509 }
510
511 pub fn default_tcp_tx() -> Self {
512 Self::new(4 * 1024, 16 * 1024, 128 * 1024).expect("Invalid TCP send buffer bounds")
515 }
516
517 pub fn default_udp_rx() -> Self {
518 Self::new(10 * 1024, 128 * 1024, 1490 * 1024).expect("Invalid UDP recv buffer bounds")
522 }
523
524 pub fn default_udp_tx() -> Self {
525 Self::new(10 * 1024, 128 * 1024, 1490 * 1024).expect("Invalid UDP send buffer bounds")
529 }
530}
531
532pub fn tcp_socket<'a>(rx_mem: Memory, tx_mem: Memory) -> tcp::Socket<'a> {
533 let rx_buf = tcp::SocketBuffer::new(vec![0; rx_mem.max]);
534 let tx_buf = tcp::SocketBuffer::new(vec![0; tx_mem.max]);
535 let mut socket = tcp::Socket::new(rx_buf, tx_buf);
536 socket.set_defaults();
537 socket
538}
539
540pub fn udp_socket<'a>(rx_mem: Memory, tx_mem: Memory) -> udp::Socket<'a> {
541 let rx_buf =
542 udp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(rx_mem.max));
543 let tx_buf =
544 udp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(tx_mem.max));
545 udp::Socket::new(rx_buf, tx_buf)
546}
547
548pub fn icmp_socket<'a>(rx_mem: Memory, tx_mem: Memory) -> icmp::Socket<'a> {
549 let rx_buf =
550 icmp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(rx_mem.max));
551 let tx_buf =
552 icmp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(tx_mem.max));
553 icmp::Socket::new(rx_buf, tx_buf)
554}
555
556pub fn raw_socket<'a>(
557 ip_version: IpVersion,
558 ip_protocol: IpProtocol,
559 rx_mem: Memory,
560 tx_mem: Memory,
561) -> raw::Socket<'a> {
562 let rx_buf =
563 raw::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(rx_mem.max));
564 let tx_buf =
565 raw::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(tx_mem.max));
566 raw::Socket::new(ip_version, ip_protocol, rx_buf, tx_buf)
567}
568
569#[inline]
570fn meta_storage<H: Clone>(size: usize) -> Vec<PacketMetadata<H>> {
571 vec![PacketMetadata::EMPTY; size]
572}
573
574#[inline]
575fn payload_storage<T: Default + Clone>(size: usize) -> Vec<T> {
576 vec![Default::default(); size]
577}