oxidef_compact1 0.1.0-alpha.1

Oxidef is an experimental interface definition language and serialisation scheme for efficient and strongly-typed payloads.
Documentation
use core::cmp::min;

use bytes::BufMut;
use oxidef_extratypes::{vu29::Vu29, vu64::Vu64};

use crate::{
    codec::Compact1Codec,
    decoder::{DecError, Decoder, MeasureBuf},
    encoder::{EncError, Encoder, ImprovidentBufMut},
    vu29,
};

impl Compact1Codec for Vu29 {
    fn encode<B: ImprovidentBufMut>(&self, encoder: &mut Encoder<B>) -> Result<(), EncError>
    where
        Self: Sized,
    {
        vu29::encode_vu29(encoder.buf.deref_mut(), u32::from(*self))
    }

    fn decode<B: MeasureBuf>(decoder: &mut Decoder<B>) -> Result<Self, DecError>
    where
        Self: Sized,
    {
        vu29::decode_vu29(&mut decoder.buf).map(Vu29::new)
    }
}

impl Compact1Codec for Vu64 {
    fn encode<B: ImprovidentBufMut>(&self, encoder: &mut Encoder<B>) -> Result<(), EncError>
    where
        Self: Sized,
    {
        if self.0 < ((1 << 21) - 1) {
            return vu29::encode_vu29(encoder.buf.deref_mut(), u32::try_from(self.0).unwrap());
        }
        let mut out = [0u8; 9];
        let encoding_len = encode_varint(VU64_MAX_PREFIX_BITS, self.0, &mut out);
        encoder.buf.put(&out[0..encoding_len]);
        Ok(())
    }

    fn decode<B: MeasureBuf>(decoder: &mut Decoder<B>) -> Result<Self, DecError>
    where
        Self: Sized,
    {
        let mut buf = [0u8; 9];
        buf[0] = decoder.buf.try_get_u8()?;
        let extra_bytes = decode_varint_header(VU64_MAX_PREFIX_BITS, buf[0]);
        decoder
            .buf
            .try_copy_to_slice(&mut buf[1..extra_bytes + 1])?;
        decode_exact_varint(VU64_MAX_PREFIX_BITS, &buf[0..1 + extra_bytes]).map(Vu64)
    }
}

fn calculate_space_for_varint_encoding(max_prefix_bits: u64, value: u64) -> (usize, bool, u64) {
    // first determine how many prefix bits we need
    let needed_data_bits = 64 - value.leading_zeros();
    // we start with a prefix of 1
    let mut available_data_bits = 7;
    let mut needed_prefix_bits = 1;
    // whether the final prefix bit is set.
    let mut final_form = false;
    let mut extra_bytes = 0;

    while available_data_bits < needed_data_bits {
        extra_bytes += 1;
        if needed_prefix_bits < max_prefix_bits {
            needed_prefix_bits += 1;
            available_data_bits += 7;
        } else {
            assert!(!final_form);
            final_form = true;
            available_data_bits += 8;
        }
    }

    assert!(available_data_bits >= needed_data_bits);

    (extra_bytes, final_form, needed_prefix_bits)
}

/// Encodes a varint and returns how many bytes were written.
fn encode_varint(max_prefix_bits: u64, mut value: u64, out: &mut [u8]) -> usize {
    let (extra_bytes, final_form, needed_prefix_bits) =
        calculate_space_for_varint_encoding(max_prefix_bits, value);

    let first_byte_bits = 8 - needed_prefix_bits;
    let prefix = if final_form {
        // produce `needed_prefix_bits` 1-bits
        ((1 << needed_prefix_bits) - 1) << first_byte_bits
    } else {
        // produce `needed_prefix_bits - 1` 1-bits, followed by one 0-bit
        ((1 << needed_prefix_bits) - 2) << first_byte_bits
    };

    out[0] = (prefix | (((1 << first_byte_bits) - 1) & value)) as u8;
    value >>= first_byte_bits;

    for idx in 1..=extra_bytes {
        out[idx] = value as u8;
        value >>= 8;
    }

    1 + extra_bytes
}

