use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::data::GeoEcefPoint;
use crate::error::Result;
use crate::lexical::index::structures::aabb::AABB;
use crate::lexical::index::structures::visitor::{CellRelation, IntersectVisitor};
use crate::lexical::query::Query;
use crate::lexical::query::matcher::Matcher;
use crate::lexical::query::scorer::Scorer;
use crate::lexical::reader::LexicalIndexReader;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Geo3dDistanceQuery {
field: String,
center: GeoEcefPoint,
distance_m: f64,
boost: f32,
}
impl Geo3dDistanceQuery {
pub fn new<F: Into<String>>(field: F, center: GeoEcefPoint, distance_m: f64) -> Self {
Self {
field: field.into(),
center,
distance_m,
boost: 1.0,
}
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn field(&self) -> &str {
&self.field
}
pub fn center(&self) -> GeoEcefPoint {
self.center
}
pub fn distance_m(&self) -> f64 {
self.distance_m
}
pub fn find_matches(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<Geo3dMatch>> {
let mut matches: Vec<Geo3dMatch> = Vec::new();
let Some(bkd) = reader.get_bkd_tree(&self.field)? else {
return Ok(matches);
};
let mut visitor = SphereVisitor::new(self.center, self.distance_m);
bkd.intersect(&mut visitor)?;
let radius = self.distance_m;
for hit in visitor.hits {
let distance = hit.distance_sq.sqrt();
let score = if hit.from_inside_cell {
1.0
} else if radius <= 0.0 {
0.0
} else {
(1.0 - distance / radius).clamp(0.0, 1.0) as f32
};
matches.push(Geo3dMatch {
doc_id: hit.doc_id,
distance_m: distance,
score,
});
}
matches.sort_by(|a, b| {
a.doc_id.cmp(&b.doc_id).then_with(|| {
a.distance_m
.partial_cmp(&b.distance_m)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
matches.dedup_by_key(|m| m.doc_id);
matches.sort_by(|a, b| {
a.distance_m
.partial_cmp(&b.distance_m)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(matches)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Geo3dMatch {
pub doc_id: u64,
pub distance_m: f64,
pub score: f32,
}
struct SphereVisitor {
center: [f64; 3],
radius_sq: f64,
hits: Vec<SphereHit>,
}
#[derive(Debug, Clone, Copy)]
struct SphereHit {
doc_id: u64,
distance_sq: f64,
from_inside_cell: bool,
}
impl SphereVisitor {
fn new(center: GeoEcefPoint, radius: f64) -> Self {
let radius_clamped = radius.max(0.0);
Self {
center: [center.x, center.y, center.z],
radius_sq: radius_clamped * radius_clamped,
hits: Vec::new(),
}
}
}
impl IntersectVisitor for SphereVisitor {
fn compare(&self, cell: &AABB) -> CellRelation {
debug_assert_eq!(cell.num_dims(), 3, "SphereVisitor expects a 3D BKD");
let min_d_sq = cell.min_distance_sq_to_point(&self.center);
if min_d_sq > self.radius_sq {
return CellRelation::Outside;
}
let max_d_sq = cell.max_distance_sq_to_point(&self.center);
if max_d_sq <= self.radius_sq {
return CellRelation::Inside;
}
CellRelation::Crosses
}
fn visit_inside(&mut self, doc_id: u64) {
self.hits.push(SphereHit {
doc_id,
distance_sq: 0.0,
from_inside_cell: true,
});
}
fn visit(&mut self, doc_id: u64, point: &[f64]) {
debug_assert_eq!(point.len(), 3, "SphereVisitor expects a 3D BKD");
let dx = point[0] - self.center[0];
let dy = point[1] - self.center[1];
let dz = point[2] - self.center[2];
let d_sq = dx * dx + dy * dy + dz * dz;
if d_sq <= self.radius_sq {
self.hits.push(SphereHit {
doc_id,
distance_sq: d_sq,
from_inside_cell: false,
});
}
}
}
#[derive(Debug)]
pub struct Geo3dMatcher {
matches: Vec<Geo3dMatch>,
cursor: usize,
}
impl Geo3dMatcher {
pub fn new(matches: Vec<Geo3dMatch>) -> Self {
Self { matches, cursor: 0 }
}
}
impl Matcher for Geo3dMatcher {
fn doc_id(&self) -> u64 {
if self.cursor >= self.matches.len() {
u64::MAX
} else {
self.matches[self.cursor].doc_id
}
}
fn next(&mut self) -> Result<bool> {
self.cursor += 1;
if self.cursor < self.matches.len() {
Ok(true)
} else {
self.cursor = self.matches.len();
Ok(false)
}
}
fn skip_to(&mut self, target: u64) -> Result<bool> {
while self.cursor < self.matches.len() {
if self.matches[self.cursor].doc_id >= target {
return Ok(true);
}
self.cursor += 1;
}
Ok(false)
}
fn cost(&self) -> u64 {
self.matches.len() as u64
}
fn is_exhausted(&self) -> bool {
self.cursor >= self.matches.len()
}
}
#[derive(Debug)]
pub struct Geo3dScorer {
doc_scores: HashMap<u64, f32>,
boost: f32,
}
impl Geo3dScorer {
pub fn new(matches: Vec<Geo3dMatch>, boost: f32) -> Self {
let mut doc_scores = HashMap::with_capacity(matches.len());
for m in matches {
doc_scores.insert(m.doc_id, m.score);
}
Self { doc_scores, boost }
}
}
impl Scorer for Geo3dScorer {
fn score(&self, doc_id: u64, _term_freq: f32, _field_length: Option<f32>) -> f32 {
self.doc_scores.get(&doc_id).copied().unwrap_or(0.0) * self.boost
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn max_score(&self) -> f32 {
self.doc_scores.values().copied().fold(0.0_f32, f32::max) * self.boost
}
fn name(&self) -> &'static str {
"Geo3dScorer"
}
}
impl Query for Geo3dDistanceQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
Ok(Box::new(Geo3dMatcher::new(self.find_matches(reader)?)))
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
Ok(Box::new(Geo3dScorer::new(
self.find_matches(reader)?,
self.boost,
)))
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn description(&self) -> String {
format!(
"Geo3dDistanceQuery(field: {}, center: {:?}, distance: {}m)",
self.field, self.center, self.distance_m
)
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
Ok(self.distance_m <= 0.0)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let doc_count = reader.doc_count();
Ok(doc_count.saturating_mul(2))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Geo3dBoundingBoxQuery {
field: String,
min: GeoEcefPoint,
max: GeoEcefPoint,
boost: f32,
}
impl Geo3dBoundingBoxQuery {
pub fn new<F: Into<String>>(field: F, min: GeoEcefPoint, max: GeoEcefPoint) -> Result<Self> {
if min.x > max.x {
return Err(crate::error::LaurusError::other(format!(
"Geo3dBoundingBoxQuery: min.x ({}) must be <= max.x ({})",
min.x, max.x
)));
}
if min.y > max.y {
return Err(crate::error::LaurusError::other(format!(
"Geo3dBoundingBoxQuery: min.y ({}) must be <= max.y ({})",
min.y, max.y
)));
}
if min.z > max.z {
return Err(crate::error::LaurusError::other(format!(
"Geo3dBoundingBoxQuery: min.z ({}) must be <= max.z ({})",
min.z, max.z
)));
}
Ok(Self {
field: field.into(),
min,
max,
boost: 1.0,
})
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn field(&self) -> &str {
&self.field
}
pub fn min(&self) -> GeoEcefPoint {
self.min
}
pub fn max(&self) -> GeoEcefPoint {
self.max
}
pub fn find_matches(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<Geo3dMatch>> {
let Some(bkd) = reader.get_bkd_tree(&self.field)? else {
return Ok(Vec::new());
};
let mins = [Some(self.min.x), Some(self.min.y), Some(self.min.z)];
let maxs = [Some(self.max.x), Some(self.max.y), Some(self.max.z)];
let doc_ids = bkd.range_search(&mins, &maxs, true, true)?;
Ok(doc_ids
.into_iter()
.map(|doc_id| Geo3dMatch {
doc_id,
distance_m: 0.0,
score: 1.0,
})
.collect())
}
}
impl Query for Geo3dBoundingBoxQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
Ok(Box::new(Geo3dMatcher::new(self.find_matches(reader)?)))
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
Ok(Box::new(Geo3dScorer::new(
self.find_matches(reader)?,
self.boost,
)))
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn description(&self) -> String {
format!(
"Geo3dBoundingBoxQuery(field: {}, min: {:?}, max: {:?})",
self.field, self.min, self.max
)
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
Ok(false)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let doc_count = reader.doc_count();
Ok(doc_count.saturating_mul(2))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Geo3dNearestQuery {
field: String,
center: GeoEcefPoint,
k: usize,
initial_radius_m: f64,
max_radius_m: f64,
boost: f32,
}
impl Geo3dNearestQuery {
pub const DEFAULT_INITIAL_RADIUS_M: f64 = 1_000.0;
pub const DEFAULT_MAX_RADIUS_M: f64 = 1.0e10;
pub fn new<F: Into<String>>(field: F, center: GeoEcefPoint, k: usize) -> Self {
Self {
field: field.into(),
center,
k,
initial_radius_m: Self::DEFAULT_INITIAL_RADIUS_M,
max_radius_m: Self::DEFAULT_MAX_RADIUS_M,
boost: 1.0,
}
}
pub fn with_initial_radius(mut self, radius_m: f64) -> Self {
self.initial_radius_m = radius_m.max(0.0);
self
}
pub fn with_max_radius(mut self, radius_m: f64) -> Self {
self.max_radius_m = radius_m.max(0.0);
self
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn field(&self) -> &str {
&self.field
}
pub fn center(&self) -> GeoEcefPoint {
self.center
}
pub fn k(&self) -> usize {
self.k
}
pub fn initial_radius_m(&self) -> f64 {
self.initial_radius_m
}
pub fn max_radius_m(&self) -> f64 {
self.max_radius_m
}
pub fn find_matches(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<Geo3dMatch>> {
if self.k == 0 {
return Ok(Vec::new());
}
let Some(bkd) = reader.get_bkd_tree(&self.field)? else {
return Ok(Vec::new());
};
let mut radius = self.initial_radius_m.max(0.0);
let visitor: NearestVisitor;
loop {
let mut current = NearestVisitor::new(self.center, radius);
bkd.intersect(&mut current)?;
let mut deduped = current.hits.clone();
deduped.sort_by(|a, b| {
a.doc_id.cmp(&b.doc_id).then_with(|| {
a.distance_sq
.partial_cmp(&b.distance_sq)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
deduped.dedup_by_key(|h| h.doc_id);
let unique_count = deduped.len();
if unique_count >= self.k {
current.hits = deduped;
visitor = current;
break;
}
if radius >= self.max_radius_m {
current.hits = deduped;
visitor = current;
break;
}
let next = if radius == 0.0 {
self.max_radius_m.min(1.0)
} else {
(radius * 2.0).min(self.max_radius_m)
};
if next == radius {
current.hits = deduped;
visitor = current;
break;
}
radius = next;
}
let mut hits = visitor.hits;
hits.sort_by(|a, b| {
a.distance_sq
.partial_cmp(&b.distance_sq)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(self.k);
let max_distance = hits.last().map(|h| h.distance_sq.sqrt()).unwrap_or(0.0);
Ok(hits
.into_iter()
.map(|h| {
let distance = h.distance_sq.sqrt();
let score = if max_distance == 0.0 {
1.0
} else {
(1.0 - distance / max_distance).clamp(0.0, 1.0) as f32
};
Geo3dMatch {
doc_id: h.doc_id,
distance_m: distance,
score,
}
})
.collect())
}
}
struct NearestVisitor {
center: [f64; 3],
radius_sq: f64,
hits: Vec<NearestHit>,
}
#[derive(Debug, Clone, Copy)]
struct NearestHit {
doc_id: u64,
distance_sq: f64,
}
impl NearestVisitor {
fn new(center: GeoEcefPoint, radius_m: f64) -> Self {
let r = radius_m.max(0.0);
Self {
center: [center.x, center.y, center.z],
radius_sq: r * r,
hits: Vec::new(),
}
}
}
impl IntersectVisitor for NearestVisitor {
fn compare(&self, cell: &AABB) -> CellRelation {
debug_assert_eq!(cell.num_dims(), 3, "NearestVisitor expects a 3D BKD");
let min_d_sq = cell.min_distance_sq_to_point(&self.center);
if min_d_sq > self.radius_sq {
CellRelation::Outside
} else {
CellRelation::Crosses
}
}
fn visit_inside(&mut self, _doc_id: u64) {
}
fn visit(&mut self, doc_id: u64, point: &[f64]) {
debug_assert_eq!(point.len(), 3, "NearestVisitor expects a 3D BKD");
let dx = point[0] - self.center[0];
let dy = point[1] - self.center[1];
let dz = point[2] - self.center[2];
let d_sq = dx * dx + dy * dy + dz * dz;
if d_sq <= self.radius_sq {
self.hits.push(NearestHit {
doc_id,
distance_sq: d_sq,
});
}
}
}
impl Query for Geo3dNearestQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
Ok(Box::new(Geo3dMatcher::new(self.find_matches(reader)?)))
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
Ok(Box::new(Geo3dScorer::new(
self.find_matches(reader)?,
self.boost,
)))
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn description(&self) -> String {
format!(
"Geo3dNearestQuery(field: {}, center: {:?}, k: {}, initial_radius: {}m, max_radius: {}m)",
self.field, self.center, self.k, self.initial_radius_m, self.max_radius_m
)
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
Ok(self.k == 0)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let doc_count = reader.doc_count();
Ok(doc_count.saturating_mul(4))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cell(min: [f64; 3], max: [f64; 3]) -> AABB {
AABB::new(min.to_vec(), max.to_vec()).unwrap()
}
fn visitor(cx: f64, cy: f64, cz: f64, radius: f64) -> SphereVisitor {
SphereVisitor::new(GeoEcefPoint::new(cx, cy, cz), radius)
}
#[test]
fn sphere_visitor_compare_outside() {
let v = visitor(0.0, 0.0, 0.0, 5.0);
let c = cell([100.0, 0.0, 0.0], [110.0, 10.0, 10.0]);
assert_eq!(v.compare(&c), CellRelation::Outside);
}
#[test]
fn sphere_visitor_compare_inside() {
let v = visitor(0.0, 0.0, 0.0, 100.0);
let c = cell([-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]);
assert_eq!(v.compare(&c), CellRelation::Inside);
}
#[test]
fn sphere_visitor_compare_crosses_at_boundary() {
let v = visitor(0.0, 0.0, 0.0, 5.0);
let c = cell([0.0, 0.0, 0.0], [10.0, 0.0, 0.0]);
assert_eq!(v.compare(&c), CellRelation::Crosses);
}
#[test]
fn sphere_visitor_visit_filters_by_radius() {
let mut v = visitor(0.0, 0.0, 0.0, 5.0);
v.visit(1, &[1.0, 2.0, 2.0]); v.visit(2, &[5.0, 0.0, 0.0]);
v.visit(3, &[10.0, 0.0, 0.0]);
assert_eq!(v.hits.len(), 2);
assert_eq!(v.hits[0].doc_id, 1);
assert!(!v.hits[0].from_inside_cell);
assert_eq!(v.hits[1].doc_id, 2);
}
#[test]
fn sphere_visitor_visit_inside_marks_hit_uniformly() {
let mut v = visitor(0.0, 0.0, 0.0, 100.0);
v.visit_inside(42);
assert_eq!(v.hits.len(), 1);
assert_eq!(v.hits[0].doc_id, 42);
assert_eq!(v.hits[0].distance_sq, 0.0);
assert!(v.hits[0].from_inside_cell);
}
#[test]
fn sphere_visitor_negative_radius_matches_nothing() {
let mut v = visitor(0.0, 0.0, 0.0, -1.0);
v.visit(1, &[0.0, 0.0, 0.0]);
v.visit(2, &[1.0, 0.0, 0.0]);
assert_eq!(v.hits.len(), 1);
assert_eq!(v.hits[0].doc_id, 1);
}
#[test]
fn geo3d_distance_query_basics() {
let q = Geo3dDistanceQuery::new("position", GeoEcefPoint::new(1.0, 2.0, 3.0), 500.0)
.with_boost(2.0);
assert_eq!(q.field(), "position");
assert_eq!(q.center(), GeoEcefPoint::new(1.0, 2.0, 3.0));
assert_eq!(q.distance_m(), 500.0);
assert_eq!(q.boost(), 2.0);
let cloned = q.clone_box();
assert!(cloned.description().contains("Geo3dDistanceQuery"));
}
#[test]
fn geo3d_bbox_query_basics() {
let q = Geo3dBoundingBoxQuery::new(
"position",
GeoEcefPoint::new(0.0, 0.0, 0.0),
GeoEcefPoint::new(10.0, 20.0, 30.0),
)
.unwrap()
.with_boost(3.0);
assert_eq!(q.field(), "position");
assert_eq!(q.min(), GeoEcefPoint::new(0.0, 0.0, 0.0));
assert_eq!(q.max(), GeoEcefPoint::new(10.0, 20.0, 30.0));
assert_eq!(q.boost(), 3.0);
let cloned = q.clone_box();
assert!(cloned.description().contains("Geo3dBoundingBoxQuery"));
}
#[test]
fn geo3d_bbox_query_accepts_degenerate_box() {
let q = Geo3dBoundingBoxQuery::new(
"position",
GeoEcefPoint::new(5.0, 5.0, 5.0),
GeoEcefPoint::new(5.0, 5.0, 5.0),
);
assert!(q.is_ok(), "degenerate (zero-volume) box must be accepted");
}
#[test]
fn geo3d_bbox_query_accepts_zero_volume_axis() {
let q = Geo3dBoundingBoxQuery::new(
"position",
GeoEcefPoint::new(0.0, 0.0, 5.0),
GeoEcefPoint::new(10.0, 10.0, 5.0),
);
assert!(
q.is_ok(),
"box with zero extent on one axis must be accepted"
);
}
#[test]
fn geo3d_nearest_query_basics() {
let q = Geo3dNearestQuery::new("position", GeoEcefPoint::new(1.0, 2.0, 3.0), 5)
.with_initial_radius(100.0)
.with_max_radius(1_000_000.0)
.with_boost(2.5);
assert_eq!(q.field(), "position");
assert_eq!(q.center(), GeoEcefPoint::new(1.0, 2.0, 3.0));
assert_eq!(q.k(), 5);
assert_eq!(q.initial_radius_m(), 100.0);
assert_eq!(q.max_radius_m(), 1_000_000.0);
assert_eq!(q.boost(), 2.5);
let cloned = q.clone_box();
assert!(cloned.description().contains("Geo3dNearestQuery"));
}
#[test]
fn nearest_visitor_compare_outside_when_cell_too_far() {
let v = NearestVisitor::new(GeoEcefPoint::new(0.0, 0.0, 0.0), 5.0);
let c = cell([100.0, 100.0, 100.0], [110.0, 110.0, 110.0]);
assert_eq!(v.compare(&c), CellRelation::Outside);
}
#[test]
fn nearest_visitor_never_returns_inside() {
let v = NearestVisitor::new(GeoEcefPoint::new(0.0, 0.0, 0.0), 10_000.0);
let c = cell([-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]);
assert_eq!(v.compare(&c), CellRelation::Crosses);
}
#[test]
fn nearest_visitor_visit_records_exact_distance() {
let mut v = NearestVisitor::new(GeoEcefPoint::new(0.0, 0.0, 0.0), 100.0);
v.visit(1, &[3.0, 4.0, 0.0]); v.visit(2, &[0.0, 0.0, 6.0]); v.visit(3, &[200.0, 0.0, 0.0]); assert_eq!(v.hits.len(), 2);
assert_eq!(v.hits[0].doc_id, 1);
assert_eq!(v.hits[0].distance_sq, 25.0);
assert_eq!(v.hits[1].doc_id, 2);
assert_eq!(v.hits[1].distance_sq, 36.0);
}
#[test]
fn geo3d_bbox_query_rejects_inverted_box() {
let err = Geo3dBoundingBoxQuery::new(
"position",
GeoEcefPoint::new(10.0, 0.0, 0.0),
GeoEcefPoint::new(5.0, 10.0, 10.0),
)
.unwrap_err();
assert!(format!("{err:?}").contains("min.x"));
let err_y = Geo3dBoundingBoxQuery::new(
"position",
GeoEcefPoint::new(0.0, 10.0, 0.0),
GeoEcefPoint::new(10.0, 5.0, 10.0),
)
.unwrap_err();
assert!(format!("{err_y:?}").contains("min.y"));
let err_z = Geo3dBoundingBoxQuery::new(
"position",
GeoEcefPoint::new(0.0, 0.0, 10.0),
GeoEcefPoint::new(10.0, 10.0, 5.0),
)
.unwrap_err();
assert!(format!("{err_z:?}").contains("min.z"));
}
}