use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
use thiserror::Error;
use common::VectorId;
type SearchResultRow = (VectorId, f32, Option<Vec<f32>>, Option<serde_json::Value>);
#[derive(Debug, Error)]
pub enum BatchSearchError {
#[error("Query batch too large: {0} exceeds maximum {1}")]
BatchTooLarge(usize, usize),
#[error("Invalid cursor: {0}")]
InvalidCursor(String),
#[error("Invalid geo-coordinate: lat={0}, lon={1}")]
InvalidGeoCoordinate(f64, f64),
#[error("Unsupported aggregation type: {0}")]
UnsupportedAggregation(String),
#[error("Invalid scoring function: {0}")]
InvalidScoringFunction(String),
#[error("Query timeout exceeded: {0}ms")]
Timeout(u64),
#[error("Internal error: {0}")]
Internal(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQueryConfig {
pub max_batch_size: usize,
pub max_concurrency: usize,
pub query_timeout_ms: u64,
pub deduplicate_queries: bool,
pub collect_stats: bool,
}
impl Default for BatchQueryConfig {
fn default() -> Self {
Self {
max_batch_size: 100,
max_concurrency: 16,
query_timeout_ms: 5000,
deduplicate_queries: true,
collect_stats: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQueryItem {
pub query_id: String,
pub vector: Vec<f32>,
pub top_k: usize,
pub filter: Option<FilterExpression>,
pub cursor: Option<SearchCursor>,
pub scoring: Option<ScoringConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQueryRequest {
pub namespace: String,
pub queries: Vec<BatchQueryItem>,
pub include_vectors: bool,
pub include_metadata: bool,
pub facets: Option<Vec<FacetRequest>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQueryItemResponse {
pub query_id: String,
pub results: Vec<SearchHit>,
pub next_cursor: Option<SearchCursor>,
pub took_ms: u64,
pub total_matches: usize,
pub explanation: Option<QueryExplanation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQueryResponse {
pub responses: Vec<BatchQueryItemResponse>,
pub facets: Option<HashMap<String, FacetResult>>,
pub stats: BatchQueryStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatchQueryStats {
pub total_queries: usize,
pub successful_queries: usize,
pub failed_queries: usize,
pub total_took_ms: u64,
pub avg_query_ms: f64,
pub max_query_ms: u64,
pub min_query_ms: u64,
pub deduplicated_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchHit {
pub id: VectorId,
pub score: f32,
pub vector: Option<Vec<f32>>,
pub metadata: Option<serde_json::Value>,
pub sort_values: Vec<SortValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchCursor {
pub cursor_type: CursorType,
pub value: String,
pub created_at: u64,
}
impl SearchCursor {
pub fn cursor_based(offset: usize, total: usize) -> Self {
let value = format!("{}:{}", offset, total);
Self {
cursor_type: CursorType::CursorBased,
value: base64_encode(&value),
created_at: current_timestamp_ms(),
}
}
pub fn search_after(sort_values: &[SortValue]) -> Self {
let json = serde_json::to_string(sort_values).unwrap_or_default();
Self {
cursor_type: CursorType::SearchAfter,
value: base64_encode(&json),
created_at: current_timestamp_ms(),
}
}
pub fn parse_offset(&self) -> Result<usize, BatchSearchError> {
if self.cursor_type != CursorType::CursorBased {
return Err(BatchSearchError::InvalidCursor(
"Expected cursor-based cursor".into(),
));
}
let decoded = base64_decode(&self.value)
.map_err(|_| BatchSearchError::InvalidCursor("Invalid base64".into()))?;
let parts: Vec<&str> = decoded.split(':').collect();
parts
.first()
.and_then(|s| s.parse().ok())
.ok_or_else(|| BatchSearchError::InvalidCursor("Invalid offset format".into()))
}
pub fn parse_sort_values(&self) -> Result<Vec<SortValue>, BatchSearchError> {
if self.cursor_type != CursorType::SearchAfter {
return Err(BatchSearchError::InvalidCursor(
"Expected search-after cursor".into(),
));
}
let decoded = base64_decode(&self.value)
.map_err(|_| BatchSearchError::InvalidCursor("Invalid base64".into()))?;
serde_json::from_str(&decoded)
.map_err(|_| BatchSearchError::InvalidCursor("Invalid sort values".into()))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CursorType {
CursorBased,
SearchAfter,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SortValue {
Score(f32),
Integer(i64),
Float(f64),
String(String),
Null,
}
impl SortValue {
pub fn compare(&self, other: &SortValue) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (self, other) {
(SortValue::Score(a), SortValue::Score(b)) => {
b.partial_cmp(a).unwrap_or(Ordering::Equal)
}
(SortValue::Integer(a), SortValue::Integer(b)) => a.cmp(b),
(SortValue::Float(a), SortValue::Float(b)) => {
a.partial_cmp(b).unwrap_or(Ordering::Equal)
}
(SortValue::String(a), SortValue::String(b)) => a.cmp(b),
(SortValue::Null, SortValue::Null) => Ordering::Equal,
(SortValue::Null, _) => Ordering::Greater,
(_, SortValue::Null) => Ordering::Less,
_ => Ordering::Equal,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaginationConfig {
pub page_size: usize,
pub max_offset: usize,
pub cursor_ttl_secs: u64,
pub default_sort: Vec<SortField>,
}
impl Default for PaginationConfig {
fn default() -> Self {
Self {
page_size: 20,
max_offset: 10000,
cursor_ttl_secs: 3600,
default_sort: vec![SortField {
field: "_score".into(),
order: SortOrder::Descending,
}],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SortField {
pub field: String,
pub order: SortOrder,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SortOrder {
Ascending,
Descending,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetRequest {
pub name: String,
pub field: String,
pub agg_type: AggregationType,
pub max_buckets: Option<usize>,
pub ranges: Option<Vec<RangeBucket>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregationType {
Terms,
Range,
DateHistogram { interval: String },
Histogram { interval: f64 },
Min,
Max,
Avg,
Sum,
Count,
Cardinality,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangeBucket {
pub key: String,
pub from: Option<f64>,
pub to: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetResult {
pub name: String,
pub field: String,
pub buckets: Option<Vec<FacetBucket>>,
pub value: Option<f64>,
pub doc_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetBucket {
pub key: String,
pub doc_count: usize,
pub sub_aggregations: Option<HashMap<String, FacetResult>>,
}
pub struct FacetExecutor {
max_buckets: usize,
}
impl FacetExecutor {
pub fn new(max_buckets: usize) -> Self {
Self { max_buckets }
}
pub fn terms_aggregation(
&self,
values: &[Option<serde_json::Value>],
max_buckets: usize,
) -> Vec<FacetBucket> {
let mut counts: HashMap<String, usize> = HashMap::new();
for value in values.iter().flatten() {
let key = match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => continue,
};
*counts.entry(key).or_insert(0) += 1;
}
let mut buckets: Vec<_> = counts
.into_iter()
.map(|(key, count)| FacetBucket {
key,
doc_count: count,
sub_aggregations: None,
})
.collect();
buckets.sort_by(|a, b| b.doc_count.cmp(&a.doc_count));
buckets.truncate(max_buckets.min(self.max_buckets));
buckets
}
pub fn range_aggregation(
&self,
values: &[Option<f64>],
ranges: &[RangeBucket],
) -> Vec<FacetBucket> {
ranges
.iter()
.map(|range| {
let count = values
.iter()
.filter(|v| {
if let Some(val) = v {
let from_ok = range.from.is_none_or(|f| *val >= f);
let to_ok = range.to.is_none_or(|t| *val < t);
from_ok && to_ok
} else {
false
}
})
.count();
FacetBucket {
key: range.key.clone(),
doc_count: count,
sub_aggregations: None,
}
})
.collect()
}
pub fn numeric_aggregation(&self, values: &[Option<f64>], agg_type: &AggregationType) -> f64 {
let valid_values: Vec<f64> = values.iter().filter_map(|v| *v).collect();
if valid_values.is_empty() {
return 0.0;
}
match agg_type {
AggregationType::Min => valid_values.iter().copied().fold(f64::INFINITY, f64::min),
AggregationType::Max => valid_values
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max),
AggregationType::Avg => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
AggregationType::Sum => valid_values.iter().sum(),
AggregationType::Count => valid_values.len() as f64,
_ => 0.0,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct GeoPoint {
pub lat: f64,
pub lon: f64,
}
impl GeoPoint {
pub fn new(lat: f64, lon: f64) -> Result<Self, BatchSearchError> {
if !(-90.0..=90.0).contains(&lat) || !(-180.0..=180.0).contains(&lon) {
return Err(BatchSearchError::InvalidGeoCoordinate(lat, lon));
}
Ok(Self { lat, lon })
}
pub fn distance_km(&self, other: &GeoPoint) -> f64 {
const EARTH_RADIUS_KM: f64 = 6371.0;
let lat1 = self.lat.to_radians();
let lat2 = other.lat.to_radians();
let delta_lat = (other.lat - self.lat).to_radians();
let delta_lon = (other.lon - self.lon).to_radians();
let a = (delta_lat / 2.0).sin().powi(2)
+ lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
let c = 2.0 * a.sqrt().asin();
EARTH_RADIUS_KM * c
}
pub fn distance_m(&self, other: &GeoPoint) -> f64 {
self.distance_km(other) * 1000.0
}
pub fn distance_miles(&self, other: &GeoPoint) -> f64 {
self.distance_km(other) * 0.621371
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GeoFilter {
Distance {
center: GeoPoint,
distance: f64,
unit: DistanceUnit,
},
BoundingBox {
top_left: GeoPoint,
bottom_right: GeoPoint,
},
Polygon {
points: Vec<GeoPoint>,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum DistanceUnit {
Meters,
Kilometers,
Miles,
Feet,
}
impl DistanceUnit {
pub fn to_meters(&self, distance: f64) -> f64 {
match self {
DistanceUnit::Meters => distance,
DistanceUnit::Kilometers => distance * 1000.0,
DistanceUnit::Miles => distance * 1609.34,
DistanceUnit::Feet => distance * 0.3048,
}
}
}
pub struct GeoFilterExecutor;
impl GeoFilterExecutor {
pub fn matches(filter: &GeoFilter, point: &GeoPoint) -> bool {
match filter {
GeoFilter::Distance {
center,
distance,
unit,
} => {
let max_distance_m = unit.to_meters(*distance);
center.distance_m(point) <= max_distance_m
}
GeoFilter::BoundingBox {
top_left,
bottom_right,
} => {
point.lat <= top_left.lat
&& point.lat >= bottom_right.lat
&& point.lon >= top_left.lon
&& point.lon <= bottom_right.lon
}
GeoFilter::Polygon { points } => Self::point_in_polygon(point, points),
}
}
fn point_in_polygon(point: &GeoPoint, polygon: &[GeoPoint]) -> bool {
if polygon.len() < 3 {
return false;
}
let mut inside = false;
let n = polygon.len();
let mut j = n - 1;
for i in 0..n {
let pi = &polygon[i];
let pj = &polygon[j];
if ((pi.lat > point.lat) != (pj.lat > point.lat))
&& (point.lon
< (pj.lon - pi.lon) * (point.lat - pi.lat) / (pj.lat - pi.lat) + pi.lon)
{
inside = !inside;
}
j = i;
}
inside
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoringConfig {
pub score_mode: ScoreMode,
pub functions: Vec<ScoreFunction>,
pub boost_mode: BoostMode,
pub min_score: Option<f32>,
}
impl Default for ScoringConfig {
fn default() -> Self {
Self {
score_mode: ScoreMode::Multiply,
functions: Vec::new(),
boost_mode: BoostMode::Multiply,
min_score: None,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ScoreMode {
Multiply,
Sum,
Average,
First,
Max,
Min,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum BoostMode {
Multiply,
Replace,
Sum,
Average,
Max,
Min,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ScoreFunction {
Weight { weight: f32 },
FieldValue {
field: String,
factor: f32,
modifier: FieldValueModifier,
missing: f32,
},
Decay {
field: String,
origin: f64,
scale: f64,
offset: f64,
decay: f64,
decay_type: DecayType,
},
RandomScore { seed: u64, field: Option<String> },
Script {
source: String,
params: HashMap<String, f64>,
},
GeoDecay {
field: String,
origin: GeoPoint,
scale: f64,
scale_unit: DistanceUnit,
offset: f64,
decay: f64,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum FieldValueModifier {
None,
Log,
Log1p,
Log2p,
Ln,
Ln1p,
Ln2p,
Square,
Sqrt,
Reciprocal,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum DecayType {
Gaussian,
Linear,
Exponential,
}
pub struct ScoreFunctionExecutor;
impl ScoreFunctionExecutor {
pub fn apply(
function: &ScoreFunction,
original_score: f32,
metadata: Option<&serde_json::Value>,
) -> f32 {
match function {
ScoreFunction::Weight { weight } => *weight,
ScoreFunction::FieldValue {
field,
factor,
modifier,
missing,
} => {
let value = metadata
.and_then(|m| m.get(field))
.and_then(|v| v.as_f64())
.unwrap_or(*missing as f64);
let modified = Self::apply_modifier(value, *modifier);
(modified * *factor as f64) as f32
}
ScoreFunction::Decay {
origin,
scale,
offset,
decay,
decay_type,
field,
} => {
let value = metadata
.and_then(|m| m.get(field))
.and_then(|v| v.as_f64())
.unwrap_or(*origin);
let distance = (value - origin).abs() - offset;
if distance <= 0.0 {
return 1.0;
}
Self::compute_decay(distance, *scale, *decay, *decay_type) as f32
}
ScoreFunction::RandomScore { seed, .. } => {
let hash = (*seed as u32).wrapping_mul(2654435769);
hash as f32 / u32::MAX as f32
}
ScoreFunction::Script { source, params } => {
Self::evaluate_script(source, params, original_score, metadata)
}
ScoreFunction::GeoDecay {
field,
origin,
scale,
scale_unit,
offset,
decay,
} => {
let point = metadata.and_then(|m| m.get(field)).and_then(|v| {
let lat = v.get("lat")?.as_f64()?;
let lon = v.get("lon")?.as_f64()?;
Some(GeoPoint { lat, lon })
});
if let Some(point) = point {
let distance_m = origin.distance_m(&point);
let scale_m = scale_unit.to_meters(*scale);
let offset_m = scale_unit.to_meters(*offset);
let adjusted_distance = (distance_m - offset_m).max(0.0);
Self::compute_decay(adjusted_distance, scale_m, *decay, DecayType::Gaussian)
as f32
} else {
0.0
}
}
}
}
fn apply_modifier(value: f64, modifier: FieldValueModifier) -> f64 {
match modifier {
FieldValueModifier::None => value,
FieldValueModifier::Log => value.log10(),
FieldValueModifier::Log1p => (1.0 + value).log10(),
FieldValueModifier::Log2p => (2.0 + value).log10(),
FieldValueModifier::Ln => value.ln(),
FieldValueModifier::Ln1p => (1.0 + value).ln(),
FieldValueModifier::Ln2p => (2.0 + value).ln(),
FieldValueModifier::Square => value * value,
FieldValueModifier::Sqrt => value.sqrt(),
FieldValueModifier::Reciprocal => 1.0 / value.max(0.001),
}
}
fn compute_decay(distance: f64, scale: f64, decay: f64, decay_type: DecayType) -> f64 {
let lambda = scale.ln().abs() / decay.ln().abs();
match decay_type {
DecayType::Gaussian => (-0.5 * (distance / lambda).powi(2)).exp(),
DecayType::Linear => ((lambda - distance) / lambda).max(0.0),
DecayType::Exponential => (-distance / lambda).exp(),
}
}
fn evaluate_script(
_source: &str,
params: &HashMap<String, f64>,
original_score: f32,
_metadata: Option<&serde_json::Value>,
) -> f32 {
let boost = params.get("boost").copied().unwrap_or(1.0) as f32;
original_score * boost
}
pub fn combine_scores(scores: &[f32], mode: ScoreMode) -> f32 {
if scores.is_empty() {
return 1.0;
}
match mode {
ScoreMode::Multiply => scores.iter().product(),
ScoreMode::Sum => scores.iter().sum(),
ScoreMode::Average => scores.iter().sum::<f32>() / scores.len() as f32,
ScoreMode::First => scores[0],
ScoreMode::Max => scores.iter().copied().fold(f32::NEG_INFINITY, f32::max),
ScoreMode::Min => scores.iter().copied().fold(f32::INFINITY, f32::min),
}
}
pub fn combine_with_original(original: f32, function_score: f32, mode: BoostMode) -> f32 {
match mode {
BoostMode::Multiply => original * function_score,
BoostMode::Replace => function_score,
BoostMode::Sum => original + function_score,
BoostMode::Average => (original + function_score) / 2.0,
BoostMode::Max => original.max(function_score),
BoostMode::Min => original.min(function_score),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryExplanation {
pub score: f32,
pub description: String,
pub details: Vec<ScoreDetail>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreDetail {
pub name: String,
pub value: f32,
pub description: String,
pub details: Option<Vec<ScoreDetail>>,
}
pub struct QueryExplainer;
impl QueryExplainer {
pub fn explain(
id: &VectorId,
original_score: f32,
scoring_config: Option<&ScoringConfig>,
metadata: Option<&serde_json::Value>,
) -> QueryExplanation {
let mut details = vec![ScoreDetail {
name: "vector_similarity".into(),
value: original_score,
description: "Base vector similarity score".into(),
details: None,
}];
let mut final_score = original_score;
if let Some(config) = scoring_config {
let mut function_scores = Vec::new();
for (i, func) in config.functions.iter().enumerate() {
let func_score = ScoreFunctionExecutor::apply(func, original_score, metadata);
function_scores.push(func_score);
details.push(ScoreDetail {
name: format!("function_{}", i),
value: func_score,
description: Self::describe_function(func),
details: None,
});
}
if !function_scores.is_empty() {
let combined =
ScoreFunctionExecutor::combine_scores(&function_scores, config.score_mode);
details.push(ScoreDetail {
name: "combined_functions".into(),
value: combined,
description: format!("Functions combined using {:?}", config.score_mode),
details: None,
});
final_score = ScoreFunctionExecutor::combine_with_original(
original_score,
combined,
config.boost_mode,
);
}
}
QueryExplanation {
score: final_score,
description: format!("Explanation for document {}", id),
details,
}
}
fn describe_function(func: &ScoreFunction) -> String {
match func {
ScoreFunction::Weight { weight } => format!("Constant weight: {}", weight),
ScoreFunction::FieldValue { field, factor, .. } => {
format!("Field value boost on '{}' with factor {}", field, factor)
}
ScoreFunction::Decay {
field, decay_type, ..
} => format!("{:?} decay on field '{}'", decay_type, field),
ScoreFunction::RandomScore { seed, .. } => format!("Random score with seed {}", seed),
ScoreFunction::Script { source, .. } => format!("Script: {}", source),
ScoreFunction::GeoDecay { field, .. } => format!("Geo decay on field '{}'", field),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FilterExpression {
Term {
field: String,
value: serde_json::Value,
},
Terms {
field: String,
values: Vec<serde_json::Value>,
},
Range {
field: String,
gte: Option<f64>,
gt: Option<f64>,
lte: Option<f64>,
lt: Option<f64>,
},
Exists { field: String },
Prefix { field: String, prefix: String },
Geo { field: String, filter: GeoFilter },
And(Vec<FilterExpression>),
Or(Vec<FilterExpression>),
Not(Box<FilterExpression>),
}
pub struct FilterExecutor;
impl FilterExecutor {
pub fn matches(filter: &FilterExpression, metadata: Option<&serde_json::Value>) -> bool {
let metadata = match metadata {
Some(m) => m,
None => return false,
};
match filter {
FilterExpression::Term { field, value } => metadata.get(field) == Some(value),
FilterExpression::Terms { field, values } => {
metadata.get(field).is_some_and(|v| values.contains(v))
}
FilterExpression::Range {
field,
gte,
gt,
lte,
lt,
} => {
let value = metadata.get(field).and_then(|v| v.as_f64());
if let Some(v) = value {
gte.is_none_or(|x| v >= x)
&& gt.is_none_or(|x| v > x)
&& lte.is_none_or(|x| v <= x)
&& lt.is_none_or(|x| v < x)
} else {
false
}
}
FilterExpression::Exists { field } => metadata.get(field).is_some(),
FilterExpression::Prefix { field, prefix } => metadata
.get(field)
.and_then(|v| v.as_str())
.is_some_and(|s| s.starts_with(prefix)),
FilterExpression::Geo { field, filter } => {
let point = metadata.get(field).and_then(|v| {
let lat = v.get("lat")?.as_f64()?;
let lon = v.get("lon")?.as_f64()?;
Some(GeoPoint { lat, lon })
});
point.is_some_and(|p| GeoFilterExecutor::matches(filter, &p))
}
FilterExpression::And(filters) => {
filters.iter().all(|f| Self::matches(f, Some(metadata)))
}
FilterExpression::Or(filters) => {
filters.iter().any(|f| Self::matches(f, Some(metadata)))
}
FilterExpression::Not(filter) => !Self::matches(filter, Some(metadata)),
}
}
}
pub struct BatchQueryExecutor {
config: BatchQueryConfig,
}
impl BatchQueryExecutor {
pub fn new(config: BatchQueryConfig) -> Self {
Self { config }
}
pub fn execute(
&self,
request: &BatchQueryRequest,
search_fn: impl Fn(
&[f32],
usize,
Option<&FilterExpression>,
) -> Vec<(VectorId, f32, Option<Vec<f32>>, Option<serde_json::Value>)>,
) -> Result<BatchQueryResponse, BatchSearchError> {
let start = Instant::now();
if request.queries.len() > self.config.max_batch_size {
return Err(BatchSearchError::BatchTooLarge(
request.queries.len(),
self.config.max_batch_size,
));
}
let mut responses: Vec<BatchQueryItemResponse> = Vec::with_capacity(request.queries.len());
let mut query_times: Vec<u64> = Vec::new();
let mut deduplicated_count = 0;
let mut seen_queries: HashMap<Vec<u32>, usize> = HashMap::new();
for query in &request.queries {
let query_start = Instant::now();
let query_hash: Vec<u32> = query.vector.iter().map(|f| f.to_bits()).collect();
if self.config.deduplicate_queries {
if let Some(&existing_idx) = seen_queries.get(&query_hash) {
let mut response = responses[existing_idx].clone();
response.query_id = query.query_id.clone();
responses.push(response);
deduplicated_count += 1;
continue;
}
seen_queries.insert(query_hash, responses.len());
}
let raw_results = search_fn(&query.vector, query.top_k * 2, query.filter.as_ref());
let (results, next_cursor) = self.apply_pagination(&raw_results, query)?;
let scored_results = self.apply_scoring(&results, query.scoring.as_ref());
let hits: Vec<SearchHit> = scored_results
.into_iter()
.take(query.top_k)
.map(|(id, score, vec, meta)| SearchHit {
id,
score,
vector: if request.include_vectors { vec } else { None },
metadata: if request.include_metadata { meta } else { None },
sort_values: vec![SortValue::Score(score)],
})
.collect();
let query_time = query_start.elapsed().as_millis() as u64;
query_times.push(query_time);
responses.push(BatchQueryItemResponse {
query_id: query.query_id.clone(),
results: hits,
next_cursor,
took_ms: query_time,
total_matches: raw_results.len(),
explanation: None,
});
}
let facets = request.facets.as_ref().map(|_| HashMap::new());
let total_took = start.elapsed().as_millis() as u64;
let stats = BatchQueryStats {
total_queries: request.queries.len(),
successful_queries: responses.len(),
failed_queries: 0,
total_took_ms: total_took,
avg_query_ms: if !query_times.is_empty() {
query_times.iter().sum::<u64>() as f64 / query_times.len() as f64
} else {
0.0
},
max_query_ms: query_times.iter().copied().max().unwrap_or(0),
min_query_ms: query_times.iter().copied().min().unwrap_or(0),
deduplicated_count,
};
Ok(BatchQueryResponse {
responses,
facets,
stats,
})
}
fn apply_pagination(
&self,
results: &[SearchResultRow],
query: &BatchQueryItem,
) -> Result<(Vec<SearchResultRow>, Option<SearchCursor>), BatchSearchError> {
if let Some(cursor) = &query.cursor {
match cursor.cursor_type {
CursorType::CursorBased => {
let offset = cursor.parse_offset()?;
let paginated: Vec<_> = results.iter().skip(offset).cloned().collect();
let next_cursor = if offset + query.top_k < results.len() {
Some(SearchCursor::cursor_based(
offset + query.top_k,
results.len(),
))
} else {
None
};
Ok((paginated, next_cursor))
}
CursorType::SearchAfter => {
let sort_values = cursor.parse_sort_values()?;
let start_idx = results
.iter()
.position(|(_, score, _, _)| {
let sv = SortValue::Score(*score);
sv.compare(&sort_values[0]) == std::cmp::Ordering::Greater
})
.unwrap_or(0);
let paginated: Vec<_> = results.iter().skip(start_idx).cloned().collect();
let next_cursor = if start_idx + query.top_k < results.len() {
paginated
.get(query.top_k - 1)
.map(|last| SearchCursor::search_after(&[SortValue::Score(last.1)]))
} else {
None
};
Ok((paginated, next_cursor))
}
}
} else {
let next_cursor = if results.len() > query.top_k {
Some(SearchCursor::cursor_based(query.top_k, results.len()))
} else {
None
};
Ok((results.to_vec(), next_cursor))
}
}
fn apply_scoring(
&self,
results: &[SearchResultRow],
scoring: Option<&ScoringConfig>,
) -> Vec<SearchResultRow> {
let config = match scoring {
Some(c) if !c.functions.is_empty() => c,
_ => return results.to_vec(),
};
let mut scored: Vec<_> = results
.iter()
.map(|(id, score, vec, meta)| {
let function_scores: Vec<f32> = config
.functions
.iter()
.map(|f| ScoreFunctionExecutor::apply(f, *score, meta.as_ref()))
.collect();
let combined =
ScoreFunctionExecutor::combine_scores(&function_scores, config.score_mode);
let final_score = ScoreFunctionExecutor::combine_with_original(
*score,
combined,
config.boost_mode,
);
(id.clone(), final_score, vec.clone(), meta.clone())
})
.collect();
if let Some(min) = config.min_score {
scored.retain(|(_, s, _, _)| *s >= min);
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
}
}
fn base64_encode(input: &str) -> String {
let bytes = input.as_bytes();
let mut result = String::new();
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
for chunk in bytes.chunks(3) {
let mut n = (chunk[0] as u32) << 16;
if chunk.len() > 1 {
n |= (chunk[1] as u32) << 8;
}
if chunk.len() > 2 {
n |= chunk[2] as u32;
}
result.push(CHARS[(n >> 18 & 0x3f) as usize] as char);
result.push(CHARS[(n >> 12 & 0x3f) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[(n >> 6 & 0x3f) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(n & 0x3f) as usize] as char);
} else {
result.push('=');
}
}
result
}
fn base64_decode(input: &str) -> Result<String, &'static str> {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = Vec::new();
let input = input.trim_end_matches('=');
let bytes: Vec<u8> = input.bytes().collect();
for chunk in bytes.chunks(4) {
let mut n = 0u32;
for (i, &b) in chunk.iter().enumerate() {
let pos = CHARS.iter().position(|&c| c == b).ok_or("Invalid base64")?;
n |= (pos as u32) << (18 - i * 6);
}
result.push((n >> 16) as u8);
if chunk.len() > 2 {
result.push((n >> 8) as u8);
}
if chunk.len() > 3 {
result.push(n as u8);
}
}
String::from_utf8(result).map_err(|_| "Invalid UTF-8")
}
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_query_config_default() {
let config = BatchQueryConfig::default();
assert_eq!(config.max_batch_size, 100);
assert_eq!(config.max_concurrency, 16);
assert_eq!(config.query_timeout_ms, 5000);
}
#[test]
fn test_cursor_based_pagination() {
let cursor = SearchCursor::cursor_based(20, 100);
assert_eq!(cursor.cursor_type, CursorType::CursorBased);
let offset = cursor.parse_offset().unwrap();
assert_eq!(offset, 20);
}
#[test]
fn test_search_after_pagination() {
let sort_values = vec![SortValue::Score(0.95), SortValue::String("doc123".into())];
let cursor = SearchCursor::search_after(&sort_values);
assert_eq!(cursor.cursor_type, CursorType::SearchAfter);
let parsed = cursor.parse_sort_values().unwrap();
assert_eq!(parsed.len(), 2);
}
#[test]
fn test_sort_value_comparison() {
assert_eq!(
SortValue::Score(0.9).compare(&SortValue::Score(0.8)),
std::cmp::Ordering::Less
);
assert_eq!(
SortValue::Integer(10).compare(&SortValue::Integer(5)),
std::cmp::Ordering::Greater
);
}
#[test]
fn test_geo_point_distance() {
let nyc = GeoPoint::new(40.7128, -74.0060).unwrap();
let la = GeoPoint::new(34.0522, -118.2437).unwrap();
let distance = nyc.distance_km(&la);
assert!(distance > 3900.0 && distance < 4000.0);
}
#[test]
fn test_geo_point_validation() {
assert!(GeoPoint::new(45.0, 90.0).is_ok());
assert!(GeoPoint::new(91.0, 0.0).is_err());
assert!(GeoPoint::new(0.0, 181.0).is_err());
}
#[test]
fn test_geo_distance_filter() {
let center = GeoPoint::new(40.7128, -74.0060).unwrap();
let filter = GeoFilter::Distance {
center,
distance: 100.0,
unit: DistanceUnit::Kilometers,
};
let nearby = GeoPoint::new(40.8, -74.1).unwrap();
assert!(GeoFilterExecutor::matches(&filter, &nearby));
let far = GeoPoint::new(34.0522, -118.2437).unwrap();
assert!(!GeoFilterExecutor::matches(&filter, &far));
}
#[test]
fn test_geo_bounding_box() {
let filter = GeoFilter::BoundingBox {
top_left: GeoPoint::new(41.0, -75.0).unwrap(),
bottom_right: GeoPoint::new(40.0, -73.0).unwrap(),
};
let inside = GeoPoint::new(40.5, -74.0).unwrap();
assert!(GeoFilterExecutor::matches(&filter, &inside));
let outside = GeoPoint::new(42.0, -74.0).unwrap();
assert!(!GeoFilterExecutor::matches(&filter, &outside));
}
#[test]
fn test_terms_aggregation() {
let executor = FacetExecutor::new(100);
let values: Vec<Option<serde_json::Value>> = vec![
Some(serde_json::json!("cat")),
Some(serde_json::json!("dog")),
Some(serde_json::json!("cat")),
Some(serde_json::json!("bird")),
Some(serde_json::json!("cat")),
];
let buckets = executor.terms_aggregation(&values, 10);
assert_eq!(buckets.len(), 3);
assert_eq!(buckets[0].key, "cat");
assert_eq!(buckets[0].doc_count, 3);
}
#[test]
fn test_range_aggregation() {
let executor = FacetExecutor::new(100);
let values: Vec<Option<f64>> =
vec![Some(5.0), Some(15.0), Some(25.0), Some(35.0), Some(45.0)];
let ranges = vec![
RangeBucket {
key: "low".into(),
from: None,
to: Some(20.0),
},
RangeBucket {
key: "medium".into(),
from: Some(20.0),
to: Some(40.0),
},
RangeBucket {
key: "high".into(),
from: Some(40.0),
to: None,
},
];
let buckets = executor.range_aggregation(&values, &ranges);
assert_eq!(buckets.len(), 3);
assert_eq!(buckets[0].doc_count, 2); assert_eq!(buckets[1].doc_count, 2); assert_eq!(buckets[2].doc_count, 1); }
#[test]
fn test_numeric_aggregations() {
let executor = FacetExecutor::new(100);
let values: Vec<Option<f64>> = vec![Some(10.0), Some(20.0), Some(30.0)];
assert_eq!(
executor.numeric_aggregation(&values, &AggregationType::Min),
10.0
);
assert_eq!(
executor.numeric_aggregation(&values, &AggregationType::Max),
30.0
);
assert_eq!(
executor.numeric_aggregation(&values, &AggregationType::Avg),
20.0
);
assert_eq!(
executor.numeric_aggregation(&values, &AggregationType::Sum),
60.0
);
}
#[test]
fn test_score_function_weight() {
let func = ScoreFunction::Weight { weight: 2.0 };
let score = ScoreFunctionExecutor::apply(&func, 0.5, None);
assert_eq!(score, 2.0);
}
#[test]
fn test_score_function_field_value() {
let func = ScoreFunction::FieldValue {
field: "popularity".into(),
factor: 0.1,
modifier: FieldValueModifier::Log1p,
missing: 1.0,
};
let metadata = serde_json::json!({"popularity": 100.0});
let score = ScoreFunctionExecutor::apply(&func, 0.5, Some(&metadata));
assert!(score > 0.19 && score < 0.21);
}
#[test]
fn test_combine_scores() {
let scores = vec![2.0f32, 3.0, 4.0];
assert_eq!(
ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Multiply),
24.0
);
assert_eq!(
ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Sum),
9.0
);
assert_eq!(
ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Average),
3.0
);
assert_eq!(
ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Max),
4.0
);
assert_eq!(
ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Min),
2.0
);
}
#[test]
fn test_filter_term() {
let filter = FilterExpression::Term {
field: "category".into(),
value: serde_json::json!("tech"),
};
let metadata = serde_json::json!({"category": "tech"});
assert!(FilterExecutor::matches(&filter, Some(&metadata)));
let other = serde_json::json!({"category": "science"});
assert!(!FilterExecutor::matches(&filter, Some(&other)));
}
#[test]
fn test_filter_range() {
let filter = FilterExpression::Range {
field: "price".into(),
gte: Some(10.0),
gt: None,
lte: Some(100.0),
lt: None,
};
let match1 = serde_json::json!({"price": 50});
assert!(FilterExecutor::matches(&filter, Some(&match1)));
let nomatch = serde_json::json!({"price": 5});
assert!(!FilterExecutor::matches(&filter, Some(&nomatch)));
}
#[test]
fn test_filter_boolean_and() {
let filter = FilterExpression::And(vec![
FilterExpression::Term {
field: "category".into(),
value: serde_json::json!("tech"),
},
FilterExpression::Range {
field: "price".into(),
gte: Some(10.0),
gt: None,
lte: None,
lt: None,
},
]);
let match1 = serde_json::json!({"category": "tech", "price": 50});
assert!(FilterExecutor::matches(&filter, Some(&match1)));
let nomatch = serde_json::json!({"category": "tech", "price": 5});
assert!(!FilterExecutor::matches(&filter, Some(&nomatch)));
}
#[test]
fn test_query_explanation() {
let id = "doc123".to_string();
let scoring = ScoringConfig {
functions: vec![ScoreFunction::Weight { weight: 1.5 }],
..Default::default()
};
let explanation = QueryExplainer::explain(&id, 0.8, Some(&scoring), None);
assert!(explanation.score > 0.0);
assert!(!explanation.details.is_empty());
assert!(explanation.description.contains("doc123"));
}
#[test]
fn test_batch_query_executor() {
let config = BatchQueryConfig::default();
let executor = BatchQueryExecutor::new(config);
let request = BatchQueryRequest {
namespace: "test".into(),
queries: vec![
BatchQueryItem {
query_id: "q1".into(),
vector: vec![1.0, 0.0, 0.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
BatchQueryItem {
query_id: "q2".into(),
vector: vec![0.0, 1.0, 0.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
],
include_vectors: false,
include_metadata: true,
facets: None,
};
let search_fn = |_vector: &[f32], _top_k: usize, _filter: Option<&FilterExpression>| {
vec![
(
"doc1".into(),
0.9,
None,
Some(serde_json::json!({"cat": "a"})),
),
(
"doc2".into(),
0.8,
None,
Some(serde_json::json!({"cat": "b"})),
),
]
};
let response = executor.execute(&request, search_fn).unwrap();
assert_eq!(response.responses.len(), 2);
assert_eq!(response.stats.total_queries, 2);
assert_eq!(response.stats.successful_queries, 2);
}
#[test]
fn test_batch_too_large_error() {
let config = BatchQueryConfig {
max_batch_size: 2,
..Default::default()
};
let executor = BatchQueryExecutor::new(config);
let request = BatchQueryRequest {
namespace: "test".into(),
queries: vec![
BatchQueryItem {
query_id: "q1".into(),
vector: vec![1.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
BatchQueryItem {
query_id: "q2".into(),
vector: vec![1.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
BatchQueryItem {
query_id: "q3".into(),
vector: vec![1.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
],
include_vectors: false,
include_metadata: false,
facets: None,
};
let result = executor.execute(&request, |_, _, _| vec![]);
assert!(matches!(result, Err(BatchSearchError::BatchTooLarge(3, 2))));
}
#[test]
fn test_base64_roundtrip() {
let original = "hello:world:123";
let encoded = base64_encode(original);
let decoded = base64_decode(&encoded).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn test_distance_unit_conversion() {
assert_eq!(DistanceUnit::Meters.to_meters(100.0), 100.0);
assert_eq!(DistanceUnit::Kilometers.to_meters(1.0), 1000.0);
assert!((DistanceUnit::Miles.to_meters(1.0) - 1609.34).abs() < 0.01);
}
#[test]
fn test_query_deduplication() {
let config = BatchQueryConfig {
deduplicate_queries: true,
..Default::default()
};
let executor = BatchQueryExecutor::new(config);
let request = BatchQueryRequest {
namespace: "test".into(),
queries: vec![
BatchQueryItem {
query_id: "q1".into(),
vector: vec![1.0, 0.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
BatchQueryItem {
query_id: "q2".into(),
vector: vec![1.0, 0.0],
top_k: 5,
filter: None,
cursor: None,
scoring: None,
},
],
include_vectors: false,
include_metadata: false,
facets: None,
};
let response = executor
.execute(&request, |_, _, _| vec![("doc1".into(), 0.9, None, None)])
.unwrap();
assert_eq!(response.stats.deduplicated_count, 1);
}
}