competitive_programming_rs/string/
rolling_hash.rs1pub mod rolling_hash {
2 const MASK_30: u64 = (1 << 30) - 1;
3 const MASK_31: u64 = (1 << 31) - 1;
4 const MOD: u64 = (1 << 61) - 1;
5
6 pub struct RollingHash {
7 hash: Vec<u64>,
8 pow: Vec<u64>,
9 }
10
11 impl RollingHash {
12 pub fn new(s: &[u8], base: u64) -> RollingHash {
13 let n = s.len();
14 let mut hash: Vec<u64> = vec![0; n + 1];
15 let mut pow: Vec<u64> = vec![0; n + 1];
16 pow[0] = 1;
17 for i in 0..n {
18 pow[i + 1] = modulo(mod_mul(pow[i], base));
19 hash[i + 1] = modulo(mod_mul(hash[i], base) + s[i] as u64);
20 }
21 RollingHash { hash, pow }
22 }
23
24 pub fn get_hash(&self, l: usize, r: usize) -> u64 {
26 modulo(self.hash[r] + MOD - mod_mul(self.hash[l], self.pow[r - l]))
27 }
28 }
29
30 fn mod_mul(a: u64, b: u64) -> u64 {
31 let (a_prefix, a_suffix) = (a >> 31, a & MASK_31);
32 let (b_prefix, b_suffix) = (b >> 31, b & MASK_31);
33 let m = a_suffix * b_prefix + a_prefix * b_suffix;
34 modulo(a_prefix * b_prefix * 2 + (m >> 30) + ((m & MASK_30) << 31) + a_suffix * b_suffix)
35 }
36
37 fn modulo(v: u64) -> u64 {
38 let v = (v & MOD) + (v >> 61);
39 if v >= MOD {
40 v - MOD
41 } else {
42 v
43 }
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use rand::distributions::Uniform;
51 use rand::Rng;
52
53 const BASE: u64 = 1_000_000_007;
54
55 #[test]
56 fn test_rolling_hash() {
57 let n = 30;
58 let mut rng = rand::thread_rng();
59
60 for _ in 0..100 {
61 let mut s = String::new();
62 for _ in 0..n {
63 let c = (rng.sample(Uniform::from(0..26)) as u8 + 'a' as u8) as char;
64 s.push(c);
65 }
66
67 let t = String::new() + s.as_str() + s.as_str();
68 let n = t.len();
69 let rolling_hash = rolling_hash::RollingHash::new(&t.as_bytes(), BASE);
70 for i in 0..n {
71 for j in i..n {
72 for k in (j + 1)..n {
73 let same = t[i..k] == t[j..k];
74 let same_hash = rolling_hash.get_hash(i, k) == rolling_hash.get_hash(j, k);
75 assert_eq!(
76 same,
77 same_hash,
78 "{:?} {:?} {} {}",
79 &t[i..k],
80 &t[j..k],
81 rolling_hash.get_hash(i, k),
82 rolling_hash.get_hash(j, k)
83 );
84 }
85 }
86 }
87 }
88 }
89}