use super::Point;
pub trait Merge: Send + Sync {
fn merge(&self, points: &[Point]) -> Point;
fn name(&self) -> &'static str;
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Mean;
impl Merge for Mean {
fn merge(&self, points: &[Point]) -> Point {
assert!(!points.is_empty(), "Cannot merge empty slice");
let dims = points[0].dimensionality();
let n = points.len() as f32;
let mut result = vec![0.0; dims];
for p in points {
assert_eq!(
p.dimensionality(),
dims,
"All points must have same dimensionality"
);
for (r, d) in result.iter_mut().zip(p.dims()) {
*r += d / n;
}
}
Point::new(result)
}
fn name(&self) -> &'static str {
"mean"
}
}
#[derive(Clone, Debug)]
pub struct WeightedMean {
weights: Vec<f32>,
}
impl WeightedMean {
pub fn new(weights: Vec<f32>) -> Self {
Self { weights }
}
pub fn uniform(n: usize) -> Self {
Self {
weights: vec![1.0; n],
}
}
pub fn recency(n: usize, decay: f32) -> Self {
let weights: Vec<f32> = (0..n).map(|i| decay.powi((n - 1 - i) as i32)).collect();
Self { weights }
}
}
impl Merge for WeightedMean {
fn merge(&self, points: &[Point]) -> Point {
assert!(!points.is_empty(), "Cannot merge empty slice");
assert_eq!(
points.len(),
self.weights.len(),
"Number of points must match number of weights"
);
let dims = points[0].dimensionality();
let total_weight: f32 = self.weights.iter().sum();
let mut result = vec![0.0; dims];
for (p, &w) in points.iter().zip(&self.weights) {
assert_eq!(
p.dimensionality(),
dims,
"All points must have same dimensionality"
);
let normalized_w = w / total_weight;
for (r, d) in result.iter_mut().zip(p.dims()) {
*r += d * normalized_w;
}
}
Point::new(result)
}
fn name(&self) -> &'static str {
"weighted_mean"
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct MaxPool;
impl Merge for MaxPool {
fn merge(&self, points: &[Point]) -> Point {
assert!(!points.is_empty(), "Cannot merge empty slice");
let dims = points[0].dimensionality();
let mut result = points[0].dims().to_vec();
for p in &points[1..] {
assert_eq!(
p.dimensionality(),
dims,
"All points must have same dimensionality"
);
for (r, d) in result.iter_mut().zip(p.dims()) {
*r = r.max(*d);
}
}
Point::new(result)
}
fn name(&self) -> &'static str {
"max_pool"
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct MinPool;
impl Merge for MinPool {
fn merge(&self, points: &[Point]) -> Point {
assert!(!points.is_empty(), "Cannot merge empty slice");
let dims = points[0].dimensionality();
let mut result = points[0].dims().to_vec();
for p in &points[1..] {
assert_eq!(
p.dimensionality(),
dims,
"All points must have same dimensionality"
);
for (r, d) in result.iter_mut().zip(p.dims()) {
*r = r.min(*d);
}
}
Point::new(result)
}
fn name(&self) -> &'static str {
"min_pool"
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Sum;
impl Merge for Sum {
fn merge(&self, points: &[Point]) -> Point {
assert!(!points.is_empty(), "Cannot merge empty slice");
let dims = points[0].dimensionality();
let mut result = vec![0.0; dims];
for p in points {
assert_eq!(
p.dimensionality(),
dims,
"All points must have same dimensionality"
);
for (r, d) in result.iter_mut().zip(p.dims()) {
*r += d;
}
}
Point::new(result)
}
fn name(&self) -> &'static str {
"sum"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mean_single() {
let points = vec![Point::new(vec![1.0, 2.0, 3.0])];
let merged = Mean.merge(&points);
assert_eq!(merged.dims(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_mean_multiple() {
let points = vec![
Point::new(vec![1.0, 2.0]),
Point::new(vec![3.0, 4.0]),
];
let merged = Mean.merge(&points);
assert_eq!(merged.dims(), &[2.0, 3.0]);
}
#[test]
fn test_weighted_mean() {
let points = vec![
Point::new(vec![0.0, 0.0]),
Point::new(vec![10.0, 10.0]),
];
let merger = WeightedMean::new(vec![1.0, 3.0]);
let merged = merger.merge(&points);
assert!((merged.dims()[0] - 7.5).abs() < 0.0001);
assert!((merged.dims()[1] - 7.5).abs() < 0.0001);
}
#[test]
fn test_weighted_mean_recency() {
let merger = WeightedMean::recency(3, 0.5);
assert_eq!(merger.weights.len(), 3);
assert!((merger.weights[0] - 0.25).abs() < 0.0001);
assert!((merger.weights[1] - 0.5).abs() < 0.0001);
assert!((merger.weights[2] - 1.0).abs() < 0.0001);
}
#[test]
fn test_max_pool() {
let points = vec![
Point::new(vec![1.0, 5.0, 2.0]),
Point::new(vec![3.0, 2.0, 4.0]),
Point::new(vec![2.0, 3.0, 1.0]),
];
let merged = MaxPool.merge(&points);
assert_eq!(merged.dims(), &[3.0, 5.0, 4.0]);
}
#[test]
fn test_min_pool() {
let points = vec![
Point::new(vec![1.0, 5.0, 2.0]),
Point::new(vec![3.0, 2.0, 4.0]),
Point::new(vec![2.0, 3.0, 1.0]),
];
let merged = MinPool.merge(&points);
assert_eq!(merged.dims(), &[1.0, 2.0, 1.0]);
}
#[test]
fn test_sum() {
let points = vec![
Point::new(vec![1.0, 2.0]),
Point::new(vec![3.0, 4.0]),
];
let merged = Sum.merge(&points);
assert_eq!(merged.dims(), &[4.0, 6.0]);
}
#[test]
fn test_merge_names() {
assert_eq!(Mean.name(), "mean");
assert_eq!(MaxPool.name(), "max_pool");
assert_eq!(MinPool.name(), "min_pool");
assert_eq!(Sum.name(), "sum");
}
#[test]
#[should_panic(expected = "Cannot merge empty")]
fn test_merge_empty_panics() {
let points: Vec<Point> = vec![];
Mean.merge(&points);
}
#[test]
#[should_panic(expected = "same dimensionality")]
fn test_merge_dimension_mismatch_panics() {
let points = vec![
Point::new(vec![1.0, 2.0]),
Point::new(vec![1.0, 2.0, 3.0]),
];
Mean.merge(&points);
}
}