use crate::*;
use layer::*;
use node::*;
use std::sync::{atomic, Arc, RwLock};
use tree_file_format::*;
use crate::plugins::{GrandmaPlugin, TreePluginSet};
use crate::query_tools::{KnnQueryHeap, MultiscaleQueryHeap, RoutingQueryHeap};
use errors::GrandmaResult;
use std::collections::HashMap;
use std::iter::Iterator;
use std::iter::Rev;
use std::ops::Range;
use std::slice::Iter;
use plugins::labels::{LabelSummaryPlugin,TreeLabelSummary};
#[derive(Debug)]
pub struct CoverTreeParameters<D: PointCloud> {
pub total_nodes: atomic::AtomicUsize,
pub scale_base: f32,
pub leaf_cutoff: usize,
pub min_res_index: i32,
pub use_singletons: bool,
pub point_cloud: Arc<D>,
pub verbosity: u32,
pub plugins: RwLock<TreePluginSet>,
}
impl<D: PointCloud> CoverTreeParameters<D> {
#[inline]
pub fn internal_index(&self, scale_index: i32) -> usize {
if scale_index < self.min_res_index {
0
} else {
(scale_index - self.min_res_index + 1) as usize
}
}
}
pub type LayerIter<'a, D> = Rev<std::iter::Zip<Range<i32>, Iter<'a, CoverLayerReader<D>>>>;
pub struct CoverTreeReader<D: PointCloud> {
parameters: Arc<CoverTreeParameters<D>>,
layers: Vec<CoverLayerReader<D>>,
root_address: NodeAddress,
}
impl<D: PointCloud> Clone for CoverTreeReader<D> {
fn clone(&self) -> CoverTreeReader<D> {
CoverTreeReader {
parameters: self.parameters.clone(),
layers: self.layers.clone(),
root_address: self.root_address,
}
}
}
impl<D: PointCloud + LabeledCloud> CoverTreeReader<D> {
pub fn get_node_label_summary_and<F, S>(
&self,
node_address: (i32, PointIndex),
transform_fn: F,
) -> Option<S>
where
F: Fn(&D::LabelSummary) -> S,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| n.label_summary().map(transform_fn))
.flatten()
}
}
impl<D: PointCloud> CoverTreeReader<D> {
pub fn point_cloud(&self) -> &Arc<D> {
&self.parameters.point_cloud
}
pub fn layer(&self, scale_index: i32) -> &CoverLayerReader<D> {
&self.layers[self.parameters.internal_index(scale_index)]
}
pub fn scale(&self, scale_index: i32) -> f32 {
self.parameters.scale_base.powi(scale_index)
}
pub fn get_node_and<F, T>(&self, node_address: (i32, PointIndex), f: F) -> Option<T>
where
F: FnOnce(&CoverNode<D>) -> T,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| f(n))
}
pub fn get_node_children_and<F, T>(&self, node_address: (i32, PointIndex), f: F) -> Option<T>
where
F: FnOnce(NodeAddress, &[NodeAddress]) -> T,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_children_and(node_address.1, f)
}
pub fn root_address(&self) -> NodeAddress {
self.root_address
}
pub fn layers(&self) -> LayerIter<D> {
((self.parameters.min_res_index - 1)
..(self.layers.len() as i32 + self.parameters.min_res_index - 1))
.zip(self.layers.iter())
.rev()
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn parameters(&self) -> &Arc<CoverTreeParameters<D>> {
&self.parameters
}
pub fn node_count(&self) -> usize {
self.layers().fold(0, |a, (_si, l)| a + l.len())
}
pub fn scale_range(&self) -> Range<i32> {
(self.parameters.min_res_index)
..(self.parameters.min_res_index - 1 + self.layers.len() as i32)
}
pub fn get_plugin_and<T: Send + Sync + 'static, F, S>(&self, transform_fn: F) -> Option<S>
where
F: FnOnce(&T) -> S,
{
self.parameters
.plugins
.read()
.unwrap()
.get::<T>()
.map(transform_fn)
}
pub fn get_node_plugin_and<T: Send + Sync + 'static, F, S>(
&self,
node_address: (i32, PointIndex),
transform_fn: F,
) -> Option<S>
where
F: FnOnce(&T) -> S,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| n.get_plugin_and(transform_fn))
.flatten()
}
pub fn knn<'a, T: Into<PointRef<'a>>>(
&self,
point: T,
k: usize,
) -> GrandmaResult<Vec<(f32, PointIndex)>> {
let mut query_heap = KnnQueryHeap::new(k, self.parameters.scale_base);
let point: PointRef<'a> = point.into();
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let dist_to_root = D::Metric::dist(&root_center, point)?;
query_heap.push_nodes(&[self.root_address], &[dist_to_root], None);
self.greedy_knn_nodes(&point, &mut query_heap);
while let Some((_dist, address)) = query_heap.closest_unvisited_singleton_covering_address()
{
self.get_node_and(address, |n| {
n.singleton_knn(&point, &self.parameters.point_cloud, &mut query_heap)
});
self.greedy_knn_nodes(&point, &mut query_heap);
}
Ok(query_heap.unpack())
}
pub fn routing_knn<'a, T: Into<PointRef<'a>>>(
&self,
point: T,
k: usize,
) -> GrandmaResult<Vec<(f32, PointIndex)>> {
let mut query_heap = KnnQueryHeap::new(k, self.parameters.scale_base);
let point: PointRef<'a> = point.into();
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let dist_to_root = D::Metric::dist(&root_center, point)?;
query_heap.push_nodes(&[self.root_address], &[dist_to_root], None);
self.greedy_knn_nodes(&point, &mut query_heap);
while self.greedy_knn_nodes(&point, &mut query_heap) {}
Ok(query_heap.unpack())
}
fn greedy_knn_nodes<'a, T: Into<PointRef<'a>>>(
&self,
point: T,
query_heap: &mut KnnQueryHeap,
) -> bool {
let point: PointRef<'a> = point.into();
let mut did_something = false;
while let Some((dist, nearest_address)) =
query_heap.closest_unvisited_child_covering_address()
{
if self
.get_node_and(nearest_address, |n| n.is_leaf())
.unwrap_or(true)
{
break;
} else {
self.get_node_and(nearest_address, |n| {
n.child_knn(Some(dist), &point, &self.parameters.point_cloud, query_heap)
});
}
did_something = true;
}
did_something
}
pub fn multiscale_knn<'a, T: Into<PointRef<'a>>>(
&self,
point: T,
k: usize,
) -> GrandmaResult<HashMap<i32, Vec<(f32, NodeAddress)>>> {
let mut query_heap = MultiscaleQueryHeap::new(k, self.parameters.scale_base);
let point: PointRef<'a> = point.into();
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let dist_to_root = D::Metric::dist(&root_center, point)?;
query_heap.push_nodes(&[self.root_address], &[dist_to_root], None);
println!("========================");
println!("{:#?}", query_heap);
for (si, _) in self.layers() {
while let Some((q_dist, nearest_address)) = query_heap.pop_closest_unqueried(si) {
println!("========================");
println!("{:#?}", query_heap);
match query_heap.furthest_node(si) {
Some((furthest_distance, _)) => {
if q_dist - self.parameters.scale_base.powi(si) < furthest_distance {
self.get_node_and(nearest_address, |n| {
n.child_knn(
Some(q_dist),
&point,
&self.parameters.point_cloud,
&mut query_heap,
)
});
} else {
break;
}
}
None => break,
}
}
}
println!("========================");
Ok(query_heap.unpack())
}
pub fn dry_insert<'a, T: Into<PointRef<'a>>>(
&self,
point: T,
) -> GrandmaResult<Vec<(f32, NodeAddress)>> {
let point: PointRef<'a> = point.into();
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let mut current_distance = D::Metric::dist(&root_center, point)?;
let mut current_address = self.root_address;
let mut trace = vec![(current_distance, current_address)];
while let Some(nearest) = self.get_node_and(current_address, |n| {
n.covering_child(
self.parameters.scale_base,
current_distance,
point,
&self.parameters.point_cloud,
)
}) {
if let Some(nearest) = nearest? {
trace.push(nearest);
current_distance = nearest.0;
current_address = nearest.1;
} else {
break;
}
}
Ok(trace)
}
pub fn node_fractal_dim(&self, node_address: NodeAddress) -> f32 {
let count: f32 = self
.get_node_and(node_address, |n| {
(n.singletons_len() + n.children_len()) as f32
})
.unwrap() as f32;
count.log(self.parameters.scale_base)
}
pub fn node_weighted_fractal_dim(&self, node_address: NodeAddress) -> f32 {
let weighted_count: f32 = self
.get_node_and(node_address, |n| {
let singleton_count = n.singletons().len() as f32;
let mut max_pop: usize = 1;
let mut weighted_count: f32 = 0.0;
if let Some((nested_scale, children)) = n.children() {
let mut pops: Vec<usize> = children
.iter()
.map(|child_addr| {
self.get_node_and(*child_addr, |child| child.cover_count())
.unwrap()
})
.collect();
pops.push(
self.get_node_and((nested_scale, node_address.1), |child| {
child.cover_count()
})
.unwrap(),
);
max_pop = *pops.iter().max().unwrap();
pops.iter()
.for_each(|p| weighted_count += (*p as f32) / (max_pop as f32));
}
weighted_count + singleton_count / (max_pop as f32)
})
.unwrap();
weighted_count.log(self.parameters.scale_base)
}
pub fn layer_fractal_dim(&self, scale_index: i32) -> f32 {
let parent_layer = self.layer(scale_index);
let parent_count = parent_layer.len() as f32;
let mut child_count: f32 = 0.0;
parent_layer
.for_each_node(|_, n| child_count += (n.singletons_len() + n.children_len()) as f32);
child_count.log(self.parameters.scale_base) - parent_count.log(self.parameters.scale_base)
}
pub fn layer_weighted_fractal_dim(&self, scale_index: i32) -> f32 {
let parent_layer = self.layer(scale_index);
let mut parent_coverage_counts: Vec<usize> = Vec::new();
let mut child_coverage_counts: Vec<usize> = Vec::new();
let mut singletons_count: f32 = 0.0;
parent_layer.for_each_node(|center_index, n| {
parent_coverage_counts.push(n.cover_count());
singletons_count += n.singletons().len() as f32;
if let Some((nested_scale, children)) = n.children() {
child_coverage_counts.extend(children.iter().map(|child_addr| {
self.get_node_and(*child_addr, |child| child.cover_count())
.unwrap()
}));
child_coverage_counts.push(
self.get_node_and((nested_scale, *center_index), |child| child.cover_count())
.unwrap(),
);
}
});
let max_parent_pop: f32 = *parent_coverage_counts.iter().max().unwrap_or(&1) as f32;
let max_child_pop: f32 = *child_coverage_counts.iter().max().unwrap_or(&1) as f32;
let weighted_child_sum: f32 = singletons_count / max_child_pop
+ child_coverage_counts
.iter()
.fold(0.0, |a, c| a + (*c as f32) / max_child_pop);
let weighted_parent_sum: f32 = parent_coverage_counts
.iter()
.fold(0.0, |a, c| a + (*c as f32) / max_parent_pop);
weighted_child_sum.log(self.parameters.scale_base)
- weighted_parent_sum.log(self.parameters.scale_base)
}
pub(crate) fn no_dangling_refs(&self) -> bool {
let mut refs_to_check = vec![self.root_address];
while let Some(node_addr) = refs_to_check.pop() {
println!("checking {:?}", node_addr);
println!("refs_to_check: {:?}", refs_to_check);
let node_exists = self.get_node_and(node_addr, |n| {
if let Some((nested_scale, other_children)) = n.children() {
println!(
"Pushing: {:?}, {:?}",
(nested_scale, other_children),
other_children
);
refs_to_check.push((nested_scale, node_addr.1));
refs_to_check.extend(&other_children[..]);
}
});
if node_exists.is_none() {
return false;
}
}
true
}
}
pub struct CoverTreeWriter<D: PointCloud> {
pub(crate) parameters: Arc<CoverTreeParameters<D>>,
pub(crate) layers: Vec<CoverLayerWriter<D>>,
pub(crate) root_address: NodeAddress,
}
impl<D: PointCloud + LabeledCloud> CoverTreeWriter<D> {
pub fn generate_summaries(&mut self){
self.add_plugin::<LabelSummaryPlugin>(TreeLabelSummary::default())
}
}
impl<D: PointCloud> CoverTreeWriter<D> {
pub fn add_plugin<P: GrandmaPlugin<D>>(
&mut self,
plug_in: <P as plugins::GrandmaPlugin<D>>::TreeComponent,
) where
<P as plugins::GrandmaPlugin<D>>::TreeComponent: 'static,
<P as plugins::GrandmaPlugin<D>>::NodeComponent: 'static,
{
let reader = self.reader();
for layer in self.layers.iter_mut() {
layer.reader().for_each_node(|pi, n| {
let node_component = P::node_component(&plug_in, n, &reader);
unsafe { layer.update_node(*pi, move |n| n.insert_plugin(node_component.clone())) }
});
layer.refresh()
}
self.parameters.plugins.write().unwrap().insert(plug_in);
}
pub(crate) unsafe fn layer(&mut self, scale_index: i32) -> &mut CoverLayerWriter<D> {
&mut self.layers[self.parameters.internal_index(scale_index)]
}
pub(crate) unsafe fn update_node<F>(&mut self, address: NodeAddress, update_fn: F)
where
F: Fn(&mut CoverNode<D>) + 'static + Send + Sync,
{
self.layers[self.parameters.internal_index(address.0)].update_node(address.1, update_fn);
}
pub fn reader(&self) -> CoverTreeReader<D> {
CoverTreeReader {
parameters: Arc::clone(&self.parameters),
layers: self.layers.iter().map(|l| l.reader()).collect(),
root_address: self.root_address,
}
}
pub(crate) unsafe fn insert_raw(
&mut self,
scale_index: i32,
point_index: PointIndex,
node: CoverNode<D>,
) {
self.layers[self.parameters.internal_index(scale_index)].insert_raw(point_index, node);
}
pub fn load(cover_proto: &CoreProto, point_cloud: Arc<D>) -> GrandmaResult<CoverTreeWriter<D>> {
let parameters = Arc::new(CoverTreeParameters {
total_nodes: atomic::AtomicUsize::new(0),
use_singletons: cover_proto.use_singletons,
scale_base: cover_proto.scale_base as f32,
leaf_cutoff: cover_proto.cutoff as usize,
min_res_index: cover_proto.resolution as i32,
point_cloud,
verbosity: 2,
plugins: RwLock::new(TreePluginSet::new()),
});
let root_address = (
cover_proto.get_root_scale(),
cover_proto.get_root_index() as usize,
);
let layers = cover_proto
.get_layers()
.par_iter()
.map(|l| CoverLayerWriter::load(l))
.collect();
Ok(CoverTreeWriter {
parameters,
layers,
root_address,
})
}
pub fn save(&self) -> CoreProto {
let mut cover_proto = CoreProto::new();
cover_proto.set_scale_base(self.parameters.scale_base);
cover_proto.set_cutoff(self.parameters.leaf_cutoff as u64);
cover_proto.set_resolution(self.parameters.min_res_index);
cover_proto.set_use_singletons(self.parameters.use_singletons);
cover_proto.set_dim(self.parameters.point_cloud.dim() as u64);
cover_proto.set_count(self.parameters.point_cloud.len() as u64);
cover_proto.set_root_scale(self.root_address.0);
cover_proto.set_root_index(self.root_address.1 as u64);
cover_proto.set_layers(self.layers.iter().map(|l| l.save()).collect());
cover_proto
}
pub fn refresh(&mut self) {
self.layers.iter_mut().rev().for_each(|l| l.refresh());
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::utils::cover_tree_from_labeled_yaml;
use std::path::Path;
pub(crate) fn build_mnist_tree() -> CoverTreeWriter<DefaultLabeledCloud<L2>> {
let file_name = "../data/mnist_complex.yml";
let path = Path::new(file_name);
if !path.exists() {
panic!(file_name.to_owned() + &" does not exist".to_string());
}
cover_tree_from_labeled_yaml(&path).unwrap()
}
pub(crate) fn build_basic_tree() -> CoverTreeWriter<DefaultLabeledCloud<L2>> {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: true,
verbosity: 0,
};
builder.build(Arc::new(point_cloud)).unwrap()
}
#[test]
fn len_is_num_layers() {
let tree = build_basic_tree();
let reader = tree.reader();
let mut l = 0;
for _ in reader.layers() {
l += 1;
}
assert_eq!(reader.len(), l);
}
#[test]
fn layer_has_correct_scale_index() {
let tree = build_basic_tree();
let reader = tree.reader();
let mut got_one = false;
for (si, l) in reader.layers() {
println!(
"Scale Index, correct: {:?}, Scale Index, layer: {:?}",
si,
l.scale_index()
);
assert_eq!(si, l.scale_index());
got_one = true;
}
assert!(got_one);
}
#[test]
fn greedy_knn_nodes() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
verbosity: 0,
};
let tree = builder.build(Arc::new(point_cloud)).unwrap();
let reader = tree.reader();
let point = [-0.5];
let mut query_heap = KnnQueryHeap::new(5, reader.parameters.scale_base);
let dist_to_root = reader
.parameters
.point_cloud
.distances_to_point(&point, &[reader.root_address().1])
.unwrap()[0];
query_heap.push_nodes(&[reader.root_address()], &[dist_to_root], None);
assert_eq!(
reader.root_address(),
query_heap
.closest_unvisited_child_covering_address()
.unwrap()
.1
);
reader.greedy_knn_nodes(&point, &mut query_heap);
println!("{:#?}", query_heap);
println!(
"{:#?}",
query_heap.closest_unvisited_child_covering_address()
);
}
#[test]
fn dry_insert_sanity() {
let writer = build_basic_tree();
let reader = writer.reader();
let trace = reader.dry_insert(&[0.495f32][..]).unwrap();
assert!(trace.len() == 4 || trace.len() == 3);
println!("{:?}", trace);
for i in 0..(trace.len() - 1) {
assert!((trace[i].1).0 > (trace[i + 1].1).0);
}
}
#[test]
fn multiscale_sanity() {
let writer = build_basic_tree();
let reader = writer.reader();
let trace = reader.multiscale_knn(&[0.495f32][..], 2).unwrap();
assert_eq!(
trace.get(&reader.root_address().0).unwrap()[0],
(0.495, reader.root_address())
);
println!("{:?}", trace);
}
#[test]
fn knn_singletons_on() {
println!("2 nearest neighbors of 0.0 are 0.48 and 0.0");
let writer = build_basic_tree();
let reader = writer.reader();
let zero_nbrs = reader.knn(&[0.1f32][..], 2).unwrap();
println!("{:?}", zero_nbrs);
assert!(zero_nbrs[0].1 == 4);
assert!(zero_nbrs[1].1 == 2);
}
#[test]
fn label_summary() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
verbosity: 0,
};
let mut tree = builder.build(Arc::new(point_cloud)).unwrap();
tree.generate_summaries();
let reader = tree.reader();
for (_,layer) in reader.layers() {
layer.for_each_node(|_,n| println!("{:?}", n.label_summary()));
}
reader.get_node_label_summary_and(reader.root_address(),|l| {
assert_eq!(l.items.len(),2);
assert_eq!(l.nones,0);
assert_eq!(l.errors,0);
});
}
#[test]
fn knn_singletons_off() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
verbosity: 0,
};
let tree = builder.build(Arc::new(point_cloud)).unwrap();
let reader = tree.reader();
println!("2 nearest neighbors of 0.1 are 0.48 and 0.0");
let zero_nbrs = reader.knn(&[0.1f32][..], 2).unwrap();
println!("{:?}", zero_nbrs);
assert!(zero_nbrs[0].1 == 4);
assert!(zero_nbrs[1].1 == 2);
}
}