use alloy_core::primitives::ruint::algorithms;
use sp_core::U256;
pub trait Modular {
#[must_use]
fn reduce_mod(self, modulus: Self) -> Self;
#[must_use]
fn add_mod(self, rhs: Self, modulus: Self) -> Self;
#[must_use]
fn mul_mod(self, rhs: Self, modulus: Self) -> Self;
}
impl Modular for U256 {
fn reduce_mod(mut self, modulus: Self) -> Self {
if modulus.is_zero() {
return Self::zero();
}
if self >= modulus {
self %= modulus;
}
self
}
fn add_mod(self, rhs: Self, modulus: Self) -> Self {
if modulus.is_zero() {
return Self::zero();
}
let lhs = self.reduce_mod(modulus);
let rhs = rhs.reduce_mod(modulus);
let (mut result, overflow) = lhs.overflowing_add(rhs);
if overflow || result >= modulus {
result = result.overflowing_sub(modulus).0;
}
result
}
fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
if modulus.is_zero() {
return Self::zero();
}
let mut product = [[0u64; 2]; 4];
let product_len = 8;
let product = unsafe {
core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
};
let overflow = algorithms::addmul(product, &self.0, &rhs.0);
debug_assert!(!overflow, "addmul overflowed for 256-bit inputs");
algorithms::div(product, &mut modulus.0);
modulus
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_core::primitives;
use proptest::proptest;
fn alloy_u256(limbs: [u64; 4]) -> primitives::U256 {
primitives::U256::from_limbs(limbs)
}
#[test]
fn test_reduce_mod() {
proptest!(|(a: [u64; 4], m: [u64; 4])| {
let ours = U256(a).reduce_mod(U256(m));
let theirs = alloy_u256(a).reduce_mod(alloy_u256(m));
assert_eq!(&ours.0, theirs.as_limbs());
});
}
#[test]
fn test_add_mod() {
proptest!(|(a: [u64; 4], b: [u64; 4], m: [u64; 4])| {
let ours = U256(a).add_mod(U256(b), U256(m));
let theirs = alloy_u256(a).add_mod(alloy_u256(b), alloy_u256(m));
assert_eq!(&ours.0, theirs.as_limbs());
});
}
#[test]
fn test_mul_mod() {
proptest!(|(a: [u64; 4], b: [u64; 4], m: [u64; 4])| {
let ours = U256(a).mul_mod(U256(b), U256(m));
let theirs = alloy_u256(a).mul_mod(alloy_u256(b), alloy_u256(m));
assert_eq!(&ours.0, theirs.as_limbs());
});
}
}