discrete_logarithm/
n_order.rs

1use std::collections::HashMap;
2
3use rug::{ops::Pow, Integer};
4
5use crate::{utils::fast_factor, Error};
6
7/// Returns the order of `a` modulo `n`.
8///
9/// The order of `a` modulo `n` is the smallest integer `k` such that `a**k` leaves a remainder of 1 with `n`.
10pub fn n_order(a: &Integer, n: &Integer) -> Result<Integer, Error> {
11    // Special case: n == 1, order is always 1
12    if *n == 1 {
13        return Ok(Integer::from(1));
14    }
15
16    // Validate n > 1
17    if *n < 1 {
18        return Err(Error::NotRelativelyPrime);
19    }
20
21    // Early return for trivial case
22    let a_mod = a.clone() % n;
23    if a_mod == 1 {
24        return Ok(Integer::from(1));
25    }
26
27    if a_mod.clone().gcd(n) != 1 {
28        return Err(Error::NotRelativelyPrime);
29    }
30
31    let factors = fast_factor(n);
32    n_order_with_factors(a, n, &factors)
33}
34
35/// Returns the order of `a` modulo `n`.
36///
37/// The order of `a` modulo `n` is the smallest integer `k` such that `a**k` leaves a remainder of 1 with `n`.
38///
39/// If the prime factorization of `n` is known, it can be passed as `n_factors` to speed up the computation.
40pub fn n_order_with_factors(
41    a: &Integer,
42    n: &Integer,
43    n_factors: &HashMap<Integer, usize>,
44) -> Result<Integer, Error> {
45    // Special case: n == 1, order is always 1
46    if *n == 1 {
47        return Ok(Integer::from(1));
48    }
49
50    // Validate n > 1
51    if *n < 1 {
52        return Err(Error::NotRelativelyPrime);
53    }
54
55    // Early return for trivial case
56    let a_mod = a.clone() % n;
57    if a_mod == 1 {
58        return Ok(Integer::from(1));
59    }
60
61    if a_mod.clone().gcd(n) != 1 {
62        return Err(Error::NotRelativelyPrime);
63    }
64
65    let mut factors = HashMap::new();
66    for (px, kx) in n_factors.iter() {
67        if *kx > 1 {
68            *factors.entry(px.clone()).or_insert(0) += kx - 1;
69        }
70        let fpx = fast_factor(&(px.clone() - 1));
71        for (py, ky) in fpx.iter() {
72            *factors.entry(py.clone()).or_insert(0) += ky;
73        }
74    }
75
76    let mut group_order = Integer::from(1);
77    for (px, kx) in factors.iter() {
78        group_order *= px.clone().pow(*kx as u32);
79    }
80
81    let mut order = Integer::from(1);
82    for (p, e) in factors {
83        let mut exponent = group_order.clone();
84        for f in 0..=e {
85            if a_mod.clone().pow_mod(&exponent, n).unwrap() != 1 {
86                order *= p.clone().pow((e - f + 1) as u32);
87                break;
88            }
89            exponent /= &p;
90        }
91    }
92
93    Ok(order)
94}
95
96#[cfg(test)]
97mod tests {
98    use std::str::FromStr;
99
100    use super::*;
101
102    #[test]
103    fn trial_mul() {
104        assert_eq!(n_order(&2.into(), &13.into()).unwrap(), 12);
105        for (a, res) in (1..=6).zip(vec![1, 3, 6, 3, 6, 2]) {
106            assert_eq!(n_order(&a.into(), &7.into()).unwrap(), res);
107        }
108        assert_eq!(n_order(&5.into(), &17.into()).unwrap(), 16);
109        assert_eq!(
110            n_order(&17.into(), &11.into()),
111            n_order(&6.into(), &11.into())
112        );
113        assert_eq!(n_order(&101.into(), &119.into()).unwrap(), 6);
114        assert_eq!(
115            n_order(&6.into(), &9.into()),
116            Err(Error::NotRelativelyPrime)
117        );
118
119        assert_eq!(n_order_with_factors(&11.into(), &(Integer::from(10).pow(50) + 151u64).square(), &HashMap::from([(Integer::from(10).pow(50) + 151, 2)])).unwrap(), Integer::from_str("10000000000000000000000000000000000000000000000030100000000000000000000000000000000000000000000022650").unwrap());
120    }
121
122    #[test]
123    fn n_order_trivial_case() {
124        // Test early return for a % n == 1
125        assert_eq!(n_order(&1.into(), &7.into()).unwrap(), 1);
126        assert_eq!(n_order(&8.into(), &7.into()).unwrap(), 1); // 8 % 7 = 1
127        assert_eq!(n_order(&15.into(), &7.into()).unwrap(), 1); // 15 % 7 = 1
128    }
129
130    #[test]
131    fn n_order_validation() {
132        // Test n == 1 returns 1 (special case)
133        assert_eq!(n_order(&2.into(), &1.into()).unwrap(), 1);
134        assert_eq!(n_order(&0.into(), &1.into()).unwrap(), 1);
135
136        // Test n < 1 validation
137        assert_eq!(
138            n_order(&2.into(), &0.into()),
139            Err(Error::NotRelativelyPrime)
140        );
141        assert_eq!(
142            n_order(&2.into(), &(-1).into()),
143            Err(Error::NotRelativelyPrime)
144        );
145    }
146}