safebit 0.1.0

Safe and secure bit access into integer types
Documentation
//! `Word` definitions and operations.

use crate::error::WordError;
use core::fmt::Debug;
use core::{mem::size_of, ops::Range};

/// Fixed-sized integer type.
pub trait Word:
    core::ops::Shr<usize, Output = Self>
    + core::ops::Shl<usize, Output = Self>
    + core::ops::Not<Output = Self>
    + core::ops::BitOr<Output = Self>
    + core::ops::BitOrAssign
    + core::ops::BitAnd<Output = Self>
    + core::cmp::PartialEq
    + Sized
    + Debug
    + Copy
{
    const IS_SIGNED: bool;
    const ZERO: Self;
    const ONE: Self;
    const BIT_LEN: usize;
    type UnsignedType: Word + TypeCast<Self>;
    type ByteSlice: Debug + Default + AsRef<[u8]> + AsMut<[u8]>;
    fn is_negative(&self) -> bool;
    fn from_le_bytes(slice: Self::ByteSlice) -> Self;
    fn from_be_bytes(slice: Self::ByteSlice) -> Self;
    fn from_ne_bytes(slice: Self::ByteSlice) -> Self;
    fn assert_platform_compatibility();

    /// Extract a bitslice of an input word and move it to the output word.
    ///
    /// Sign-extension is performed if OW is signed and the slice does not fit
    /// into OW.
    /// TODO: should we switch to start+len, to prevent ranges with start>end?
    fn get<OW>(self, range: Range<usize>) -> Result<OW, WordError>
    where
        Self: TypeCast<OW>,
        OW: Word + TypeCast<OW::UnsignedType>,
    {
        if range.start >= range.end {
            return Err(WordError::InvalidBounds);
        }

        let len = range.len();
        let output_type_len = OW::BIT_LEN;
        if len > output_type_len {
            return Err(WordError::InvalidBounds);
        }

        // Move the slice to "the right" as far as needed.
        let shifted: Self = self >> range.start; // TODO when does shift panic?

        // We checked the range's length, so shifted can be safely truncated to OW.
        let shifted_ow: OW = shifted.typecast();

        // We can return if we filled OW completely
        if len == output_type_len {
            return Ok(shifted_ow);
        }

        // Otherwise we have to set bits above the slice
        let ones: OW::UnsignedType = !OW::UnsignedType::ZERO;
        if OW::IS_SIGNED {
            let sign_mask = OW::ONE << (len - 1);
            let surplus_len = OW::BIT_LEN - len;
            if shifted_ow & sign_mask != OW::ZERO {
                // Set all bits above the slice to 1
                let mask = (ones >> len) << len; // TODO when does shift panic?
                Ok(mask.typecast() | shifted_ow)
            } else {
                // Set all bits above the slice to 0
                let mask = (ones << surplus_len) >> surplus_len; // TODO when does shift panic?
                Ok(mask.typecast() & shifted_ow)
            }
        } else {
            // Set all bits above the slice to 0
            let mask = (ones << len) >> len; // TODO when does shift panic?
            Ok(mask.typecast() & shifted_ow)
        }
    }
}

macro_rules! impl_word_u {
    ($w:ident, $unsigned_type:ident) => {
        impl Word for $w {
            const IS_SIGNED: bool = false;
            const ZERO: Self = 0;
            const ONE: Self = 1;
            const BIT_LEN: usize = Self::BITS as usize; //TODO can this panic on small archs?
            type UnsignedType = $unsigned_type;
            type ByteSlice = [u8; size_of::<Self>()];

            fn is_negative(&self) -> bool {
                false
            }

            fn from_le_bytes(slice: Self::ByteSlice) -> Self {
                Self::from_le_bytes(slice)
            }

            fn from_be_bytes(slice: Self::ByteSlice) -> Self {
                Self::from_be_bytes(slice)
            }

            fn from_ne_bytes(slice: Self::ByteSlice) -> Self {
                Self::from_ne_bytes(slice)
            }

            fn assert_platform_compatibility() {
                // Assert that BIT_LEN can be safely casted to usize without loss
                const _: () = assert!(($w::BITS as usize) as u32 == $w::BITS);
            }
        }
    };
}

