aldrin-core 0.13.0

Shared core components of Aldrin, a message bus for service-oriented RPC and interprocess communication.
Documentation
#[cfg(test)]
mod test;

use crate::message::MessageDeserializeError;
use crate::DeserializeError;
use bytes::{Buf, BufMut, Bytes};

pub(crate) trait BufMutExt: BufMut {
    fn put_discriminant_u8(&mut self, discriminant: impl Into<u8>) {
        self.put_u8(discriminant.into())
    }

    fn put_varint_u16_le(&mut self, n: u16) {
        self.put_varint_le(n.to_le_bytes());
    }

    fn put_varint_i16_le(&mut self, n: i16) {
        self.put_varint_u16_le(zigzag_encode_i16(n));
    }

    fn put_varint_u32_le(&mut self, n: u32) {
        self.put_varint_le(n.to_le_bytes());
    }

    fn put_varint_i32_le(&mut self, n: i32) {
        self.put_varint_u32_le(zigzag_encode_i32(n));
    }

    fn put_varint_u64_le(&mut self, n: u64) {
        self.put_varint_le(n.to_le_bytes());
    }

    fn put_varint_i64_le(&mut self, n: i64) {
        self.put_varint_u64_le(zigzag_encode_i64(n));
    }

    fn put_varint_le<const N: usize>(&mut self, bytes: [u8; N]) {
        for (i, n) in bytes.into_iter().rev().enumerate().take(N - 1) {
            if n != 0 {
                self.put_u8(255 - i as u8);
                self.put_slice(&bytes[..N - i]);
                return;
            }
        }

        if bytes[0] > 255 - N as u8 {
            self.put_u8(255 - N as u8 + 1);
        }

        self.put_u8(bytes[0]);
    }
}

impl<T: BufMut + ?Sized> BufMutExt for T {}

pub(crate) trait ValueBufExt: Buf {
    fn try_get_discriminant_u8<T: TryFrom<u8>>(&mut self) -> Result<T, DeserializeError> {
        self.try_get_u8()
            .map_err(|_| DeserializeError::UnexpectedEoi)?
            .try_into()
            .map_err(|_| DeserializeError::InvalidSerialization)
    }

    fn try_peek_discriminant_u8<T: TryFrom<u8>>(&self) -> Result<T, DeserializeError> {
        if self.remaining() >= 1 {
            self.chunk()[0]
                .try_into()
                .map_err(|_| DeserializeError::InvalidSerialization)
        } else {
            Err(DeserializeError::UnexpectedEoi)
        }
    }

    fn ensure_discriminant_u8<T: TryFrom<u8> + PartialEq>(
        &mut self,
        discriminant: T,
    ) -> Result<(), DeserializeError> {
        if self.try_get_discriminant_u8::<T>()? == discriminant {
            Ok(())
        } else {
            Err(DeserializeError::UnexpectedValue)
        }
    }

    fn try_get_varint_u16_le(&mut self) -> Result<u16, DeserializeError> {
        self.try_get_varint_le().map(u16::from_le_bytes)
    }

    fn try_get_varint_i16_le(&mut self) -> Result<i16, DeserializeError> {
        self.try_get_varint_u16_le().map(zigzag_decode_i16)
    }

    fn try_get_varint_u32_le(&mut self) -> Result<u32, DeserializeError> {
        self.try_get_varint_le().map(u32::from_le_bytes)
    }

    fn try_get_varint_i32_le(&mut self) -> Result<i32, DeserializeError> {
        self.try_get_varint_u32_le().map(zigzag_decode_i32)
    }

    fn try_get_varint_u64_le(&mut self) -> Result<u64, DeserializeError> {
        self.try_get_varint_le().map(u64::from_le_bytes)
    }

    fn try_get_varint_i64_le(&mut self) -> Result<i64, DeserializeError> {
        self.try_get_varint_u64_le().map(zigzag_decode_i64)
    }

