use crate::plugins::TreePluginSet;
use crate::*;
use super::*;
use super::data_caches::*;
use super::layer::*;
use super::node::*;
use pbr::ProgressBar;
use std::fs::read_to_string;
use std::cmp::{max, min};
use std::path::Path;
use std::sync::{atomic, Arc, RwLock};
use yaml_rust::YamlLoader;
use crossbeam_channel::{unbounded, Receiver, Sender};
use errors::GokoResult;
use std::time::Instant;
#[derive(Debug)]
struct BuilderNode {
parent_address: Option<NodeAddress>,
scale_index: i32,
covered: CoveredData,
}
type NodeSplitResult<D> = GokoResult<(i32, PointIndex, CoverNode<D>)>;
impl BuilderNode {
fn new<D: PointCloud>(parameters: &CoverTreeParameters<D>) -> GokoResult<BuilderNode> {
let covered = CoveredData::new::<D>(¶meters.point_cloud)?;
let scale_index = (covered.max_distance()).log(parameters.scale_base).ceil() as i32;
Ok(BuilderNode {
parent_address: None,
scale_index,
covered,
})
}
#[inline]
fn address(&self) -> NodeAddress {
(self.scale_index, self.covered.center_index)
}
fn split_parallel<D: PointCloud>(
self,
parameters: &Arc<CoverTreeParameters<D>>,
node_sender: &Arc<Sender<NodeSplitResult<D>>>,
) {
let parameters = Arc::clone(parameters);
let node_sender = Arc::clone(node_sender);
rayon::spawn(move || {
let (si, pi) = self.address();
match self.split(¶meters) {
Ok((new_node, mut new_nodes)) => {
node_sender.send(Ok((si, pi, new_node))).unwrap();
while let Some(node) = new_nodes.pop() {
node.split_parallel(¶meters, &node_sender);
}
}
Err(e) => node_sender.send(Err(e)).unwrap(),
};
});
}
fn split<D: PointCloud>(
self,
parameters: &Arc<CoverTreeParameters<D>>,
) -> GokoResult<(CoverNode<D>, Vec<BuilderNode>)> {
let scale_index = self.scale_index;
let covered = self.covered;
let current_address = (scale_index, covered.center_index);
let mut node = CoverNode::new(self.parent_address,current_address);
let radius = covered.max_distance();
let mut new_nodes = Vec::new();
node.set_radius(radius);
if covered.len() <= parameters.leaf_cutoff || scale_index < parameters.min_res_index {
node.insert_singletons(covered.into_indexes());
} else {
let next_scale_index = min(
scale_index - 1,
max(
radius.log(parameters.scale_base).ceil() as i32,
parameters.min_res_index,
),
);
let next_scale = parameters.scale_base.powi(next_scale_index);
let (close, mut fars) = covered.split(next_scale).unwrap();
node.insert_nested_child(next_scale_index, close.len())?;
let new_node = BuilderNode {
parent_address: Some(current_address),
scale_index: next_scale_index,
covered: close,
};
new_nodes.push(new_node);
parameters
.total_nodes
.fetch_add(1, atomic::Ordering::SeqCst);
while fars.len() > 0 {
let new_close = fars.pick_center(next_scale, ¶meters.point_cloud)?;
if new_close.len() == 1 && parameters.use_singletons {
node.insert_singleton(new_close.center_index);
} else {
node.insert_child((next_scale_index, new_close.center_index), new_close.len())?;
let new_node = BuilderNode {
parent_address: Some(current_address),
scale_index: next_scale_index,
covered: new_close,
};
new_nodes.push(new_node);
parameters
.total_nodes
.fetch_add(1, atomic::Ordering::SeqCst);
}
}
}
if new_nodes.len() == 1 && new_nodes[0].covered.len() == 1 {
node.remove_children();
parameters
.total_nodes
.fetch_sub(1, atomic::Ordering::SeqCst);
node.insert_singletons(new_nodes.pop().unwrap().covered.into_indexes());
}
Ok((node, new_nodes))
}
}
#[derive(Debug, Default)]
pub struct CoverTreeBuilder {
pub scale_base: f32,
pub leaf_cutoff: usize,
pub min_res_index: i32,
pub use_singletons: bool,
pub verbosity: u32,
}
impl CoverTreeBuilder {
pub fn new() -> CoverTreeBuilder {
CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -10,
use_singletons: true,
verbosity: 2,
}
}
pub fn from_yaml<P: AsRef<Path>>(path: P) -> Self {
let config = read_to_string(&path).expect("Unable to read config file");
let params_files = YamlLoader::load_from_str(&config).unwrap();
let params = ¶ms_files[0];
CoverTreeBuilder {
scale_base: params["scale_base"].as_f64().unwrap_or(2.0) as f32,
leaf_cutoff: params["leaf_cutoff"].as_i64().unwrap_or(1) as usize,
min_res_index: params["min_res_index"].as_i64().unwrap_or(-10) as i32,
use_singletons: params["use_singletons"].as_bool().unwrap_or(true),
verbosity: params["verbosity"].as_i64().unwrap_or(2) as u32,
}
}
pub fn set_scale_base(&mut self, x: f32) -> &mut Self {
self.scale_base = x;
self
}
pub fn set_leaf_cutoff(&mut self, x: usize) -> &mut Self {
self.leaf_cutoff = x;
self
}
pub fn set_min_res_index(&mut self, x: i32) -> &mut Self {
self.min_res_index = x;
self
}
pub fn set_use_singletons(&mut self, x: bool) -> &mut Self {
self.use_singletons = x;
self
}
pub fn set_verbosity(&mut self, x: u32) -> &mut Self {
self.verbosity = x;
self
}
pub fn build<D: PointCloud>(&self, point_cloud: Arc<D>) -> GokoResult<CoverTreeWriter<D>> {
let parameters = CoverTreeParameters {
total_nodes: atomic::AtomicUsize::new(1),
scale_base: self.scale_base,
leaf_cutoff: self.leaf_cutoff,
min_res_index: self.min_res_index,
use_singletons: self.use_singletons,
point_cloud,
verbosity: self.verbosity,
plugins: RwLock::new(TreePluginSet::new()),
};
let root = BuilderNode::new(¶meters)?;
let root_address = root.address();
let scale_range = root_address.0 - parameters.min_res_index;
let mut layers = Vec::with_capacity(scale_range as usize);
layers.push(CoverLayerWriter::new(parameters.min_res_index - 1));
for i in 0..(scale_range + 1) {
layers.push(CoverLayerWriter::new(parameters.min_res_index + i as i32));
}
let (node_sender, node_receiver): (
Sender<NodeSplitResult<D>>,
Receiver<NodeSplitResult<D>>,
) = unbounded();
let node_sender = Arc::new(node_sender);
let parameters = Arc::new(parameters);
root.split_parallel(¶meters, &node_sender);
let mut pb = ProgressBar::new(1u64);
if parameters.verbosity > 1 {
pb.format("╢▌▌░╟");
}
let (_final_addresses_reader, final_addresses) = monomap::new();
let mut cover_tree = CoverTreeWriter {
parameters: Arc::clone(¶meters),
layers,
root_address,
final_addresses,
};
let mut inserted_nodes: usize = 0;
let now = Instant::now();
loop {
if let Ok(res) = node_receiver.recv() {
let (scale_index, point_index, new_node) = res.unwrap();
unsafe {
cover_tree.insert_raw(scale_index, point_index, new_node);
}
inserted_nodes += 1;
if parameters.verbosity > 1 {
pb.total = parameters.total_nodes.load(atomic::Ordering::SeqCst) as u64;
pb.inc();
}
}
if inserted_nodes == parameters.total_nodes.load(atomic::Ordering::SeqCst) {
break;
}
}
if parameters.verbosity > 1 {
println!("\nWriting layers...");
}
cover_tree.refresh();
if parameters.verbosity > 1 {
println!(
"Finished building, took {:?} with {} per second",
now.elapsed(),
(inserted_nodes as f32) / now.elapsed().as_secs_f32()
);
}
Ok(cover_tree)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{thread, time};
pub fn create_test_parameters(
data: Vec<f32>,
data_dim: usize,
) -> Arc<CoverTreeParameters<DefaultCloud<L2>>> {
let point_cloud = Arc::new(DefaultCloud::<L2>::new(data, data_dim).unwrap());
Arc::new(CoverTreeParameters {
total_nodes: atomic::AtomicUsize::new(1),
scale_base: 2.0,
leaf_cutoff: 0,
min_res_index: -9,
use_singletons: true,
point_cloud,
verbosity: 0,
plugins: RwLock::new(TreePluginSet::new()),
})
}
#[test]
fn splits_conditions() {
let mut data = Vec::with_capacity(20);
for _i in 0..19 {
data.push(rand::random::<f32>());
}
data.push(0.0);
let test_parameters = create_test_parameters(data, 1);
let build_node = BuilderNode::new(&test_parameters).unwrap();
let (scale_index, center_index) = build_node.address();
println!("{:?}", build_node);
println!(
"The center_index for the covered data should be 19 but is {}",
build_node.covered.center_index
);
assert!(center_index == 19);
println!("The scale_index should be 0, but is {}", scale_index);
assert!(scale_index == 0);
let (new_node, unfinished_nodes) = build_node.split(&test_parameters).unwrap();
let split_count = test_parameters.total_nodes.load(atomic::Ordering::SeqCst) - 1;
println!(
"We should have split count be equal to the work count: split {} , work {}",
split_count,
unfinished_nodes.len()
);
println!("We shouldn't be a leaf: {}", new_node.is_leaf());
assert!(!new_node.is_leaf());
println!(
"We should have children count be equal to the split count: {}",
new_node.children_len()
);
assert!(new_node.children_len() == split_count);
}
#[test]
fn tree_structure_condition() {
let data = vec![0.49, 0.491, -0.49, 0.0];
let test_parameters = create_test_parameters(data, 1);
let build_node = BuilderNode::new(&test_parameters).unwrap();
let (node_sender, node_receiver): (
Sender<GokoResult<(i32, PointIndex, CoverNode<DefaultCloud<L2>>)>>,
Receiver<GokoResult<(i32, PointIndex, CoverNode<DefaultCloud<L2>>)>>,
) = unbounded();
let node_sender = Arc::new(node_sender);
build_node.split_parallel(&test_parameters, &node_sender);
thread::sleep(time::Duration::from_millis(100));
let split_count = test_parameters.total_nodes.load(atomic::Ordering::SeqCst) - 1;
println!(
"Split count {}, node_receiver {}",
split_count,
node_receiver.len()
);
assert!(split_count + 1 == node_receiver.len());
assert!(split_count == 3);
while let Ok(pat) = node_receiver.try_recv() {
let (scale_index, center_index, node) = pat.unwrap();
println!("{:?}", node);
match (scale_index, center_index) {
(-1, 3) => assert!(!node.is_leaf()),
(-2, 3) => assert!(node.is_leaf()),
(-2, 2) => assert!(node.is_leaf()),
(-2, 0) => assert!(!node.is_leaf()),
(-2, 1) => assert!(!node.is_leaf()),
_ => {}
};
}
}
#[test]
fn insertion_tree_structure_condition() {
let data = vec![0.49, 0.491, -0.49, 0.0];
let point_cloud = Arc::new(DefaultCloud::<L2>::new(data, 1).unwrap());
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: true,
verbosity: 0,
};
let tree = builder.build(point_cloud).unwrap();
let reader = tree.reader();
println!("Testing top layer");
let top_layer = reader.layer(-1);
println!("Should only be one node");
assert!(top_layer.len() == 1);
println!("The root should not be a leaf");
assert!(reader.get_node_and((-1, 3), |n| !n.is_leaf()).unwrap());
println!("The root should have children");
assert!(reader
.get_node_and((-1, 3), |n| n.children().is_some())
.unwrap());
println!("Testing Mid Layer");
let mid_layer = reader.layer(-2);
println!("Should have 2 nodes");
assert!(mid_layer.len() == 2);
println!("Nested child of root should leafify");
assert!(reader.get_node_and((-2, 3), |n| n.is_leaf()).unwrap());
println!("Nested child of root should not have any children");
assert!(reader
.get_node_and((-2, 3), |n| n.children().is_none())
.unwrap());
println!("-0.49 is a singleton that shouldn't be here.");
assert!(reader.get_node_and((-2, 2), |n| n.is_leaf()).is_none());
assert!(reader.no_dangling_refs());
}
#[test]
fn singleltons_off_condition() {
let data = vec![0.49, 0.491, -0.49, 0.0];
let point_cloud = Arc::new(DefaultCloud::<L2>::new(data, 1).unwrap());
let builder = CoverTreeBuilder {
scale_base: 2.0,
leaf_cutoff: 1,
min_res_index: -9,
use_singletons: false,
verbosity: 0,
};
let tree = builder.build(point_cloud).unwrap();
let reader = tree.reader();
println!("-0.49 is a singleton that should be here.");
assert!(reader.get_node_and((-2, 2), |n| n.is_leaf()).is_some());
assert!(reader.no_dangling_refs());
}
}