1use std::{
2 net::IpAddr,
3 pin::Pin,
4 task::{ready, Context, Poll},
5};
6
7use futures::{Sink, Stream};
8use smoltcp::wire::IpProtocol;
9use tokio::sync::mpsc::{channel, Receiver};
10use tokio_util::sync::PollSender;
11use tracing::{debug, trace};
12
13use crate::{
14 filter::{IpFilter, IpFilters},
15 packet::{AnyIpPktFrame, IpPacket},
16 runner::Runner,
17 tcp::TcpListener,
18 udp::UdpSocket,
19};
20
21pub struct StackBuilder {
22 enable_udp: bool,
23 enable_tcp: bool,
24 enable_icmp: bool,
25 stack_buffer_size: usize,
26 udp_buffer_size: usize,
27 tcp_buffer_size: usize,
28 mtu: usize,
29 ip_filters: IpFilters<'static>,
30}
31
32impl Default for StackBuilder {
33 fn default() -> Self {
34 Self {
35 enable_udp: false,
36 enable_tcp: false,
37 enable_icmp: false,
38 stack_buffer_size: 1024,
39 udp_buffer_size: 512,
40 tcp_buffer_size: 512,
41 mtu: 1504, ip_filters: IpFilters::with_non_broadcast(),
43 }
44 }
45}
46
47#[allow(unused)]
48impl StackBuilder {
49 pub fn enable_udp(mut self, enable: bool) -> Self {
50 self.enable_udp = enable;
51 self
52 }
53
54 pub fn enable_tcp(mut self, enable: bool) -> Self {
55 self.enable_tcp = enable;
56 self
57 }
58
59 pub fn enable_icmp(mut self, enable: bool) -> Self {
60 self.enable_icmp = enable;
61 self
62 }
63
64 pub fn stack_buffer_size(mut self, size: usize) -> Self {
65 self.stack_buffer_size = size;
66 self
67 }
68
69 pub fn udp_buffer_size(mut self, size: usize) -> Self {
70 self.udp_buffer_size = size;
71 self
72 }
73
74 pub fn tcp_buffer_size(mut self, size: usize) -> Self {
75 self.tcp_buffer_size = size;
76 self
77 }
78
79 pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
80 self.ip_filters = filters;
81 self
82 }
83
84 pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
85 self.ip_filters.add(filter);
86 self
87 }
88
89 pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
90 where
91 F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
92 {
93 self.ip_filters.add_fn(filter);
94 self
95 }
96
97 pub fn mtu(mut self, mtu: usize) -> Self {
98 self.mtu = mtu;
99 self
100 }
101
102 #[allow(clippy::type_complexity)]
103 pub fn build(
104 self,
105 ) -> std::io::Result<(
106 Stack,
107 Option<Runner>,
108 Option<UdpSocket>,
109 Option<TcpListener>,
110 )> {
111 let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
112
113 let (udp_tx, udp_rx) = if self.enable_udp {
114 let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
115 (Some(PollSender::new(udp_tx)), Some(udp_rx))
116 } else {
117 (None, None)
118 };
119
120 let (tcp_tx, tcp_rx) = if self.enable_tcp {
121 let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
122 (Some(PollSender::new(tcp_tx)), Some(tcp_rx))
123 } else {
124 (None, None)
125 };
126
127 if self.enable_icmp && !self.enable_tcp {
130 use std::io::{Error, ErrorKind::InvalidInput};
131 return Err(Error::new(InvalidInput, "ICMP requires TCP"));
132 }
133 let icmp_tx = if self.enable_icmp {
134 tcp_tx.clone()
135 } else {
136 None
137 };
138
139 let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone()));
140
141 let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
142 let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx, self.mtu)?;
143 (Some(tcp_runner), Some(tcp_listener))
144 } else {
145 (None, None)
146 };
147
148 let stack = Stack {
149 ip_filters: self.ip_filters,
150 stack_rx,
151 sink_buf: None,
152 udp_tx,
153 tcp_tx,
154 icmp_tx,
155 };
156
157 Ok((stack, tcp_runner, udp_socket, tcp_listener))
158 }
159}
160
161pub struct Stack {
162 ip_filters: IpFilters<'static>,
163 sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
164 udp_tx: Option<PollSender<AnyIpPktFrame>>,
165 tcp_tx: Option<PollSender<AnyIpPktFrame>>,
166 icmp_tx: Option<PollSender<AnyIpPktFrame>>,
167 stack_rx: Receiver<AnyIpPktFrame>,
168}
169
170impl Stack {
171 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
172 let (item, proto) = match self.sink_buf.take() {
173 Some(val) => val,
174 None => return Poll::Ready(Ok(())),
175 };
176
177 let tx = match proto {
178 IpProtocol::Tcp => self.tcp_tx.as_mut(),
179 IpProtocol::Udp => self.udp_tx.as_mut(),
180 IpProtocol::Icmp | IpProtocol::Icmpv6 => self.icmp_tx.as_mut(),
181 _ => unreachable!(),
182 };
183
184 let Some(tx) = tx else {
185 return Poll::Ready(Ok(()));
186 };
187
188 match tx.poll_reserve(cx) {
189 Poll::Pending => {
190 self.sink_buf = Some((item, proto));
191 Poll::Pending
192 }
193 Poll::Ready(Err(_)) => Poll::Ready(Err(channel_closed_err("channel is closed"))),
194 Poll::Ready(Ok(_)) => match tx.send_item(item) {
195 Ok(()) => Poll::Ready(Ok(())),
196 Err(_) => Poll::Ready(Err(channel_closed_err("channel is closed"))),
197 },
198 }
199 }
200}
201
202impl Stream for Stack {
204 type Item = std::io::Result<AnyIpPktFrame>;
205
206 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207 match self.stack_rx.poll_recv(cx) {
208 Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))),
209 Poll::Ready(None) => Poll::Ready(None),
210 Poll::Pending => Poll::Pending,
211 }
212 }
213}
214
215impl Sink<AnyIpPktFrame> for Stack {
217 type Error = std::io::Error;
218
219 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220 if self.sink_buf.is_some() {
225 ready!(self.poll_send(cx))?;
226 }
227 Poll::Ready(Ok(()))
228 }
229
230 fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
231 if item.is_empty() {
232 return Ok(());
233 }
234
235 use std::io::{Error, ErrorKind::InvalidInput};
236 let packet = IpPacket::new_checked(item.as_slice())
237 .map_err(|err| Error::new(InvalidInput, format!("invalid IP packet: {err}")))?;
238
239 let src_ip = packet.src_addr();
240 let dst_ip = packet.dst_addr();
241
242 let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
243 if !addr_allowed {
244 trace!("IP packet {src_ip} -> {dst_ip} (allowed? {addr_allowed}) throwing away",);
245 return Ok(());
246 }
247
248 let protocol = packet.protocol();
249 if matches!(
250 protocol,
251 IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
252 ) {
253 self.sink_buf.replace((item, protocol));
254 } else {
255 debug!("tun IP packet ignored (protocol: {:?})", protocol);
256 }
257
258 Ok(())
259 }
260
261 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262 self.poll_send(cx)
263 }
264
265 fn poll_close(
266 mut self: Pin<&mut Self>,
267 _cx: &mut Context<'_>,
268 ) -> Poll<Result<(), Self::Error>> {
269 self.stack_rx.close();
270 Poll::Ready(Ok(()))
271 }
272}
273
274fn channel_closed_err<E>(err: E) -> std::io::Error
275where
276 E: Into<Box<dyn std::error::Error + Send + Sync>>,
277{
278 std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
279}