arms/core/
merge.rs

1//! # Merge
2//!
3//! Trait and implementations for composing multiple points into one.
4//!
5//! This is one of the five primitives of ARMS:
6//! `Merge: fn(points) -> point` - Compose together
7//!
8//! Merge is used for hierarchical composition:
9//! - Chunks → Document
10//! - Documents → Session
11//! - Sessions → Domain
12//!
13//! Merge functions are pluggable - use whichever fits your use case.
14
15use super::Point;
16
17/// Trait for merging multiple points into one
18///
19/// Used for hierarchical composition and aggregation.
20pub trait Merge: Send + Sync {
21    /// Merge multiple points into a single point
22    ///
23    /// All points must have the same dimensionality.
24    /// The slice must not be empty.
25    fn merge(&self, points: &[Point]) -> Point;
26
27    /// Name of this merge function (for debugging/config)
28    fn name(&self) -> &'static str;
29}
30
31// ============================================================================
32// IMPLEMENTATIONS
33// ============================================================================
34
35/// Mean (average) of all points
36///
37/// The centroid of the input points.
38/// Good default for most hierarchical composition.
39#[derive(Clone, Copy, Debug, Default)]
40pub struct Mean;
41
42impl Merge for Mean {
43    fn merge(&self, points: &[Point]) -> Point {
44        assert!(!points.is_empty(), "Cannot merge empty slice");
45
46        let dims = points[0].dimensionality();
47        let n = points.len() as f32;
48
49        let mut result = vec![0.0; dims];
50        for p in points {
51            assert_eq!(
52                p.dimensionality(),
53                dims,
54                "All points must have same dimensionality"
55            );
56            for (r, d) in result.iter_mut().zip(p.dims()) {
57                *r += d / n;
58            }
59        }
60
61        Point::new(result)
62    }
63
64    fn name(&self) -> &'static str {
65        "mean"
66    }
67}
68
69/// Weighted mean of points
70///
71/// Each point contributes proportionally to its weight.
72/// Useful for recency weighting, importance weighting, etc.
73#[derive(Clone, Debug)]
74pub struct WeightedMean {
75    weights: Vec<f32>,
76}
77
78impl WeightedMean {
79    /// Create a new weighted mean with given weights
80    ///
81    /// Weights will be normalized (divided by sum) during merge.
82    pub fn new(weights: Vec<f32>) -> Self {
83        Self { weights }
84    }
85
86    /// Create with uniform weights (equivalent to Mean)
87    pub fn uniform(n: usize) -> Self {
88        Self {
89            weights: vec![1.0; n],
90        }
91    }
92
93    /// Create with recency weighting (more recent = higher weight)
94    ///
95    /// `decay` should be in (0, 1). Smaller = faster decay.
96    /// First point is oldest, last is most recent.
97    pub fn recency(n: usize, decay: f32) -> Self {
98        let weights: Vec<f32> = (0..n).map(|i| decay.powi((n - 1 - i) as i32)).collect();
99        Self { weights }
100    }
101}
102
103impl Merge for WeightedMean {
104    fn merge(&self, points: &[Point]) -> Point {
105        assert!(!points.is_empty(), "Cannot merge empty slice");
106        assert_eq!(
107            points.len(),
108            self.weights.len(),
109            "Number of points must match number of weights"
110        );
111
112        let dims = points[0].dimensionality();
113        let total_weight: f32 = self.weights.iter().sum();
114
115        let mut result = vec![0.0; dims];
116        for (p, &w) in points.iter().zip(&self.weights) {
117            assert_eq!(
118                p.dimensionality(),
119                dims,
120                "All points must have same dimensionality"
121            );
122            let normalized_w = w / total_weight;
123            for (r, d) in result.iter_mut().zip(p.dims()) {
124                *r += d * normalized_w;
125            }
126        }
127
128        Point::new(result)
129    }
130
131    fn name(&self) -> &'static str {
132        "weighted_mean"
133    }
134}
135
136/// Max pooling across points
137///
138/// Takes the maximum value of each dimension across all points.
139/// Preserves the strongest activations.
140#[derive(Clone, Copy, Debug, Default)]
141pub struct MaxPool;
142
143impl Merge for MaxPool {
144    fn merge(&self, points: &[Point]) -> Point {
145        assert!(!points.is_empty(), "Cannot merge empty slice");
146
147        let dims = points[0].dimensionality();
148        let mut result = points[0].dims().to_vec();
149
150        for p in &points[1..] {
151            assert_eq!(
152                p.dimensionality(),
153                dims,
154                "All points must have same dimensionality"
155            );
156            for (r, d) in result.iter_mut().zip(p.dims()) {
157                *r = r.max(*d);
158            }
159        }
160
161        Point::new(result)
162    }
163
164    fn name(&self) -> &'static str {
165        "max_pool"
166    }
167}
168
169/// Min pooling across points
170///
171/// Takes the minimum value of each dimension across all points.
172#[derive(Clone, Copy, Debug, Default)]
173pub struct MinPool;
174
175impl Merge for MinPool {
176    fn merge(&self, points: &[Point]) -> Point {
177        assert!(!points.is_empty(), "Cannot merge empty slice");
178
179        let dims = points[0].dimensionality();
180        let mut result = points[0].dims().to_vec();
181
182        for p in &points[1..] {
183            assert_eq!(
184                p.dimensionality(),
185                dims,
186                "All points must have same dimensionality"
187            );
188            for (r, d) in result.iter_mut().zip(p.dims()) {
189                *r = r.min(*d);
190            }
191        }
192
193        Point::new(result)
194    }
195
196    fn name(&self) -> &'static str {
197        "min_pool"
198    }
199}
200
201/// Sum of all points (no averaging)
202///
203/// Simple additive composition.
204#[derive(Clone, Copy, Debug, Default)]
205pub struct Sum;
206
207impl Merge for Sum {
208    fn merge(&self, points: &[Point]) -> Point {
209        assert!(!points.is_empty(), "Cannot merge empty slice");
210
211        let dims = points[0].dimensionality();
212        let mut result = vec![0.0; dims];
213
214        for p in points {
215            assert_eq!(
216                p.dimensionality(),
217                dims,
218                "All points must have same dimensionality"
219            );
220            for (r, d) in result.iter_mut().zip(p.dims()) {
221                *r += d;
222            }
223        }
224
225        Point::new(result)
226    }
227
228    fn name(&self) -> &'static str {
229        "sum"
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_mean_single() {
239        let points = vec![Point::new(vec![1.0, 2.0, 3.0])];
240        let merged = Mean.merge(&points);
241        assert_eq!(merged.dims(), &[1.0, 2.0, 3.0]);
242    }
243
244    #[test]
245    fn test_mean_multiple() {
246        let points = vec![
247            Point::new(vec![1.0, 2.0]),
248            Point::new(vec![3.0, 4.0]),
249        ];
250        let merged = Mean.merge(&points);
251        assert_eq!(merged.dims(), &[2.0, 3.0]);
252    }
253
254    #[test]
255    fn test_weighted_mean() {
256        let points = vec![
257            Point::new(vec![0.0, 0.0]),
258            Point::new(vec![10.0, 10.0]),
259        ];
260        // Weight second point 3x more than first
261        let merger = WeightedMean::new(vec![1.0, 3.0]);
262        let merged = merger.merge(&points);
263        // (0*0.25 + 10*0.75, 0*0.25 + 10*0.75) = (7.5, 7.5)
264        assert!((merged.dims()[0] - 7.5).abs() < 0.0001);
265        assert!((merged.dims()[1] - 7.5).abs() < 0.0001);
266    }
267
268    #[test]
269    fn test_weighted_mean_recency() {
270        let merger = WeightedMean::recency(3, 0.5);
271        // decay = 0.5, n = 3
272        // weights: [0.5^2, 0.5^1, 0.5^0] = [0.25, 0.5, 1.0]
273        assert_eq!(merger.weights.len(), 3);
274        assert!((merger.weights[0] - 0.25).abs() < 0.0001);
275        assert!((merger.weights[1] - 0.5).abs() < 0.0001);
276        assert!((merger.weights[2] - 1.0).abs() < 0.0001);
277    }
278
279    #[test]
280    fn test_max_pool() {
281        let points = vec![
282            Point::new(vec![1.0, 5.0, 2.0]),
283            Point::new(vec![3.0, 2.0, 4.0]),
284            Point::new(vec![2.0, 3.0, 1.0]),
285        ];
286        let merged = MaxPool.merge(&points);
287        assert_eq!(merged.dims(), &[3.0, 5.0, 4.0]);
288    }
289
290    #[test]
291    fn test_min_pool() {
292        let points = vec![
293            Point::new(vec![1.0, 5.0, 2.0]),
294            Point::new(vec![3.0, 2.0, 4.0]),
295            Point::new(vec![2.0, 3.0, 1.0]),
296        ];
297        let merged = MinPool.merge(&points);
298        assert_eq!(merged.dims(), &[1.0, 2.0, 1.0]);
299    }
300
301    #[test]
302    fn test_sum() {
303        let points = vec![
304            Point::new(vec![1.0, 2.0]),
305            Point::new(vec![3.0, 4.0]),
306        ];
307        let merged = Sum.merge(&points);
308        assert_eq!(merged.dims(), &[4.0, 6.0]);
309    }
310
311    #[test]
312    fn test_merge_names() {
313        assert_eq!(Mean.name(), "mean");
314        assert_eq!(MaxPool.name(), "max_pool");
315        assert_eq!(MinPool.name(), "min_pool");
316        assert_eq!(Sum.name(), "sum");
317    }
318
319    #[test]
320    #[should_panic(expected = "Cannot merge empty")]
321    fn test_merge_empty_panics() {
322        let points: Vec<Point> = vec![];
323        Mean.merge(&points);
324    }
325
326    #[test]
327    #[should_panic(expected = "same dimensionality")]
328    fn test_merge_dimension_mismatch_panics() {
329        let points = vec![
330            Point::new(vec![1.0, 2.0]),
331            Point::new(vec![1.0, 2.0, 3.0]),
332        ];
333        Mean.merge(&points);
334    }
335}