use std::{marker::PhantomData, ops::DerefMut};
use crate::model::{Ball, BallNode, GetNeighbors, Model};
const EXTRA_THRESHOLD: f64 = 25.;
const INTRA_THRESHOLD: f64 = 16.;
const MERGE_THRESHOLD: f64 = 1.;
const DECAY_FACTOR: f64 = 0.95;
const DECAY_THRESHOLD: f64 = 1E-2;
const MAX_NEIGHBORS: usize = 2;
pub struct Algo<Point: PartialEq + 'static> {
dist: Box<dyn Fn(&Point, &Point) -> f64>,
combine: Box<dyn Fn(&Point, f64, &Point, f64) -> Point>,
phantom: PhantomData<Point>,
}
impl<Point: PartialEq + 'static> Algo<Point> {
pub fn new<Dist, Combine>(dist: Dist, combine: Combine) -> Self
where
Dist: Fn(&Point, &Point) -> f64 + 'static,
Combine: Fn(&Point, f64, &Point, f64) -> Point + 'static,
{
Self {
dist: Box::new(dist),
combine: Box::new(combine),
phantom: PhantomData,
}
}
pub fn fit<'a>(&'a self, model: &'a mut Model<Point>, point: Point) {
let neighborhood = model.get_neighborhood(&point);
match neighborhood.first() {
None => {
self.init(model, point);
}
Some(candidate) => {
let (vertex, maybe_neighbor) = self.update(model, candidate, point, &neighborhood);
if let Some(maybe_neighbor) = maybe_neighbor {
self.update_local_graph(candidate, maybe_neighbor);
};
self.decay(model, vertex);
}
}
}
fn init(&self, model: &mut Model<Point>, point: Point) -> BallNode<Point> {
let ball = Ball::new(point, f64::INFINITY, 0.);
model.add_ball(ball, vec![])
}
fn update(
&self,
model: &mut Model<Point>,
vertex: &BallNode<Point>,
point: Point,
neighborhood: &Vec<BallNode<Point>>,
) -> (BallNode<Point>, Option<BallNode<Point>>) {
let mut closest = vertex.deref_data_mut();
let d = (self.dist)(&closest.center, &point);
if d < INTRA_THRESHOLD * closest.radius {
self.update_ball(&mut closest, point, d);
(vertex.clone(), neighborhood.get(1).map(|v| v.clone()))
} else {
let ball = self.split_ball(point, d, &closest);
let vertex = model.add_ball(ball, neighborhood.get_neighbors());
(vertex.clone(), Some(vertex))
}
}
fn update_ball(&self, ball: &mut impl DerefMut<Target = Ball<Point>>, point: Point, dist: f64) {
ball.center = self.update_mu(ball, point);
ball.radius = self.update_sigma(ball, dist);
ball.weight += 1.;
}
fn update_mu(&self, ball: &impl DerefMut<Target = Ball<Point>>, point: Point) -> Point {
(self.combine)(&ball.center, ball.weight, &point, 1.)
}
fn update_sigma(&self, ball: &impl DerefMut<Target = Ball<Point>>, dist: f64) -> f64 {
if ball.weight == 0. {
dist
} else {
(ball.radius * ball.weight + dist) / (ball.weight + 1.)
}
}
fn split_ball(
&self,
point: Point,
d: f64,
neighbor: &impl DerefMut<Target = Ball<Point>>,
) -> Ball<Point> {
let radius = d / EXTRA_THRESHOLD;
let center = (self.combine)(&neighbor.center, -1., &point, 5.);
Ball::new(center, radius, 1.)
}
fn update_local_graph(&self, vertex: &BallNode<Point>, maybe_neighbor: BallNode<Point>) {
let neighborhood: Vec<BallNode<Point>> = vertex.iter_neighbors().collect();
let neighborhood = self.rebuild_neighborhood(vertex, neighborhood, maybe_neighbor);
let mut neighborhood = self.rebuild_merge(vertex, neighborhood);
if neighborhood.len() > MAX_NEIGHBORS {
neighborhood.pop();
}
vertex.set_neighbors(neighborhood.get_neighbors());
}
fn rebuild_neighborhood(
&self,
vertex: &BallNode<Point>,
mut neighborhood: Vec<BallNode<Point>>,
maybe_neighbor: BallNode<Point>,
) -> Vec<BallNode<Point>> {
let current_point = &vertex.deref_data().center;
let dist_to_current =
|p: &BallNode<Point>| (self.dist)(&p.deref_data().center, ¤t_point);
let candidate_dist = dist_to_current(&maybe_neighbor);
for i in 0..MAX_NEIGHBORS {
if i == neighborhood.len() {
neighborhood.push(maybe_neighbor);
break;
}
if neighborhood[i].eq(&maybe_neighbor) {
break;
}
if dist_to_current(&neighborhood[i]) > candidate_dist {
neighborhood.insert(i, maybe_neighbor);
break;
}
}
neighborhood
}
fn rebuild_merge(
&self,
vertex: &BallNode<Point>,
mut neighborhood: Vec<BallNode<Point>>,
) -> Vec<BallNode<Point>> {
let (should_merge, d) = self.should_merge(vertex, &neighborhood[0]);
if should_merge {
self.merge_balls(vertex, &neighborhood[0], d);
neighborhood.remove(0);
}
neighborhood
}
fn should_merge(&self, first: &BallNode<Point>, second: &BallNode<Point>) -> (bool, f64) {
let current_data = first.deref_data();
let neighbor_data = second.deref_data();
let d = (self.dist)(¤t_data.center, &neighbor_data.center);
let should_merge = d < (current_data.radius + neighbor_data.radius) * MERGE_THRESHOLD;
(should_merge, d)
}
fn merge_balls(&self, vertex: &BallNode<Point>, neighbor: &BallNode<Point>, d: f64) {
let mut current_data = vertex.deref_data_mut();
let mut neighbor_data = neighbor.deref_data_mut();
current_data.center = (self.combine)(
¤t_data.center,
current_data.weight,
&neighbor_data.center,
neighbor_data.weight,
);
current_data.radius = d
+ (current_data.radius * current_data.weight
+ neighbor_data.radius * neighbor_data.weight)
/ (current_data.weight + neighbor_data.weight);
current_data.weight = current_data.weight + neighbor_data.weight;
neighbor_data.weight = 0.;
}
fn decay(&self, model: &mut Model<Point>, vertex: BallNode<Point>) {
model.graph.retain(|v| {
if v.deref_data().ne(&vertex.deref_data()) {
v.deref_data_mut().weight *= DECAY_FACTOR;
}
v.deref_data().weight > DECAY_THRESHOLD
})
}
}
#[cfg(test)]
mod tests {
use approx_eq::assert_approx_eq;
use crate::algorithm::*;
use crate::space;
#[test]
fn test_init() {
let (dataset, model) = build_model(1);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
assert_eq!(dataset[0], first.center);
assert_eq!(f64::INFINITY, first.radius);
assert_eq!(0., first.weight);
}
#[test]
fn test_update() {
let (dataset, model) = build_model(2);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
assert_eq!(dataset[1], first.center);
assert_eq!(20., first.radius);
assert_eq!(1., first.weight);
}
#[test]
fn test_new() {
let (dataset, model) = build_model(3);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
assert_eq!(dataset[1], first.center);
assert_eq!(20., first.radius);
assert_approx_eq!(DECAY_FACTOR, first.weight);
let second = balls.next().unwrap();
assert_eq!(vec![18.5, -16.5], second.center);
assert_eq!(15.68, second.radius);
assert_eq!(1., second.weight);
}
#[test]
fn test_neighborhood_init() {
let (_dataset, model) = build_model(3);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
let second = balls.next().unwrap();
let mut n1 = model.graph[0].iter_neighbors();
assert_eq!(second.center, n1.next().unwrap().deref_data().center);
assert!(n1.next().is_none());
let mut n2 = model.graph[1].iter_neighbors();
assert_eq!(first.center, n2.next().unwrap().deref_data().center);
assert!(n2.next().is_none());
}
#[test]
fn test_neighborhood_refine_append() {
let (_dataset, model) = build_model(4);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
let second = balls.next().unwrap();
let third = balls.next().unwrap();
let mut n1 = model.graph[0].iter_neighbors();
assert_eq!(second.center, n1.next().unwrap().deref_data().center);
assert_eq!(third.center, n1.next().unwrap().deref_data().center); let mut n2 = model.graph[1].iter_neighbors();
assert_eq!(first.center, n2.next().unwrap().deref_data().center);
assert!(n2.next().is_none()); let mut n3 = model.graph[2].iter_neighbors();
assert_eq!(first.center, n3.next().unwrap().deref_data().center);
assert_eq!(second.center, n3.next().unwrap().deref_data().center);
}
#[test]
fn test_neighborhood_refine_prepend() {
let (_dataset, model) = build_model(5);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
let second = balls.next().unwrap();
let third = balls.next().unwrap();
let fourth = balls.next().unwrap();
let mut n1 = model.graph[0].iter_neighbors();
assert_eq!(second.center, n1.next().unwrap().deref_data().center);
assert_eq!(third.center, n1.next().unwrap().deref_data().center);
let mut n2 = model.graph[1].iter_neighbors();
assert_eq!(fourth.center, n2.next().unwrap().deref_data().center); assert_eq!(first.center, n2.next().unwrap().deref_data().center);
let mut n3 = model.graph[2].iter_neighbors();
assert_eq!(first.center, n3.next().unwrap().deref_data().center);
assert_eq!(second.center, n3.next().unwrap().deref_data().center);
}
#[test]
fn test_merge() {
let (_dataset, model) = build_model(8);
let mut balls = model.iter_balls();
let first = balls.next().unwrap();
let second = balls.next().unwrap();
let third = balls.next().unwrap();
assert!(balls.next().is_none());
assert!(first.weight > 3.);
assert!(second.weight < 1.);
assert!(third.weight < 1.);
assert!(first.center[0] < 10.);
assert!(first.center[1] < 0.);
assert!(second.center[0] > 10.);
assert!(second.center[1] > 0.);
assert!(third.center[0] > 10.);
assert!(third.center[1] > 0.);
let mut n1 = model.graph[0].iter_neighbors();
assert_eq!(third.center, n1.next().unwrap().deref_data().center);
assert!(n1.next().is_none());
}
fn build_model(count: usize) -> (Vec<Vec<f64>>, Model<Vec<f64>>) {
let dataset = build_sample();
let algo = Algo::new(space::euclid_dist, space::real_combine);
let mut model = Model::new(space::euclid_dist);
for i in 0..count {
algo.fit(&mut model, dataset[i].clone());
}
(dataset, model)
}
fn build_sample() -> Vec<Vec<f64>> {
vec![
vec![5., -1.],
vec![1., 1.],
vec![15., -13.],
vec![11., 23.],
vec![31., -3.],
vec![10., -9.],
vec![6., -4.],
vec![-2., -5.],
]
}
}