use crate::ct::{
Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
};
pub type Limb = u64;
pub const LIMB_BITS: usize = 64;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Uint<const LIMBS: usize> {
limbs: [Limb; LIMBS],
}
#[inline]
pub(crate) const fn adc(a: Limb, b: Limb, carry: Limb) -> (Limb, Limb) {
let ret = (a as u128) + (b as u128) + (carry as u128);
(ret as Limb, (ret >> LIMB_BITS) as Limb)
}
#[inline]
pub(crate) const fn sbb(a: Limb, b: Limb, borrow: Limb) -> (Limb, Limb) {
let ret = (a as u128).wrapping_sub((b as u128) + (borrow as u128));
(ret as Limb, ((ret >> LIMB_BITS) as Limb) & 1)
}
impl<const LIMBS: usize> Uint<LIMBS> {
pub const LIMBS: usize = LIMBS;
pub const ZERO: Self = Uint { limbs: [0; LIMBS] };
pub const ONE: Self = Self::from_u64(1);
pub const fn from_u64(v: u64) -> Self {
let mut limbs = [0; LIMBS];
limbs[0] = v;
Uint { limbs }
}
#[inline]
pub const fn as_limbs(&self) -> &[Limb; LIMBS] {
&self.limbs
}
#[inline]
pub const fn from_limbs(limbs: [Limb; LIMBS]) -> Self {
Uint { limbs }
}
pub fn from_be_bytes(bytes: &[u8]) -> Self {
assert!(bytes.len() <= LIMBS * 8, "input too large for Uint");
let mut limbs = [0; LIMBS];
let mut end = bytes.len();
let mut i = 0;
while end > 0 {
let start = end.saturating_sub(8);
let mut buf = [0u8; 8];
let slice = &bytes[start..end];
buf[8 - slice.len()..].copy_from_slice(slice);
limbs[i] = Limb::from_be_bytes(buf);
i += 1;
end = start;
}
Uint { limbs }
}
pub fn write_be_bytes(&self, out: &mut [u8]) {
assert_eq!(out.len(), LIMBS * 8, "output buffer has wrong length");
for i in 0..LIMBS {
let limb = self.limbs[LIMBS - 1 - i];
out[i * 8..i * 8 + 8].copy_from_slice(&limb.to_be_bytes());
}
}
pub fn from_le_bytes(bytes: &[u8]) -> Self {
assert!(bytes.len() <= LIMBS * 8, "input too large for Uint");
let mut limbs = [0; LIMBS];
let mut i = 0;
while i * 8 < bytes.len() {
let end = (i * 8 + 8).min(bytes.len());
let mut buf = [0u8; 8];
buf[..end - i * 8].copy_from_slice(&bytes[i * 8..end]);
limbs[i] = Limb::from_le_bytes(buf);
i += 1;
}
Uint { limbs }
}
pub fn write_le_bytes(&self, out: &mut [u8]) {
assert_eq!(out.len(), LIMBS * 8, "output buffer has wrong length");
for i in 0..LIMBS {
out[i * 8..i * 8 + 8].copy_from_slice(&self.limbs[i].to_le_bytes());
}
}
pub fn adc(&self, rhs: &Self, carry: Limb) -> (Self, Limb) {
let mut limbs = [0; LIMBS];
let mut c = carry;
let mut i = 0;
while i < LIMBS {
let (s, co) = adc(self.limbs[i], rhs.limbs[i], c);
limbs[i] = s;
c = co;
i += 1;
}
(Uint { limbs }, c)
}
pub fn sbb(&self, rhs: &Self, borrow: Limb) -> (Self, Limb) {
let mut limbs = [0; LIMBS];
let mut b = borrow;
let mut i = 0;
while i < LIMBS {
let (d, bo) = sbb(self.limbs[i], rhs.limbs[i], b);
limbs[i] = d;
b = bo;
i += 1;
}
(Uint { limbs }, b)
}
#[inline]
pub fn wrapping_add(&self, rhs: &Self) -> Self {
self.adc(rhs, 0).0
}
#[inline]
pub fn wrapping_sub(&self, rhs: &Self) -> Self {
self.sbb(rhs, 0).0
}
#[inline]
pub fn is_zero(&self) -> Choice {
self.ct_eq(&Self::ZERO)
}
#[inline]
pub fn is_odd(&self) -> Choice {
Choice::from((self.limbs[0] & 1) as u8)
}
pub fn bit_len(&self) -> usize {
let mut i = LIMBS;
while i > 0 {
i -= 1;
if self.limbs[i] != 0 {
return i * LIMB_BITS + (LIMB_BITS - self.limbs[i].leading_zeros() as usize);
}
}
0
}
pub fn shr1(&self) -> Self {
let mut limbs = self.limbs;
let mut carry = 0;
let mut i = LIMBS;
while i > 0 {
i -= 1;
let next_carry = limbs[i] & 1;
limbs[i] = (limbs[i] >> 1) | (carry << (LIMB_BITS - 1));
carry = next_carry;
}
Uint { limbs }
}
pub fn reduce(&self, modulus: &Self) -> Self {
assert!(
!bool::from(modulus.is_zero()),
"Uint::reduce: modulus must be nonzero"
);
let mut r = Self::ZERO;
let mut i = LIMBS;
while i > 0 {
i -= 1;
let mut bit = LIMB_BITS;
while bit > 0 {
bit -= 1;
let (mut shifted, carry) = r.adc(&r, 0);
shifted.limbs[0] |= (self.limbs[i] >> bit) & 1;
let (diff, borrow) = shifted.sbb(modulus, 0);
let ge = Choice::from((carry | (borrow ^ 1)) as u8);
r = Self::conditional_select(&diff, &shifted, ge);
}
}
r
}
pub fn divrem(&self, divisor: &Self) -> (Self, Self) {
assert!(
!bool::from(divisor.is_zero()),
"Uint::divrem: divisor must be nonzero"
);
let mut q = Self::ZERO;
let mut r = Self::ZERO;
let mut i = LIMBS;
while i > 0 {
i -= 1;
let mut bit = LIMB_BITS;
while bit > 0 {
bit -= 1;
let (mut shifted, carry) = r.adc(&r, 0);
shifted.limbs[0] |= (self.limbs[i] >> bit) & 1;
let (mut q_shifted, _) = q.adc(&q, 0); let (diff, borrow) = shifted.sbb(divisor, 0);
let ge = Choice::from((carry | (borrow ^ 1)) as u8);
r = Self::conditional_select(&diff, &shifted, ge);
q_shifted.limbs[0] |= ge.unwrap_u8() as u64; q = q_shifted;
}
}
(q, r)
}
}
impl<const LIMBS: usize> Default for Uint<LIMBS> {
#[inline]
fn default() -> Self {
Self::ZERO
}
}
impl<const LIMBS: usize> ConstantTimeEq for Uint<LIMBS> {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.limbs.ct_eq(&other.limbs)
}
}
impl<const LIMBS: usize> ConstantTimeGreater for Uint<LIMBS> {
#[inline]
fn ct_gt(&self, other: &Self) -> Choice {
let (_, borrow) = other.sbb(self, 0);
Choice::from(borrow as u8)
}
}
impl<const LIMBS: usize> ConstantTimeLess for Uint<LIMBS> {}
impl<const LIMBS: usize> ConditionallySelectable for Uint<LIMBS> {
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut limbs = [0; LIMBS];
let mut i = 0;
while i < LIMBS {
limbs[i] = Limb::conditional_select(&a.limbs[i], &b.limbs[i], choice);
i += 1;
}
Uint { limbs }
}
}
#[cfg(test)]
mod tests {
use super::*;
type U128 = Uint<2>;
fn from_u128(v: u128) -> U128 {
Uint::from_limbs([v as u64, (v >> 64) as u64])
}
fn to_u128(u: &U128) -> u128 {
(u.as_limbs()[0] as u128) | ((u.as_limbs()[1] as u128) << 64)
}
const CASES: &[u128] = &[
0,
1,
2,
u64::MAX as u128,
(u64::MAX as u128) + 1,
u128::MAX,
u128::MAX - 1,
0x0123_4567_89ab_cdef_fedc_ba98_7654_3210,
1 << 64,
1 << 127,
];
#[test]
fn add_sub_match_u128() {
for &a in CASES {
for &b in CASES {
let (sum, carry) = from_u128(a).adc(&from_u128(b), 0);
assert_eq!(to_u128(&sum), a.wrapping_add(b));
assert_eq!(carry == 1, a.checked_add(b).is_none());
let (diff, borrow) = from_u128(a).sbb(&from_u128(b), 0);
assert_eq!(to_u128(&diff), a.wrapping_sub(b));
assert_eq!(borrow == 1, a < b);
}
}
}
#[test]
fn ct_compare_matches_u128() {
for &a in CASES {
for &b in CASES {
let (x, y) = (from_u128(a), from_u128(b));
assert_eq!(bool::from(x.ct_eq(&y)), a == b);
assert_eq!(bool::from(x.ct_gt(&y)), a > b);
assert_eq!(bool::from(x.ct_lt(&y)), a < b);
}
}
assert!(bool::from(U128::ZERO.is_zero()));
assert!(!bool::from(U128::ONE.is_zero()));
}
#[test]
fn conditional_select_picks_correctly() {
let a = from_u128(0xaaaa_aaaa_aaaa_aaaa);
let b = from_u128(0x5555_5555_5555_5555);
assert_eq!(U128::conditional_select(&a, &b, Choice::from(1)), a);
assert_eq!(U128::conditional_select(&a, &b, Choice::from(0)), b);
}
#[test]
fn be_bytes_roundtrip() {
let v = 0x0123_4567_89ab_cdef_fedc_ba98_7654_3210u128;
let u = from_u128(v);
let mut buf = [0u8; 16];
u.write_be_bytes(&mut buf);
assert_eq!(buf, v.to_be_bytes());
assert_eq!(U128::from_be_bytes(&buf), u);
assert_eq!(U128::from_be_bytes(&[0x01, 0x00]), from_u128(0x100));
assert_eq!(U128::from_be_bytes(&[]), U128::ZERO);
}
#[test]
fn larger_widths_compile_and_work() {
let mut a = Uint::<64>::ONE;
a = a.wrapping_add(&Uint::<64>::ONE);
assert_eq!(a.as_limbs()[0], 2);
assert!(bool::from(Uint::<64>::default().is_zero()));
}
#[test]
#[should_panic(expected = "modulus must be nonzero")]
fn reduce_zero_modulus_panics() {
let _ = Uint::<2>::from_u64(7).reduce(&Uint::<2>::ZERO);
}
#[test]
#[should_panic(expected = "divisor must be nonzero")]
fn divrem_zero_divisor_panics() {
let _ = Uint::<2>::from_u64(7).divrem(&Uint::<2>::ZERO);
}
}