#[cfg(test)]
mod tests;
use crate::numeric::{Numeric, SignedInteger, UnsignedInteger};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Debug, PartialEq, Eq, Copy, Clone, Deserialize, Serialize)]
pub struct DecompositionBaseLog(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Deserialize, Serialize)]
pub struct DecompositionLevelCount(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Deserialize, Serialize)]
pub struct DecompositionLevel(pub usize);
pub trait SignedDecomposable: Sized {
fn round_to_closest_multiple(
self,
base_log: DecompositionBaseLog,
level: DecompositionLevelCount,
) -> Self;
fn signed_decompose_one_level(
self,
previous_carry: Self,
base_log: DecompositionBaseLog,
level: DecompositionLevel,
) -> (Self, Self);
fn set_val_at_level(self, base_log: DecompositionBaseLog, level: DecompositionLevel) -> Self;
}
macro_rules! implement {
($Type: tt) => {
impl SignedDecomposable for $Type {
fn round_to_closest_multiple(
self,
base_log: DecompositionBaseLog,
level: DecompositionLevelCount,
) -> Self {
let shift: usize = <Self as Numeric>::BITS - level.0 * base_log.0;
let mask = 1 << (shift - 1);
let b = (self & mask) >> (shift - 1);
let mut res = self >> shift;
res += b;
res <<= shift;
return res;
}
fn signed_decompose_one_level(
self,
previous_carry: Self,
base_log: DecompositionBaseLog,
level: DecompositionLevel,
) -> (Self, Self) {
let block_bit_mask: Self = (1 << base_log.0) - 1; let msb_block_mask: Self = 1 << (base_log.0 - 1); let mut tmp = (self >> (<Self as Numeric>::BITS - base_log.0 * (level.0 + 1)))
& block_bit_mask;
let mut carry: Self = tmp & msb_block_mask; tmp = tmp.wrapping_add(previous_carry); carry |= tmp & msb_block_mask; let left = tmp.into_signed();
let right = (carry << 1).into_signed();
let res = (left - right).into_unsigned(); carry >>= (base_log.0 - 1); return (res, carry);
}
fn set_val_at_level(
self,
base_log: DecompositionBaseLog,
level: DecompositionLevel,
) -> Self {
let mut res = 0;
let shift: usize = <Self as Numeric>::BITS - (base_log.0 * (level.0 + 1));
res += self << (shift);
return res;
}
}
};
}
implement!(u8);
implement!(u16);
implement!(u32);
implement!(u64);
implement!(u128);