#![allow(clippy::should_implement_trait)]
use crate::trit::{Trit, INV, N, P, Z};
const PAIRS: usize = 64;
const LO_MASK: u128 = 0x5555_5555_5555_5555_5555_5555_5555_5555;
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct Tern(pub(crate) u128);
impl Tern {
pub const ZERO: Tern = Tern(0);
pub const ONE: Tern = Tern(0b10); pub const NEG_ONE: Tern = Tern(0b01);
#[inline]
pub fn trit_at(self, pos: usize) -> Trit {
debug_assert!(pos < PAIRS);
Trit(((self.0 >> (pos * 2)) & 0b11) as u8)
}
#[inline]
pub fn with_trit(mut self, pos: usize, t: Trit) -> Tern {
debug_assert!(pos < PAIRS);
let shift = pos * 2;
self.0 = (self.0 & !(0b11u128 << shift)) | ((t.0 as u128) << shift);
self
}
#[inline]
pub fn is_zero(self) -> bool {
self.0 == 0
}
pub fn is_valid(self) -> bool {
let hi = (self.0 >> 1) & LO_MASK; let lo = self.0 & LO_MASK; (hi & lo) == 0 }
pub fn sign(self) -> Trit {
if self.is_zero() {
return Z;
}
for i in (0..PAIRS).rev() {
let t = self.trit_at(i);
if !t.is_zero() {
return t.sign();
}
}
Z
}
pub fn neg(self) -> Tern {
let hi = (self.0 >> 1) & LO_MASK;
let lo = self.0 & LO_MASK;
let diff = hi ^ lo; let flip_mask = diff | (diff << 1);
Tern(self.0 ^ flip_mask)
}
pub fn abs(self) -> Tern {
if self.sign() == N {
self.neg()
} else {
self
}
}
pub fn add(self, rhs: Tern) -> Tern {
let mut result = 0u128;
let mut carry = Z;
for i in 0..PAIRS {
let (digit, next_carry) = add3(self.trit_at(i), rhs.trit_at(i), carry);
result |= (digit.0 as u128) << (i * 2);
carry = next_carry;
}
debug_assert!(carry.is_zero(), "Tern overflow beyond 64 trits");
Tern(result)
}
#[inline]
pub fn sub(self, rhs: Tern) -> Tern {
self.add(rhs.neg())
}
pub fn mul(self, rhs: Tern) -> Tern {
let mut result = Tern::ZERO;
for i in 0..PAIRS {
let t = rhs.trit_at(i);
if t.is_zero() {
continue;
}
let partial = self.scale(t).shift_up(i);
result = result.add(partial);
}
result
}
fn scale(self, t: Trit) -> Tern {
match t.0 {
0b00 => Tern::ZERO,
0b01 => self.neg(), 0b10 => self, _ => Tern::ZERO, }
}
fn shift_up(self, positions: usize) -> Tern {
if positions == 0 {
return self;
}
if positions >= PAIRS {
return Tern::ZERO;
}
Tern(self.0 << (positions * 2))
}
pub fn div_rem(self, rhs: Tern) -> (Tern, Tern) {
assert!(!rhs.is_zero(), "division by zero");
if self.is_zero() {
return (Tern::ZERO, Tern::ZERO);
}
let mut remainder = self;
let mut quotient = Tern::ZERO;
for i in (0..PAIRS).rev() {
let shifted = rhs.shift_up(i);
if shifted.is_zero() {
continue;
}
let abs_rem = remainder.abs();
let trial_sub = remainder.sub(shifted);
let trial_add = remainder.add(shifted);
let abs_sub = trial_sub.abs();
let abs_add = trial_add.abs();
if abs_sub < abs_add {
if abs_sub < abs_rem {
remainder = trial_sub;
quotient = quotient.with_trit(i, P);
}
} else if abs_add < abs_rem {
remainder = trial_add;
quotient = quotient.with_trit(i, N);
}
}
if !remainder.is_zero() && remainder.sign() == N {
let abs_rhs = rhs.abs();
remainder = remainder.add(abs_rhs);
if rhs.sign() == P {
quotient = quotient.sub(Tern::ONE);
} else {
quotient = quotient.add(Tern::ONE);
}
}
(quotient, remainder)
}
pub fn gcd(a: Tern, b: Tern) -> Tern {
let mut a = a.abs();
let mut b = b.abs();
while !b.is_zero() {
let (_, r) = a.div_rem(b);
a = b;
b = r;
}
a
}
pub fn from_i64(mut v: i64) -> Tern {
let mut result = Tern::ZERO;
let mut i = 0usize;
while v != 0 && i < PAIRS {
let rem = v.rem_euclid(3);
let (trit, next) = match rem {
0 => (Z, v / 3),
1 => (P, (v - 1) / 3),
2 => (N, (v + 1) / 3),
_ => unreachable!(),
};
result = result.with_trit(i, trit);
v = next;
i += 1;
}
result
}
pub fn to_i64(self) -> i64 {
let mut result: i64 = 0;
let mut power: i64 = 1;
for i in 0..PAIRS {
let t = self.trit_at(i);
if let Some(v) = t.to_i8() {
result = result
.checked_add(v as i64 * power)
.expect("Tern::to_i64 overflow");
}
if i < PAIRS - 1 {
power = power.saturating_mul(3);
}
}
result
}
pub fn content_hash(bytes: &[u8]) -> Tern {
let base = Tern::from_i64(257);
let modulus = Tern::from_i64(2_000_000_011);
let mut hash = Tern::ZERO;
for &b in bytes {
hash = hash.mul(base).add(Tern::from_i64(b as i64));
let (_, rem) = hash.div_rem(modulus);
hash = rem;
}
hash
}
}
#[inline]
fn add3(a: Trit, b: Trit, c: Trit) -> (Trit, Trit) {
let sum = a.to_i8().unwrap_or(0) as i16
+ b.to_i8().unwrap_or(0) as i16
+ c.to_i8().unwrap_or(0) as i16;
match sum {
-3 => (Z, N), -2 => (P, N), -1 => (N, Z),
0 => (Z, Z),
1 => (P, Z),
2 => (N, P), 3 => (Z, P), _ => (INV, Z),
}
}
impl PartialOrd for Tern {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Tern {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let diff = self.sub(*other);
match diff.sign() {
N => std::cmp::Ordering::Less,
Z => std::cmp::Ordering::Equal,
P => std::cmp::Ordering::Greater,
_ => std::cmp::Ordering::Equal,
}
}
}
impl std::fmt::Debug for Tern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Tern({})", self.to_i64())
}
}
impl std::fmt::Display for Tern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_i64())
}
}
impl std::ops::Add for Tern {
type Output = Tern;
fn add(self, rhs: Tern) -> Tern {
Tern::add(self, rhs)
}
}
impl std::ops::Sub for Tern {
type Output = Tern;
fn sub(self, rhs: Tern) -> Tern {
Tern::sub(self, rhs)
}
}
impl std::ops::Mul for Tern {
type Output = Tern;
fn mul(self, rhs: Tern) -> Tern {
Tern::mul(self, rhs)
}
}
impl std::ops::Neg for Tern {
type Output = Tern;
fn neg(self) -> Tern {
Tern::neg(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn t(v: i64) -> Tern {
Tern::from_i64(v)
}
#[test]
fn zero_roundtrip() {
assert_eq!(t(0).to_i64(), 0);
assert!(t(0).is_zero());
}
#[test]
fn roundtrip_range() {
for v in -100i64..=100 {
assert_eq!(t(v).to_i64(), v, "roundtrip failed for {v}");
}
}
#[test]
fn addition_exhaustive() {
for a in -30i64..=30 {
for b in -30i64..=30 {
assert_eq!(t(a).add(t(b)).to_i64(), a + b, "{a}+{b}");
}
}
}
#[test]
fn subtraction_exhaustive() {
for a in -30i64..=30 {
for b in -30i64..=30 {
assert_eq!(t(a).sub(t(b)).to_i64(), a - b, "{a}-{b}");
}
}
}
#[test]
fn multiplication_exhaustive() {
for a in -20i64..=20 {
for b in -20i64..=20 {
assert_eq!(t(a).mul(t(b)).to_i64(), a * b, "{a}*{b}");
}
}
}
#[test]
fn division_exhaustive() {
for a in -30i64..=30 {
for b in -30i64..=30 {
if b == 0 {
continue;
}
let (q, r) = t(a).div_rem(t(b));
assert_eq!(
q.to_i64() * b + r.to_i64(),
a,
"div_rem({a},{b}): q={}, r={}",
q.to_i64(),
r.to_i64()
);
let r_val = r.to_i64();
let b_abs = b.abs();
assert!(
r_val >= 0 && r_val < b_abs,
"non-Euclidean remainder for div_rem({a},{b}): q={}, r={}",
q.to_i64(),
r_val
);
}
}
}
#[test]
fn negation_involution() {
for v in -50i64..=50 {
assert_eq!(t(v).neg().neg().to_i64(), v);
}
}
#[test]
fn sign_correct() {
assert_eq!(t(-5).sign(), N);
assert_eq!(t(0).sign(), Z);
assert_eq!(t(7).sign(), P);
}
#[test]
fn abs_nonnegative() {
for v in -20i64..=20 {
let a = t(v).abs();
assert!(a.sign() != N, "abs of {v} is negative");
assert_eq!(a.to_i64(), v.abs());
}
}
#[test]
fn gcd_basic() {
assert_eq!(Tern::gcd(t(12), t(8)).to_i64(), 4);
assert_eq!(Tern::gcd(t(17), t(13)).to_i64(), 1);
assert_eq!(Tern::gcd(t(0), t(5)).to_i64(), 5);
}
#[test]
fn ordering() {
assert!(t(-5) < t(3));
assert!(t(3) > t(-5));
assert_eq!(t(4).cmp(&t(4)), std::cmp::Ordering::Equal);
}
#[test]
fn valid_packed_word() {
for v in -100i64..=100 {
assert!(t(v).is_valid(), "Tern({v}) is invalid");
}
}
#[test]
fn zero_is_memset_safe() {
assert_eq!(Tern::ZERO.0, 0u128);
assert!(Tern::ZERO.is_zero());
}
#[test]
fn content_hash_deterministic() {
let h1 = Tern::content_hash(b"neberu");
let h2 = Tern::content_hash(b"neberu");
assert_eq!(h1, h2);
}
#[test]
fn content_hash_distinct() {
let h1 = Tern::content_hash(b"hello");
let h2 = Tern::content_hash(b"world");
assert_ne!(h1, h2);
}
#[test]
fn large_value_roundtrip() {
let v = 1_000_000_000i64;
assert_eq!(t(v).to_i64(), v);
assert_eq!(t(-v).to_i64(), -v);
}
#[test]
fn trit_at_and_with_trit() {
let mut x = Tern::ZERO;
x = x.with_trit(0, P);
x = x.with_trit(1, N);
x = x.with_trit(2, Z);
assert_eq!(x.trit_at(0), P);
assert_eq!(x.trit_at(1), N);
assert_eq!(x.trit_at(2), Z);
assert_eq!(x.to_i64(), -2);
}
#[test]
fn no_i128_bridge_in_div_rem() {
let a = t(999_983);
let b = t(997);
let (q, r) = a.div_rem(b);
assert_eq!(q.to_i64() * 997 + r.to_i64(), 999_983);
}
}