1use smallvec::SmallVec;
11
12use super::{FieldValue, ParseContext, ParseResult, Protocol, TunnelType};
13use crate::schema::{DataKind, FieldDescriptor};
14
15pub const IP_PROTOCOL_GRE: u8 = 47;
17
18pub mod gre_version {
20 pub const STANDARD: u8 = 0;
22 pub const PPTP_ENHANCED: u8 = 1;
24}
25
26fn internet_checksum(data: &[u8]) -> u16 {
29 let mut sum: u32 = 0;
30 let mut i = 0;
31
32 while i + 1 < data.len() {
34 sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
35 i += 2;
36 }
37
38 if i < data.len() {
40 sum += (data[i] as u32) << 8;
41 }
42
43 while sum >> 16 != 0 {
45 sum = (sum & 0xFFFF) + (sum >> 16);
46 }
47
48 !sum as u16
49}
50
51#[derive(Debug, Clone, Copy)]
53pub struct GreProtocol;
54
55impl Protocol for GreProtocol {
56 fn name(&self) -> &'static str {
57 "gre"
58 }
59
60 fn display_name(&self) -> &'static str {
61 "GRE"
62 }
63
64 fn can_parse(&self, context: &ParseContext) -> Option<u32> {
65 match context.hint("ip_protocol") {
67 Some(proto) if proto == IP_PROTOCOL_GRE as u64 => Some(100),
68 _ => None,
69 }
70 }
71
72 fn parse<'a>(&self, data: &'a [u8], _context: &ParseContext) -> ParseResult<'a> {
73 if data.len() < 4 {
75 return ParseResult::error("GRE header too short".to_string(), data);
76 }
77
78 let mut fields = SmallVec::new();
79 let mut offset = 0;
80
81 let flags = u16::from_be_bytes([data[0], data[1]]);
91
92 let checksum_present = (flags & 0x8000) != 0;
93 let key_present = (flags & 0x2000) != 0;
94 let sequence_present = (flags & 0x1000) != 0;
95 let version = (flags & 0x0007) as u8;
96
97 fields.push(("checksum_present", FieldValue::Bool(checksum_present)));
98 fields.push(("key_present", FieldValue::Bool(key_present)));
99 fields.push(("sequence_present", FieldValue::Bool(sequence_present)));
100 fields.push(("version", FieldValue::UInt8(version)));
101
102 let version_valid =
104 version == gre_version::STANDARD || version == gre_version::PPTP_ENHANCED;
105 fields.push(("version_valid", FieldValue::Bool(version_valid)));
106
107 let version_name = match version {
109 gre_version::STANDARD => "Standard",
110 gre_version::PPTP_ENHANCED => "PPTP-Enhanced",
111 _ => "Unknown",
112 };
113 fields.push(("version_name", FieldValue::Str(version_name)));
114
115 offset += 2;
116
117 let protocol_type = u16::from_be_bytes([data[offset], data[offset + 1]]);
119 fields.push(("protocol", FieldValue::UInt16(protocol_type)));
120 offset += 2;
121
122 let header_start = 0;
124
125 if checksum_present {
129 if data.len() < offset + 4 {
130 return ParseResult::error("GRE: missing checksum field".to_string(), data);
131 }
132 let checksum = u16::from_be_bytes([data[offset], data[offset + 1]]);
133 fields.push(("checksum", FieldValue::UInt16(checksum)));
134
135 let computed = internet_checksum(data);
140 let checksum_valid = computed == 0;
141 fields.push(("checksum_valid", FieldValue::Bool(checksum_valid)));
142
143 offset += 4;
145 }
146
147 let mut key_value: Option<u32> = None;
149 if key_present {
150 if data.len() < offset + 4 {
151 return ParseResult::error("GRE: missing key field".to_string(), data);
152 }
153 let key = u32::from_be_bytes([
154 data[offset],
155 data[offset + 1],
156 data[offset + 2],
157 data[offset + 3],
158 ]);
159 fields.push(("key", FieldValue::UInt32(key)));
160 key_value = Some(key);
161 offset += 4;
162 }
163
164 if sequence_present {
166 if data.len() < offset + 4 {
167 return ParseResult::error("GRE: missing sequence field".to_string(), data);
168 }
169 let sequence = u32::from_be_bytes([
170 data[offset],
171 data[offset + 1],
172 data[offset + 2],
173 data[offset + 3],
174 ]);
175 fields.push(("sequence", FieldValue::UInt32(sequence)));
176 offset += 4;
177 }
178
179 fields.push((
181 "header_length",
182 FieldValue::UInt8((offset - header_start) as u8),
183 ));
184
185 let mut child_hints = SmallVec::new();
187 child_hints.push(("ethertype", protocol_type as u64));
188
189 if let Some(key) = key_value {
191 child_hints.push(("gre_key", key as u64));
192 }
193
194 child_hints.push(("tunnel_type", TunnelType::Gre as u64));
202 if let Some(key) = key_value {
204 child_hints.push(("tunnel_id", key as u64));
205 }
206
207 ParseResult::success(fields, &data[offset..], child_hints)
208 }
209
210 fn schema_fields(&self) -> Vec<FieldDescriptor> {
211 vec![
212 FieldDescriptor::new("gre.checksum_present", DataKind::Bool).set_nullable(true),
213 FieldDescriptor::new("gre.key_present", DataKind::Bool).set_nullable(true),
214 FieldDescriptor::new("gre.sequence_present", DataKind::Bool).set_nullable(true),
215 FieldDescriptor::new("gre.version", DataKind::UInt8).set_nullable(true),
216 FieldDescriptor::new("gre.version_valid", DataKind::Bool).set_nullable(true),
217 FieldDescriptor::new("gre.version_name", DataKind::String).set_nullable(true),
218 FieldDescriptor::new("gre.protocol", DataKind::UInt16).set_nullable(true),
219 FieldDescriptor::new("gre.checksum", DataKind::UInt16).set_nullable(true),
220 FieldDescriptor::new("gre.checksum_valid", DataKind::Bool).set_nullable(true),
221 FieldDescriptor::new("gre.key", DataKind::UInt32).set_nullable(true),
222 FieldDescriptor::new("gre.sequence", DataKind::UInt32).set_nullable(true),
223 FieldDescriptor::new("gre.header_length", DataKind::UInt8).set_nullable(true),
224 ]
225 }
226
227 fn child_protocols(&self) -> &[&'static str] {
228 &["ipv4", "ipv6", "ethernet"]
230 }
231
232 fn dependencies(&self) -> &'static [&'static str] {
233 &["ipv4", "ipv6"] }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::protocol::ethernet::ethertype;
241
242 fn create_gre_header(checksum: bool, key: bool, sequence: bool, protocol: u16) -> Vec<u8> {
244 let mut header = Vec::new();
245
246 let mut flags: u16 = 0;
248 if checksum {
249 flags |= 0x8000;
250 }
251 if key {
252 flags |= 0x2000;
253 }
254 if sequence {
255 flags |= 0x1000;
256 }
257 header.extend_from_slice(&flags.to_be_bytes());
260 header.extend_from_slice(&protocol.to_be_bytes());
261
262 header
263 }
264
265 #[test]
267 fn test_can_parse_with_ip_protocol_47() {
268 let parser = GreProtocol;
269
270 let ctx1 = ParseContext::new(1);
272 assert!(parser.can_parse(&ctx1).is_none());
273
274 let mut ctx2 = ParseContext::new(1);
276 ctx2.insert_hint("ip_protocol", 6); assert!(parser.can_parse(&ctx2).is_none());
278
279 let mut ctx3 = ParseContext::new(1);
281 ctx3.insert_hint("ip_protocol", 47);
282 assert!(parser.can_parse(&ctx3).is_some());
283 assert_eq!(parser.can_parse(&ctx3), Some(100));
284 }
285
286 #[test]
288 fn test_basic_gre_header_parsing() {
289 let mut header = create_gre_header(false, false, false, ethertype::IPV4); header.extend_from_slice(&[0x45, 0x00, 0x00, 0x28]);
292
293 let parser = GreProtocol;
294 let mut context = ParseContext::new(1);
295 context.insert_hint("ip_protocol", 47);
296
297 let result = parser.parse(&header, &context);
298
299 assert!(result.is_ok());
300 assert_eq!(
301 result.get("checksum_present"),
302 Some(&FieldValue::Bool(false))
303 );
304 assert_eq!(result.get("key_present"), Some(&FieldValue::Bool(false)));
305 assert_eq!(
306 result.get("sequence_present"),
307 Some(&FieldValue::Bool(false))
308 );
309 assert_eq!(result.get("version"), Some(&FieldValue::UInt8(0)));
310 assert_eq!(
311 result.get("protocol"),
312 Some(&FieldValue::UInt16(ethertype::IPV4))
313 );
314
315 assert!(result.get("checksum").is_none());
317 assert!(result.get("key").is_none());
318 assert!(result.get("sequence").is_none());
319
320 assert_eq!(result.remaining.len(), 4);
322 }
323
324 #[test]
326 fn test_gre_with_checksum() {
327 let mut header = create_gre_header(true, false, false, ethertype::IPV4);
328 header.extend_from_slice(&[0xAB, 0xCD, 0x00, 0x00]);
330 header.extend_from_slice(&[0x45, 0x00]);
332
333 let parser = GreProtocol;
334 let mut context = ParseContext::new(1);
335 context.insert_hint("ip_protocol", 47);
336
337 let result = parser.parse(&header, &context);
338
339 assert!(result.is_ok());
340 assert_eq!(
341 result.get("checksum_present"),
342 Some(&FieldValue::Bool(true))
343 );
344 assert_eq!(result.get("checksum"), Some(&FieldValue::UInt16(0xABCD)));
345 assert_eq!(result.remaining.len(), 2);
346 }
347
348 #[test]
350 fn test_gre_with_key() {
351 let mut header = create_gre_header(false, true, false, ethertype::IPV4);
352 header.extend_from_slice(&[0x00, 0x01, 0x02, 0x03]);
354 header.extend_from_slice(&[0x45, 0x00]);
356
357 let parser = GreProtocol;
358 let mut context = ParseContext::new(1);
359 context.insert_hint("ip_protocol", 47);
360
361 let result = parser.parse(&header, &context);
362
363 assert!(result.is_ok());
364 assert_eq!(result.get("key_present"), Some(&FieldValue::Bool(true)));
365 assert_eq!(result.get("key"), Some(&FieldValue::UInt32(0x00010203)));
366 assert_eq!(result.hint("gre_key"), Some(0x00010203u64));
367 assert_eq!(result.remaining.len(), 2);
368 }
369
370 #[test]
372 fn test_gre_with_sequence() {
373 let mut header = create_gre_header(false, false, true, ethertype::IPV4);
374 header.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
376 header.extend_from_slice(&[0x45, 0x00]);
378
379 let parser = GreProtocol;
380 let mut context = ParseContext::new(1);
381 context.insert_hint("ip_protocol", 47);
382
383 let result = parser.parse(&header, &context);
384
385 assert!(result.is_ok());
386 assert_eq!(
387 result.get("sequence_present"),
388 Some(&FieldValue::Bool(true))
389 );
390 assert_eq!(
391 result.get("sequence"),
392 Some(&FieldValue::UInt32(0xDEADBEEF))
393 );
394 assert_eq!(result.remaining.len(), 2);
395 }
396
397 #[test]
399 fn test_gre_with_all_optional_fields() {
400 let mut header = create_gre_header(true, true, true, ethertype::IPV6); header.extend_from_slice(&[0x12, 0x34, 0x00, 0x00]);
403 header.extend_from_slice(&[0xAA, 0xBB, 0xCC, 0xDD]);
405 header.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
407 header.extend_from_slice(&[0x60, 0x00]);
409
410 let parser = GreProtocol;
411 let mut context = ParseContext::new(1);
412 context.insert_hint("ip_protocol", 47);
413
414 let result = parser.parse(&header, &context);
415
416 assert!(result.is_ok());
417 assert_eq!(
418 result.get("checksum_present"),
419 Some(&FieldValue::Bool(true))
420 );
421 assert_eq!(result.get("key_present"), Some(&FieldValue::Bool(true)));
422 assert_eq!(
423 result.get("sequence_present"),
424 Some(&FieldValue::Bool(true))
425 );
426 assert_eq!(result.get("checksum"), Some(&FieldValue::UInt16(0x1234)));
427 assert_eq!(result.get("key"), Some(&FieldValue::UInt32(0xAABBCCDD)));
428 assert_eq!(result.get("sequence"), Some(&FieldValue::UInt32(1)));
429 assert_eq!(
430 result.get("protocol"),
431 Some(&FieldValue::UInt16(ethertype::IPV6))
432 );
433 assert_eq!(result.remaining.len(), 2);
434 }
435
436 #[test]
438 fn test_child_protocol_hint_ethertype() {
439 let header_ipv4 = create_gre_header(false, false, false, ethertype::IPV4);
441 let parser = GreProtocol;
442 let mut context = ParseContext::new(1);
443 context.insert_hint("ip_protocol", 47);
444
445 let result = parser.parse(&header_ipv4, &context);
446 assert!(result.is_ok());
447 assert_eq!(result.hint("ethertype"), Some(ethertype::IPV4 as u64));
448
449 let header_ipv6 = create_gre_header(false, false, false, ethertype::IPV6);
451 let result = parser.parse(&header_ipv6, &context);
452 assert!(result.is_ok());
453 assert_eq!(result.hint("ethertype"), Some(ethertype::IPV6 as u64));
454
455 let header_teb = create_gre_header(false, false, false, 0x6558);
457 let result = parser.parse(&header_teb, &context);
458 assert!(result.is_ok());
459 assert_eq!(result.hint("ethertype"), Some(0x6558u64));
460 }
461
462 #[test]
464 fn test_gre_too_short() {
465 let short_header = [0x00, 0x00]; let parser = GreProtocol;
468 let mut context = ParseContext::new(1);
469 context.insert_hint("ip_protocol", 47);
470
471 let result = parser.parse(&short_header, &context);
472 assert!(!result.is_ok());
473 assert!(result.error.is_some());
474 }
475
476 #[test]
478 fn test_gre_missing_key_field() {
479 let header = create_gre_header(false, true, false, ethertype::IPV4); let parser = GreProtocol;
482 let mut context = ParseContext::new(1);
483 context.insert_hint("ip_protocol", 47);
484
485 let result = parser.parse(&header, &context);
486 assert!(!result.is_ok());
487 assert!(result.error.unwrap().contains("missing key field"));
488 }
489
490 #[test]
492 fn test_gre_schema_fields() {
493 let parser = GreProtocol;
494 let fields = parser.schema_fields();
495
496 assert!(!fields.is_empty());
497 let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
498 assert!(field_names.contains(&"gre.checksum_present"));
499 assert!(field_names.contains(&"gre.key_present"));
500 assert!(field_names.contains(&"gre.sequence_present"));
501 assert!(field_names.contains(&"gre.version"));
502 assert!(field_names.contains(&"gre.protocol"));
503 assert!(field_names.contains(&"gre.checksum"));
504 assert!(field_names.contains(&"gre.key"));
505 assert!(field_names.contains(&"gre.sequence"));
506 }
507
508 #[test]
510 fn test_version_0_standard_gre() {
511 let parser = GreProtocol;
512 let mut context = ParseContext::new(1);
513 context.insert_hint("ip_protocol", 47);
514
515 let header = create_gre_header(false, false, false, ethertype::IPV4);
517 let result = parser.parse(&header, &context);
518
519 assert!(result.is_ok());
520 assert_eq!(
521 result.get("version"),
522 Some(&FieldValue::UInt8(gre_version::STANDARD))
523 );
524 assert_eq!(result.get("version_valid"), Some(&FieldValue::Bool(true)));
525 assert_eq!(
526 result.get("version_name"),
527 Some(&FieldValue::Str("Standard"))
528 );
529 }
530
531 #[test]
533 fn test_version_1_pptp_enhanced() {
534 let parser = GreProtocol;
535 let mut context = ParseContext::new(1);
536 context.insert_hint("ip_protocol", 47);
537
538 let mut header = Vec::new();
540 let flags: u16 = 0x0001; header.extend_from_slice(&flags.to_be_bytes());
542 header.extend_from_slice(&0x880Bu16.to_be_bytes()); let result = parser.parse(&header, &context);
545
546 assert!(result.is_ok());
547 assert_eq!(
548 result.get("version"),
549 Some(&FieldValue::UInt8(gre_version::PPTP_ENHANCED))
550 );
551 assert_eq!(result.get("version_valid"), Some(&FieldValue::Bool(true)));
552 assert_eq!(
553 result.get("version_name"),
554 Some(&FieldValue::Str("PPTP-Enhanced"))
555 );
556 }
557
558 #[test]
560 fn test_invalid_version() {
561 let parser = GreProtocol;
562 let mut context = ParseContext::new(1);
563 context.insert_hint("ip_protocol", 47);
564
565 for version in 2..=7u16 {
567 let mut header = Vec::new();
568 let flags: u16 = version; header.extend_from_slice(&flags.to_be_bytes());
570 header.extend_from_slice(ðertype::IPV4.to_be_bytes());
571
572 let result = parser.parse(&header, &context);
573
574 assert!(result.is_ok()); assert_eq!(
576 result.get("version"),
577 Some(&FieldValue::UInt8(version as u8))
578 );
579 assert_eq!(result.get("version_valid"), Some(&FieldValue::Bool(false)));
580 assert_eq!(
581 result.get("version_name"),
582 Some(&FieldValue::Str("Unknown"))
583 );
584 }
585 }
586
587 #[test]
589 fn test_checksum_valid() {
590 let parser = GreProtocol;
591 let mut context = ParseContext::new(1);
592 context.insert_hint("ip_protocol", 47);
593
594 let mut packet = Vec::new();
596 let flags: u16 = 0x8000; packet.extend_from_slice(&flags.to_be_bytes());
598 packet.extend_from_slice(ðertype::IPV4.to_be_bytes()); packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); packet.extend_from_slice(&[0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00]);
605
606 let checksum = internet_checksum(&packet);
608
609 packet[4] = (checksum >> 8) as u8;
611 packet[5] = (checksum & 0xFF) as u8;
612
613 let result = parser.parse(&packet, &context);
614
615 assert!(result.is_ok());
616 assert_eq!(
617 result.get("checksum_present"),
618 Some(&FieldValue::Bool(true))
619 );
620 assert_eq!(result.get("checksum_valid"), Some(&FieldValue::Bool(true)));
621 }
622
623 #[test]
625 fn test_checksum_invalid() {
626 let parser = GreProtocol;
627 let mut context = ParseContext::new(1);
628 context.insert_hint("ip_protocol", 47);
629
630 let mut packet = Vec::new();
632 let flags: u16 = 0x8000; packet.extend_from_slice(&flags.to_be_bytes());
634 packet.extend_from_slice(ðertype::IPV4.to_be_bytes());
635
636 packet.extend_from_slice(&[0xAB, 0xCD, 0x00, 0x00]);
638
639 packet.extend_from_slice(&[0x45, 0x00, 0x00, 0x28]);
641
642 let result = parser.parse(&packet, &context);
643
644 assert!(result.is_ok()); assert_eq!(
646 result.get("checksum_present"),
647 Some(&FieldValue::Bool(true))
648 );
649 assert_eq!(result.get("checksum_valid"), Some(&FieldValue::Bool(false)));
650 }
651
652 #[test]
654 fn test_header_length() {
655 let parser = GreProtocol;
656 let mut context = ParseContext::new(1);
657 context.insert_hint("ip_protocol", 47);
658
659 let header_min = create_gre_header(false, false, false, ethertype::IPV4);
661 let result = parser.parse(&header_min, &context);
662 assert!(result.is_ok());
663 assert_eq!(result.get("header_length"), Some(&FieldValue::UInt8(4)));
664
665 let mut header_chk = create_gre_header(true, false, false, ethertype::IPV4);
667 header_chk.extend_from_slice(&[0x00; 4]); let result = parser.parse(&header_chk, &context);
669 assert!(result.is_ok());
670 assert_eq!(result.get("header_length"), Some(&FieldValue::UInt8(8)));
671
672 let mut header_key = create_gre_header(false, true, false, ethertype::IPV4);
674 header_key.extend_from_slice(&[0x00; 4]); let result = parser.parse(&header_key, &context);
676 assert!(result.is_ok());
677 assert_eq!(result.get("header_length"), Some(&FieldValue::UInt8(8)));
678
679 let mut header_seq = create_gre_header(false, false, true, ethertype::IPV4);
681 header_seq.extend_from_slice(&[0x00; 4]); let result = parser.parse(&header_seq, &context);
683 assert!(result.is_ok());
684 assert_eq!(result.get("header_length"), Some(&FieldValue::UInt8(8)));
685
686 let mut header_all = create_gre_header(true, true, true, ethertype::IPV4);
688 header_all.extend_from_slice(&[0x00; 4]); header_all.extend_from_slice(&[0x00; 4]); header_all.extend_from_slice(&[0x00; 4]); let result = parser.parse(&header_all, &context);
692 assert!(result.is_ok());
693 assert_eq!(result.get("header_length"), Some(&FieldValue::UInt8(16)));
694 }
695
696 #[test]
698 fn test_schema_fields_complete() {
699 let parser = GreProtocol;
700 let fields = parser.schema_fields();
701
702 let field_names: Vec<&str> = fields.iter().map(|f| f.name).collect();
703 assert!(field_names.contains(&"gre.checksum_present"));
704 assert!(field_names.contains(&"gre.key_present"));
705 assert!(field_names.contains(&"gre.sequence_present"));
706 assert!(field_names.contains(&"gre.version"));
707 assert!(field_names.contains(&"gre.version_valid"));
708 assert!(field_names.contains(&"gre.version_name"));
709 assert!(field_names.contains(&"gre.protocol"));
710 assert!(field_names.contains(&"gre.checksum"));
711 assert!(field_names.contains(&"gre.checksum_valid"));
712 assert!(field_names.contains(&"gre.key"));
713 assert!(field_names.contains(&"gre.sequence"));
714 assert!(field_names.contains(&"gre.header_length"));
715 }
716
717 #[test]
719 fn test_internet_checksum_function() {
720 let zeros = [0u8; 10];
723 assert_eq!(internet_checksum(&zeros), 0xFFFF);
724
725 let ffff = [0xFF, 0xFF];
727 assert_eq!(internet_checksum(&ffff), 0x0000);
728
729 let odd = [0x01, 0x02, 0x03];
731 let _cksum = internet_checksum(&odd);
732 let mut test_data = vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05];
736 let initial_sum = internet_checksum(&test_data);
737 test_data.push((initial_sum >> 8) as u8);
739 test_data.push((initial_sum & 0xFF) as u8);
740 assert_eq!(internet_checksum(&test_data), 0);
742 }
743
744 #[test]
746 fn test_gre_missing_checksum_field() {
747 let parser = GreProtocol;
748 let mut context = ParseContext::new(1);
749 context.insert_hint("ip_protocol", 47);
750
751 let header = create_gre_header(true, false, false, ethertype::IPV4);
753
754 let result = parser.parse(&header, &context);
755 assert!(!result.is_ok());
756 assert!(result.error.unwrap().contains("missing checksum field"));
757 }
758
759 #[test]
761 fn test_gre_missing_sequence_field() {
762 let parser = GreProtocol;
763 let mut context = ParseContext::new(1);
764 context.insert_hint("ip_protocol", 47);
765
766 let header = create_gre_header(false, false, true, ethertype::IPV4);
768
769 let result = parser.parse(&header, &context);
770 assert!(!result.is_ok());
771 assert!(result.error.unwrap().contains("missing sequence field"));
772 }
773}