1use super::Point;
16
17pub trait Merge: Send + Sync {
21 fn merge(&self, points: &[Point]) -> Point;
26
27 fn name(&self) -> &'static str;
29}
30
31#[derive(Clone, Copy, Debug, Default)]
40pub struct Mean;
41
42impl Merge for Mean {
43 fn merge(&self, points: &[Point]) -> Point {
44 assert!(!points.is_empty(), "Cannot merge empty slice");
45
46 let dims = points[0].dimensionality();
47 let n = points.len() as f32;
48
49 let mut result = vec![0.0; dims];
50 for p in points {
51 assert_eq!(
52 p.dimensionality(),
53 dims,
54 "All points must have same dimensionality"
55 );
56 for (r, d) in result.iter_mut().zip(p.dims()) {
57 *r += d / n;
58 }
59 }
60
61 Point::new(result)
62 }
63
64 fn name(&self) -> &'static str {
65 "mean"
66 }
67}
68
69#[derive(Clone, Debug)]
74pub struct WeightedMean {
75 weights: Vec<f32>,
76}
77
78impl WeightedMean {
79 pub fn new(weights: Vec<f32>) -> Self {
83 Self { weights }
84 }
85
86 pub fn uniform(n: usize) -> Self {
88 Self {
89 weights: vec![1.0; n],
90 }
91 }
92
93 pub fn recency(n: usize, decay: f32) -> Self {
98 let weights: Vec<f32> = (0..n).map(|i| decay.powi((n - 1 - i) as i32)).collect();
99 Self { weights }
100 }
101}
102
103impl Merge for WeightedMean {
104 fn merge(&self, points: &[Point]) -> Point {
105 assert!(!points.is_empty(), "Cannot merge empty slice");
106 assert_eq!(
107 points.len(),
108 self.weights.len(),
109 "Number of points must match number of weights"
110 );
111
112 let dims = points[0].dimensionality();
113 let total_weight: f32 = self.weights.iter().sum();
114
115 let mut result = vec![0.0; dims];
116 for (p, &w) in points.iter().zip(&self.weights) {
117 assert_eq!(
118 p.dimensionality(),
119 dims,
120 "All points must have same dimensionality"
121 );
122 let normalized_w = w / total_weight;
123 for (r, d) in result.iter_mut().zip(p.dims()) {
124 *r += d * normalized_w;
125 }
126 }
127
128 Point::new(result)
129 }
130
131 fn name(&self) -> &'static str {
132 "weighted_mean"
133 }
134}
135
136#[derive(Clone, Copy, Debug, Default)]
141pub struct MaxPool;
142
143impl Merge for MaxPool {
144 fn merge(&self, points: &[Point]) -> Point {
145 assert!(!points.is_empty(), "Cannot merge empty slice");
146
147 let dims = points[0].dimensionality();
148 let mut result = points[0].dims().to_vec();
149
150 for p in &points[1..] {
151 assert_eq!(
152 p.dimensionality(),
153 dims,
154 "All points must have same dimensionality"
155 );
156 for (r, d) in result.iter_mut().zip(p.dims()) {
157 *r = r.max(*d);
158 }
159 }
160
161 Point::new(result)
162 }
163
164 fn name(&self) -> &'static str {
165 "max_pool"
166 }
167}
168
169#[derive(Clone, Copy, Debug, Default)]
173pub struct MinPool;
174
175impl Merge for MinPool {
176 fn merge(&self, points: &[Point]) -> Point {
177 assert!(!points.is_empty(), "Cannot merge empty slice");
178
179 let dims = points[0].dimensionality();
180 let mut result = points[0].dims().to_vec();
181
182 for p in &points[1..] {
183 assert_eq!(
184 p.dimensionality(),
185 dims,
186 "All points must have same dimensionality"
187 );
188 for (r, d) in result.iter_mut().zip(p.dims()) {
189 *r = r.min(*d);
190 }
191 }
192
193 Point::new(result)
194 }
195
196 fn name(&self) -> &'static str {
197 "min_pool"
198 }
199}
200
201#[derive(Clone, Copy, Debug, Default)]
205pub struct Sum;
206
207impl Merge for Sum {
208 fn merge(&self, points: &[Point]) -> Point {
209 assert!(!points.is_empty(), "Cannot merge empty slice");
210
211 let dims = points[0].dimensionality();
212 let mut result = vec![0.0; dims];
213
214 for p in points {
215 assert_eq!(
216 p.dimensionality(),
217 dims,
218 "All points must have same dimensionality"
219 );
220 for (r, d) in result.iter_mut().zip(p.dims()) {
221 *r += d;
222 }
223 }
224
225 Point::new(result)
226 }
227
228 fn name(&self) -> &'static str {
229 "sum"
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_mean_single() {
239 let points = vec![Point::new(vec![1.0, 2.0, 3.0])];
240 let merged = Mean.merge(&points);
241 assert_eq!(merged.dims(), &[1.0, 2.0, 3.0]);
242 }
243
244 #[test]
245 fn test_mean_multiple() {
246 let points = vec![
247 Point::new(vec![1.0, 2.0]),
248 Point::new(vec![3.0, 4.0]),
249 ];
250 let merged = Mean.merge(&points);
251 assert_eq!(merged.dims(), &[2.0, 3.0]);
252 }
253
254 #[test]
255 fn test_weighted_mean() {
256 let points = vec![
257 Point::new(vec![0.0, 0.0]),
258 Point::new(vec![10.0, 10.0]),
259 ];
260 let merger = WeightedMean::new(vec![1.0, 3.0]);
262 let merged = merger.merge(&points);
263 assert!((merged.dims()[0] - 7.5).abs() < 0.0001);
265 assert!((merged.dims()[1] - 7.5).abs() < 0.0001);
266 }
267
268 #[test]
269 fn test_weighted_mean_recency() {
270 let merger = WeightedMean::recency(3, 0.5);
271 assert_eq!(merger.weights.len(), 3);
274 assert!((merger.weights[0] - 0.25).abs() < 0.0001);
275 assert!((merger.weights[1] - 0.5).abs() < 0.0001);
276 assert!((merger.weights[2] - 1.0).abs() < 0.0001);
277 }
278
279 #[test]
280 fn test_max_pool() {
281 let points = vec![
282 Point::new(vec![1.0, 5.0, 2.0]),
283 Point::new(vec![3.0, 2.0, 4.0]),
284 Point::new(vec![2.0, 3.0, 1.0]),
285 ];
286 let merged = MaxPool.merge(&points);
287 assert_eq!(merged.dims(), &[3.0, 5.0, 4.0]);
288 }
289
290 #[test]
291 fn test_min_pool() {
292 let points = vec![
293 Point::new(vec![1.0, 5.0, 2.0]),
294 Point::new(vec![3.0, 2.0, 4.0]),
295 Point::new(vec![2.0, 3.0, 1.0]),
296 ];
297 let merged = MinPool.merge(&points);
298 assert_eq!(merged.dims(), &[1.0, 2.0, 1.0]);
299 }
300
301 #[test]
302 fn test_sum() {
303 let points = vec![
304 Point::new(vec![1.0, 2.0]),
305 Point::new(vec![3.0, 4.0]),
306 ];
307 let merged = Sum.merge(&points);
308 assert_eq!(merged.dims(), &[4.0, 6.0]);
309 }
310
311 #[test]
312 fn test_merge_names() {
313 assert_eq!(Mean.name(), "mean");
314 assert_eq!(MaxPool.name(), "max_pool");
315 assert_eq!(MinPool.name(), "min_pool");
316 assert_eq!(Sum.name(), "sum");
317 }
318
319 #[test]
320 #[should_panic(expected = "Cannot merge empty")]
321 fn test_merge_empty_panics() {
322 let points: Vec<Point> = vec![];
323 Mean.merge(&points);
324 }
325
326 #[test]
327 #[should_panic(expected = "same dimensionality")]
328 fn test_merge_dimension_mismatch_panics() {
329 let points = vec![
330 Point::new(vec![1.0, 2.0]),
331 Point::new(vec![1.0, 2.0, 3.0]),
332 ];
333 Mean.merge(&points);
334 }
335}