use crate::errors::GrandmaResult;
use pointcloud::*;
use rand::{thread_rng, Rng};
use std::fmt;
use std::sync::Arc;
#[derive(Clone)]
pub(crate) struct CoveredData {
dists: Vec<f32>,
coverage: Vec<PointIndex>,
pub(crate) center_index: PointIndex,
}
#[derive(Debug, Clone)]
pub(crate) struct UncoveredData {
coverage: Vec<PointIndex>,
}
impl UncoveredData {
pub(crate) fn pick_center<D: PointCloud>(
&mut self,
radius: f32,
point_cloud: &Arc<D>,
) -> GrandmaResult<CoveredData> {
let mut rng = thread_rng();
let new_center: usize = rng.gen_range(0, self.coverage.len());
let center_index = self.coverage.remove(new_center);
let dists = point_cloud.distances_to_point_index(center_index, &self.coverage)?;
let mut close_index = Vec::with_capacity(self.coverage.len());
let mut close_dist = Vec::with_capacity(self.coverage.len());
let mut far = Vec::new();
for (i, d) in self.coverage.iter().zip(&dists) {
if *d < radius {
close_index.push(*i);
close_dist.push(*d);
} else {
far.push(*i);
}
}
let close = CoveredData {
coverage: close_index,
dists: close_dist,
center_index,
};
self.coverage = far;
Ok(close)
}
pub(crate) fn len(&self) -> usize {
self.coverage.len()
}
}
impl fmt::Debug for CoveredData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"CoveredData {{ center_index: {:?},coverage: {:?} }}",
self.center_index, self.coverage
)
}
}
fn find_split(dist_indexes: &[(f32, usize)], thresh: f32) -> usize {
let mut smaller = 0;
let mut larger = dist_indexes.len() - 1;
while smaller <= larger {
let m = (smaller + larger) / 2;
if dist_indexes[m].0 < thresh {
smaller = m + 1;
} else if dist_indexes[m].0 > thresh {
if m == 0 {
return 0;
}
larger = m - 1;
} else {
return m;
}
}
smaller
}
impl CoveredData {
pub(crate) fn new<D: PointCloud>(point_cloud: &Arc<D>) -> GrandmaResult<CoveredData> {
let mut coverage = point_cloud.reference_indexes();
let center_index = coverage.pop().unwrap();
let dists = point_cloud.distances_to_point_index(center_index, &coverage)?;
Ok(CoveredData {
dists,
coverage,
center_index,
})
}
pub(crate) fn split(self, thresh: f32) -> GrandmaResult<(CoveredData, UncoveredData)> {
let mut close_index = Vec::with_capacity(self.coverage.len());
let mut close_dist = Vec::with_capacity(self.coverage.len());
let mut far = Vec::new();
for (i, d) in self.coverage.iter().zip(&self.dists) {
if *d < thresh {
close_index.push(*i);
close_dist.push(*d);
} else {
far.push(*i);
}
}
let close = CoveredData {
coverage: close_index,
dists: close_dist,
center_index: self.center_index,
};
let new_far = UncoveredData { coverage: far };
Ok((close, new_far))
}
pub(crate) fn into_indexes(self) -> Vec<PointIndex> {
self.coverage
}
pub(crate) fn max_distance(&self) -> f32 {
self.dists
.iter()
.cloned()
.fold(-1. / 0. , f32::max)
}
pub(crate) fn len(&self) -> usize {
self.coverage.len() + 1
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn splits_correctly_1() {
let mut data = Vec::with_capacity(20);
for _i in 0..19 {
data.push(rand::random::<f32>() + 3.0);
}
data.push(0.0);
let labels: Vec<u64> = data.iter().map(|x| if *x > 0.5 { 1 } else { 0 }).collect();
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let cache = CoveredData::new(&Arc::new(point_cloud)).unwrap();
let (close, far) = cache.split(1.0).unwrap();
assert_eq!(1, close.len());
assert_eq!(19, far.len());
}
#[test]
fn uncovered_splits_correctly_1() {
let mut data = Vec::with_capacity(20);
for _i in 0..19 {
data.push(rand::random::<f32>() + 3.0);
}
data.push(0.0);
let labels: Vec<u64> = data.iter().map(|x| if *x > 0.5 { 1 } else { 0 }).collect();
let point_cloud = Arc::new(DefaultLabeledCloud::<L2>::new_simple(data, 1, labels));
let mut cache = UncoveredData {
coverage: (0..19 as PointIndex).collect(),
};
let close = cache.pick_center(1.0, &point_cloud).unwrap();
assert!(!close.coverage.contains(&close.center_index));
assert!(!cache.coverage.contains(&close.center_index));
for i in &close.coverage {
assert!(!cache.coverage.contains(i));
}
for i in &cache.coverage {
assert!(!close.coverage.contains(i));
}
}
#[test]
fn correct_dists() {
let mut data = Vec::with_capacity(20);
for _i in 0..19 {
data.push(rand::random::<f32>() + 3.0);
}
data.push(0.0);
let labels: Vec<u64> = data.iter().map(|x| if *x > 0.5 { 1 } else { 0 }).collect();
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data.clone(), 1, labels);
let cache = CoveredData::new(&Arc::new(point_cloud)).unwrap();
let thresh = 0.5;
let mut true_close = Vec::new();
let mut true_far = Vec::new();
for i in 0..19 {
if data[i] < thresh {
true_close.push(i);
} else {
true_far.push(i);
}
assert_approx_eq!(data[i], cache.dists[i]);
}
let (close, _far) = cache.split(thresh).unwrap();
for (tc, c) in true_close.iter().zip(close.coverage) {
assert_eq!(*tc, c);
}
}
}