use core::mem::MaybeUninit;
use crate::nodes::{LazyNode, Node};
pub struct LazyRecursive<T: LazyNode> {
nodes: Vec<T>,
n: usize,
}
impl<T: LazyNode + Clone> LazyRecursive<T> {
pub fn build(values: &[T]) -> Self {
let n = values.len();
if n == 0 {
return Self {
nodes: Vec::new(),
n,
};
}
let mut nodes = Vec::with_capacity(4 * n);
unsafe { nodes.set_len(4 * n) };
Self::build_helper(0, 0, n - 1, values, &mut nodes);
let ptr = nodes.as_mut_ptr();
core::mem::forget(nodes);
let nodes = unsafe { Vec::from_raw_parts(ptr.cast::<T>(), 4 * n, 4 * n) };
Self { nodes, n }
}
fn build_helper(
curr_node: usize,
i: usize,
j: usize,
values: &[T],
nodes: &mut [MaybeUninit<T>],
) {
if i == j {
nodes[curr_node].write(values[i].clone());
return;
}
let mid = (i + j) / 2;
let left_node = 2 * curr_node + 1;
let right_node = 2 * curr_node + 2;
Self::build_helper(left_node, i, mid, values, nodes);
Self::build_helper(right_node, mid + 1, j, values, nodes);
let (top_nodes, bottom_nodes) = nodes.split_at_mut(curr_node + 1);
top_nodes[curr_node].write(Node::combine(
unsafe { bottom_nodes[left_node - curr_node - 1].assume_init_ref() },
unsafe { bottom_nodes[right_node - curr_node - 1].assume_init_ref() },
));
}
fn push(&mut self, u: usize, i: usize, j: usize) {
let (parent_slice, sons_slice) = self.nodes.split_at_mut(u + 1);
if let Some(value) = parent_slice[u].lazy_value() {
if i != j {
sons_slice[u].update_lazy_value(value); sons_slice[u + 1].update_lazy_value(value); }
}
self.nodes[u].lazy_update(i, j);
}
pub fn update(&mut self, i: usize, j: usize, value: &<T as Node>::Value) {
self.update_helper(i, j, value, 0, 0, self.n - 1);
}
fn update_helper(
&mut self,
left: usize,
right: usize,
value: &<T as Node>::Value,
curr_node: usize,
i: usize,
j: usize,
) {
if self.nodes[curr_node].lazy_value().is_some() {
self.push(curr_node, i, j);
}
if j < left || right < i {
return;
}
if left <= i && j <= right {
self.nodes[curr_node].update_lazy_value(value);
self.push(curr_node, i, j);
return;
}
let mid = (i + j) / 2;
let left_node = 2 * curr_node + 1;
let right_node = 2 * curr_node + 2;
self.update_helper(left, right, value, left_node, i, mid);
self.update_helper(left, right, value, right_node, mid + 1, j);
self.nodes[curr_node] = Node::combine(&self.nodes[left_node], &self.nodes[right_node]);
}
pub fn query(&mut self, left: usize, right: usize) -> Option<T> {
self.query_helper(left, right, 0, 0, self.n - 1)
}
fn query_helper(
&mut self,
left: usize,
right: usize,
curr_node: usize,
i: usize,
j: usize,
) -> Option<T> {
if j < left || right < i {
return None;
}
let mid = (i + j) / 2;
let left_node = 2 * curr_node + 1;
let right_node = 2 * curr_node + 2;
if self.nodes[curr_node].lazy_value().is_some() {
self.push(curr_node, i, j);
}
if left <= i && j <= right {
return Some(self.nodes[curr_node].clone());
}
match (
self.query_helper(left, right, left_node, i, mid),
self.query_helper(left, right, right_node, mid + 1, j),
) {
(Some(ans_left), Some(ans_right)) => Some(Node::combine(&ans_left, &ans_right)),
(Some(ans_left), None) => Some(ans_left),
(None, Some(ans_right)) => Some(ans_right),
(None, None) => None,
}
}
pub fn lower_bound<F, G>(&self, predicate: F, g: G, value: <T as Node>::Value) -> usize
where
F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
{
self.lower_bound_helper(0, 0, self.n - 1, predicate, g, value)
}
fn lower_bound_helper<F, G>(
&self,
curr_node: usize,
i: usize,
j: usize,
predicate: F,
g: G,
value: <T as Node>::Value,
) -> usize
where
F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
{
if i == j {
return i;
}
let mid = (i + j) / 2;
let left_node = 2 * curr_node + 1;
let right_node = 2 * curr_node + 2;
let left_value = self.nodes[left_node].value();
if predicate(left_value, &value) {
self.lower_bound_helper(left_node, i, mid, predicate, g, value)
} else {
let value = g(left_value, value);
self.lower_bound_helper(right_node, mid + 1, j, predicate, g, value)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
nodes::Node,
utils::{LazySetWrapper, Min},
};
use super::LazyRecursive;
type LSMin<T> = LazySetWrapper<Min<T>>;
#[test]
fn build_works() {
let n = 16;
let nodes: Vec<LSMin<usize>> = (0..n).map(|x| LSMin::initialize(&x)).collect();
let mut segment_tree = LazyRecursive::build(&nodes);
for i in 0..n {
let temp = segment_tree.query(i, i).unwrap();
assert_eq!(temp.value(), &i);
}
}
#[test]
fn non_empty_query_returns_some() {
let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
let mut segment_tree = LazyRecursive::build(&nodes);
assert!(segment_tree.query(0, 9).is_some());
}
#[test]
fn empty_query_returns_none() {
let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
let mut segment_tree = LazyRecursive::build(&nodes);
assert!(segment_tree.query(10, 0).is_none());
}
#[test]
fn update_works() {
let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
let mut segment_tree = LazyRecursive::build(&nodes);
let value = 20;
segment_tree.update(0, 9, &value);
assert_eq!(segment_tree.query(0, 1).unwrap().value(), &value);
}
#[test]
fn query_works() {
let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
let mut segment_tree = LazyRecursive::build(&nodes);
assert_eq!(segment_tree.query(1, 9).unwrap().value(), &1);
}
}