dsi_bitstream/utils/
find_change.rs

1/*
2 * SPDX-FileCopyrightText: 2025 Tommaso Fontana
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
5 */
6
7/// Iters the points where the given function change value.
8/// This only works for monotonic non decreasing functions.
9///
10/// Each call to next returns a tuple with the first input where the function
11/// changes value and the new value.
12///
13/// This is useful to generate data following the implied distribution of a code.
14pub struct FindChangePoints<F: Fn(u64) -> usize> {
15    func: F,
16    current: u64,
17    prev_value: usize,
18}
19
20impl<F: Fn(u64) -> usize> FindChangePoints<F> {
21    pub fn new(func: F) -> Self {
22        Self {
23            func,
24            current: 0,
25            prev_value: usize::MAX,
26        }
27    }
28}
29
30impl<F: Fn(u64) -> usize> Iterator for FindChangePoints<F> {
31    /// (first input, output)
32    type Item = (u64, usize);
33
34    fn next(&mut self) -> Option<Self::Item> {
35        // handle the first case, we don't need to search for the first change
36        if self.current == 0 && self.prev_value == usize::MAX {
37            self.prev_value = (self.func)(0);
38            return Some((0, self.prev_value));
39        }
40
41        // Exponential search to find next potential change point starting from
42        // the last change point.
43        let mut step = 1;
44        loop {
45            // Avoid overflow, use <= instead of < because none of our codes
46            // can encode u64::MAX, so let's just ignore it
47            if u64::MAX - self.current <= step {
48                return None;
49            }
50            // check if we found a change point
51            let new_val = (self.func)(self.current + step);
52            debug_assert!(
53                new_val >= self.prev_value,
54                "Function is not monotonic as f({}) = {} < {} = f({})",
55                self.current + step,
56                new_val,
57                self.prev_value,
58                self.current,
59            );
60            if new_val != self.prev_value {
61                break;
62            }
63            step *= 2;
64        }
65
66        // Binary search in the last exponential step to find exact change point
67        let mut left = self.current + step / 2;
68        let mut right = self.current + step;
69
70        while left < right {
71            let mid = left + (right - left) / 2;
72            let mid_val = (self.func)(mid);
73            debug_assert!(
74                mid_val >= self.prev_value,
75                "Function is not monotonic as f({}) = {} < {} = f({})",
76                mid,
77                mid_val,
78                self.prev_value,
79                self.current,
80            );
81            if mid_val == self.prev_value {
82                left = mid + 1;
83            } else {
84                right = mid;
85            }
86        }
87
88        // Update state
89        let new_value = (self.func)(left);
90        debug_assert!(
91            new_value >= self.prev_value,
92            "Function is not monotonic as f({}) = {} < {} = f({})",
93            left,
94            new_value,
95            self.prev_value,
96            self.current,
97        );
98
99        self.current = left;
100        self.prev_value = new_value;
101        Some((self.current, new_value))
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use super::FindChangePoints;
108
109    #[test]
110    fn test_find_change_points() {
111        test_func(crate::codes::len_gamma);
112        test_func(crate::codes::len_delta);
113        test_func(crate::codes::len_omega);
114        test_func(|x| crate::codes::len_zeta(x, 3));
115        test_func(|x| crate::codes::len_pi(x, 3));
116    }
117
118    fn test_func(func: impl Fn(u64) -> usize) {
119        for (first, len) in FindChangePoints::new(&func) {
120            // first check that the len is actually correct
121            assert_eq!(func(first), len);
122            // then check that it's the first one with that len
123            if first > 0 {
124                assert_ne!(func(first - 1), len);
125            }
126        }
127    }
128}