1use std::{
2 io,
3 net::IpAddr,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures::{Sink, Stream};
9use smoltcp::wire::IpProtocol;
10use tokio::sync::mpsc::{channel, Receiver, Sender};
11use tracing::{debug, trace};
12
13use crate::{
14 filter::{IpFilter, IpFilters},
15 packet::{AnyIpPktFrame, IpPacket},
16 runner::Runner,
17 tcp::TcpListener,
18 udp::UdpSocket,
19};
20
21pub struct StackBuilder {
22 enable_udp: bool,
23 enable_tcp: bool,
24 enable_icmp: bool,
25 stack_buffer_size: usize,
26 udp_buffer_size: usize,
27 tcp_buffer_size: usize,
28 ip_filters: IpFilters<'static>,
29}
30
31impl Default for StackBuilder {
32 fn default() -> Self {
33 Self {
34 enable_udp: false,
35 enable_tcp: false,
36 enable_icmp: false,
37 stack_buffer_size: 1024,
38 udp_buffer_size: 512,
39 tcp_buffer_size: 512,
40 ip_filters: IpFilters::with_non_broadcast(),
41 }
42 }
43}
44
45#[allow(unused)]
46impl StackBuilder {
47 pub fn enable_udp(mut self, enable: bool) -> Self {
48 self.enable_udp = enable;
49 self
50 }
51
52 pub fn enable_tcp(mut self, enable: bool) -> Self {
53 self.enable_tcp = enable;
54 self
55 }
56
57 pub fn enable_icmp(mut self, enable: bool) -> Self {
58 self.enable_icmp = enable;
59 self
60 }
61
62 pub fn stack_buffer_size(mut self, size: usize) -> Self {
63 self.stack_buffer_size = size;
64 self
65 }
66
67 pub fn udp_buffer_size(mut self, size: usize) -> Self {
68 self.udp_buffer_size = size;
69 self
70 }
71
72 pub fn tcp_buffer_size(mut self, size: usize) -> Self {
73 self.tcp_buffer_size = size;
74 self
75 }
76
77 pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
78 self.ip_filters = filters;
79 self
80 }
81
82 pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
83 self.ip_filters.add(filter);
84 self
85 }
86
87 pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
88 where
89 F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
90 {
91 self.ip_filters.add_fn(filter);
92 self
93 }
94
95 pub fn build(
96 self,
97 ) -> std::io::Result<(
98 Stack,
99 Option<Runner>,
100 Option<UdpSocket>,
101 Option<TcpListener>,
102 )> {
103 let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
104
105 let (udp_tx, udp_rx) = if self.enable_udp {
106 let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
107 (Some(udp_tx), Some(udp_rx))
108 } else {
109 (None, None)
110 };
111
112 let (tcp_tx, tcp_rx) = if self.enable_tcp {
113 let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
114 (Some(tcp_tx), Some(tcp_rx))
115 } else {
116 (None, None)
117 };
118
119 if self.enable_icmp && !self.enable_tcp {
122 return Err(std::io::Error::new(
123 std::io::ErrorKind::InvalidInput,
124 "Enabling icmp requires enabling tcp",
125 ));
126 }
127 let icmp_tx = if self.enable_icmp {
128 if let Some(ref tcp_tx) = tcp_tx {
129 Some(tcp_tx.clone())
130 } else {
131 None
132 }
133 } else {
134 None
135 };
136
137 let udp_socket = if let Some(udp_rx) = udp_rx {
138 Some(UdpSocket::new(udp_rx, stack_tx.clone()))
139 } else {
140 None
141 };
142
143 let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
144 let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
145 (Some(tcp_runner), Some(tcp_listener))
146 } else {
147 (None, None)
148 };
149
150 let stack = Stack {
151 ip_filters: self.ip_filters,
152 stack_rx,
153 sink_buf: None,
154 udp_tx,
155 tcp_tx,
156 icmp_tx,
157 };
158
159 Ok((stack, tcp_runner, udp_socket, tcp_listener))
160 }
161}
162
163pub struct Stack {
164 ip_filters: IpFilters<'static>,
165 sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
166 udp_tx: Option<Sender<AnyIpPktFrame>>,
167 tcp_tx: Option<Sender<AnyIpPktFrame>>,
168 icmp_tx: Option<Sender<AnyIpPktFrame>>,
169 stack_rx: Receiver<AnyIpPktFrame>,
170}
171
172impl Stack {
173 fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
174 let (item, proto) = match self.sink_buf.take() {
175 Some(val) => val,
176 None => return Poll::Ready(Ok(())),
177 };
178
179 let ready_res = match proto {
180 IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()),
181 IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()),
182 IpProtocol::Icmp | IpProtocol::Icmpv6 => {
183 self.icmp_tx.as_mut().map(|tx| tx.try_reserve())
184 }
185 _ => unreachable!(),
186 };
187
188 let Some(ready_res) = ready_res else {
189 return Poll::Ready(Ok(()));
190 };
191
192 let permit = match ready_res {
193 Ok(permit) => permit,
194 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
195 self.sink_buf.replace((item, proto));
196 return Poll::Pending;
197 }
198 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
199 return Poll::Ready(Err(channel_closed_err("channel is closed")));
200 }
201 };
202
203 permit.send(item);
204 Poll::Ready(Ok(()))
205 }
206}
207
208impl Stream for Stack {
210 type Item = io::Result<AnyIpPktFrame>;
211
212 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213 match self.stack_rx.poll_recv(cx) {
214 Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))),
215 Poll::Ready(None) => Poll::Ready(None),
216 Poll::Pending => Poll::Pending,
217 }
218 }
219}
220
221impl Sink<AnyIpPktFrame> for Stack {
223 type Error = io::Error;
224
225 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226 if self.sink_buf.is_none() {
227 Poll::Ready(Ok(()))
228 } else {
229 Poll::Pending
230 }
231 }
232
233 fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
234 if item.is_empty() {
235 return Ok(());
236 }
237
238 let packet = IpPacket::new_checked(item.as_slice()).map_err(|err| {
239 io::Error::new(
240 io::ErrorKind::InvalidInput,
241 format!("invalid IP packet: {}", err),
242 )
243 })?;
244
245 let src_ip = packet.src_addr();
246 let dst_ip = packet.dst_addr();
247
248 let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
249 if !addr_allowed {
250 trace!(
251 "IP packet {} -> {} (allowed? {}) throwing away",
252 src_ip,
253 dst_ip,
254 addr_allowed,
255 );
256 return Ok(());
257 }
258
259 let protocol = packet.protocol();
260 if matches!(
261 protocol,
262 IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
263 ) {
264 self.sink_buf.replace((item, protocol));
265 } else {
266 debug!("tun IP packet ignored (protocol: {:?})", protocol);
267 }
268
269 Ok(())
270 }
271
272 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273 self.poll_send(cx)
274 }
275
276 fn poll_close(
277 mut self: Pin<&mut Self>,
278 _cx: &mut Context<'_>,
279 ) -> Poll<Result<(), Self::Error>> {
280 self.stack_rx.close();
281 Poll::Ready(Ok(()))
282 }
283}
284
285fn channel_closed_err<E>(err: E) -> std::io::Error
286where
287 E: Into<Box<dyn std::error::Error + Send + Sync>>,
288{
289 std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
290}