segtri 0.2.1

Segment tree with customizable data type and update operations
Documentation
use std::ops::{Add, Mul, Range};

use crate::modify_op::ModifyOp;

#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(test)]
static CHILD_CREATED_CNT: AtomicUsize = AtomicUsize::new(0);

pub struct DivergedSegNode<T, Op> {
    data_acc: T,
    pending_ops_for_children: Op,
    l_child: Box<SegNode<T, Op>>,
    r_child: Box<SegNode<T, Op>>,
}

pub enum SegNode<T, Op> {
    Same(T),
    Diverged(DivergedSegNode<T, Op>),
}

use SegNode::*;

impl<T, Op> SegNode<T, Op> {
    fn modify_whole_with<'a>(&mut self, node_range_len: usize, op: &Op)
    where
        Op: ModifyOp<T> + 'a,
    {
        match self {
            Same(s) => op.apply(s, 1),
            Diverged(d) => {
                op.apply(&mut d.data_acc, node_range_len);
                d.pending_ops_for_children.combine(op);
            }
        }
    }
}

fn range_mid(range: &Range<usize>) -> usize {
    (range.start + range.end) / 2
}

/// right child seg may be larger
fn split_lr_range(range: &Range<usize>) -> (Range<usize>, Range<usize>) {
    let mid = range_mid(range);
    (range.start..mid, mid..range.end)
}

fn range_intersect(
    range1: &Range<usize>,
    range2: &Range<usize>,
) -> Range<usize> {
    range1.start.max(range2.start)..range1.end.min(range2.end)
}

impl<T, Op> DivergedSegNode<T, Op>
where
    Op: ModifyOp<T>,
{
    fn resolve_pending_ops(
        &mut self,
        l_child_range_len: usize,
        r_child_range_len: usize,
    ) {
        let op = &mut self.pending_ops_for_children;
        self.l_child.modify_whole_with(l_child_range_len, op);
        self.r_child.modify_whole_with(r_child_range_len, op);
        *op = Op::nop();
    }
}

