1use std::{
2 net::IpAddr,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures::{Sink, Stream};
8use smoltcp::wire::IpProtocol;
9use tokio::sync::mpsc::{channel, Receiver, Sender};
10use tracing::{debug, trace};
11
12use crate::{
13 filter::{IpFilter, IpFilters},
14 packet::{AnyIpPktFrame, IpPacket},
15 runner::Runner,
16 tcp::TcpListener,
17 udp::UdpSocket,
18};
19
20pub struct StackBuilder {
21 enable_udp: bool,
22 enable_tcp: bool,
23 enable_icmp: bool,
24 stack_buffer_size: usize,
25 udp_buffer_size: usize,
26 tcp_buffer_size: usize,
27 ip_filters: IpFilters<'static>,
28}
29
30impl Default for StackBuilder {
31 fn default() -> Self {
32 Self {
33 enable_udp: false,
34 enable_tcp: false,
35 enable_icmp: false,
36 stack_buffer_size: 1024,
37 udp_buffer_size: 512,
38 tcp_buffer_size: 512,
39 ip_filters: IpFilters::with_non_broadcast(),
40 }
41 }
42}
43
44#[allow(unused)]
45impl StackBuilder {
46 pub fn enable_udp(mut self, enable: bool) -> Self {
47 self.enable_udp = enable;
48 self
49 }
50
51 pub fn enable_tcp(mut self, enable: bool) -> Self {
52 self.enable_tcp = enable;
53 self
54 }
55
56 pub fn enable_icmp(mut self, enable: bool) -> Self {
57 self.enable_icmp = enable;
58 self
59 }
60
61 pub fn stack_buffer_size(mut self, size: usize) -> Self {
62 self.stack_buffer_size = size;
63 self
64 }
65
66 pub fn udp_buffer_size(mut self, size: usize) -> Self {
67 self.udp_buffer_size = size;
68 self
69 }
70
71 pub fn tcp_buffer_size(mut self, size: usize) -> Self {
72 self.tcp_buffer_size = size;
73 self
74 }
75
76 pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
77 self.ip_filters = filters;
78 self
79 }
80
81 pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
82 self.ip_filters.add(filter);
83 self
84 }
85
86 pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
87 where
88 F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
89 {
90 self.ip_filters.add_fn(filter);
91 self
92 }
93
94 #[allow(clippy::type_complexity)]
95 pub fn build(
96 self,
97 ) -> std::io::Result<(
98 Stack,
99 Option<Runner>,
100 Option<UdpSocket>,
101 Option<TcpListener>,
102 )> {
103 let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
104
105 let (udp_tx, udp_rx) = if self.enable_udp {
106 let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
107 (Some(udp_tx), Some(udp_rx))
108 } else {
109 (None, None)
110 };
111
112 let (tcp_tx, tcp_rx) = if self.enable_tcp {
113 let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
114 (Some(tcp_tx), Some(tcp_rx))
115 } else {
116 (None, None)
117 };
118
119 if self.enable_icmp && !self.enable_tcp {
122 use std::io::{Error, ErrorKind::InvalidInput};
123 return Err(Error::new(InvalidInput, "ICMP requires TCP"));
124 }
125 let icmp_tx = if self.enable_icmp {
126 tcp_tx.clone()
127 } else {
128 None
129 };
130
131 let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone()));
132
133 let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
134 let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx)?;
135 (Some(tcp_runner), Some(tcp_listener))
136 } else {
137 (None, None)
138 };
139
140 let stack = Stack {
141 ip_filters: self.ip_filters,
142 stack_rx,
143 sink_buf: None,
144 udp_tx,
145 tcp_tx,
146 icmp_tx,
147 };
148
149 Ok((stack, tcp_runner, udp_socket, tcp_listener))
150 }
151}
152
153pub struct Stack {
154 ip_filters: IpFilters<'static>,
155 sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
156 udp_tx: Option<Sender<AnyIpPktFrame>>,
157 tcp_tx: Option<Sender<AnyIpPktFrame>>,
158 icmp_tx: Option<Sender<AnyIpPktFrame>>,
159 stack_rx: Receiver<AnyIpPktFrame>,
160}
161
162impl Stack {
163 fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
164 let (item, proto) = match self.sink_buf.take() {
165 Some(val) => val,
166 None => return Poll::Ready(Ok(())),
167 };
168
169 let ready_res = match proto {
170 IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()),
171 IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()),
172 IpProtocol::Icmp | IpProtocol::Icmpv6 => {
173 self.icmp_tx.as_mut().map(|tx| tx.try_reserve())
174 }
175 _ => unreachable!(),
176 };
177
178 let Some(ready_res) = ready_res else {
179 return Poll::Ready(Ok(()));
180 };
181
182 let permit = match ready_res {
183 Ok(permit) => permit,
184 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
185 self.sink_buf.replace((item, proto));
186 return Poll::Pending;
187 }
188 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
189 return Poll::Ready(Err(channel_closed_err("channel is closed")));
190 }
191 };
192
193 permit.send(item);
194 Poll::Ready(Ok(()))
195 }
196}
197
198impl Stream for Stack {
200 type Item = std::io::Result<AnyIpPktFrame>;
201
202 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
203 match self.stack_rx.poll_recv(cx) {
204 Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))),
205 Poll::Ready(None) => Poll::Ready(None),
206 Poll::Pending => Poll::Pending,
207 }
208 }
209}
210
211impl Sink<AnyIpPktFrame> for Stack {
213 type Error = std::io::Error;
214
215 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 if self.sink_buf.is_none() {
217 Poll::Ready(Ok(()))
218 } else {
219 Poll::Pending
220 }
221 }
222
223 fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
224 if item.is_empty() {
225 return Ok(());
226 }
227
228 use std::io::{Error, ErrorKind::InvalidInput};
229 let packet = IpPacket::new_checked(item.as_slice())
230 .map_err(|err| Error::new(InvalidInput, format!("invalid IP packet: {err}")))?;
231
232 let src_ip = packet.src_addr();
233 let dst_ip = packet.dst_addr();
234
235 let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
236 if !addr_allowed {
237 trace!("IP packet {src_ip} -> {dst_ip} (allowed? {addr_allowed}) throwing away",);
238 return Ok(());
239 }
240
241 let protocol = packet.protocol();
242 if matches!(
243 protocol,
244 IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
245 ) {
246 self.sink_buf.replace((item, protocol));
247 } else {
248 debug!("tun IP packet ignored (protocol: {:?})", protocol);
249 }
250
251 Ok(())
252 }
253
254 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255 self.poll_send(cx)
256 }
257
258 fn poll_close(
259 mut self: Pin<&mut Self>,
260 _cx: &mut Context<'_>,
261 ) -> Poll<Result<(), Self::Error>> {
262 self.stack_rx.close();
263 Poll::Ready(Ok(()))
264 }
265}
266
267fn channel_closed_err<E>(err: E) -> std::io::Error
268where
269 E: Into<Box<dyn std::error::Error + Send + Sync>>,
270{
271 std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
272}