use crate::error::Result;
use crate::storage::structured::{StructReader, StructWriter};
use crate::storage::{Storage, StorageInput, StorageOutput};
use std::cmp::Ordering;
use std::io::SeekFrom;
use std::sync::Arc;
pub trait BKDTree: Send + Sync + std::fmt::Debug {
fn range_search(
&self,
mins: &[Option<f64>],
maxs: &[Option<f64>],
include_min: bool,
include_max: bool,
) -> Result<Vec<u64>>;
}
pub const BKD_MAGIC: u32 = 0x54444B42;
pub const BKD_VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct BKDFileHeader {
pub magic: u32,
pub version: u32,
pub num_dims: u32,
pub bytes_per_dim: u32,
pub total_point_count: u64,
pub num_blocks: u64,
pub min_values: Vec<f64>,
pub max_values: Vec<f64>,
pub index_start_offset: u64,
pub root_node_offset: u64,
}
pub struct BKDWriter<W: StorageOutput> {
writer: StructWriter<W>,
block_size: usize,
num_blocks: u64,
num_dims: u32,
min_values: Vec<f64>,
max_values: Vec<f64>,
index_nodes: Vec<IndexNode>,
}
#[derive(Debug, Clone)]
struct IndexNode {
split_dim: u32,
split_value: f64,
left_offset: u64,
right_offset: u64,
left_child_idx: Option<usize>,
right_child_idx: Option<usize>,
}
impl<W: StorageOutput> BKDWriter<W> {
pub fn new(writer: W, num_dims: u32) -> Self {
BKDWriter {
writer: StructWriter::new(writer),
block_size: 512,
num_blocks: 0,
num_dims,
min_values: vec![f64::MAX; num_dims as usize],
max_values: vec![f64::MIN; num_dims as usize],
index_nodes: Vec::new(),
}
}
pub fn with_block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size;
self
}
pub fn write(&mut self, entries: &[(Vec<f64>, u64)]) -> Result<()> {
if entries.is_empty() {
self.write_header(0, 0, 0)?;
return Ok(());
}
for (vals, _) in entries {
for (i, &val) in vals.iter().enumerate() {
self.min_values[i] = self.min_values[i].min(val);
self.max_values[i] = self.max_values[i].max(val);
}
}
let total_count = entries.len() as u64;
let header_size = 4 + 4 + 4 + 4 + 8 + 8 + (self.num_dims as u64 * 8 * 2) + 8 + 8;
self.writer.write_u32(0)?; self.writer.seek(SeekFrom::Start(header_size))?;
let mut mutable_entries = entries.to_vec();
let root_node_idx = self.build_subtree(&mut mutable_entries, 0)?;
let index_start_offset = self.writer.stream_position()?;
self.write_index()?;
let node_size = 4 + 8 + 8 + 8;
let root_node_offset = if let Some(idx) = root_node_idx {
index_start_offset + (idx as u64) * node_size
} else {
header_size
};
self.writer.seek(SeekFrom::Start(0))?;
self.write_header(total_count, index_start_offset, root_node_offset)?;
self.writer.seek(SeekFrom::End(0))?;
Ok(())
}
fn write_header(&mut self, total_count: u64, index_start: u64, root_offset: u64) -> Result<()> {
self.writer.write_u32(BKD_MAGIC)?;
self.writer.write_u32(BKD_VERSION)?;
self.writer.write_u32(self.num_dims)?;
self.writer.write_u32(8)?; self.writer.write_u64(total_count)?;
self.writer.write_u64(self.num_blocks)?;
for &v in &self.min_values {
self.writer.write_f64(v)?;
}
for &v in &self.max_values {
self.writer.write_f64(v)?;
}
self.writer.write_u64(index_start)?;
self.writer.write_u64(root_offset)?;
Ok(())
}
fn build_subtree(
&mut self,
entries: &mut [(Vec<f64>, u64)],
depth: u32,
) -> Result<Option<usize>> {
if entries.is_empty() {
return Ok(None);
}
if entries.len() <= self.block_size {
self.write_leaf_block(entries)?;
self.num_blocks += 1;
return Ok(None);
}
let split_dim = depth % self.num_dims;
entries.sort_by(|a, b| {
a.0[split_dim as usize]
.partial_cmp(&b.0[split_dim as usize])
.unwrap_or(Ordering::Equal)
});
let mid = entries.len() / 2;
let (left_entries, right_entries) = entries.split_at_mut(mid);
let split_value = right_entries[0].0[split_dim as usize];
let node_idx = self.index_nodes.len();
self.index_nodes.push(IndexNode {
split_dim,
split_value,
left_offset: 0,
right_offset: 0,
left_child_idx: None,
right_child_idx: None,
});
let left_file_pos_before = self.writer.stream_position()?;
let left_child_node_idx = self.build_subtree(left_entries, depth + 1)?;
let left_is_leaf = left_child_node_idx.is_none();
let right_file_pos_before = self.writer.stream_position()?;
let right_child_node_idx = self.build_subtree(right_entries, depth + 1)?;
let right_is_leaf = right_child_node_idx.is_none();
self.index_nodes[node_idx].left_child_idx = left_child_node_idx;
self.index_nodes[node_idx].right_child_idx = right_child_node_idx;
if left_is_leaf {
self.index_nodes[node_idx].left_offset = left_file_pos_before;
}
if right_is_leaf {
self.index_nodes[node_idx].right_offset = right_file_pos_before;
}
Ok(Some(node_idx))
}
fn write_leaf_block(&mut self, entries: &[(Vec<f64>, u64)]) -> Result<()> {
let count = entries.len() as u32;
self.writer.write_u32(count)?;
for (vals, _) in entries {
for &val in vals {
self.writer.write_f64(val)?;
}
}
for (_, doc_id) in entries {
self.writer.write_u64(*doc_id)?;
}
Ok(())
}
fn write_index(&mut self) -> Result<()> {
let start_pos = self.writer.stream_position()?;
let node_size = 4 + 8 + 8 + 8;
for i in 0..self.index_nodes.len() {
let left_idx = self.index_nodes[i].left_child_idx;
if let Some(idx) = left_idx {
self.index_nodes[i].left_offset = start_pos + (idx as u64) * node_size;
}
let right_idx = self.index_nodes[i].right_child_idx;
if let Some(idx) = right_idx {
self.index_nodes[i].right_offset = start_pos + (idx as u64) * node_size;
}
}
for node in &self.index_nodes {
self.writer.write_u32(node.split_dim)?;
self.writer.write_f64(node.split_value)?;
self.writer.write_u64(node.left_offset)?;
self.writer.write_u64(node.right_offset)?;
}
Ok(())
}
pub fn finish(self) -> Result<()> {
self.writer.close()
}
}
#[derive(Debug)]
pub struct BKDReader {
header: BKDFileHeader,
storage: Arc<dyn Storage>,
path: String,
}
impl BKDReader {
pub fn open(storage: Arc<dyn Storage>, path: &str) -> Result<Self> {
let input = storage.open_input(path)?;
let mut reader = StructReader::new(input)?;
let magic = reader.read_u32()?;
if magic != BKD_MAGIC {
return Err(crate::error::LaurusError::storage(format!(
"Invalid BKD magic: {:x}",
magic
)));
}
let version = reader.read_u32()?;
let num_dims = reader.read_u32()?;
let bytes_per_dim = reader.read_u32()?;
let total_point_count = reader.read_u64()?;
let num_blocks = reader.read_u64()?;
let mut min_values = Vec::with_capacity(num_dims as usize);
for _ in 0..num_dims {
min_values.push(reader.read_f64()?);
}
let mut max_values = Vec::with_capacity(num_dims as usize);
for _ in 0..num_dims {
max_values.push(reader.read_f64()?);
}
let index_start_offset = reader.read_u64()?;
let root_node_offset = reader.read_u64()?;
let header = BKDFileHeader {
magic,
version,
num_dims,
bytes_per_dim,
total_point_count,
num_blocks,
min_values,
max_values,
index_start_offset,
root_node_offset,
};
Ok(BKDReader {
header,
storage,
path: path.to_string(),
})
}
fn visit_node<R: StorageInput>(
&self,
reader: &mut StructReader<R>,
offset: u64,
ctx: &QueryContext,
collector: &mut Vec<u64>,
) -> Result<()> {
if offset < self.header.index_start_offset {
return self.visit_leaf_block(reader, offset, ctx, collector);
}
reader.seek(SeekFrom::Start(offset))?;
let split_dim = reader.read_u32()? as usize;
let split_value = reader.read_f64()?;
let left_offset = reader.read_u64()?;
let right_offset = reader.read_u64()?;
if split_dim >= self.header.num_dims as usize {
return Err(crate::error::LaurusError::index(format!(
"Invalid split dimension {} (num_dims={})",
split_dim, self.header.num_dims
)));
}
let min = ctx.mins[split_dim];
let max = ctx.maxs[split_dim];
let go_left = min.is_none_or(|m| {
if ctx.include_min {
m <= split_value
} else {
m < split_value
}
});
if go_left {
self.visit_node(reader, left_offset, ctx, collector)?;
}
let go_right = max.is_none_or(|m| {
if ctx.include_max {
m >= split_value
} else {
m > split_value
}
});
if go_right {
self.visit_node(reader, right_offset, ctx, collector)?;
}
Ok(())
}
fn visit_leaf_block<R: StorageInput>(
&self,
reader: &mut StructReader<R>,
offset: u64,
ctx: &QueryContext,
collector: &mut Vec<u64>,
) -> Result<()> {
reader.seek(SeekFrom::Start(offset))?;
let count = reader.read_u32()?;
let mut points = Vec::with_capacity(count as usize);
for _ in 0..count {
let mut vals = Vec::with_capacity(self.header.num_dims as usize);
for _ in 0..self.header.num_dims {
vals.push(reader.read_f64()?);
}
points.push(vals);
}
let mut doc_ids = Vec::with_capacity(count as usize);
for _ in 0..count {
doc_ids.push(reader.read_u64()?);
}
for (vals, doc_id) in points.iter().zip(doc_ids.iter()) {
let mut matches = true;
for (i, &val) in vals.iter().enumerate() {
let gte_min =
ctx.mins[i].is_none_or(|m| if ctx.include_min { val >= m } else { val > m });
let lte_max =
ctx.maxs[i].is_none_or(|m| if ctx.include_max { val <= m } else { val < m });
if !gte_min || !lte_max {
matches = false;
break;
}
}
if matches {
collector.push(*doc_id);
}
}
Ok(())
}
}
struct QueryContext<'a> {
mins: &'a [Option<f64>],
maxs: &'a [Option<f64>],
include_min: bool,
include_max: bool,
}
impl BKDTree for BKDReader {
fn range_search(
&self,
mins: &[Option<f64>],
maxs: &[Option<f64>],
include_min: bool,
include_max: bool,
) -> Result<Vec<u64>> {
if self.header.total_point_count == 0 {
return Ok(Vec::new());
}
let num_dims = self.header.num_dims as usize;
if mins.len() != num_dims || maxs.len() != num_dims {
return Err(crate::error::LaurusError::index(format!(
"Dimension mismatch: expected {} dims, got mins={} maxs={}",
num_dims,
mins.len(),
maxs.len()
)));
}
let mut doc_ids = Vec::new();
let root_offset = self.header.root_node_offset;
let input = self.storage.open_input(&self.path)?;
let mut reader = StructReader::new(input)?;
if root_offset < self.header.index_start_offset && self.header.total_point_count > 0 {
let ctx = QueryContext {
mins,
maxs,
include_min,
include_max,
};
self.visit_leaf_block(&mut reader, root_offset, &ctx, &mut doc_ids)?;
} else {
let ctx = QueryContext {
mins,
maxs,
include_min,
include_max,
};
self.visit_node(&mut reader, root_offset, &ctx, &mut doc_ids)?;
}
doc_ids.sort_unstable();
doc_ids.dedup();
Ok(doc_ids)
}
}
#[derive(Debug, Clone)]
pub struct SimpleBKDTree {
entries: Vec<(Vec<f64>, u64)>,
num_dims: u32,
field_name: String,
}
impl BKDTree for SimpleBKDTree {
fn range_search(
&self,
mins: &[Option<f64>],
maxs: &[Option<f64>],
include_min: bool,
include_max: bool,
) -> Result<Vec<u64>> {
let mut doc_ids = Vec::new();
for (vals, doc_id) in &self.entries {
let mut matches = true;
for i in 0..self.num_dims as usize {
let val = vals[i];
let gte_min = mins[i].is_none_or(|m| if include_min { val >= m } else { val > m });
let lte_max = maxs[i].is_none_or(|m| if include_max { val <= m } else { val < m });
if !gte_min || !lte_max {
matches = false;
break;
}
}
if matches {
doc_ids.push(*doc_id);
}
}
doc_ids.sort_unstable();
doc_ids.dedup();
Ok(doc_ids)
}
}
impl SimpleBKDTree {
pub fn new(field_name: String, num_dims: u32, mut entries: Vec<(Vec<f64>, u64)>) -> Self {
entries.sort_by(|a, b| {
if a.0.is_empty() || b.0.is_empty() {
std::cmp::Ordering::Equal
} else {
a.0[0]
.partial_cmp(&b.0[0])
.unwrap_or(std::cmp::Ordering::Equal)
}
});
SimpleBKDTree {
entries,
num_dims,
field_name,
}
}
pub fn empty(field_name: String, num_dims: u32) -> Self {
SimpleBKDTree {
entries: Vec::new(),
num_dims,
field_name,
}
}
pub fn field_name(&self) -> &str {
&self.field_name
}
pub fn size(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Storage;
use crate::storage::memory::{MemoryStorage, MemoryStorageConfig};
use std::sync::Arc;
fn create_test_tree() -> SimpleBKDTree {
let entries = vec![
(vec![1.0], 10), (vec![3.0], 20), (vec![2.0], 30), (vec![5.0], 40), (vec![4.0], 50), (vec![1.5], 60), ];
SimpleBKDTree::new("test_field".to_string(), 1, entries)
}
#[test]
fn test_bkd_writer_reader_2d() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let entries = vec![
(vec![10.0, 20.0], 1),
(vec![15.0, 25.0], 2),
(vec![20.0, 30.0], 3),
];
{
let output = storage.create_output("test_2d.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2);
writer.write(&entries).unwrap();
writer.finish().unwrap();
}
{
let reader = BKDReader::open(storage.clone(), "test_2d.bkd").unwrap();
assert_eq!(reader.header.num_dims, 2);
let results = reader
.range_search(
&[Some(10.0), Some(10.0)],
&[Some(15.0), Some(25.0)],
true,
true,
)
.unwrap();
assert_eq!(results, vec![1, 2]);
}
}
#[test]
fn test_bkd_tree_creation() {
let tree = create_test_tree();
assert_eq!(tree.size(), 6);
assert_eq!(tree.field_name(), "test_field");
assert!(!tree.is_empty());
let expected_order = vec![
(vec![1.0], 10),
(vec![1.5], 60),
(vec![2.0], 30),
(vec![3.0], 20),
(vec![4.0], 50),
(vec![5.0], 40),
];
assert_eq!(tree.entries, expected_order);
}
#[test]
fn test_empty_tree() {
let tree = SimpleBKDTree::empty("empty_field".to_string(), 1);
assert_eq!(tree.size(), 0);
assert!(tree.is_empty());
assert_eq!(
tree.range_search(&[Some(1.0)], &[Some(5.0)], true, true)
.unwrap(),
Vec::<u64>::new()
);
}
#[test]
fn test_range_search_exact_bounds() {
let tree = create_test_tree();
let results = tree
.range_search(&[Some(2.0)], &[Some(4.0)], true, true)
.unwrap();
let mut expected = vec![30, 20, 50]; expected.sort();
assert_eq!(results, expected);
}
}