1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#![cfg_attr(not(feature = "std"), no_std)]
#![allow(async_fn_in_trait)]
#![warn(clippy::large_futures)]

use core::fmt;

use no_std_net::{Ipv4Addr, SocketAddrV4};

use self::udp::UdpPacketHeader;

#[cfg(feature = "io")]
pub mod io;

pub mod bytes;
pub mod ip;
pub mod udp;

use bytes::BytesIn;

#[derive(Debug)]
pub enum Error {
    DataUnderflow,
    BufferOverflow,
    InvalidFormat,
    InvalidChecksum,
}

impl From<bytes::Error> for Error {
    fn from(value: bytes::Error) -> Self {
        match value {
            bytes::Error::BufferOverflow => Self::BufferOverflow,
            bytes::Error::DataUnderflow => Self::DataUnderflow,
            bytes::Error::InvalidFormat => Self::InvalidFormat,
        }
    }
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let str = match self {
            Self::DataUnderflow => "Data underflow",
            Self::BufferOverflow => "Buffer overflow",
            Self::InvalidFormat => "Invalid format",
            Self::InvalidChecksum => "Invalid checksum",
        };

        write!(f, "{}", str)
    }
}

#[cfg(feature = "std")]
impl std::error::Error for Error {}

#[allow(clippy::type_complexity)]
pub fn ip_udp_decode(
    packet: &[u8],
    filter_src: Option<SocketAddrV4>,
    filter_dst: Option<SocketAddrV4>,
) -> Result<Option<(SocketAddrV4, SocketAddrV4, &[u8])>, Error> {
    if let Some((src, dst, _proto, udp_packet)) = ip::decode(
        packet,
        filter_src.map(|a| *a.ip()).unwrap_or(Ipv4Addr::UNSPECIFIED),
        filter_dst.map(|a| *a.ip()).unwrap_or(Ipv4Addr::UNSPECIFIED),
        Some(UdpPacketHeader::PROTO),
    )? {
        udp::decode(
            src,
            dst,
            udp_packet,
            filter_src.map(|a| a.port()),
            filter_dst.map(|a| a.port()),
        )
    } else {
        Ok(None)
    }
}

pub fn ip_udp_encode<F>(
    buf: &mut [u8],
    src: SocketAddrV4,
    dst: SocketAddrV4,
    encoder: F,
) -> Result<&[u8], Error>
where
    F: FnOnce(&mut [u8]) -> Result<usize, Error>,
{
    ip::encode(buf, *src.ip(), *dst.ip(), UdpPacketHeader::PROTO, |buf| {
        Ok(udp::encode(buf, src, dst, encoder)?.len())
    })
}

pub fn checksum_accumulate(bytes: &[u8], checksum_word: usize) -> u32 {
    let mut bytes = BytesIn::new(bytes);

    let mut sum: u32 = 0;
    while !bytes.is_empty() {
        let skip = (bytes.offset() >> 1) == checksum_word;
        let arr = bytes
            .arr()
            .ok()
            .unwrap_or_else(|| [bytes.byte().unwrap(), 0]);

        let word = if skip { 0 } else { u16::from_be_bytes(arr) };

        sum += word as u32;
    }

    sum
}

pub fn checksum_finish(mut sum: u32) -> u16 {
    while sum >> 16 != 0 {
        sum = (sum >> 16) + (sum & 0xffff);
    }

    !sum as u16
}