macro_rules! impl_word_i {
    ($w:ident, $unsigned_type:ident) => {
        impl Word for $w {
            const IS_SIGNED: bool = true;
            const ZERO: Self = 0;
            const ONE: Self = 1;
            const BIT_LEN: usize = Self::BITS as _;
            type UnsignedType = $unsigned_type;
            type ByteSlice = [u8; size_of::<Self>()];

            fn is_negative(&self) -> bool {
                *self < 0
            }

            fn from_le_bytes(slice: Self::ByteSlice) -> Self {
                Self::from_le_bytes(slice)
            }

            fn from_be_bytes(slice: Self::ByteSlice) -> Self {
                Self::from_be_bytes(slice)
            }

            fn from_ne_bytes(slice: Self::ByteSlice) -> Self {
                Self::from_ne_bytes(slice)
            }

            fn assert_platform_compatibility() {
                // Assert that BIT_LEN can be safely casted to usize without loss
                const _: () = assert!(($w::BITS as usize) as u32 == $w::BITS);
            }
        }
    };
}

impl_word_u!(u8, u8);
impl_word_u!(u16, u16);
impl_word_u!(u32, u32);
impl_word_u!(u64, u64);
impl_word_u!(u128, u128);

impl_word_i!(i8, u8);
impl_word_i!(i16, u16);
impl_word_i!(i32, u32);
impl_word_i!(i64, u64);
impl_word_i!(i128, u128);

/// Trait for the `as` operator.
pub trait TypeCast<O> {
    fn typecast(self) -> O;
}

macro_rules! impl_typecast {
    ($iw:ident, $ow:ident) => {
        impl TypeCast<$ow> for $iw {
            fn typecast(self) -> $ow {
                self as _
            }
        }
    };
}

impl_typecast!(u8, u8);
impl_typecast!(u8, u16);
impl_typecast!(u8, u32);
impl_typecast!(u8, u64);
impl_typecast!(u8, u128);
impl_typecast!(u16, u8);
impl_typecast!(u16, u16);
impl_typecast!(u16, u32);
impl_typecast!(u16, u64);
impl_typecast!(u16, u128);
impl_typecast!(u32, u8);
impl_typecast!(u32, u16);
impl_typecast!(u32, u32);
impl_typecast!(u32, u64);
impl_typecast!(u32, u128);
impl_typecast!(u64, u8);
impl_typecast!(u64, u16);
impl_typecast!(u64, u32);
impl_typecast!(u64, u64);
impl_typecast!(u64, u128);
impl_typecast!(u128, u8);
impl_typecast!(u128, u16);
impl_typecast!(u128, u32);
impl_typecast!(u128, u64);
impl_typecast!(u128, u128);
impl_typecast!(u8, i8);
impl_typecast!(u8, i16);
impl_typecast!(u8, i32);
impl_typecast!(u8, i64);
impl_typecast!(u8, i128);
impl_typecast!(u16, i8);
impl_typecast!(u16, i16);
impl_typecast!(u16, i32);
impl_typecast!(u16, i64);
impl_typecast!(u16, i128);
impl_typecast!(u32, i8);
impl_typecast!(u32, i16);
impl_typecast!(u32, i32);
impl_typecast!(u32, i64);
impl_typecast!(u32, i128);
impl_typecast!(u64, i8);
impl_typecast!(u64, i16);
impl_typecast!(u64, i32);
impl_typecast!(u64, i64);
impl_typecast!(u64, i128);
impl_typecast!(u128, i8);
impl_typecast!(u128, i16);
impl_typecast!(u128, i32);
impl_typecast!(u128, i64);
impl_typecast!(u128, i128);
impl_typecast!(i8, i8);
impl_typecast!(i8, i16);
impl_typecast!(i8, i32);
impl_typecast!(i8, i64);
impl_typecast!(i8, i128);
impl_typecast!(i16, i8);
impl_typecast!(i16, i16);
impl_typecast!(i16, i32);
impl_typecast!(i16, i64);
impl_typecast!(i16, i128);
impl_typecast!(i32, i8);
impl_typecast!(i32, i16);
impl_typecast!(i32, i32);
impl_typecast!(i32, i64);
impl_typecast!(i32, i128);
impl_typecast!(i64, i8);
impl_typecast!(i64, i16);
impl_typecast!(i64, i32);
impl_typecast!(i64, i64);
impl_typecast!(i64, i128);
impl_typecast!(i128, i8);
impl_typecast!(i128, i16);
impl_typecast!(i128, i32);
impl_typecast!(i128, i64);
impl_typecast!(i128, i128);
impl_typecast!(i8, u8);
impl_typecast!(i8, u16);
impl_typecast!(i8, u32);
impl_typecast!(i8, u64);
impl_typecast!(i8, u128);
impl_typecast!(i16, u8);
impl_typecast!(i16, u16);
impl_typecast!(i16, u32);
impl_typecast!(i16, u64);
impl_typecast!(i16, u128);
impl_typecast!(i32, u8);
impl_typecast!(i32, u16);
impl_typecast!(i32, u32);
impl_typecast!(i32, u64);
impl_typecast!(i32, u128);
impl_typecast!(i64, u8);
impl_typecast!(i64, u16);
impl_typecast!(i64, u32);
impl_typecast!(i64, u64);
impl_typecast!(i64, u128);
impl_typecast!(i128, u8);
impl_typecast!(i128, u16);
impl_typecast!(i128, u32);
impl_typecast!(i128, u64);
impl_typecast!(i128, u128);

