ntp-proto 1.7.2

ntpd-rs packet parsing and algorithms
Documentation
use crate::io::NonBlockingWrite;
use crate::packet::ExtensionField;
use crate::packet::error::ParsingError;
use crate::packet::extension_fields::ExtensionFieldTypeId;
use crate::packet::v5::server_reference_id::BloomFilter;
use std::borrow::Cow;
use std::convert::Infallible;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ReferenceIdRequest {
    payload_len: u16,
    offset: u16,
}

impl ReferenceIdRequest {
    pub const fn new(payload_len: u16, offset: u16) -> Option<Self> {
        if !payload_len.is_multiple_of(4) {
            return None;
        }

        if payload_len + offset > 512 {
            return None;
        }

        Some(Self {
            payload_len,
            offset,
        })
    }

    pub fn to_response(self, filter: &BloomFilter) -> Option<ReferenceIdResponse<'_>> {
        let offset = usize::from(self.offset);
        let payload_len = usize::from(self.payload_len);

        let bytes = filter.as_bytes().get(offset..)?.get(..payload_len)?.into();

        Some(ReferenceIdResponse { bytes })
    }

    pub fn serialize(&self, mut writer: impl NonBlockingWrite) -> std::io::Result<()> {
        let payload_len = self.payload_len;
        let ef_len: u16 = payload_len + 4;

        writer.write_all(
            &ExtensionFieldTypeId::ReferenceIdRequest
                .to_type_id()
                .to_be_bytes(),
        )?;
        writer.write_all(&ef_len.to_be_bytes())?;
        writer.write_all(&self.offset.to_be_bytes())?;
        writer.write_all(&[0; 2])?;

        let words = payload_len / 4;
        assert_eq!(payload_len % 4, 0);

        for _ in 1..words {
            writer.write_all(&[0; 4])?;
        }

        Ok(())
    }

    pub fn decode(msg: &[u8]) -> Result<Self, ParsingError<Infallible>> {
        let payload_len =
            u16::try_from(msg.len()).expect("NTP fields can not be longer than u16::MAX");
        let offset_bytes: [u8; 2] = msg
            .get(0..2)
            .ok_or(ParsingError::IncorrectLength)?
            .try_into()
            .unwrap();

        Ok(Self {
            payload_len,
            offset: u16::from_be_bytes(offset_bytes),
        })
    }

    pub const fn offset(&self) -> u16 {
        self.offset
    }

    pub const fn payload_len(&self) -> u16 {
        self.payload_len
    }
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ReferenceIdResponse<'a> {
    bytes: Cow<'a, [u8]>,
}

impl<'a> ReferenceIdResponse<'a> {
    pub const fn new(bytes: &'a [u8]) -> Option<Self> {
        if !bytes.len().is_multiple_of(4) {
            return None;
        }

        if bytes.len() > 512 {
            return None;
        }

        Some(Self {
            bytes: Cow::Borrowed(bytes),
        })
    }

    pub fn into_owned(self) -> ReferenceIdResponse<'static> {
        ReferenceIdResponse {
            bytes: Cow::Owned(self.bytes.into_owned()),
        }
    }

    pub fn serialize(&self, mut writer: impl NonBlockingWrite) -> std::io::Result<()> {
        let len: u16 = self.bytes.len().try_into().unwrap();
        let len = len + 4; // Add room for type and length

        writer.write_all(
            &ExtensionFieldTypeId::ReferenceIdResponse
                .to_type_id()
                .to_be_bytes(),
        )?;
        writer.write_all(&len.to_be_bytes())?;
        writer.write_all(self.bytes.as_ref())?;
        if !len.is_multiple_of(4) {
            writer.write_all(&[0u8; 3][..(4 - (len % 4)) as usize])?;
        }

        Ok(())
    }

    pub const fn decode(bytes: &'a [u8]) -> Self {
        Self {
            bytes: Cow::Borrowed(bytes),
        }
    }

    pub fn bytes(&self) -> &[u8] {
        &self.bytes
    }
}

impl From<ReferenceIdRequest> for ExtensionField<'static> {
    fn from(value: ReferenceIdRequest) -> Self {
        Self::ReferenceIdRequest(value)
    }
}

impl<'a> From<ReferenceIdResponse<'a>> for ExtensionField<'a> {
    fn from(value: ReferenceIdResponse<'a>) -> Self {
        Self::ReferenceIdResponse(value)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_reference_id_request_too_short() {
        assert!(matches!(
            ReferenceIdRequest::decode(&[]),
            Err(ParsingError::IncorrectLength)
        ));
    }

    #[test]
    fn test_reference_id_request_decode() {
        let res = ReferenceIdRequest::decode(&[0, 2, 0, 0, 0]).unwrap();
        assert_eq!(res.payload_len, 5);
        assert_eq!(res.offset, 2);
    }

    #[test]
    fn test_reference_id_response_serialize() {
        let mut out = vec![];
        ReferenceIdResponse {
            bytes: [0; 1].as_ref().into(),
        }
        .serialize(&mut out)
        .unwrap();
        assert_eq!(out, [0xF5, 4, 0, 5, 0, 0, 0, 0]);

        let mut out = vec![];
        ReferenceIdResponse {
            bytes: [0; 2].as_ref().into(),
        }
        .serialize(&mut out)
        .unwrap();
        assert_eq!(out, [0xF5, 4, 0, 6, 0, 0, 0, 0]);

        let mut out = vec![];
        ReferenceIdResponse {
            bytes: [0; 3].as_ref().into(),
        }
        .serialize(&mut out)
        .unwrap();
        assert_eq!(out, [0xF5, 4, 0, 7, 0, 0, 0, 0]);

        let mut out = vec![];
        ReferenceIdResponse {
            bytes: [0; 4].as_ref().into(),
        }
        .serialize(&mut out)
        .unwrap();
        assert_eq!(out, [0xF5, 4, 0, 8, 0, 0, 0, 0]);
    }
}