    fn try_get_varint_le<const N: usize>(&mut self) -> Result<[u8; N], DeserializeError> {
        let mut bytes = [0; N];

        let first = self
            .try_get_u8()
            .map_err(|_| DeserializeError::UnexpectedEoi)?;

        if first > 255 - N as u8 {
            let num_bytes = first as usize + N - 255;

            self.try_copy_to_slice(&mut bytes[..num_bytes])
                .map_err(|_| DeserializeError::UnexpectedEoi)?;
        } else {
            bytes[0] = first;
        }

        Ok(bytes)
    }

    fn try_copy_to_bytes(&mut self, len: usize) -> Result<Bytes, DeserializeError> {
        if self.remaining() >= len {
            Ok(self.copy_to_bytes(len))
        } else {
            Err(DeserializeError::UnexpectedEoi)
        }
    }

    fn try_skip(&mut self, len: usize) -> Result<(), DeserializeError> {
        if self.remaining() >= len {
            self.advance(len);
            Ok(())
        } else {
            Err(DeserializeError::UnexpectedEoi)
        }
    }

    fn try_skip_varint_le<const N: usize>(&mut self) -> Result<(), DeserializeError> {
        let first = self
            .try_get_u8()
            .map_err(|_| DeserializeError::UnexpectedEoi)?;

        if first > 255 - N as u8 {
            let num_bytes = first as usize + N - 255;
            self.try_skip(num_bytes)?;
        }

        Ok(())
    }
}

impl<T: Buf + ?Sized> ValueBufExt for T {}

pub(crate) trait MessageBufExt: Buf {
    fn try_get_discriminant_u8<T: TryFrom<u8>>(&mut self) -> Result<T, MessageDeserializeError> {
        self.try_get_u8()
            .map_err(|_| MessageDeserializeError::UnexpectedEoi)?
            .try_into()
            .map_err(|_| MessageDeserializeError::InvalidSerialization)
    }

    fn ensure_discriminant_u8<T: TryFrom<u8> + PartialEq>(
        &mut self,
        discriminant: T,
    ) -> Result<(), MessageDeserializeError> {
        if self.try_get_discriminant_u8::<T>()? == discriminant {
            Ok(())
        } else {
            Err(MessageDeserializeError::UnexpectedMessage)
        }
    }

    fn try_get_varint_u32_le(&mut self) -> Result<u32, MessageDeserializeError> {
        self.try_get_varint_le().map(u32::from_le_bytes)
    }

    fn try_get_varint_le<const N: usize>(&mut self) -> Result<[u8; N], MessageDeserializeError> {
        let mut bytes = [0; N];

        let first = self
            .try_get_u8()
            .map_err(|_| MessageDeserializeError::UnexpectedEoi)?;

        if first > 255 - N as u8 {
            let num_bytes = first as usize + N - 255;

            self.try_copy_to_slice(&mut bytes[..num_bytes])
                .map_err(|_| MessageDeserializeError::UnexpectedEoi)?;
        } else {
            bytes[0] = first;
        }

        Ok(bytes)
    }
}

impl<T: Buf + ?Sized> MessageBufExt for T {}

fn zigzag_encode_i16(n: i16) -> u16 {
    (n >> 15) as u16 ^ (n << 1) as u16
}

fn zigzag_decode_i16(n: u16) -> i16 {
    (n >> 1) as i16 ^ -((n & 1) as i16)
}

fn zigzag_encode_i32(n: i32) -> u32 {
    (n >> 31) as u32 ^ (n << 1) as u32
}

fn zigzag_decode_i32(n: u32) -> i32 {
    (n >> 1) as i32 ^ -((n & 1) as i32)
}

fn zigzag_encode_i64(n: i64) -> u64 {
    (n >> 63) as u64 ^ (n << 1) as u64
}

fn zigzag_decode_i64(n: u64) -> i64 {
    (n >> 1) as i64 ^ -((n & 1) as i64)
}