Skip to main content

dsi_bitstream/utils/
find_change.rs

1/*
2 * SPDX-FileCopyrightText: 2025 Tommaso Fontana
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
5 */
6
7/// Iterates over the points where the given function changes value.
8/// This only works for monotonic non-decreasing functions.
9///
10/// Each call to next returns a tuple with the first input where the function
11/// changes value and the new value.
12///
13/// This is useful to generate data following the implied distribution of a code.
14pub struct FindChangePoints<F: Fn(u64) -> usize> {
15    func: F,
16    current: u64,
17    prev_value: usize,
18}
19
20impl<F: Fn(u64) -> usize> FindChangePoints<F> {
21    #[must_use]
22    pub fn new(func: F) -> Self {
23        Self {
24            func,
25            current: 0,
26            prev_value: usize::MAX,
27        }
28    }
29}
30
31impl<F: Fn(u64) -> usize> Iterator for FindChangePoints<F> {
32    /// (first input, output)
33    type Item = (u64, usize);
34
35    fn next(&mut self) -> Option<Self::Item> {
36        // handle the first case, we don't need to search for the first change
37        if self.current == 0 && self.prev_value == usize::MAX {
38            self.prev_value = (self.func)(0);
39            return Some((0, self.prev_value));
40        }
41
42        // Exponential search to find next potential change point starting from
43        // the last change point.
44        let mut step = 1;
45        loop {
46            // Avoid overflow, use <= instead of < because most of our codes
47            // cannot encode u64::MAX, so let's just ignore it
48            if u64::MAX - self.current <= step {
49                return None;
50            }
51            // check if we found a change point
52            let new_val = (self.func)(self.current + step);
53            debug_assert!(
54                new_val >= self.prev_value,
55                "Function is not monotonic as f({}) = {} < {} = f({})",
56                self.current + step,
57                new_val,
58                self.prev_value,
59                self.current,
60            );
61            if new_val != self.prev_value {
62                break;
63            }
64            step = step.saturating_mul(2);
65        }
66
67        // Binary search in the last exponential step to find exact change point
68        let mut left = self.current + step / 2;
69        let mut right = self.current + step;
70
71        while left < right {
72            let mid = left + (right - left) / 2;
73            let mid_val = (self.func)(mid);
74            debug_assert!(
75                mid_val >= self.prev_value,
76                "Function is not monotonic as f({}) = {} < {} = f({})",
77                mid,
78                mid_val,
79                self.prev_value,
80                self.current,
81            );
82            if mid_val == self.prev_value {
83                left = mid + 1;
84            } else {
85                right = mid;
86            }
87        }
88
89        // Update state
90        let new_value = (self.func)(left);
91        debug_assert!(
92            new_value >= self.prev_value,
93            "Function is not monotonic as f({}) = {} < {} = f({})",
94            left,
95            new_value,
96            self.prev_value,
97            self.current,
98        );
99
100        self.current = left;
101        self.prev_value = new_value;
102        Some((self.current, new_value))
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::FindChangePoints;
109
110    #[test]
111    fn test_find_change_points() {
112        test_func(crate::codes::len_gamma);
113        test_func(crate::codes::len_delta);
114        test_func(crate::codes::len_omega);
115        test_func(|x| crate::codes::len_zeta(x, 3));
116        test_func(|x| crate::codes::len_pi(x, 3));
117    }
118
119    fn test_func(func: impl Fn(u64) -> usize) {
120        for (first, len) in FindChangePoints::new(&func) {
121            // first check that the len is actually correct
122            assert_eq!(func(first), len);
123            // then check that it's the first one with that len
124            if first > 0 {
125                assert_ne!(func(first - 1), len);
126            }
127        }
128    }
129}