oxirs_vec/
tree_indices_covertree.rs1use crate::tree_indices_types::TreeIndexConfig;
7use crate::Vector;
8use anyhow::Result;
9use std::cmp::Ordering;
10
11pub struct CoverTree {
13 pub(crate) root: Option<Box<CoverNode>>,
14 pub(crate) data: Vec<(String, Vector)>,
15 pub(crate) config: TreeIndexConfig,
16 base: f32,
17}
18
19pub(crate) struct CoverNode {
20 point: usize,
22 level: i32,
24 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
27}
28
29impl CoverTree {
30 pub fn new(config: TreeIndexConfig) -> Self {
31 Self {
32 root: None,
33 data: Vec::new(),
34 config,
35 base: 2.0, }
37 }
38
39 pub fn build(&mut self) -> Result<()> {
40 if self.data.is_empty() {
41 return Ok(());
42 }
43
44 self.root = Some(Box::new(CoverNode {
46 point: 0,
47 level: self.get_level(0),
48 children: Vec::new(),
49 }));
50
51 for idx in 1..self.data.len() {
53 self.insert(idx)?;
54 }
55
56 Ok(())
57 }
58
59 fn get_level(&self, _point_idx: usize) -> i32 {
60 ((self.data.len() as f32).log2() as i32).max(0)
62 }
63
64 fn insert(&mut self, point_idx: usize) -> Result<()> {
65 let level = self.get_level(point_idx);
68 if let Some(root) = &mut self.root {
69 root.children.push(Box::new(CoverNode {
70 point: point_idx,
71 level,
72 children: Vec::new(),
73 }));
74 }
75 Ok(())
76 }
77
78 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
79 if self.root.is_none() {
80 return Vec::new();
81 }
82
83 let mut results = Vec::new();
84 self.search_node(
85 self.root
86 .as_ref()
87 .expect("tree should have root after build"),
88 query,
89 k,
90 &mut results,
91 );
92
93 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
94 results.truncate(k);
95 results
96 }
97
98 #[allow(clippy::only_used_in_recursion)]
99 fn search_node(
100 &self,
101 node: &CoverNode,
102 query: &[f32],
103 k: usize,
104 results: &mut Vec<(usize, f32)>,
105 ) {
106 if results.len() >= k * 10 {
108 return;
109 }
110
111 let point_data = &self.data[node.point].1.as_f32();
112 let dist = self.config.distance_metric.distance(query, point_data);
113
114 results.push((node.point, dist));
115
116 for child in &node.children {
118 self.search_node(child, query, k, results);
119 }
120 }
121}