pcapsql_core/protocol/
netlink.rs

1//! Netlink protocol parser.
2//!
3//! Parses Linux Netlink messages from LINKTYPE_NETLINK (253) captures.
4//! This is the base parser that handles the common netlink header and
5//! sets hints for family-specific child parsers (rtnetlink, nfnetlink, etc.).
6
7use smallvec::SmallVec;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12// Re-export constants from netlink-packet-core for use elsewhere
13pub use netlink_packet_core::constants::{NLM_F_ACK, NLM_F_ECHO, NLM_F_MULTIPART, NLM_F_REQUEST};
14
15// Message type constants from netlink-packet-core
16pub use netlink_packet_core::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN};
17
18/// Netlink header length in bytes.
19pub const NETLINK_HEADER_LEN: usize = 16;
20
21/// PCAP link type for Netlink captures.
22pub const LINKTYPE_NETLINK: u16 = 253;
23
24/// Netlink protocol families.
25pub mod family {
26    pub const ROUTE: u8 = 0;
27    pub const UNUSED: u8 = 1;
28    pub const USERSOCK: u8 = 2;
29    pub const FIREWALL: u8 = 3;
30    pub const SOCK_DIAG: u8 = 4;
31    pub const NFLOG: u8 = 5;
32    pub const XFRM: u8 = 6;
33    pub const SELINUX: u8 = 7;
34    pub const ISCSI: u8 = 8;
35    pub const AUDIT: u8 = 9;
36    pub const FIB_LOOKUP: u8 = 10;
37    pub const CONNECTOR: u8 = 11;
38    pub const NETFILTER: u8 = 12;
39    pub const IP6_FW: u8 = 13;
40    pub const DNRTMSG: u8 = 14;
41    pub const KOBJECT_UEVENT: u8 = 15;
42    pub const GENERIC: u8 = 16;
43    pub const SCSITRANSPORT: u8 = 18;
44    pub const ECRYPTFS: u8 = 19;
45    pub const RDMA: u8 = 20;
46    pub const CRYPTO: u8 = 21;
47}
48
49/// Get the name of a netlink family.
50fn family_name(family: u8) -> &'static str {
51    match family {
52        family::ROUTE => "ROUTE",
53        family::USERSOCK => "USERSOCK",
54        family::FIREWALL => "FIREWALL",
55        family::SOCK_DIAG => "SOCK_DIAG",
56        family::NFLOG => "NFLOG",
57        family::XFRM => "XFRM",
58        family::SELINUX => "SELINUX",
59        family::ISCSI => "ISCSI",
60        family::AUDIT => "AUDIT",
61        family::FIB_LOOKUP => "FIB_LOOKUP",
62        family::CONNECTOR => "CONNECTOR",
63        family::NETFILTER => "NETFILTER",
64        family::IP6_FW => "IP6_FW",
65        family::DNRTMSG => "DNRTMSG",
66        family::KOBJECT_UEVENT => "KOBJECT_UEVENT",
67        family::GENERIC => "GENERIC",
68        family::SCSITRANSPORT => "SCSITRANSPORT",
69        family::ECRYPTFS => "ECRYPTFS",
70        family::RDMA => "RDMA",
71        family::CRYPTO => "CRYPTO",
72        _ => "UNKNOWN",
73    }
74}
75
76/// Get the name of a netlink message type.
77fn msg_type_name(msg_type: u16) -> &'static str {
78    match msg_type {
79        NLMSG_NOOP => "NLMSG_NOOP",
80        NLMSG_ERROR => "NLMSG_ERROR",
81        NLMSG_DONE => "NLMSG_DONE",
82        NLMSG_OVERRUN => "NLMSG_OVERRUN",
83        _ => "PROTOCOL_SPECIFIC",
84    }
85}
86
87/// Netlink protocol parser.
88///
89/// Parses the base 16-byte netlink header and sets hints for
90/// family-specific child protocols.
91#[derive(Debug, Clone, Copy)]
92pub struct NetlinkProtocol;
93
94impl Protocol for NetlinkProtocol {
95    fn name(&self) -> &'static str {
96        "netlink"
97    }
98
99    fn display_name(&self) -> &'static str {
100        "Netlink"
101    }
102
103    fn can_parse(&self, context: &ParseContext) -> Option<u32> {
104        // Parse at root level with LINKTYPE_NETLINK
105        if context.is_root() && context.link_type == LINKTYPE_NETLINK {
106            return Some(100);
107        }
108
109        // Also parse when parent is linux_sll with ARPHRD_NETLINK
110        if context.parent_protocol == Some("linux_sll") && context.hint("is_netlink") == Some(1) {
111            return Some(100);
112        }
113
114        None
115    }
116
117    fn parse<'a>(&self, data: &'a [u8], context: &ParseContext) -> ParseResult<'a> {
118        // Netlink header is 16 bytes minimum
119        if data.len() < NETLINK_HEADER_LEN {
120            return ParseResult::error(
121                format!("Netlink message too short: {} bytes", data.len()),
122                data,
123            );
124        }
125
126        let mut fields = SmallVec::new();
127
128        // Parse netlink header (little-endian!)
129        let msg_len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
130        let msg_type = u16::from_le_bytes([data[4], data[5]]);
131        let msg_flags = u16::from_le_bytes([data[6], data[7]]);
132        let msg_seq = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
133        let msg_pid = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
134
135        fields.push(("msg_len", FieldValue::UInt32(msg_len)));
136        fields.push(("msg_type", FieldValue::UInt16(msg_type)));
137        fields.push(("msg_flags", FieldValue::UInt16(msg_flags)));
138        fields.push(("msg_seq", FieldValue::UInt32(msg_seq)));
139        fields.push(("msg_pid", FieldValue::UInt32(msg_pid)));
140
141        // Message type name
142        fields.push(("msg_type_name", FieldValue::Str(msg_type_name(msg_type))));
143
144        // Flag extraction using crate constants
145        let is_request = (msg_flags & NLM_F_REQUEST) != 0;
146        let is_multipart = (msg_flags & NLM_F_MULTIPART) != 0;
147        let is_ack = (msg_flags & NLM_F_ACK) != 0;
148        let is_echo = (msg_flags & NLM_F_ECHO) != 0;
149
150        fields.push(("is_request", FieldValue::Bool(is_request)));
151        fields.push(("is_multipart", FieldValue::Bool(is_multipart)));
152        fields.push(("is_ack", FieldValue::Bool(is_ack)));
153        fields.push(("is_echo", FieldValue::Bool(is_echo)));
154
155        // Get the netlink family from context hint (set by PCAP reader or cooked header)
156        // Default to ROUTE if not specified
157        let nl_family = context
158            .hint("netlink_family")
159            .map(|f| f as u8)
160            .unwrap_or(family::ROUTE);
161
162        fields.push(("family", FieldValue::UInt8(nl_family)));
163        fields.push(("family_name", FieldValue::Str(family_name(nl_family))));
164
165        // Calculate remaining payload
166        let payload_start = NETLINK_HEADER_LEN;
167        let payload_len = (msg_len as usize).saturating_sub(NETLINK_HEADER_LEN);
168        let remaining = if payload_start + payload_len <= data.len() {
169            &data[payload_start..payload_start + payload_len]
170        } else {
171            &data[payload_start..]
172        };
173
174        // Set hints for child protocols
175        let mut child_hints = SmallVec::new();
176        child_hints.push(("netlink_family", nl_family as u64));
177        child_hints.push(("netlink_msg_type", msg_type as u64));
178
179        ParseResult::success(fields, remaining, child_hints)
180    }
181
182    fn schema_fields(&self) -> Vec<FieldDescriptor> {
183        vec![
184            FieldDescriptor::new("netlink.msg_len", DataKind::UInt32).set_nullable(true),
185            FieldDescriptor::new("netlink.msg_type", DataKind::UInt16).set_nullable(true),
186            FieldDescriptor::new("netlink.msg_flags", DataKind::UInt16).set_nullable(true),
187            FieldDescriptor::new("netlink.msg_seq", DataKind::UInt32).set_nullable(true),
188            FieldDescriptor::new("netlink.msg_pid", DataKind::UInt32).set_nullable(true),
189            FieldDescriptor::new("netlink.msg_type_name", DataKind::String).set_nullable(true),
190            FieldDescriptor::new("netlink.is_request", DataKind::Bool).set_nullable(true),
191            FieldDescriptor::new("netlink.is_multipart", DataKind::Bool).set_nullable(true),
192            FieldDescriptor::new("netlink.is_ack", DataKind::Bool).set_nullable(true),
193            FieldDescriptor::new("netlink.is_echo", DataKind::Bool).set_nullable(true),
194            FieldDescriptor::new("netlink.family", DataKind::UInt8).set_nullable(true),
195            FieldDescriptor::new("netlink.family_name", DataKind::String).set_nullable(true),
196        ]
197    }
198
199    fn child_protocols(&self) -> &[&'static str] {
200        &["rtnetlink"]
201    }
202
203    fn dependencies(&self) -> &'static [&'static str] {
204        &[] // Root protocol - no dependencies
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    /// Helper to create a netlink header (little-endian).
213    fn create_netlink_header(
214        msg_len: u32,
215        msg_type: u16,
216        flags: u16,
217        seq: u32,
218        pid: u32,
219    ) -> Vec<u8> {
220        let mut header = Vec::with_capacity(16);
221        header.extend_from_slice(&msg_len.to_le_bytes());
222        header.extend_from_slice(&msg_type.to_le_bytes());
223        header.extend_from_slice(&flags.to_le_bytes());
224        header.extend_from_slice(&seq.to_le_bytes());
225        header.extend_from_slice(&pid.to_le_bytes());
226        header
227    }
228
229    // ==========================================================================
230    // Test 1: can_parse returns Some for root context with LINKTYPE_NETLINK
231    // ==========================================================================
232    #[test]
233    fn test_can_parse_netlink_at_root() {
234        let parser = NetlinkProtocol;
235        let ctx = ParseContext::new(LINKTYPE_NETLINK);
236
237        assert!(parser.can_parse(&ctx).is_some());
238        assert_eq!(parser.can_parse(&ctx), Some(100));
239    }
240
241    // ==========================================================================
242    // Test 2: can_parse returns None for non-netlink link types
243    // ==========================================================================
244    #[test]
245    fn test_cannot_parse_ethernet() {
246        let parser = NetlinkProtocol;
247        let ctx = ParseContext::new(1); // Ethernet LINKTYPE
248
249        assert!(parser.can_parse(&ctx).is_none());
250    }
251
252    // ==========================================================================
253    // Test 3: can_parse returns None when not at root
254    // ==========================================================================
255    #[test]
256    fn test_cannot_parse_when_not_root() {
257        let parser = NetlinkProtocol;
258        let mut ctx = ParseContext::new(LINKTYPE_NETLINK);
259        ctx.parent_protocol = Some("something");
260
261        assert!(parser.can_parse(&ctx).is_none());
262    }
263
264    // ==========================================================================
265    // Test 4: Parse basic netlink header
266    // ==========================================================================
267    #[test]
268    fn test_parse_netlink_header_basic() {
269        let header = create_netlink_header(32, 16, NLM_F_REQUEST, 1, 1234);
270        let parser = NetlinkProtocol;
271        let ctx = ParseContext::new(LINKTYPE_NETLINK);
272
273        let result = parser.parse(&header, &ctx);
274
275        assert!(result.is_ok());
276        assert_eq!(result.get("msg_len"), Some(&FieldValue::UInt32(32)));
277        assert_eq!(result.get("msg_type"), Some(&FieldValue::UInt16(16)));
278        assert_eq!(
279            result.get("msg_flags"),
280            Some(&FieldValue::UInt16(NLM_F_REQUEST))
281        );
282        assert_eq!(result.get("msg_seq"), Some(&FieldValue::UInt32(1)));
283        assert_eq!(result.get("msg_pid"), Some(&FieldValue::UInt32(1234)));
284    }
285
286    // ==========================================================================
287    // Test 5: Header too short
288    // ==========================================================================
289    #[test]
290    fn test_parse_netlink_header_too_short() {
291        let short_data = vec![0u8; 10]; // Less than 16 bytes
292        let parser = NetlinkProtocol;
293        let ctx = ParseContext::new(LINKTYPE_NETLINK);
294
295        let result = parser.parse(&short_data, &ctx);
296
297        assert!(!result.is_ok());
298        assert!(result.error.is_some());
299    }
300
301    // ==========================================================================
302    // Test 6: Reserved message types (NLMSG_DONE)
303    // ==========================================================================
304    #[test]
305    fn test_parse_nlmsg_done() {
306        let header = create_netlink_header(16, NLMSG_DONE, 0, 1, 0);
307        let parser = NetlinkProtocol;
308        let ctx = ParseContext::new(LINKTYPE_NETLINK);
309
310        let result = parser.parse(&header, &ctx);
311
312        assert!(result.is_ok());
313        assert_eq!(
314            result.get("msg_type_name"),
315            Some(&FieldValue::Str("NLMSG_DONE"))
316        );
317    }
318
319    // ==========================================================================
320    // Test 7: Request flag detection
321    // ==========================================================================
322    #[test]
323    fn test_request_flag() {
324        let header = create_netlink_header(16, 16, NLM_F_REQUEST, 1, 0);
325        let parser = NetlinkProtocol;
326        let ctx = ParseContext::new(LINKTYPE_NETLINK);
327
328        let result = parser.parse(&header, &ctx);
329
330        assert!(result.is_ok());
331        assert_eq!(result.get("is_request"), Some(&FieldValue::Bool(true)));
332        assert_eq!(result.get("is_multipart"), Some(&FieldValue::Bool(false)));
333    }
334
335    // ==========================================================================
336    // Test 8: Multipart message flag detection
337    // ==========================================================================
338    #[test]
339    fn test_multipart_message_flag() {
340        let header = create_netlink_header(32, 16, NLM_F_MULTIPART, 1, 0);
341        let parser = NetlinkProtocol;
342        let ctx = ParseContext::new(LINKTYPE_NETLINK);
343
344        let result = parser.parse(&header, &ctx);
345
346        assert!(result.is_ok());
347        assert_eq!(result.get("is_multipart"), Some(&FieldValue::Bool(true)));
348        assert_eq!(result.get("is_request"), Some(&FieldValue::Bool(false)));
349    }
350
351    // ==========================================================================
352    // Test 9: Combined flags
353    // ==========================================================================
354    #[test]
355    fn test_combined_flags() {
356        let flags = NLM_F_REQUEST | NLM_F_MULTIPART | NLM_F_ACK;
357        let header = create_netlink_header(16, 16, flags, 1, 0);
358        let parser = NetlinkProtocol;
359        let ctx = ParseContext::new(LINKTYPE_NETLINK);
360
361        let result = parser.parse(&header, &ctx);
362
363        assert!(result.is_ok());
364        assert_eq!(result.get("is_request"), Some(&FieldValue::Bool(true)));
365        assert_eq!(result.get("is_multipart"), Some(&FieldValue::Bool(true)));
366        assert_eq!(result.get("is_ack"), Some(&FieldValue::Bool(true)));
367        assert_eq!(result.get("is_echo"), Some(&FieldValue::Bool(false)));
368    }
369
370    // ==========================================================================
371    // Test 10: Child hints for rtnetlink parser
372    // ==========================================================================
373    #[test]
374    fn test_child_hints_for_rtnetlink() {
375        let header = create_netlink_header(32, 16, 0, 1, 0);
376        let parser = NetlinkProtocol;
377        let mut ctx = ParseContext::new(LINKTYPE_NETLINK);
378        ctx.insert_hint("netlink_family", family::ROUTE as u64);
379
380        let result = parser.parse(&header, &ctx);
381
382        assert!(result.is_ok());
383        // Check child hints are set
384        assert!(result
385            .child_hints
386            .iter()
387            .any(|(k, v)| *k == "netlink_family" && *v == family::ROUTE as u64));
388        assert!(result
389            .child_hints
390            .iter()
391            .any(|(k, v)| *k == "netlink_msg_type" && *v == 16));
392    }
393
394    // ==========================================================================
395    // Test 11: Family name resolution
396    // ==========================================================================
397    #[test]
398    fn test_family_name_route() {
399        let header = create_netlink_header(16, 16, 0, 1, 0);
400        let parser = NetlinkProtocol;
401        let mut ctx = ParseContext::new(LINKTYPE_NETLINK);
402        ctx.insert_hint("netlink_family", family::ROUTE as u64);
403
404        let result = parser.parse(&header, &ctx);
405
406        assert!(result.is_ok());
407        assert_eq!(
408            result.get("family"),
409            Some(&FieldValue::UInt8(family::ROUTE))
410        );
411        assert_eq!(result.get("family_name"), Some(&FieldValue::Str("ROUTE")));
412    }
413
414    // ==========================================================================
415    // Test 12: Schema fields are complete
416    // ==========================================================================
417    #[test]
418    fn test_netlink_schema_fields() {
419        let parser = NetlinkProtocol;
420        let fields = parser.schema_fields();
421
422        let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
423
424        assert!(field_names.contains(&"netlink.msg_len"));
425        assert!(field_names.contains(&"netlink.msg_type"));
426        assert!(field_names.contains(&"netlink.msg_flags"));
427        assert!(field_names.contains(&"netlink.msg_seq"));
428        assert!(field_names.contains(&"netlink.msg_pid"));
429        assert!(field_names.contains(&"netlink.msg_type_name"));
430        assert!(field_names.contains(&"netlink.is_request"));
431        assert!(field_names.contains(&"netlink.is_multipart"));
432        assert!(field_names.contains(&"netlink.family"));
433        assert!(field_names.contains(&"netlink.family_name"));
434    }
435
436    // ==========================================================================
437    // Test 13: Child protocols declaration
438    // ==========================================================================
439    #[test]
440    fn test_netlink_child_protocols() {
441        let parser = NetlinkProtocol;
442        let children = parser.child_protocols();
443
444        assert!(children.contains(&"rtnetlink"));
445    }
446
447    // ==========================================================================
448    // Test 14: No dependencies (root protocol)
449    // ==========================================================================
450    #[test]
451    fn test_netlink_no_dependencies() {
452        let parser = NetlinkProtocol;
453        let deps = parser.dependencies();
454
455        assert!(deps.is_empty());
456    }
457
458    // ==========================================================================
459    // Test 15: parse_packet integration - netlink is selected for LINKTYPE_NETLINK
460    // ==========================================================================
461    #[test]
462    fn test_parse_packet_selects_netlink() {
463        use crate::protocol::{default_registry, parse_packet};
464
465        let registry = default_registry();
466        let header = create_netlink_header(32, 16, NLM_F_REQUEST, 1, 1234);
467
468        // Parse with link_type 253 (LINKTYPE_NETLINK)
469        let results = parse_packet(&registry, LINKTYPE_NETLINK, &header);
470
471        // Should find netlink protocol
472        assert!(
473            !results.is_empty(),
474            "parse_packet should return at least one protocol"
475        );
476
477        let protocol_names: Vec<&str> = results.iter().map(|(name, _)| *name).collect();
478        assert!(
479            protocol_names.contains(&"netlink"),
480            "parse_packet should select 'netlink' for LINKTYPE_NETLINK, got: {:?}",
481            protocol_names
482        );
483
484        // Verify netlink result has expected fields
485        let (_, netlink_result) = results.iter().find(|(name, _)| *name == "netlink").unwrap();
486        assert_eq!(
487            netlink_result.get("msg_type"),
488            Some(&FieldValue::UInt16(16))
489        );
490        assert_eq!(
491            netlink_result.get("is_request"),
492            Some(&FieldValue::Bool(true))
493        );
494    }
495}