1use smallvec::SmallVec;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12pub use netlink_packet_core::constants::{NLM_F_ACK, NLM_F_ECHO, NLM_F_MULTIPART, NLM_F_REQUEST};
14
15pub use netlink_packet_core::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN};
17
18pub const NETLINK_HEADER_LEN: usize = 16;
20
21pub const LINKTYPE_NETLINK: u16 = 253;
23
24pub mod family {
26 pub const ROUTE: u8 = 0;
27 pub const UNUSED: u8 = 1;
28 pub const USERSOCK: u8 = 2;
29 pub const FIREWALL: u8 = 3;
30 pub const SOCK_DIAG: u8 = 4;
31 pub const NFLOG: u8 = 5;
32 pub const XFRM: u8 = 6;
33 pub const SELINUX: u8 = 7;
34 pub const ISCSI: u8 = 8;
35 pub const AUDIT: u8 = 9;
36 pub const FIB_LOOKUP: u8 = 10;
37 pub const CONNECTOR: u8 = 11;
38 pub const NETFILTER: u8 = 12;
39 pub const IP6_FW: u8 = 13;
40 pub const DNRTMSG: u8 = 14;
41 pub const KOBJECT_UEVENT: u8 = 15;
42 pub const GENERIC: u8 = 16;
43 pub const SCSITRANSPORT: u8 = 18;
44 pub const ECRYPTFS: u8 = 19;
45 pub const RDMA: u8 = 20;
46 pub const CRYPTO: u8 = 21;
47}
48
49fn family_name(family: u8) -> &'static str {
51 match family {
52 family::ROUTE => "ROUTE",
53 family::USERSOCK => "USERSOCK",
54 family::FIREWALL => "FIREWALL",
55 family::SOCK_DIAG => "SOCK_DIAG",
56 family::NFLOG => "NFLOG",
57 family::XFRM => "XFRM",
58 family::SELINUX => "SELINUX",
59 family::ISCSI => "ISCSI",
60 family::AUDIT => "AUDIT",
61 family::FIB_LOOKUP => "FIB_LOOKUP",
62 family::CONNECTOR => "CONNECTOR",
63 family::NETFILTER => "NETFILTER",
64 family::IP6_FW => "IP6_FW",
65 family::DNRTMSG => "DNRTMSG",
66 family::KOBJECT_UEVENT => "KOBJECT_UEVENT",
67 family::GENERIC => "GENERIC",
68 family::SCSITRANSPORT => "SCSITRANSPORT",
69 family::ECRYPTFS => "ECRYPTFS",
70 family::RDMA => "RDMA",
71 family::CRYPTO => "CRYPTO",
72 _ => "UNKNOWN",
73 }
74}
75
76fn msg_type_name(msg_type: u16) -> &'static str {
78 match msg_type {
79 NLMSG_NOOP => "NLMSG_NOOP",
80 NLMSG_ERROR => "NLMSG_ERROR",
81 NLMSG_DONE => "NLMSG_DONE",
82 NLMSG_OVERRUN => "NLMSG_OVERRUN",
83 _ => "PROTOCOL_SPECIFIC",
84 }
85}
86
87#[derive(Debug, Clone, Copy)]
92pub struct NetlinkProtocol;
93
94impl Protocol for NetlinkProtocol {
95 fn name(&self) -> &'static str {
96 "netlink"
97 }
98
99 fn display_name(&self) -> &'static str {
100 "Netlink"
101 }
102
103 fn can_parse(&self, context: &ParseContext) -> Option<u32> {
104 if context.is_root() && context.link_type == LINKTYPE_NETLINK {
106 return Some(100);
107 }
108
109 if context.parent_protocol == Some("linux_sll") && context.hint("is_netlink") == Some(1) {
111 return Some(100);
112 }
113
114 None
115 }
116
117 fn parse<'a>(&self, data: &'a [u8], context: &ParseContext) -> ParseResult<'a> {
118 if data.len() < NETLINK_HEADER_LEN {
120 return ParseResult::error(
121 format!("Netlink message too short: {} bytes", data.len()),
122 data,
123 );
124 }
125
126 let mut fields = SmallVec::new();
127
128 let msg_len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
130 let msg_type = u16::from_le_bytes([data[4], data[5]]);
131 let msg_flags = u16::from_le_bytes([data[6], data[7]]);
132 let msg_seq = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
133 let msg_pid = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
134
135 fields.push(("msg_len", FieldValue::UInt32(msg_len)));
136 fields.push(("msg_type", FieldValue::UInt16(msg_type)));
137 fields.push(("msg_flags", FieldValue::UInt16(msg_flags)));
138 fields.push(("msg_seq", FieldValue::UInt32(msg_seq)));
139 fields.push(("msg_pid", FieldValue::UInt32(msg_pid)));
140
141 fields.push(("msg_type_name", FieldValue::Str(msg_type_name(msg_type))));
143
144 let is_request = (msg_flags & NLM_F_REQUEST) != 0;
146 let is_multipart = (msg_flags & NLM_F_MULTIPART) != 0;
147 let is_ack = (msg_flags & NLM_F_ACK) != 0;
148 let is_echo = (msg_flags & NLM_F_ECHO) != 0;
149
150 fields.push(("is_request", FieldValue::Bool(is_request)));
151 fields.push(("is_multipart", FieldValue::Bool(is_multipart)));
152 fields.push(("is_ack", FieldValue::Bool(is_ack)));
153 fields.push(("is_echo", FieldValue::Bool(is_echo)));
154
155 let nl_family = context
158 .hint("netlink_family")
159 .map(|f| f as u8)
160 .unwrap_or(family::ROUTE);
161
162 fields.push(("family", FieldValue::UInt8(nl_family)));
163 fields.push(("family_name", FieldValue::Str(family_name(nl_family))));
164
165 let payload_start = NETLINK_HEADER_LEN;
167 let payload_len = (msg_len as usize).saturating_sub(NETLINK_HEADER_LEN);
168 let remaining = if payload_start + payload_len <= data.len() {
169 &data[payload_start..payload_start + payload_len]
170 } else {
171 &data[payload_start..]
172 };
173
174 let mut child_hints = SmallVec::new();
176 child_hints.push(("netlink_family", nl_family as u64));
177 child_hints.push(("netlink_msg_type", msg_type as u64));
178
179 ParseResult::success(fields, remaining, child_hints)
180 }
181
182 fn schema_fields(&self) -> Vec<FieldDescriptor> {
183 vec![
184 FieldDescriptor::new("netlink.msg_len", DataKind::UInt32).set_nullable(true),
185 FieldDescriptor::new("netlink.msg_type", DataKind::UInt16).set_nullable(true),
186 FieldDescriptor::new("netlink.msg_flags", DataKind::UInt16).set_nullable(true),
187 FieldDescriptor::new("netlink.msg_seq", DataKind::UInt32).set_nullable(true),
188 FieldDescriptor::new("netlink.msg_pid", DataKind::UInt32).set_nullable(true),
189 FieldDescriptor::new("netlink.msg_type_name", DataKind::String).set_nullable(true),
190 FieldDescriptor::new("netlink.is_request", DataKind::Bool).set_nullable(true),
191 FieldDescriptor::new("netlink.is_multipart", DataKind::Bool).set_nullable(true),
192 FieldDescriptor::new("netlink.is_ack", DataKind::Bool).set_nullable(true),
193 FieldDescriptor::new("netlink.is_echo", DataKind::Bool).set_nullable(true),
194 FieldDescriptor::new("netlink.family", DataKind::UInt8).set_nullable(true),
195 FieldDescriptor::new("netlink.family_name", DataKind::String).set_nullable(true),
196 ]
197 }
198
199 fn child_protocols(&self) -> &[&'static str] {
200 &["rtnetlink"]
201 }
202
203 fn dependencies(&self) -> &'static [&'static str] {
204 &[] }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 fn create_netlink_header(
214 msg_len: u32,
215 msg_type: u16,
216 flags: u16,
217 seq: u32,
218 pid: u32,
219 ) -> Vec<u8> {
220 let mut header = Vec::with_capacity(16);
221 header.extend_from_slice(&msg_len.to_le_bytes());
222 header.extend_from_slice(&msg_type.to_le_bytes());
223 header.extend_from_slice(&flags.to_le_bytes());
224 header.extend_from_slice(&seq.to_le_bytes());
225 header.extend_from_slice(&pid.to_le_bytes());
226 header
227 }
228
229 #[test]
233 fn test_can_parse_netlink_at_root() {
234 let parser = NetlinkProtocol;
235 let ctx = ParseContext::new(LINKTYPE_NETLINK);
236
237 assert!(parser.can_parse(&ctx).is_some());
238 assert_eq!(parser.can_parse(&ctx), Some(100));
239 }
240
241 #[test]
245 fn test_cannot_parse_ethernet() {
246 let parser = NetlinkProtocol;
247 let ctx = ParseContext::new(1); assert!(parser.can_parse(&ctx).is_none());
250 }
251
252 #[test]
256 fn test_cannot_parse_when_not_root() {
257 let parser = NetlinkProtocol;
258 let mut ctx = ParseContext::new(LINKTYPE_NETLINK);
259 ctx.parent_protocol = Some("something");
260
261 assert!(parser.can_parse(&ctx).is_none());
262 }
263
264 #[test]
268 fn test_parse_netlink_header_basic() {
269 let header = create_netlink_header(32, 16, NLM_F_REQUEST, 1, 1234);
270 let parser = NetlinkProtocol;
271 let ctx = ParseContext::new(LINKTYPE_NETLINK);
272
273 let result = parser.parse(&header, &ctx);
274
275 assert!(result.is_ok());
276 assert_eq!(result.get("msg_len"), Some(&FieldValue::UInt32(32)));
277 assert_eq!(result.get("msg_type"), Some(&FieldValue::UInt16(16)));
278 assert_eq!(
279 result.get("msg_flags"),
280 Some(&FieldValue::UInt16(NLM_F_REQUEST))
281 );
282 assert_eq!(result.get("msg_seq"), Some(&FieldValue::UInt32(1)));
283 assert_eq!(result.get("msg_pid"), Some(&FieldValue::UInt32(1234)));
284 }
285
286 #[test]
290 fn test_parse_netlink_header_too_short() {
291 let short_data = vec![0u8; 10]; let parser = NetlinkProtocol;
293 let ctx = ParseContext::new(LINKTYPE_NETLINK);
294
295 let result = parser.parse(&short_data, &ctx);
296
297 assert!(!result.is_ok());
298 assert!(result.error.is_some());
299 }
300
301 #[test]
305 fn test_parse_nlmsg_done() {
306 let header = create_netlink_header(16, NLMSG_DONE, 0, 1, 0);
307 let parser = NetlinkProtocol;
308 let ctx = ParseContext::new(LINKTYPE_NETLINK);
309
310 let result = parser.parse(&header, &ctx);
311
312 assert!(result.is_ok());
313 assert_eq!(
314 result.get("msg_type_name"),
315 Some(&FieldValue::Str("NLMSG_DONE"))
316 );
317 }
318
319 #[test]
323 fn test_request_flag() {
324 let header = create_netlink_header(16, 16, NLM_F_REQUEST, 1, 0);
325 let parser = NetlinkProtocol;
326 let ctx = ParseContext::new(LINKTYPE_NETLINK);
327
328 let result = parser.parse(&header, &ctx);
329
330 assert!(result.is_ok());
331 assert_eq!(result.get("is_request"), Some(&FieldValue::Bool(true)));
332 assert_eq!(result.get("is_multipart"), Some(&FieldValue::Bool(false)));
333 }
334
335 #[test]
339 fn test_multipart_message_flag() {
340 let header = create_netlink_header(32, 16, NLM_F_MULTIPART, 1, 0);
341 let parser = NetlinkProtocol;
342 let ctx = ParseContext::new(LINKTYPE_NETLINK);
343
344 let result = parser.parse(&header, &ctx);
345
346 assert!(result.is_ok());
347 assert_eq!(result.get("is_multipart"), Some(&FieldValue::Bool(true)));
348 assert_eq!(result.get("is_request"), Some(&FieldValue::Bool(false)));
349 }
350
351 #[test]
355 fn test_combined_flags() {
356 let flags = NLM_F_REQUEST | NLM_F_MULTIPART | NLM_F_ACK;
357 let header = create_netlink_header(16, 16, flags, 1, 0);
358 let parser = NetlinkProtocol;
359 let ctx = ParseContext::new(LINKTYPE_NETLINK);
360
361 let result = parser.parse(&header, &ctx);
362
363 assert!(result.is_ok());
364 assert_eq!(result.get("is_request"), Some(&FieldValue::Bool(true)));
365 assert_eq!(result.get("is_multipart"), Some(&FieldValue::Bool(true)));
366 assert_eq!(result.get("is_ack"), Some(&FieldValue::Bool(true)));
367 assert_eq!(result.get("is_echo"), Some(&FieldValue::Bool(false)));
368 }
369
370 #[test]
374 fn test_child_hints_for_rtnetlink() {
375 let header = create_netlink_header(32, 16, 0, 1, 0);
376 let parser = NetlinkProtocol;
377 let mut ctx = ParseContext::new(LINKTYPE_NETLINK);
378 ctx.insert_hint("netlink_family", family::ROUTE as u64);
379
380 let result = parser.parse(&header, &ctx);
381
382 assert!(result.is_ok());
383 assert!(result
385 .child_hints
386 .iter()
387 .any(|(k, v)| *k == "netlink_family" && *v == family::ROUTE as u64));
388 assert!(result
389 .child_hints
390 .iter()
391 .any(|(k, v)| *k == "netlink_msg_type" && *v == 16));
392 }
393
394 #[test]
398 fn test_family_name_route() {
399 let header = create_netlink_header(16, 16, 0, 1, 0);
400 let parser = NetlinkProtocol;
401 let mut ctx = ParseContext::new(LINKTYPE_NETLINK);
402 ctx.insert_hint("netlink_family", family::ROUTE as u64);
403
404 let result = parser.parse(&header, &ctx);
405
406 assert!(result.is_ok());
407 assert_eq!(
408 result.get("family"),
409 Some(&FieldValue::UInt8(family::ROUTE))
410 );
411 assert_eq!(result.get("family_name"), Some(&FieldValue::Str("ROUTE")));
412 }
413
414 #[test]
418 fn test_netlink_schema_fields() {
419 let parser = NetlinkProtocol;
420 let fields = parser.schema_fields();
421
422 let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
423
424 assert!(field_names.contains(&"netlink.msg_len"));
425 assert!(field_names.contains(&"netlink.msg_type"));
426 assert!(field_names.contains(&"netlink.msg_flags"));
427 assert!(field_names.contains(&"netlink.msg_seq"));
428 assert!(field_names.contains(&"netlink.msg_pid"));
429 assert!(field_names.contains(&"netlink.msg_type_name"));
430 assert!(field_names.contains(&"netlink.is_request"));
431 assert!(field_names.contains(&"netlink.is_multipart"));
432 assert!(field_names.contains(&"netlink.family"));
433 assert!(field_names.contains(&"netlink.family_name"));
434 }
435
436 #[test]
440 fn test_netlink_child_protocols() {
441 let parser = NetlinkProtocol;
442 let children = parser.child_protocols();
443
444 assert!(children.contains(&"rtnetlink"));
445 }
446
447 #[test]
451 fn test_netlink_no_dependencies() {
452 let parser = NetlinkProtocol;
453 let deps = parser.dependencies();
454
455 assert!(deps.is_empty());
456 }
457
458 #[test]
462 fn test_parse_packet_selects_netlink() {
463 use crate::protocol::{default_registry, parse_packet};
464
465 let registry = default_registry();
466 let header = create_netlink_header(32, 16, NLM_F_REQUEST, 1, 1234);
467
468 let results = parse_packet(®istry, LINKTYPE_NETLINK, &header);
470
471 assert!(
473 !results.is_empty(),
474 "parse_packet should return at least one protocol"
475 );
476
477 let protocol_names: Vec<&str> = results.iter().map(|(name, _)| *name).collect();
478 assert!(
479 protocol_names.contains(&"netlink"),
480 "parse_packet should select 'netlink' for LINKTYPE_NETLINK, got: {:?}",
481 protocol_names
482 );
483
484 let (_, netlink_result) = results.iter().find(|(name, _)| *name == "netlink").unwrap();
486 assert_eq!(
487 netlink_result.get("msg_type"),
488 Some(&FieldValue::UInt16(16))
489 );
490 assert_eq!(
491 netlink_result.get("is_request"),
492 Some(&FieldValue::Bool(true))
493 );
494 }
495}