Skip to main content

oxihuman_core/
segment_tree_v2.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! Segment tree for range min/max/sum queries over i64 values.
6
7/// The operation supported by the segment tree.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[allow(dead_code)]
10pub enum SegOp {
11    Sum,
12    Min,
13    Max,
14}
15
16/// A segment tree supporting point updates and range queries.
17#[derive(Debug, Clone)]
18#[allow(dead_code)]
19pub struct SegTreeV2 {
20    n: usize,
21    tree: Vec<i64>,
22    op: SegOp,
23}
24
25impl SegTreeV2 {
26    fn identity(&self) -> i64 {
27        match self.op {
28            SegOp::Sum => 0,
29            SegOp::Min => i64::MAX,
30            SegOp::Max => i64::MIN,
31        }
32    }
33
34    fn combine(&self, a: i64, b: i64) -> i64 {
35        match self.op {
36            SegOp::Sum => a + b,
37            SegOp::Min => a.min(b),
38            SegOp::Max => a.max(b),
39        }
40    }
41
42    /// Build from a slice.
43    #[allow(dead_code)]
44    pub fn build(data: &[i64], op: SegOp) -> Self {
45        let n = data.len();
46        let mut tree = vec![0i64; 2 * n];
47        for (i, &v) in data.iter().enumerate() {
48            tree[n + i] = v;
49        }
50        let mut st = Self { n, tree, op };
51        for i in (1..n).rev() {
52            st.tree[i] = st.combine(st.tree[2 * i], st.tree[2 * i + 1]);
53        }
54        st
55    }
56
57    /// Point update at index `i`.
58    #[allow(dead_code)]
59    pub fn update(&mut self, mut i: usize, val: i64) {
60        if i >= self.n {
61            return;
62        }
63        i += self.n;
64        self.tree[i] = val;
65        let mut pos = i >> 1;
66        while pos >= 1 {
67            self.tree[pos] = self.combine(self.tree[2 * pos], self.tree[2 * pos + 1]);
68            if pos == 1 {
69                break;
70            }
71            pos >>= 1;
72        }
73    }
74
75    /// Range query over `[l, r)`.
76    #[allow(dead_code)]
77    pub fn query(&self, mut l: usize, mut r: usize) -> i64 {
78        if l >= r || self.n == 0 {
79            return self.identity();
80        }
81        let mut res = self.identity();
82        l += self.n;
83        r += self.n;
84        while l < r {
85            if l & 1 == 1 {
86                res = self.combine(res, self.tree[l]);
87                l += 1;
88            }
89            if r & 1 == 1 {
90                r -= 1;
91                res = self.combine(res, self.tree[r]);
92            }
93            l >>= 1;
94            r >>= 1;
95        }
96        res
97    }
98
99    /// Number of elements.
100    #[allow(dead_code)]
101    pub fn len(&self) -> usize {
102        self.n
103    }
104
105    /// Returns true if empty.
106    #[allow(dead_code)]
107    pub fn is_empty(&self) -> bool {
108        self.n == 0
109    }
110}
111
112/// Build a sum segment tree.
113#[allow(dead_code)]
114pub fn seg2_sum(data: &[i64]) -> SegTreeV2 {
115    SegTreeV2::build(data, SegOp::Sum)
116}
117
118/// Build a min segment tree.
119#[allow(dead_code)]
120pub fn seg2_min(data: &[i64]) -> SegTreeV2 {
121    SegTreeV2::build(data, SegOp::Min)
122}
123
124/// Build a max segment tree.
125#[allow(dead_code)]
126pub fn seg2_max(data: &[i64]) -> SegTreeV2 {
127    SegTreeV2::build(data, SegOp::Max)
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn sum_query_full_range() {
136        let t = seg2_sum(&[1, 2, 3, 4, 5]);
137        assert_eq!(t.query(0, 5), 15);
138    }
139
140    #[test]
141    fn sum_query_partial() {
142        let t = seg2_sum(&[1, 2, 3, 4, 5]);
143        assert_eq!(t.query(1, 4), 9);
144    }
145
146    #[test]
147    fn min_query() {
148        let t = seg2_min(&[5, 3, 8, 1, 7]);
149        assert_eq!(t.query(0, 5), 1);
150    }
151
152    #[test]
153    fn max_query() {
154        let t = seg2_max(&[5, 3, 8, 1, 7]);
155        assert_eq!(t.query(0, 5), 8);
156    }
157
158    #[test]
159    fn update_and_requery() {
160        let mut t = seg2_sum(&[1, 2, 3]);
161        t.update(1, 10);
162        assert_eq!(t.query(0, 3), 14);
163    }
164
165    #[test]
166    fn len_correct() {
167        let t = seg2_sum(&[1, 2, 3, 4]);
168        assert_eq!(t.len(), 4);
169    }
170
171    #[test]
172    fn empty_query_returns_identity() {
173        let t = seg2_sum(&[1, 2, 3]);
174        assert_eq!(t.query(2, 2), 0);
175    }
176
177    #[test]
178    fn single_element() {
179        let t = seg2_min(&[42]);
180        assert_eq!(t.query(0, 1), 42);
181    }
182
183    #[test]
184    fn is_empty_false() {
185        let t = seg2_sum(&[1]);
186        assert!(!t.is_empty());
187    }
188
189    #[test]
190    fn max_update_changes_result() {
191        let mut t = seg2_max(&[1, 2, 3]);
192        t.update(0, 100);
193        assert_eq!(t.query(0, 3), 100);
194    }
195}