use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::lexical::query::Query;
use crate::lexical::query::matcher::Matcher;
use crate::lexical::query::scorer::Scorer;
use crate::lexical::reader::LexicalIndexReader;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Serialize,
Deserialize,
rkyv::Archive,
rkyv::Serialize,
rkyv::Deserialize,
)]
pub struct GeoPoint {
pub lat: f64,
pub lon: f64,
}
impl GeoPoint {
pub fn new(lat: f64, lon: f64) -> Result<Self> {
if !(-90.0..=90.0).contains(&lat) {
return Err(crate::error::LaurusError::other(format!(
"Invalid latitude: {lat} (must be between -90 and 90)"
)));
}
if !(-180.0..=180.0).contains(&lon) {
return Err(crate::error::LaurusError::other(format!(
"Invalid longitude: {lon} (must be between -180 and 180)"
)));
}
Ok(GeoPoint { lat, lon })
}
pub fn distance_to(&self, other: &GeoPoint) -> f64 {
const EARTH_RADIUS_KM: f64 = 6371.0;
let lat1_rad = self.lat.to_radians();
let lat2_rad = other.lat.to_radians();
let delta_lat = (other.lat - self.lat).to_radians();
let delta_lon = (other.lon - self.lon).to_radians();
let a = (delta_lat / 2.0).sin().powi(2)
+ lat1_rad.cos() * lat2_rad.cos() * (delta_lon / 2.0).sin().powi(2);
let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
EARTH_RADIUS_KM * c
}
pub fn bearing_to(&self, other: &GeoPoint) -> f64 {
let lat1_rad = self.lat.to_radians();
let lat2_rad = other.lat.to_radians();
let delta_lon = (other.lon - self.lon).to_radians();
let y = delta_lon.sin() * lat2_rad.cos();
let x = lat1_rad.cos() * lat2_rad.sin() - lat1_rad.sin() * lat2_rad.cos() * delta_lon.cos();
let bearing_rad = y.atan2(x);
(bearing_rad.to_degrees() + 360.0) % 360.0
}
pub fn within_bounds(&self, min_lat: f64, max_lat: f64, min_lon: f64, max_lon: f64) -> bool {
self.lat >= min_lat && self.lat <= max_lat && self.lon >= min_lon && self.lon <= max_lon
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GeoBoundingBox {
pub top_left: GeoPoint,
pub bottom_right: GeoPoint,
}
impl GeoBoundingBox {
pub fn new(top_left: GeoPoint, bottom_right: GeoPoint) -> Result<Self> {
if top_left.lat < bottom_right.lat {
return Err(crate::error::LaurusError::other(
"Top-left latitude must be greater than bottom-right latitude",
));
}
if top_left.lon > bottom_right.lon {
return Err(crate::error::LaurusError::other(
"Top-left longitude must be less than bottom-right longitude",
));
}
Ok(GeoBoundingBox {
top_left,
bottom_right,
})
}
pub fn contains(&self, point: &GeoPoint) -> bool {
point.within_bounds(
self.bottom_right.lat, self.top_left.lat, self.top_left.lon, self.bottom_right.lon, )
}
pub fn center(&self) -> GeoPoint {
let center_lat = ((self.top_left.lat + self.bottom_right.lat) / 2.0).clamp(-90.0, 90.0);
let center_lon = ((self.top_left.lon + self.bottom_right.lon) / 2.0).clamp(-180.0, 180.0);
GeoPoint::new(center_lat, center_lon).expect("clamped center must be valid")
}
pub fn dimensions(&self) -> (f64, f64) {
let width = self.bottom_right.lon - self.top_left.lon;
let height = self.top_left.lat - self.bottom_right.lat;
(width, height)
}
pub fn max_distance_from_center(&self) -> f64 {
let center = self.center();
let corner_tr =
GeoPoint::new(self.top_left.lat, self.bottom_right.lon).unwrap_or(self.top_left);
let corner_bl =
GeoPoint::new(self.bottom_right.lat, self.top_left.lon).unwrap_or(self.bottom_right);
let corners = [&self.top_left, &self.bottom_right, &corner_tr, &corner_bl];
corners
.iter()
.map(|corner| center.distance_to(corner))
.fold(0.0, f64::max)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeoDistanceQuery {
field: String,
center: GeoPoint,
distance_km: f64,
boost: f32,
}
impl GeoDistanceQuery {
pub fn new<F: Into<String>>(field: F, center: GeoPoint, distance_km: f64) -> Self {
GeoDistanceQuery {
field: field.into(),
center,
distance_km,
boost: 1.0,
}
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn field(&self) -> &str {
&self.field
}
pub fn center(&self) -> GeoPoint {
self.center
}
pub fn distance_km(&self) -> f64 {
self.distance_km
}
pub fn find_matches(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<GeoMatch>> {
let mut matches = Vec::new();
let mut seen_docs = std::collections::HashSet::new();
let bounding_box = self.create_bounding_box();
let candidates = self.get_spatial_candidates(reader, &bounding_box)?;
for (doc_id, point) in candidates {
if seen_docs.contains(&doc_id) {
continue;
}
seen_docs.insert(doc_id);
let distance = self.center.distance_to(&point);
if distance <= self.distance_km {
let score = if distance == 0.0 {
1.0
} else {
(1.0 - (distance / self.distance_km)).max(0.0) as f32
};
matches.push(GeoMatch {
doc_id,
point,
distance_km: distance,
relevance_score: score,
});
}
}
matches.sort_by(|a, b| {
a.distance_km
.partial_cmp(&b.distance_km)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
Ok(matches)
}
fn create_bounding_box(&self) -> GeoBoundingBox {
let lat_deg_km = 111.0; let lon_deg_km = (111.0 * self.center.lat.to_radians().cos()).max(0.001);
let lat_delta = self.distance_km / lat_deg_km;
let lon_delta = self.distance_km / lon_deg_km;
let top_left = GeoPoint::new(
(self.center.lat + lat_delta).min(90.0),
(self.center.lon - lon_delta).max(-180.0),
)
.unwrap_or(self.center);
let bottom_right = GeoPoint::new(
(self.center.lat - lat_delta).max(-90.0),
(self.center.lon + lon_delta).min(180.0),
)
.unwrap_or(self.center);
GeoBoundingBox::new(top_left, bottom_right).unwrap_or_else(|_| {
let fallback_top_left = GeoPoint::new(
(self.center.lat + 0.01).min(90.0),
(self.center.lon - 0.01).max(-180.0),
)
.unwrap_or(self.center);
let fallback_bottom_right = GeoPoint::new(
(self.center.lat - 0.01).max(-90.0),
(self.center.lon + 0.01).min(180.0),
)
.unwrap_or(self.center);
GeoBoundingBox::new(fallback_top_left, fallback_bottom_right).unwrap_or(
GeoBoundingBox {
top_left: self.center,
bottom_right: self.center,
},
)
})
}
fn get_spatial_candidates(
&self,
reader: &dyn LexicalIndexReader,
bounding_box: &GeoBoundingBox,
) -> Result<Vec<(u64, GeoPoint)>> {
let mut candidates = Vec::new();
if let Some(bkd_tree) = reader.get_bkd_tree(&self.field)? {
let mins = [
Some(bounding_box.bottom_right.lat), Some(bounding_box.top_left.lon), ];
let maxs = [
Some(bounding_box.top_left.lat), Some(bounding_box.bottom_right.lon), ];
let doc_ids = bkd_tree.range_search(&mins, &maxs, true, true)?;
for doc_id in doc_ids {
if let Some(doc) = reader.document(doc_id)?
&& let Some(field_value) = doc.get_field(&self.field)
&& let Some((lat, lon)) = field_value.as_geo()
{
let geo_point = GeoPoint::new(lat, lon)?;
candidates.push((doc_id, geo_point));
}
}
return Ok(candidates);
}
let max_doc = reader.max_doc();
for doc_id in 0..max_doc {
if let Some(doc) = reader.document(doc_id)? {
if let Some(field_value) = doc.get_field(&self.field) {
if let Some((lat, lon)) = field_value.as_geo() {
let geo_point = GeoPoint::new(lat, lon)?;
if bounding_box.contains(&geo_point) {
let distance = self.center.distance_to(&geo_point);
if distance <= self.distance_km {
candidates.push((doc_id, geo_point));
}
}
}
}
}
}
Ok(candidates)
}
}
#[cfg(test)]
impl GeoDistanceQuery {
fn calculate_distance_score(&self, distance_km: f64) -> f32 {
if distance_km > self.distance_km {
return 0.0;
}
let normalized_distance = distance_km / self.distance_km;
(1.0 - normalized_distance) as f32
}
fn calculate_distance_score_enhanced(&self, distance_km: f64, point: &GeoPoint) -> f32 {
if distance_km > self.distance_km {
return 0.0;
}
let normalized_distance = distance_km / self.distance_km;
let base_score = (-2.0 * normalized_distance).exp() as f32;
let precision_bonus = if distance_km < 0.1 { 0.1 } else { 0.0 };
let geo_bonus = self.calculate_geographic_relevance(point);
let density_bonus = self.estimate_population_density(point) * 0.05;
(base_score + precision_bonus + geo_bonus + density_bonus).min(1.0)
}
fn calculate_geographic_relevance(&self, point: &GeoPoint) -> f32 {
let lat_abs = point.lat.abs();
let temperate_bonus = if lat_abs > 23.5 && lat_abs < 66.5 {
0.05
} else {
0.0
};
let meridian_bonus = if point.lon.abs() % 15.0 < 1.0 {
0.02
} else {
0.0
};
let equator_bonus = if point.lat.abs() < 5.0 { 0.03 } else { 0.0 };
temperate_bonus + meridian_bonus + equator_bonus
}
fn estimate_population_density(&self, point: &GeoPoint) -> f32 {
let lat_density = (1.0 - (point.lat.abs() / 90.0)) as f32;
let lon_density = (1.0 - (point.lon.abs() / 180.0)) as f32;
let coastal_bonus = if point.lon.abs() < 10.0 || (point.lat.abs() - 40.0).abs() < 5.0 {
0.2
} else {
0.0
};
((lat_density + lon_density) / 2.0 + coastal_bonus).min(1.0)
}
}
impl Query for GeoDistanceQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
let matches = self.find_matches(reader)?;
Ok(Box::new(GeoMatcher::new(matches)))
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
let matches = self.find_matches(reader)?;
Ok(Box::new(GeoScorer::new(matches, self.boost)))
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn description(&self) -> String {
format!(
"GeoDistanceQuery(field: {}, center: {:?}, distance: {}km)",
self.field, self.center, self.distance_km
)
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
Ok(self.distance_km <= 0.0)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let doc_count = reader.doc_count() as u32;
Ok(doc_count as u64 * 2) }
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeoBoundingBoxQuery {
field: String,
bounding_box: GeoBoundingBox,
boost: f32,
}
impl GeoBoundingBoxQuery {
pub fn new<F: Into<String>>(field: F, bounding_box: GeoBoundingBox) -> Self {
GeoBoundingBoxQuery {
field: field.into(),
bounding_box,
boost: 1.0,
}
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn field(&self) -> &str {
&self.field
}
pub fn bounding_box(&self) -> &GeoBoundingBox {
&self.bounding_box
}
pub fn find_matches(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<GeoMatch>> {
let mut matches = Vec::new();
let mut seen_docs = std::collections::HashSet::new();
let candidates = self.get_candidates_in_bounds(reader)?;
for (doc_id, point) in candidates {
if seen_docs.contains(&doc_id) {
continue;
}
seen_docs.insert(doc_id);
if self.bounding_box.contains(&point) {
let center = self.bounding_box.center();
let distance = center.distance_to(&point);
let relevance_score = if distance == 0.0 {
1.0
} else {
let max_distance = self.bounding_box.max_distance_from_center();
((max_distance - distance) / max_distance).max(0.0) as f32
};
matches.push(GeoMatch {
doc_id,
point,
distance_km: distance,
relevance_score,
});
}
}
matches.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap()
.then_with(|| a.distance_km.partial_cmp(&b.distance_km).unwrap())
});
Ok(matches)
}
fn get_candidates_in_bounds(
&self,
reader: &dyn LexicalIndexReader,
) -> Result<Vec<(u64, GeoPoint)>> {
let mut candidates = Vec::new();
if let Some(bkd_tree) = reader.get_bkd_tree(&self.field)? {
let mins = [
Some(self.bounding_box.bottom_right.lat), Some(self.bounding_box.top_left.lon), ];
let maxs = [
Some(self.bounding_box.top_left.lat), Some(self.bounding_box.bottom_right.lon), ];
let doc_ids = bkd_tree.range_search(&mins, &maxs, true, true)?;
for doc_id in doc_ids {
if let Some(doc) = reader.document(doc_id)?
&& let Some(field_value) = doc.get_field(&self.field)
&& let Some((lat, lon)) = field_value.as_geo()
{
let geo_point = GeoPoint::new(lat, lon)?;
candidates.push((doc_id, geo_point));
}
}
return Ok(candidates);
}
let max_doc = reader.max_doc();
for doc_id in 0..max_doc {
if let Some(doc) = reader.document(doc_id)? {
if let Some(field_value) = doc.get_field(&self.field) {
if let Some((lat, lon)) = field_value.as_geo() {
let geo_point = GeoPoint::new(lat, lon)?;
if self.bounding_box.contains(&geo_point) {
candidates.push((doc_id, geo_point));
}
}
}
}
}
Ok(candidates)
}
}
#[cfg(test)]
impl GeoBoundingBoxQuery {
fn generate_bounding_box_candidates(&self) -> Vec<(u64, GeoPoint)> {
let mut candidates = Vec::new();
let (width, height) = self.bounding_box.dimensions();
let grid_size = 20;
for i in 0..grid_size {
for j in 0..grid_size {
let lat_ratio = i as f64 / (grid_size - 1) as f64;
let lon_ratio = j as f64 / (grid_size - 1) as f64;
let lat = self.bounding_box.bottom_right.lat + lat_ratio * height;
let lon = self.bounding_box.top_left.lon + lon_ratio * width;
if let Ok(point) = GeoPoint::new(lat, lon) {
let doc_id = (i * grid_size + j + 2000) as u64;
candidates.push((doc_id, point));
}
}
}
let expansion_factor = 0.1;
let expanded_width = width * (1.0 + expansion_factor);
let expanded_height = height * (1.0 + expansion_factor);
for i in 0..10 {
let angle = (i as f64 / 10.0) * 2.0 * std::f64::consts::PI;
let lat_offset = angle.sin() * expanded_height / 2.0;
let lon_offset = angle.cos() * expanded_width / 2.0;
let center = self.bounding_box.center();
if let Ok(point) = GeoPoint::new(center.lat + lat_offset, center.lon + lon_offset) {
let doc_id = (i + 3000) as u64;
candidates.push((doc_id, point));
}
}
candidates
}
fn calculate_bounding_box_score(&self, point: &GeoPoint) -> f32 {
let center = self.bounding_box.center();
let (width, height) = self.bounding_box.dimensions();
let distance_to_center = center.distance_to(point);
let diagonal_km = ((width * 111.0).powi(2) + (height * 111.0).powi(2)).sqrt();
let normalized_distance = distance_to_center / diagonal_km;
let base_score = (1.0 - normalized_distance.min(1.0)) as f32;
let edge_bonus = self.calculate_edge_proximity_bonus(point);
let corner_bonus = self.calculate_corner_bonus(point);
(base_score + edge_bonus + corner_bonus).min(1.0)
}
fn calculate_edge_proximity_bonus(&self, point: &GeoPoint) -> f32 {
let (width, height) = self.bounding_box.dimensions();
let edge_threshold = 0.1;
let lat_distance_from_edge = (point.lat - self.bounding_box.bottom_right.lat)
.min(self.bounding_box.top_left.lat - point.lat);
let lon_distance_from_edge = (point.lon - self.bounding_box.top_left.lon)
.min(self.bounding_box.bottom_right.lon - point.lon);
let lat_edge_proximity = if lat_distance_from_edge < height * edge_threshold {
0.05
} else {
0.0
};
let lon_edge_proximity = if lon_distance_from_edge < width * edge_threshold {
0.05
} else {
0.0
};
lat_edge_proximity + lon_edge_proximity
}
fn calculate_corner_bonus(&self, point: &GeoPoint) -> f32 {
let corners = [
self.bounding_box.top_left,
GeoPoint::new(
self.bounding_box.top_left.lat,
self.bounding_box.bottom_right.lon,
)
.unwrap(),
self.bounding_box.bottom_right,
GeoPoint::new(
self.bounding_box.bottom_right.lat,
self.bounding_box.top_left.lon,
)
.unwrap(),
];
let corner_threshold_km = 1.0; let mut min_corner_distance = f64::INFINITY;
for corner in &corners {
let distance = point.distance_to(corner);
min_corner_distance = min_corner_distance.min(distance);
}
if min_corner_distance < corner_threshold_km {
0.1 } else {
0.0
}
}
}
impl Query for GeoBoundingBoxQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
let matches = self.find_matches(reader)?;
Ok(Box::new(GeoMatcher::new(matches)))
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
let matches = self.find_matches(reader)?;
Ok(Box::new(GeoScorer::new(matches, self.boost)))
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn description(&self) -> String {
format!(
"GeoBoundingBoxQuery(field: {}, bounds: {:?})",
self.field, self.bounding_box
)
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
let (width, height) = self.bounding_box.dimensions();
Ok(width <= 0.0 || height <= 0.0)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let doc_count = reader.doc_count() as u32;
Ok(doc_count as u64)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeoMatch {
pub doc_id: u64,
pub point: GeoPoint,
pub distance_km: f64,
pub relevance_score: f32,
}
#[derive(Debug)]
pub struct GeoMatcher {
matches: Vec<GeoMatch>,
current_index: usize,
}
impl GeoMatcher {
pub fn new(mut matches: Vec<GeoMatch>) -> Self {
matches.sort_by(|a, b| {
a.distance_km
.partial_cmp(&b.distance_km)
.unwrap_or(std::cmp::Ordering::Equal)
});
GeoMatcher {
matches,
current_index: 0,
}
}
}
impl Matcher for GeoMatcher {
fn doc_id(&self) -> u64 {
if self.current_index >= self.matches.len() {
u64::MAX
} else {
self.matches[self.current_index].doc_id
}
}
fn next(&mut self) -> Result<bool> {
self.current_index += 1;
if self.current_index < self.matches.len() {
Ok(true)
} else {
self.current_index = self.matches.len();
Ok(false)
}
}
fn skip_to(&mut self, target: u64) -> Result<bool> {
while self.current_index < self.matches.len() {
let doc_id = self.matches[self.current_index].doc_id;
if doc_id >= target {
return Ok(true);
}
self.current_index += 1;
}
Ok(false)
}
fn cost(&self) -> u64 {
self.matches.len() as u64
}
fn is_exhausted(&self) -> bool {
self.current_index >= self.matches.len()
}
}
#[derive(Debug)]
pub struct GeoScorer {
doc_scores: HashMap<u64, f32>,
boost: f32,
}
impl GeoScorer {
pub fn new(matches: Vec<GeoMatch>, boost: f32) -> Self {
let mut doc_scores = HashMap::new();
for geo_match in matches {
doc_scores.insert(geo_match.doc_id, geo_match.relevance_score);
}
GeoScorer { doc_scores, boost }
}
}
impl Scorer for GeoScorer {
fn score(&self, doc_id: u64, _term_freq: f32, _field_length: Option<f32>) -> f32 {
self.doc_scores.get(&doc_id).unwrap_or(&0.0) * self.boost
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn max_score(&self) -> f32 {
self.doc_scores
.values()
.fold(0.0_f32, |max, &score| max.max(score))
* self.boost
}
fn name(&self) -> &'static str {
"GeoScorer"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GeoQuery {
Distance(GeoDistanceQuery),
BoundingBox(GeoBoundingBoxQuery),
}
impl GeoQuery {
pub fn within_radius<F: Into<String>>(
field: F,
lat: f64,
lon: f64,
radius_km: f64,
) -> Result<Self> {
let center = GeoPoint::new(lat, lon)?;
Ok(GeoQuery::Distance(GeoDistanceQuery::new(
field, center, radius_km,
)))
}
pub fn within_bounding_box<F: Into<String>>(
field: F,
min_lat: f64,
min_lon: f64,
max_lat: f64,
max_lon: f64,
) -> Result<Self> {
let top_left = GeoPoint::new(max_lat, min_lon)?;
let bottom_right = GeoPoint::new(min_lat, max_lon)?;
let bbox = GeoBoundingBox::new(top_left, bottom_right)?;
Ok(GeoQuery::BoundingBox(GeoBoundingBoxQuery::new(field, bbox)))
}
pub fn from_center_and_radius<F: Into<String>>(
field: F,
center: GeoPoint,
radius_km: f64,
) -> Self {
GeoQuery::Distance(GeoDistanceQuery::new(field, center, radius_km))
}
pub fn from_bounding_box<F: Into<String>>(field: F, bbox: GeoBoundingBox) -> Self {
GeoQuery::BoundingBox(GeoBoundingBoxQuery::new(field, bbox))
}
pub fn with_boost(mut self, boost: f32) -> Self {
match &mut self {
GeoQuery::Distance(query) => {
*query = query.clone().with_boost(boost);
}
GeoQuery::BoundingBox(query) => {
*query = query.clone().with_boost(boost);
}
}
self
}
pub fn field(&self) -> &str {
match self {
GeoQuery::Distance(query) => query.field(),
GeoQuery::BoundingBox(query) => query.field(),
}
}
pub fn boost(&self) -> f32 {
match self {
GeoQuery::Distance(query) => query.boost(),
GeoQuery::BoundingBox(query) => query.boost(),
}
}
}
impl Query for GeoQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
match self {
GeoQuery::Distance(query) => query.matcher(reader),
GeoQuery::BoundingBox(query) => query.matcher(reader),
}
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
match self {
GeoQuery::Distance(query) => query.scorer(reader),
GeoQuery::BoundingBox(query) => query.scorer(reader),
}
}
fn boost(&self) -> f32 {
match self {
GeoQuery::Distance(query) => query.boost(),
GeoQuery::BoundingBox(query) => query.boost(),
}
}
fn set_boost(&mut self, boost: f32) {
match self {
GeoQuery::Distance(query) => query.set_boost(boost),
GeoQuery::BoundingBox(query) => query.set_boost(boost),
}
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn description(&self) -> String {
match self {
GeoQuery::Distance(query) => query.description(),
GeoQuery::BoundingBox(query) => query.description(),
}
}
fn is_empty(&self, reader: &dyn LexicalIndexReader) -> Result<bool> {
match self {
GeoQuery::Distance(query) => query.is_empty(reader),
GeoQuery::BoundingBox(query) => query.is_empty(reader),
}
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
match self {
GeoQuery::Distance(query) => query.cost(reader),
GeoQuery::BoundingBox(query) => query.cost(reader),
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_geo_point_creation() {
let point = GeoPoint::new(40.7128, -74.0060).unwrap(); assert_eq!(point.lat, 40.7128);
assert_eq!(point.lon, -74.0060);
assert!(GeoPoint::new(91.0, 0.0).is_err()); assert!(GeoPoint::new(0.0, 181.0).is_err()); }
#[test]
fn test_geo_distance_calculation() {
let nyc = GeoPoint::new(40.7128, -74.0060).unwrap();
let la = GeoPoint::new(34.0522, -118.2437).unwrap();
let distance = nyc.distance_to(&la);
assert!((distance - 3944.0).abs() < 100.0); }
#[test]
fn test_geo_bearing() {
let nyc = GeoPoint::new(40.7128, -74.0060).unwrap();
let la = GeoPoint::new(34.0522, -118.2437).unwrap();
let bearing = nyc.bearing_to(&la);
assert!(bearing > 200.0 && bearing < 300.0);
}
#[test]
fn test_geo_bounding_box() {
let top_left = GeoPoint::new(41.0, -75.0).unwrap();
let bottom_right = GeoPoint::new(40.0, -74.0).unwrap();
let bbox = GeoBoundingBox::new(top_left, bottom_right).unwrap();
let inside_point = GeoPoint::new(40.5, -74.5).unwrap();
let outside_point = GeoPoint::new(42.0, -73.0).unwrap();
assert!(bbox.contains(&inside_point));
assert!(!bbox.contains(&outside_point));
let center = bbox.center();
assert_eq!(center.lat, 40.5);
assert_eq!(center.lon, -74.5);
}
#[test]
fn test_geo_distance_query() {
let center = GeoPoint::new(40.7128, -74.0060).unwrap();
let query = GeoDistanceQuery::new("location", center, 10.0).with_boost(1.5);
assert_eq!(query.field(), "location");
assert_eq!(query.center(), center);
assert_eq!(query.distance_km(), 10.0);
assert_eq!(query.boost(), 1.5);
}
#[test]
fn test_geo_distance_scoring() {
let center = GeoPoint::new(0.0, 0.0).unwrap();
let query = GeoDistanceQuery::new("location", center, 10.0);
assert_eq!(query.calculate_distance_score(0.0), 1.0); assert_eq!(query.calculate_distance_score(5.0), 0.5); assert_eq!(query.calculate_distance_score(10.0), 0.0); assert_eq!(query.calculate_distance_score(15.0), 0.0); }
#[test]
fn test_geo_bounding_box_query() {
let top_left = GeoPoint::new(41.0, -75.0).unwrap();
let bottom_right = GeoPoint::new(40.0, -74.0).unwrap();
let bbox = GeoBoundingBox::new(top_left, bottom_right).unwrap();
let query = GeoBoundingBoxQuery::new("location", bbox);
assert_eq!(query.field(), "location");
assert_eq!(query.bounding_box().top_left, top_left);
assert_eq!(query.bounding_box().bottom_right, bottom_right);
}
#[test]
fn test_geo_matcher() {
let matches = vec![
GeoMatch {
doc_id: 3,
point: GeoPoint::new(0.0, 0.0).unwrap(),
distance_km: 1.0,
relevance_score: 0.9,
},
GeoMatch {
doc_id: 1,
point: GeoPoint::new(0.0, 0.0).unwrap(),
distance_km: 2.0,
relevance_score: 0.8,
},
];
let mut matcher = GeoMatcher::new(matches);
assert_eq!(matcher.doc_id(), 3);
assert!(matcher.next().unwrap()); assert_eq!(matcher.doc_id(), 1);
assert!(!matcher.next().unwrap()); }
#[test]
fn test_geo_scorer() {
let matches = vec![GeoMatch {
doc_id: 1,
point: GeoPoint::new(0.0, 0.0).unwrap(),
distance_km: 1.0,
relevance_score: 0.9,
}];
let scorer = GeoScorer::new(matches, 2.0);
assert_eq!(scorer.score(1, 1.0, None), 0.9 * 2.0);
assert_eq!(scorer.score(999, 1.0, None), 0.0); assert_eq!(scorer.max_score(), 0.9 * 2.0);
assert_eq!(scorer.name(), "GeoScorer");
}
#[test]
fn test_enhanced_distance_scoring() {
let center = GeoPoint::new(40.7128, -74.0060).unwrap(); let query = GeoDistanceQuery::new("location", center, 10.0);
let close_point = GeoPoint::new(40.7130, -74.0062).unwrap();
let close_score = query.calculate_distance_score_enhanced(0.05, &close_point);
let mid_point = GeoPoint::new(40.7200, -74.0100).unwrap();
let mid_score = query.calculate_distance_score_enhanced(1.0, &mid_point);
let far_point = GeoPoint::new(40.8000, -74.1000).unwrap();
let far_score = query.calculate_distance_score_enhanced(9.0, &far_point);
assert!(close_score > mid_score);
assert!(mid_score > far_score);
assert!(close_score > 0.9); }
#[test]
fn test_bounding_box_enhanced_functionality() {
let top_left = GeoPoint::new(41.0, -75.0).unwrap();
let bottom_right = GeoPoint::new(40.0, -74.0).unwrap();
let bbox = GeoBoundingBox::new(top_left, bottom_right).unwrap();
let query = GeoBoundingBoxQuery::new("location", bbox);
let candidates = query.generate_bounding_box_candidates();
assert!(!candidates.is_empty());
let within_count = candidates
.iter()
.filter(|(_, point)| query.bounding_box().contains(point))
.count();
assert!(within_count > 0);
let center_point = query.bounding_box().center();
let center_score = query.calculate_bounding_box_score(¢er_point);
let corner_point = query.bounding_box().top_left;
let corner_score = query.calculate_bounding_box_score(&corner_point);
assert!(center_score >= corner_score);
}
#[test]
fn test_spatial_bounding_box_creation() {
let center = GeoPoint::new(40.7128, -74.0060).unwrap(); let query = GeoDistanceQuery::new("location", center, 5.0);
let bbox = query.create_bounding_box();
assert!(bbox.contains(¢er));
let (width, height) = bbox.dimensions();
assert!(width > 0.0 && width < 1.0); assert!(height > 0.0 && height < 1.0);
let bbox_center = bbox.center();
let center_distance = center.distance_to(&bbox_center);
assert!(center_distance < 1.0); }
#[test]
fn test_geographic_relevance_calculation() {
let center = GeoPoint::new(40.7128, -74.0060).unwrap();
let query = GeoDistanceQuery::new("location", center, 10.0);
let temperate_point = GeoPoint::new(45.0, 0.0).unwrap(); let tropical_point = GeoPoint::new(10.0, 0.0).unwrap();
let temperate_bonus = query.calculate_geographic_relevance(&temperate_point);
let tropical_bonus = query.calculate_geographic_relevance(&tropical_point);
assert!(temperate_bonus > tropical_bonus);
let equator_point = GeoPoint::new(2.0, 0.0).unwrap(); let non_equator_point = GeoPoint::new(45.0, 0.0).unwrap();
let equator_geo_bonus = query.calculate_geographic_relevance(&equator_point);
let non_equator_geo_bonus = query.calculate_geographic_relevance(&non_equator_point);
assert!(equator_geo_bonus > 0.0);
assert!(non_equator_geo_bonus > 0.0);
}
}