netstack_smoltcp/
stack.rs

1use std::{
2    net::IpAddr,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures::{Sink, Stream};
8use smoltcp::wire::IpProtocol;
9use tokio::sync::mpsc::{channel, Receiver, Sender};
10use tracing::{debug, trace};
11
12use crate::{
13    filter::{IpFilter, IpFilters},
14    packet::{AnyIpPktFrame, IpPacket},
15    runner::Runner,
16    tcp::TcpListener,
17    udp::UdpSocket,
18};
19
20pub struct StackBuilder {
21    enable_udp: bool,
22    enable_tcp: bool,
23    enable_icmp: bool,
24    stack_buffer_size: usize,
25    udp_buffer_size: usize,
26    tcp_buffer_size: usize,
27    ip_filters: IpFilters<'static>,
28}
29
30impl Default for StackBuilder {
31    fn default() -> Self {
32        Self {
33            enable_udp: false,
34            enable_tcp: false,
35            enable_icmp: false,
36            stack_buffer_size: 1024,
37            udp_buffer_size: 512,
38            tcp_buffer_size: 512,
39            ip_filters: IpFilters::with_non_broadcast(),
40        }
41    }
42}
43
44#[allow(unused)]
45impl StackBuilder {
46    pub fn enable_udp(mut self, enable: bool) -> Self {
47        self.enable_udp = enable;
48        self
49    }
50
51    pub fn enable_tcp(mut self, enable: bool) -> Self {
52        self.enable_tcp = enable;
53        self
54    }
55
56    pub fn enable_icmp(mut self, enable: bool) -> Self {
57        self.enable_icmp = enable;
58        self
59    }
60
61    pub fn stack_buffer_size(mut self, size: usize) -> Self {
62        self.stack_buffer_size = size;
63        self
64    }
65
66    pub fn udp_buffer_size(mut self, size: usize) -> Self {
67        self.udp_buffer_size = size;
68        self
69    }
70
71    pub fn tcp_buffer_size(mut self, size: usize) -> Self {
72        self.tcp_buffer_size = size;
73        self
74    }
75
76    pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
77        self.ip_filters = filters;
78        self
79    }
80
81    pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
82        self.ip_filters.add(filter);
83        self
84    }
85
86    pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
87    where
88        F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
89    {
90        self.ip_filters.add_fn(filter);
91        self
92    }
93
94    #[allow(clippy::type_complexity)]
95    pub fn build(
96        self,
97    ) -> std::io::Result<(
98        Stack,
99        Option<Runner>,
100        Option<UdpSocket>,
101        Option<TcpListener>,
102    )> {
103        let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
104
105        let (udp_tx, udp_rx) = if self.enable_udp {
106            let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
107            (Some(udp_tx), Some(udp_rx))
108        } else {
109            (None, None)
110        };
111
112        let (tcp_tx, tcp_rx) = if self.enable_tcp {
113            let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
114            (Some(tcp_tx), Some(tcp_rx))
115        } else {
116            (None, None)
117        };
118
119        // ICMP is handled by TCP's Interface.
120        // smoltcp's interface will always send replies to EchoRequest
121        if self.enable_icmp && !self.enable_tcp {
122            use std::io::{Error, ErrorKind::InvalidInput};
123            return Err(Error::new(InvalidInput, "ICMP requires TCP"));
124        }
125        let icmp_tx = if self.enable_icmp {
126            tcp_tx.clone()
127        } else {
128            None
129        };
130
131        let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone()));
132
133        let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
134            let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx)?;
135            (Some(tcp_runner), Some(tcp_listener))
136        } else {
137            (None, None)
138        };
139
140        let stack = Stack {
141            ip_filters: self.ip_filters,
142            stack_rx,
143            sink_buf: None,
144            udp_tx,
145            tcp_tx,
146            icmp_tx,
147        };
148
149        Ok((stack, tcp_runner, udp_socket, tcp_listener))
150    }
151}
152
153pub struct Stack {
154    ip_filters: IpFilters<'static>,
155    sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
156    udp_tx: Option<Sender<AnyIpPktFrame>>,
157    tcp_tx: Option<Sender<AnyIpPktFrame>>,
158    icmp_tx: Option<Sender<AnyIpPktFrame>>,
159    stack_rx: Receiver<AnyIpPktFrame>,
160}
161
162impl Stack {
163    fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
164        let (item, proto) = match self.sink_buf.take() {
165            Some(val) => val,
166            None => return Poll::Ready(Ok(())),
167        };
168
169        let ready_res = match proto {
170            IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()),
171            IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()),
172            IpProtocol::Icmp | IpProtocol::Icmpv6 => {
173                self.icmp_tx.as_mut().map(|tx| tx.try_reserve())
174            }
175            _ => unreachable!(),
176        };
177
178        let Some(ready_res) = ready_res else {
179            return Poll::Ready(Ok(()));
180        };
181
182        let permit = match ready_res {
183            Ok(permit) => permit,
184            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
185                self.sink_buf.replace((item, proto));
186                return Poll::Pending;
187            }
188            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
189                return Poll::Ready(Err(channel_closed_err("channel is closed")));
190            }
191        };
192
193        permit.send(item);
194        Poll::Ready(Ok(()))
195    }
196}
197
198// Recv from stack.
199impl Stream for Stack {
200    type Item = std::io::Result<AnyIpPktFrame>;
201
202    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
203        match self.stack_rx.poll_recv(cx) {
204            Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))),
205            Poll::Ready(None) => Poll::Ready(None),
206            Poll::Pending => Poll::Pending,
207        }
208    }
209}
210
211// Send to stack.
212impl Sink<AnyIpPktFrame> for Stack {
213    type Error = std::io::Error;
214
215    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        if self.sink_buf.is_none() {
217            Poll::Ready(Ok(()))
218        } else {
219            Poll::Pending
220        }
221    }
222
223    fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
224        if item.is_empty() {
225            return Ok(());
226        }
227
228        use std::io::{Error, ErrorKind::InvalidInput};
229        let packet = IpPacket::new_checked(item.as_slice())
230            .map_err(|err| Error::new(InvalidInput, format!("invalid IP packet: {err}")))?;
231
232        let src_ip = packet.src_addr();
233        let dst_ip = packet.dst_addr();
234
235        let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
236        if !addr_allowed {
237            trace!("IP packet {src_ip} -> {dst_ip} (allowed? {addr_allowed}) throwing away",);
238            return Ok(());
239        }
240
241        let protocol = packet.protocol();
242        if matches!(
243            protocol,
244            IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
245        ) {
246            self.sink_buf.replace((item, protocol));
247        } else {
248            debug!("tun IP packet ignored (protocol: {:?})", protocol);
249        }
250
251        Ok(())
252    }
253
254    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255        self.poll_send(cx)
256    }
257
258    fn poll_close(
259        mut self: Pin<&mut Self>,
260        _cx: &mut Context<'_>,
261    ) -> Poll<Result<(), Self::Error>> {
262        self.stack_rx.close();
263        Poll::Ready(Ok(()))
264    }
265}
266
267fn channel_closed_err<E>(err: E) -> std::io::Error
268where
269    E: Into<Box<dyn std::error::Error + Send + Sync>>,
270{
271    std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
272}