Skip to main content

oxihuman_core/
segment_tree.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! Range-sum segment tree over f32 values.
6
7/// A segment tree that supports point-update and range-sum queries.
8#[allow(dead_code)]
9pub struct SegmentTree {
10    n: usize,
11    data: Vec<f32>,
12}
13
14#[allow(dead_code)]
15impl SegmentTree {
16    /// Build from a slice of f32 values.
17    pub fn build(values: &[f32]) -> Self {
18        let n = values.len();
19        let mut data = vec![0.0_f32; 2 * n];
20        for (i, &v) in values.iter().enumerate() {
21            data[n + i] = v;
22        }
23        #[allow(clippy::needless_range_loop)]
24        for i in (1..n).rev() {
25            data[i] = data[2 * i] + data[2 * i + 1];
26        }
27        Self { n, data }
28    }
29
30    /// Update position `pos` to `value`.
31    pub fn update(&mut self, mut pos: usize, value: f32) {
32        pos += self.n;
33        self.data[pos] = value;
34        let mut i = pos >> 1;
35        while i >= 1 {
36            self.data[i] = self.data[2 * i] + self.data[2 * i + 1];
37            i >>= 1;
38        }
39    }
40
41    /// Sum over [l, r) (exclusive right).
42    pub fn query(&self, mut l: usize, mut r: usize) -> f32 {
43        let mut sum = 0.0_f32;
44        l += self.n;
45        r += self.n;
46        while l < r {
47            if l & 1 != 0 {
48                sum += self.data[l];
49                l += 1;
50            }
51            if r & 1 != 0 {
52                r -= 1;
53                sum += self.data[r];
54            }
55            l >>= 1;
56            r >>= 1;
57        }
58        sum
59    }
60
61    /// Query single element.
62    pub fn get(&self, pos: usize) -> f32 {
63        self.data[self.n + pos]
64    }
65
66    /// Total sum of all elements.
67    pub fn total(&self) -> f32 {
68        if self.n == 0 {
69            0.0
70        } else {
71            self.data[1]
72        }
73    }
74
75    pub fn len(&self) -> usize {
76        self.n
77    }
78
79    pub fn is_empty(&self) -> bool {
80        self.n == 0
81    }
82}
83
84pub fn build_segment_tree(values: &[f32]) -> SegmentTree {
85    SegmentTree::build(values)
86}
87
88pub fn seg_query(tree: &SegmentTree, l: usize, r: usize) -> f32 {
89    tree.query(l, r)
90}
91
92pub fn seg_update(tree: &mut SegmentTree, pos: usize, val: f32) {
93    tree.update(pos, val);
94}
95
96pub fn seg_total(tree: &SegmentTree) -> f32 {
97    tree.total()
98}
99
100pub fn seg_get(tree: &SegmentTree, pos: usize) -> f32 {
101    tree.get(pos)
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn total_sum() {
110        let t = build_segment_tree(&[1.0, 2.0, 3.0, 4.0]);
111        assert!((seg_total(&t) - 10.0).abs() < 1e-6);
112    }
113
114    #[test]
115    fn range_query() {
116        let t = build_segment_tree(&[1.0, 2.0, 3.0, 4.0]);
117        assert!((seg_query(&t, 1, 3) - 5.0).abs() < 1e-6);
118    }
119
120    #[test]
121    fn point_update() {
122        let mut t = build_segment_tree(&[1.0, 2.0, 3.0]);
123        seg_update(&mut t, 1, 10.0);
124        assert!((seg_get(&t, 1) - 10.0).abs() < 1e-6);
125        assert!((seg_total(&t) - 14.0).abs() < 1e-6);
126    }
127
128    #[test]
129    fn empty_tree() {
130        let t = build_segment_tree(&[]);
131        assert!(t.is_empty());
132        assert!((seg_total(&t)).abs() < 1e-6);
133    }
134
135    #[test]
136    fn single_element() {
137        let t = build_segment_tree(&[42.0]);
138        assert!((seg_total(&t) - 42.0).abs() < 1e-6);
139        assert!((seg_query(&t, 0, 1) - 42.0).abs() < 1e-6);
140    }
141
142    #[test]
143    fn full_range_equals_total() {
144        let vals = [1.0, 5.0, 2.0, 8.0];
145        let t = build_segment_tree(&vals);
146        assert!((seg_query(&t, 0, 4) - seg_total(&t)).abs() < 1e-6);
147    }
148
149    #[test]
150    fn update_first_element() {
151        let mut t = build_segment_tree(&[0.0, 3.0, 5.0]);
152        seg_update(&mut t, 0, 7.0);
153        assert!((seg_total(&t) - 15.0).abs() < 1e-6);
154    }
155
156    #[test]
157    fn len_matches_input() {
158        let t = build_segment_tree(&[1.0, 2.0, 3.0]);
159        assert_eq!(t.len(), 3);
160    }
161
162    #[test]
163    fn get_returns_leaf() {
164        let t = build_segment_tree(&[9.0, 4.0, 7.0]);
165        assert!((seg_get(&t, 2) - 7.0).abs() < 1e-6);
166    }
167
168    #[test]
169    fn repeated_updates_consistent() {
170        let mut t = build_segment_tree(&[1.0, 1.0, 1.0, 1.0]);
171        seg_update(&mut t, 0, 10.0);
172        seg_update(&mut t, 3, 10.0);
173        assert!((seg_total(&t) - 22.0).abs() < 1e-6);
174    }
175}