use crate::utils;
use core::marker::PhantomData;
use core::ops::RangeBounds;
pub trait SegTreeSpec {
type T: Clone;
const ID: Self::T;
fn op(a: &mut Self::T, b: &Self::T);
}
pub struct SegTree<Spec: SegTreeSpec> {
size: usize,
max_size: usize,
data: Box<[Spec::T]>,
_spec: PhantomData<Spec>,
}
impl<Spec: SegTreeSpec> SegTree<Spec> {
pub fn new(size: usize) -> Self {
let max_size = size.next_power_of_two();
Self {
size,
max_size,
data: vec![Spec::ID; max_size * 2].into_boxed_slice(),
_spec: PhantomData,
}
}
pub fn from_slice(values: &[Spec::T]) -> Self {
let size = values.len();
let max_size = size.next_power_of_two();
let mut data = vec![Spec::ID; 2 * max_size];
data[max_size..(max_size + size)].clone_from_slice(values);
for i in (1..max_size).rev() {
let mut v = data[i * 2].clone();
Spec::op(&mut v, &data[i * 2 + 1]);
data[i] = v;
}
Self {
size,
max_size,
data: data.into_boxed_slice(),
_spec: PhantomData,
}
}
pub fn from_vec(vec: Vec<Spec::T>) -> Self {
let size = vec.len();
let max_size = size.next_power_of_two();
let mut data = vec![Spec::ID; 2 * max_size];
for (i, v) in vec.into_iter().enumerate() {
data[max_size + i] = v;
}
for i in (1..max_size).rev() {
let mut v = data[i * 2].clone();
Spec::op(&mut v, &data[i * 2 + 1]);
data[i] = v;
}
Self {
size,
max_size,
data: data.into_boxed_slice(),
_spec: PhantomData,
}
}
pub fn query<R: RangeBounds<usize>>(&self, range: R) -> Spec::T {
let (left, right) = utils::parse_range(range, self.size);
utils::validate_range(left, right, self.size);
if left == right {
return Spec::ID;
}
let mut left = left + self.max_size;
let mut right = right + self.max_size;
let mut result_left = Spec::ID;
let mut result_right = Spec::ID;
while left < right {
if left & 1 == 1 {
Spec::op(&mut result_left, &self.data[left]);
left += 1;
}
if right % 2 == 1 {
right -= 1;
Spec::op(&mut result_right, &self.data[right]);
}
left /= 2;
right /= 2;
}
Spec::op(&mut result_left, &result_right);
result_left
}
pub fn update(&mut self, index: usize, value: Spec::T) {
assert!(index < self.size, "update index out of bounds");
let leaf_index = index + self.max_size;
self.data[leaf_index] = value;
self.recompute(leaf_index);
}
fn recompute(&mut self, mut index: usize) {
while index > 1 {
index /= 2;
let mut v = self.data[index * 2].clone();
Spec::op(&mut v, &self.data[index * 2 + 1]);
self.data[index] = v;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct SumSpec;
impl SegTreeSpec for SumSpec {
type T = i64;
const ID: Self::T = 0;
fn op(a: &mut Self::T, b: &Self::T) {
*a += *b;
}
}
#[test]
fn test_new_empty() {
let seg_tree = SegTree::<SumSpec>::new(10);
assert_eq!(seg_tree.query(..), 0);
}
#[test]
fn test_from_slice_with_query() {
let values = vec![1, 2, 3];
let seg_tree = SegTree::<SumSpec>::from_slice(&values);
assert_eq!(seg_tree.query(0..1), 1);
assert_eq!(seg_tree.query(1..2), 2);
assert_eq!(seg_tree.query(2..3), 3);
assert_eq!(seg_tree.query(..2), 3);
assert_eq!(seg_tree.query(1..), 5);
assert_eq!(seg_tree.query(..), 6);
}
#[test]
fn test_from_vec_with_query() {
let values = vec![1, 2, 3];
let seg_tree = SegTree::<SumSpec>::from_vec(values);
assert_eq!(seg_tree.query(0..1), 1);
assert_eq!(seg_tree.query(1..2), 2);
assert_eq!(seg_tree.query(2..3), 3);
assert_eq!(seg_tree.query(..2), 3);
assert_eq!(seg_tree.query(1..), 5);
assert_eq!(seg_tree.query(..), 6);
}
#[test]
fn test_query_sub_ranges() {
let seg_tree = SegTree::<SumSpec>::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8]);
assert_eq!(seg_tree.query(0..3), 6); assert_eq!(seg_tree.query(2..5), 12); assert_eq!(seg_tree.query(4..), 26); assert_eq!(seg_tree.query(..=6), 28); assert_eq!(seg_tree.query(7..8), 8); }
#[test]
fn test_query_empty_range() {
let seg_tree = SegTree::<SumSpec>::from_vec(vec![1, 2, 3]);
assert_eq!(seg_tree.query(1..1), 0);
assert_eq!(seg_tree.query(3..3), 0);
}
#[test]
fn test_update() {
let mut seg_tree = SegTree::<SumSpec>::from_vec(vec![1, 2, 3, 4, 5]);
assert_eq!(seg_tree.query(..), 15);
seg_tree.update(2, 10);
assert_eq!(seg_tree.query(..), 1 + 2 + 10 + 4 + 5);
assert_eq!(seg_tree.query(2..3), 10);
assert_eq!(seg_tree.query(..2), 3);
}
#[test]
fn test_large_tree() {
let mut seg_tree = SegTree::<SumSpec>::from_vec((1..=1000).collect());
assert_eq!(seg_tree.query(..), 500500);
assert_eq!(seg_tree.query(..500), 125250);
seg_tree.update(499, 1000);
assert_eq!(seg_tree.query(..), 500500 + 500);
assert_eq!(seg_tree.query(..500), 125250 + 500);
}
#[test]
#[should_panic(expected = "update index out of bounds")]
fn test_panic_update_out_of_bounds() {
let mut seg_tree = SegTree::<SumSpec>::new(10);
seg_tree.update(10, 5);
}
#[test]
#[should_panic]
fn test_panic_query_out_of_bounds() {
let seg_tree = SegTree::<SumSpec>::new(10);
seg_tree.query(..11);
}
#[test]
#[should_panic]
#[allow(clippy::reversed_empty_ranges)]
fn test_panic_query_invalid_range() {
let seg_tree = SegTree::<SumSpec>::new(10);
seg_tree.query(5..4);
}
}