use crate::ip::{IpV4Frame, IpV4Header};
use crate::{calc_ip_checksum_finalize, calc_ip_checksum_incomplete};
use byte_struct::*;
pub use ufmt::derive::uDebug;
#[derive(ByteStruct, Clone, Copy, uDebug, Debug, PartialEq, Eq)]
#[byte_struct_be]
pub struct UdpHeader {
pub src_port: u16,
pub dst_port: u16,
pub length: u16,
pub checksum: u16,
}
impl UdpHeader {
fn len(&self) -> usize {
Self::BYTE_LEN
}
pub fn to_be_bytes(&self) -> [u8; Self::BYTE_LEN] {
let mut bytes = [0_u8; Self::BYTE_LEN];
self.write_bytes(&mut bytes);
bytes
}
}
#[derive(Clone, Copy, uDebug, Debug, PartialEq, Eq)]
pub struct UdpFrame<T: ByteStruct> {
pub header: UdpHeader,
pub data: T,
}
impl<T: ByteStruct> UdpFrame<T> {
pub fn to_be_bytes(&self) -> [u8; Self::BYTE_LEN] {
let mut bytes = [0_u8; Self::BYTE_LEN];
self.write_bytes(&mut bytes);
bytes
}
}
impl<T> ByteStructLen for UdpFrame<T>
where
T: ByteStruct,
{
const BYTE_LEN: usize = IpV4Header::BYTE_LEN + UdpHeader::BYTE_LEN + T::BYTE_LEN;
}
impl<T> ByteStruct for UdpFrame<T>
where
T: ByteStruct,
{
fn read_bytes(bytes: &[u8]) -> Self {
UdpFrame::<T> {
header: UdpHeader::read_bytes(&bytes[0..UdpHeader::BYTE_LEN]),
data: T::read_bytes(&bytes[UdpHeader::BYTE_LEN..Self::BYTE_LEN]),
}
}
fn write_bytes(&self, bytes: &mut [u8]) {
self.header.write_bytes(&mut bytes[0..UdpHeader::BYTE_LEN]);
self.data
.write_bytes(&mut bytes[UdpHeader::BYTE_LEN..Self::BYTE_LEN]);
}
}
pub fn calc_udp_checksum<T: ByteStruct>(ipframe: &IpV4Frame<UdpFrame<T>>) -> u16
where
[(); UdpFrame::<T>::BYTE_LEN]:,
{
let udp_len = ipframe.data.header.length;
let udp_length_bytes = udp_len.to_be_bytes();
let ip_pseudoheader: [u8; 4] = [
0,
(ipframe.header.protocol as u8).to_be(),
udp_length_bytes[0],
udp_length_bytes[1],
];
let mut sum: u32 = 0;
sum += calc_ip_checksum_incomplete(&ipframe.header.src_ipaddr.0); sum += calc_ip_checksum_incomplete(&ipframe.header.dst_ipaddr.0);
sum += calc_ip_checksum_incomplete(&ip_pseudoheader); let index = UdpFrame::<T>::BYTE_LEN.min(udp_len as usize); sum += calc_ip_checksum_incomplete(&ipframe.data.to_be_bytes()[..index]);
let checksum: u16 = calc_ip_checksum_finalize(sum);
checksum
}