Skip to main content

ombrac_netstack/
stack.rs

1use std::io;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet};
9use tokio::sync::mpsc;
10
11use crate::tcp::TcpConnection;
12use crate::{buffer::BufferPool, udp::UdpTunnel};
13use crate::{debug, error};
14
15const DEFAULT_IPV4_ADDR: Ipv4Address = Ipv4Address::new(10, 0, 0, 1);
16const DEFAULT_IPV6_ADDR: Ipv6Address = Ipv6Address::new(0x0, 0xfac, 0, 0, 0, 0, 0, 1);
17
18#[derive(Debug)]
19pub enum IpPacket<T: AsRef<[u8]>> {
20    Ipv4(Ipv4Packet<T>),
21    Ipv6(Ipv6Packet<T>),
22}
23
24impl<T: AsRef<[u8]> + Copy> IpPacket<T> {
25    pub fn new_checked(packet: T) -> smoltcp::wire::Result<IpPacket<T>> {
26        let buffer = packet.as_ref();
27        match IpVersion::of_packet(buffer)? {
28            IpVersion::Ipv4 => Ok(IpPacket::Ipv4(Ipv4Packet::new_checked(packet)?)),
29            IpVersion::Ipv6 => Ok(IpPacket::Ipv6(Ipv6Packet::new_checked(packet)?)),
30        }
31    }
32
33    pub fn src_addr(&self) -> IpAddr {
34        match *self {
35            IpPacket::Ipv4(ref packet) => IpAddr::from(packet.src_addr()),
36            IpPacket::Ipv6(ref packet) => IpAddr::from(packet.src_addr()),
37        }
38    }
39
40    pub fn dst_addr(&self) -> IpAddr {
41        match *self {
42            IpPacket::Ipv4(ref packet) => IpAddr::from(packet.dst_addr()),
43            IpPacket::Ipv6(ref packet) => IpAddr::from(packet.dst_addr()),
44        }
45    }
46
47    pub fn protocol(&self) -> IpProtocol {
48        match *self {
49            IpPacket::Ipv4(ref packet) => packet.next_header(),
50            IpPacket::Ipv6(ref packet) => packet.next_header(),
51        }
52    }
53}
54
55impl<'a, T: AsRef<[u8]> + ?Sized> IpPacket<&'a T> {
56    #[inline]
57    pub fn payload(&self) -> &'a [u8] {
58        match *self {
59            IpPacket::Ipv4(ref packet) => packet.payload(),
60            IpPacket::Ipv6(ref packet) => packet.payload(),
61        }
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct NetStackConfig {
67    pub mtu: usize,
68    pub channel_size: usize,
69    pub number_workers: usize,
70
71    pub tcp_send_buffer_size: usize,
72    pub tcp_recv_buffer_size: usize,
73
74    pub buffer_pool_size: usize,
75    pub buffer_pool_default_buffer_size: usize,
76
77    pub ipv4_addr: Ipv4Addr,
78    pub ipv4_prefix_len: u8,
79    pub ipv6_addr: Ipv6Addr,
80    pub ipv6_prefix_len: u8,
81
82    pub tcp_keep_alive: Duration,
83    pub tcp_timeout: Duration,
84    pub packet_batch_size: usize,
85    pub ip_ttl: u8,
86}
87
88impl Default for NetStackConfig {
89    fn default() -> Self {
90        Self {
91            mtu: 1500,
92            channel_size: 4096,
93            number_workers: std::thread::available_parallelism().map_or(4, |n| n.get()),
94            tcp_send_buffer_size: 16 * 1024,
95            tcp_recv_buffer_size: 16 * 1024,
96            buffer_pool_size: 32,
97            buffer_pool_default_buffer_size: 2 * 1024,
98            ipv4_addr: DEFAULT_IPV4_ADDR,
99            ipv4_prefix_len: 24,
100            ipv6_addr: DEFAULT_IPV6_ADDR,
101            ipv6_prefix_len: 64,
102            tcp_timeout: Duration::from_secs(60),
103            tcp_keep_alive: Duration::from_secs(28),
104            packet_batch_size: 32,
105            ip_ttl: 64,
106        }
107    }
108}
109
110pub struct NetStack {
111    udp_inbound: mpsc::Sender<Packet>,
112    tcp_inbound: mpsc::Sender<Packet>,
113    packet_outbound: mpsc::Receiver<Packet>,
114}
115
116pub struct Packet {
117    data: Bytes,
118}
119
120impl Packet {
121    pub fn new(data: impl Into<Bytes>) -> Self {
122        Packet { data: data.into() }
123    }
124
125    pub fn data(&self) -> &[u8] {
126        &self.data
127    }
128
129    pub fn into_bytes(self) -> Bytes {
130        self.data
131    }
132}
133
134impl<T> From<T> for Packet
135where
136    T: Into<Bytes>,
137{
138    fn from(data: T) -> Self {
139        Packet::new(data)
140    }
141}
142
143impl NetStack {
144    pub fn new(config: NetStackConfig) -> (Self, TcpConnection, UdpTunnel) {
145        let (packet_sender, packet_receiver) = mpsc::channel::<Packet>(config.channel_size);
146        let (udp_inbound_app, udp_outbound_stack) = mpsc::channel::<Packet>(config.channel_size);
147        let (tcp_inbound_app, tcp_outbound_stack) = mpsc::channel::<Packet>(config.channel_size);
148        let buffer_pool = Arc::new(BufferPool::new(
149            config.buffer_pool_size,
150            config.buffer_pool_default_buffer_size,
151        ));
152
153        (
154            NetStack {
155                udp_inbound: udp_inbound_app,
156                tcp_inbound: tcp_inbound_app,
157                packet_outbound: packet_receiver,
158            },
159            TcpConnection::new(
160                config.clone(),
161                tcp_outbound_stack,
162                packet_sender.clone(),
163                buffer_pool.clone(),
164            ),
165            UdpTunnel::new(
166                config.into(),
167                udp_outbound_stack,
168                packet_sender.clone(),
169                buffer_pool.clone(),
170            ),
171        )
172    }
173
174    pub fn split(self) -> (StackSplitSink, StackSplitStream) {
175        (
176            StackSplitSink::new(self.udp_inbound, self.tcp_inbound),
177            StackSplitStream::new(self.packet_outbound),
178        )
179    }
180}
181
182pub struct StackSplitSink {
183    udp_inbound: mpsc::Sender<Packet>,
184    tcp_inbound: mpsc::Sender<Packet>,
185    packet_container: Option<(Packet, IpProtocol)>,
186}
187
188impl StackSplitSink {
189    pub fn new(udp_inbound: mpsc::Sender<Packet>, tcp_inbound: mpsc::Sender<Packet>) -> Self {
190        Self {
191            udp_inbound,
192            tcp_inbound,
193            packet_container: None,
194        }
195    }
196}
197
198impl futures::Sink<Packet> for StackSplitSink {
199    type Error = io::Error;
200
201    fn poll_ready(
202        mut self: std::pin::Pin<&mut Self>,
203        cx: &mut Context<'_>,
204    ) -> Poll<Result<(), Self::Error>> {
205        if self.packet_container.is_some() {
206            match self.as_mut().poll_flush(cx) {
207                Poll::Ready(Ok(())) => {}
208                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
209                Poll::Pending => return Poll::Pending,
210            }
211        }
212        Poll::Ready(Ok(()))
213    }
214
215    fn start_send(mut self: std::pin::Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
216        if item.data().is_empty() {
217            return Ok(());
218        }
219
220        let packet = IpPacket::new_checked(item.data())
221            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
222
223        let protocol = packet.protocol();
224        if matches!(
225            protocol,
226            IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
227        ) {
228            self.packet_container.replace((item, protocol));
229        } else {
230            error!("IP packet ignored protocol: {protocol:?}");
231        }
232
233        Ok(())
234    }
235
236    fn poll_flush(
237        mut self: std::pin::Pin<&mut Self>,
238        cx: &mut Context<'_>,
239    ) -> Poll<Result<(), Self::Error>> {
240        let (item, proto) = match self.packet_container.take() {
241            Some(val) => val,
242            None => return Poll::Ready(Ok(())),
243        };
244
245        let sender = match proto {
246            IpProtocol::Udp => self.udp_inbound.clone(),
247            IpProtocol::Tcp | IpProtocol::Icmp | IpProtocol::Icmpv6 => self.tcp_inbound.clone(),
248            _ => {
249                error!("Unsupported protocol for packet: {proto:?}");
250                return Poll::Ready(Ok(()));
251            }
252        };
253        let mut fut = Box::pin(sender.reserve());
254
255        match fut.as_mut().poll(cx) {
256            Poll::Ready(Ok(permit)) => {
257                permit.send(item);
258                Poll::Ready(Ok(()))
259            }
260            Poll::Ready(Err(_)) => {
261                let msg = format!("Failed to send packet: channel closed for protocol {proto:?}");
262                debug!("{}", msg);
263                Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, msg)))
264            }
265            Poll::Pending => {
266                self.packet_container = Some((item, proto));
267                Poll::Pending
268            }
269        }
270    }
271
272    fn poll_close(
273        self: std::pin::Pin<&mut Self>,
274        _cx: &mut Context<'_>,
275    ) -> Poll<Result<(), Self::Error>> {
276        Poll::Ready(Ok(()))
277    }
278}
279
280pub struct StackSplitStream {
281    packet_outbound: mpsc::Receiver<Packet>,
282}
283
284impl StackSplitStream {
285    pub fn new(packet_outbound: mpsc::Receiver<Packet>) -> Self {
286        Self { packet_outbound }
287    }
288}
289
290impl futures::Stream for StackSplitStream {
291    type Item = io::Result<Packet>;
292
293    fn poll_next(
294        mut self: std::pin::Pin<&mut Self>,
295        cx: &mut Context<'_>,
296    ) -> Poll<Option<Self::Item>> {
297        match self.packet_outbound.poll_recv(cx) {
298            Poll::Ready(Some(packet)) => Poll::Ready(Some(Ok(packet))),
299            Poll::Ready(None) => Poll::Ready(None),
300            Poll::Pending => Poll::Pending,
301        }
302    }
303}