impl<T, Op> SegNode<T, Op>
where
    T: Clone,
    for<'x> &'x T: Mul<usize, Output = T> + Add<Output = T>,
    Op: ModifyOp<T>,
{
    pub fn from_same_point_data(same_point_data: T) -> Self {
        Self::Same(same_point_data)
    }

    pub fn with_points(point_data: &[T]) -> Self {
        #[cfg(debug_assertions)]
        {
            assert!(!point_data.is_empty())
        }
        if point_data.len() == 1 {
            return Self::Same(point_data[0].clone());
        }
        let (l_range, r_range) = split_lr_range(&(0..point_data.len()));
        let mut l_child =
            Box::new(Self::with_points(&point_data[l_range.clone()]));
        let mut r_child =
            Box::new(Self::with_points(&point_data[r_range.clone()]));
        let data_acc = &l_child.query(&l_range, &l_range)
            + &r_child.query(&r_range, &r_range);
        Self::Diverged(DivergedSegNode {
            data_acc,
            pending_ops_for_children: Op::nop(),
            l_child,
            r_child,
        })
    }

    pub fn query(
        &mut self,
        node_range: &Range<usize>,
        target_range: &Range<usize>,
    ) -> T {
        #[cfg(debug_assertions)]
        {
            assert!(!target_range.is_empty());
            assert!(
                target_range.start >= node_range.start
                    && target_range.end <= node_range.end,
                "target: {:?} node: {:?}",
                target_range,
                node_range
            );
        }

        let diverged = match self {
            Same(s) => return &*s * target_range.len(),
            Diverged(d) => {
                if target_range == node_range {
                    return d.data_acc.clone();
                }
                d
            }
        };

        let (l_child_range, r_child_range) = split_lr_range(node_range);
        let l_target_range = range_intersect(&l_child_range, target_range);
        let r_target_range = range_intersect(&r_child_range, target_range);

        diverged
            .resolve_pending_ops(l_child_range.len(), r_child_range.len());

        let mut query_l_child =
            || diverged.l_child.query(&l_child_range, &l_target_range);
        let mut query_r_child =
            || diverged.r_child.query(&r_child_range, &r_target_range);

        if l_target_range.is_empty() {
            query_r_child()
        } else if r_target_range.is_empty() {
            query_l_child()
        } else {
            &query_l_child() + &query_r_child()
        }
    }

    pub fn modify(
        &mut self,
        node_range: &Range<usize>,
        target_range: &Range<usize>,
        op: &Op,
    ) {
        #[cfg(debug_assertions)]
        {
            assert!(!target_range.is_empty());
            assert!(!range_intersect(node_range, target_range).is_empty());
            assert!(
                target_range.start >= node_range.start
                    && target_range.end <= node_range.end,
                "target: {:?} node: {:?}",
                target_range,
                node_range
            );
        }

        if node_range == target_range {
            self.modify_whole_with(node_range.len(), op);
            return;
        }

        let (l_child_range, r_child_range) = split_lr_range(node_range);
        let l_target_range = range_intersect(&l_child_range, target_range);
        let r_target_range = range_intersect(&r_child_range, target_range);

        let modify_and_query_children =
            |l_child: &mut Self, r_child: &mut Self| {
                if !r_target_range.is_empty() {
                    r_child.modify(&r_child_range, &r_target_range, op);
                }
                if !l_target_range.is_empty() {
                    l_child.modify(&l_child_range, &l_target_range, op);
                }
                &l_child.query(&l_child_range, &l_child_range)
                    + &r_child.query(&r_child_range, &r_child_range)
            };

        match self {
            Same(s) => {
                #[cfg(test)]
                CHILD_CREATED_CNT.fetch_add(2, Ordering::SeqCst);
                let mut l_child = Box::new(SegNode::Same(s.clone()));
                let mut r_child = Box::new(SegNode::Same(s.clone()));
                *self = Diverged(DivergedSegNode {
                    data_acc: modify_and_query_children(
                        &mut l_child,
                        &mut r_child,
                    ),
                    l_child,
                    r_child,
                    pending_ops_for_children: Op::nop(),
                });
            }
            Diverged(diverged) => {
                diverged.resolve_pending_ops(
                    l_child_range.len(),
                    r_child_range.len(),
                );
                diverged.data_acc = modify_and_query_children(
                    &mut diverged.l_child,
                    &mut diverged.r_child,
                );
            }
        }
    }
}

#[cfg(test)]
mod test {
    use std::usize;

    use serial_test::{parallel, serial};

    use crate::SegTree;

    use super::*;

    #[test]
    #[parallel]
    fn test_split_lr_range() {
        assert_eq!(split_lr_range(&(4..6)), (4..5, 5..6))
    }

    #[test]
    #[serial]
    fn test_laziness() {
        let len = usize::MAX / 4 + 1;
        struct Add(usize);
        impl ModifyOp<usize> for Add {
            fn nop() -> Self {
                Add(0)
            }

            fn combine(&mut self, another_op: &Self) {
                self.0 += another_op.0
            }

            fn apply(&self, orig_seg_data: &mut usize, seg_len: usize) {
                *orig_seg_data += seg_len * self.0
            }
        }
        let mut seg = SegTree::new(len, 1);
        seg.modify(&(len / 4 - 1..0), &Add(1));
        assert_eq!(seg.query_point(len / 5), 1);
        assert_eq!(CHILD_CREATED_CNT.load(Ordering::Acquire), 0);
        seg.modify(&(len / 4..len / 4 * 3), &Add(1));
        assert_eq!(CHILD_CREATED_CNT.load(Ordering::Acquire), 6);
        assert_eq!(seg.query_point(len / 4 - 1), 1);
        assert_eq!(seg.query_point(len / 4), 2);
        assert_eq!(seg.query_point(len / 4 * 3 - 1), 2);
        assert_eq!(seg.query_point(len / 4 * 3), 1);
        assert_eq!(CHILD_CREATED_CNT.load(Ordering::Acquire), 6);
        seg.modify(&(len / 16 * 7..len / 2), &Add(1));
        assert_eq!(CHILD_CREATED_CNT.load(Ordering::Acquire), 10);
        assert_eq!(seg.query(&(len / 2 - 89..len / 2)), 267);
        assert_eq!(CHILD_CREATED_CNT.load(Ordering::Acquire), 10);
    }
}