1pub fn asserting_cmp<T: PartialOrd>(a: &T, b: &T) -> std::cmp::Ordering {
14 a.partial_cmp(b).expect("Comparing incomparable elements")
15}
16
17pub fn slice_lower_bound<T: PartialOrd>(slice: &[T], key: &T) -> usize {
20 slice
21 .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Greater))
22 .unwrap_err()
23}
24
25pub fn slice_upper_bound<T: PartialOrd>(slice: &[T], key: &T) -> usize {
28 slice
29 .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Less))
30 .unwrap_err()
31}
32
33pub fn merge_sorted<T: PartialOrd>(
35 i1: impl IntoIterator<Item = T>,
36 i2: impl IntoIterator<Item = T>,
37) -> Vec<T> {
38 let mut i1 = i1.into_iter().peekable();
39 let mut i2 = i2.into_iter().peekable();
40 let mut merged = Vec::with_capacity(i1.size_hint().0 + i2.size_hint().0);
41 while let (Some(a), Some(b)) = (i1.peek(), i2.peek()) {
42 merged.push(if a <= b { i1.next() } else { i2.next() }.unwrap());
43 }
44 merged.extend(i1.chain(i2));
45 merged
46}
47
48pub fn merge_sort<T: Ord>(mut v: Vec<T>) -> Vec<T> {
50 if v.len() < 2 {
51 v
52 } else {
53 let v2 = v.split_off(v.len() / 2);
54 merge_sorted(merge_sort(v), merge_sort(v2))
55 }
56}
57
58pub struct SparseIndex {
60 coords: Vec<i64>,
61}
62
63impl SparseIndex {
64 pub fn new(mut coords: Vec<i64>) -> Self {
66 coords.sort_unstable();
67 coords.dedup();
68 Self { coords }
69 }
70
71 pub fn compress(&self, q: i64) -> Result<usize, usize> {
74 self.coords.binary_search(&q)
75 }
76}
77
78#[derive(Default)]
93pub struct PiecewiseLinearConvexFn {
94 recent_lines: Vec<(f64, f64)>,
95 sorted_lines: Vec<(f64, f64)>,
96 intersections: Vec<f64>,
97 amortized_work: usize,
98}
99
100impl PiecewiseLinearConvexFn {
101 pub fn max_with(&mut self, new_m: f64, new_b: f64) {
103 self.recent_lines.push((new_m, new_b));
104 }
105
106 fn max_with_sorted(&mut self, new_m: f64, new_b: f64) {
108 while let Some(&(last_m, last_b)) = self.sorted_lines.last() {
109 if (new_m - last_m).abs() > 1e-9 {
111 let intersect = (new_b - last_b) / (last_m - new_m);
112 if self.intersections.last() < Some(&intersect) {
113 self.intersections.push(intersect);
114 break;
115 }
116 }
117 self.intersections.pop();
118 self.sorted_lines.pop();
119 }
120 self.sorted_lines.push((new_m, new_b));
121 }
122
123 fn eval_unoptimized(&self, x: f64) -> f64 {
125 let idx = slice_lower_bound(&self.intersections, &x);
126 self.recent_lines
127 .iter()
128 .chain(self.sorted_lines.get(idx))
129 .map(|&(m, b)| m * x + b)
130 .max_by(asserting_cmp)
131 .unwrap_or(-1e18)
132 }
133
134 pub fn evaluate(&mut self, x: f64) -> f64 {
136 self.amortized_work += self.recent_lines.len();
137 if self.amortized_work > self.sorted_lines.len() {
138 self.amortized_work = 0;
139 self.recent_lines.sort_unstable_by(asserting_cmp);
140 self.intersections.clear();
141 let all_lines = merge_sorted(self.recent_lines.drain(..), self.sorted_lines.drain(..));
142 for (new_m, new_b) in all_lines {
143 self.max_with_sorted(new_m, new_b);
144 }
145 }
146 self.eval_unoptimized(x)
147 }
148}
149
150#[cfg(test)]
151mod test {
152 use super::*;
153
154 #[test]
155 fn test_bounds() {
156 let mut vals = vec![16, 45, 45, 45, 82];
157
158 assert_eq!(slice_upper_bound(&vals, &44), 1);
159 assert_eq!(slice_lower_bound(&vals, &45), 1);
160 assert_eq!(slice_upper_bound(&vals, &45), 4);
161 assert_eq!(slice_lower_bound(&vals, &46), 4);
162
163 vals.dedup();
164 for (i, q) in vals.iter().enumerate() {
165 assert_eq!(slice_lower_bound(&vals, q), i);
166 assert_eq!(slice_upper_bound(&vals, q), i + 1);
167 }
168 }
169
170 #[test]
171 fn test_merge_sorted() {
172 let vals1 = vec![16, 45, 45, 82];
173 let vals2 = vec![-20, 40, 45, 50];
174 let vals_merged = vec![-20, 16, 40, 45, 45, 45, 50, 82];
175
176 assert_eq!(merge_sorted(None, Some(42)), vec![42]);
177 assert_eq!(merge_sorted(vals1.iter().cloned(), None), vals1);
178 assert_eq!(merge_sorted(vals1, vals2), vals_merged);
179 }
180
181 #[test]
182 fn test_merge_sort() {
183 let unsorted = vec![8, -5, 1, 4, -3, 4];
184 let sorted = vec![-5, -3, 1, 4, 4, 8];
185
186 assert_eq!(merge_sort(unsorted), sorted);
187 assert_eq!(merge_sort(sorted.clone()), sorted);
188 }
189
190 #[test]
191 fn test_coord_compress() {
192 let mut coords = vec![16, 99, 45, 18];
193 let index = SparseIndex::new(coords.clone());
194
195 coords.sort_unstable();
196 for (i, q) in coords.into_iter().enumerate() {
197 assert_eq!(index.compress(q - 1), Err(i));
198 assert_eq!(index.compress(q), Ok(i));
199 assert_eq!(index.compress(q + 1), Err(i + 1));
200 }
201 }
202
203 #[test]
204 fn test_range_compress() {
205 let queries = vec![(0, 10), (10, 19), (20, 29)];
206 let coords = queries.iter().flat_map(|&(i, j)| vec![i, j + 1]).collect();
207 let index = SparseIndex::new(coords);
208
209 assert_eq!(index.coords, vec![0, 10, 11, 20, 30]);
210 }
211
212 #[test]
213 fn test_convex_hull_trick() {
214 let lines = [(0, -3), (-1, 0), (1, -8), (-2, 1), (1, -4)];
215 let xs = [0, 1, 2, 3, 4, 5];
216 let results = [
219 [-3, -3, -3, -3, -3, -3],
220 [0, -1, -2, -3, -3, -3],
221 [0, -1, -2, -3, -3, -3],
222 [1, -1, -2, -3, -3, -3],
223 [1, -1, -2, -1, 0, 1],
224 ];
225 let mut func = PiecewiseLinearConvexFn::default();
226 assert_eq!(func.evaluate(0.0), -1e18);
227 for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) {
228 func.max_with(slope as f64, intercept as f64);
229 let ys: Vec<i64> = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect();
230 assert_eq!(expected, &ys[..]);
231 }
232 }
233}