use super::aabb::AABB;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CellRelation {
Inside,
Outside,
Crosses,
}
pub trait IntersectVisitor {
fn compare(&self, cell: &AABB) -> CellRelation;
fn visit_inside(&mut self, doc_id: u64);
fn visit(&mut self, doc_id: u64, point: &[f64]);
}
pub struct RangeQueryVisitor {
mins: Vec<f64>,
maxs: Vec<f64>,
include_min: bool,
include_max: bool,
hits: Vec<u64>,
}
impl RangeQueryVisitor {
pub fn new(
mins: &[Option<f64>],
maxs: &[Option<f64>],
include_min: bool,
include_max: bool,
) -> Self {
let mins_f: Vec<f64> = mins
.iter()
.map(|m| m.unwrap_or(f64::NEG_INFINITY))
.collect();
let maxs_f: Vec<f64> = maxs.iter().map(|m| m.unwrap_or(f64::INFINITY)).collect();
RangeQueryVisitor {
mins: mins_f,
maxs: maxs_f,
include_min,
include_max,
hits: Vec::new(),
}
}
pub fn into_hits(self) -> Vec<u64> {
self.hits
}
fn point_matches(&self, point: &[f64]) -> bool {
for (d, &v) in point.iter().enumerate() {
let lower_ok = if self.include_min {
v >= self.mins[d]
} else {
v > self.mins[d]
};
let upper_ok = if self.include_max {
v <= self.maxs[d]
} else {
v < self.maxs[d]
};
if !(lower_ok && upper_ok) {
return false;
}
}
true
}
}
impl IntersectVisitor for RangeQueryVisitor {
fn compare(&self, cell: &AABB) -> CellRelation {
debug_assert_eq!(cell.num_dims(), self.mins.len());
let cell_min = cell.min();
let cell_max = cell.max();
for d in 0..cell_min.len() {
if cell_max[d] < self.mins[d] {
return CellRelation::Outside;
}
if !self.include_min && cell_max[d] <= self.mins[d] {
return CellRelation::Outside;
}
if cell_min[d] > self.maxs[d] {
return CellRelation::Outside;
}
if !self.include_max && cell_min[d] >= self.maxs[d] {
return CellRelation::Outside;
}
}
for d in 0..cell_min.len() {
let lower_ok = if self.include_min {
cell_min[d] >= self.mins[d]
} else {
cell_min[d] > self.mins[d]
};
let upper_ok = if self.include_max {
cell_max[d] <= self.maxs[d]
} else {
cell_max[d] < self.maxs[d]
};
if !(lower_ok && upper_ok) {
return CellRelation::Crosses;
}
}
CellRelation::Inside
}
fn visit_inside(&mut self, doc_id: u64) {
self.hits.push(doc_id);
}
fn visit(&mut self, doc_id: u64, point: &[f64]) {
if self.point_matches(point) {
self.hits.push(doc_id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cell(min: Vec<f64>, max: Vec<f64>) -> AABB {
AABB::new(min, max).unwrap()
}
#[test]
fn range_query_visitor_inside_outside_crosses() {
let v = RangeQueryVisitor::new(&[Some(10.0)], &[Some(20.0)], true, true);
assert_eq!(
v.compare(&cell(vec![12.0], vec![18.0])),
CellRelation::Inside
);
assert_eq!(
v.compare(&cell(vec![25.0], vec![30.0])),
CellRelation::Outside
);
assert_eq!(
v.compare(&cell(vec![5.0], vec![8.0])),
CellRelation::Outside
);
assert_eq!(
v.compare(&cell(vec![15.0], vec![25.0])),
CellRelation::Crosses
);
assert_eq!(
v.compare(&cell(vec![10.0], vec![10.0])),
CellRelation::Inside
);
}
#[test]
fn range_query_visitor_exclusive_boundary_outside() {
let v = RangeQueryVisitor::new(&[Some(10.0)], &[None], false, true);
assert_eq!(
v.compare(&cell(vec![5.0], vec![10.0])),
CellRelation::Outside
);
assert_eq!(
v.compare(&cell(vec![11.0], vec![20.0])),
CellRelation::Inside
);
assert_eq!(
v.compare(&cell(vec![10.0], vec![20.0])),
CellRelation::Crosses
);
}
#[test]
fn range_query_visitor_unbounded_dimensions() {
let v = RangeQueryVisitor::new(&[Some(0.0), None], &[Some(100.0), None], true, true);
assert_eq!(
v.compare(&cell(vec![10.0, -1e9], vec![90.0, 1e9])),
CellRelation::Inside
);
assert_eq!(
v.compare(&cell(vec![-50.0, 0.0], vec![-10.0, 10.0])),
CellRelation::Outside
);
}
#[test]
fn range_query_visitor_visit_filters_points_on_crosses() {
let mut v = RangeQueryVisitor::new(&[Some(10.0)], &[Some(20.0)], true, true);
v.visit(1, &[5.0]); v.visit(2, &[15.0]); v.visit(3, &[20.0]); v.visit(4, &[21.0]); let hits = v.into_hits();
assert_eq!(hits, vec![2, 3]);
}
#[test]
fn range_query_visitor_visit_inside_skips_filter() {
let mut v = RangeQueryVisitor::new(&[Some(10.0)], &[Some(20.0)], true, true);
v.visit_inside(7);
v.visit_inside(8);
let hits = v.into_hits();
assert_eq!(hits, vec![7, 8]);
}
}