1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::rc::Rc;
4
5use smoltcp::iface::{Route, SocketHandle};
6use smoltcp::socket::*;
7use smoltcp::time::Instant;
8use smoltcp::wire::{IpAddress, IpCidr, IpEndpoint, IpProtocol, IpVersion};
9
10use crate::connection::{Connect, Connection, ConnectionMeta, Disconnect, Send};
11use crate::interface::*;
12use crate::metrics::ChannelMetrics;
13use crate::patch_smoltcp::GetSocketSafe;
14use crate::protocol::Protocol;
15use crate::socket::*;
16use crate::{port, StackConfig};
17use crate::{Error, Result};
18
19use ya_relay_util::Payload;
20
21#[derive(Clone)]
22pub struct Stack<'a> {
23 iface: Rc<RefCell<CaptureInterface<'a>>>,
24 metrics: Rc<RefCell<HashMap<SocketDesc, ChannelMetrics>>>,
25 ports: Rc<RefCell<port::Allocator>>,
26 config: Rc<StackConfig>,
27}
28
29impl<'a> Stack<'a> {
30 pub fn new(iface: CaptureInterface<'a>, config: Rc<StackConfig>) -> Self {
31 Self {
32 iface: Rc::new(RefCell::new(iface)),
33 metrics: Default::default(),
34 ports: Default::default(),
35 config,
36 }
37 }
38
39 pub fn address(&self) -> Result<IpCidr> {
40 {
41 let iface = self.iface.borrow();
42 iface.inner().ip_addrs().iter().next().cloned()
43 }
44 .ok_or(Error::NetEmpty)
45 }
46
47 pub fn addresses(&self) -> Vec<IpCidr> {
48 self.iface.borrow().inner().ip_addrs().to_vec()
49 }
50
51 pub fn add_address(&self, address: IpCidr) {
52 let mut iface = self.iface.borrow_mut();
53 add_iface_address(&mut iface, address);
54 }
55
56 pub fn add_route(&self, net_ip: IpCidr, route: Route) {
57 let mut iface = self.iface.borrow_mut();
58 add_iface_route(&mut iface, net_ip, route);
59 }
60
61 #[inline]
62 pub(crate) fn iface(&self) -> Rc<RefCell<CaptureInterface<'a>>> {
63 self.iface.clone()
64 }
65
66 #[inline]
67 pub(crate) fn metrics(&self) -> Rc<RefCell<HashMap<SocketDesc, ChannelMetrics>>> {
68 self.metrics.clone()
69 }
70
71 pub(crate) fn on_sent(&self, desc: &SocketDesc, size: usize) {
72 let mut metrics = self.metrics.borrow_mut();
73 metrics.entry(*desc).or_default().tx.push(size as f32);
74 }
75
76 pub(crate) fn on_received(&self, desc: &SocketDesc, size: usize) {
77 let mut metrics = self.metrics.borrow_mut();
78 metrics.entry(*desc).or_default().rx.push(size as f32);
79 }
80}
81
82impl<'a> Stack<'a> {
83 pub fn bind(
84 &self,
85 protocol: Protocol,
86 endpoint: impl Into<SocketEndpoint>,
87 ) -> Result<SocketHandle> {
88 let endpoint = endpoint.into();
89 let mut iface = self.iface.borrow_mut();
90
91 let handle = match protocol {
92 Protocol::Tcp => {
93 if let SocketEndpoint::Ip(ep) = endpoint {
94 let mut socket = tcp_socket(self.config.tcp_mem.rx, self.config.tcp_mem.tx);
95 socket.listen(ep).map_err(|e| Error::Other(e.to_string()))?;
96 socket.set_defaults();
97 iface.add_socket(socket)
98 } else {
99 return Err(Error::Other("Expected an IP endpoint".to_string()));
100 }
101 }
102 Protocol::Udp => {
103 if let SocketEndpoint::Ip(ep) = endpoint {
104 let mut socket = udp_socket(self.config.udp_mem.rx, self.config.udp_mem.tx);
105 socket.bind(ep).map_err(|e| Error::Other(e.to_string()))?;
106 iface.add_socket(socket)
107 } else {
108 return Err(Error::Other("Expected an IP endpoint".to_string()));
109 }
110 }
111 Protocol::Icmp | Protocol::Ipv6Icmp => {
112 if let SocketEndpoint::Icmp(e) = endpoint {
113 let mut socket = icmp_socket(self.config.icmp_mem.rx, self.config.icmp_mem.tx);
114 socket.bind(e).map_err(|e| Error::Other(e.to_string()))?;
115 iface.add_socket(socket)
116 } else {
117 return Err(Error::Other("Expected an ICMP endpoint".to_string()));
118 }
119 }
120 _ => {
121 let ip_version = {
122 match endpoint {
123 SocketEndpoint::Ip(ep) => match ep.addr {
124 IpAddress::Ipv4(_) => IpVersion::Ipv4,
125 IpAddress::Ipv6(_) => IpVersion::Ipv6,
126 },
127 _ => return Err(Error::Other("Expected an IP endpoint".to_string())),
128 }
129 };
130
131 let socket = raw_socket(
132 ip_version,
133 map_protocol(protocol)?,
134 self.config.raw_mem.rx,
135 self.config.raw_mem.tx,
136 );
137 iface.add_socket(socket)
138 }
139 };
140
141 Ok(handle)
142 }
143
144 pub fn unbind(
145 &self,
146 protocol: Protocol,
147 endpoint: impl Into<SocketEndpoint>,
148 ) -> Result<SocketHandle> {
149 let endpoint = endpoint.into();
150 log::trace!("Unbinding {protocol:?} {endpoint}");
151 let mut iface = self.iface.borrow_mut();
152 let mut sockets = iface.sockets_mut();
153
154 let handle = sockets
155 .find(|(_, s)| s.local_endpoint() == endpoint)
156 .and_then(|(h, _)| match protocol {
157 Protocol::Tcp | Protocol::Udp | Protocol::Icmp | Protocol::Ipv6Icmp => Some(h),
158 _ => None,
159 })
160 .ok_or(Error::SocketClosed)?;
161
162 let _ = endpoint.ip_endpoint().map(|e| {
163 let mut ports = self.ports.borrow_mut();
164 ports.free(protocol, e.port);
165 });
166
167 drop(sockets);
168 iface.remove_socket(handle);
169 Ok(handle)
170 }
171
172 pub fn connect(&self, remote: IpEndpoint) -> Result<Connect<'a>> {
173 let ip = self.address()?.address();
174
175 let mut iface = self.iface.borrow_mut();
176 let mut ports = self.ports.borrow_mut();
177
178 let protocol = Protocol::Tcp;
179 let handle = iface.add_socket(tcp_socket(self.config.tcp_mem.rx, self.config.tcp_mem.tx));
180 let port = ports.next(protocol)?;
181 let local: IpEndpoint = (ip, port).into();
182
183 match {
184 let (socket, ctx) = iface.get_socket_and_context::<tcp::Socket>(handle);
185 socket.connect(ctx, remote, local).map(|_| socket)
186 } {
187 Ok(socket) => socket.set_defaults(),
188 Err(e) => {
189 iface.remove_socket(handle);
190 ports.free(Protocol::Tcp, port);
191 return Err(Error::ConnectionError(e.to_string()));
192 }
193 }
194
195 let meta = ConnectionMeta {
196 protocol,
197 local,
198 remote,
199 };
200 Ok(Connect::new(
201 Connection { handle, meta },
202 self.iface.clone(),
203 ))
204 }
205
206 pub fn disconnect(&self, handle: SocketHandle) -> Disconnect<'a> {
207 let mut iface = self.iface.borrow_mut();
208 if let Ok(sock) = iface.get_socket_safe::<tcp::Socket>(handle) {
209 log::trace!("Disconnecting. Socket handle: {handle}.");
210 sock.close();
211 }
212 Disconnect::new(handle, self.iface.clone())
213 }
214
215 pub(crate) fn abort(&self, handle: SocketHandle) {
216 let mut iface = self.iface.borrow_mut();
217 if let Ok(sock) = iface.get_socket_safe::<tcp::Socket>(handle) {
218 log::trace!("Aborting. Socket handle: {handle}.");
219 sock.abort();
220 }
221 }
222
223 pub(crate) fn remove(&self, meta: &ConnectionMeta, handle: SocketHandle) {
224 let mut iface = self.iface.borrow_mut();
225 let mut metrics = self.metrics.borrow_mut();
226 let mut ports = self.ports.borrow_mut();
227
228 if let Some((handle, socket)) = {
229 let mut sockets = iface.sockets();
230 sockets.find(|(h, _)| h == &handle)
231 } {
232 log::trace!("Removing connection: {meta}. Socket handle: {handle}");
233
234 metrics.remove(&socket.desc());
235 iface.remove_socket(handle);
236 ports.free(meta.protocol, meta.local.port);
237 }
238 }
239
240 #[inline]
241 pub fn send<B: Into<Payload>, F: Fn() + 'static>(
242 &self,
243 data: B,
244 conn: Connection,
245 f: F,
246 ) -> Send<'a> {
247 Send::new(data.into(), conn, self.iface.clone(), f)
248 }
249
250 #[inline]
251 pub fn receive<B: Into<Payload>>(&self, data: B) {
252 let mut iface = self.iface.borrow_mut();
253 iface.device_mut().phy_rx(data.into());
254 }
255
256 #[inline]
257 pub fn poll(&self) -> bool {
258 let mut iface = self.iface.borrow_mut();
259 iface.poll(Instant::now())
260 }
261}
262
263fn map_protocol(protocol: Protocol) -> Result<IpProtocol> {
264 match protocol {
265 Protocol::HopByHop => Ok(IpProtocol::HopByHop),
266 Protocol::Icmp => Ok(IpProtocol::Icmp),
267 Protocol::Igmp => Ok(IpProtocol::Igmp),
268 Protocol::Tcp => Ok(IpProtocol::Tcp),
269 Protocol::Udp => Ok(IpProtocol::Udp),
270 Protocol::Ipv6Route => Ok(IpProtocol::Ipv6Route),
271 Protocol::Ipv6Frag => Ok(IpProtocol::Ipv6Frag),
272 Protocol::Ipv6Icmp => Ok(IpProtocol::Icmpv6),
273 Protocol::Ipv6NoNxt => Ok(IpProtocol::Ipv6NoNxt),
274 Protocol::Ipv6Opts => Ok(IpProtocol::Ipv6Opts),
275 _ => Err(Error::ProtocolNotSupported(protocol.to_string())),
276 }
277}