use crate::fixed_point::core_types::errors::OverflowDetected;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i8)]
pub enum Trit {
Neg = -1,
Zero = 0,
Pos = 1,
}
impl Trit {
#[inline]
pub fn from_i8(val: i8) -> Result<Self, OverflowDetected> {
match val {
-1 => Ok(Trit::Neg),
0 => Ok(Trit::Zero),
1 => Ok(Trit::Pos),
_ => Err(OverflowDetected::InvalidInput),
}
}
#[inline]
pub const fn as_i8(self) -> i8 {
self as i8
}
#[inline]
const fn to_packed(self) -> u8 {
(self as i8 + 1) as u8
}
#[inline]
fn from_packed(val: u8) -> Result<Self, OverflowDetected> {
match val {
0 => Ok(Trit::Neg),
1 => Ok(Trit::Zero),
2 => Ok(Trit::Pos),
_ => Err(OverflowDetected::InvalidInput),
}
}
}
pub fn pack_trits(trits: &[Trit]) -> Vec<u8> {
let num_bytes = (trits.len() + 4) / 5;
let mut packed = Vec::with_capacity(num_bytes);
let mut i = 0;
while i < trits.len() {
let mut byte: u8 = 0;
for j in 0..5 {
byte *= 3;
if i + j < trits.len() {
byte += trits[i + j].to_packed();
} else {
byte += Trit::Zero.to_packed(); }
}
packed.push(byte);
i += 5;
}
packed
}
pub fn unpack_trits(data: &[u8], count: usize) -> Result<Vec<Trit>, OverflowDetected> {
let mut trits = Vec::with_capacity(count);
let mut extracted = 0;
for &byte in data {
if extracted >= count {
break;
}
let mut remaining = byte;
let mut chunk = [Trit::Zero; 5];
for j in (0..5).rev() {
let d = remaining % 3;
remaining /= 3;
chunk[j] = Trit::from_packed(d)?;
}
for j in 0..5 {
if extracted >= count {
break;
}
trits.push(chunk[j]);
extracted += 1;
}
}
Ok(trits)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_unpack_roundtrip() {
let trits = vec![Trit::Pos, Trit::Neg, Trit::Zero, Trit::Pos, Trit::Neg];
let packed = pack_trits(&trits);
assert_eq!(packed.len(), 1); let unpacked = unpack_trits(&packed, 5).unwrap();
assert_eq!(trits, unpacked);
}
#[test]
fn test_pack_unpack_roundtrip_partial() {
let trits = vec![
Trit::Pos, Trit::Zero, Trit::Neg, Trit::Pos, Trit::Zero,
Trit::Neg, Trit::Pos,
];
let packed = pack_trits(&trits);
assert_eq!(packed.len(), 2);
let unpacked = unpack_trits(&packed, 7).unwrap();
assert_eq!(trits, unpacked);
}
#[test]
fn test_pack_all_zeros() {
let trits = vec![Trit::Zero; 10];
let packed = pack_trits(&trits);
assert_eq!(packed.len(), 2);
let unpacked = unpack_trits(&packed, 10).unwrap();
assert_eq!(trits, unpacked);
}
#[test]
fn test_pack_all_pos() {
let trits = vec![Trit::Pos; 5];
let packed = pack_trits(&trits);
assert_eq!(packed[0], 242);
let unpacked = unpack_trits(&packed, 5).unwrap();
assert_eq!(trits, unpacked);
}
#[test]
fn test_pack_all_neg() {
let trits = vec![Trit::Neg; 5];
let packed = pack_trits(&trits);
assert_eq!(packed[0], 0);
let unpacked = unpack_trits(&packed, 5).unwrap();
assert_eq!(trits, unpacked);
}
#[test]
fn test_encoding_fits_in_u8() {
assert!(242u8 <= u8::MAX);
assert_eq!(3u32.pow(5), 243);
}
#[test]
fn test_empty() {
let trits: Vec<Trit> = vec![];
let packed = pack_trits(&trits);
assert!(packed.is_empty());
let unpacked = unpack_trits(&packed, 0).unwrap();
assert!(unpacked.is_empty());
}
#[test]
fn test_single_trit() {
for trit in [Trit::Neg, Trit::Zero, Trit::Pos] {
let packed = pack_trits(&[trit]);
let unpacked = unpack_trits(&packed, 1).unwrap();
assert_eq!(unpacked[0], trit);
}
}
#[test]
fn test_large_roundtrip() {
let mut trits = Vec::with_capacity(1000);
for i in 0..1000 {
trits.push(match i % 3 {
0 => Trit::Neg,
1 => Trit::Zero,
_ => Trit::Pos,
});
}
let packed = pack_trits(&trits);
assert_eq!(packed.len(), 200);
let unpacked = unpack_trits(&packed, 1000).unwrap();
assert_eq!(trits, unpacked);
}
#[test]
fn test_trit_conversions() {
assert_eq!(Trit::from_i8(-1).unwrap(), Trit::Neg);
assert_eq!(Trit::from_i8(0).unwrap(), Trit::Zero);
assert_eq!(Trit::from_i8(1).unwrap(), Trit::Pos);
assert!(Trit::from_i8(2).is_err());
assert!(Trit::from_i8(-2).is_err());
}
}