algorithms_edu/data_structures/
sparse_table.rs

1//! Implementation of a sparse table which is a data structure that can very quickly query a range on
2//! a static array in $O(1)$ for overlap-friendly functions (idempotent functions) using $O(n*logn)$ memory.
3//! For functions that are only associative, the query is done in $O(log(n))$.
4//!
5//! A function $f$ is associative if $f(a, f(b, c)) = f(f(a, b), c)$. Examples include scalar and matrix
6//! addition and multiplication, and string concatenation.
7//! A function is overlap-freindly if $f(f(a, b), f(b, c)) = f(f(a, b), c)$. Examples include min, max, gcd and lcm.
8
9use crate::algo::math::log2::VecLog2;
10
11pub struct SparseTable<T, F>
12where
13    F: Fn(T, T) -> T,
14{
15    // The sparse table values.
16    values: Vec<Vec<Option<T>>>,
17    // Pre-computed array of log2 values
18    log2: Vec<usize>,
19    // The function to be applied
20    f: F,
21    overlap_friendly: bool,
22}
23
24impl<T: Clone, F: Fn(T, T) -> T> SparseTable<T, F> {
25    pub fn new(arr: &[T], f: F, overlap_friendly: bool) -> Self {
26        let n = arr.len();
27        let log2 = Vec::log2(n + 1);
28        let m = log2[n];
29        let mut values = vec![vec![None; n]; m + 1];
30        for (i, v) in arr.iter().enumerate() {
31            values[0][i] = Some(v.clone());
32        }
33        // Build sparse table combining the values of the previous intervals.
34        for i in 1..=m {
35            for j in 0..=(n - (1 << i)) {
36                let left_interval = values[i - 1][j].clone();
37                let right_interval = values[i - 1][j + (1 << (i - 1))].clone();
38                values[i][j] = Some(f(left_interval.unwrap(), right_interval.unwrap()));
39            }
40        }
41        Self {
42            values,
43            log2,
44            f,
45            overlap_friendly,
46        }
47    }
48    pub fn query(&self, mut l: usize, r: usize) -> T {
49        if self.overlap_friendly {
50            let len = r - l + 1;
51            let i = self.log2[len];
52            let left_interval = self.values[i][l].clone();
53            let right_interval = self.values[i][1 + r - (1 << i)].clone();
54            (self.f)(left_interval.unwrap(), right_interval.unwrap())
55        } else {
56            let mut p = self.log2[1 + r - l];
57            let mut acc = self.values[p][l].clone().unwrap();
58            l += 1 << p;
59            while l <= r {
60                p = self.log2[1 + r - l];
61                acc = (self.f)(acc, self.values[p][l].clone().unwrap());
62                l += 1 << p;
63            }
64
65            acc
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use lazy_static::lazy_static;
74    use rand::{thread_rng, Rng};
75
76    const SAMPLE_SIZE: usize = 10;
77    lazy_static! {
78        static ref TEST_DATA: Vec<u128> = {
79            let mut rng = thread_rng();
80            (0..SAMPLE_SIZE).map(|_| rng.gen_range(1, 20)).collect()
81        };
82    }
83
84    fn validate<F>(f: F, overlap_ok: bool)
85    where
86        F: Fn(u128, u128) -> u128,
87    {
88        let sparse_table = SparseTable::new(&TEST_DATA, f, overlap_ok);
89        for i in 0..SAMPLE_SIZE - 1 {
90            for j in i..SAMPLE_SIZE {
91                let expected = TEST_DATA[i + 1..=j]
92                    .iter()
93                    .fold(TEST_DATA[i], |acc, curr| (sparse_table.f)(acc, *curr));
94                let quried = sparse_table.query(i, j);
95                assert_eq!(expected, quried);
96            }
97        }
98    }
99
100    #[test]
101    fn test_sparse_table_min_max() {
102        validate(std::cmp::min, true);
103        validate(std::cmp::max, true);
104    }
105
106    #[test]
107    fn test_sparse_table_add_mul() {
108        validate(|a, b| a + b, false);
109        validate(|a, b| a * b, false);
110    }
111
112    #[test]
113    fn test_gcd_lcm() {
114        use crate::algo::math::{gcd::GcdUnsigned, lcm::LcmUnsigned};
115        validate(|a, b| a.gcd(b), true);
116        validate(|a, b| a.lcm(b), true);
117    }
118
119    #[test]
120    fn test_string_concat() {
121        let mut rng = thread_rng();
122        let data: Vec<String> = (0..SAMPLE_SIZE * 4)
123            .map(|_| rng.gen_range(b'a', b'z'))
124            .collect::<Vec<_>>()
125            .chunks_exact(4)
126            .map(|x| unsafe { String::from_utf8_unchecked(vec![x[0], x[1], x[2], x[3]]) })
127            .collect();
128        let sparse_table = SparseTable::new(&data, |a, b| a + &b, false);
129        for i in 0..SAMPLE_SIZE - 1 {
130            for j in i..SAMPLE_SIZE {
131                let expected = data[i + 1..=j].iter().fold(data[i].clone(), |acc, curr| {
132                    (sparse_table.f)(acc, curr.clone())
133                });
134                let quried = sparse_table.query(i, j);
135                assert_eq!(expected, quried);
136            }
137        }
138    }
139
140    #[test]
141    fn test_sparse_table_matrix_add_mul() {
142        let mut rng = thread_rng();
143        let data: Vec<Matrix2x2> = (0..SAMPLE_SIZE * 4)
144            .map(|_| rng.gen_range(-10i128, 10))
145            .collect::<Vec<_>>()
146            .chunks_exact(4)
147            .map(|x| [[x[0], x[1]], [x[2], x[3]]])
148            .collect();
149
150        let sparse_table_add = SparseTable::new(&data, matrix_addition_2x2, false);
151        let sparse_table_mul = SparseTable::new(&data, matrix_multiplication_2x2, false);
152        for i in 0..SAMPLE_SIZE - 1 {
153            for j in i..SAMPLE_SIZE {
154                let (expected_sum, expected_product) =
155                    data[i + 1..=j]
156                        .iter()
157                        .fold((data[i], data[i]), |acc, curr| {
158                            (
159                                matrix_addition_2x2(acc.0, *curr),
160                                matrix_multiplication_2x2(acc.1, *curr),
161                            )
162                        });
163                let quried_sum = sparse_table_add.query(i, j);
164                let quried_product = sparse_table_mul.query(i, j);
165                assert_eq!(expected_sum, quried_sum);
166                assert_eq!(expected_product, quried_product);
167            }
168        }
169    }
170
171    type Matrix2x2 = [[i128; 2]; 2];
172    fn matrix_addition_2x2(a: Matrix2x2, b: Matrix2x2) -> Matrix2x2 {
173        [
174            [a[0][0] + b[0][0], a[0][1] + b[0][1]],
175            [a[1][0] + b[1][0], a[1][1] + b[1][1]],
176        ]
177    }
178    fn matrix_multiplication_2x2(a: Matrix2x2, b: Matrix2x2) -> Matrix2x2 {
179        [
180            [
181                a[0][0] * b[0][0] + a[0][1] * b[1][0],
182                a[0][0] * b[0][1] + a[0][1] * b[1][1],
183            ],
184            [
185                a[1][0] * b[0][0] + a[1][1] * b[1][0],
186                a[1][0] * b[0][1] + a[1][1] * b[1][1],
187            ],
188        ]
189    }
190}