libnsave/
flow.rs

1use crate::packet::{Packet, PacketKey, TransProto};
2use crate::store::*;
3use crate::tmohash::TmoHash;
4use etherparse::IpHeader;
5use std::net::IpAddr;
6
7const MAX_TABLE_CAPACITY: usize = 1024;
8const NODE_TIMEOUT: u128 = 10_000_000_000; // 10
9const MAX_SEQ_GAP: usize = 8;
10
11#[derive(Debug)]
12pub struct FlowNode {
13    pub key: PacketKey,
14    pub start_time: u128,
15    pub last_time: u128,
16    seq_strm1: SeqStream,
17    seq_strm2: SeqStream,
18
19    pub store_ctx: Option<StoreCtx>,
20}
21
22impl FlowNode {
23    fn new(key: PacketKey, now: u128, max_seq_gap: usize) -> Self {
24        FlowNode {
25            key,
26            start_time: now,
27            last_time: now,
28            seq_strm1: SeqStream::new_with_arg(max_seq_gap),
29            seq_strm2: SeqStream::new_with_arg(max_seq_gap),
30            store_ctx: None,
31        }
32    }
33
34    pub fn update(&mut self, pkt: &Packet, now: u128) {
35        self.last_time = now;
36        if pkt.trans_proto() == TransProto::Tcp {
37            match &pkt.header.borrow().as_ref().unwrap().ip {
38                Some(IpHeader::Version4(ipv4h, _)) => {
39                    if self.key.addr1 == <[u8; 4] as std::convert::Into<IpAddr>>::into(ipv4h.source)
40                        && self.key.port1 == pkt.sport()
41                    {
42                        self.seq_strm1.update(pkt);
43                    } else if self.key.addr2
44                        == <[u8; 4] as std::convert::Into<IpAddr>>::into(ipv4h.source)
45                        && self.key.port2 == pkt.sport()
46                    {
47                        self.seq_strm2.update(pkt);
48                    }
49                }
50                Some(IpHeader::Version6(ipv6h, _)) => {
51                    if self.key.addr1
52                        == <[u8; 16] as std::convert::Into<IpAddr>>::into(ipv6h.source)
53                        && self.key.port1 == pkt.sport()
54                    {
55                        self.seq_strm1.update(pkt);
56                    } else if self.key.addr2
57                        == <[u8; 16] as std::convert::Into<IpAddr>>::into(ipv6h.source)
58                        && self.key.port2 == pkt.sport()
59                    {
60                        self.seq_strm2.update(pkt);
61                    }
62                }
63                None => {}
64            }
65        }
66    }
67
68    pub fn is_fin(&self) -> bool {
69        self.seq_strm1.is_fin() && self.seq_strm2.is_fin()
70    }
71}
72
73#[derive(Debug)]
74struct SeqSeg {
75    start: u32,
76    next: u32,
77}
78
79#[derive(Debug)]
80struct SeqStream {
81    segment: Vec<SeqSeg>,
82    fin: bool,
83}
84
85impl SeqStream {
86    #[allow(dead_code)]
87    fn new() -> Self {
88        SeqStream {
89            segment: Vec::with_capacity(MAX_SEQ_GAP),
90            fin: false,
91        }
92    }
93
94    fn new_with_arg(max_seq_gap: usize) -> Self {
95        SeqStream {
96            segment: Vec::with_capacity(max_seq_gap),
97            fin: false,
98        }
99    }
100
101    fn update(&mut self, pkt: &Packet) {
102        if self.segment.len() > MAX_SEQ_GAP {
103            return;
104        }
105
106        if pkt.fin() {
107            self.fin = true
108        }
109
110        let new_seg = if pkt.syn() && pkt.payload_len() == 0 {
111            SeqSeg {
112                start: pkt.seq(),
113                next: pkt.seq() + 1,
114            }
115        } else {
116            SeqSeg {
117                start: pkt.seq(),
118                next: pkt.seq() + pkt.payload_len(),
119            }
120        };
121
122        if self.segment.is_empty() {
123            self.segment.push(new_seg);
124            return;
125        }
126
127        // case 1
128        // vec:                  start,next  start,next
129        // new_seg: start,next
130        if new_seg.next < self.segment[0].start {
131            self.segment.insert(0, new_seg);
132            return;
133        }
134
135        // case 2
136        // vec:           start,next  start,next
137        // new_seg: start,next
138        if new_seg.next == self.segment[0].start {
139            self.segment[0].start = new_seg.start;
140            return;
141        }
142
143        // case 3
144        // vec:     start,next  start,next
145        // new_seg:                   start,next
146        if new_seg.start == self.segment[self.segment.len() - 1].next {
147            let last_index = self.segment.len() - 1;
148            self.segment[last_index].next = new_seg.next;
149            return;
150        }
151
152        // case 4
153        // vec:     start,next  start,next
154        // new_seg:                          start,next
155        if new_seg.start > self.segment[self.segment.len() - 1].next {
156            self.segment.push(new_seg);
157            return;
158        }
159
160        // 段之间段空洞情况
161        let mut i = 0;
162        while i < self.segment.len() - 1 {
163            // case 5
164            // vec:     start,next  start,next
165            // new_seg:       start,next
166            if new_seg.start == self.segment[i].next && new_seg.next == self.segment[i + 1].start {
167                self.segment[i].next = self.segment[i + 1].next;
168                self.segment.remove(i + 1);
169                return;
170            }
171
172            // case 6
173            // vec:     start,next        start,next
174            // new_seg:       start,next
175            if new_seg.start == self.segment[i].next && new_seg.next < self.segment[i + 1].start {
176                self.segment[i].next = new_seg.next;
177                return;
178            }
179
180            // case 7
181            // vec:     start,next        start,next
182            // new_seg:             start,next
183            if new_seg.start > self.segment[i].next && new_seg.next == self.segment[i + 1].start {
184                self.segment[i + 1].start = new_seg.start;
185                return;
186            }
187
188            // case 8
189            // vec:     start,next              start,next
190            // new_seg:             start,next
191            if new_seg.start > self.segment[i].next && new_seg.next < self.segment[i + 1].start {
192                self.segment.insert(i + 1, new_seg);
193                return;
194            }
195
196            i += 1;
197        }
198    }
199
200    fn is_fin(&self) -> bool {
201        self.fin && self.segment.len() == 1
202    }
203}
204
205pub struct Flow {
206    node_timeout: u128,
207    max_seq_gap: usize,
208    table: TmoHash<PacketKey, FlowNode>,
209}
210
211impl Flow {
212    pub fn new() -> Self {
213        Flow {
214            node_timeout: NODE_TIMEOUT,
215            max_seq_gap: MAX_SEQ_GAP,
216            table: TmoHash::new(MAX_TABLE_CAPACITY),
217        }
218    }
219
220    pub fn new_with_arg(max_table_capacity: usize, node_timeout: u128, max_seq_gap: usize) -> Self {
221        Flow {
222            node_timeout,
223            max_seq_gap,
224            table: TmoHash::new(max_table_capacity),
225        }
226    }
227
228    pub fn contains_key(&self, key: &PacketKey) -> bool {
229        self.table.contains_key(key)
230    }
231
232    // 返回插入节点的引用。如果已经存在,返回None
233    fn insert(&mut self, pkt: &Packet, now: u128) -> Option<&FlowNode> {
234        let key = pkt.hash_key();
235        if self.contains_key(&key) {
236            return None;
237        }
238        self.table
239            .insert(key, FlowNode::new(key, now, self.max_seq_gap))
240    }
241
242    // 返回插入节点的可变引用。如果已经存在,返回None
243    fn insert_mut(&mut self, pkt: &Packet, now: u128) -> Option<&mut FlowNode> {
244        let key = pkt.hash_key();
245        if self.contains_key(&key) {
246            return None;
247        }
248        self.table
249            .insert_mut(key, FlowNode::new(key, now, self.max_seq_gap))
250    }
251
252    // 返回packet所在节点的引用。如果不存在,返回None
253    pub fn get(&self, pkt: &Packet) -> Option<&FlowNode> {
254        let key = pkt.hash_key();
255        self.table.get(&key)
256    }
257
258    // 返回node的引用。如果table中没有,新建node
259    pub fn get_or_new(&mut self, pkt: &Packet, now: u128) -> Option<&FlowNode> {
260        let key = pkt.hash_key();
261        if self.contains_key(&key) {
262            return self.get(pkt);
263        }
264        self.insert(pkt, now)
265    }
266
267    // 返回packet所在节点的可变引用。如果不存在,返回None
268    pub fn get_mut(&mut self, pkt: &Packet) -> Option<&mut FlowNode> {
269        let key = pkt.hash_key();
270        self.table.get_mut(&key)
271    }
272
273    // 返回node的可变引用。如果table中没有,新建node
274    pub fn get_mut_or_new(&mut self, pkt: &Packet, now: u128) -> Option<&mut FlowNode> {
275        let key = pkt.hash_key();
276        if self.contains_key(&key) {
277            return self.get_mut(pkt);
278        }
279        self.insert_mut(pkt, now)
280    }
281
282    pub fn get_from_key(&self, key: &PacketKey) -> Option<&FlowNode> {
283        self.table.get(key)
284    }
285
286    pub fn get_mut_from_key(&mut self, key: &PacketKey) -> Option<&mut FlowNode> {
287        self.table.get_mut(key)
288    }
289
290    // 删除一个节点
291    pub fn remove(&mut self, key: &PacketKey) {
292        self.table.remove(key)
293    }
294
295    pub fn capacity(&self) -> usize {
296        self.table.capacity()
297    }
298
299    pub fn len(&self) -> usize {
300        self.table.len()
301    }
302
303    pub fn is_empty(&self) -> bool {
304        self.table.is_empty()
305    }
306
307    pub fn is_full(&self) -> bool {
308        self.table.is_full()
309    }
310
311    pub fn timeout<F>(&mut self, now: u128, fun: F)
312    where
313        F: Fn(&FlowNode),
314    {
315        self.table.timeout(|_key, node| {
316            if now - node.last_time >= self.node_timeout {
317                fun(node);
318                true
319            } else {
320                false
321            }
322        })
323    }
324}
325
326impl Default for Flow {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use etherparse::*;
336
337    #[test]
338    fn test_seqstream_new() {
339        let seq_stm = SeqStream::new();
340        assert_eq!(seq_stm.segment.len(), 0);
341        assert!(!seq_stm.fin);
342    }
343
344    #[test]
345    fn test_seqstream_fin() {
346        let mut seq_stm = SeqStream::new();
347        let pkt_fin = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 10, true);
348        let _ = pkt_fin.decode();
349
350        seq_stm.update(&pkt_fin);
351        assert!(seq_stm.fin);
352        assert_eq!(1, seq_stm.segment.len());
353        assert!(seq_stm.is_fin());
354    }
355
356    // case 1.
357    #[test]
358    fn test_seqstream_pre() {
359        let mut seq_stm = SeqStream::new();
360        let pkt_syn = build_syn([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 1);
361        let _ = pkt_syn.decode();
362        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
363        let _ = pkt1.decode();
364        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
365        let _ = pkt2.decode();
366        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
367        let _ = pkt3.decode();
368        let pkt_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 32);
369        let _ = pkt_fin.decode();
370
371        seq_stm.update(&pkt3);
372        assert_eq!(1, seq_stm.segment.len());
373        assert_eq!(22, seq_stm.segment[0].start);
374        assert_eq!(32, seq_stm.segment[0].next);
375
376        seq_stm.update(&pkt1);
377        assert_eq!(2, seq_stm.segment.len());
378
379        seq_stm.update(&pkt2);
380        assert_eq!(1, seq_stm.segment.len());
381        assert_eq!(2, seq_stm.segment[0].start);
382        assert_eq!(32, seq_stm.segment[0].next);
383
384        seq_stm.update(&pkt_syn);
385        seq_stm.update(&pkt_fin);
386        assert_eq!(1, seq_stm.segment.len());
387        assert_eq!(1, seq_stm.segment[0].start);
388        assert!(seq_stm.is_fin());
389        assert_eq!(32, seq_stm.segment[0].next);
390    }
391
392    // case 2.
393    #[test]
394    fn test_seqstream_case2() {
395        let mut seq_stm = SeqStream::new();
396        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
397        let _ = pkt1.decode();
398        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
399        let _ = pkt2.decode();
400        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
401        let _ = pkt3.decode();
402
403        seq_stm.update(&pkt3);
404        assert_eq!(1, seq_stm.segment.len());
405        assert_eq!(22, seq_stm.segment[0].start);
406        assert_eq!(32, seq_stm.segment[0].next);
407
408        seq_stm.update(&pkt2);
409        assert_eq!(1, seq_stm.segment.len());
410        assert_eq!(12, seq_stm.segment[0].start);
411        assert_eq!(32, seq_stm.segment[0].next);
412
413        seq_stm.update(&pkt1);
414        assert_eq!(1, seq_stm.segment.len());
415        assert_eq!(2, seq_stm.segment[0].start);
416        assert_eq!(32, seq_stm.segment[0].next);
417    }
418
419    // case 3. syn, 三个连续,最后一个空fin
420    #[test]
421    fn test_seqstream_normal() {
422        let mut seq_stm = SeqStream::new();
423        let pkt_syn = build_syn([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 1);
424        let _ = pkt_syn.decode();
425        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
426        let _ = pkt1.decode();
427        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
428        let _ = pkt2.decode();
429        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
430        let _ = pkt3.decode();
431        let pkt_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 32);
432        let _ = pkt_fin.decode();
433
434        seq_stm.update(&pkt_syn);
435        assert_eq!(1, seq_stm.segment.len());
436        assert_eq!(1, seq_stm.segment[0].start);
437        assert_eq!(2, seq_stm.segment[0].next);
438
439        seq_stm.update(&pkt1);
440        assert_eq!(1, seq_stm.segment.len());
441        assert_eq!(1, seq_stm.segment[0].start);
442        assert_eq!(12, seq_stm.segment[0].next);
443
444        seq_stm.update(&pkt2);
445        assert_eq!(1, seq_stm.segment.len());
446        assert_eq!(1, seq_stm.segment[0].start);
447        assert_eq!(22, seq_stm.segment[0].next);
448
449        seq_stm.update(&pkt3);
450        assert_eq!(1, seq_stm.segment.len());
451        assert_eq!(1, seq_stm.segment[0].start);
452        assert_eq!(32, seq_stm.segment[0].next);
453
454        seq_stm.update(&pkt_fin);
455        assert_eq!(1, seq_stm.segment.len());
456        assert_eq!(1, seq_stm.segment[0].start);
457        assert!(seq_stm.is_fin());
458        assert_eq!(32, seq_stm.segment[0].next);
459    }
460
461    // case 4
462    #[test]
463    fn test_seqstream_case4() {
464        let mut seq_stm = SeqStream::new();
465        let pkt_syn = build_syn([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 1);
466        let _ = pkt_syn.decode();
467        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
468        let _ = pkt1.decode();
469        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
470        let _ = pkt2.decode();
471        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
472        let _ = pkt3.decode();
473        let pkt4 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 32, false);
474        let _ = pkt4.decode();
475        let pkt_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 42);
476        let _ = pkt_fin.decode();
477
478        seq_stm.update(&pkt_syn);
479        assert_eq!(1, seq_stm.segment.len());
480        assert_eq!(1, seq_stm.segment[0].start);
481        assert_eq!(2, seq_stm.segment[0].next);
482
483        seq_stm.update(&pkt2);
484        assert_eq!(2, seq_stm.segment.len());
485
486        seq_stm.update(&pkt4);
487        assert_eq!(3, seq_stm.segment.len());
488
489        seq_stm.update(&pkt1);
490        seq_stm.update(&pkt3);
491        seq_stm.update(&pkt_fin);
492        assert_eq!(1, seq_stm.segment.len());
493        assert_eq!(1, seq_stm.segment[0].start);
494        assert_eq!(42, seq_stm.segment[0].next);
495    }
496
497    // case 5 见case 1
498
499    // case 6
500    #[test]
501    fn test_seqstream_case6() {
502        let mut seq_stm = SeqStream::new();
503        let pkt_syn = build_syn([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 1);
504        let _ = pkt_syn.decode();
505        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
506        let _ = pkt1.decode();
507        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
508        let _ = pkt2.decode();
509        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
510        let _ = pkt3.decode();
511        let pkt4 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 32, false);
512        let _ = pkt4.decode();
513        let pkt_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 42);
514        let _ = pkt_fin.decode();
515
516        seq_stm.update(&pkt_syn);
517        assert_eq!(1, seq_stm.segment.len());
518        assert_eq!(1, seq_stm.segment[0].start);
519        assert_eq!(2, seq_stm.segment[0].next);
520
521        seq_stm.update(&pkt4);
522        assert_eq!(2, seq_stm.segment.len());
523
524        seq_stm.update(&pkt1);
525        assert_eq!(2, seq_stm.segment.len());
526
527        seq_stm.update(&pkt2);
528        assert_eq!(2, seq_stm.segment.len());
529
530        seq_stm.update(&pkt3);
531        assert_eq!(1, seq_stm.segment.len());
532
533        seq_stm.update(&pkt_fin);
534        assert_eq!(1, seq_stm.segment.len());
535        assert_eq!(1, seq_stm.segment[0].start);
536        assert_eq!(42, seq_stm.segment[0].next);
537    }
538
539    // case 7
540    #[test]
541    fn test_seqstream_case7() {
542        let mut seq_stm = SeqStream::new();
543        let pkt_syn = build_syn([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 1);
544        let _ = pkt_syn.decode();
545        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
546        let _ = pkt1.decode();
547        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
548        let _ = pkt2.decode();
549        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
550        let _ = pkt3.decode();
551        let pkt4 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 32, false);
552        let _ = pkt4.decode();
553        let pkt_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 42);
554        let _ = pkt_fin.decode();
555
556        seq_stm.update(&pkt_syn);
557        assert_eq!(1, seq_stm.segment.len());
558        assert_eq!(1, seq_stm.segment[0].start);
559        assert_eq!(2, seq_stm.segment[0].next);
560
561        seq_stm.update(&pkt_fin);
562        assert_eq!(2, seq_stm.segment.len());
563        assert!(!seq_stm.is_fin());
564
565        seq_stm.update(&pkt4);
566        assert_eq!(2, seq_stm.segment.len());
567
568        seq_stm.update(&pkt3);
569        assert_eq!(2, seq_stm.segment.len());
570
571        seq_stm.update(&pkt2);
572        assert_eq!(2, seq_stm.segment.len());
573
574        seq_stm.update(&pkt1);
575        assert_eq!(1, seq_stm.segment.len());
576        assert_eq!(1, seq_stm.segment[0].start);
577        assert_eq!(42, seq_stm.segment[0].next);
578    }
579
580    // case 8
581    #[test]
582    fn test_seqstream_case8() {
583        let mut seq_stm = SeqStream::new();
584        let pkt_syn = build_syn([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 1);
585        let _ = pkt_syn.decode();
586        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
587        let _ = pkt1.decode();
588        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
589        let _ = pkt2.decode();
590        let pkt3 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 22, false);
591        let _ = pkt3.decode();
592        let pkt4 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 32, false);
593        let _ = pkt4.decode();
594        let pkt_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 42);
595        let _ = pkt_fin.decode();
596
597        seq_stm.update(&pkt_syn);
598        assert_eq!(1, seq_stm.segment.len());
599        assert_eq!(1, seq_stm.segment[0].start);
600        assert_eq!(2, seq_stm.segment[0].next);
601
602        seq_stm.update(&pkt_fin);
603        assert!(!seq_stm.is_fin());
604        assert_eq!(2, seq_stm.segment.len());
605        assert_eq!(1, seq_stm.segment[0].start);
606        assert_eq!(2, seq_stm.segment[0].next);
607        assert_eq!(42, seq_stm.segment[1].start);
608        assert_eq!(42, seq_stm.segment[1].next);
609
610        dbg!("before update pkt2. segment: {}", &seq_stm.segment);
611        seq_stm.update(&pkt2);
612        dbg!("update pkt2. segment: {}", &seq_stm.segment);
613        assert_eq!(3, seq_stm.segment.len());
614        assert_eq!(1, seq_stm.segment[0].start);
615        assert_eq!(2, seq_stm.segment[0].next);
616        assert_eq!(12, seq_stm.segment[1].start);
617        assert_eq!(22, seq_stm.segment[1].next);
618        assert_eq!(42, seq_stm.segment[2].start);
619        assert_eq!(42, seq_stm.segment[2].next);
620
621        seq_stm.update(&pkt4);
622        assert_eq!(3, seq_stm.segment.len());
623        assert_eq!(1, seq_stm.segment[0].start);
624        assert_eq!(2, seq_stm.segment[0].next);
625        assert_eq!(12, seq_stm.segment[1].start);
626        assert_eq!(22, seq_stm.segment[1].next);
627        assert_eq!(32, seq_stm.segment[2].start);
628        assert_eq!(42, seq_stm.segment[2].next);
629
630        seq_stm.update(&pkt1);
631        assert_eq!(2, seq_stm.segment.len());
632        assert_eq!(1, seq_stm.segment[0].start);
633        assert_eq!(22, seq_stm.segment[0].next);
634        assert_eq!(32, seq_stm.segment[1].start);
635        assert_eq!(42, seq_stm.segment[1].next);
636
637        seq_stm.update(&pkt3);
638        assert_eq!(1, seq_stm.segment.len());
639        assert_eq!(1, seq_stm.segment[0].start);
640        assert_eq!(42, seq_stm.segment[0].next);
641        assert!(seq_stm.is_fin());
642    }
643
644    #[test]
645    fn test_node_update() {
646        let pkt_c2s = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
647        let _ = pkt_c2s.decode();
648        let pkt_c2s_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12);
649        let _ = pkt_c2s_fin.decode();
650        let pkt_s2c = build_tcp([2, 2, 2, 2], [1, 1, 1, 1], 80, 333, 2, false);
651        let _ = pkt_s2c.decode();
652        let pkt_s2c_fin = build_fin([2, 2, 2, 2], [1, 1, 1, 1], 80, 333, 12);
653        let _ = pkt_s2c_fin.decode();
654        let mut node = FlowNode::new(pkt_c2s.hash_key(), 888, MAX_SEQ_GAP);
655
656        assert_eq!(888, node.last_time);
657        assert_eq!(pkt_c2s.hash_key(), node.key);
658        assert_eq!(pkt_s2c.hash_key(), node.key);
659
660        node.update(&pkt_c2s, 1000);
661        assert_eq!(1000, node.last_time);
662        assert_eq!(0, node.seq_strm1.segment.len());
663        assert_eq!(1, node.seq_strm2.segment.len());
664        assert_eq!(2, node.seq_strm2.segment[0].start);
665        assert_eq!(12, node.seq_strm2.segment[0].next);
666
667        node.update(&pkt_s2c, 1001);
668        assert_eq!(1001, node.last_time);
669        assert_eq!(1, node.seq_strm1.segment.len());
670        assert_eq!(1, node.seq_strm2.segment.len());
671        assert_eq!(2, node.seq_strm1.segment[0].start);
672        assert_eq!(12, node.seq_strm1.segment[0].next);
673
674        node.update(&pkt_c2s_fin, 1002);
675        assert_eq!(1, node.seq_strm1.segment.len());
676        assert_eq!(1, node.seq_strm2.segment.len());
677        assert!(node.seq_strm2.is_fin());
678        assert!(!node.is_fin());
679
680        node.update(&pkt_s2c_fin, 1003);
681        assert_eq!(1, node.seq_strm1.segment.len());
682        assert_eq!(1, node.seq_strm2.segment.len());
683        assert!(node.seq_strm2.is_fin());
684        assert!(node.seq_strm1.is_fin());
685        assert!(node.is_fin());
686    }
687
688    #[test]
689    fn test_flow() {
690        let pkt_c2s = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
691        let _ = pkt_c2s.decode();
692        let pkt_c2s_fin = build_fin([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12);
693        let _ = pkt_c2s_fin.decode();
694        let pkt_s2c = build_tcp([2, 2, 2, 2], [1, 1, 1, 1], 80, 333, 2, false);
695        let _ = pkt_s2c.decode();
696        let pkt_s2c_fin = build_fin([2, 2, 2, 2], [1, 1, 1, 1], 80, 333, 12);
697        let _ = pkt_s2c_fin.decode();
698        let mut flow = Flow::new();
699
700        let node = flow.get_mut_or_new(&pkt_c2s, 1000).unwrap();
701        node.update(&pkt_c2s, 1000);
702        assert_eq!(1000, node.last_time);
703        assert_eq!(0, node.seq_strm1.segment.len());
704        assert_eq!(1, node.seq_strm2.segment.len());
705        assert_eq!(2, node.seq_strm2.segment[0].start);
706        assert_eq!(12, node.seq_strm2.segment[0].next);
707        assert_eq!(1, flow.len());
708
709        let node = flow.get_mut_or_new(&pkt_s2c, 1001).unwrap();
710        node.update(&pkt_s2c, 1001);
711        assert_eq!(1001, node.last_time);
712        assert_eq!(1, node.seq_strm1.segment.len());
713        assert_eq!(1, node.seq_strm2.segment.len());
714        assert_eq!(2, node.seq_strm1.segment[0].start);
715        assert_eq!(12, node.seq_strm1.segment[0].next);
716        assert_eq!(1, flow.len());
717
718        let node = flow.get_mut_or_new(&pkt_c2s_fin, 1002).unwrap();
719        node.update(&pkt_c2s_fin, 1002);
720        let node = flow.get_mut_or_new(&pkt_s2c_fin, 1003).unwrap();
721        node.update(&pkt_s2c_fin, 1003);
722        assert!(node.is_fin());
723        let key = node.key;
724        flow.remove(&key);
725        assert_eq!(0, flow.len());
726    }
727
728    #[test]
729    fn test_flow_timeout() {
730        let pkt1 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 2, false);
731        let _ = pkt1.decode();
732        let pkt2 = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 333, 80, 12, false);
733        let _ = pkt2.decode();
734        let mut flow = Flow::new();
735        let mut now = 1000;
736
737        let node = flow.get_mut_or_new(&pkt1, now).unwrap();
738        node.update(&pkt1, now);
739        assert_eq!(now, node.start_time);
740        assert_eq!(now, node.last_time);
741
742        now += 100;
743        let node = flow.get_mut_or_new(&pkt2, now).unwrap();
744        node.update(&pkt2, now);
745        assert_eq!(now, node.last_time);
746
747        now += NODE_TIMEOUT;
748        flow.timeout(now, |node| {
749            test_call_node(node);
750        });
751        assert!(flow.is_empty());
752    }
753
754    fn test_call_node(node: &FlowNode) {
755        assert_eq!(node.start_time, 1000);
756    }
757
758    fn build_tcp(
759        sip: [u8; 4],
760        dip: [u8; 4],
761        sport: u16,
762        dport: u16,
763        seq: u32,
764        fin: bool,
765    ) -> Packet {
766        let mut builder = PacketBuilder::ethernet2(
767            [1, 2, 3, 4, 5, 6], //source mac
768            [7, 8, 9, 10, 11, 12],
769        ) //destionation mac
770        .ipv4(
771            sip, //source ip
772            dip, //desitionation ip
773            20,
774        ) //time to life
775        .tcp(
776            sport, //source port
777            dport, //desitnation port
778            seq,   //sequence number
779            1024,
780        ) //window size
781        //set additional tcp header fields
782        .ns() //set the ns flag
783        //supported flags: ns(), fin(), syn(), rst(), psh(), ece(), cwr()
784        .ack(123) //ack flag + the ack number
785        .urg(23) //urg flag + urgent pointer
786        .options(&[
787            TcpOptionElement::Noop,
788            TcpOptionElement::MaximumSegmentSize(1234),
789        ])
790        .unwrap();
791        if fin {
792            builder = builder.fin();
793        }
794
795        //payload of the tcp packet
796        let payload = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
797        //get some memory to store the result
798        let mut result = Vec::<u8>::with_capacity(builder.size(payload.len()));
799        //serialize
800        //this will automatically set all length fields, checksums and identifiers (ethertype & protocol)
801        builder.write(&mut result, &payload).unwrap();
802        // println!("result len:{}", result.len());
803
804        let pkt = Packet::new(result, 1);
805        let _ = pkt.decode();
806        pkt
807    }
808
809    // sync包,不带载荷
810    fn build_syn(sip: [u8; 4], dip: [u8; 4], sport: u16, dport: u16, seq: u32) -> Packet {
811        let builder = PacketBuilder::ethernet2(
812            [1, 2, 3, 4, 5, 6], //source mac
813            [7, 8, 9, 10, 11, 12],
814        ) //destionation mac
815        .ipv4(
816            sip, //source ip
817            dip, //desitionation ip
818            20,
819        ) //time to life
820        .tcp(
821            sport, //source port
822            dport, //desitnation port
823            seq,   //sequence number
824            1024,
825        ) //window size
826        //set additional tcp header fields
827        .ns() //set the ns flag
828        //supported flags: ns(), fin(), syn(), rst(), psh(), ece(), cwr()
829        .syn()
830        .ack(123) //ack flag + the ack number
831        .urg(23) //urg flag + urgent pointer
832        .options(&[
833            TcpOptionElement::Noop,
834            TcpOptionElement::MaximumSegmentSize(1234),
835        ])
836        .unwrap();
837
838        //payload of the tcp packet
839        let payload = [];
840        //get some memory to store the result
841        let mut result = Vec::<u8>::with_capacity(builder.size(payload.len()));
842        //serialize
843        //this will automatically set all length fields, checksums and identifiers (ethertype & protocol)
844        builder.write(&mut result, &payload).unwrap();
845        // println!("result len:{}", result.len());
846
847        let pkt = Packet::new(result, 1);
848        let _ = pkt.decode();
849        pkt
850    }
851
852    // fin包,不带载荷
853    fn build_fin(sip: [u8; 4], dip: [u8; 4], sport: u16, dport: u16, seq: u32) -> Packet {
854        let builder = PacketBuilder::ethernet2(
855            [1, 2, 3, 4, 5, 6], //source mac
856            [7, 8, 9, 10, 11, 12],
857        ) //destionation mac
858        .ipv4(
859            sip, //source ip
860            dip, //desitionation ip
861            20,
862        ) //time to life
863        .tcp(
864            sport, //source port
865            dport, //desitnation port
866            seq,   //sequence number
867            1024,
868        ) //window size
869        //set additional tcp header fields
870        .ns() //set the ns flag
871        //supported flags: ns(), fin(), syn(), rst(), psh(), ece(), cwr()
872        .fin()
873        .ack(123) //ack flag + the ack number
874        .urg(23) //urg flag + urgent pointer
875        .options(&[
876            TcpOptionElement::Noop,
877            TcpOptionElement::MaximumSegmentSize(1234),
878        ])
879        .unwrap();
880
881        //payload of the tcp packet
882        let payload = [];
883        //get some memory to store the result
884        let mut result = Vec::<u8>::with_capacity(builder.size(payload.len()));
885        //serialize
886        //this will automatically set all length fields, checksums and identifiers (ethertype & protocol)
887        builder.write(&mut result, &payload).unwrap();
888        // println!("result len:{}", result.len());
889
890        let pkt = Packet::new(result, 1);
891        let _ = pkt.decode();
892        pkt
893    }
894}