Skip to main content

nlink/netlink/
netfilter.rs

1//! Netfilter implementation for `Connection<Netfilter>`.
2//!
3//! This module provides methods for querying and managing connection tracking
4//! entries via the NETLINK_NETFILTER protocol.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use nlink::netlink::{Connection, Netfilter};
10//!
11//! let conn = Connection::<Netfilter>::new()?;
12//!
13//! // List all connection tracking entries
14//! let entries = conn.get_conntrack().await?;
15//! for entry in &entries {
16//!     println!("{:?} {}:{} -> {}:{}",
17//!         entry.proto,
18//!         entry.orig.src_ip,
19//!         entry.orig.src_port.unwrap_or(0),
20//!         entry.orig.dst_ip,
21//!         entry.orig.dst_port.unwrap_or(0));
22//! }
23//! ```
24
25use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
26
27use winnow::binary::be_u16;
28use winnow::prelude::*;
29
30use super::connection::Connection;
31use super::error::Result;
32use super::parse::PResult;
33use super::protocol::{Netfilter, ProtocolState};
34use super::socket::NetlinkSocket;
35
36// Netlink constants
37const NLMSG_DONE: u16 = 3;
38const NLMSG_ERROR: u16 = 2;
39const NLM_F_REQUEST: u16 = 0x01;
40const NLM_F_DUMP: u16 = 0x300;
41
42// Netfilter subsystem IDs
43const NFNL_SUBSYS_CTNETLINK: u8 = 1;
44
45// Conntrack message types
46const IPCTNL_MSG_CT_GET: u8 = 1;
47
48// Conntrack attributes
49const CTA_TUPLE_ORIG: u16 = 1;
50const CTA_TUPLE_REPLY: u16 = 2;
51const CTA_STATUS: u16 = 3;
52const CTA_PROTOINFO: u16 = 4;
53const CTA_TIMEOUT: u16 = 7;
54const CTA_MARK: u16 = 8;
55const CTA_COUNTERS_ORIG: u16 = 9;
56const CTA_COUNTERS_REPLY: u16 = 10;
57const CTA_ID: u16 = 12;
58
59// Tuple attributes
60const CTA_TUPLE_IP: u16 = 1;
61const CTA_TUPLE_PROTO: u16 = 2;
62
63// IP attributes
64const CTA_IP_V4_SRC: u16 = 1;
65const CTA_IP_V4_DST: u16 = 2;
66const CTA_IP_V6_SRC: u16 = 3;
67const CTA_IP_V6_DST: u16 = 4;
68
69// Proto attributes
70const CTA_PROTO_NUM: u16 = 1;
71const CTA_PROTO_SRC_PORT: u16 = 2;
72const CTA_PROTO_DST_PORT: u16 = 3;
73const CTA_PROTO_ICMP_ID: u16 = 4;
74const CTA_PROTO_ICMP_TYPE: u16 = 5;
75const CTA_PROTO_ICMP_CODE: u16 = 6;
76
77// Protoinfo attributes
78const CTA_PROTOINFO_TCP: u16 = 1;
79const CTA_PROTOINFO_TCP_STATE: u16 = 1;
80
81// Counter attributes
82const CTA_COUNTERS_PACKETS: u16 = 1;
83const CTA_COUNTERS_BYTES: u16 = 2;
84
85// Netlink header size
86const NLMSG_HDRLEN: usize = 16;
87
88/// IP protocol numbers.
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum IpProtocol {
91    /// TCP (6)
92    Tcp,
93    /// UDP (17)
94    Udp,
95    /// ICMP (1)
96    Icmp,
97    /// ICMPv6 (58)
98    Icmpv6,
99    /// Other protocol
100    Other(u8),
101}
102
103impl IpProtocol {
104    fn from_u8(val: u8) -> Self {
105        match val {
106            1 => Self::Icmp,
107            6 => Self::Tcp,
108            17 => Self::Udp,
109            58 => Self::Icmpv6,
110            other => Self::Other(other),
111        }
112    }
113
114    /// Get the protocol number.
115    pub fn number(&self) -> u8 {
116        match self {
117            Self::Icmp => 1,
118            Self::Tcp => 6,
119            Self::Udp => 17,
120            Self::Icmpv6 => 58,
121            Self::Other(n) => *n,
122        }
123    }
124}
125
126/// TCP connection tracking state.
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum TcpConntrackState {
129    None,
130    SynSent,
131    SynRecv,
132    Established,
133    FinWait,
134    CloseWait,
135    LastAck,
136    TimeWait,
137    Close,
138    Listen,
139    SynSent2,
140    Max,
141    Ignore,
142    Retrans,
143    Unack,
144    Unknown(u8),
145}
146
147impl TcpConntrackState {
148    fn from_u8(val: u8) -> Self {
149        match val {
150            0 => Self::None,
151            1 => Self::SynSent,
152            2 => Self::SynRecv,
153            3 => Self::Established,
154            4 => Self::FinWait,
155            5 => Self::CloseWait,
156            6 => Self::LastAck,
157            7 => Self::TimeWait,
158            8 => Self::Close,
159            9 => Self::Listen,
160            10 => Self::SynSent2,
161            11 => Self::Max,
162            12 => Self::Ignore,
163            13 => Self::Retrans,
164            14 => Self::Unack,
165            other => Self::Unknown(other),
166        }
167    }
168}
169
170/// A connection tracking tuple (source/destination).
171#[derive(Debug, Clone, Default)]
172pub struct ConntrackTuple {
173    /// Source IP address.
174    pub src_ip: Option<IpAddr>,
175    /// Destination IP address.
176    pub dst_ip: Option<IpAddr>,
177    /// Source port (TCP/UDP).
178    pub src_port: Option<u16>,
179    /// Destination port (TCP/UDP).
180    pub dst_port: Option<u16>,
181    /// ICMP ID.
182    pub icmp_id: Option<u16>,
183    /// ICMP type.
184    pub icmp_type: Option<u8>,
185    /// ICMP code.
186    pub icmp_code: Option<u8>,
187}
188
189/// Packet/byte counters.
190#[derive(Debug, Clone, Default)]
191pub struct ConntrackCounters {
192    /// Number of packets.
193    pub packets: u64,
194    /// Number of bytes.
195    pub bytes: u64,
196}
197
198/// A connection tracking entry.
199#[derive(Debug, Clone)]
200pub struct ConntrackEntry {
201    /// IP protocol (TCP, UDP, ICMP, etc.).
202    pub proto: IpProtocol,
203    /// Original direction tuple.
204    pub orig: ConntrackTuple,
205    /// Reply direction tuple.
206    pub reply: ConntrackTuple,
207    /// TCP connection state (if TCP).
208    pub tcp_state: Option<TcpConntrackState>,
209    /// Timeout in seconds.
210    pub timeout: Option<u32>,
211    /// Connection mark.
212    pub mark: Option<u32>,
213    /// Connection status flags.
214    pub status: Option<u32>,
215    /// Connection ID.
216    pub id: Option<u32>,
217    /// Original direction counters.
218    pub counters_orig: Option<ConntrackCounters>,
219    /// Reply direction counters.
220    pub counters_reply: Option<ConntrackCounters>,
221}
222
223impl Default for ConntrackEntry {
224    fn default() -> Self {
225        Self {
226            proto: IpProtocol::Other(0),
227            orig: ConntrackTuple::default(),
228            reply: ConntrackTuple::default(),
229            tcp_state: None,
230            timeout: None,
231            mark: None,
232            status: None,
233            id: None,
234            counters_orig: None,
235            counters_reply: None,
236        }
237    }
238}
239
240/// nfgenmsg header (4 bytes).
241#[repr(C)]
242#[derive(Debug, Clone, Copy, Default)]
243struct NfGenMsg {
244    family: u8,
245    version: u8,
246    res_id: u16,
247}
248
249impl NfGenMsg {
250    fn parse(input: &mut &[u8]) -> PResult<Self> {
251        let family = winnow::binary::le_u8.parse_next(input)?;
252        let version = winnow::binary::le_u8.parse_next(input)?;
253        let res_id = be_u16.parse_next(input)?;
254        Ok(Self {
255            family,
256            version,
257            res_id,
258        })
259    }
260}
261
262impl Connection<Netfilter> {
263    /// Create a new netfilter connection.
264    ///
265    /// # Example
266    ///
267    /// ```ignore
268    /// use nlink::netlink::{Connection, Netfilter};
269    ///
270    /// let conn = Connection::<Netfilter>::new()?;
271    /// ```
272    pub fn new() -> Result<Self> {
273        let socket = NetlinkSocket::new(Netfilter::PROTOCOL)?;
274        Ok(Self::from_parts(socket, Netfilter))
275    }
276
277    /// Get all connection tracking entries.
278    ///
279    /// # Example
280    ///
281    /// ```ignore
282    /// use nlink::netlink::{Connection, Netfilter};
283    ///
284    /// let conn = Connection::<Netfilter>::new()?;
285    /// let entries = conn.get_conntrack().await?;
286    ///
287    /// for entry in &entries {
288    ///     println!("{:?}: {:?} -> {:?}",
289    ///         entry.proto,
290    ///         entry.orig.src_ip,
291    ///         entry.orig.dst_ip);
292    /// }
293    /// ```
294    pub async fn get_conntrack(&self) -> Result<Vec<ConntrackEntry>> {
295        self.get_conntrack_family(libc::AF_INET as u8).await
296    }
297
298    /// Get connection tracking entries for IPv6.
299    pub async fn get_conntrack_v6(&self) -> Result<Vec<ConntrackEntry>> {
300        self.get_conntrack_family(libc::AF_INET6 as u8).await
301    }
302
303    /// Get connection tracking entries for a specific address family.
304    async fn get_conntrack_family(&self, family: u8) -> Result<Vec<ConntrackEntry>> {
305        let seq = self.socket().next_seq();
306        let pid = self.socket().pid();
307
308        // Build request
309        let mut buf = Vec::with_capacity(64);
310
311        // Netlink header (16 bytes)
312        // Message type: (NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET
313        let msg_type = ((NFNL_SUBSYS_CTNETLINK as u16) << 8) | (IPCTNL_MSG_CT_GET as u16);
314
315        buf.extend_from_slice(&0u32.to_ne_bytes()); // nlmsg_len (fill later)
316        buf.extend_from_slice(&msg_type.to_ne_bytes()); // nlmsg_type
317        buf.extend_from_slice(&(NLM_F_REQUEST | NLM_F_DUMP).to_ne_bytes()); // nlmsg_flags
318        buf.extend_from_slice(&seq.to_ne_bytes()); // nlmsg_seq
319        buf.extend_from_slice(&pid.to_ne_bytes()); // nlmsg_pid
320
321        // nfgenmsg (4 bytes)
322        buf.push(family); // nfgen_family
323        buf.push(0); // version (NFNETLINK_V0)
324        buf.extend_from_slice(&0u16.to_be_bytes()); // res_id
325
326        // Update length
327        let len = buf.len() as u32;
328        buf[0..4].copy_from_slice(&len.to_ne_bytes());
329
330        // Send request
331        self.socket().send(&buf).await?;
332
333        // Receive responses
334        let mut entries = Vec::new();
335
336        loop {
337            let data = self.socket().recv_msg().await?;
338
339            let mut offset = 0;
340            while offset + 16 <= data.len() {
341                let nlmsg_len = u32::from_ne_bytes([
342                    data[offset],
343                    data[offset + 1],
344                    data[offset + 2],
345                    data[offset + 3],
346                ]) as usize;
347
348                let nlmsg_type = u16::from_ne_bytes([data[offset + 4], data[offset + 5]]);
349
350                if nlmsg_len < 16 || offset + nlmsg_len > data.len() {
351                    break;
352                }
353
354                match nlmsg_type {
355                    NLMSG_DONE => return Ok(entries),
356                    NLMSG_ERROR => {
357                        if nlmsg_len >= 20 {
358                            let errno = i32::from_ne_bytes([
359                                data[offset + 16],
360                                data[offset + 17],
361                                data[offset + 18],
362                                data[offset + 19],
363                            ]);
364                            if errno != 0 {
365                                return Err(super::error::Error::from_errno(-errno));
366                            }
367                        }
368                    }
369                    _ => {
370                        // Check if it's a conntrack message
371                        let subsys = (nlmsg_type >> 8) as u8;
372                        if subsys == NFNL_SUBSYS_CTNETLINK
373                            && let Some(entry) =
374                                self.parse_conntrack(&data[offset..offset + nlmsg_len])
375                        {
376                            entries.push(entry);
377                        }
378                    }
379                }
380
381                // Align to 4 bytes
382                offset += (nlmsg_len + 3) & !3;
383            }
384        }
385    }
386
387    /// Parse a conntrack message using winnow.
388    fn parse_conntrack(&self, data: &[u8]) -> Option<ConntrackEntry> {
389        // Skip netlink header (16 bytes)
390        if data.len() < NLMSG_HDRLEN + 4 {
391            return None;
392        }
393
394        let mut input = &data[NLMSG_HDRLEN..];
395
396        // Parse nfgenmsg header
397        let _nfmsg = NfGenMsg::parse(&mut input).ok()?;
398
399        // Parse attributes
400        let mut entry = ConntrackEntry::default();
401
402        while input.len() >= 4 {
403            let (attr_type, attr_data) = parse_nla(&mut input)?;
404
405            match attr_type & 0x7FFF {
406                // Remove NLA_F_NESTED flag
407                CTA_TUPLE_ORIG => {
408                    if let Some((tuple, proto)) = parse_tuple(attr_data) {
409                        entry.orig = tuple;
410                        entry.proto = proto;
411                    }
412                }
413                CTA_TUPLE_REPLY => {
414                    if let Some((tuple, _)) = parse_tuple(attr_data) {
415                        entry.reply = tuple;
416                    }
417                }
418                CTA_STATUS => {
419                    if attr_data.len() >= 4 {
420                        entry.status = Some(u32::from_be_bytes([
421                            attr_data[0],
422                            attr_data[1],
423                            attr_data[2],
424                            attr_data[3],
425                        ]));
426                    }
427                }
428                CTA_TIMEOUT => {
429                    if attr_data.len() >= 4 {
430                        entry.timeout = Some(u32::from_be_bytes([
431                            attr_data[0],
432                            attr_data[1],
433                            attr_data[2],
434                            attr_data[3],
435                        ]));
436                    }
437                }
438                CTA_MARK => {
439                    if attr_data.len() >= 4 {
440                        entry.mark = Some(u32::from_be_bytes([
441                            attr_data[0],
442                            attr_data[1],
443                            attr_data[2],
444                            attr_data[3],
445                        ]));
446                    }
447                }
448                CTA_ID => {
449                    if attr_data.len() >= 4 {
450                        entry.id = Some(u32::from_be_bytes([
451                            attr_data[0],
452                            attr_data[1],
453                            attr_data[2],
454                            attr_data[3],
455                        ]));
456                    }
457                }
458                CTA_PROTOINFO => {
459                    entry.tcp_state = parse_protoinfo(attr_data);
460                }
461                CTA_COUNTERS_ORIG => {
462                    entry.counters_orig = parse_counters(attr_data);
463                }
464                CTA_COUNTERS_REPLY => {
465                    entry.counters_reply = parse_counters(attr_data);
466                }
467                _ => {}
468            }
469        }
470
471        Some(entry)
472    }
473}
474
475/// Parse a netlink attribute.
476fn parse_nla<'a>(input: &mut &'a [u8]) -> Option<(u16, &'a [u8])> {
477    if input.len() < 4 {
478        return None;
479    }
480
481    // Parse length and type from first 4 bytes
482    let len = u16::from_le_bytes([input[0], input[1]]) as usize;
483    let attr_type = u16::from_le_bytes([input[2], input[3]]);
484    *input = &input[4..];
485
486    if len < 4 {
487        return None;
488    }
489
490    let payload_len = len.saturating_sub(4);
491    if input.len() < payload_len {
492        return None;
493    }
494
495    let payload = &input[..payload_len];
496    *input = &input[payload_len..];
497
498    // Align to 4 bytes
499    let aligned = (len + 3) & !3;
500    let padding = aligned.saturating_sub(len);
501    if input.len() >= padding {
502        *input = &input[padding..];
503    }
504
505    Some((attr_type, payload))
506}
507
508/// Parse a conntrack tuple.
509fn parse_tuple(data: &[u8]) -> Option<(ConntrackTuple, IpProtocol)> {
510    let mut input = data;
511    let mut tuple = ConntrackTuple::default();
512    let mut proto = IpProtocol::Other(0);
513
514    while input.len() >= 4 {
515        let (attr_type, attr_data) = parse_nla(&mut input)?;
516
517        match attr_type & 0x7FFF {
518            CTA_TUPLE_IP => {
519                parse_tuple_ip(attr_data, &mut tuple);
520            }
521            CTA_TUPLE_PROTO => {
522                proto = parse_tuple_proto(attr_data, &mut tuple);
523            }
524            _ => {}
525        }
526    }
527
528    Some((tuple, proto))
529}
530
531/// Parse IP addresses from tuple.
532fn parse_tuple_ip(data: &[u8], tuple: &mut ConntrackTuple) {
533    let mut input = data;
534
535    while input.len() >= 4 {
536        if let Some((attr_type, attr_data)) = parse_nla(&mut input) {
537            match attr_type {
538                CTA_IP_V4_SRC => {
539                    if attr_data.len() >= 4 {
540                        tuple.src_ip = Some(IpAddr::V4(Ipv4Addr::new(
541                            attr_data[0],
542                            attr_data[1],
543                            attr_data[2],
544                            attr_data[3],
545                        )));
546                    }
547                }
548                CTA_IP_V4_DST => {
549                    if attr_data.len() >= 4 {
550                        tuple.dst_ip = Some(IpAddr::V4(Ipv4Addr::new(
551                            attr_data[0],
552                            attr_data[1],
553                            attr_data[2],
554                            attr_data[3],
555                        )));
556                    }
557                }
558                CTA_IP_V6_SRC => {
559                    if attr_data.len() >= 16 {
560                        let mut octets = [0u8; 16];
561                        octets.copy_from_slice(&attr_data[..16]);
562                        tuple.src_ip = Some(IpAddr::V6(Ipv6Addr::from(octets)));
563                    }
564                }
565                CTA_IP_V6_DST => {
566                    if attr_data.len() >= 16 {
567                        let mut octets = [0u8; 16];
568                        octets.copy_from_slice(&attr_data[..16]);
569                        tuple.dst_ip = Some(IpAddr::V6(Ipv6Addr::from(octets)));
570                    }
571                }
572                _ => {}
573            }
574        } else {
575            break;
576        }
577    }
578}
579
580/// Parse protocol info from tuple.
581fn parse_tuple_proto(data: &[u8], tuple: &mut ConntrackTuple) -> IpProtocol {
582    let mut input = data;
583    let mut proto = IpProtocol::Other(0);
584
585    while input.len() >= 4 {
586        if let Some((attr_type, attr_data)) = parse_nla(&mut input) {
587            match attr_type {
588                CTA_PROTO_NUM => {
589                    if !attr_data.is_empty() {
590                        proto = IpProtocol::from_u8(attr_data[0]);
591                    }
592                }
593                CTA_PROTO_SRC_PORT => {
594                    if attr_data.len() >= 2 {
595                        tuple.src_port = Some(u16::from_be_bytes([attr_data[0], attr_data[1]]));
596                    }
597                }
598                CTA_PROTO_DST_PORT => {
599                    if attr_data.len() >= 2 {
600                        tuple.dst_port = Some(u16::from_be_bytes([attr_data[0], attr_data[1]]));
601                    }
602                }
603                CTA_PROTO_ICMP_ID => {
604                    if attr_data.len() >= 2 {
605                        tuple.icmp_id = Some(u16::from_be_bytes([attr_data[0], attr_data[1]]));
606                    }
607                }
608                CTA_PROTO_ICMP_TYPE => {
609                    if !attr_data.is_empty() {
610                        tuple.icmp_type = Some(attr_data[0]);
611                    }
612                }
613                CTA_PROTO_ICMP_CODE => {
614                    if !attr_data.is_empty() {
615                        tuple.icmp_code = Some(attr_data[0]);
616                    }
617                }
618                _ => {}
619            }
620        } else {
621            break;
622        }
623    }
624
625    proto
626}
627
628/// Parse protoinfo for TCP state.
629fn parse_protoinfo(data: &[u8]) -> Option<TcpConntrackState> {
630    let mut input = data;
631
632    while input.len() >= 4 {
633        let (attr_type, attr_data) = parse_nla(&mut input)?;
634
635        if (attr_type & 0x7FFF) == CTA_PROTOINFO_TCP {
636            // Parse TCP protoinfo
637            let mut tcp_input = attr_data;
638            while tcp_input.len() >= 4 {
639                if let Some((tcp_attr, tcp_data)) = parse_nla(&mut tcp_input) {
640                    if tcp_attr == CTA_PROTOINFO_TCP_STATE && !tcp_data.is_empty() {
641                        return Some(TcpConntrackState::from_u8(tcp_data[0]));
642                    }
643                } else {
644                    break;
645                }
646            }
647        }
648    }
649
650    None
651}
652
653/// Parse counters.
654fn parse_counters(data: &[u8]) -> Option<ConntrackCounters> {
655    let mut input = data;
656    let mut counters = ConntrackCounters::default();
657
658    while input.len() >= 4 {
659        if let Some((attr_type, attr_data)) = parse_nla(&mut input) {
660            match attr_type {
661                CTA_COUNTERS_PACKETS => {
662                    if attr_data.len() >= 8 {
663                        counters.packets = u64::from_be_bytes([
664                            attr_data[0],
665                            attr_data[1],
666                            attr_data[2],
667                            attr_data[3],
668                            attr_data[4],
669                            attr_data[5],
670                            attr_data[6],
671                            attr_data[7],
672                        ]);
673                    }
674                }
675                CTA_COUNTERS_BYTES => {
676                    if attr_data.len() >= 8 {
677                        counters.bytes = u64::from_be_bytes([
678                            attr_data[0],
679                            attr_data[1],
680                            attr_data[2],
681                            attr_data[3],
682                            attr_data[4],
683                            attr_data[5],
684                            attr_data[6],
685                            attr_data[7],
686                        ]);
687                    }
688                }
689                _ => {}
690            }
691        } else {
692            break;
693        }
694    }
695
696    Some(counters)
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    #[test]
704    fn ip_protocol_roundtrip() {
705        assert_eq!(IpProtocol::Tcp.number(), 6);
706        assert_eq!(IpProtocol::from_u8(6), IpProtocol::Tcp);
707
708        assert_eq!(IpProtocol::Udp.number(), 17);
709        assert_eq!(IpProtocol::from_u8(17), IpProtocol::Udp);
710    }
711
712    #[test]
713    fn tcp_state_from_u8() {
714        assert_eq!(
715            TcpConntrackState::from_u8(3),
716            TcpConntrackState::Established
717        );
718        assert_eq!(TcpConntrackState::from_u8(7), TcpConntrackState::TimeWait);
719    }
720}