use crate::algo::math::log2::VecLog2;
pub struct SparseTable<T, F>
where
F: Fn(T, T) -> T,
{
values: Vec<Vec<Option<T>>>,
log2: Vec<usize>,
f: F,
overlap_friendly: bool,
}
impl<T: Clone, F: Fn(T, T) -> T> SparseTable<T, F> {
pub fn new(arr: &[T], f: F, overlap_friendly: bool) -> Self {
let n = arr.len();
let log2 = Vec::log2(n + 1);
let m = log2[n];
let mut values = vec![vec![None; n]; m + 1];
for (i, v) in arr.iter().enumerate() {
values[0][i] = Some(v.clone());
}
for i in 1..=m {
for j in 0..=(n - (1 << i)) {
let left_interval = values[i - 1][j].clone();
let right_interval = values[i - 1][j + (1 << (i - 1))].clone();
values[i][j] = Some(f(left_interval.unwrap(), right_interval.unwrap()));
}
}
Self {
values,
log2,
f,
overlap_friendly,
}
}
pub fn query(&self, mut l: usize, r: usize) -> T {
if self.overlap_friendly {
let len = r - l + 1;
let i = self.log2[len];
let left_interval = self.values[i][l].clone();
let right_interval = self.values[i][1 + r - (1 << i)].clone();
(self.f)(left_interval.unwrap(), right_interval.unwrap())
} else {
let mut p = self.log2[1 + r - l];
let mut acc = self.values[p][l].clone().unwrap();
l += 1 << p;
while l <= r {
p = self.log2[1 + r - l];
acc = (self.f)(acc, self.values[p][l].clone().unwrap());
l += 1 << p;
}
acc
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use lazy_static::lazy_static;
use rand::{thread_rng, Rng};
const SAMPLE_SIZE: usize = 10;
lazy_static! {
static ref TEST_DATA: Vec<u128> = {
let mut rng = thread_rng();
(0..SAMPLE_SIZE).map(|_| rng.gen_range(1, 20)).collect()
};
}
fn validate<F>(f: F, overlap_ok: bool)
where
F: Fn(u128, u128) -> u128,
{
let sparse_table = SparseTable::new(&TEST_DATA, f, overlap_ok);
for i in 0..SAMPLE_SIZE - 1 {
for j in i..SAMPLE_SIZE {
let expected = TEST_DATA[i + 1..=j]
.iter()
.fold(TEST_DATA[i], |acc, curr| (sparse_table.f)(acc, *curr));
let quried = sparse_table.query(i, j);
assert_eq!(expected, quried);
}
}
}
#[test]
fn test_sparse_table_min_max() {
validate(std::cmp::min, true);
validate(std::cmp::max, true);
}
#[test]
fn test_sparse_table_add_mul() {
validate(|a, b| a + b, false);
validate(|a, b| a * b, false);
}
#[test]
fn test_gcd_lcm() {
use crate::algo::math::{gcd::GcdUnsigned, lcm::LcmUnsigned};
validate(|a, b| a.gcd(b), true);
validate(|a, b| a.lcm(b), true);
}
#[test]
fn test_string_concat() {
let mut rng = thread_rng();
let data: Vec<String> = (0..SAMPLE_SIZE * 4)
.map(|_| rng.gen_range(b'a', b'z'))
.collect::<Vec<_>>()
.chunks_exact(4)
.map(|x| unsafe { String::from_utf8_unchecked(vec![x[0], x[1], x[2], x[3]]) })
.collect();
let sparse_table = SparseTable::new(&data, |a, b| a + &b, false);
for i in 0..SAMPLE_SIZE - 1 {
for j in i..SAMPLE_SIZE {
let expected = data[i + 1..=j].iter().fold(data[i].clone(), |acc, curr| {
(sparse_table.f)(acc, curr.clone())
});
let quried = sparse_table.query(i, j);
assert_eq!(expected, quried);
}
}
}
#[test]
fn test_sparse_table_matrix_add_mul() {
let mut rng = thread_rng();
let data: Vec<Matrix2x2> = (0..SAMPLE_SIZE * 4)
.map(|_| rng.gen_range(-10i128, 10))
.collect::<Vec<_>>()
.chunks_exact(4)
.map(|x| [[x[0], x[1]], [x[2], x[3]]])
.collect();
let sparse_table_add = SparseTable::new(&data, matrix_addition_2x2, false);
let sparse_table_mul = SparseTable::new(&data, matrix_multiplication_2x2, false);
for i in 0..SAMPLE_SIZE - 1 {
for j in i..SAMPLE_SIZE {
let (expected_sum, expected_product) =
data[i + 1..=j]
.iter()
.fold((data[i], data[i]), |acc, curr| {
(
matrix_addition_2x2(acc.0, *curr),
matrix_multiplication_2x2(acc.1, *curr),
)
});
let quried_sum = sparse_table_add.query(i, j);
let quried_product = sparse_table_mul.query(i, j);
assert_eq!(expected_sum, quried_sum);
assert_eq!(expected_product, quried_product);
}
}
}
type Matrix2x2 = [[i128; 2]; 2];
fn matrix_addition_2x2(a: Matrix2x2, b: Matrix2x2) -> Matrix2x2 {
[
[a[0][0] + b[0][0], a[0][1] + b[0][1]],
[a[1][0] + b[1][0], a[1][1] + b[1][1]],
]
}
fn matrix_multiplication_2x2(a: Matrix2x2, b: Matrix2x2) -> Matrix2x2 {
[
[
a[0][0] * b[0][0] + a[0][1] * b[1][0],
a[0][0] * b[0][1] + a[0][1] * b[1][1],
],
[
a[1][0] * b[0][0] + a[1][1] * b[1][0],
a[1][0] * b[0][1] + a[1][1] * b[1][1],
],
]
}
}