Skip to main content

alkahest_cas/number_theory/
mod.rs

1//! V3-1 — Integer number theory via FLINT `fmpz`.
2//!
3//! Wraps proven primality, integer factorisation, totients, Jacobi symbols,
4//! modular square roots (`fmpz_sqrtmod`), nth roots modulo primes when Coprime holds,
5//! brute-force discrete logs, and quadratic Dirichlet characters (odd square-free conductor).
6
7use crate::errors::AlkahestError;
8use crate::flint::ffi::{self as ffi, FmpzFactorStruct};
9use crate::flint::FlintInteger;
10use rug::Complete;
11use rug::Integer;
12use std::cmp::Ordering;
13use std::fmt;
14use std::str::FromStr;
15
16// ---------------------------------------------------------------------------
17// NumberTheoryError
18// ---------------------------------------------------------------------------
19
20/// Failed integer number-theory primitive (`E-NT-*`).
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NumberTheoryError {
23    InvalidInput { msg: &'static str },
24    Domain { msg: &'static str },
25    NoSolution,
26    CompositeModulus,
27    UnsupportedNthRoot,
28}
29
30impl fmt::Display for NumberTheoryError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            NumberTheoryError::InvalidInput { msg } => write!(f, "{msg}"),
34            NumberTheoryError::Domain { msg } => write!(f, "{msg}"),
35            NumberTheoryError::NoSolution => {
36                write!(f, "no discrete logarithm or modular root exists")
37            }
38            NumberTheoryError::CompositeModulus => write!(f, "operation requires a prime modulus"),
39            NumberTheoryError::UnsupportedNthRoot => {
40                write!(f, "nth root modulo p requires k=2 or gcd(k,p−1)=1")
41            }
42        }
43    }
44}
45
46impl std::error::Error for NumberTheoryError {}
47
48impl AlkahestError for NumberTheoryError {
49    fn code(&self) -> &'static str {
50        match self {
51            NumberTheoryError::InvalidInput { .. } => "E-NT-001",
52            NumberTheoryError::Domain { .. } => "E-NT-002",
53            NumberTheoryError::NoSolution => "E-NT-003",
54            NumberTheoryError::CompositeModulus => "E-NT-004",
55            NumberTheoryError::UnsupportedNthRoot => "E-NT-005",
56        }
57    }
58
59    fn remediation(&self) -> Option<&'static str> {
60        match self {
61            NumberTheoryError::InvalidInput { .. } => {
62                Some("pass arbitrary-precision integers as decimal strings without spaces")
63            }
64            NumberTheoryError::Domain { .. } => {
65                Some("check parity, positivity, and defined ranges")
66            }
67            NumberTheoryError::NoSolution => {
68                Some("verify solvability: residue in ⟨base⟩, or quadratic residue for k=2")
69            }
70            NumberTheoryError::CompositeModulus => {
71                Some("use a prime field modulus where the FLINT primitives apply")
72            }
73            NumberTheoryError::UnsupportedNthRoot => Some(
74                "use sqrt (k=2) or primes with gcd(k,p−1)=1; Tonelli–Shanks chains are deferred",
75            ),
76        }
77    }
78}
79
80fn parse_int(s: &str) -> Result<Integer, NumberTheoryError> {
81    Integer::from_str(s.trim()).map_err(|_| NumberTheoryError::InvalidInput {
82        msg: "invalid decimal integer string",
83    })
84}
85
86fn parse_nonnegative(s: &str) -> Result<Integer, NumberTheoryError> {
87    let z = parse_int(s)?;
88    if z.cmp0() == Ordering::Less {
89        Err(NumberTheoryError::Domain {
90            msg: "expected a non-negative integer",
91        })
92    } else {
93        Ok(z)
94    }
95}
96
97/// Multiplicative inverse of `a` modulo `m` when `gcd(a,m)=1`.
98fn mod_inverse(mut a: Integer, m: &Integer) -> Option<Integer> {
99    if m.cmp0() != Ordering::Greater {
100        return None;
101    }
102    if m == &Integer::from(1) {
103        return Some(Integer::from(0));
104    }
105    a %= m;
106    let (g, s, _) = a.extended_gcd(m.clone(), Integer::new());
107    if g != 1 && g != -1 {
108        return None;
109    }
110    let mut inv = if g == -1 { -s } else { s };
111    inv %= m;
112    if inv.cmp0() == Ordering::Less {
113        inv += m;
114    }
115    Some(inv)
116}
117
118fn integer_is_odd(n: &Integer) -> bool {
119    (n.clone() % Integer::from(2_u32)).cmp0() != Ordering::Equal
120}
121
122/// Positive integer parser (strictly \(> 0\) when required by the caller).
123fn parse_positive(s: &str) -> Result<Integer, NumberTheoryError> {
124    let z = parse_nonnegative(s)?;
125    if z.is_zero() {
126        Err(NumberTheoryError::Domain {
127            msg: "expected a positive integer",
128        })
129    } else {
130        Ok(z)
131    }
132}
133
134/// Exact primality (`fmpz_is_prime`).
135pub fn isprime(n: &str) -> Result<bool, NumberTheoryError> {
136    let z = parse_int(n)?;
137    if z.cmp0() != Ordering::Greater || z < 2 {
138        return Ok(false);
139    }
140    let fz = FlintInteger::from_rug(&z);
141    let r = unsafe { ffi::fmpz_is_prime(fz.inner_ptr()) };
142    Ok(r != 0)
143}
144
145/// Full factorisation: `(sign, list of (prime, exponent))` for \(\prod p^e\cdot \mathrm{sign}\).
146pub fn factorint(n: &str) -> Result<(i32, Vec<(String, u64)>), NumberTheoryError> {
147    let z = parse_int(n)?;
148    let fz = FlintInteger::from_rug(&z);
149    unsafe {
150        let mut fac = std::mem::MaybeUninit::<FmpzFactorStruct>::uninit();
151        ffi::fmpz_factor_init(fac.as_mut_ptr());
152        let mut fac = fac.assume_init();
153        ffi::fmpz_factor(&mut fac, fz.inner_ptr());
154        let mut out = Vec::with_capacity(fac.num.max(0) as usize);
155        for i in 0..fac.num {
156            let mut base = FlintInteger::new();
157            ffi::fmpz_set(base.inner_mut_ptr(), fac.p.add(i as usize));
158            let exp = *fac.exp.add(i as usize);
159            out.push((base.to_string(), exp));
160        }
161        let sign = fac.sign;
162        ffi::fmpz_factor_clear(&mut fac);
163        Ok((sign, out))
164    }
165}
166
167/// Next prime strictly after `n` (`fmpz_nextprime`).
168pub fn nextprime(n: &str, proved: bool) -> Result<String, NumberTheoryError> {
169    let z = parse_int(n)?;
170    let fz = FlintInteger::from_rug(&z);
171    let mut res = FlintInteger::new();
172    unsafe {
173        ffi::fmpz_nextprime(
174            res.inner_mut_ptr(),
175            fz.inner_ptr(),
176            if proved { 1 } else { 0 },
177        );
178    }
179    Ok(res.to_string())
180}
181
182/// Euler totient \(\varphi(n)\) (`fmpz_euler_phi`).
183pub fn totient(n: &str) -> Result<String, NumberTheoryError> {
184    let z = parse_positive(n)?;
185    let fz = FlintInteger::from_rug(&z);
186    let mut out = FlintInteger::new();
187    unsafe {
188        ffi::fmpz_euler_phi(out.inner_mut_ptr(), fz.inner_ptr());
189    }
190    Ok(out.to_string())
191}
192
193/// Jacobi symbol \((a | n)\) for odd \(n > 1\) (`fmpz_jacobi`).
194pub fn jacobi_symbol(a: &str, n: &str) -> Result<i32, NumberTheoryError> {
195    let na = parse_int(a)?;
196    let nn = parse_positive(n)?;
197    if nn <= 1 || !integer_is_odd(&nn) {
198        return Err(NumberTheoryError::Domain {
199            msg: "Jacobi denominator must be odd and greater than 1",
200        });
201    }
202    let fa = FlintInteger::from_rug(&na);
203    let fn_ = FlintInteger::from_rug(&nn);
204    let j = unsafe { ffi::fmpz_jacobi(fa.inner_ptr(), fn_.inner_ptr()) };
205    Ok(j as i32)
206}
207
208/// Modular \(k\)th root: some \(x\) with \(x^k \equiv a \pmod p\) for prime \(p\).
209///
210/// Implemented for `k == 2` via `fmpz_sqrtmod`, or when \(\gcd(k, p{-}1)=1\) via exponent inversion.
211pub fn nthroot_mod(a: &str, k: u64, p: &str) -> Result<String, NumberTheoryError> {
212    if k == 0 {
213        return Err(NumberTheoryError::InvalidInput {
214            msg: "root degree must be ≥ 1",
215        });
216    }
217    let pm = parse_positive(p)?;
218    let fp = FlintInteger::from_rug(&pm);
219    if unsafe { ffi::fmpz_is_prime(fp.inner_ptr()) } == 0 {
220        return Err(NumberTheoryError::CompositeModulus);
221    }
222
223    let mut ared = parse_int(a)?;
224    ared %= &pm;
225
226    let mut out = FlintInteger::new();
227
228    if k == 2 {
229        let fa = FlintInteger::from_rug(&ared);
230        let ok = unsafe { ffi::fmpz_sqrtmod(out.inner_mut_ptr(), fa.inner_ptr(), fp.inner_ptr()) };
231        if ok == 0 {
232            return Err(NumberTheoryError::NoSolution);
233        }
234        return Ok(out.to_string());
235    }
236
237    let ord = pm.clone() - 1;
238    let kk = Integer::from(k);
239    if kk.clone().gcd(&ord) != 1 {
240        return Err(NumberTheoryError::UnsupportedNthRoot);
241    }
242    let mut inv_e = mod_inverse(kk.clone(), &ord).ok_or(NumberTheoryError::UnsupportedNthRoot)?;
243    inv_e %= &ord;
244    let fa = FlintInteger::from_rug(&ared);
245    let fe = FlintInteger::from_rug(&inv_e);
246    unsafe {
247        ffi::fmpz_powm(
248            out.inner_mut_ptr(),
249            fa.inner_ptr(),
250            fe.inner_ptr(),
251            fp.inner_ptr(),
252        );
253    }
254    Ok(out.to_string())
255}
256
257/// Smallest exponent \(e \geq 0\) with \(\mathit{base}^e \equiv \mathit{residue}\pmod{p}\) (`p` prime).
258///
259/// This uses a deterministic linear sweep over exponents bounded by \(p{-}1\); it is tuned for API
260/// parity and moderate primes, not large-field cryptography.
261pub fn discrete_log(residue: &str, base: &str, p: &str) -> Result<String, NumberTheoryError> {
262    let pm = parse_positive(p)?;
263    if pm < 2 {
264        return Err(NumberTheoryError::Domain {
265            msg: "modulus must be at least 2",
266        });
267    }
268    let fp = FlintInteger::from_rug(&pm);
269    if unsafe { ffi::fmpz_is_prime(fp.inner_ptr()) } == 0 {
270        return Err(NumberTheoryError::CompositeModulus);
271    }
272
273    let ord = pm.clone() - Integer::from(1);
274    let mut b = parse_int(base)?;
275    let mut r = parse_int(residue)?;
276    r %= &pm;
277    b %= &pm;
278
279    if b.is_zero() {
280        return if r.is_zero() {
281            Ok("1".into())
282        } else {
283            Err(NumberTheoryError::NoSolution)
284        };
285    }
286
287    let mut cur = Integer::from(1);
288    let mut exp = Integer::from(0);
289    while exp < ord {
290        if cur == r {
291            return Ok(exp.to_string());
292        }
293        cur = (&cur * &b).complete();
294        cur %= &pm;
295        exp += 1;
296    }
297    Err(NumberTheoryError::NoSolution)
298}
299
300/// Quadratic Dirichlet character: Jacobi symbol \((· | q)\) for odd square-free \(q≥3\).
301#[derive(Clone, Debug)]
302pub struct QuadraticDirichlet {
303    modulus: Integer,
304}
305
306impl QuadraticDirichlet {
307    pub fn new(conductor: &str) -> Result<Self, NumberTheoryError> {
308        let q = parse_positive(conductor)?;
309        if q <= 2 || !integer_is_odd(&q) {
310            return Err(NumberTheoryError::Domain {
311                msg: "quadratic Dirichlet conductor must be odd and ≥ 3",
312            });
313        }
314        let (_sign, fac) = factorint(conductor)?;
315        for (_, e) in &fac {
316            if *e != 1 {
317                return Err(NumberTheoryError::Domain {
318                    msg: "conductor must be square-free",
319                });
320            }
321        }
322        Ok(QuadraticDirichlet { modulus: q })
323    }
324
325    pub fn conductor(&self) -> String {
326        self.modulus.to_string()
327    }
328
329    /// \(\chi_q(n)\) as `−1`, `0`, or `1`.
330    pub fn eval(&self, n: &str) -> Result<i32, NumberTheoryError> {
331        jacobi_symbol(n, &self.modulus.to_string())
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use rug::ops::Pow;
339    use std::collections::HashMap;
340
341    #[test]
342    fn mersenne_m127_prime() {
343        let m = Integer::from(2u32).pow(127_u32) - 1_u32;
344        assert!(isprime(&m.to_string()).unwrap());
345    }
346
347    #[test]
348    fn factorint_f5() {
349        let n = &(1u128 << 32) - 1;
350        let (sign, pairs) = factorint(&n.to_string()).unwrap();
351        assert_eq!(sign, 1);
352        let m: HashMap<_, _> = pairs.into_iter().collect();
353        assert_eq!(m.get("65537").copied(), Some(1));
354    }
355
356    #[test]
357    fn nextprime_gap() {
358        assert_eq!(nextprime("13", true).unwrap(), "17");
359    }
360
361    #[test]
362    fn totient_twelve() {
363        assert_eq!(totient("12").unwrap(), "4");
364    }
365
366    #[test]
367    fn jacobi_two_fifteen() {
368        assert_eq!(jacobi_symbol("2", "15").unwrap(), 1);
369    }
370
371    #[test]
372    fn sqrt_mod_prime() {
373        let x_str = nthroot_mod("144", 2, "401").unwrap();
374        let x: u64 = x_str.parse().unwrap();
375        assert_eq!((x * x) % 401, 144);
376    }
377
378    #[test]
379    fn nth_root_via_coprime_exponent() {
380        let pm = Integer::from(10007);
381        let a = Integer::from(42);
382        let k = 5u64;
383        let kk = Integer::from(k);
384        let ord = pm.clone() - Integer::from(1);
385        assert_eq!(kk.clone().gcd(&ord), Integer::from(1));
386
387        let x_str = nthroot_mod(&a.to_string(), k, &pm.to_string()).unwrap();
388        let x = Integer::from_str(&x_str).unwrap();
389        let chk = x.clone().pow(k as u32) % &pm;
390        assert_eq!(chk, a % &pm);
391    }
392
393    #[test]
394    fn discrete_log_three_mod_seventeen() {
395        assert_eq!(discrete_log("13", "3", "17").unwrap(), "4",);
396    }
397
398    #[test]
399    fn dirichlet_phi_fifteen() {
400        let chi = QuadraticDirichlet::new("15").unwrap();
401        assert_eq!(chi.eval("14").unwrap(), -1);
402        assert_eq!(chi.eval("3").unwrap(), 0);
403    }
404}