1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
//! Modular Exponentiation
//!
//! Implementation of a modular exponentiation algorithm based on the algorithm from
//! 'Applied Cryptography'. You can find more details in that book, or on
//! [wikipedia](https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method).

extern crate num;

use std::ops::{Shr};
use num::traits::{Num, One, Zero, Bounded};

#[allow(non_snake_case)]
/// Performs the exponentiation
///
/// All parameters are generic, provided they implement the following traits:
///
/// * Num
/// * PartialOrd
/// * Shr<T, Output=T>
/// * Copy
/// * Bounded
///
/// You can find the `Num` and `Bounded` traits in the [num](https://crates.io/crate/num) crate.
///
/// # Examples
///
/// ```
/// use mod_exp::mod_exp;
///
/// assert_eq!(mod_exp(5, 3, 13), 8);
/// ```
///
/// # Panics
///
/// The function does an `assert!` to verify that the data type of `base` is
/// large enough that the result won't overflow during the computation
pub fn mod_exp<T>(base: T, exponent: T, modulus: T) -> T where T: Num + PartialOrd + Shr<T, Output=T> + Copy + Bounded {
    let ONE: T = One::one();
    let TWO: T = ONE + ONE;
    let ZERO: T = Zero::zero();
    let MAX: T = Bounded::max_value();

    assert!((modulus - ONE)  < (MAX / (modulus - ONE)));

    let mut result = ONE;
    let mut base = base % modulus;
    let mut exponent = exponent;

    loop {
        if exponent <= ZERO {
            break;
        }

        if exponent % TWO == ONE {
            result = (result * base) % modulus;
        }

        exponent = exponent >> ONE;
        base = (base * base) % modulus;
    }

    result
}

#[cfg(test)] mod tests {
    use super::mod_exp;
    use std::panic;

    #[test]
    fn test_mod_exp() {
        let base = 4i64;
        let exponent = 13i64;
        let modulus = 497i64;
        assert_eq!(mod_exp(base, exponent, modulus), 445i64);
    }

    #[test]
    fn test_overflow_lhs() {
        if let Err(ref e) = panic::catch_unwind(|| {
            let modulus = 254u8;
            mod_exp(1u8, 1u8, modulus);
        }) {
            if let Some(msg) = e.downcast_ref::<&str>() {
                assert!(msg.starts_with("assertion failed: "));
                return
            }
        }
        assert!(false, "Assertion didn't fail as it should have");
    }
}