ipstack/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use ahash::AHashMap;
4use packet::{NetworkPacket, NetworkTuple, TransportHeader};
5use std::{sync::Arc, time::Duration};
6use tokio::{
7    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
8    select,
9    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
10    task::JoinHandle,
11};
12
13pub(crate) type PacketSender = UnboundedSender<NetworkPacket>;
14pub(crate) type PacketReceiver = UnboundedReceiver<NetworkPacket>;
15pub(crate) type SessionCollection = AHashMap<NetworkTuple, PacketSender>;
16
17mod error;
18mod packet;
19mod stream;
20
21pub use self::error::{IpStackError, Result};
22pub use self::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport};
23pub use self::stream::{TcpConfig, TcpOptions};
24pub use etherparse::IpNumber;
25
26#[cfg(unix)]
27const TTL: u8 = 64;
28
29#[cfg(windows)]
30const TTL: u8 = 128;
31
32#[cfg(unix)]
33const TUN_FLAGS: [u8; 2] = [0x00, 0x00];
34
35#[cfg(any(target_os = "linux", target_os = "android", target_os = "freebsd", target_os = "espidf"))]
36const TUN_PROTO_IP6: [u8; 2] = [0x86, 0xdd];
37#[cfg(any(target_os = "linux", target_os = "android", target_os = "freebsd", target_os = "espidf"))]
38const TUN_PROTO_IP4: [u8; 2] = [0x08, 0x00];
39
40#[cfg(any(target_os = "macos", target_os = "ios"))]
41const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A];
42#[cfg(any(target_os = "macos", target_os = "ios"))]
43const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02];
44
45/// Minimum MTU required for IPv6 (per RFC 8200 §5: MTU ≥ 1280).
46/// Also satisfies IPv4 minimum MTU (RFC 791 §3.1: 68 bytes).
47const MIN_MTU: u16 = 1280;
48
49/// Configuration for the IP stack.
50///
51/// This structure holds configuration parameters that control the behavior of the IP stack,
52/// including network settings and protocol-specific timeouts.
53///
54/// # Examples
55///
56/// ```
57/// use ipstack::IpStackConfig;
58/// use std::time::Duration;
59///
60/// let mut config = IpStackConfig::default();
61/// config.mtu(1500).expect("Failed to set MTU")
62///       .udp_timeout(Duration::from_secs(60))
63///       .packet_information(false);
64/// ```
65#[non_exhaustive]
66pub struct IpStackConfig {
67    /// Maximum Transmission Unit (MTU) size in bytes.
68    /// Default is `MIN_MTU` (1280).
69    pub mtu: u16,
70    /// Whether to include packet information headers (Unix platforms only).
71    /// Default is `false`.
72    pub packet_information: bool,
73    /// TCP-specific configuration parameters.
74    pub tcp_config: Arc<TcpConfig>,
75    /// Timeout for UDP connections.
76    /// Default is 30 seconds.
77    pub udp_timeout: Duration,
78}
79
80impl Default for IpStackConfig {
81    fn default() -> Self {
82        IpStackConfig {
83            mtu: MIN_MTU,
84            packet_information: false,
85            tcp_config: Arc::new(TcpConfig::default()),
86            udp_timeout: Duration::from_secs(30),
87        }
88    }
89}
90
91impl IpStackConfig {
92    /// Set custom TCP configuration.
93    ///
94    /// # Arguments
95    ///
96    /// * `config` - The TCP configuration to use
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use ipstack::{IpStackConfig, TcpConfig};
102    ///
103    /// let mut config = IpStackConfig::default();
104    /// config.with_tcp_config(TcpConfig::default());
105    /// ```
106    pub fn with_tcp_config(&mut self, config: TcpConfig) -> &mut Self {
107        self.tcp_config = Arc::new(config);
108        self
109    }
110
111    /// Set the UDP connection timeout.
112    ///
113    /// # Arguments
114    ///
115    /// * `timeout` - The timeout duration for UDP connections
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use ipstack::IpStackConfig;
121    /// use std::time::Duration;
122    ///
123    /// let mut config = IpStackConfig::default();
124    /// config.udp_timeout(Duration::from_secs(60));
125    /// ```
126    pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self {
127        self.udp_timeout = timeout;
128        self
129    }
130
131    /// Set the Maximum Transmission Unit (MTU) size.
132    ///
133    /// # Arguments
134    ///
135    /// * `mtu` - The MTU size in bytes
136    ///
137    /// # Examples
138    ///
139    /// ```
140    /// use ipstack::IpStackConfig;
141    ///
142    /// let mut config = IpStackConfig::default();
143    /// config.mtu(1500).expect("Failed to set MTU");
144    /// ```
145    pub fn mtu(&mut self, mtu: u16) -> Result<&mut Self, IpStackError> {
146        if mtu < MIN_MTU {
147            return Err(IpStackError::InvalidMtuSize(mtu));
148        }
149        self.mtu = mtu;
150        Ok(self)
151    }
152
153    /// Set the Maximum Transmission Unit (MTU) size without validation.
154    pub fn mtu_unchecked(&mut self, mtu: u16) -> &mut Self {
155        self.mtu = mtu;
156        self
157    }
158
159    /// Enable or disable packet information headers (Unix platforms only).
160    ///
161    /// When enabled on Unix platforms, the TUN device will include 4-byte packet
162    /// information headers.
163    ///
164    /// # Arguments
165    ///
166    /// * `packet_information` - Whether to include packet information headers
167    ///
168    /// # Examples
169    ///
170    /// ```
171    /// use ipstack::IpStackConfig;
172    ///
173    /// let mut config = IpStackConfig::default();
174    /// config.packet_information(true);
175    /// ```
176    pub fn packet_information(&mut self, packet_information: bool) -> &mut Self {
177        self.packet_information = packet_information;
178        self
179    }
180}
181
182/// The main IP stack instance.
183///
184/// `IpStack` provides a userspace TCP/IP stack implementation for TUN devices.
185/// It processes network packets and creates stream abstractions for TCP, UDP, and
186/// unknown transport protocols.
187///
188/// # Examples
189///
190/// ```no_run
191/// use ipstack::{IpStack, IpStackConfig, IpStackStream};
192/// use std::net::Ipv4Addr;
193///
194/// #[tokio::main]
195/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
196///     // Configure TUN device
197///     let mut config = tun::Configuration::default();
198///     config
199///         .address(Ipv4Addr::new(10, 0, 0, 1))
200///         .netmask(Ipv4Addr::new(255, 255, 255, 0))
201///         .up();
202///
203///     // Create IP stack
204///     let ipstack_config = IpStackConfig::default();
205///     let mut ip_stack = IpStack::new(ipstack_config, tun::create_as_async(&config)?);
206///
207///     // Accept incoming streams
208///     while let Ok(stream) = ip_stack.accept().await {
209///         match stream {
210///             IpStackStream::Tcp(tcp) => {
211///                 // Handle TCP connection
212///             }
213///             IpStackStream::Udp(udp) => {
214///                 // Handle UDP connection
215///             }
216///             _ => {}
217///         }
218///     }
219///     Ok(())
220/// }
221/// ```
222pub struct IpStack {
223    accept_receiver: UnboundedReceiver<IpStackStream>,
224    handle: JoinHandle<Result<()>>,
225}
226
227impl IpStack {
228    /// Create a new IP stack instance.
229    ///
230    /// # Arguments
231    ///
232    /// * `config` - Configuration for the IP stack
233    /// * `device` - An async TUN device implementing `AsyncRead` + `AsyncWrite`
234    ///
235    /// # Examples
236    ///
237    /// ```no_run
238    /// use ipstack::{IpStack, IpStackConfig};
239    /// use std::net::Ipv4Addr;
240    ///
241    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
242    /// let mut tun_config = tun::Configuration::default();
243    /// tun_config.address(Ipv4Addr::new(10, 0, 0, 1))
244    ///           .netmask(Ipv4Addr::new(255, 255, 255, 0))
245    ///           .up();
246    ///
247    /// let ipstack_config = IpStackConfig::default();
248    /// let ip_stack = IpStack::new(ipstack_config, tun::create_as_async(&tun_config)?);
249    /// # Ok(())
250    /// # }
251    /// ```
252    pub fn new<Device>(config: IpStackConfig, device: Device) -> IpStack
253    where
254        Device: AsyncRead + AsyncWrite + Unpin + Send + 'static,
255    {
256        let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();
257        IpStack {
258            accept_receiver,
259            handle: run(config, device, accept_sender),
260        }
261    }
262
263    /// Accept an incoming network stream.
264    ///
265    /// This method waits for and returns the next incoming network connection or packet.
266    /// The returned `IpStackStream` enum indicates the type of stream (TCP, UDP, or unknown).
267    ///
268    /// # Returns
269    ///
270    /// * `Ok(IpStackStream)` - The next incoming stream
271    /// * `Err(IpStackError::AcceptError)` - If the IP stack has been shut down
272    ///
273    /// # Examples
274    ///
275    /// ```no_run
276    /// use ipstack::{IpStack, IpStackConfig, IpStackStream};
277    ///
278    /// # async fn example(mut ip_stack: IpStack) -> Result<(), Box<dyn std::error::Error>> {
279    /// match ip_stack.accept().await? {
280    ///     IpStackStream::Tcp(tcp) => {
281    ///         println!("New TCP connection from {}", tcp.peer_addr());
282    ///     }
283    ///     IpStackStream::Udp(udp) => {
284    ///         println!("New UDP stream from {}", udp.peer_addr());
285    ///     }
286    ///     IpStackStream::UnknownTransport(unknown) => {
287    ///         println!("Unknown transport protocol: {:?}", unknown.ip_protocol());
288    ///     }
289    ///     IpStackStream::UnknownNetwork(data) => {
290    ///         println!("Unknown network packet: {} bytes", data.len());
291    ///     }
292    /// }
293    /// # Ok(())
294    /// # }
295    /// ```
296    pub async fn accept(&mut self) -> Result<IpStackStream, IpStackError> {
297        self.accept_receiver.recv().await.ok_or(IpStackError::AcceptError)
298    }
299}
300
301impl Drop for IpStack {
302    fn drop(&mut self) {
303        self.handle.abort();
304    }
305}
306
307fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
308    config: IpStackConfig,
309    mut device: Device,
310    accept_sender: UnboundedSender<IpStackStream>,
311) -> JoinHandle<Result<()>> {
312    let mut sessions: SessionCollection = AHashMap::new();
313    let (session_remove_tx, mut session_remove_rx) = mpsc::unbounded_channel::<NetworkTuple>();
314    let pi = config.packet_information;
315    let offset = if pi && cfg!(unix) { 4 } else { 0 };
316    let mut buffer = vec![0_u8; config.mtu as usize + offset];
317    let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
318
319    tokio::spawn(async move {
320        loop {
321            select! {
322                Ok(n) = device.read(&mut buffer) => {
323                    if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions, &session_remove_tx, &up_pkt_sender, &config, &accept_sender).await {
324                        let io_err: std::io::Error = e.into();
325                        if io_err.kind() == std::io::ErrorKind::ConnectionRefused {
326                            log::trace!("Received junk data: {io_err}");
327                        } else {
328                            log::warn!("process_device_read error: {io_err}");
329                        }
330                    }
331                }
332                Some(network_tuple) = session_remove_rx.recv() => {
333                    sessions.remove(&network_tuple);
334                    log::debug!("session destroyed: {network_tuple}");
335                }
336                Some(packet) = up_pkt_receiver.recv() => {
337                    process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?;
338                }
339            }
340        }
341    })
342}
343
344async fn process_device_read(
345    data: &[u8],
346    sessions: &mut SessionCollection,
347    session_remove_tx: &UnboundedSender<NetworkTuple>,
348    up_pkt_sender: &PacketSender,
349    config: &IpStackConfig,
350    accept_sender: &UnboundedSender<IpStackStream>,
351) -> Result<()> {
352    let Ok(packet) = NetworkPacket::parse(data) else {
353        let stream = IpStackStream::UnknownNetwork(data.to_owned());
354        accept_sender.send(stream)?;
355        return Ok(());
356    };
357
358    if let TransportHeader::Unknown = packet.transport_header() {
359        let stream = IpStackStream::UnknownTransport(IpStackUnknownTransport::new(
360            packet.src_addr().ip(),
361            packet.dst_addr().ip(),
362            packet.payload.unwrap_or_default(),
363            &packet.ip,
364            config.mtu,
365            up_pkt_sender.clone(),
366        ));
367        accept_sender.send(stream)?;
368        return Ok(());
369    }
370
371    let network_tuple = packet.network_tuple();
372    match sessions.entry(network_tuple) {
373        std::collections::hash_map::Entry::Occupied(entry) => {
374            let len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0);
375            log::trace!("packet sent to stream: {network_tuple} len {len}");
376            entry.get().send(packet).map_err(std::io::Error::other)?;
377        }
378        std::collections::hash_map::Entry::Vacant(entry) => {
379            let (tx, rx) = tokio::sync::oneshot::channel::<()>();
380            let ip_stack_stream = create_stream(packet, config, up_pkt_sender.clone(), Some(tx))?;
381            let session_remove_tx = session_remove_tx.clone();
382            tokio::spawn(async move {
383                rx.await.ok();
384                if let Err(e) = session_remove_tx.send(network_tuple) {
385                    log::error!("Failed to send session removal for {network_tuple}: {e}");
386                }
387            });
388            let packet_sender = ip_stack_stream.stream_sender()?;
389            accept_sender.send(ip_stack_stream)?;
390            entry.insert(packet_sender);
391            log::debug!("session created: {network_tuple}");
392        }
393    }
394    Ok(())
395}
396
397fn create_stream(
398    packet: NetworkPacket,
399    cfg: &IpStackConfig,
400    up_pkt_sender: PacketSender,
401    msgr: Option<::tokio::sync::oneshot::Sender<()>>,
402) -> Result<IpStackStream> {
403    let src_addr = packet.src_addr();
404    let dst_addr = packet.dst_addr();
405    match packet.transport_header() {
406        TransportHeader::Tcp(h) => {
407            let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, msgr, cfg.tcp_config.clone())?;
408            Ok(IpStackStream::Tcp(stream))
409        }
410        TransportHeader::Udp(_) => {
411            let payload = packet.payload.unwrap_or_default();
412            let stream = IpStackUdpStream::new(src_addr, dst_addr, payload, up_pkt_sender, cfg.mtu, cfg.udp_timeout, msgr);
413            Ok(IpStackStream::Udp(stream))
414        }
415        TransportHeader::Unknown => Err(IpStackError::UnsupportedTransportProtocol),
416    }
417}
418
419async fn process_upstream_recv<Device: AsyncWrite + Unpin + 'static>(
420    up_packet: NetworkPacket,
421    device: &mut Device,
422    #[cfg(unix)] packet_information: bool,
423) -> Result<()> {
424    #[allow(unused_mut)]
425    let Ok(mut packet_bytes) = up_packet.to_bytes() else {
426        log::warn!("to_bytes error");
427        return Ok(());
428    };
429    #[cfg(unix)]
430    if packet_information {
431        if up_packet.src_addr().is_ipv4() {
432            packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
433        } else {
434            packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
435        }
436    }
437    device.write_all(&packet_bytes).await?;
438    // device.flush().await?;
439
440    Ok(())
441}