libnsave/
packet.rs

1use crate::common::*;
2use etherparse::{Ethernet2Header, IpHeader, PacketHeaders, TransportHeader, VlanHeader};
3use serde::{Deserialize, Serialize};
4use std::cell::RefCell;
5use std::collections::hash_map::DefaultHasher;
6use std::fmt;
7use std::hash::{Hash, Hasher};
8use std::io::Write;
9use std::net::IpAddr;
10use std::net::Ipv4Addr;
11use std::ops::Deref;
12
13#[derive(Eq, PartialEq, Clone, Debug)]
14pub struct PktHeader {
15    link: Option<Ethernet2Header>,
16    vlan: Option<VlanHeader>,
17    pub ip: Option<IpHeader>,
18    transport: Option<TransportHeader>,
19    payload_offset: usize,
20    payload_len: usize,
21}
22
23unsafe impl Send for PktHeader {}
24unsafe impl Sync for PktHeader {}
25
26pub enum PacketError {
27    DecodeErr,
28}
29
30#[derive(Eq, PartialEq, Clone)]
31pub struct Packet {
32    pub timestamp: u128,
33    pub header: RefCell<Option<PktHeader>>,
34    pub data: Vec<u8>,
35}
36
37impl Packet {
38    pub fn new(data: Vec<u8>, ts: u128) -> Self {
39        Packet {
40            timestamp: ts,
41            data,
42            header: RefCell::new(None),
43        }
44    }
45
46    pub fn decode(&self) -> Result<(), PacketError> {
47        match PacketHeaders::from_ethernet_slice(self) {
48            Ok(headers) => {
49                if headers.ip.is_none() || headers.transport.is_none() {
50                    return Err(PacketError::DecodeErr);
51                }
52
53                self.header.replace(Some(PktHeader {
54                    link: headers.link,
55                    vlan: headers.vlan,
56                    ip: headers.ip,
57                    transport: headers.transport,
58                    payload_offset: headers.payload.as_ptr() as usize - self.data.as_ptr() as usize,
59                    payload_len: self.data.len()
60                        - (headers.payload.as_ptr() as usize - self.data.as_ptr() as usize),
61                }));
62                Ok(())
63            }
64            Err(_) => Err(PacketError::DecodeErr),
65        }
66    }
67
68    pub fn sport(&self) -> u16 {
69        match &self.header.borrow().as_ref().unwrap().transport {
70            Some(TransportHeader::Udp(udph)) => udph.source_port,
71            Some(TransportHeader::Tcp(tcph)) => tcph.source_port,
72            _ => 0,
73        }
74    }
75
76    pub fn dport(&self) -> u16 {
77        match &self.header.borrow().as_ref().unwrap().transport {
78            Some(TransportHeader::Udp(udph)) => udph.destination_port,
79            Some(TransportHeader::Tcp(tcph)) => tcph.destination_port,
80            _ => 0,
81        }
82    }
83
84    pub fn seq(&self) -> u32 {
85        if let Some(TransportHeader::Tcp(tcph)) = &self.header.borrow().as_ref().unwrap().transport
86        {
87            tcph.sequence_number
88        } else {
89            0
90        }
91    }
92
93    pub fn syn(&self) -> bool {
94        if let Some(TransportHeader::Tcp(tcph)) = &self.header.borrow().as_ref().unwrap().transport
95        {
96            tcph.syn
97        } else {
98            false
99        }
100    }
101
102    pub fn fin(&self) -> bool {
103        if let Some(TransportHeader::Tcp(tcph)) = &self.header.borrow().as_ref().unwrap().transport
104        {
105            tcph.fin
106        } else {
107            false
108        }
109    }
110
111    pub fn payload(&self) -> &[u8] {
112        let offset = self.header.borrow().as_ref().unwrap().payload_offset;
113        let len = self.header.borrow().as_ref().unwrap().payload_len;
114        &self.data[offset..offset + len]
115    }
116
117    pub fn payload_len(&self) -> u32 {
118        self.header
119            .borrow()
120            .as_ref()
121            .unwrap()
122            .payload_len
123            .try_into()
124            .unwrap()
125    }
126
127    pub fn trans_proto(&self) -> TransProto {
128        match self.header.borrow().as_ref().unwrap().transport {
129            Some(TransportHeader::Udp(_)) => TransProto::Udp,
130            Some(TransportHeader::Tcp(_)) => TransProto::Tcp,
131            Some(TransportHeader::Icmpv4(_)) => TransProto::Icmp4,
132            Some(TransportHeader::Icmpv6(_)) => TransProto::Icmp6,
133            None => panic!("unknown transport protocol."),
134        }
135    }
136
137    pub fn hash_key(&self) -> PacketKey {
138        match &self.header.borrow().as_ref().unwrap().ip {
139            Some(IpHeader::Version4(ipv4h, _)) => {
140                if ipv4h.source > ipv4h.destination {
141                    PacketKey {
142                        addr1: ipv4h.source.into(),
143                        port1: self.sport(),
144                        addr2: ipv4h.destination.into(),
145                        port2: self.dport(),
146                        trans_proto: self.trans_proto(),
147                    }
148                } else if ipv4h.source < ipv4h.destination {
149                    PacketKey {
150                        addr1: ipv4h.destination.into(),
151                        port1: self.dport(),
152                        addr2: ipv4h.source.into(),
153                        port2: self.sport(),
154                        trans_proto: self.trans_proto(),
155                    }
156                } else if self.sport() >= self.dport() {
157                    PacketKey {
158                        addr1: ipv4h.source.into(),
159                        port1: self.sport(),
160                        addr2: ipv4h.destination.into(),
161                        port2: self.dport(),
162                        trans_proto: self.trans_proto(),
163                    }
164                } else {
165                    PacketKey {
166                        addr1: ipv4h.destination.into(),
167                        port1: self.dport(),
168                        addr2: ipv4h.source.into(),
169                        port2: self.sport(),
170                        trans_proto: self.trans_proto(),
171                    }
172                }
173            }
174            Some(IpHeader::Version6(ipv6h, _)) => {
175                if ipv6h.source > ipv6h.destination {
176                    PacketKey {
177                        addr1: ipv6h.source.into(),
178                        port1: self.sport(),
179                        addr2: ipv6h.destination.into(),
180                        port2: self.dport(),
181                        trans_proto: self.trans_proto(),
182                    }
183                } else if ipv6h.source < ipv6h.destination {
184                    PacketKey {
185                        addr1: ipv6h.destination.into(),
186                        port1: self.dport(),
187                        addr2: ipv6h.source.into(),
188                        port2: self.sport(),
189                        trans_proto: self.trans_proto(),
190                    }
191                } else if self.sport() >= self.dport() {
192                    PacketKey {
193                        addr1: ipv6h.source.into(),
194                        port1: self.sport(),
195                        addr2: ipv6h.destination.into(),
196                        port2: self.dport(),
197                        trans_proto: self.trans_proto(),
198                    }
199                } else {
200                    PacketKey {
201                        addr1: ipv6h.destination.into(),
202                        port1: self.dport(),
203                        addr2: ipv6h.source.into(),
204                        port2: self.sport(),
205                        trans_proto: self.trans_proto(),
206                    }
207                }
208            }
209            None => PacketKey {
210                addr1: Ipv4Addr::new(0, 0, 0, 0).into(),
211                port1: 0,
212                addr2: Ipv4Addr::new(0, 0, 0, 0).into(),
213                port2: 0,
214                trans_proto: TransProto::Icmp6,
215            },
216        }
217    }
218
219    pub fn hash_value(&self) -> u64 {
220        hash_val(self)
221    }
222
223    // 按照StorePacket格式写入
224    pub fn serialize_into<W: Write>(&self, writer: &mut W) -> Result<(), StoreError> {
225        let next_offset: u32 = 0;
226        writer.write_all(&next_offset.to_le_bytes())?;
227        writer.write_all(&self.timestamp.to_le_bytes())?;
228        writer.write_all(&((self.data.len() as u16).to_le_bytes()))?;
229        writer.write_all(&self.data)?;
230        Ok(())
231    }
232
233    pub fn serialize_size(&self) -> u32 {
234        22 + self.data.len() as u32
235    }
236}
237
238impl Deref for Packet {
239    type Target = Vec<u8>;
240
241    fn deref(&self) -> &Self::Target {
242        &self.data
243    }
244}
245
246impl fmt::Debug for Packet {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        f.debug_struct("Packet")
249            .field("timestamp", &self.timestamp)
250            .field("header", &self.header)
251            .field("data", &self.data)
252            .finish()
253    }
254}
255
256unsafe impl Send for Packet {}
257unsafe impl Sync for Packet {}
258
259impl Hash for Packet {
260    fn hash<H: Hasher>(&self, state: &mut H) {
261        self.hash_key().hash(state)
262    }
263}
264
265fn hash_val<T: Hash>(t: &T) -> u64 {
266    let mut s = DefaultHasher::new();
267    t.hash(&mut s);
268    s.finish()
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Hash, Clone, Copy)]
272pub enum TransProto {
273    Udp,
274    Tcp,
275    Icmp4,
276    Icmp6,
277}
278
279#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Hash, Clone, Copy)]
280pub struct PacketKey {
281    pub addr1: IpAddr,
282    pub port1: u16,
283    pub addr2: IpAddr,
284    pub port2: u16,
285    pub trans_proto: TransProto,
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use etherparse::*;
292
293    #[test]
294    fn test_decode() {
295        let pkt = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 1, 2);
296        let _ = pkt.decode();
297
298        if let Some(IpHeader::Version4(ipv4h, _)) = &pkt.header.borrow().as_ref().unwrap().ip {
299            assert_eq!(
300                Ipv4Addr::new(1, 1, 1, 1),
301                <[u8; 4] as std::convert::Into<IpAddr>>::into(ipv4h.source)
302            );
303            assert_eq!(
304                Ipv4Addr::new(2, 2, 2, 2),
305                <[u8; 4] as std::convert::Into<IpAddr>>::into(ipv4h.destination)
306            );
307        }
308        assert_eq!(TransProto::Tcp, pkt.trans_proto());
309        assert_eq!(1, pkt.sport());
310        assert_eq!(2, pkt.dport());
311        assert!(!pkt.syn());
312        assert_eq!(1234, pkt.seq());
313        assert!(pkt.fin());
314        assert_eq!(10, pkt.payload_len());
315        assert_eq!([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], pkt.payload());
316    }
317
318    #[test]
319    fn test_key() {
320        let pkt = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 1, 2);
321        let _ = pkt.decode();
322        let key = PacketKey {
323            addr1: Ipv4Addr::new(2, 2, 2, 2).into(),
324            port1: 2,
325            addr2: Ipv4Addr::new(1, 1, 1, 1).into(),
326            port2: 1,
327            trans_proto: TransProto::Tcp,
328        };
329        assert_eq!(key, pkt.hash_key());
330
331        let pkt = build_tcp([1, 1, 1, 1], [1, 1, 1, 1], 1, 2);
332        let _ = pkt.decode();
333        let key = PacketKey {
334            addr1: Ipv4Addr::new(1, 1, 1, 1).into(),
335            port1: 2,
336            addr2: Ipv4Addr::new(1, 1, 1, 1).into(),
337            port2: 1,
338            trans_proto: TransProto::Tcp,
339        };
340        assert_eq!(key, pkt.hash_key());
341
342        let pkt = build_tcp([1, 1, 1, 1], [1, 1, 1, 1], 1, 1);
343        let _ = pkt.decode();
344        let key = PacketKey {
345            addr1: Ipv4Addr::new(1, 1, 1, 1).into(),
346            port1: 1,
347            addr2: Ipv4Addr::new(1, 1, 1, 1).into(),
348            port2: 1,
349            trans_proto: TransProto::Tcp,
350        };
351        assert_eq!(key, pkt.hash_key());
352    }
353
354    #[test]
355    fn test_hash() {
356        let pkt_c2s = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 1, 2);
357        let _ = pkt_c2s.decode();
358        let pkt_s2c = build_tcp([2, 2, 2, 2], [1, 1, 1, 1], 2, 1);
359        let _ = pkt_s2c.decode();
360        let pkt_other = build_tcp([1, 1, 1, 1], [2, 2, 2, 2], 1, 3);
361        let _ = pkt_other.decode();
362
363        assert_eq!(pkt_c2s.hash_key(), pkt_s2c.hash_key());
364        assert_eq!(hash_val(&pkt_c2s), hash_val(&pkt_s2c));
365        assert_ne!(hash_val(&pkt_c2s), hash_val(&pkt_other));
366    }
367
368    fn build_tcp(sip: [u8; 4], dip: [u8; 4], sport: u16, dport: u16) -> Packet {
369        let builder = PacketBuilder::ethernet2(
370            [1, 2, 3, 4, 5, 6], //source mac
371            [7, 8, 9, 10, 11, 12],
372        ) //destionation mac
373        .ipv4(
374            sip, //source ip
375            dip, //desitionation ip
376            20,
377        ) //time to life
378        .tcp(
379            sport, //source port
380            dport, //desitnation port
381            1234,  //sequence number
382            1024,
383        ) //window size
384        //set additional tcp header fields
385        .ns() //set the ns flag
386        //supported flags: ns(), fin(), syn(), rst(), psh(), ece(), cwr()
387        .fin()
388        .ack(123) //ack flag + the ack number
389        .urg(23) //urg flag + urgent pointer
390        .options(&[
391            TcpOptionElement::Noop,
392            TcpOptionElement::MaximumSegmentSize(1234),
393        ])
394        .unwrap();
395
396        //payload of the tcp packet
397        let payload = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
398        //get some memory to store the result
399        let mut result = Vec::<u8>::with_capacity(builder.size(payload.len()));
400        //serialize
401        //this will automatically set all length fields, checksums and identifiers (ethertype & protocol)
402        builder.write(&mut result, &payload).unwrap();
403        println!("result len:{}", result.len());
404
405        let pkt = Packet::new(result, 1);
406        let _ = pkt.decode();
407        pkt
408    }
409}