1use crate::addr::Inet4Addr;
20use std::net;
21
22const fn partial_netsum(current: u32, buffer: &[u8]) -> u32 {
23    let mut i = 0;
24    let mut sum = current;
25    let mut count = buffer.len();
26    while count > 1 {
27        let v = ((buffer[i] as u32) << 8) | (buffer[i + 1] as u32);
28        sum += v;
29        i += 2;
30        count -= 2;
31    }
32    if count > 0 {
33        let v = (buffer[i] as u32) << 8;
34        sum += v;
35    }
36    sum
37}
38
39const fn finish_netsum(sum: u32) -> u16 {
40    let mut sum = sum;
41    while sum > 0xffff {
42        sum = (sum >> 16) + (sum & 0xFFFF);
43    }
44    !(sum as u16)
45}
46
47#[derive(Clone, Debug)]
48pub enum Tail<'a> {
49    Payload(&'a [u8]),
50    Fragment(Box<Fragment<'a>>),
51    #[allow(dead_code)]
52    None,
53}
54
55impl<'a> Tail<'a> {
56    fn len(&self) -> usize {
57        match self {
58            Tail::Payload(x) => x.len(),
59            Tail::Fragment(x) => x.len(),
60            Tail::None => 0,
61        }
62    }
63
64    fn partial_netsum(&self, current: u32) -> u32 {
65        match self {
66            Tail::Payload(x) => partial_netsum(current, x),
67            Tail::Fragment(x) => x.partial_netsum(current),
68            Tail::None => current,
69        }
70    }
71}
72
73#[derive(Clone, Debug)]
74pub struct Fragment<'a> {
75    buffer: Vec<u8>,
76    tail: Tail<'a>,
77}
78
79impl<'a> Fragment<'a> {
80    fn len(&self) -> usize {
81        self.buffer.len() + self.tail.len()
82    }
83    fn partial_netsum(&self, current: u32) -> u32 {
84        self.tail
85            .partial_netsum(partial_netsum(current, &self.buffer))
86    }
87    fn netsum(&self) -> u16 {
88        finish_netsum(self.partial_netsum(0))
89    }
90    pub fn flatten(&self) -> Vec<u8> {
91        let mut x = self;
92        let mut ret = vec![];
93        loop {
94            ret.extend_from_slice(&x.buffer);
95            match &x.tail {
96                Tail::None => break,
97                Tail::Payload(x) => {
98                    ret.extend_from_slice(x);
99                    break;
100                }
101                Tail::Fragment(f) => {
102                    x = f.as_ref();
103                }
104            }
105        }
106        ret
107    }
108
109    const fn from_tail(tail: Tail) -> Fragment {
110        Fragment {
111            buffer: vec![],
112            tail,
113        }
114    }
115    fn push_u8(&mut self, b: u8) {
116        self.buffer.push(b);
117    }
118    fn push_bytes(&mut self, b: &[u8]) {
119        self.buffer.extend_from_slice(b);
120    }
121    fn push_be16(&mut self, b: u16) {
122        self.push_bytes(&b.to_be_bytes());
123    }
124
125    fn new_ethernet<'l>(
126        dst: &[u8; 6],
127        src: &[u8; 6],
128        ethertype: u16,
129        payload: Tail<'l>,
130    ) -> Fragment<'l> {
131        let mut f = Fragment::from_tail(payload);
132        f.push_bytes(dst);
133        f.push_bytes(src);
134        f.push_be16(ethertype);
135        f
136    }
137
138    fn new_ipv4<'l>(
139        src: &net::Ipv4Addr,
140        srcmac: &[u8; 6],
141        dst: &net::Ipv4Addr,
142        dstmac: &[u8; 6],
143        protocol: u8,
144        payload: Tail<'l>,
145    ) -> Fragment<'l> {
146        let mut f = Fragment::from_tail(payload);
147        f.push_u8(0x45); f.push_u8(0x00); f.push_be16(20_u16 + f.tail.len() as u16); f.push_be16(0x0000); f.push_be16(0x0000); f.push_u8(0x01); f.push_u8(protocol);
154        f.push_be16(0x0000); f.push_bytes(&src.octets());
156        f.push_bytes(&dst.octets());
157        let netsum = finish_netsum(partial_netsum(0, &f.buffer));
158        f.buffer[10] = (netsum >> 8) as u8;
159        f.buffer[11] = (netsum & 0xFF) as u8;
160        Self::new_ethernet(dstmac, srcmac, 0x0800_u16, Tail::Fragment(Box::new(f)))
161    }
162
163    pub fn new_udp4<'l>(
164        src: Inet4Addr,
165        srcmac: &[u8; 6],
166        dst: Inet4Addr,
167        dstmac: &[u8; 6],
168        payload: Tail<'l>,
169    ) -> Fragment<'l> {
170        let mut f = Self::from_tail(payload);
171        f.push_be16(src.port());
172        f.push_be16(dst.port());
173        f.push_be16(8_u16 + f.tail.len() as u16); f.push_be16(0x0000); let l = f.len();
176        let mut pseudohdr = Self::from_tail(Tail::Fragment(Box::new(f.clone())));
177        let udp_protocol: u8 = 17;
178        pseudohdr.push_bytes(&u32::to_be_bytes(src.ip()));
179        pseudohdr.push_bytes(&u32::to_be_bytes(dst.ip()));
180        pseudohdr.push_u8(0x00_u8);
181        pseudohdr.push_u8(udp_protocol);
182        pseudohdr.push_be16(l as u16);
183        let netsum = pseudohdr.netsum();
184        f.buffer[6] = (netsum >> 8) as u8;
185        f.buffer[7] = (netsum & 0xFF) as u8;
186        let t = Tail::Fragment(Box::new(f.clone()));
187        Self::new_ipv4(
188            &src.ip().into(),
189            srcmac,
190            &dst.ip().into(),
191            dstmac,
192            udp_protocol,
193            t,
194        )
195    }
196}
197
198#[test]
199fn test_udp_packet() {
200    let u = Fragment::new_udp4(
201        "192.0.2.1:1".parse().unwrap(),
202        &[2, 0, 0, 0, 0, 0],
203        "192.0.2.2:2".parse().unwrap(),
204        &[2, 0, 0, 0, 0, 1],
205        Tail::Payload(&[1, 2, 3, 4]),
206    );
207    println!("u={:?}", u);
208}
209
210#[test]
211fn test_checksum() {
212    let data = vec![8, 0, 0, 0, 0x12, 0x34, 0x00, 0x01];
213
214    assert_eq!(finish_netsum(partial_netsum(0, &data)), 0xE5CA);
215}