ya_relay_stack/
stack.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::rc::Rc;
4
5use smoltcp::iface::{Route, SocketHandle};
6use smoltcp::socket::*;
7use smoltcp::time::Instant;
8use smoltcp::wire::{IpAddress, IpCidr, IpEndpoint, IpProtocol, IpVersion};
9
10use crate::connection::{Connect, Connection, ConnectionMeta, Disconnect, Send};
11use crate::interface::*;
12use crate::metrics::ChannelMetrics;
13use crate::patch_smoltcp::GetSocketSafe;
14use crate::protocol::Protocol;
15use crate::socket::*;
16use crate::{port, StackConfig};
17use crate::{Error, Result};
18
19use ya_relay_util::Payload;
20
21#[derive(Clone)]
22pub struct Stack<'a> {
23    iface: Rc<RefCell<CaptureInterface<'a>>>,
24    metrics: Rc<RefCell<HashMap<SocketDesc, ChannelMetrics>>>,
25    ports: Rc<RefCell<port::Allocator>>,
26    config: Rc<StackConfig>,
27}
28
29impl<'a> Stack<'a> {
30    pub fn new(iface: CaptureInterface<'a>, config: Rc<StackConfig>) -> Self {
31        Self {
32            iface: Rc::new(RefCell::new(iface)),
33            metrics: Default::default(),
34            ports: Default::default(),
35            config,
36        }
37    }
38
39    pub fn address(&self) -> Result<IpCidr> {
40        {
41            let iface = self.iface.borrow();
42            iface.inner().ip_addrs().iter().next().cloned()
43        }
44        .ok_or(Error::NetEmpty)
45    }
46
47    pub fn addresses(&self) -> Vec<IpCidr> {
48        self.iface.borrow().inner().ip_addrs().to_vec()
49    }
50
51    pub fn add_address(&self, address: IpCidr) {
52        let mut iface = self.iface.borrow_mut();
53        add_iface_address(&mut iface, address);
54    }
55
56    pub fn add_route(&self, net_ip: IpCidr, route: Route) {
57        let mut iface = self.iface.borrow_mut();
58        add_iface_route(&mut iface, net_ip, route);
59    }
60
61    #[inline]
62    pub(crate) fn iface(&self) -> Rc<RefCell<CaptureInterface<'a>>> {
63        self.iface.clone()
64    }
65
66    #[inline]
67    pub(crate) fn metrics(&self) -> Rc<RefCell<HashMap<SocketDesc, ChannelMetrics>>> {
68        self.metrics.clone()
69    }
70
71    pub(crate) fn on_sent(&self, desc: &SocketDesc, size: usize) {
72        let mut metrics = self.metrics.borrow_mut();
73        metrics.entry(*desc).or_default().tx.push(size as f32);
74    }
75
76    pub(crate) fn on_received(&self, desc: &SocketDesc, size: usize) {
77        let mut metrics = self.metrics.borrow_mut();
78        metrics.entry(*desc).or_default().rx.push(size as f32);
79    }
80}
81
82impl<'a> Stack<'a> {
83    pub fn bind(
84        &self,
85        protocol: Protocol,
86        endpoint: impl Into<SocketEndpoint>,
87    ) -> Result<SocketHandle> {
88        let endpoint = endpoint.into();
89        let mut iface = self.iface.borrow_mut();
90
91        let handle = match protocol {
92            Protocol::Tcp => {
93                if let SocketEndpoint::Ip(ep) = endpoint {
94                    let mut socket = tcp_socket(self.config.tcp_mem.rx, self.config.tcp_mem.tx);
95                    socket.listen(ep).map_err(|e| Error::Other(e.to_string()))?;
96                    socket.set_defaults();
97                    iface.add_socket(socket)
98                } else {
99                    return Err(Error::Other("Expected an IP endpoint".to_string()));
100                }
101            }
102            Protocol::Udp => {
103                if let SocketEndpoint::Ip(ep) = endpoint {
104                    let mut socket = udp_socket(self.config.udp_mem.rx, self.config.udp_mem.tx);
105                    socket.bind(ep).map_err(|e| Error::Other(e.to_string()))?;
106                    iface.add_socket(socket)
107                } else {
108                    return Err(Error::Other("Expected an IP endpoint".to_string()));
109                }
110            }
111            Protocol::Icmp | Protocol::Ipv6Icmp => {
112                if let SocketEndpoint::Icmp(e) = endpoint {
113                    let mut socket = icmp_socket(self.config.icmp_mem.rx, self.config.icmp_mem.tx);
114                    socket.bind(e).map_err(|e| Error::Other(e.to_string()))?;
115                    iface.add_socket(socket)
116                } else {
117                    return Err(Error::Other("Expected an ICMP endpoint".to_string()));
118                }
119            }
120            _ => {
121                let ip_version = {
122                    match endpoint {
123                        SocketEndpoint::Ip(ep) => match ep.addr {
124                            IpAddress::Ipv4(_) => IpVersion::Ipv4,
125                            IpAddress::Ipv6(_) => IpVersion::Ipv6,
126                        },
127                        _ => return Err(Error::Other("Expected an IP endpoint".to_string())),
128                    }
129                };
130
131                let socket = raw_socket(
132                    ip_version,
133                    map_protocol(protocol)?,
134                    self.config.raw_mem.rx,
135                    self.config.raw_mem.tx,
136                );
137                iface.add_socket(socket)
138            }
139        };
140
141        Ok(handle)
142    }
143
144    pub fn unbind(
145        &self,
146        protocol: Protocol,
147        endpoint: impl Into<SocketEndpoint>,
148    ) -> Result<SocketHandle> {
149        let endpoint = endpoint.into();
150        log::trace!("Unbinding {protocol:?} {endpoint}");
151        let mut iface = self.iface.borrow_mut();
152        let mut sockets = iface.sockets_mut();
153
154        let handle = sockets
155            .find(|(_, s)| s.local_endpoint() == endpoint)
156            .and_then(|(h, _)| match protocol {
157                Protocol::Tcp | Protocol::Udp | Protocol::Icmp | Protocol::Ipv6Icmp => Some(h),
158                _ => None,
159            })
160            .ok_or(Error::SocketClosed)?;
161
162        let _ = endpoint.ip_endpoint().map(|e| {
163            let mut ports = self.ports.borrow_mut();
164            ports.free(protocol, e.port);
165        });
166
167        drop(sockets);
168        iface.remove_socket(handle);
169        Ok(handle)
170    }
171
172    pub fn connect(&self, remote: IpEndpoint) -> Result<Connect<'a>> {
173        let ip = self.address()?.address();
174
175        let mut iface = self.iface.borrow_mut();
176        let mut ports = self.ports.borrow_mut();
177
178        let protocol = Protocol::Tcp;
179        let handle = iface.add_socket(tcp_socket(self.config.tcp_mem.rx, self.config.tcp_mem.tx));
180        let port = ports.next(protocol)?;
181        let local: IpEndpoint = (ip, port).into();
182
183        match {
184            let (socket, ctx) = iface.get_socket_and_context::<tcp::Socket>(handle);
185            socket.connect(ctx, remote, local).map(|_| socket)
186        } {
187            Ok(socket) => socket.set_defaults(),
188            Err(e) => {
189                iface.remove_socket(handle);
190                ports.free(Protocol::Tcp, port);
191                return Err(Error::ConnectionError(e.to_string()));
192            }
193        }
194
195        let meta = ConnectionMeta {
196            protocol,
197            local,
198            remote,
199        };
200        Ok(Connect::new(
201            Connection { handle, meta },
202            self.iface.clone(),
203        ))
204    }
205
206    pub fn disconnect(&self, handle: SocketHandle) -> Disconnect<'a> {
207        let mut iface = self.iface.borrow_mut();
208        if let Ok(sock) = iface.get_socket_safe::<tcp::Socket>(handle) {
209            log::trace!("Disconnecting. Socket handle: {handle}.");
210            sock.close();
211        }
212        Disconnect::new(handle, self.iface.clone())
213    }
214
215    pub(crate) fn abort(&self, handle: SocketHandle) {
216        let mut iface = self.iface.borrow_mut();
217        if let Ok(sock) = iface.get_socket_safe::<tcp::Socket>(handle) {
218            log::trace!("Aborting. Socket handle: {handle}.");
219            sock.abort();
220        }
221    }
222
223    pub(crate) fn remove(&self, meta: &ConnectionMeta, handle: SocketHandle) {
224        let mut iface = self.iface.borrow_mut();
225        let mut metrics = self.metrics.borrow_mut();
226        let mut ports = self.ports.borrow_mut();
227
228        if let Some((handle, socket)) = {
229            let mut sockets = iface.sockets();
230            sockets.find(|(h, _)| h == &handle)
231        } {
232            log::trace!("Removing connection: {meta}. Socket handle: {handle}");
233
234            metrics.remove(&socket.desc());
235            iface.remove_socket(handle);
236            ports.free(meta.protocol, meta.local.port);
237        }
238    }
239
240    #[inline]
241    pub fn send<B: Into<Payload>, F: Fn() + 'static>(
242        &self,
243        data: B,
244        conn: Connection,
245        f: F,
246    ) -> Send<'a> {
247        Send::new(data.into(), conn, self.iface.clone(), f)
248    }
249
250    #[inline]
251    pub fn receive<B: Into<Payload>>(&self, data: B) {
252        let mut iface = self.iface.borrow_mut();
253        iface.device_mut().phy_rx(data.into());
254    }
255
256    #[inline]
257    pub fn poll(&self) -> bool {
258        let mut iface = self.iface.borrow_mut();
259        iface.poll(Instant::now())
260    }
261}
262
263fn map_protocol(protocol: Protocol) -> Result<IpProtocol> {
264    match protocol {
265        Protocol::HopByHop => Ok(IpProtocol::HopByHop),
266        Protocol::Icmp => Ok(IpProtocol::Icmp),
267        Protocol::Igmp => Ok(IpProtocol::Igmp),
268        Protocol::Tcp => Ok(IpProtocol::Tcp),
269        Protocol::Udp => Ok(IpProtocol::Udp),
270        Protocol::Ipv6Route => Ok(IpProtocol::Ipv6Route),
271        Protocol::Ipv6Frag => Ok(IpProtocol::Ipv6Frag),
272        Protocol::Ipv6Icmp => Ok(IpProtocol::Icmpv6),
273        Protocol::Ipv6NoNxt => Ok(IpProtocol::Ipv6NoNxt),
274        Protocol::Ipv6Opts => Ok(IpProtocol::Ipv6Opts),
275        _ => Err(Error::ProtocolNotSupported(protocol.to_string())),
276    }
277}