#[inline(always)]
pub fn mod_mul_28(a: u32, b: u32, q: u32) -> u32 {
debug_assert!(a < q, "mod_mul_28: a={a} >= q={q}");
debug_assert!(b < q, "mod_mul_28: b={b} >= q={q}");
debug_assert!(q < (1u32 << 28), "mod_mul_28: q={q} >= 2^28");
((a as u64 * b as u64) % q as u64) as u32
}
#[inline(always)]
pub fn mod_add_28(a: u32, b: u32, q: u32) -> u32 {
debug_assert!(a < q, "mod_add_28: a={a} >= q={q}");
debug_assert!(b < q, "mod_add_28: b={b} >= q={q}");
let s = a + b;
let mask = ((s >= q) as u32).wrapping_neg();
s.wrapping_sub(q & mask)
}
#[inline(always)]
pub fn mod_sub_28(a: u32, b: u32, q: u32) -> u32 {
debug_assert!(a < q, "mod_sub_28: a={a} >= q={q}");
debug_assert!(b < q, "mod_sub_28: b={b} >= q={q}");
let mask = ((a < b) as u32).wrapping_neg();
a.wrapping_sub(b).wrapping_add(q & mask)
}
pub fn mod_pow_32(base: u32, exp: u32, q: u32) -> u32 {
if exp == 0 {
return 1;
}
if base == 0 {
return 0;
}
let mut result = 1u64;
let mut b = (base % q) as u64;
let q64 = q as u64;
let mut e = exp;
while e > 0 {
if e & 1 == 1 {
result = result * b % q64;
}
e >>= 1;
if e > 0 {
b = b * b % q64;
}
}
result as u32
}
pub fn mod_inv_32(a: u32, q: u32) -> u32 {
assert!(a != 0, "No inverse for zero");
mod_pow_32(a, q - 2, q)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mod_mul_28_basic() {
let q = 268_435_399u32;
assert_eq!(mod_mul_28(0, 123, q), 0);
assert_eq!(mod_mul_28(1, 42, q), 42);
assert_eq!(mod_mul_28(42, 1, q), 42);
let a = q - 1;
let b = q - 1;
let expected = ((a as u64 * b as u64) % q as u64) as u32;
assert_eq!(mod_mul_28(a, b, q), expected);
}
#[test]
fn test_mod_add_sub_28() {
let q = 100_000_007u32;
assert_eq!(mod_add_28(3, 5, q), 8);
assert_eq!(mod_add_28(q - 1, 1, q), 0);
assert_eq!(mod_add_28(q - 1, q - 1, q), q - 2);
assert_eq!(mod_sub_28(5, 3, q), 2);
assert_eq!(mod_sub_28(3, 5, q), q - 2);
assert_eq!(mod_sub_28(0, 0, q), 0);
assert_eq!(mod_sub_28(0, 1, q), q - 1);
}
#[test]
fn test_mod_pow_32_basic() {
let q = 1009u32;
assert_eq!(mod_pow_32(2, 10, q), 15);
assert_eq!(mod_pow_32(42, 0, q), 1);
assert_eq!(mod_pow_32(42, 1, q), 42);
assert_eq!(mod_pow_32(0, 100, q), 0);
assert_eq!(mod_pow_32(7, q - 1, q), 1);
}
#[test]
fn test_mod_inv_32_basic() {
let q = 17u32;
for a in 1..q {
let inv = mod_inv_32(a, q);
let prod = mod_mul_28(a, inv, q);
assert_eq!(prod, 1, "Inverse failed for a={a}, q={q}, inv={inv}");
}
}
}