#[cfg(test)]
mod tests {
    extern crate alloc;
    extern crate std;
    use crate::error::WordError;
    use crate::util::tests::init_logger;
    use crate::word::TypeCast;
    use crate::word::Word;
    use alloc::format;
    use log::trace;

    #[test]
    fn test_get() {
        // uint -> uint
        let t0: u16 = 0xff00;
        assert_eq!(t0.get::<u16>(4..12).unwrap(), 0xf0);

        // uint -> smaller uint
        let t1: u16 = 0xff00;
        assert_eq!(t1.get::<u8>(4..12).unwrap(), 0xf0);

        // sint -> smaller uint
        let t2: i16 = 0xff00 as _;
        assert_eq!(t2.get::<u8>(4..12).unwrap(), 0xf0);

        // sint -> smaller sint
        let t3: i16 = 0xff00 as _;
        assert_eq!(t3.get::<i8>(4..12).unwrap(), -16);
        assert_eq!(t3.get::<i8>(0..1).unwrap(), 0);

        // 0b1111_1111_1110_1001 ->
        //                   001
        let t4: i16 = -23;
        assert_eq!(t4.get::<i8>(0..3).unwrap(), 1);

        // uint -> bigger sint
        let t5: u8 = 0xff;
        assert_eq!(t5.get::<i16>(0..8).unwrap(), -1);
    }

    #[test]
    fn test_u8_to_i16() {
        init_logger();
        for word in 0..u8::MAX {
            let word_str = format!("{word:08b}");
            let word_i16 = word as i8 as i16;
            let word_i16_str = format!("{word_i16:016b}");

            // Take every possible subslice
            for start_pos in 0..word_str.len() {
                for end_pos in 0..word_str.len() {
                    let read_result: Result<i16, WordError> = word.get(start_pos..end_pos);
                    trace!("{word_i16_str}[{start_pos}..{end_pos}]");
                    match read_result {
                        Ok(read_value) => {
                            // Convert read value into bit string
                            let i = read_value as u16;
                            let read_value_string = format!("{i:016b}");

                            // Convert ground truth into bit string
                            let inverted_end = 16 - start_pos;
                            let inverted_start = 16 - end_pos;
                            let substr = &word_i16_str[inverted_start..inverted_end];
                            let pad_length = 16 - (end_pos - start_pos);
                            trace!("{pad_length} {substr}");
                            let ground_truth_string = if substr.starts_with('1') {
                                let padding = "1".repeat(pad_length);
                                format!("{padding}{substr}")
                            } else {
                                let padding = "0".repeat(pad_length);
                                format!("{padding}{substr}")
                            };

                            assert_eq!(read_value_string, ground_truth_string)
                        }
                        Err(WordError::InvalidBounds) => {
                            // If invalid bounds are reported, they better be invalid.
                            trace!("invalid bounds");
                            assert!(end_pos <= start_pos)
                        }
                    }
                }
            }
        }
    }

    #[test]
    fn test_typecast_signed() {
        // Sign extension is only performed if *both* types are signed.
        let t0: u8 = 0xff;
        let t1: i16 = t0.typecast();
        assert_eq!(t1, 0xff);

        let t3: u8 = 0xff;
        let t4: i8 = t3 as _;
        let t5: i16 = t4.typecast();
        assert_eq!(t5, -1);
    }

    #[test]
    fn test_overflowing_get() {
        let t0: u16 = 42;
        assert_eq!(t0.get::<u8>(0..9), Err(WordError::InvalidBounds));
    }
}