1use std::io;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet};
9use tokio::sync::mpsc;
10
11use crate::tcp::TcpConnection;
12use crate::{buffer::BufferPool, udp::UdpTunnel};
13use crate::{debug, error};
14
15const DEFAULT_IPV4_ADDR: Ipv4Address = Ipv4Address::new(10, 0, 0, 1);
16const DEFAULT_IPV6_ADDR: Ipv6Address = Ipv6Address::new(0x0, 0xfac, 0, 0, 0, 0, 0, 1);
17
18#[derive(Debug)]
19pub enum IpPacket<T: AsRef<[u8]>> {
20 Ipv4(Ipv4Packet<T>),
21 Ipv6(Ipv6Packet<T>),
22}
23
24impl<T: AsRef<[u8]> + Copy> IpPacket<T> {
25 pub fn new_checked(packet: T) -> smoltcp::wire::Result<IpPacket<T>> {
26 let buffer = packet.as_ref();
27 match IpVersion::of_packet(buffer)? {
28 IpVersion::Ipv4 => Ok(IpPacket::Ipv4(Ipv4Packet::new_checked(packet)?)),
29 IpVersion::Ipv6 => Ok(IpPacket::Ipv6(Ipv6Packet::new_checked(packet)?)),
30 }
31 }
32
33 pub fn src_addr(&self) -> IpAddr {
34 match *self {
35 IpPacket::Ipv4(ref packet) => IpAddr::from(packet.src_addr()),
36 IpPacket::Ipv6(ref packet) => IpAddr::from(packet.src_addr()),
37 }
38 }
39
40 pub fn dst_addr(&self) -> IpAddr {
41 match *self {
42 IpPacket::Ipv4(ref packet) => IpAddr::from(packet.dst_addr()),
43 IpPacket::Ipv6(ref packet) => IpAddr::from(packet.dst_addr()),
44 }
45 }
46
47 pub fn protocol(&self) -> IpProtocol {
48 match *self {
49 IpPacket::Ipv4(ref packet) => packet.next_header(),
50 IpPacket::Ipv6(ref packet) => packet.next_header(),
51 }
52 }
53}
54
55impl<'a, T: AsRef<[u8]> + ?Sized> IpPacket<&'a T> {
56 #[inline]
57 pub fn payload(&self) -> &'a [u8] {
58 match *self {
59 IpPacket::Ipv4(ref packet) => packet.payload(),
60 IpPacket::Ipv6(ref packet) => packet.payload(),
61 }
62 }
63}
64
65#[derive(Clone, Debug)]
66pub struct NetStackConfig {
67 pub mtu: usize,
68 pub channel_size: usize,
69 pub number_workers: usize,
70
71 pub tcp_send_buffer_size: usize,
72 pub tcp_recv_buffer_size: usize,
73
74 pub buffer_pool_size: usize,
75 pub buffer_pool_default_buffer_size: usize,
76
77 pub ipv4_addr: Ipv4Addr,
78 pub ipv4_prefix_len: u8,
79 pub ipv6_addr: Ipv6Addr,
80 pub ipv6_prefix_len: u8,
81
82 pub tcp_keep_alive: Duration,
83 pub tcp_timeout: Duration,
84 pub packet_batch_size: usize,
85 pub ip_ttl: u8,
86}
87
88impl Default for NetStackConfig {
89 fn default() -> Self {
90 Self {
91 mtu: 1500,
92 channel_size: 4096,
93 number_workers: std::thread::available_parallelism().map_or(4, |n| n.get()),
94 tcp_send_buffer_size: 16 * 1024,
95 tcp_recv_buffer_size: 16 * 1024,
96 buffer_pool_size: 32,
97 buffer_pool_default_buffer_size: 2 * 1024,
98 ipv4_addr: DEFAULT_IPV4_ADDR,
99 ipv4_prefix_len: 24,
100 ipv6_addr: DEFAULT_IPV6_ADDR,
101 ipv6_prefix_len: 64,
102 tcp_timeout: Duration::from_secs(60),
103 tcp_keep_alive: Duration::from_secs(28),
104 packet_batch_size: 32,
105 ip_ttl: 64,
106 }
107 }
108}
109
110pub struct NetStack {
111 udp_inbound: mpsc::Sender<Packet>,
112 tcp_inbound: mpsc::Sender<Packet>,
113 packet_outbound: mpsc::Receiver<Packet>,
114}
115
116pub struct Packet {
117 data: Bytes,
118}
119
120impl Packet {
121 pub fn new(data: impl Into<Bytes>) -> Self {
122 Packet { data: data.into() }
123 }
124
125 pub fn data(&self) -> &[u8] {
126 &self.data
127 }
128
129 pub fn into_bytes(self) -> Bytes {
130 self.data
131 }
132}
133
134impl<T> From<T> for Packet
135where
136 T: Into<Bytes>,
137{
138 fn from(data: T) -> Self {
139 Packet::new(data)
140 }
141}
142
143impl NetStack {
144 pub fn new(config: NetStackConfig) -> (Self, TcpConnection, UdpTunnel) {
145 let (packet_sender, packet_receiver) = mpsc::channel::<Packet>(config.channel_size);
146 let (udp_inbound_app, udp_outbound_stack) = mpsc::channel::<Packet>(config.channel_size);
147 let (tcp_inbound_app, tcp_outbound_stack) = mpsc::channel::<Packet>(config.channel_size);
148 let buffer_pool = Arc::new(BufferPool::new(
149 config.buffer_pool_size,
150 config.buffer_pool_default_buffer_size,
151 ));
152
153 (
154 NetStack {
155 udp_inbound: udp_inbound_app,
156 tcp_inbound: tcp_inbound_app,
157 packet_outbound: packet_receiver,
158 },
159 TcpConnection::new(
160 config.clone(),
161 tcp_outbound_stack,
162 packet_sender.clone(),
163 buffer_pool.clone(),
164 ),
165 UdpTunnel::new(
166 config.into(),
167 udp_outbound_stack,
168 packet_sender.clone(),
169 buffer_pool.clone(),
170 ),
171 )
172 }
173
174 pub fn split(self) -> (StackSplitSink, StackSplitStream) {
175 (
176 StackSplitSink::new(self.udp_inbound, self.tcp_inbound),
177 StackSplitStream::new(self.packet_outbound),
178 )
179 }
180}
181
182pub struct StackSplitSink {
183 udp_inbound: mpsc::Sender<Packet>,
184 tcp_inbound: mpsc::Sender<Packet>,
185 packet_container: Option<(Packet, IpProtocol)>,
186}
187
188impl StackSplitSink {
189 pub fn new(udp_inbound: mpsc::Sender<Packet>, tcp_inbound: mpsc::Sender<Packet>) -> Self {
190 Self {
191 udp_inbound,
192 tcp_inbound,
193 packet_container: None,
194 }
195 }
196}
197
198impl futures::Sink<Packet> for StackSplitSink {
199 type Error = io::Error;
200
201 fn poll_ready(
202 mut self: std::pin::Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 ) -> Poll<Result<(), Self::Error>> {
205 if self.packet_container.is_some() {
206 match self.as_mut().poll_flush(cx) {
207 Poll::Ready(Ok(())) => {}
208 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
209 Poll::Pending => return Poll::Pending,
210 }
211 }
212 Poll::Ready(Ok(()))
213 }
214
215 fn start_send(mut self: std::pin::Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
216 if item.data().is_empty() {
217 return Ok(());
218 }
219
220 let packet = IpPacket::new_checked(item.data())
221 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
222
223 let protocol = packet.protocol();
224 if matches!(
225 protocol,
226 IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
227 ) {
228 self.packet_container.replace((item, protocol));
229 } else {
230 error!("IP packet ignored protocol: {protocol:?}");
231 }
232
233 Ok(())
234 }
235
236 fn poll_flush(
237 mut self: std::pin::Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 ) -> Poll<Result<(), Self::Error>> {
240 let (item, proto) = match self.packet_container.take() {
241 Some(val) => val,
242 None => return Poll::Ready(Ok(())),
243 };
244
245 let sender = match proto {
246 IpProtocol::Udp => self.udp_inbound.clone(),
247 IpProtocol::Tcp | IpProtocol::Icmp | IpProtocol::Icmpv6 => self.tcp_inbound.clone(),
248 _ => {
249 error!("Unsupported protocol for packet: {proto:?}");
250 return Poll::Ready(Ok(()));
251 }
252 };
253 let mut fut = Box::pin(sender.reserve());
254
255 match fut.as_mut().poll(cx) {
256 Poll::Ready(Ok(permit)) => {
257 permit.send(item);
258 Poll::Ready(Ok(()))
259 }
260 Poll::Ready(Err(_)) => {
261 let msg = format!("Failed to send packet: channel closed for protocol {proto:?}");
262 debug!("{}", msg);
263 Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, msg)))
264 }
265 Poll::Pending => {
266 self.packet_container = Some((item, proto));
267 Poll::Pending
268 }
269 }
270 }
271
272 fn poll_close(
273 self: std::pin::Pin<&mut Self>,
274 _cx: &mut Context<'_>,
275 ) -> Poll<Result<(), Self::Error>> {
276 Poll::Ready(Ok(()))
277 }
278}
279
280pub struct StackSplitStream {
281 packet_outbound: mpsc::Receiver<Packet>,
282}
283
284impl StackSplitStream {
285 pub fn new(packet_outbound: mpsc::Receiver<Packet>) -> Self {
286 Self { packet_outbound }
287 }
288}
289
290impl futures::Stream for StackSplitStream {
291 type Item = io::Result<Packet>;
292
293 fn poll_next(
294 mut self: std::pin::Pin<&mut Self>,
295 cx: &mut Context<'_>,
296 ) -> Poll<Option<Self::Item>> {
297 match self.packet_outbound.poll_recv(cx) {
298 Poll::Ready(Some(packet)) => Poll::Ready(Some(Ok(packet))),
299 Poll::Ready(None) => Poll::Ready(None),
300 Poll::Pending => Poll::Pending,
301 }
302 }
303}