use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use common::{DistanceMetric, VectorId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedSearchConfig {
pub enable_mmr: bool,
pub mmr_lambda: f32,
pub mmr_candidates: usize,
pub enable_grouping: bool,
pub group_by_field: Option<String>,
pub max_per_group: usize,
}
impl Default for AdvancedSearchConfig {
fn default() -> Self {
Self {
enable_mmr: false,
mmr_lambda: 0.5,
mmr_candidates: 100,
enable_grouping: false,
group_by_field: None,
max_per_group: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorQuery {
pub positive_vectors: Vec<Vec<f32>>,
pub positive_weights: Vec<f32>,
pub negative_vectors: Vec<Vec<f32>>,
pub negative_weights: Vec<f32>,
pub top_k: usize,
pub distance_threshold: Option<f32>,
}
impl MultiVectorQuery {
pub fn single(vector: Vec<f32>, top_k: usize) -> Self {
Self {
positive_vectors: vec![vector],
positive_weights: vec![1.0],
negative_vectors: Vec::new(),
negative_weights: Vec::new(),
top_k,
distance_threshold: None,
}
}
pub fn multi(vectors: Vec<Vec<f32>>, top_k: usize) -> Self {
let weights = vec![1.0 / vectors.len() as f32; vectors.len()];
Self {
positive_vectors: vectors,
positive_weights: weights,
negative_vectors: Vec::new(),
negative_weights: Vec::new(),
top_k,
distance_threshold: None,
}
}
pub fn with_negative(mut self, vector: Vec<f32>, weight: f32) -> Self {
self.negative_vectors.push(vector);
self.negative_weights.push(weight);
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.distance_threshold = Some(threshold);
self
}
pub fn with_weights(mut self, weights: Vec<f32>) -> Self {
self.positive_weights = weights;
self
}
pub fn compute_query_vector(&self, dimensions: usize) -> Vec<f32> {
let mut result = vec![0.0; dimensions];
for (vec, &weight) in self.positive_vectors.iter().zip(&self.positive_weights) {
for (i, &v) in vec.iter().enumerate() {
if i < dimensions {
result[i] += v * weight;
}
}
}
for (vec, &weight) in self.negative_vectors.iter().zip(&self.negative_weights) {
for (i, &v) in vec.iter().enumerate() {
if i < dimensions {
result[i] -= v * weight;
}
}
}
let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut result {
*v /= norm;
}
}
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedSearchResult {
pub id: VectorId,
pub score: f32,
pub original_rank: usize,
pub final_rank: usize,
pub mmr_score: Option<f32>,
pub group_key: Option<String>,
}
pub struct MmrReranker {
lambda: f32,
}
impl MmrReranker {
pub fn new(lambda: f32) -> Self {
Self {
lambda: lambda.clamp(0.0, 1.0),
}
}
pub fn rerank(
&self,
candidates: &[(VectorId, f32, Vec<f32>)], top_k: usize,
) -> Vec<AdvancedSearchResult> {
if candidates.is_empty() {
return Vec::new();
}
let mut selected: Vec<usize> = Vec::with_capacity(top_k);
let mut remaining: HashSet<usize> = (0..candidates.len()).collect();
let mut results = Vec::with_capacity(top_k);
let first_idx = candidates
.iter()
.enumerate()
.max_by(|a, b| {
a.1 .1
.partial_cmp(&b.1 .1)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
selected.push(first_idx);
remaining.remove(&first_idx);
results.push(AdvancedSearchResult {
id: candidates[first_idx].0.clone(),
score: candidates[first_idx].1,
original_rank: first_idx,
final_rank: 0,
mmr_score: Some(candidates[first_idx].1),
group_key: None,
});
while results.len() < top_k && !remaining.is_empty() {
let mut best_idx = None;
let mut best_mmr = f32::NEG_INFINITY;
for &idx in &remaining {
let relevance = candidates[idx].1;
let max_sim = selected
.iter()
.map(|&sel_idx| {
self.cosine_similarity(&candidates[idx].2, &candidates[sel_idx].2)
})
.fold(f32::NEG_INFINITY, f32::max);
let mmr = self.lambda * relevance - (1.0 - self.lambda) * max_sim;
if mmr > best_mmr {
best_mmr = mmr;
best_idx = Some(idx);
}
}
if let Some(idx) = best_idx {
selected.push(idx);
remaining.remove(&idx);
results.push(AdvancedSearchResult {
id: candidates[idx].0.clone(),
score: candidates[idx].1,
original_rank: idx,
final_rank: results.len(),
mmr_score: Some(best_mmr),
group_key: None,
});
} else {
break;
}
}
results
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
}
pub struct RangeQuery {
metric: DistanceMetric,
threshold: f32,
}
impl RangeQuery {
pub fn new(metric: DistanceMetric, threshold: f32) -> Self {
Self { metric, threshold }
}
pub fn filter(&self, results: Vec<(VectorId, f32)>) -> Vec<(VectorId, f32)> {
results
.into_iter()
.filter(|(_, score)| self.passes_threshold(*score))
.collect()
}
fn passes_threshold(&self, score: f32) -> bool {
match self.metric {
DistanceMetric::Cosine | DistanceMetric::DotProduct => score >= self.threshold,
DistanceMetric::Euclidean => score >= -self.threshold,
}
}
}
pub struct ResultGrouper {
group_field: String,
max_per_group: usize,
}
impl ResultGrouper {
pub fn new(group_field: String, max_per_group: usize) -> Self {
Self {
group_field,
max_per_group,
}
}
pub fn group(
&self,
results: Vec<(VectorId, f32, Option<serde_json::Value>)>,
) -> HashMap<String, Vec<(VectorId, f32)>> {
let mut groups: HashMap<String, Vec<(VectorId, f32)>> = HashMap::new();
for (id, score, metadata) in results {
let group_key = metadata
.and_then(|m| m.get(&self.group_field).cloned())
.and_then(|v| match v {
serde_json::Value::String(s) => Some(s),
serde_json::Value::Number(n) => Some(n.to_string()),
_ => None,
})
.unwrap_or_else(|| "_ungrouped".to_string());
let group = groups.entry(group_key).or_default();
if group.len() < self.max_per_group {
group.push((id, score));
}
}
groups
}
}
pub struct AdvancedSearchExecutor {
config: AdvancedSearchConfig,
}
impl AdvancedSearchExecutor {
pub fn new(config: AdvancedSearchConfig) -> Self {
Self { config }
}
pub fn process_results(
&self,
candidates: Vec<(VectorId, f32, Vec<f32>)>,
query: &MultiVectorQuery,
) -> Vec<AdvancedSearchResult> {
let mut results = candidates;
if let Some(threshold) = query.distance_threshold {
results.retain(|(_, score, _)| *score >= threshold);
}
if self.config.enable_mmr {
let reranker = MmrReranker::new(self.config.mmr_lambda);
return reranker.rerank(&results, query.top_k);
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(query.top_k);
results
.into_iter()
.enumerate()
.map(|(rank, (id, score, _))| AdvancedSearchResult {
id,
score,
original_rank: rank,
final_rank: rank,
mmr_score: None,
group_key: None,
})
.collect()
}
pub fn apply_negative_penalty(
&self,
results: &mut [(VectorId, f32, Vec<f32>)],
negative_vectors: &[Vec<f32>],
negative_weights: &[f32],
) {
for (_, score, vec) in results.iter_mut() {
for (neg_vec, &weight) in negative_vectors.iter().zip(negative_weights) {
let sim = self.cosine_similarity(vec, neg_vec);
*score -= sim * weight;
}
}
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SearchStats {
pub candidates_considered: usize,
pub after_threshold: usize,
pub after_mmr: usize,
pub num_groups: usize,
pub latency_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_vector_query() {
let query = MultiVectorQuery::multi(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]], 10);
assert_eq!(query.positive_vectors.len(), 2);
assert_eq!(query.positive_weights.len(), 2);
assert_eq!(query.positive_weights[0], 0.5);
}
#[test]
fn test_query_vector_computation() {
let query = MultiVectorQuery::single(vec![1.0, 0.0, 0.0], 10);
let computed = query.compute_query_vector(3);
assert_eq!(computed.len(), 3);
assert!((computed[0] - 1.0).abs() < 0.01);
}
#[test]
fn test_negative_vector() {
let query = MultiVectorQuery::single(vec![1.0, 0.0, 0.0], 10)
.with_negative(vec![0.0, 1.0, 0.0], 0.5);
assert_eq!(query.negative_vectors.len(), 1);
assert_eq!(query.negative_weights[0], 0.5);
}
#[test]
fn test_mmr_reranker() {
let reranker = MmrReranker::new(0.5);
let candidates = vec![
("a".to_string(), 0.9, vec![1.0, 0.0, 0.0]),
("b".to_string(), 0.85, vec![0.95, 0.1, 0.0]), ("c".to_string(), 0.8, vec![0.0, 1.0, 0.0]), ("d".to_string(), 0.75, vec![0.0, 0.0, 1.0]), ];
let results = reranker.rerank(&candidates, 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, "a");
}
#[test]
fn test_range_query() {
let range = RangeQuery::new(DistanceMetric::Cosine, 0.8);
let results = vec![
("a".to_string(), 0.95),
("b".to_string(), 0.75), ("c".to_string(), 0.85),
];
let filtered = range.filter(results);
assert_eq!(filtered.len(), 2);
assert!(filtered.iter().all(|(_, s)| *s >= 0.8));
}
#[test]
fn test_result_grouper() {
let grouper = ResultGrouper::new("category".to_string(), 2);
let results = vec![
(
"a".to_string(),
0.9,
Some(serde_json::json!({"category": "tech"})),
),
(
"b".to_string(),
0.85,
Some(serde_json::json!({"category": "tech"})),
),
(
"c".to_string(),
0.8,
Some(serde_json::json!({"category": "tech"})),
), (
"d".to_string(),
0.75,
Some(serde_json::json!({"category": "science"})),
),
];
let groups = grouper.group(results);
assert_eq!(groups.len(), 2);
assert_eq!(groups["tech"].len(), 2);
assert_eq!(groups["science"].len(), 1);
}
#[test]
fn test_advanced_search_executor() {
let config = AdvancedSearchConfig {
enable_mmr: false,
..Default::default()
};
let executor = AdvancedSearchExecutor::new(config);
let candidates = vec![
("a".to_string(), 0.9, vec![1.0, 0.0]),
("b".to_string(), 0.8, vec![0.0, 1.0]),
("c".to_string(), 0.7, vec![0.5, 0.5]),
];
let query = MultiVectorQuery::single(vec![1.0, 0.0], 2);
let results = executor.process_results(candidates, &query);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "a");
assert_eq!(results[1].id, "b");
}
#[test]
fn test_threshold_filtering() {
let config = AdvancedSearchConfig::default();
let executor = AdvancedSearchExecutor::new(config);
let candidates = vec![
("a".to_string(), 0.9, vec![1.0, 0.0]),
("b".to_string(), 0.5, vec![0.0, 1.0]), ("c".to_string(), 0.85, vec![0.5, 0.5]),
];
let query = MultiVectorQuery::single(vec![1.0, 0.0], 10).with_threshold(0.7);
let results = executor.process_results(candidates, &query);
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.score >= 0.7));
}
}