use super::layer::*;
use super::node::*;
use crate::*;
use crate::monomap::{MonoReadHandle, MonoWriteHandle};
use crate::tree_file_format::*;
use std::sync::{atomic, Arc, RwLock};
use super::query_tools::{KnnQueryHeap, RoutingQueryHeap};
use crate::plugins::{GokoPlugin, TreePluginSet};
use errors::{GokoError, GokoResult};
use std::iter::Iterator;
use std::iter::Rev;
use std::ops::Range;
use std::slice::Iter;
use plugins::labels::*;
#[derive(Debug, Copy, Clone)]
pub enum PartitionType {
Nearest,
First,
}
#[derive(Debug)]
pub struct CoverTreeParameters<D: PointCloud> {
pub total_nodes: atomic::AtomicUsize,
pub scale_base: f32,
pub leaf_cutoff: usize,
pub min_res_index: i32,
pub use_singletons: bool,
pub partition_type: PartitionType,
pub point_cloud: Arc<D>,
pub verbosity: u32,
pub rng_seed: Option<u64>,
pub plugins: RwLock<TreePluginSet>,
}
impl<D: PointCloud> CoverTreeParameters<D> {
#[inline]
pub fn internal_index(&self, scale_index: i32) -> usize {
if scale_index < self.min_res_index {
0
} else {
(scale_index - self.min_res_index + 1) as usize
}
}
}
pub type LayerIter<'a, D> = Rev<std::iter::Zip<Range<i32>, Iter<'a, CoverLayerReader<D>>>>;
pub struct CoverTreeReader<D: PointCloud> {
parameters: Arc<CoverTreeParameters<D>>,
layers: Vec<CoverLayerReader<D>>,
root_address: NodeAddress,
final_addresses: MonoReadHandle<usize, NodeAddress>,
}
impl<D: PointCloud> Clone for CoverTreeReader<D> {
fn clone(&self) -> CoverTreeReader<D> {
CoverTreeReader {
parameters: self.parameters.clone(),
layers: self.layers.clone(),
root_address: self.root_address,
final_addresses: self.final_addresses.clone(),
}
}
}
impl<D: PointCloud + LabeledCloud> CoverTreeReader<D> {
pub fn get_node_label_summary(
&self,
node_address: (i32, usize),
) -> Option<Arc<SummaryCounter<D::LabelSummary>>> {
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| n.label_summary())
.flatten()
}
}
impl<D: PointCloud + MetaCloud> CoverTreeReader<D> {
pub fn get_node_metasummary(
&self,
node_address: (i32, usize),
) -> Option<Arc<SummaryCounter<D::MetaSummary>>> {
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| n.metasummary())
.flatten()
}
}
impl<D: PointCloud> CoverTreeReader<D> {
pub fn point_cloud(&self) -> &Arc<D> {
&self.parameters.point_cloud
}
pub fn layer(&self, scale_index: i32) -> &CoverLayerReader<D> {
&self.layers[self.parameters.internal_index(scale_index)]
}
pub fn scale(&self, scale_index: i32) -> f32 {
self.parameters.scale_base.powi(scale_index)
}
pub fn get_node_and<F, T>(&self, node_address: (i32, usize), f: F) -> Option<T>
where
F: FnOnce(&CoverNode<D>) -> T,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| f(n))
}
pub fn get_node_children_and<F, T>(&self, node_address: (i32, usize), f: F) -> Option<T>
where
F: FnOnce(NodeAddress, &[NodeAddress]) -> T,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_children_and(node_address.1, f)
}
pub fn root_address(&self) -> NodeAddress {
self.root_address
}
pub fn layers(&self) -> LayerIter<D> {
((self.parameters.min_res_index - 1)
..(self.layers.len() as i32 + self.parameters.min_res_index - 1))
.zip(self.layers.iter())
.rev()
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn parameters(&self) -> &Arc<CoverTreeParameters<D>> {
&self.parameters
}
pub fn node_count(&self) -> usize {
self.layers().fold(0, |a, (_si, l)| a + l.len())
}
pub fn scale_range(&self) -> Range<i32> {
(self.parameters.min_res_index)
..(self.parameters.min_res_index - 1 + self.layers.len() as i32)
}
pub fn get_plugin_and<T: Send + Sync + 'static, F, S>(&self, transform_fn: F) -> Option<S>
where
F: FnOnce(&T) -> S,
{
self.parameters
.plugins
.read()
.unwrap()
.get::<T>()
.map(transform_fn)
}
pub fn get_node_plugin_and<T: Send + Sync + 'static, F, S>(
&self,
node_address: (i32, usize),
transform_fn: F,
) -> Option<S>
where
F: FnOnce(&T) -> S,
{
self.layers[self.parameters.internal_index(node_address.0)]
.get_node_and(node_address.1, |n| n.get_plugin_and(transform_fn))
.flatten()
}
pub fn knn<'a>(&self, point: &D::PointRef<'a>, k: usize) -> GokoResult<Vec<(f32, usize)>> {
let mut query_heap = KnnQueryHeap::new(k, self.parameters.scale_base);
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let dist_to_root = D::Metric::dist(&root_center, &point);
query_heap.push_nodes(&[self.root_address], &[dist_to_root], None);
self.greedy_knn_nodes(&point, &mut query_heap);
while let Some((_dist, address)) = query_heap.closest_unvisited_singleton_covering_address()
{
self.get_node_and(address, |n| {
n.singleton_knn(&point, &self.parameters.point_cloud, &mut query_heap)
});
self.greedy_knn_nodes(&point, &mut query_heap);
}
Ok(query_heap.unpack())
}
pub fn routing_knn<'a>(
&self,
point: &D::PointRef<'a>,
k: usize,
) -> GokoResult<Vec<(f32, usize)>> {
let mut query_heap = KnnQueryHeap::new(k, self.parameters.scale_base);
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let dist_to_root = D::Metric::dist(&root_center, &point);
query_heap.push_nodes(&[self.root_address], &[dist_to_root], None);
self.greedy_knn_nodes(point, &mut query_heap);
while self.greedy_knn_nodes(point, &mut query_heap) {}
Ok(query_heap.unpack())
}
fn greedy_knn_nodes<'a>(&self, point: &D::PointRef<'a>, query_heap: &mut KnnQueryHeap) -> bool {
let mut did_something = false;
while let Some((dist, nearest_address)) =
query_heap.closest_unvisited_child_covering_address()
{
if self
.get_node_and(nearest_address, |n| n.is_leaf())
.unwrap_or(true)
{
break;
} else {
self.get_node_and(nearest_address, |n| {
n.child_knn(Some(dist), &point, &self.parameters.point_cloud, query_heap)
});
}
did_something = true;
}
did_something
}
pub fn path<'a>(&self, point: &D::PointRef<'a>) -> GokoResult<Vec<(f32, NodeAddress)>> {
let root_center = self.parameters.point_cloud.point(self.root_address.1)?;
let mut current_distance = D::Metric::dist(&root_center, &point);
let mut current_address = self.root_address;
let mut trace = vec![(current_distance, current_address)];
while let Some(nearest) =
self.get_node_and(current_address, |n| match self.parameters.partition_type {
PartitionType::Nearest => n.nearest_covering_child(
self.parameters.scale_base,
current_distance,
point,
&self.parameters.point_cloud,
),
PartitionType::First => n.first_covering_child(
self.parameters.scale_base,
current_distance,
point,
&self.parameters.point_cloud,
),
})
{
if let Some(nearest) = nearest? {
trace.push(nearest);
current_distance = nearest.0;
current_address = nearest.1;
} else {
break;
}
}
Ok(trace)
}
pub fn known_path(&self, point_index: usize) -> GokoResult<Vec<(f32, NodeAddress)>> {
self.final_addresses
.get_and(&point_index, |addr| {
let mut path = Vec::with_capacity((self.root_address().0 - addr.0) as usize);
let mut parent = Some(*addr);
while let Some(addr) = parent {
path.push(addr);
parent = self.get_node_and(addr, |n| n.parent_address()).flatten();
}
(&mut path[..]).reverse();
let point_indexes: Vec<usize> = path.iter().map(|na| na.1).collect();
let dists = self
.parameters
.point_cloud
.distances_to_point_index(point_index, &point_indexes[..])
.unwrap();
dists.iter().zip(path).map(|(d, a)| (*d, a)).collect()
})
.ok_or(GokoError::IndexNotInTree(point_index))
}
pub fn node_fractal_dim(&self, node_address: NodeAddress) -> f32 {
let count: f32 = self
.get_node_and(node_address, |n| {
(n.singletons_len() + n.children_len()) as f32
})
.unwrap() as f32;
count.log(self.parameters.scale_base)
}
pub fn node_weighted_fractal_dim(&self, node_address: NodeAddress) -> f32 {
let weighted_count: f32 = self
.get_node_and(node_address, |n| {
let singleton_count = n.singletons().len() as f32;
let mut max_pop: usize = 1;
let mut weighted_count: f32 = 0.0;
if let Some((nested_scale, children)) = n.children() {
let mut pops: Vec<usize> = children
.iter()
.map(|child_addr| {
self.get_node_and(*child_addr, |child| child.coverage_count())
.unwrap()
})
.collect();
pops.push(
self.get_node_and((nested_scale, node_address.1), |child| {
child.coverage_count()
})
.unwrap(),
);
max_pop = *pops.iter().max().unwrap();
pops.iter()
.for_each(|p| weighted_count += (*p as f32) / (max_pop as f32));
}
weighted_count + singleton_count / (max_pop as f32)
})
.unwrap();
weighted_count.log(self.parameters.scale_base)
}
pub fn layer_fractal_dim(&self, scale_index: i32) -> f32 {
let parent_layer = self.layer(scale_index);
let parent_count = parent_layer.len() as f32;
let mut child_count: f32 = 0.0;
parent_layer
.for_each_node(|_, n| child_count += (n.singletons_len() + n.children_len()) as f32);
child_count.log(self.parameters.scale_base) - parent_count.log(self.parameters.scale_base)
}
pub fn layer_weighted_fractal_dim(&self, scale_index: i32) -> f32 {
let parent_layer = self.layer(scale_index);
let mut parent_coverage_counts: Vec<usize> = Vec::new();
let mut child_coverage_counts: Vec<usize> = Vec::new();
let mut singletons_count: f32 = 0.0;
parent_layer.for_each_node(|center_index, n| {
parent_coverage_counts.push(n.coverage_count());
singletons_count += n.singletons().len() as f32;
if let Some((nested_scale, children)) = n.children() {
child_coverage_counts.extend(children.iter().map(|child_addr| {
self.get_node_and(*child_addr, |child| child.coverage_count())
.unwrap()
}));
child_coverage_counts.push(
self.get_node_and((nested_scale, *center_index), |child| {
child.coverage_count()
})
.unwrap(),
);
}
});
let max_parent_pop: f32 = *parent_coverage_counts.iter().max().unwrap_or(&1) as f32;
let max_child_pop: f32 = *child_coverage_counts.iter().max().unwrap_or(&1) as f32;
let weighted_child_sum: f32 = singletons_count / max_child_pop
+ child_coverage_counts
.iter()
.fold(0.0, |a, c| a + (*c as f32) / max_child_pop);
let weighted_parent_sum: f32 = parent_coverage_counts
.iter()
.fold(0.0, |a, c| a + (*c as f32) / max_parent_pop);
weighted_child_sum.log(self.parameters.scale_base)
- weighted_parent_sum.log(self.parameters.scale_base)
}
pub(crate) fn no_dangling_refs(&self) -> bool {
let mut refs_to_check = vec![self.root_address];
while let Some(node_addr) = refs_to_check.pop() {
println!("checking {:?}", node_addr);
println!("refs_to_check: {:?}", refs_to_check);
let node_exists = self.get_node_and(node_addr, |n| {
if let Some((nested_scale, other_children)) = n.children() {
println!(
"Pushing: {:?}, {:?}",
(nested_scale, other_children),
other_children
);
refs_to_check.push((nested_scale, node_addr.1));
refs_to_check.extend(&other_children[..]);
}
});
if node_exists.is_none() {
return false;
}
}
true
}
}
pub struct CoverTreeWriter<D: PointCloud> {
pub(crate) parameters: Arc<CoverTreeParameters<D>>,
pub(crate) layers: Vec<CoverLayerWriter<D>>,
pub(crate) root_address: NodeAddress,
pub(crate) final_addresses: MonoWriteHandle<usize, NodeAddress>,
}
impl<D: PointCloud + LabeledCloud> CoverTreeWriter<D> {
pub fn generate_summaries(&mut self) {
self.add_plugin::<LabelSummaryPlugin>(LabelSummaryPlugin::default())
}
}
impl<D: PointCloud + MetaCloud> CoverTreeWriter<D> {
pub fn generate_meta_summaries(&mut self) {
self.add_plugin::<MetaSummaryPlugin>(MetaSummaryPlugin::default())
}
}
impl<D: PointCloud> CoverTreeWriter<D> {
pub fn add_plugin<P: GokoPlugin<D>>(&mut self, plug_in: P) {
P::prepare_tree(&plug_in, self);
let reader = self.reader();
for layer in self.layers.iter_mut() {
layer.reader().for_each_node(|pi, n| {
if let Some(node_component) = P::node_component(&plug_in, n, &reader) {
unsafe {
layer.update_node(*pi, move |n| n.insert_plugin(node_component.clone()))
}
}
});
layer.refresh()
}
self.parameters.plugins.write().unwrap().insert(plug_in);
}
pub(crate) unsafe fn layer(&mut self, scale_index: i32) -> &mut CoverLayerWriter<D> {
&mut self.layers[self.parameters.internal_index(scale_index)]
}
pub(crate) unsafe fn update_node<F>(&mut self, address: NodeAddress, update_fn: F)
where
F: Fn(&mut CoverNode<D>) + 'static + Send + Sync,
{
self.layers[self.parameters.internal_index(address.0)].update_node(address.1, update_fn);
}
pub fn reader(&self) -> CoverTreeReader<D> {
CoverTreeReader {
parameters: Arc::clone(&self.parameters),
layers: self.layers.iter().map(|l| l.reader()).collect(),
root_address: self.root_address,
final_addresses: self.final_addresses.factory().handle(),
}
}
pub(crate) unsafe fn insert_raw(
&mut self,
scale_index: i32,
point_index: usize,
node: CoverNode<D>,
) {
self.layers[self.parameters.internal_index(scale_index)].insert_raw(point_index, node);
}
pub fn load(cover_proto: &CoreProto, point_cloud: Arc<D>) -> GokoResult<CoverTreeWriter<D>> {
let partition_type = if cover_proto.partition_type == "first" {
PartitionType::First
} else {
PartitionType::Nearest
};
let parameters = Arc::new(CoverTreeParameters {
total_nodes: atomic::AtomicUsize::new(0),
use_singletons: cover_proto.use_singletons,
scale_base: cover_proto.scale_base as f32,
leaf_cutoff: cover_proto.cutoff as usize,
min_res_index: cover_proto.resolution as i32,
point_cloud,
verbosity: 2,
partition_type,
plugins: RwLock::new(TreePluginSet::new()),
rng_seed: None,
});
let root_address = (
cover_proto.get_root_scale(),
cover_proto.get_root_index() as usize,
);
let layers: Vec<CoverLayerWriter<D>> = cover_proto
.get_layers()
.par_iter()
.map(|l| CoverLayerWriter::load(l))
.collect();
let (_final_addresses_reader, final_addresses) = monomap::new();
let mut tree = CoverTreeWriter {
parameters,
layers,
root_address,
final_addresses,
};
tree.refresh_final_indexes();
Ok(tree)
}
pub fn refresh_final_indexes(&mut self) {
let reader = self.reader();
let mut unvisited_nodes: Vec<NodeAddress> = vec![self.root_address];
while !unvisited_nodes.is_empty() {
let cur_add = unvisited_nodes.pop().unwrap();
reader
.get_node_and(cur_add, |n| {
for singleton in n.singletons() {
self.final_addresses.insert(*singleton, cur_add);
}
if let Some((nested_si, child_addresses)) = n.children() {
unvisited_nodes.extend(child_addresses);
unvisited_nodes.push((nested_si, cur_add.1));
} else {
self.final_addresses.insert(cur_add.1, cur_add);
}
})
.unwrap();
}
self.final_addresses.refresh();
self.final_addresses.refresh();
}
pub fn save(&self) -> CoreProto {
let mut cover_proto = CoreProto::new();
match self.parameters.partition_type {
PartitionType::First => cover_proto.set_partition_type("first".to_string()),
PartitionType::Nearest => cover_proto.set_partition_type("nearest".to_string()),
}
cover_proto.set_scale_base(self.parameters.scale_base);
cover_proto.set_cutoff(self.parameters.leaf_cutoff as u64);
cover_proto.set_resolution(self.parameters.min_res_index);
cover_proto.set_use_singletons(self.parameters.use_singletons);
cover_proto.set_dim(self.parameters.point_cloud.dim() as u64);
cover_proto.set_count(self.parameters.point_cloud.len() as u64);
cover_proto.set_root_scale(self.root_address.0);
cover_proto.set_root_index(self.root_address.1 as u64);
cover_proto.set_layers(self.layers.iter().map(|l| l.save()).collect());
cover_proto
}
pub fn refresh(&mut self) {
self.layers.iter_mut().rev().for_each(|l| l.refresh());
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::utils::cover_tree_from_labeled_yaml;
use std::path::Path;
pub(crate) fn build_mnist_tree() -> CoverTreeWriter<DefaultLabeledCloud<L2>> {
let file_name = "../data/mnist_complex.yml";
let path = Path::new(file_name);
if !path.exists() {
panic!(file_name.to_owned() + &" does not exist".to_string());
}
cover_tree_from_labeled_yaml(&path).unwrap()
}
pub(crate) fn build_basic_tree() -> CoverTreeWriter<DefaultLabeledCloud<L2>> {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: true,
partition_type: PartitionType::Nearest,
verbosity: 0,
rng_seed: Some(0),
};
builder.build(Arc::new(point_cloud)).unwrap()
}
#[test]
fn len_is_num_layers() {
let tree = build_basic_tree();
let reader = tree.reader();
let mut l = 0;
for _ in reader.layers() {
l += 1;
}
assert_eq!(reader.len(), l);
}
#[test]
fn layer_has_correct_scale_index() {
let tree = build_basic_tree();
let reader = tree.reader();
let mut got_one = false;
for (si, l) in reader.layers() {
println!(
"Scale Index, correct: {:?}, Scale Index, layer: {:?}",
si,
l.scale_index()
);
assert_eq!(si, l.scale_index());
got_one = true;
}
assert!(got_one);
}
#[test]
fn greedy_knn_nodes() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
partition_type: PartitionType::Nearest,
verbosity: 0,
rng_seed: Some(0),
};
let tree = builder.build(Arc::new(point_cloud)).unwrap();
let reader = tree.reader();
let point = [-0.5];
let mut query_heap = KnnQueryHeap::new(5, reader.parameters.scale_base);
let dist_to_root = reader
.parameters
.point_cloud
.distances_to_point(&point.as_ref(), &[reader.root_address().1])
.unwrap()[0];
query_heap.push_nodes(&[reader.root_address()], &[dist_to_root], None);
assert_eq!(
reader.root_address(),
query_heap
.closest_unvisited_child_covering_address()
.unwrap()
.1
);
reader.greedy_knn_nodes(&point.as_ref(), &mut query_heap);
println!("{:#?}", query_heap);
println!(
"{:#?}",
query_heap.closest_unvisited_child_covering_address()
);
}
#[test]
fn path_sanity() {
let writer = build_basic_tree();
let reader = writer.reader();
let trace = reader.path(&[0.495f32].as_ref()).unwrap();
assert!(trace.len() == 4 || trace.len() == 3);
println!("{:?}", trace);
for i in 0..(trace.len() - 1) {
assert!((trace[i].1).0 > (trace[i + 1].1).0);
}
}
#[test]
fn known_path_sanity() {
let writer = build_basic_tree();
let reader = writer.reader();
for i in 0..5 {
let trace = reader.known_path(i).unwrap();
println!("i {}, trace {:?}", i, trace);
println!(
"final address: {:?}",
reader.final_addresses.get_and(&i, |i| *i)
);
let ad = trace.last().unwrap().1;
reader
.get_node_and(ad, |n| {
if !n.is_leaf() {
assert!(n.singletons().contains(&i));
} else {
assert!(
(ad.1 != i && n.singletons().contains(&i))
|| (ad.1 == i && !n.singletons().contains(&i))
);
}
})
.unwrap();
}
let known_trace = reader.known_path(4).unwrap();
let trace = reader.path(&[0.0f32].as_ref()).unwrap();
println!(
"Testing known: {:?} matches unknown {:?}",
known_trace, trace
);
for (p, kp) in trace.iter().zip(known_trace) {
assert_eq!(*p, kp);
}
}
#[test]
fn knn_singletons_on() {
println!("2 nearest neighbors of 0.0 are 0.48 and 0.0");
let writer = build_basic_tree();
let reader = writer.reader();
let zero_nbrs = reader.knn(&[0.1f32].as_ref(), 2).unwrap();
println!("{:?}", zero_nbrs);
assert!(zero_nbrs[0].1 == 4);
assert!(zero_nbrs[1].1 == 2);
}
#[test]
fn label_summary() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
partition_type: PartitionType::Nearest,
verbosity: 0,
rng_seed: Some(0),
};
let mut tree = builder.build(Arc::new(point_cloud)).unwrap();
tree.generate_summaries();
let reader = tree.reader();
for (_, layer) in reader.layers() {
layer.for_each_node(|_, n| println!("{:?}", n.label_summary()));
}
let l = reader
.get_node_label_summary(reader.root_address())
.unwrap();
assert_eq!(l.summary.items.len(), 2);
assert_eq!(l.nones, 0);
assert_eq!(l.errors, 0);
}
#[test]
fn knn_singletons_off() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = DefaultLabeledCloud::<L2>::new_simple(data, 1, labels);
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
partition_type: PartitionType::Nearest,
verbosity: 0,
rng_seed: Some(0),
};
let tree = builder.build(Arc::new(point_cloud)).unwrap();
let reader = tree.reader();
println!("2 nearest neighbors of 0.1 are 0.48 and 0.0");
let zero_nbrs = reader.knn(&[0.1f32].as_ref(), 2).unwrap();
println!("{:?}", zero_nbrs);
assert!(zero_nbrs[0].1 == 4);
assert!(zero_nbrs[1].1 == 2);
}
#[test]
fn test_save_load_tree() {
let data = vec![0.499, 0.49, 0.48, -0.49, 0.0];
let labels = vec![0, 0, 0, 1, 1];
let point_cloud = Arc::new(DefaultLabeledCloud::<L2>::new_simple(data, 1, labels));
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
partition_type: PartitionType::Nearest,
verbosity: 0,
rng_seed: Some(0),
};
let tree = builder.build(Arc::clone(&point_cloud)).unwrap();
let reader = tree.reader();
let proto = tree.save();
assert_eq!(reader.layers.len(), proto.get_layers().len());
for (layer, proto_layer) in reader.layers.iter().zip(proto.get_layers()) {
assert_eq!(layer.len(), proto_layer.get_nodes().len());
}
let reconstructed_tree_writer =
CoverTreeWriter::load(&proto, Arc::clone(&point_cloud)).unwrap();
let reconstructed_tree = reconstructed_tree_writer.reader();
assert_eq!(reader.layers.len(), reconstructed_tree.layers.len());
for (layer, reconstructed_layer) in reader.layers.iter().zip(reconstructed_tree.layers) {
assert_eq!(layer.len(), reconstructed_layer.len());
layer.for_each_node(|pi, n| {
reconstructed_layer
.get_node_and(*pi, |rn| {
assert_eq!(n.address(), rn.address());
assert_eq!(n.parent_address(), rn.parent_address());
assert_eq!(n.singletons(), rn.singletons());
})
.unwrap();
})
}
}
}