use super::aabb::AABB;
use super::visitor::{CellRelation, IntersectVisitor, RangeQueryVisitor};
use crate::error::Result;
use crate::storage::structured::{StructReader, StructWriter};
use crate::storage::{Storage, StorageInput, StorageOutput};
use std::io::SeekFrom;
use std::sync::Arc;
pub trait BKDTree: Send + Sync + std::fmt::Debug {
fn intersect(&self, visitor: &mut dyn IntersectVisitor) -> Result<()>;
fn range_search(
&self,
mins: &[Option<f64>],
maxs: &[Option<f64>],
include_min: bool,
include_max: bool,
) -> Result<Vec<u64>> {
let mut visitor = RangeQueryVisitor::new(mins, maxs, include_min, include_max);
self.intersect(&mut visitor)?;
let mut hits = visitor.into_hits();
hits.sort_unstable();
hits.dedup();
Ok(hits)
}
}
pub const BKD_MAGIC: u32 = 0x54444B42;
pub const BKD_VERSION: u32 = 2;
#[derive(Debug, Clone)]
pub struct BKDFileHeader {
pub magic: u32,
pub version: u32,
pub num_dims: u32,
pub bytes_per_dim: u32,
pub total_point_count: u64,
pub num_blocks: u64,
pub min_values: Vec<f64>,
pub max_values: Vec<f64>,
pub index_start_offset: u64,
pub root_node_offset: u64,
}
pub struct BKDWriter<W: StorageOutput> {
writer: StructWriter<W>,
block_size: usize,
num_blocks: u64,
num_dims: u32,
min_values: Vec<f64>,
max_values: Vec<f64>,
index_nodes: Vec<IndexNode>,
}
#[derive(Debug, Clone)]
struct IndexNode {
split_dim: u32,
split_value: f64,
left_min: Vec<f64>,
left_max: Vec<f64>,
right_min: Vec<f64>,
right_max: Vec<f64>,
left_offset: u64,
right_offset: u64,
left_child_idx: Option<usize>,
right_child_idx: Option<usize>,
}
struct SubtreeInfo {
node_idx: Option<usize>,
min: Vec<f64>,
max: Vec<f64>,
}
struct BuildContext<'a> {
points: &'a [f64],
doc_ids: &'a [u64],
num_dims: usize,
}
impl BuildContext<'_> {
#[inline]
fn value(&self, i: u32, d: usize) -> f64 {
self.points[i as usize * self.num_dims + d]
}
}
fn compute_aabb(ctx: &BuildContext<'_>, indices: &[u32]) -> (Vec<f64>, Vec<f64>) {
let mut min = vec![f64::INFINITY; ctx.num_dims];
let mut max = vec![f64::NEG_INFINITY; ctx.num_dims];
for &i in indices {
let base = i as usize * ctx.num_dims;
for d in 0..ctx.num_dims {
let v = ctx.points[base + d];
if v < min[d] {
min[d] = v;
}
if v > max[d] {
max[d] = v;
}
}
}
(min, max)
}
fn widest_axis(min: &[f64], max: &[f64]) -> u32 {
debug_assert_eq!(min.len(), max.len());
debug_assert!(!min.is_empty());
let mut best = 0usize;
let mut best_range = max[0] - min[0];
for d in 1..min.len() {
let r = max[d] - min[d];
if r > best_range {
best = d;
best_range = r;
}
}
best as u32
}
impl<W: StorageOutput> BKDWriter<W> {
pub fn new(writer: W, num_dims: u32) -> Self {
BKDWriter {
writer: StructWriter::new(writer),
block_size: 512,
num_blocks: 0,
num_dims,
min_values: vec![f64::MAX; num_dims as usize],
max_values: vec![f64::MIN; num_dims as usize],
index_nodes: Vec::new(),
}
}
pub fn with_block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size;
self
}
pub fn write(&mut self, points: &[f64], doc_ids: &[u64]) -> Result<()> {
let num_dims = self.num_dims as usize;
let expected = doc_ids.len().checked_mul(num_dims).ok_or_else(|| {
crate::error::LaurusError::index(
"Point count overflows when multiplied by num_dims".to_string(),
)
})?;
if points.len() != expected {
return Err(crate::error::LaurusError::index(format!(
"Point buffer size mismatch: expected {} doc_ids * {} dims = {} f64s, got {}",
doc_ids.len(),
num_dims,
expected,
points.len()
)));
}
if doc_ids.is_empty() {
self.write_header(0, 0, 0)?;
return Ok(());
}
for (offset, &v) in points.iter().enumerate() {
if v.is_nan() {
let doc_idx = offset / num_dims;
let dim = offset % num_dims;
return Err(crate::error::LaurusError::index(format!(
"Point at doc index {doc_idx} dim {dim} is NaN; BKD requires \
totally-ordered values (NaN has no defined ordering)"
)));
}
}
for i in 0..doc_ids.len() {
let base = i * num_dims;
for d in 0..num_dims {
let v = points[base + d];
self.min_values[d] = self.min_values[d].min(v);
self.max_values[d] = self.max_values[d].max(v);
}
}
let total_count = doc_ids.len() as u64;
let header_size = 4 + 4 + 4 + 4 + 8 + 8 + (self.num_dims as u64 * 8 * 2) + 8 + 8;
self.writer.write_u32(0)?; self.writer.seek(SeekFrom::Start(header_size))?;
let mut indices: Vec<u32> = (0..doc_ids.len() as u32).collect();
let ctx = BuildContext {
points,
doc_ids,
num_dims,
};
let root_info = self.build_subtree(&ctx, &mut indices)?;
let index_start_offset = self.writer.stream_position()?;
self.write_index()?;
let node_size = Self::node_size(self.num_dims);
let root_node_offset = if let Some(idx) = root_info.node_idx {
index_start_offset + (idx as u64) * node_size
} else {
header_size
};
self.writer.seek(SeekFrom::Start(0))?;
self.write_header(total_count, index_start_offset, root_node_offset)?;
self.writer.seek(SeekFrom::End(0))?;
Ok(())
}
fn write_header(&mut self, total_count: u64, index_start: u64, root_offset: u64) -> Result<()> {
self.writer.write_u32(BKD_MAGIC)?;
self.writer.write_u32(BKD_VERSION)?;
self.writer.write_u32(self.num_dims)?;
self.writer.write_u32(8)?; self.writer.write_u64(total_count)?;
self.writer.write_u64(self.num_blocks)?;
for &v in &self.min_values {
self.writer.write_f64(v)?;
}
for &v in &self.max_values {
self.writer.write_f64(v)?;
}
self.writer.write_u64(index_start)?;
self.writer.write_u64(root_offset)?;
Ok(())
}
#[inline]
fn node_size(num_dims: u32) -> u64 {
28 + 32 * num_dims as u64
}
fn build_subtree(
&mut self,
ctx: &BuildContext<'_>,
indices: &mut [u32],
) -> Result<SubtreeInfo> {
if indices.is_empty() {
return Err(crate::error::LaurusError::index(
"build_subtree called with empty indices".to_string(),
));
}
let (subtree_min, subtree_max) = compute_aabb(ctx, indices);
if indices.len() <= self.block_size {
self.write_leaf_block(ctx, indices, &subtree_min, &subtree_max)?;
self.num_blocks += 1;
return Ok(SubtreeInfo {
node_idx: None,
min: subtree_min,
max: subtree_max,
});
}
let split_dim = widest_axis(&subtree_min, &subtree_max);
let split_dim_us = split_dim as usize;
indices.sort_by(|&a, &b| {
ctx.value(a, split_dim_us)
.total_cmp(&ctx.value(b, split_dim_us))
});
let mid = indices.len() / 2;
let (left_indices, right_indices) = indices.split_at_mut(mid);
let split_value = ctx.value(right_indices[0], split_dim_us);
let node_idx = self.index_nodes.len();
self.index_nodes.push(IndexNode {
split_dim,
split_value,
left_min: Vec::new(),
left_max: Vec::new(),
right_min: Vec::new(),
right_max: Vec::new(),
left_offset: 0,
right_offset: 0,
left_child_idx: None,
right_child_idx: None,
});
let left_file_pos_before = self.writer.stream_position()?;
let left_info = self.build_subtree(ctx, left_indices)?;
let left_is_leaf = left_info.node_idx.is_none();
let right_file_pos_before = self.writer.stream_position()?;
let right_info = self.build_subtree(ctx, right_indices)?;
let right_is_leaf = right_info.node_idx.is_none();
let node = &mut self.index_nodes[node_idx];
node.left_child_idx = left_info.node_idx;
node.right_child_idx = right_info.node_idx;
node.left_min = left_info.min;
node.left_max = left_info.max;
node.right_min = right_info.min;
node.right_max = right_info.max;
if left_is_leaf {
node.left_offset = left_file_pos_before;
}
if right_is_leaf {
node.right_offset = right_file_pos_before;
}
Ok(SubtreeInfo {
node_idx: Some(node_idx),
min: subtree_min,
max: subtree_max,
})
}
fn write_leaf_block(
&mut self,
ctx: &BuildContext<'_>,
indices: &[u32],
leaf_min: &[f64],
leaf_max: &[f64],
) -> Result<()> {
let count = indices.len() as u32;
self.writer.write_u32(count)?;
for &v in leaf_min {
self.writer.write_f64(v)?;
}
for &v in leaf_max {
self.writer.write_f64(v)?;
}
for &i in indices {
let base = i as usize * ctx.num_dims;
for d in 0..ctx.num_dims {
self.writer.write_f64(ctx.points[base + d])?;
}
}
for &i in indices {
self.writer.write_u64(ctx.doc_ids[i as usize])?;
}
Ok(())
}
fn write_index(&mut self) -> Result<()> {
let start_pos = self.writer.stream_position()?;
let node_size = Self::node_size(self.num_dims);
for i in 0..self.index_nodes.len() {
let left_idx = self.index_nodes[i].left_child_idx;
if let Some(idx) = left_idx {
self.index_nodes[i].left_offset = start_pos + (idx as u64) * node_size;
}
let right_idx = self.index_nodes[i].right_child_idx;
if let Some(idx) = right_idx {
self.index_nodes[i].right_offset = start_pos + (idx as u64) * node_size;
}
}
for node in &self.index_nodes {
self.writer.write_u32(node.split_dim)?;
self.writer.write_f64(node.split_value)?;
for &v in &node.left_min {
self.writer.write_f64(v)?;
}
for &v in &node.left_max {
self.writer.write_f64(v)?;
}
for &v in &node.right_min {
self.writer.write_f64(v)?;
}
for &v in &node.right_max {
self.writer.write_f64(v)?;
}
self.writer.write_u64(node.left_offset)?;
self.writer.write_u64(node.right_offset)?;
}
Ok(())
}
pub fn finish(self) -> Result<()> {
self.writer.close()
}
}
#[derive(Debug)]
pub struct BKDReader {
header: BKDFileHeader,
storage: Arc<dyn Storage>,
path: String,
}
impl BKDReader {
pub fn header(&self) -> &BKDFileHeader {
&self.header
}
}
impl BKDReader {
pub fn open(storage: Arc<dyn Storage>, path: &str) -> Result<Self> {
let input = storage.open_input(path)?;
let mut reader = StructReader::new(input)?;
let magic = reader.read_u32()?;
if magic != BKD_MAGIC {
return Err(crate::error::LaurusError::storage(format!(
"Invalid BKD magic: {:x}",
magic
)));
}
let version = reader.read_u32()?;
if version != BKD_VERSION {
return Err(crate::error::LaurusError::storage(format!(
"Unsupported BKD version: {} (expected {}). Pre-release format \
changes do not support older revisions; rebuild the index.",
version, BKD_VERSION
)));
}
let num_dims = reader.read_u32()?;
let bytes_per_dim = reader.read_u32()?;
let total_point_count = reader.read_u64()?;
let num_blocks = reader.read_u64()?;
let mut min_values = Vec::with_capacity(num_dims as usize);
for _ in 0..num_dims {
min_values.push(reader.read_f64()?);
}
let mut max_values = Vec::with_capacity(num_dims as usize);
for _ in 0..num_dims {
max_values.push(reader.read_f64()?);
}
let index_start_offset = reader.read_u64()?;
let root_node_offset = reader.read_u64()?;
let header = BKDFileHeader {
magic,
version,
num_dims,
bytes_per_dim,
total_point_count,
num_blocks,
min_values,
max_values,
index_start_offset,
root_node_offset,
};
Ok(BKDReader {
header,
storage,
path: path.to_string(),
})
}
fn read_child_aabb<R: StorageInput>(
reader: &mut StructReader<R>,
num_dims: usize,
) -> Result<AABB> {
let mut min = Vec::with_capacity(num_dims);
for _ in 0..num_dims {
min.push(reader.read_f64()?);
}
let mut max = Vec::with_capacity(num_dims);
for _ in 0..num_dims {
max.push(reader.read_f64()?);
}
AABB::new(min, max)
}
fn intersect_subtree<R: StorageInput>(
&self,
reader: &mut StructReader<R>,
offset: u64,
visitor: &mut dyn IntersectVisitor,
scratch: &mut IntersectScratch,
) -> Result<()> {
if offset < self.header.index_start_offset {
return self.intersect_leaf(reader, offset, visitor, scratch);
}
let num_dims = self.header.num_dims as usize;
reader.seek(SeekFrom::Start(offset))?;
let _split_dim = reader.read_u32()?;
let _split_value = reader.read_f64()?;
let left_aabb = Self::read_child_aabb(reader, num_dims)?;
let right_aabb = Self::read_child_aabb(reader, num_dims)?;
let left_offset = reader.read_u64()?;
let right_offset = reader.read_u64()?;
match visitor.compare(&left_aabb) {
CellRelation::Outside => {}
CellRelation::Inside => self.collect_subtree(reader, left_offset, visitor)?,
CellRelation::Crosses => {
self.intersect_subtree(reader, left_offset, visitor, scratch)?
}
}
match visitor.compare(&right_aabb) {
CellRelation::Outside => {}
CellRelation::Inside => self.collect_subtree(reader, right_offset, visitor)?,
CellRelation::Crosses => {
self.intersect_subtree(reader, right_offset, visitor, scratch)?
}
}
Ok(())
}
fn intersect_leaf<R: StorageInput>(
&self,
reader: &mut StructReader<R>,
offset: u64,
visitor: &mut dyn IntersectVisitor,
scratch: &mut IntersectScratch,
) -> Result<()> {
reader.seek(SeekFrom::Start(offset))?;
let count = reader.read_u32()? as usize;
let num_dims = self.header.num_dims as usize;
let leaf_aabb = Self::read_child_aabb(reader, num_dims)?;
match visitor.compare(&leaf_aabb) {
CellRelation::Outside => Ok(()),
CellRelation::Inside => {
let point_bytes = (count as u64) * (num_dims as u64) * 8;
reader.seek(SeekFrom::Current(point_bytes as i64))?;
for _ in 0..count {
let doc_id = reader.read_u64()?;
visitor.visit_inside(doc_id);
}
Ok(())
}
CellRelation::Crosses => {
let needed = count * num_dims;
let buf = scratch.point_slice(needed);
for slot in buf.iter_mut() {
*slot = reader.read_f64()?;
}
for i in 0..count {
let doc_id = reader.read_u64()?;
let point = &buf[i * num_dims..(i + 1) * num_dims];
visitor.visit(doc_id, point);
}
Ok(())
}
}
}
fn collect_subtree<R: StorageInput>(
&self,
reader: &mut StructReader<R>,
offset: u64,
visitor: &mut dyn IntersectVisitor,
) -> Result<()> {
if offset < self.header.index_start_offset {
return self.collect_leaf(reader, offset, visitor);
}
let num_dims = self.header.num_dims as usize;
reader.seek(SeekFrom::Start(offset))?;
let _split_dim = reader.read_u32()?;
let _split_value = reader.read_f64()?;
let aabb_bytes = (num_dims as u64) * 16 * 2;
reader.seek(SeekFrom::Current(aabb_bytes as i64))?;
let left_offset = reader.read_u64()?;
let right_offset = reader.read_u64()?;
self.collect_subtree(reader, left_offset, visitor)?;
self.collect_subtree(reader, right_offset, visitor)?;
Ok(())
}
fn collect_leaf<R: StorageInput>(
&self,
reader: &mut StructReader<R>,
offset: u64,
visitor: &mut dyn IntersectVisitor,
) -> Result<()> {
reader.seek(SeekFrom::Start(offset))?;
let count = reader.read_u32()? as usize;
let num_dims = self.header.num_dims as usize;
let skip_bytes = (num_dims as u64) * 16 + (count as u64) * (num_dims as u64) * 8;
reader.seek(SeekFrom::Current(skip_bytes as i64))?;
for _ in 0..count {
let doc_id = reader.read_u64()?;
visitor.visit_inside(doc_id);
}
Ok(())
}
}
struct IntersectScratch {
points: Vec<f64>,
}
impl IntersectScratch {
fn new() -> Self {
IntersectScratch { points: Vec::new() }
}
fn point_slice(&mut self, len: usize) -> &mut [f64] {
if self.points.len() < len {
self.points.resize(len, 0.0);
}
&mut self.points[..len]
}
}
impl BKDTree for BKDReader {
fn intersect(&self, visitor: &mut dyn IntersectVisitor) -> Result<()> {
if self.header.total_point_count == 0 {
return Ok(());
}
let input = self.storage.open_input(&self.path)?;
let mut reader = StructReader::new(input)?;
let root_offset = self.header.root_node_offset;
let mut scratch = IntersectScratch::new();
if root_offset < self.header.index_start_offset {
self.intersect_leaf(&mut reader, root_offset, visitor, &mut scratch)
} else {
self.intersect_subtree(&mut reader, root_offset, visitor, &mut scratch)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Storage;
use crate::storage::memory::{MemoryStorage, MemoryStorageConfig};
use std::sync::Arc;
#[test]
fn test_bkd_writer_reader_2d() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let points: Vec<f64> = vec![10.0, 20.0, 15.0, 25.0, 20.0, 30.0];
let doc_ids: Vec<u64> = vec![1, 2, 3];
{
let output = storage.create_output("test_2d.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
{
let reader = BKDReader::open(storage.clone(), "test_2d.bkd").unwrap();
assert_eq!(reader.header.num_dims, 2);
let results = reader
.range_search(
&[Some(10.0), Some(10.0)],
&[Some(15.0), Some(25.0)],
true,
true,
)
.unwrap();
assert_eq!(results, vec![1, 2]);
}
}
#[test]
fn test_bkd_writer_empty() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let points: Vec<f64> = vec![];
let doc_ids: Vec<u64> = vec![];
{
let output = storage.create_output("empty.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "empty.bkd").unwrap();
assert_eq!(reader.header.total_point_count, 0);
let results = reader
.range_search(&[None, None], &[None, None], true, true)
.unwrap();
assert!(results.is_empty());
}
#[test]
fn test_bkd_writer_size_mismatch_rejected() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let points: Vec<f64> = vec![1.0, 2.0, 3.0];
let doc_ids: Vec<u64> = vec![10, 20];
let output = storage.create_output("bad.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2);
let err = writer.write(&points, &doc_ids).unwrap_err();
assert!(
format!("{err:?}").contains("Point buffer size mismatch"),
"unexpected error: {err:?}"
);
}
#[test]
fn test_bkd_writer_reader_1d_multi_block() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 2_000;
let points: Vec<f64> = (0..n).map(|i| i as f64).collect();
let doc_ids: Vec<u64> = (0..n as u64).collect();
{
let output = storage.create_output("range1d.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1).with_block_size(128);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "range1d.bkd").unwrap();
let results = reader
.range_search(&[Some(100.0)], &[Some(200.0)], true, true)
.unwrap();
let expected: Vec<u64> = (100u64..=200u64).collect();
assert_eq!(results, expected);
}
#[test]
fn test_bkd_writer_reader_3d_multi_block() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 1_000;
let mut points: Vec<f64> = Vec::with_capacity(n * 3);
let mut doc_ids: Vec<u64> = Vec::with_capacity(n);
for i in 0..n {
let v = i as f64;
points.push(v);
points.push(v + 1000.0);
points.push(v + 2000.0);
doc_ids.push(i as u64);
}
{
let output = storage.create_output("range3d.bkd").unwrap();
let mut writer = BKDWriter::new(output, 3).with_block_size(64);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "range3d.bkd").unwrap();
assert_eq!(reader.header.num_dims, 3);
assert_eq!(reader.header.version, BKD_VERSION);
let results = reader
.range_search(
&[Some(100.0), None, None],
&[Some(150.0), Some(1200.0), None],
true,
true,
)
.unwrap();
let expected: Vec<u64> = (100u64..=150u64)
.filter(|&i| (i as f64) + 1000.0 <= 1200.0)
.collect();
assert_eq!(results, expected);
}
#[test]
fn test_bkd_reader_rejects_version_mismatch() {
use crate::storage::structured::StructWriter;
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
{
let output = storage.create_output("v1.bkd").unwrap();
let mut writer = StructWriter::new(output);
writer.write_u32(BKD_MAGIC).unwrap();
writer.write_u32(1).unwrap(); writer.write_u32(2).unwrap(); writer.write_u32(8).unwrap(); writer.write_u64(0).unwrap(); writer.write_u64(0).unwrap(); writer.write_f64(0.0).unwrap(); writer.write_f64(0.0).unwrap(); writer.write_f64(0.0).unwrap(); writer.write_f64(0.0).unwrap(); writer.write_u64(0).unwrap(); writer.write_u64(0).unwrap(); writer.close().unwrap();
}
let err = BKDReader::open(storage.clone(), "v1.bkd").unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.contains("Unsupported BKD version"),
"unexpected error: {msg}"
);
}
struct TracingVisitor {
query: AABB,
inside_hits: Vec<u64>,
crosses_hits: Vec<u64>,
}
impl TracingVisitor {
fn new(query: AABB) -> Self {
Self {
query,
inside_hits: Vec::new(),
crosses_hits: Vec::new(),
}
}
}
impl IntersectVisitor for TracingVisitor {
fn compare(&self, cell: &AABB) -> CellRelation {
let qmin = self.query.min();
let qmax = self.query.max();
let cmin = cell.min();
let cmax = cell.max();
for d in 0..cell.num_dims() {
if cmax[d] < qmin[d] || cmin[d] > qmax[d] {
return CellRelation::Outside;
}
}
for d in 0..cell.num_dims() {
if cmin[d] < qmin[d] || cmax[d] > qmax[d] {
return CellRelation::Crosses;
}
}
CellRelation::Inside
}
fn visit_inside(&mut self, doc_id: u64) {
self.inside_hits.push(doc_id);
}
fn visit(&mut self, doc_id: u64, point: &[f64]) {
if self.query.contains_point(point) {
self.crosses_hits.push(doc_id);
}
}
}
struct RecordingVisitor {
query: AABB,
relations: std::cell::RefCell<Vec<CellRelation>>,
hits: Vec<u64>,
}
impl RecordingVisitor {
fn new(query: AABB) -> Self {
Self {
query,
relations: std::cell::RefCell::new(Vec::new()),
hits: Vec::new(),
}
}
}
impl IntersectVisitor for RecordingVisitor {
fn compare(&self, cell: &AABB) -> CellRelation {
let qmin = self.query.min();
let qmax = self.query.max();
let cmin = cell.min();
let cmax = cell.max();
let mut relation = CellRelation::Inside;
for d in 0..cell.num_dims() {
if cmax[d] < qmin[d] || cmin[d] > qmax[d] {
relation = CellRelation::Outside;
break;
}
}
if !matches!(relation, CellRelation::Outside) {
for d in 0..cell.num_dims() {
if cmin[d] < qmin[d] || cmax[d] > qmax[d] {
relation = CellRelation::Crosses;
break;
}
}
}
self.relations.borrow_mut().push(relation);
relation
}
fn visit_inside(&mut self, doc_id: u64) {
self.hits.push(doc_id);
}
fn visit(&mut self, doc_id: u64, point: &[f64]) {
if self.query.contains_point(point) {
self.hits.push(doc_id);
}
}
}
#[test]
fn widest_axis_picks_largest_range() {
assert_eq!(widest_axis(&[0.0, 0.0], &[10.0, 100.0]), 1);
assert_eq!(widest_axis(&[0.0, 0.0], &[100.0, 10.0]), 0);
assert_eq!(widest_axis(&[0.0, 0.0], &[5.0, 5.0]), 0);
assert_eq!(widest_axis(&[0.0, 0.0, 0.0], &[1.0, 50.0, 10.0]), 1);
}
#[test]
fn build_subtree_root_split_is_widest_axis() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 256;
let mut points: Vec<f64> = Vec::with_capacity(n * 2);
let mut doc_ids: Vec<u64> = Vec::with_capacity(n);
for i in 0..n {
points.push(i as f64);
points.push(0.0); doc_ids.push(i as u64);
}
{
let output = storage.create_output("wide_dim0.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2).with_block_size(32);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
points.clear();
doc_ids.clear();
for i in 0..n {
points.push(0.0);
points.push(i as f64);
doc_ids.push(i as u64);
}
{
let output = storage.create_output("wide_dim1.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2).with_block_size(32);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
fn root_split_dim(storage: &Arc<MemoryStorage>, path: &str) -> u32 {
let reader = BKDReader::open(storage.clone(), path).unwrap();
let index_start = reader.header.index_start_offset;
let input = storage.open_input(path).unwrap();
let mut sr = StructReader::new(input).unwrap();
sr.seek(SeekFrom::Start(index_start)).unwrap();
sr.read_u32().unwrap()
}
assert_eq!(
root_split_dim(&storage, "wide_dim0.bkd"),
0,
"root should split on dim 0 when dim 0 is widest"
);
assert_eq!(
root_split_dim(&storage, "wide_dim1.bkd"),
1,
"root should split on dim 1 when dim 1 is widest"
);
}
#[test]
fn build_subtree_skewed_data_round_trip() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 1_000;
let mut points: Vec<f64> = Vec::with_capacity(n * 3);
let mut doc_ids: Vec<u64> = Vec::with_capacity(n);
for i in 0..n {
let v = i as f64;
points.push(v); points.push(v / (n as f64)); points.push(v / (n as f64 * 1000.0)); doc_ids.push(i as u64);
}
{
let output = storage.create_output("skewed.bkd").unwrap();
let mut writer = BKDWriter::new(output, 3).with_block_size(64);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "skewed.bkd").unwrap();
let results = reader
.range_search(
&[Some(100.0), None, None],
&[Some(200.0), None, None],
true,
true,
)
.unwrap();
assert_eq!(results, (100u64..=200u64).collect::<Vec<_>>());
}
#[test]
fn intersect_scratch_reuse_across_many_crosses_leaves() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 4_096;
let block_size: usize = 32; let points: Vec<f64> = (0..n).map(|i| i as f64).collect();
let doc_ids: Vec<u64> = (0..n as u64).collect();
{
let output = storage.create_output("scratch.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1).with_block_size(block_size);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "scratch.bkd").unwrap();
let lower = 10.5;
let upper = (n - 10) as f64 + 0.5;
let results = reader
.range_search(&[Some(lower)], &[Some(upper)], true, true)
.unwrap();
let expected: Vec<u64> = (11u64..=(n as u64 - 10)).collect();
assert_eq!(results, expected);
let results2 = reader
.range_search(&[Some(lower)], &[Some(upper)], true, true)
.unwrap();
assert_eq!(results2, expected);
}
#[test]
fn intersect_inside_avoids_per_point_filter() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 256;
let points: Vec<f64> = (0..n).map(|i| i as f64).collect();
let doc_ids: Vec<u64> = (0..n as u64).collect();
{
let output = storage.create_output("inside.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1).with_block_size(32);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "inside.bkd").unwrap();
let query = AABB::new(vec![-1e9], vec![1e9]).unwrap();
let mut v = TracingVisitor::new(query);
reader.intersect(&mut v).unwrap();
assert_eq!(v.inside_hits.len(), n);
assert!(v.crosses_hits.is_empty());
v.inside_hits.sort_unstable();
let expected: Vec<u64> = (0..n as u64).collect();
assert_eq!(v.inside_hits, expected);
}
#[test]
fn intersect_outside_prunes_subtree() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 128;
let points: Vec<f64> = (0..n).map(|i| i as f64).collect();
let doc_ids: Vec<u64> = (0..n as u64).collect();
{
let output = storage.create_output("outside.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1).with_block_size(16);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "outside.bkd").unwrap();
let query = AABB::new(vec![1000.0], vec![2000.0]).unwrap();
let mut v = RecordingVisitor::new(query);
reader.intersect(&mut v).unwrap();
assert!(v.hits.is_empty());
assert!(
v.relations
.borrow()
.iter()
.any(|r| matches!(r, CellRelation::Outside)),
"expected at least one Outside compare, got {:?}",
v.relations.borrow()
);
}
#[test]
fn intersect_crosses_filters_per_point() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 200;
let points: Vec<f64> = (0..n).map(|i| i as f64).collect();
let doc_ids: Vec<u64> = (0..n as u64).collect();
{
let output = storage.create_output("crosses.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1).with_block_size(16);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "crosses.bkd").unwrap();
let query = AABB::new(vec![50.5], vec![100.5]).unwrap();
let mut v = TracingVisitor::new(query);
reader.intersect(&mut v).unwrap();
let expected: Vec<u64> = (51u64..=100u64).collect();
let mut got = v.crosses_hits.clone();
got.append(&mut v.inside_hits.clone());
got.sort_unstable();
got.dedup();
assert_eq!(got, expected);
assert!(!v.crosses_hits.is_empty());
}
#[test]
fn range_search_default_impl_matches_legacy_semantics() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let n: usize = 500;
let points: Vec<f64> = (0..n).map(|i| i as f64).collect();
let doc_ids: Vec<u64> = (0..n as u64).collect();
{
let output = storage.create_output("legacy.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1).with_block_size(64);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "legacy.bkd").unwrap();
let inclusive = reader
.range_search(&[Some(100.0)], &[Some(200.0)], true, true)
.unwrap();
assert_eq!(inclusive, (100u64..=200u64).collect::<Vec<_>>());
let exclusive = reader
.range_search(&[Some(100.0)], &[Some(200.0)], false, false)
.unwrap();
assert_eq!(exclusive, (101u64..=199u64).collect::<Vec<_>>());
let lower_only = reader
.range_search(&[Some(490.0)], &[None], true, true)
.unwrap();
assert_eq!(lower_only, (490u64..n as u64).collect::<Vec<_>>());
}
#[test]
fn test_bkd_writer_reader_2d_single_leaf_aabb() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let points: Vec<f64> = vec![1.0, 100.0, 2.0, 200.0, 3.0, 300.0];
let doc_ids: Vec<u64> = vec![10, 20, 30];
{
let output = storage.create_output("single.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "single.bkd").unwrap();
let results = reader
.range_search(
&[Some(2.0), Some(150.0)],
&[Some(3.0), Some(250.0)],
true,
true,
)
.unwrap();
assert_eq!(results, vec![20]);
}
#[test]
fn write_rejects_nan_coordinate() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let points: Vec<f64> = vec![1.0, 2.0, f64::NAN, 4.0];
let doc_ids: Vec<u64> = vec![10, 20];
let output = storage.create_output("nan.bkd").unwrap();
let mut writer = BKDWriter::new(output, 2);
let err = writer.write(&points, &doc_ids).unwrap_err();
let msg = format!("{err:?}");
assert!(msg.contains("NaN"), "unexpected error: {msg}");
assert!(msg.contains("doc index 1"), "unexpected error: {msg}");
assert!(msg.contains("dim 0"), "unexpected error: {msg}");
}
#[test]
fn write_accepts_infinity_and_round_trips() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let points: Vec<f64> = vec![f64::NEG_INFINITY, -10.0, 0.0, 10.0, f64::INFINITY];
let doc_ids: Vec<u64> = vec![100, 200, 300, 400, 500];
{
let output = storage.create_output("inf.bkd").unwrap();
let mut writer = BKDWriter::new(output, 1);
writer.write(&points, &doc_ids).unwrap();
writer.finish().unwrap();
}
let reader = BKDReader::open(storage.clone(), "inf.bkd").unwrap();
let mut all = reader.range_search(&[None], &[None], true, true).unwrap();
all.sort_unstable();
assert_eq!(all, vec![100, 200, 300, 400, 500]);
let finite = reader
.range_search(&[Some(-100.0)], &[Some(100.0)], true, true)
.unwrap();
assert_eq!(finite, vec![200, 300, 400]);
let lower_inf = reader
.range_search(&[Some(f64::NEG_INFINITY)], &[Some(0.0)], true, true)
.unwrap();
assert_eq!(lower_inf, vec![100, 200, 300]);
let upper_inf = reader
.range_search(&[Some(0.0)], &[Some(f64::INFINITY)], true, true)
.unwrap();
assert_eq!(upper_inf, vec![300, 400, 500]);
}
}