#![no_std]
extern crate alloc;
use alloc::vec::Vec;
use core::ops::{AddAssign, SubAssign};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub struct BITree<T> {
inner: Vec<T>,
}
impl<T> BITree<T> {
#[inline]
pub const fn new() -> Self {
Self { inner: Vec::new() }
}
#[inline]
pub fn new_zeros(n: usize) -> Self
where
T: Default,
{
let mut inner = Vec::with_capacity(n);
for _ in 0..n {
inner.push(T::default());
}
Self { inner }
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Vec::with_capacity(capacity),
}
}
#[inline]
pub const fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[inline]
pub const fn len(&self) -> usize {
self.inner.len()
}
#[inline(always)]
pub fn pop(&mut self) -> bool {
self.inner.pop().is_some()
}
#[inline(always)]
fn walk_prefix<F: FnMut(&mut T, &T)>(&self, index: usize, sum: &mut T, mut op: F) {
assert!(index < self.inner.len() + 1);
let mut current_idx = index;
while current_idx > 0 {
op(sum, &self.inner[current_idx - 1]);
current_idx &= current_idx - 1;
}
}
#[inline(always)]
fn walk_update<F: FnMut(&mut T, &T)>(&mut self, index: usize, diff: T, mut op: F) {
assert!(index < self.len());
let mut current_idx = index;
while let Some(value) = self.inner.get_mut(current_idx) {
op(value, &diff);
current_idx |= current_idx + 1;
}
}
}
impl<T> Default for BITree<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T: for<'a> AddAssign<&'a T>> From<Vec<T>> for BITree<T> {
#[inline]
fn from(mut inner: Vec<T>) -> Self {
let n = inner.len();
rebuild(&mut inner, 0..n, |p, c| *p += c);
BITree { inner }
}
}
impl<T: for<'a> SubAssign<&'a T>> From<BITree<T>> for Vec<T> {
#[inline]
fn from(mut bitree: BITree<T>) -> Self {
let n = bitree.inner.len();
rebuild(&mut bitree.inner, (0..n).rev(), |p, c| *p -= c);
bitree.inner
}
}
impl<T: for<'a> AddAssign<&'a T>> FromIterator<T> for BITree<T> {
#[inline]
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Self::from(iter.into_iter().collect::<Vec<_>>())
}
}
impl<T: for<'a> SubAssign<&'a T>> IntoIterator for BITree<T> {
type Item = T;
type IntoIter = alloc::vec::IntoIter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
Vec::from(self).into_iter()
}
}
#[inline(always)]
fn rebuild<T, I, F>(inner: &mut [T], indices: I, mut op: F)
where
I: IntoIterator<Item = usize>,
F: FnMut(&mut T, &T),
{
let n = inner.len();
let ptr = inner.as_mut_ptr();
for i in indices {
let parent = i | (i + 1);
if parent < n {
unsafe {
let child = &*ptr.add(i);
let parent_ref = &mut *ptr.add(parent);
op(parent_ref, child);
}
}
}
}
impl<T: for<'a> AddAssign<&'a T> + for<'a> SubAssign<&'a T>> BITree<T> {
#[inline]
pub fn add_prefix_sum(&self, index: usize, sum: &mut T) {
self.walk_prefix(index, sum, |s, v| *s += v);
}
#[inline]
pub fn sub_prefix_sum(&self, index: usize, sum: &mut T) {
self.walk_prefix(index, sum, |s, v| *s -= v);
}
#[inline]
pub fn prefix_sum(&self, index: usize) -> T
where
T: Default,
{
let mut sum = T::default();
self.add_prefix_sum(index, &mut sum);
sum
}
#[inline]
pub fn add_at(&mut self, index: usize, diff: T) {
self.walk_update(index, diff, |v, d| *v += d);
}
#[inline]
pub fn sub_at(&mut self, index: usize, diff: T) {
self.walk_update(index, diff, |v, d| *v -= d);
}
#[inline]
pub fn push(&mut self, mut value: T) {
let n = self.inner.len();
for i in 0..n.trailing_ones() {
let child = n & !(1 << i);
value += &self.inner[child];
}
self.inner.push(value);
}
}
impl<T: for<'a> AddAssign<&'a T> + for<'a> SubAssign<&'a T> + PartialOrd> BITree<T> {
pub fn sub_binary_search(&self, remainder: &mut T) -> usize {
let n = self.inner.len();
let mut pos = 0;
let mut mask = n.checked_ilog2().map_or(0, |log| 1 << log);
while mask > 0 {
let next = pos + mask;
if next <= n {
let value = &self.inner[next - 1];
if !(*remainder < *value) {
pos = next;
*remainder -= value;
}
}
mask >>= 1;
}
pos
}
#[inline]
pub fn binary_search(&self, mut target: T) -> Result<usize, usize>
where
T: Default,
{
let pos = self.sub_binary_search(&mut target);
let zero = T::default();
if target < zero {
Err(pos)
} else if zero < target {
Err(pos + 1)
} else {
Ok(pos)
}
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::BITree;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn test_new() {
let lengths: [usize; 5] = [1, 6, 3, 9, 2];
let expected_index = vec![1, 7, 3, 19, 2];
let actual_index = BITree::from_iter(lengths);
assert_eq!(expected_index, actual_index.inner);
let n = 5;
let tree = BITree::<usize>::new_zeros(5);
assert_eq!(tree.len(), n);
assert!(!tree.is_empty());
assert_eq!(tree.prefix_sum(0), 0);
assert_eq!(tree.prefix_sum(3), 0);
assert_eq!(tree.prefix_sum(5), 0);
}
#[test]
fn test_prefix_sum() {
let lengths = [1, 6, 3, 9, 2];
let bitree = BITree::from_iter(lengths);
let cases: Vec<(usize, usize)> = vec![(0, 0), (1, 1), (2, 7), (3, 10), (4, 19), (5, 21)];
cases
.into_iter()
.for_each(|(idx, expected_sum)| assert_eq!(bitree.prefix_sum(idx), expected_sum))
}
#[test]
fn test_update_index() {
let lengths = [1, 6, 3, 9, 2];
let mut bitree = BITree::from_iter(lengths);
let cases: Vec<(usize, usize)> = vec![(0, 2), (1, 8), (2, 3), (3, 20), (4, 2)];
bitree.add_at(0, 1);
cases
.into_iter()
.for_each(|(idx, expected_value)| assert_eq!(bitree.inner[idx], expected_value))
}
#[test]
fn test_binary_search() {
let lengths = [1, 6, 3, 9, 2];
let bitree = BITree::from_iter(lengths);
let cases: Vec<(usize, Result<usize, usize>)> = vec![
(0, Ok(0)),
(1, Ok(1)),
(7, Ok(2)),
(10, Ok(3)),
(19, Ok(4)),
(21, Ok(5)),
(6, Err(2)),
(9, Err(3)),
(18, Err(4)),
(20, Err(5)),
(22, Err(6)),
];
cases
.into_iter()
.for_each(|(target, expected)| assert_eq!(bitree.binary_search(target), expected))
}
#[test]
#[ntest::timeout(1000)]
fn test_zero_array() {
let f0: BITree<usize> = BITree::from_iter([0]);
assert_eq!(f0.prefix_sum(0), 0);
assert_eq!(f0.binary_search(1), Err(2));
let mut remaining = 1usize;
assert_eq!(f0.sub_binary_search(&mut remaining), 1);
assert_eq!(remaining, 1);
}
#[test]
fn test_sub_binary_search_empty() {
let bitree: BITree<usize> = BITree::new();
let mut remaining = 5;
assert_eq!(bitree.sub_binary_search(&mut remaining), 0);
assert_eq!(remaining, 5);
assert_eq!(bitree.binary_search(0usize), Ok(0));
assert_eq!(bitree.binary_search(5usize), Err(1));
}
#[test]
fn test_sub_binary_search_single() {
let bitree = BITree::from_iter([7usize]);
let cases: Vec<(usize, (usize, usize))> =
vec![(0, (0, 0)), (1, (0, 1)), (7, (1, 0)), (8, (1, 1))];
cases.into_iter().for_each(|(target, expected)| {
let mut remaining = target;
let idx = bitree.sub_binary_search(&mut remaining);
assert_eq!((idx, remaining), expected, "target={}", target);
});
}
#[test]
fn test_sub_binary_search_power_of_two_len() {
let bitree = BITree::from_iter([2usize, 3, 5, 7]);
let cases: Vec<(usize, (usize, usize))> = vec![
(0, (0, 0)),
(2, (1, 0)), (3, (1, 1)),
(5, (2, 0)), (6, (2, 1)),
(10, (3, 0)), (11, (3, 1)),
(17, (4, 0)), (18, (4, 1)), ];
cases.into_iter().for_each(|(target, expected)| {
let mut remaining = target;
let idx = bitree.sub_binary_search(&mut remaining);
assert_eq!((idx, remaining), expected, "target={}", target);
});
}
#[test]
fn test_sub_binary_search_uniform_seven() {
let bitree = BITree::from_iter([1usize; 7]);
let cases: Vec<(usize, (usize, usize))> = vec![
(0, (0, 0)),
(1, (1, 0)),
(2, (2, 0)),
(3, (3, 0)),
(4, (4, 0)),
(5, (5, 0)),
(6, (6, 0)),
(7, (7, 0)),
(8, (7, 1)), ];
cases.into_iter().for_each(|(target, expected)| {
let mut remaining = target;
let idx = bitree.sub_binary_search(&mut remaining);
assert_eq!((idx, remaining), expected, "target={}", target);
});
}
#[test]
fn test_sub_binary_search_exceeds_total() {
let bitree = BITree::from_iter([1usize, 6, 3, 9, 2]);
let mut remaining = 100;
assert_eq!(bitree.sub_binary_search(&mut remaining), 5);
assert_eq!(remaining, 100 - 21);
}
#[test]
fn test_push_empty() {
let mut bitree = BITree::new();
bitree.push(5);
assert_eq!(bitree.inner, vec![5]);
assert_eq!(bitree.prefix_sum(1), 5);
}
#[test]
fn test_push_sequence() {
let mut bitree = BITree::new();
let values = [1, 6, 3, 9, 2];
let expected_sums = vec![(1, 1), (2, 7), (3, 10), (4, 19), (5, 21)];
for &v in values.iter() {
bitree.push(v);
}
expected_sums
.into_iter()
.for_each(|(idx, expected_sum)| assert_eq!(bitree.prefix_sum(idx), expected_sum));
}
#[test]
fn test_push_after_initialization() {
let mut bitree = BITree::from_iter([1, 6, 3].into_iter());
bitree.push(9);
bitree.push(2);
let expected_sums = vec![(1, 1), (2, 7), (3, 10), (4, 19), (5, 21)];
expected_sums
.into_iter()
.for_each(|(idx, expected_sum)| assert_eq!(bitree.prefix_sum(idx), expected_sum));
}
#[test]
fn test_pop_empty() {
let mut bitree: BITree<usize> = BITree::new();
assert_eq!(bitree.pop(), false);
}
#[test]
fn test_pop_single() {
let mut bitree = BITree::from_iter([5].into_iter());
assert_eq!(bitree.pop(), true);
assert!(bitree.is_empty());
}
#[test]
fn test_pop_sequence() {
let mut bitree = BITree::from_iter([1, 6, 3, 9, 2].into_iter());
assert_eq!(bitree.pop(), true);
assert_eq!(bitree.pop(), true);
assert_eq!(bitree.pop(), true);
assert_eq!(bitree.prefix_sum(1), 1);
assert_eq!(bitree.prefix_sum(2), 7);
}
#[test]
fn test_push_pop_alternating() {
let mut bitree = BITree::new();
bitree.push(1);
bitree.push(6);
assert_eq!(bitree.pop(), true);
bitree.push(3);
assert_eq!(bitree.pop(), true);
bitree.push(9);
bitree.push(2);
assert_eq!(bitree.pop(), true);
assert_eq!(bitree.prefix_sum(1), 1);
assert_eq!(bitree.prefix_sum(2), 10);
}
#[test]
fn test_zero_handling() {
let mut bitree = BITree::new();
bitree.push(0);
bitree.push(0);
assert_eq!(bitree.pop(), true);
assert_eq!(bitree.prefix_sum(1), 0);
}
#[test]
fn test_negative_values() {
let mut bitree: BITree<i32> = BITree::new();
bitree.push(-1);
bitree.push(2);
bitree.push(-3);
assert_eq!(bitree.pop(), true);
assert_eq!(bitree.prefix_sum(2), 1);
}
}