netstack_smoltcp/
stack.rs

1use std::{
2    io,
3    net::IpAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures::{Sink, Stream};
9use smoltcp::wire::IpProtocol;
10use tokio::sync::mpsc::{channel, Receiver, Sender};
11use tracing::{debug, trace};
12
13use crate::{
14    filter::{IpFilter, IpFilters},
15    packet::{AnyIpPktFrame, IpPacket},
16    runner::Runner,
17    tcp::TcpListener,
18    udp::UdpSocket,
19};
20
21pub struct StackBuilder {
22    enable_udp: bool,
23    enable_tcp: bool,
24    enable_icmp: bool,
25    stack_buffer_size: usize,
26    udp_buffer_size: usize,
27    tcp_buffer_size: usize,
28    ip_filters: IpFilters<'static>,
29}
30
31impl Default for StackBuilder {
32    fn default() -> Self {
33        Self {
34            enable_udp: false,
35            enable_tcp: false,
36            enable_icmp: false,
37            stack_buffer_size: 1024,
38            udp_buffer_size: 512,
39            tcp_buffer_size: 512,
40            ip_filters: IpFilters::with_non_broadcast(),
41        }
42    }
43}
44
45#[allow(unused)]
46impl StackBuilder {
47    pub fn enable_udp(mut self, enable: bool) -> Self {
48        self.enable_udp = enable;
49        self
50    }
51
52    pub fn enable_tcp(mut self, enable: bool) -> Self {
53        self.enable_tcp = enable;
54        self
55    }
56
57    pub fn enable_icmp(mut self, enable: bool) -> Self {
58        self.enable_icmp = enable;
59        self
60    }
61
62    pub fn stack_buffer_size(mut self, size: usize) -> Self {
63        self.stack_buffer_size = size;
64        self
65    }
66
67    pub fn udp_buffer_size(mut self, size: usize) -> Self {
68        self.udp_buffer_size = size;
69        self
70    }
71
72    pub fn tcp_buffer_size(mut self, size: usize) -> Self {
73        self.tcp_buffer_size = size;
74        self
75    }
76
77    pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
78        self.ip_filters = filters;
79        self
80    }
81
82    pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
83        self.ip_filters.add(filter);
84        self
85    }
86
87    pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
88    where
89        F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
90    {
91        self.ip_filters.add_fn(filter);
92        self
93    }
94
95    pub fn build(
96        self,
97    ) -> std::io::Result<(
98        Stack,
99        Option<Runner>,
100        Option<UdpSocket>,
101        Option<TcpListener>,
102    )> {
103        let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
104
105        let (udp_tx, udp_rx) = if self.enable_udp {
106            let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
107            (Some(udp_tx), Some(udp_rx))
108        } else {
109            (None, None)
110        };
111
112        let (tcp_tx, tcp_rx) = if self.enable_tcp {
113            let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
114            (Some(tcp_tx), Some(tcp_rx))
115        } else {
116            (None, None)
117        };
118
119        // ICMP is handled by TCP's Interface.
120        // smoltcp's interface will always send replies to EchoRequest
121        if self.enable_icmp && !self.enable_tcp {
122            return Err(std::io::Error::new(
123                std::io::ErrorKind::InvalidInput,
124                "Enabling icmp requires enabling tcp",
125            ));
126        }
127        let icmp_tx = if self.enable_icmp {
128            if let Some(ref tcp_tx) = tcp_tx {
129                Some(tcp_tx.clone())
130            } else {
131                None
132            }
133        } else {
134            None
135        };
136
137        let udp_socket = if let Some(udp_rx) = udp_rx {
138            Some(UdpSocket::new(udp_rx, stack_tx.clone()))
139        } else {
140            None
141        };
142
143        let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
144            let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
145            (Some(tcp_runner), Some(tcp_listener))
146        } else {
147            (None, None)
148        };
149
150        let stack = Stack {
151            ip_filters: self.ip_filters,
152            stack_rx,
153            sink_buf: None,
154            udp_tx,
155            tcp_tx,
156            icmp_tx,
157        };
158
159        Ok((stack, tcp_runner, udp_socket, tcp_listener))
160    }
161}
162
163pub struct Stack {
164    ip_filters: IpFilters<'static>,
165    sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
166    udp_tx: Option<Sender<AnyIpPktFrame>>,
167    tcp_tx: Option<Sender<AnyIpPktFrame>>,
168    icmp_tx: Option<Sender<AnyIpPktFrame>>,
169    stack_rx: Receiver<AnyIpPktFrame>,
170}
171
172impl Stack {
173    fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
174        let (item, proto) = match self.sink_buf.take() {
175            Some(val) => val,
176            None => return Poll::Ready(Ok(())),
177        };
178
179        let ready_res = match proto {
180            IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()),
181            IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()),
182            IpProtocol::Icmp | IpProtocol::Icmpv6 => {
183                self.icmp_tx.as_mut().map(|tx| tx.try_reserve())
184            }
185            _ => unreachable!(),
186        };
187
188        let Some(ready_res) = ready_res else {
189            return Poll::Ready(Ok(()));
190        };
191
192        let permit = match ready_res {
193            Ok(permit) => permit,
194            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
195                self.sink_buf.replace((item, proto));
196                return Poll::Pending;
197            }
198            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
199                return Poll::Ready(Err(channel_closed_err("channel is closed")));
200            }
201        };
202
203        permit.send(item);
204        Poll::Ready(Ok(()))
205    }
206}
207
208// Recv from stack.
209impl Stream for Stack {
210    type Item = io::Result<AnyIpPktFrame>;
211
212    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213        match self.stack_rx.poll_recv(cx) {
214            Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))),
215            Poll::Ready(None) => Poll::Ready(None),
216            Poll::Pending => Poll::Pending,
217        }
218    }
219}
220
221// Send to stack.
222impl Sink<AnyIpPktFrame> for Stack {
223    type Error = io::Error;
224
225    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226        if self.sink_buf.is_none() {
227            Poll::Ready(Ok(()))
228        } else {
229            Poll::Pending
230        }
231    }
232
233    fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
234        if item.is_empty() {
235            return Ok(());
236        }
237
238        let packet = IpPacket::new_checked(item.as_slice()).map_err(|err| {
239            io::Error::new(
240                io::ErrorKind::InvalidInput,
241                format!("invalid IP packet: {}", err),
242            )
243        })?;
244
245        let src_ip = packet.src_addr();
246        let dst_ip = packet.dst_addr();
247
248        let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
249        if !addr_allowed {
250            trace!(
251                "IP packet {} -> {} (allowed? {}) throwing away",
252                src_ip,
253                dst_ip,
254                addr_allowed,
255            );
256            return Ok(());
257        }
258
259        let protocol = packet.protocol();
260        if matches!(
261            protocol,
262            IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
263        ) {
264            self.sink_buf.replace((item, protocol));
265        } else {
266            debug!("tun IP packet ignored (protocol: {:?})", protocol);
267        }
268
269        Ok(())
270    }
271
272    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273        self.poll_send(cx)
274    }
275
276    fn poll_close(
277        mut self: Pin<&mut Self>,
278        _cx: &mut Context<'_>,
279    ) -> Poll<Result<(), Self::Error>> {
280        self.stack_rx.close();
281        Poll::Ready(Ok(()))
282    }
283}
284
285fn channel_closed_err<E>(err: E) -> std::io::Error
286where
287    E: Into<Box<dyn std::error::Error + Send + Sync>>,
288{
289    std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
290}