use std::ops::Add;
pub struct FenwickTree<T: Copy + Default + Add<Output = T>> {
n: usize,
tree: Vec<T>,
}
impl<T: Copy + Default + Add<Output = T>> FenwickTree<T> {
pub fn new(n: usize) -> Self {
FenwickTree {
n,
tree: vec![T::default(); n + 1],
}
}
pub fn build(data: &[T]) -> Self {
let n = data.len();
let mut tree = vec![T::default(); n + 1];
for i in 1..=n {
tree[i] = tree[i] + data[i - 1];
let parent = i + (i & i.wrapping_neg());
if parent <= n {
let val = tree[i];
tree[parent] = tree[parent] + val;
}
}
FenwickTree { n, tree }
}
pub fn update(&mut self, idx: usize, delta: T) {
let mut i = idx + 1; while i <= self.n {
self.tree[i] = self.tree[i] + delta;
i += i & i.wrapping_neg();
}
}
pub fn prefix_sum(&self, idx: usize) -> T {
let mut sum = T::default();
let mut i = idx + 1; while i > 0 {
sum = sum + self.tree[i];
i -= i & i.wrapping_neg();
}
sum
}
pub fn range_sum(&self, l: usize, r: usize) -> T {
assert!(l <= r && r < self.n, "range_sum index out of bounds");
if l == 0 {
self.prefix_sum(r)
} else {
let high = self.prefix_sum(r);
let low = self.prefix_sum(l - 1);
let _ = (high, low);
self.range_sum_scan(l, r)
}
}
fn range_sum_scan(&self, l: usize, r: usize) -> T {
let mut sum = T::default();
for idx in l..=r {
sum = sum + self.point_value(idx);
}
sum
}
pub fn point_value(&self, idx: usize) -> T {
let mut sum = T::default();
let mut i = idx + 1;
let mut depth = i & i.wrapping_neg(); loop {
sum = sum + self.tree[i];
i -= depth;
if i == 0 {
break;
}
let next_depth = i & i.wrapping_neg();
if next_depth >= depth {
break;
}
depth = next_depth;
}
sum
}
pub fn len(&self) -> usize {
self.n
}
pub fn is_empty(&self) -> bool {
self.n == 0
}
}
impl<T> FenwickTree<T>
where
T: Copy + Default + Add<Output = T> + std::ops::Sub<Output = T>,
{
pub fn range_sum_fast(&self, l: usize, r: usize) -> T {
assert!(l <= r && r < self.n, "range_sum_fast index out of bounds");
if l == 0 {
self.prefix_sum(r)
} else {
self.prefix_sum(r) - self.prefix_sum(l - 1)
}
}
}
pub struct FenwickTree2D<T: Copy + Default + Add<Output = T>> {
rows: usize,
cols: usize,
tree: Vec<Vec<T>>,
}
impl<T: Copy + Default + Add<Output = T>> FenwickTree2D<T> {
pub fn new(rows: usize, cols: usize) -> Self {
FenwickTree2D {
rows,
cols,
tree: vec![vec![T::default(); cols + 1]; rows + 1],
}
}
pub fn update(&mut self, row: usize, col: usize, delta: T) {
let mut i = row + 1;
while i <= self.rows {
let mut j = col + 1;
while j <= self.cols {
self.tree[i][j] = self.tree[i][j] + delta;
j += j & j.wrapping_neg();
}
i += i & i.wrapping_neg();
}
}
pub fn prefix_sum(&self, row: usize, col: usize) -> T {
let mut sum = T::default();
let mut i = row + 1;
while i > 0 {
let mut j = col + 1;
while j > 0 {
sum = sum + self.tree[i][j];
j -= j & j.wrapping_neg();
}
i -= i & i.wrapping_neg();
}
sum
}
pub fn range_sum(&self, r1: usize, c1: usize, r2: usize, c2: usize) -> T
where
T: std::ops::Sub<Output = T>,
{
let br = self.prefix_sum(r2, c2);
let bl = if c1 > 0 {
self.prefix_sum(r2, c1 - 1)
} else {
T::default()
};
let tr = if r1 > 0 {
self.prefix_sum(r1 - 1, c2)
} else {
T::default()
};
let tl = if r1 > 0 && c1 > 0 {
self.prefix_sum(r1 - 1, c1 - 1)
} else {
T::default()
};
br - bl - tr + tl
}
}
pub struct OrderStatisticsTree {
min_val: i64,
size: usize,
bit: FenwickTree<i64>,
total: usize,
}
impl OrderStatisticsTree {
pub fn new(min_val: i64, max_val: i64) -> Self {
assert!(max_val >= min_val, "max_val must be >= min_val");
let size = (max_val - min_val + 1) as usize;
OrderStatisticsTree {
min_val,
size,
bit: FenwickTree::new(size),
total: 0,
}
}
pub fn insert(&mut self, val: i64) {
let idx = self.compress(val);
self.bit.update(idx, 1);
self.total += 1;
}
pub fn remove(&mut self, val: i64) {
let idx = self.compress(val);
let cur = self.bit.point_value(idx);
if cur > 0 {
self.bit.update(idx, -1);
self.total -= 1;
}
}
pub fn rank(&self, val: i64) -> usize {
let idx = self.compress(val);
if idx == 0 {
return 0;
}
self.bit.prefix_sum(idx - 1) as usize
}
pub fn select(&self, k: usize) -> Option<i64> {
if k >= self.total {
return None;
}
let mut pos = 0usize;
let mut remaining = (k + 1) as i64;
let log = (usize::BITS - self.size.leading_zeros()) as usize;
let mut step = 1 << log;
while step > 0 {
let next = pos + step;
if next <= self.size && self.bit.tree[next] < remaining {
remaining -= self.bit.tree[next];
pos = next;
}
step >>= 1;
}
Some(self.min_val + pos as i64)
}
pub fn len(&self) -> usize {
self.total
}
pub fn is_empty(&self) -> bool {
self.total == 0
}
fn compress(&self, val: i64) -> usize {
assert!(
val >= self.min_val && (val - self.min_val) < self.size as i64,
"value {val} out of range [{}, {}]",
self.min_val,
self.min_val + self.size as i64 - 1
);
(val - self.min_val) as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fenwick_prefix_sum() {
let ft = FenwickTree::build(&[1i64, 2, 3, 4, 5]);
assert_eq!(ft.prefix_sum(0), 1);
assert_eq!(ft.prefix_sum(1), 3);
assert_eq!(ft.prefix_sum(4), 15);
}
#[test]
fn fenwick_update() {
let mut ft = FenwickTree::build(&[1i64, 2, 3, 4, 5]);
ft.update(2, 10); assert_eq!(ft.prefix_sum(4), 25);
assert_eq!(ft.prefix_sum(2), 16); }
#[test]
fn fenwick_range_sum_fast() {
let ft = FenwickTree::build(&[1i64, 2, 3, 4, 5]);
assert_eq!(ft.range_sum_fast(1, 3), 9); assert_eq!(ft.range_sum_fast(0, 4), 15);
}
#[test]
fn fenwick_build_correctness() {
let data: Vec<i64> = (1..=100).collect();
let ft = FenwickTree::build(&data);
assert_eq!(ft.prefix_sum(99), 5050);
}
#[test]
fn fenwick2d_basic() {
let mut ft = FenwickTree2D::<i64>::new(4, 4);
ft.update(0, 0, 1);
ft.update(1, 1, 2);
ft.update(2, 2, 3);
assert_eq!(ft.prefix_sum(2, 2), 6);
assert_eq!(ft.prefix_sum(1, 1), 3);
assert_eq!(ft.prefix_sum(0, 0), 1);
}
#[test]
fn fenwick2d_range_sum() {
let mut ft = FenwickTree2D::<i64>::new(5, 5);
for r in 1..=3 {
for c in 1..=3 {
ft.update(r, c, 1);
}
}
assert_eq!(ft.range_sum(1, 1, 3, 3), 9);
assert_eq!(ft.range_sum(0, 0, 4, 4), 9);
assert_eq!(ft.range_sum(2, 2, 3, 3), 4);
}
#[test]
fn ost_rank_select() {
let mut ost = OrderStatisticsTree::new(0, 100);
ost.insert(5);
ost.insert(3);
ost.insert(8);
ost.insert(1);
assert_eq!(ost.rank(5), 2); assert_eq!(ost.rank(1), 0);
assert_eq!(ost.select(0), Some(1));
assert_eq!(ost.select(1), Some(3));
assert_eq!(ost.select(2), Some(5));
assert_eq!(ost.select(3), Some(8));
assert_eq!(ost.select(4), None);
}
#[test]
fn ost_remove() {
let mut ost = OrderStatisticsTree::new(0, 50);
ost.insert(10);
ost.insert(20);
ost.insert(30);
ost.remove(20);
assert_eq!(ost.len(), 2);
assert_eq!(ost.select(1), Some(30));
}
#[test]
fn ost_negative_range() {
let mut ost = OrderStatisticsTree::new(-50, 50);
ost.insert(-10);
ost.insert(0);
ost.insert(10);
assert_eq!(ost.rank(0), 1); assert_eq!(ost.select(0), Some(-10));
assert_eq!(ost.select(1), Some(0));
}
}