Skip to main content

netstack_smoltcp/
stack.rs

1use std::{
2    net::IpAddr,
3    pin::Pin,
4    task::{ready, Context, Poll},
5};
6
7use futures::{Sink, Stream};
8use smoltcp::wire::IpProtocol;
9use tokio::sync::mpsc::{channel, Receiver};
10use tokio_util::sync::PollSender;
11use tracing::{debug, trace};
12
13use crate::{
14    filter::{IpFilter, IpFilters},
15    packet::{AnyIpPktFrame, IpPacket},
16    runner::Runner,
17    tcp::TcpListener,
18    udp::UdpSocket,
19};
20
21pub struct StackBuilder {
22    enable_udp: bool,
23    enable_tcp: bool,
24    enable_icmp: bool,
25    stack_buffer_size: usize,
26    udp_buffer_size: usize,
27    tcp_buffer_size: usize,
28    mtu: usize,
29    ip_filters: IpFilters<'static>,
30}
31
32impl Default for StackBuilder {
33    fn default() -> Self {
34        Self {
35            enable_udp: false,
36            enable_tcp: false,
37            enable_icmp: false,
38            stack_buffer_size: 1024,
39            udp_buffer_size: 512,
40            tcp_buffer_size: 512,
41            mtu: 1504, // 1500 for Ethernet + 4 for VLAN
42            ip_filters: IpFilters::with_non_broadcast(),
43        }
44    }
45}
46
47#[allow(unused)]
48impl StackBuilder {
49    pub fn enable_udp(mut self, enable: bool) -> Self {
50        self.enable_udp = enable;
51        self
52    }
53
54    pub fn enable_tcp(mut self, enable: bool) -> Self {
55        self.enable_tcp = enable;
56        self
57    }
58
59    pub fn enable_icmp(mut self, enable: bool) -> Self {
60        self.enable_icmp = enable;
61        self
62    }
63
64    pub fn stack_buffer_size(mut self, size: usize) -> Self {
65        self.stack_buffer_size = size;
66        self
67    }
68
69    pub fn udp_buffer_size(mut self, size: usize) -> Self {
70        self.udp_buffer_size = size;
71        self
72    }
73
74    pub fn tcp_buffer_size(mut self, size: usize) -> Self {
75        self.tcp_buffer_size = size;
76        self
77    }
78
79    pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
80        self.ip_filters = filters;
81        self
82    }
83
84    pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
85        self.ip_filters.add(filter);
86        self
87    }
88
89    pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
90    where
91        F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
92    {
93        self.ip_filters.add_fn(filter);
94        self
95    }
96
97    pub fn mtu(mut self, mtu: usize) -> Self {
98        self.mtu = mtu;
99        self
100    }
101
102    #[allow(clippy::type_complexity)]
103    pub fn build(
104        self,
105    ) -> std::io::Result<(
106        Stack,
107        Option<Runner>,
108        Option<UdpSocket>,
109        Option<TcpListener>,
110    )> {
111        let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
112
113        let (udp_tx, udp_rx) = if self.enable_udp {
114            let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
115            (Some(PollSender::new(udp_tx)), Some(udp_rx))
116        } else {
117            (None, None)
118        };
119
120        let (tcp_tx, tcp_rx) = if self.enable_tcp {
121            let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
122            (Some(PollSender::new(tcp_tx)), Some(tcp_rx))
123        } else {
124            (None, None)
125        };
126
127        // ICMP is handled by TCP's Interface.
128        // smoltcp's interface will always send replies to EchoRequest
129        if self.enable_icmp && !self.enable_tcp {
130            use std::io::{Error, ErrorKind::InvalidInput};
131            return Err(Error::new(InvalidInput, "ICMP requires TCP"));
132        }
133        let icmp_tx = if self.enable_icmp {
134            tcp_tx.clone()
135        } else {
136            None
137        };
138
139        let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone()));
140
141        let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
142            let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx, self.mtu)?;
143            (Some(tcp_runner), Some(tcp_listener))
144        } else {
145            (None, None)
146        };
147
148        let stack = Stack {
149            ip_filters: self.ip_filters,
150            stack_rx,
151            sink_buf: None,
152            udp_tx,
153            tcp_tx,
154            icmp_tx,
155        };
156
157        Ok((stack, tcp_runner, udp_socket, tcp_listener))
158    }
159}
160
161pub struct Stack {
162    ip_filters: IpFilters<'static>,
163    sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
164    udp_tx: Option<PollSender<AnyIpPktFrame>>,
165    tcp_tx: Option<PollSender<AnyIpPktFrame>>,
166    icmp_tx: Option<PollSender<AnyIpPktFrame>>,
167    stack_rx: Receiver<AnyIpPktFrame>,
168}
169
170impl Stack {
171    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
172        let (item, proto) = match self.sink_buf.take() {
173            Some(val) => val,
174            None => return Poll::Ready(Ok(())),
175        };
176
177        let tx = match proto {
178            IpProtocol::Tcp => self.tcp_tx.as_mut(),
179            IpProtocol::Udp => self.udp_tx.as_mut(),
180            IpProtocol::Icmp | IpProtocol::Icmpv6 => self.icmp_tx.as_mut(),
181            _ => unreachable!(),
182        };
183
184        let Some(tx) = tx else {
185            return Poll::Ready(Ok(()));
186        };
187
188        match tx.poll_reserve(cx) {
189            Poll::Pending => {
190                self.sink_buf = Some((item, proto));
191                Poll::Pending
192            }
193            Poll::Ready(Err(_)) => Poll::Ready(Err(channel_closed_err("channel is closed"))),
194            Poll::Ready(Ok(_)) => match tx.send_item(item) {
195                Ok(()) => Poll::Ready(Ok(())),
196                Err(_) => Poll::Ready(Err(channel_closed_err("channel is closed"))),
197            },
198        }
199    }
200}
201
202// Recv from stack.
203impl Stream for Stack {
204    type Item = std::io::Result<AnyIpPktFrame>;
205
206    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207        match self.stack_rx.poll_recv(cx) {
208            Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))),
209            Poll::Ready(None) => Poll::Ready(None),
210            Poll::Pending => Poll::Pending,
211        }
212    }
213}
214
215// Send to stack.
216impl Sink<AnyIpPktFrame> for Stack {
217    type Error = std::io::Error;
218
219    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220        // If a buffered item exists, try to flush it first. This also properly
221        // registers the waker via poll_reserve so we get woken when the channel
222        // has capacity. Without this, returning Pending here with _cx unused
223        // means the task never gets rescheduled.
224        if self.sink_buf.is_some() {
225            ready!(self.poll_send(cx))?;
226        }
227        Poll::Ready(Ok(()))
228    }
229
230    fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
231        if item.is_empty() {
232            return Ok(());
233        }
234
235        use std::io::{Error, ErrorKind::InvalidInput};
236        let packet = IpPacket::new_checked(item.as_slice())
237            .map_err(|err| Error::new(InvalidInput, format!("invalid IP packet: {err}")))?;
238
239        let src_ip = packet.src_addr();
240        let dst_ip = packet.dst_addr();
241
242        let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
243        if !addr_allowed {
244            trace!("IP packet {src_ip} -> {dst_ip} (allowed? {addr_allowed}) throwing away",);
245            return Ok(());
246        }
247
248        let protocol = packet.protocol();
249        if matches!(
250            protocol,
251            IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
252        ) {
253            self.sink_buf.replace((item, protocol));
254        } else {
255            debug!("tun IP packet ignored (protocol: {:?})", protocol);
256        }
257
258        Ok(())
259    }
260
261    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262        self.poll_send(cx)
263    }
264
265    fn poll_close(
266        mut self: Pin<&mut Self>,
267        _cx: &mut Context<'_>,
268    ) -> Poll<Result<(), Self::Error>> {
269        self.stack_rx.close();
270        Poll::Ready(Ok(()))
271    }
272}
273
274fn channel_closed_err<E>(err: E) -> std::io::Error
275where
276    E: Into<Box<dyn std::error::Error + Send + Sync>>,
277{
278    std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
279}