use crate::errors::Result;
pub(super) trait CombineOp<T> {
fn combine(&mut self, arg1: T, arg2: T) -> Result<T>;
}
pub(super) fn log_depth_sum<T: std::clone::Clone>(
items: &[T],
combine_op: &mut dyn CombineOp<T>,
) -> Result<T> {
if items.is_empty() {
return Err(runtime_error!("Cannot combine empty vector"));
}
let mut combined_items = items.to_owned();
while combined_items.len() > 1 {
let mut new_combined_items = vec![];
for i in (0..combined_items.len()).step_by(2) {
let j = i + 1;
if j >= combined_items.len() {
new_combined_items.push(combined_items[i].clone());
} else {
let new_item =
combine_op.combine(combined_items[i].clone(), combined_items[j].clone())?;
new_combined_items.push(new_item);
}
}
combined_items = new_combined_items;
}
Ok(combined_items[0].clone())
}
#[allow(dead_code)]
pub(super) fn prefix_sums_binary_ascent<T: std::clone::Clone>(
items: &[T],
combine_op: &mut dyn CombineOp<T>,
) -> Result<Vec<T>> {
if items.is_empty() {
return Ok(vec![]);
}
let mut combined_items = items.to_owned();
let mut depth = 1;
while depth < combined_items.len() {
for i in (depth..combined_items.len()).rev() {
combined_items[i] =
combine_op.combine(combined_items[i - depth].clone(), combined_items[i].clone())?;
}
depth *= 2;
}
Ok(combined_items)
}
#[allow(dead_code)]
pub(super) fn prefix_sums_sqrt_trick<T: std::clone::Clone>(
items: &[T],
combine_op: &mut dyn CombineOp<T>,
) -> Result<Vec<T>> {
if items.is_empty() {
return Ok(vec![]);
}
let block_size = std::cmp::max(1, (items.len() as f64).sqrt() as usize);
let mut combined_items = items.to_owned();
for i in 0..combined_items.len() {
if i % block_size != 0 {
combined_items[i] =
combine_op.combine(combined_items[i - 1].clone(), combined_items[i].clone())?;
}
}
for i in block_size..combined_items.len() {
combined_items[i] = combine_op.combine(
combined_items[i - i % block_size - 1].clone(),
combined_items[i].clone(),
)?;
}
Ok(combined_items)
}
#[allow(dead_code)]
pub(super) fn prefix_sums_segment_tree<T: std::clone::Clone>(
items: &[T],
combine_op: &mut dyn CombineOp<T>,
) -> Result<Vec<T>> {
if items.is_empty() {
return Ok(vec![]);
}
let mut layers = vec![items.to_owned()];
let mut layer = 0;
while layers[layer].len() > 1 {
let mut next_layer = vec![];
for i in (0..layers[layer].len()).step_by(2) {
if i + 1 < layers[layer].len() {
next_layer.push(
combine_op.combine(layers[layer][i].clone(), layers[layer][i + 1].clone())?,
);
} else {
next_layer.push(layers[layer][i].clone());
}
}
layer += 1;
layers.push(next_layer);
}
for i in (0..layers.len() - 1).rev() {
for j in 1..layers[i].len() {
if j % 2 == 1 {
layers[i][j] = layers[i + 1][j / 2].clone();
} else {
layers[i][j] =
combine_op.combine(layers[i + 1][(j - 1) / 2].clone(), layers[i][j].clone())?;
}
}
}
Ok(layers[0].clone())
}
#[cfg(test)]
mod tests {
use super::*;
struct IntCombiner {}
impl CombineOp<u64> for IntCombiner {
fn combine(&mut self, arg1: u64, arg2: u64) -> Result<u64> {
return Ok(arg1 + arg2);
}
}
#[test]
fn test_log_depth_sum() {
for len in 1..20 {
let mut v = vec![];
for i in 1..len + 1 {
v.push(i * i);
}
let mut expected = 0;
for x in v.clone() {
expected += x;
}
let mut combiner = IntCombiner {};
let actual = log_depth_sum(&v, &mut combiner).unwrap();
assert_eq!(expected, actual);
}
}
#[test]
fn test_prefix_sums() {
for len in 0..20 {
let mut v = vec![];
for i in 1..len + 1 {
v.push(i * i);
}
let mut expected = vec![];
let mut sum = 0;
for x in v.clone() {
sum += x;
expected.push(sum);
}
let mut combiner = IntCombiner {};
let actual_bin_ascent = prefix_sums_binary_ascent(&v, &mut combiner).unwrap();
assert_eq!(expected, actual_bin_ascent);
let actual_sqrt = prefix_sums_sqrt_trick(&v, &mut combiner).unwrap();
assert_eq!(expected, actual_sqrt);
let actual_fenwick = prefix_sums_segment_tree(&v, &mut combiner).unwrap();
assert_eq!(expected, actual_fenwick);
}
}
struct TrackingCombiner {
total_calls: u64,
}
impl CombineOp<u64> for TrackingCombiner {
fn combine(&mut self, arg1: u64, arg2: u64) -> Result<u64> {
self.total_calls += 1;
return Ok(std::cmp::max(arg1, arg2) + 1); }
}
#[test]
fn test_complexity() {
for len in 200..300 {
let v = vec![0; len];
{
let mut combiner = TrackingCombiner { total_calls: 0 };
let depth = log_depth_sum(&v, &mut combiner).unwrap();
assert!(depth <= (len as f64).log(2.0).ceil() as u64);
assert!(combiner.total_calls <= len as u64 - 1);
}
{
let mut combiner = TrackingCombiner { total_calls: 0 };
let depths = prefix_sums_binary_ascent(&v, &mut combiner).unwrap();
let log_n: u64 = (len as f64).log(2.0).ceil() as u64;
assert!(depths.iter().max().unwrap() <= &log_n);
assert!(combiner.total_calls <= (len as u64) * log_n);
}
{
let mut combiner = TrackingCombiner { total_calls: 0 };
let depths = prefix_sums_sqrt_trick(&v, &mut combiner).unwrap();
let sqrt_n: u64 = (len as f64).sqrt().ceil() as u64;
assert!(depths.iter().max().unwrap() <= &(2 * sqrt_n + 1));
assert!(combiner.total_calls <= (len as u64) * 2);
}
{
let mut combiner = TrackingCombiner { total_calls: 0 };
let depths = prefix_sums_segment_tree(&v, &mut combiner).unwrap();
let log_n: u64 = (len as f64).log(2.0).ceil() as u64;
assert!(depths.iter().max().unwrap() <= &(2 * log_n + 1));
assert!(combiner.total_calls <= (len as u64) * 2);
}
}
}
}