1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
extern crate num;
use std::ops::{Shr};
use num::traits::{Num, One, Zero, Bounded};
#[allow(non_snake_case)]
pub fn mod_exp<T>(base: T, exponent: T, modulus: T) -> T where T: Num + PartialOrd + Shr<T, Output=T> + Copy + Bounded {
let ONE: T = One::one();
let TWO: T = ONE + ONE;
let ZERO: T = Zero::zero();
let MAX: T = Bounded::max_value();
assert!((modulus - ONE) < (MAX / (modulus - ONE)));
let mut result = ONE;
let mut base = base % modulus;
let mut exponent = exponent;
loop {
if exponent <= ZERO {
break;
}
if exponent % TWO == ONE {
result = (result * base) % modulus;
}
exponent = exponent >> ONE;
base = (base * base) % modulus;
}
result
}
#[cfg(test)] mod tests {
use super::mod_exp;
use std::panic;
#[test]
fn test_mod_exp() {
let base = 4i64;
let exponent = 13i64;
let modulus = 497i64;
assert_eq!(mod_exp(base, exponent, modulus), 445i64);
}
#[test]
fn test_overflow_lhs() {
if let Err(ref e) = panic::catch_unwind(|| {
let modulus = 254u8;
mod_exp(1u8, 1u8, modulus);
}) {
if let Some(msg) = e.downcast_ref::<&str>() {
assert!(msg.starts_with("assertion failed: "));
return
}
}
assert!(false, "Assertion didn't fail as it should have");
}
}