/// Returns how many extra bytes we need to read.
///
/// This is easily calculated: it's the number of leading 1-bits in the header,
/// clamped at the maximum number of prefix bits.
#[inline]
fn decode_varint_header(max_prefix_bits: u64, header: u8) -> usize {
    header.leading_ones().min(max_prefix_bits as u32) as usize
}

/// Decodes a varint, provided the input array is exactly the right size according to [`decode_varint_header`]
/// (otherwise may panic!).
fn decode_exact_varint(max_prefix_bits: u64, inp: &[u8]) -> Result<u64, DecError> {
    // how many bits are data bits in the first byte?
    let num_first_byte_bits = 8 - min(inp.len(), max_prefix_bits as usize);
    let mut out = (inp[0] & ((1 << num_first_byte_bits) - 1)) as u64;

    let mut shift = num_first_byte_bits;
    for extra_byte in &inp[1..] {
        out |= (*extra_byte as u64) << shift;
        shift += 8;
    }

    let (needed_extra_bytes, _, _) = calculate_space_for_varint_encoding(max_prefix_bits, out);

    if inp.len() != 1 + needed_extra_bytes {
        return Err(DecError::VarUintNotCanonical);
    }

    Ok(out)
}

pub const VU64_MAX_PREFIX_BITS: u64 = 8;

#[cfg(test)]
mod tests {
    use crate::primitives_extratypes::{decode_exact_varint, decode_varint_header};

    use super::{encode_varint, VU64_MAX_PREFIX_BITS};

    #[test]
    fn test_roundtrip_vu64() {
        for (test_case, expected_length) in [
            (0, 1),
            (1, 1),
            (126, 1),
            (127, 1),
            (128, 2),
            (16_383, 2),
            (16_384, 3),
            (2_097_151, 3),
            (2_097_152, 4),
            (268_435_455, 4),
            (268_435_456, 5),
            (34_359_738_367, 5),
            (34_359_738_368, 6),
            (4_398_046_511_103, 6),
            (4_398_046_511_104, 7),
            (562_949_953_421_311, 7),
            (562_949_953_421_312, 8),
            (72_057_594_037_927_935, 8),
            (72_057_594_037_927_936, 9),
            (18_446_744_073_709_551_615, 9), // max uint64
        ] {
            let mut out = [0u8; 9];
            let written = encode_varint(VU64_MAX_PREFIX_BITS, test_case, &mut out[..]);
            assert_eq!(
                expected_length, written,
                "{test_case} when encoded should be {expected_length} bytes, not {written} bytes"
            );

            // then round-trip it
            let how_many_more_bytes = decode_varint_header(VU64_MAX_PREFIX_BITS, out[0]);
            let roundtripped =
                decode_exact_varint(VU64_MAX_PREFIX_BITS, &out[0..how_many_more_bytes + 1])
                    .unwrap();
            assert_eq!(roundtripped, test_case);
        }
    }

    #[test]
    fn test_noncanonical_vu64() {
        const TEST_CASES: &[&[u8]] = &[
            // In vu64, for the 1-7 byte forms, if the last byte is 0x00 or 0x01 then that means
            // we could have used the shorter form.
            // For the 8 and 9-byte form, then the last byte being 0x00 means we could have used the shorter form.
            &[0xbf, 0x01],
            &[0xc0, 0xff, 0x01],
            &[0xe0, 0xff, 0xff, 0x01],
            &[0xf0, 0xff, 0xff, 0xff, 0x01],
            &[0xf8, 0xff, 0xff, 0xff, 0xff, 0x01],
            &[0xfc, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01],
            // 8- & 9-byte form: here only a trailing 0x00 is non-canonical
            &[0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00],
            &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00],
        ];

        for test_case in TEST_CASES {
            let how_many_more_bytes = decode_varint_header(VU64_MAX_PREFIX_BITS, test_case[0]);
            assert_eq!(how_many_more_bytes, test_case.len() - 1);

            assert!(
                decode_exact_varint(VU64_MAX_PREFIX_BITS, test_case).is_err(),
                "{test_case:?} should be rejected as non-canonical"
            );
        }
    }
}