ciphercore_base/inline/
data_structures.rs

1use crate::errors::Result;
2
3pub(super) trait CombineOp<T> {
4    fn combine(&mut self, arg1: T, arg2: T) -> Result<T>;
5}
6
7// Simple sum computation in the form of binary tree. `combine_op` must be associative.
8// Depth: RoundUp(log(len))
9// Complexity: len - 1
10pub(super) fn log_depth_sum<T: std::clone::Clone>(
11    items: &[T],
12    combine_op: &mut dyn CombineOp<T>,
13) -> Result<T> {
14    if items.is_empty() {
15        return Err(runtime_error!("Cannot combine empty vector"));
16    }
17    let mut combined_items = items.to_owned();
18    while combined_items.len() > 1 {
19        let mut new_combined_items = vec![];
20        for i in (0..combined_items.len()).step_by(2) {
21            let j = i + 1;
22            if j >= combined_items.len() {
23                new_combined_items.push(combined_items[i].clone());
24            } else {
25                let new_item =
26                    combine_op.combine(combined_items[i].clone(), combined_items[j].clone())?;
27                new_combined_items.push(new_item);
28            }
29        }
30        combined_items = new_combined_items;
31    }
32    Ok(combined_items[0].clone())
33}
34
35/// Computes prefix sums for the given vector. `combine_op` must be associative.
36/// Returned vector contains sums of `items[0]..items[i]` at position i.
37/// Depth: RoundUp(log(len))
38/// Complexity: len * RoundUp(log(len))
39#[allow(dead_code)]
40pub(super) fn prefix_sums_binary_ascent<T: std::clone::Clone>(
41    items: &[T],
42    combine_op: &mut dyn CombineOp<T>,
43) -> Result<Vec<T>> {
44    if items.is_empty() {
45        return Ok(vec![]);
46    }
47    let mut combined_items = items.to_owned();
48    let mut depth = 1;
49    // Invariant: combined_items[i] = sum(items[max(i - depth + 1, 0) : i + 1])
50    while depth < combined_items.len() {
51        for i in (depth..combined_items.len()).rev() {
52            combined_items[i] =
53                combine_op.combine(combined_items[i - depth].clone(), combined_items[i].clone())?;
54        }
55        depth *= 2;
56    }
57    Ok(combined_items)
58}
59
60/// Computes prefix sums for the given vector. `combine_op` must be associative.
61/// Returned vector contains sums of `items[0]..items[i]` at position i.
62/// Depth: 2 * RoundUp(sqrt(len))
63/// Complexity: 2 * len
64#[allow(dead_code)]
65pub(super) fn prefix_sums_sqrt_trick<T: std::clone::Clone>(
66    items: &[T],
67    combine_op: &mut dyn CombineOp<T>,
68) -> Result<Vec<T>> {
69    if items.is_empty() {
70        return Ok(vec![]);
71    }
72    let block_size = std::cmp::max(1, (items.len() as f64).sqrt() as usize);
73    let mut combined_items = items.to_owned();
74    // Invariant: combined_items[i] = sum(items[i - i % block_size : i + 1])
75    for i in 0..combined_items.len() {
76        if i % block_size != 0 {
77            combined_items[i] =
78                combine_op.combine(combined_items[i - 1].clone(), combined_items[i].clone())?;
79        }
80    }
81    // Now, compute the actual sums.
82    for i in block_size..combined_items.len() {
83        combined_items[i] = combine_op.combine(
84            combined_items[i - i % block_size - 1].clone(),
85            combined_items[i].clone(),
86        )?;
87    }
88    Ok(combined_items)
89}
90
91/// Computes prefix sums for the given vector. `combine_op` must be associative.
92/// Returned vector contains sums of `items[0]..items[i]` at position i.
93/// Depth: RoundUp(log(len)) * 2
94/// Complexity: len * 2
95#[allow(dead_code)]
96pub(super) fn prefix_sums_segment_tree<T: std::clone::Clone>(
97    items: &[T],
98    combine_op: &mut dyn CombineOp<T>,
99) -> Result<Vec<T>> {
100    if items.is_empty() {
101        return Ok(vec![]);
102    }
103    // We construct segment tree layer by layer.
104    let mut layers = vec![items.to_owned()];
105    let mut layer = 0;
106    while layers[layer].len() > 1 {
107        let mut next_layer = vec![];
108        for i in (0..layers[layer].len()).step_by(2) {
109            if i + 1 < layers[layer].len() {
110                next_layer.push(
111                    combine_op.combine(layers[layer][i].clone(), layers[layer][i + 1].clone())?,
112                );
113            } else {
114                next_layer.push(layers[layer][i].clone());
115            }
116        }
117        layer += 1;
118        layers.push(next_layer);
119    }
120    // Now, we go from top to bottom and compute prefix sums in each layer.
121    for i in (0..layers.len() - 1).rev() {
122        for j in 1..layers[i].len() {
123            if j % 2 == 1 {
124                layers[i][j] = layers[i + 1][j / 2].clone();
125            } else {
126                layers[i][j] =
127                    combine_op.combine(layers[i + 1][(j - 1) / 2].clone(), layers[i][j].clone())?;
128            }
129        }
130    }
131    Ok(layers[0].clone())
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    struct IntCombiner {}
139
140    impl CombineOp<u64> for IntCombiner {
141        fn combine(&mut self, arg1: u64, arg2: u64) -> Result<u64> {
142            return Ok(arg1 + arg2);
143        }
144    }
145
146    #[test]
147    fn test_log_depth_sum() {
148        for len in 1..20 {
149            let mut v = vec![];
150            for i in 1..len + 1 {
151                v.push(i * i);
152            }
153            let mut expected = 0;
154            for x in v.clone() {
155                expected += x;
156            }
157            let mut combiner = IntCombiner {};
158            let actual = log_depth_sum(&v, &mut combiner).unwrap();
159            assert_eq!(expected, actual);
160        }
161    }
162
163    #[test]
164    fn test_prefix_sums() {
165        for len in 0..20 {
166            let mut v = vec![];
167            for i in 1..len + 1 {
168                v.push(i * i);
169            }
170            let mut expected = vec![];
171            let mut sum = 0;
172            for x in v.clone() {
173                sum += x;
174                expected.push(sum);
175            }
176            let mut combiner = IntCombiner {};
177            let actual_bin_ascent = prefix_sums_binary_ascent(&v, &mut combiner).unwrap();
178            assert_eq!(expected, actual_bin_ascent);
179            let actual_sqrt = prefix_sums_sqrt_trick(&v, &mut combiner).unwrap();
180            assert_eq!(expected, actual_sqrt);
181            let actual_fenwick = prefix_sums_segment_tree(&v, &mut combiner).unwrap();
182            assert_eq!(expected, actual_fenwick);
183        }
184    }
185
186    // Rather than doing anything useful, this combiner tracks number of operations and total depth.
187    struct TrackingCombiner {
188        total_calls: u64,
189    }
190
191    impl CombineOp<u64> for TrackingCombiner {
192        fn combine(&mut self, arg1: u64, arg2: u64) -> Result<u64> {
193            self.total_calls += 1;
194            return Ok(std::cmp::max(arg1, arg2) + 1); // Return new depth.
195        }
196    }
197
198    #[test]
199    fn test_complexity() {
200        for len in 200..300 {
201            let v = vec![0; len];
202            {
203                let mut combiner = TrackingCombiner { total_calls: 0 };
204                let depth = log_depth_sum(&v, &mut combiner).unwrap();
205                assert!(depth <= (len as f64).log(2.0).ceil() as u64);
206                assert!(combiner.total_calls <= len as u64 - 1);
207            }
208            {
209                let mut combiner = TrackingCombiner { total_calls: 0 };
210                let depths = prefix_sums_binary_ascent(&v, &mut combiner).unwrap();
211                let log_n: u64 = (len as f64).log(2.0).ceil() as u64;
212                assert!(depths.iter().max().unwrap() <= &log_n);
213                assert!(combiner.total_calls <= (len as u64) * log_n);
214            }
215            {
216                let mut combiner = TrackingCombiner { total_calls: 0 };
217                let depths = prefix_sums_sqrt_trick(&v, &mut combiner).unwrap();
218                let sqrt_n: u64 = (len as f64).sqrt().ceil() as u64;
219                assert!(depths.iter().max().unwrap() <= &(2 * sqrt_n + 1));
220                assert!(combiner.total_calls <= (len as u64) * 2);
221            }
222            {
223                let mut combiner = TrackingCombiner { total_calls: 0 };
224                let depths = prefix_sums_segment_tree(&v, &mut combiner).unwrap();
225                let log_n: u64 = (len as f64).log(2.0).ceil() as u64;
226                assert!(depths.iter().max().unwrap() <= &(2 * log_n + 1));
227                assert!(combiner.total_calls <= (len as u64) * 2);
228            }
229        }
230    }
231}