ya_relay_stack/
socket.rs

1use derive_more::Display;
2use std::hash::{Hash, Hasher};
3
4use derive_more::From;
5use smoltcp::socket::*;
6use smoltcp::storage::PacketMetadata;
7use smoltcp::time::Duration;
8use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpVersion};
9
10use crate::{Error, Protocol};
11
12pub const ENV_VAR_TCP_TIMEOUT: &str = "YA_NET_TCP_TIMEOUT_MS";
13pub const ENV_VAR_TCP_KEEP_ALIVE: &str = "YA_NET_TCP_KEEP_ALIVE_MS";
14pub const ENV_VAR_TCP_ACK_DELAY: &str = "YA_NET_TCP_ACK_DELAY_MS";
15pub const ENV_VAR_TCP_NAGLE: &str = "YA_NET_TCP_ACK_DELAY";
16pub const TCP_CONN_TIMEOUT: Duration = Duration::from_secs(45);
17pub const TCP_DISCONN_TIMEOUT: Duration = Duration::from_secs(2);
18const META_STORAGE_SIZE: usize = 1024;
19
20lazy_static::lazy_static! {
21    pub static ref TCP_NAGLE_ENABLED: bool = env_opt(ENV_VAR_TCP_NAGLE, |v| v != 0)
22        .flatten()
23        .unwrap_or(false);
24    pub static ref TCP_TIMEOUT: Option<Duration> = env_opt(ENV_VAR_TCP_TIMEOUT, Duration::from_millis)
25        .unwrap_or(Some(Duration::from_secs(120)));
26    pub static ref TCP_KEEP_ALIVE: Option<Duration> = env_opt(ENV_VAR_TCP_KEEP_ALIVE, Duration::from_millis)
27        .unwrap_or(Some(Duration::from_secs(30)));
28    pub static ref TCP_ACK_DELAY: Option<Duration> = env_opt(ENV_VAR_TCP_ACK_DELAY, Duration::from_millis)
29        .unwrap_or(Some(Duration::from_millis(40)));
30}
31
32fn env_opt<T, F: FnOnce(u64) -> T>(var: &str, f: F) -> Option<Option<T>> {
33    std::env::var(var)
34        .ok()
35        .map(|v| v.parse::<u64>().map(f).ok())
36}
37
38/// Socket quintuplet
39#[derive(Clone, Display, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
40#[display(
41    fmt = "SocketDesc {{ protocol: {}, local: {}, remote: {} }}",
42    protocol,
43    local,
44    remote
45)]
46pub struct SocketDesc {
47    pub protocol: Protocol,
48    pub local: SocketEndpoint,
49    pub remote: SocketEndpoint,
50}
51
52impl SocketDesc {
53    pub fn new(
54        protocol: Protocol,
55        local: impl Into<SocketEndpoint>,
56        remote: impl Into<SocketEndpoint>,
57    ) -> Self {
58        Self {
59            protocol,
60            local: local.into(),
61            remote: remote.into(),
62        }
63    }
64}
65
66#[derive(Clone, Copy, Debug, Eq, PartialEq)]
67pub enum SocketState<T> {
68    Tcp { state: tcp::State, inner: T },
69    Other { inner: T },
70}
71
72impl<T> SocketState<T> {
73    pub fn inner_mut(&mut self) -> &mut T {
74        match self {
75            Self::Tcp { inner, .. } | Self::Other { inner } => inner,
76        }
77    }
78
79    pub fn set_inner(&mut self, value: T) {
80        match self {
81            Self::Tcp { inner, .. } | Self::Other { inner } => *inner = value,
82        }
83    }
84}
85
86impl<T: Default> From<tcp::State> for SocketState<T> {
87    fn from(state: tcp::State) -> Self {
88        Self::Tcp {
89            state,
90            inner: Default::default(),
91        }
92    }
93}
94
95impl<T: Default> Default for SocketState<T> {
96    fn default() -> Self {
97        SocketState::Other {
98            inner: Default::default(),
99        }
100    }
101}
102
103impl<T> ToString for SocketState<T> {
104    fn to_string(&self) -> String {
105        match self {
106            Self::Tcp { state, .. } => format!("{:?}", state),
107            _ => String::default(),
108        }
109    }
110}
111
112/// Socket endpoint kind
113#[derive(From, Display, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
114pub enum SocketEndpoint {
115    Ip(IpEndpoint),
116    #[display(fmt = "{:?}", _0)]
117    Icmp(icmp::Endpoint),
118    Other,
119}
120
121impl SocketEndpoint {
122    #[inline]
123    pub fn ip_endpoint(&self) -> Result<IpEndpoint, Error> {
124        match self {
125            SocketEndpoint::Ip(endpoint) => Ok(*endpoint),
126            other => Err(Error::EndpointInvalid(*other)),
127        }
128    }
129
130    #[inline]
131    pub fn is_specified(&self) -> bool {
132        match self {
133            Self::Ip(ip) => {
134                let ip: IpListenEndpoint = IpListenEndpoint::from(*ip);
135                ip.is_specified()
136            }
137            Self::Icmp(icmp) => icmp.is_specified(),
138            Self::Other => false,
139        }
140    }
141
142    pub fn addr_repr(&self) -> String {
143        match self {
144            Self::Ip(ip) => format!("{}", ip.addr),
145            _ => Default::default(),
146        }
147    }
148
149    pub fn port_repr(&self) -> String {
150        match self {
151            Self::Ip(ip) => format!("{}", ip.port),
152            Self::Icmp(icmp) => match icmp {
153                icmp::Endpoint::Unspecified => "*".to_string(),
154                endpoint => format!("{:?}", endpoint),
155            },
156            Self::Other => Default::default(),
157        }
158    }
159}
160
161#[allow(clippy::derived_hash_with_manual_eq)]
162impl Hash for SocketEndpoint {
163    fn hash<H: Hasher>(&self, state: &mut H) {
164        match self {
165            Self::Ip(ip) => {
166                state.write_u8(1);
167                ip.hash(state);
168            }
169            Self::Icmp(icmp) => {
170                state.write_u8(2);
171                match icmp {
172                    icmp::Endpoint::Unspecified => state.write_u8(1),
173                    icmp::Endpoint::Udp(ip) => {
174                        state.write_u8(2);
175                        ip.hash(state);
176                    }
177                    icmp::Endpoint::Ident(id) => {
178                        state.write_u8(3);
179                        id.hash(state);
180                    }
181                }
182            }
183            Self::Other => state.write_u8(3),
184        }
185    }
186}
187
188impl PartialEq<IpEndpoint> for SocketEndpoint {
189    fn eq(&self, other: &IpEndpoint) -> bool {
190        match &self {
191            Self::Ip(endpoint) => endpoint == other,
192            _ => false,
193        }
194    }
195}
196
197impl From<Option<IpEndpoint>> for SocketEndpoint {
198    fn from(opt: Option<IpEndpoint>) -> Self {
199        match opt {
200            Some(endpoint) => Self::Ip(endpoint),
201            None => Self::Other,
202        }
203    }
204}
205
206impl From<u16> for SocketEndpoint {
207    fn from(ident: u16) -> Self {
208        Self::Icmp(icmp::Endpoint::Ident(ident))
209    }
210}
211
212impl<T: Into<IpAddress>> From<(T, u16)> for SocketEndpoint {
213    fn from((t, port): (T, u16)) -> Self {
214        let endpoint: IpEndpoint = (t, port).into();
215        Self::from(endpoint)
216    }
217}
218
219use thiserror::Error;
220
221#[derive(Error, Debug)]
222pub enum RecvError {
223    #[error(transparent)]
224    Tcp(#[from] smoltcp::socket::tcp::RecvError),
225    #[error(transparent)]
226    Udp(#[from] smoltcp::socket::udp::RecvError),
227    #[error(transparent)]
228    Raw(#[from] smoltcp::socket::raw::RecvError),
229    #[error(transparent)]
230    Icmp(#[from] smoltcp::socket::icmp::RecvError),
231    #[error("Dhcpv4 error")]
232    Dhcpv4,
233    #[error("DNS error")]
234    Dns,
235}
236
237/// Common interface for various socket types
238pub trait SocketExt {
239    fn protocol(&self) -> Protocol;
240    fn local_endpoint(&self) -> SocketEndpoint;
241    fn remote_endpoint(&self) -> SocketEndpoint;
242
243    fn is_closed(&self) -> bool;
244    fn close(&mut self);
245
246    fn can_recv(&self) -> bool;
247    fn recv(&mut self) -> std::result::Result<Option<(IpEndpoint, Vec<u8>)>, RecvError>;
248
249    fn can_send(&self) -> bool;
250    fn send_capacity(&self) -> usize;
251    fn send_queue(&self) -> usize;
252
253    fn state<T: Default>(&self) -> SocketState<T>;
254    fn desc(&self) -> SocketDesc;
255}
256
257impl<'a> SocketExt for Socket<'a> {
258    fn protocol(&self) -> Protocol {
259        match &self {
260            Self::Tcp(_) => Protocol::Tcp,
261            Self::Udp(_) => Protocol::Udp,
262            Self::Icmp(_) => Protocol::Icmp,
263            Self::Raw(_) => Protocol::Ethernet,
264            Self::Dhcpv4(_) => Protocol::None,
265            Socket::Dns(_) => Protocol::None,
266        }
267    }
268
269    fn local_endpoint(&self) -> SocketEndpoint {
270        match &self {
271            Self::Tcp(s) => s.local_endpoint().into(),
272            Self::Udp(s) => {
273                let Some(addr) = s.endpoint().addr else {
274                    return SocketEndpoint::Other
275                };
276                let port = s.endpoint().port;
277                SocketEndpoint::Ip(IpEndpoint { addr, port })
278            }
279            _ => SocketEndpoint::Other,
280        }
281    }
282
283    fn remote_endpoint(&self) -> SocketEndpoint {
284        match &self {
285            Self::Tcp(s) => s.remote_endpoint().into(),
286            _ => SocketEndpoint::Other,
287        }
288    }
289
290    fn is_closed(&self) -> bool {
291        match &self {
292            Self::Tcp(s) => s.state() == tcp::State::Closed,
293            Self::Udp(s) => !s.is_open(),
294            Self::Icmp(s) => !s.is_open(),
295            Self::Raw(_) => false,
296            Self::Dhcpv4(_) => false,
297            Self::Dns(_) => false,
298        }
299    }
300
301    fn close(&mut self) {
302        match self {
303            Self::Tcp(s) => s.close(),
304            Self::Udp(s) => s.close(),
305            _ => (),
306        }
307    }
308
309    fn can_recv(&self) -> bool {
310        match &self {
311            Self::Tcp(s) => s.can_recv(),
312            Self::Udp(s) => s.can_recv(),
313            Self::Icmp(s) => s.can_recv(),
314            Self::Raw(s) => s.can_recv(),
315            Self::Dhcpv4(_) => false,
316            Self::Dns(_) => false,
317        }
318    }
319
320    fn recv(&mut self) -> std::result::Result<Option<(IpEndpoint, Vec<u8>)>, RecvError> {
321        let result = match self {
322            Self::Tcp(tcp) => tcp
323                .recv(|bytes| (bytes.len(), bytes.to_vec()))
324                .map(|vec| (tcp.remote_endpoint(), vec))
325                .map_err(RecvError::from),
326            Self::Udp(udp) => udp
327                .recv()
328                .map(|(bytes, endpoint)| (Some(endpoint.endpoint), bytes.to_vec()))
329                .map_err(RecvError::from),
330            Self::Icmp(icmp) => icmp
331                .recv()
332                .map(|(bytes, address)| (Some((address, 0).into()), bytes.to_vec()))
333                .map_err(RecvError::from),
334            Self::Raw(raw) => raw
335                .recv()
336                .map(|bytes| {
337                    let addr = smoltcp::wire::Ipv4Address::UNSPECIFIED.into_address();
338                    let port = 0;
339                    (Some(IpEndpoint::new(addr, port)), bytes.to_vec())
340                })
341                .map_err(RecvError::from),
342            Self::Dhcpv4(_) => Err(RecvError::Dhcpv4),
343            Self::Dns(_) => Err(RecvError::Dns),
344        };
345
346        match result {
347            Ok((Some(endpoint), bytes)) => Ok(Some((endpoint, bytes))),
348            Ok((None, _)) => Ok(None),
349            Err(RecvError::Udp(smoltcp::socket::udp::RecvError::Exhausted)) => Ok(None),
350            Err(err) => Err(err),
351        }
352    }
353
354    fn can_send(&self) -> bool {
355        match &self {
356            Self::Tcp(s) => s.can_send(),
357            Self::Udp(s) => s.can_send(),
358            Self::Icmp(s) => s.can_send(),
359            Self::Raw(s) => s.can_send(),
360            Self::Dhcpv4(_) => false,
361            Self::Dns(_) => false,
362        }
363    }
364
365    fn send_capacity(&self) -> usize {
366        match &self {
367            Self::Tcp(s) => s.send_capacity(),
368            Self::Udp(s) => s.payload_send_capacity(),
369            Self::Icmp(s) => s.payload_send_capacity(),
370            Self::Raw(s) => s.payload_send_capacity(),
371            Self::Dhcpv4(_) => 0,
372            Self::Dns(_) => 0,
373        }
374    }
375
376    fn send_queue(&self) -> usize {
377        match &self {
378            Self::Tcp(s) => s.send_queue(),
379            _ => {
380                if self.can_send() {
381                    self.send_capacity() // mock value
382                } else {
383                    0
384                }
385            }
386        }
387    }
388
389    fn state<T: Default>(&self) -> SocketState<T> {
390        match &self {
391            Self::Tcp(s) => SocketState::from(s.state()),
392            _ => SocketState::Other {
393                inner: Default::default(),
394            },
395        }
396    }
397
398    fn desc(&self) -> SocketDesc {
399        SocketDesc {
400            protocol: self.protocol(),
401            local: self.local_endpoint(),
402            remote: self.remote_endpoint(),
403        }
404    }
405}
406
407pub trait TcpSocketExt {
408    fn set_defaults(&mut self);
409}
410
411impl<'a> TcpSocketExt for tcp::Socket<'a> {
412    fn set_defaults(&mut self) {
413        self.set_nagle_enabled(*TCP_NAGLE_ENABLED);
414        self.set_timeout(*TCP_TIMEOUT);
415        self.set_keep_alive(*TCP_KEEP_ALIVE);
416        self.set_ack_delay(*TCP_ACK_DELAY);
417    }
418}
419
420#[derive(Clone, Copy, Debug)]
421pub struct SocketMemory {
422    pub tx: Memory,
423    pub rx: Memory,
424}
425
426impl SocketMemory {
427    pub fn default_tcp() -> Self {
428        Self {
429            rx: Memory::default_tcp_rx(),
430            tx: Memory::default_tcp_tx(),
431        }
432    }
433
434    pub fn default_udp() -> Self {
435        Self {
436            rx: Memory::default_udp_rx(),
437            tx: Memory::default_udp_tx(),
438        }
439    }
440
441    pub fn default_icmp() -> Self {
442        Self::default_udp()
443    }
444
445    pub fn default_raw() -> Self {
446        Self::default_tcp()
447    }
448}
449
450/// Buffer size bounds used in auto tuning.
451/// Currently, only `max` is used; other values are reserved for future use
452#[derive(Clone, Copy, Debug)]
453pub struct Memory {
454    min: usize,
455    default: usize,
456    max: usize,
457}
458
459impl Memory {
460    pub fn new(min: usize, default: usize, max: usize) -> Result<Self, Error> {
461        if default < min || default > max {
462            return Err(Error::Other(format!(
463                "Invalid memory bounds: {min} <= {default} <= {max}",
464            )));
465        }
466        Ok(Self { min, default, max })
467    }
468
469    pub fn set_min(&mut self, min: usize) -> Result<(), Error> {
470        if min > self.default {
471            return Err(Error::Other(format!(
472                "Invalid min memory bound: {min} <= {}",
473                self.default
474            )));
475        }
476
477        self.min = min;
478        Ok(())
479    }
480
481    pub fn set_default(&mut self, default: usize) -> Result<(), Error> {
482        if default < self.min || default > self.max {
483            return Err(Error::Other(format!(
484                "Invalid default memory size: {} <= {default} <= {}",
485                self.min, self.max,
486            )));
487        }
488
489        self.default = default;
490        Ok(())
491    }
492
493    pub fn set_max(&mut self, max: usize) -> Result<(), Error> {
494        if max < self.default {
495            return Err(Error::Other(format!(
496                "Invalid max memory bound: {} <= {max}",
497                self.default
498            )));
499        }
500
501        self.max = max;
502        Ok(())
503    }
504
505    pub fn default_tcp_rx() -> Self {
506        // 5.15.0-39-generic #42-Ubuntu SMP
507        // net.ipv4.tcp_rmem = 4096	131072	6291456
508        Self::new(4 * 1024, 128 * 1024, 4 * 1024 * 1024).expect("Invalid TCP recv buffer bounds")
509    }
510
511    pub fn default_tcp_tx() -> Self {
512        // 5.15.0-39-generic #42-Ubuntu SMP
513        // net.ipv4.tcp_wmem = 4096	16384	4194304
514        Self::new(4 * 1024, 16 * 1024, 128 * 1024).expect("Invalid TCP send buffer bounds")
515    }
516
517    pub fn default_udp_rx() -> Self {
518        // 5.15.0-39-generic #42-Ubuntu SMP
519        // net.ipv4.udp_mem = 763233	1017647	1526466
520        // net.ipv4.udp_rmem_min = 4096
521        Self::new(10 * 1024, 128 * 1024, 1490 * 1024).expect("Invalid UDP recv buffer bounds")
522    }
523
524    pub fn default_udp_tx() -> Self {
525        // 5.15.0-39-generic #42-Ubuntu SMP
526        // net.ipv4.udp_mem = 763233	1017647	1526466
527        // net.ipv4.udp_wmem_min = 4096
528        Self::new(10 * 1024, 128 * 1024, 1490 * 1024).expect("Invalid UDP send buffer bounds")
529    }
530}
531
532pub fn tcp_socket<'a>(rx_mem: Memory, tx_mem: Memory) -> tcp::Socket<'a> {
533    let rx_buf = tcp::SocketBuffer::new(vec![0; rx_mem.max]);
534    let tx_buf = tcp::SocketBuffer::new(vec![0; tx_mem.max]);
535    let mut socket = tcp::Socket::new(rx_buf, tx_buf);
536    socket.set_defaults();
537    socket
538}
539
540pub fn udp_socket<'a>(rx_mem: Memory, tx_mem: Memory) -> udp::Socket<'a> {
541    let rx_buf =
542        udp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(rx_mem.max));
543    let tx_buf =
544        udp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(tx_mem.max));
545    udp::Socket::new(rx_buf, tx_buf)
546}
547
548pub fn icmp_socket<'a>(rx_mem: Memory, tx_mem: Memory) -> icmp::Socket<'a> {
549    let rx_buf =
550        icmp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(rx_mem.max));
551    let tx_buf =
552        icmp::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(tx_mem.max));
553    icmp::Socket::new(rx_buf, tx_buf)
554}
555
556pub fn raw_socket<'a>(
557    ip_version: IpVersion,
558    ip_protocol: IpProtocol,
559    rx_mem: Memory,
560    tx_mem: Memory,
561) -> raw::Socket<'a> {
562    let rx_buf =
563        raw::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(rx_mem.max));
564    let tx_buf =
565        raw::PacketBuffer::new(meta_storage(META_STORAGE_SIZE), payload_storage(tx_mem.max));
566    raw::Socket::new(ip_version, ip_protocol, rx_buf, tx_buf)
567}
568
569#[inline]
570fn meta_storage<H: Clone>(size: usize) -> Vec<PacketMetadata<H>> {
571    vec![PacketMetadata::EMPTY; size]
572}
573
574#[inline]
575fn payload_storage<T: Default + Clone>(size: usize) -> Vec<T> {
576    vec![Default::default(); size]
577}