use crate::model::UnsupervisedLearn;
use std::iter::Iterator;
use std::clone::Clone;
use std::collections::{HashSet, HashMap};
pub struct DBSCAN<T>
{
eps: f64,
min_points: usize,
dist: fn(&T, &T) -> f64,
neighbours: HashMap<usize, HashSet<usize>>,
phantom: std::marker::PhantomData<T>,
label: Vec<i32>,
cluster_data: Vec<T>
}
impl<T> DBSCAN<T>
where T: std::marker::Sized
{
pub fn new(eps: f64, min_points: usize, dist: fn(&T, &T) -> f64) -> DBSCAN<T>
{
DBSCAN
{
eps,
min_points,
dist: dist,
neighbours: HashMap::new(),
phantom: std::marker::PhantomData,
label: Vec::new(),
cluster_data: Vec::new()
}
}
pub fn get_number_clusters(self: &Self) -> usize
{
let mut clusters: HashSet<i32> = HashSet::new();
self.label.iter().for_each(|l| {clusters.insert(*l);});
let mut num_clusters: usize = 0;
clusters.iter().for_each(|c|
{
if *c >= 0
{
num_clusters += 1;
}
});
return num_clusters;
}
}
impl<T> UnsupervisedLearn<Vec<T>, Vec<i32>> for DBSCAN<T>
where T: Clone
{
fn train<'a, 'b>(self: &'a mut Self, input: &'b Vec<T>)
{
let mut cluster: i32 = 0;
for _v in input
{
self.label.push(-2);
}
self.cluster_data.append(&mut input.clone());
for (idx, _point) in self.cluster_data.clone().iter().enumerate()
{
let neighbours: HashSet<usize> = self.get_neighbours(idx);
if self.label[idx] == -2
{
if neighbours.len() < self.min_points
{
self.label[idx] = -1;
}
else
{
self.expand_cluster(idx, neighbours, cluster);
cluster += 1;
}
}
}
}
fn predict<'a>(self: &'a Self, x: &'a Vec<T>) -> Vec<i32>
{
if self.cluster_data.len() == 0
{
panic!("Model is not trained");
}
let mut cluster: Vec<i32> = Vec::with_capacity(x.len());
for y in x
{
let mut min_cluster: i32 = *self.label.first().unwrap();
let mut min_dist: f64 = (self.dist)(&y, self.cluster_data.first().unwrap());
for (z_c, z) in self.label.iter().zip(self.cluster_data.iter())
{
let dist: f64 = (self.dist)(&y, &z);
if dist < min_dist
{
min_dist = dist;
min_cluster = *z_c;
}
}
if min_dist < self.eps
{
cluster.push(min_cluster);
}
else
{
cluster.push(-1);
}
}
return cluster;
}
}
impl<T> DBSCAN<T>
where T: Clone
{
fn get_neighbours(self: &mut Self, point_idx: usize) -> HashSet<usize>
{
let p: &T = self.cluster_data.get(point_idx).unwrap();
let neighbours: Option<&HashSet<usize>> = self.neighbours.get(&point_idx);
if neighbours == None
{
let mut new_neighbours: HashSet<usize> = HashSet::new();
for (r_idx, r) in self.cluster_data.iter().enumerate()
{
if (self.dist)(&p, r) <= self.eps
{
new_neighbours.insert(r_idx);
}
}
self.neighbours.insert(point_idx, new_neighbours);
return self.neighbours.get(&point_idx).unwrap().clone();
}
else
{
return neighbours.unwrap().clone();
}
}
fn expand_cluster(self: &mut Self, point_idx: usize, neighbours: HashSet<usize>, cluster: i32)
{
self.label[point_idx] = cluster;
for q_idx in neighbours.iter()
{
if self.label[*q_idx] == -2
{
self.label[*q_idx] = -1;
let sub_neighbours: HashSet<usize> = self.get_neighbours(point_idx);
if sub_neighbours.len() >= self.min_points
{
self.expand_cluster(point_idx, sub_neighbours, cluster);
}
}
if self.label[*q_idx] < 0
{
self.label[*q_idx] = cluster;
}
}
}
}