use common::{DistanceMetric, PaginationCursor, QueryResponse, SearchResult, Vector};
use crate::distance::calculate_distance;
pub fn brute_force_search(
query: &[f32],
vectors: &[Vector],
top_k: usize,
metric: DistanceMetric,
include_metadata: bool,
include_vectors: bool,
cursor: Option<&PaginationCursor>,
) -> QueryResponse {
let top_k = if top_k == 0 {
tracing::warn!("top_k of 0 is invalid, using 1");
1
} else if top_k > 10_000 {
tracing::warn!("top_k {} exceeds maximum, clamping to 10000", top_k);
10_000
} else {
top_k
};
if vectors.is_empty() {
return QueryResponse {
results: vec![],
next_cursor: None,
has_more: Some(false),
search_time_ms: 0,
};
}
let mut scored: Vec<(f32, &Vector)> = vectors
.iter()
.map(|v| (calculate_distance(query, &v.values, metric), v))
.collect();
scored.sort_by(
|a, b| match b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal) {
std::cmp::Ordering::Equal => a.1.id.cmp(&b.1.id),
other => other,
},
);
let filtered: Vec<_> = if let Some(cursor) = cursor {
scored
.into_iter()
.filter(|(score, vector)| {
*score < cursor.last_score
|| (*score == cursor.last_score && vector.id > cursor.last_id)
})
.collect()
} else {
scored
};
let fetch_count = top_k + 1;
let fetched: Vec<_> = filtered.into_iter().take(fetch_count).collect();
let has_more = fetched.len() > top_k;
let results_slice = if has_more {
&fetched[..top_k]
} else {
&fetched[..]
};
let results: Vec<SearchResult> = results_slice
.iter()
.map(|(score, vector)| SearchResult {
id: vector.id.clone(),
score: *score,
metadata: if include_metadata {
vector.metadata.clone()
} else {
None
},
vector: if include_vectors {
Some(vector.values.clone())
} else {
None
},
})
.collect();
let next_cursor = if has_more {
results.last().map(|last_result| {
PaginationCursor::new(last_result.score, last_result.id.clone()).encode()
})
} else {
None
};
QueryResponse {
results,
next_cursor,
has_more: Some(has_more),
search_time_ms: 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_vector(id: &str, values: Vec<f32>) -> Vector {
Vector {
id: id.to_string(),
values,
metadata: None,
ttl_seconds: None,
expires_at: None,
}
}
fn make_vector_with_metadata(
id: &str,
values: Vec<f32>,
metadata: serde_json::Value,
) -> Vector {
Vector {
id: id.to_string(),
values,
metadata: Some(metadata),
ttl_seconds: None,
expires_at: None,
}
}
#[test]
fn test_brute_force_search_empty() {
let query = vec![1.0, 0.0, 0.0];
let vectors: Vec<Vector> = vec![];
let result = brute_force_search(
&query,
&vectors,
5,
DistanceMetric::Cosine,
true,
false,
None,
);
assert!(result.results.is_empty());
assert_eq!(result.has_more, Some(false));
assert!(result.next_cursor.is_none());
}
#[test]
fn test_brute_force_search_single_vector() {
let query = vec![1.0, 0.0, 0.0];
let vectors = vec![make_vector("v1", vec![1.0, 0.0, 0.0])];
let result = brute_force_search(
&query,
&vectors,
5,
DistanceMetric::Cosine,
true,
false,
None,
);
assert_eq!(result.results.len(), 1);
assert_eq!(result.results[0].id, "v1");
assert!((result.results[0].score - 1.0).abs() < 1e-6);
assert_eq!(result.has_more, Some(false));
}
#[test]
fn test_brute_force_search_ordering() {
let query = vec![1.0, 0.0, 0.0];
let vectors = vec![
make_vector("v1", vec![1.0, 0.0, 0.0]), make_vector("v2", vec![0.0, 1.0, 0.0]), make_vector("v3", vec![0.707, 0.707, 0.0]), ];
let result = brute_force_search(
&query,
&vectors,
3,
DistanceMetric::Cosine,
true,
false,
None,
);
assert_eq!(result.results.len(), 3);
assert_eq!(result.results[0].id, "v1"); assert_eq!(result.results[1].id, "v3"); assert_eq!(result.results[2].id, "v2"); }
#[test]
fn test_brute_force_search_top_k() {
let query = vec![1.0, 0.0];
let vectors = vec![
make_vector("v1", vec![1.0, 0.0]),
make_vector("v2", vec![0.9, 0.1]),
make_vector("v3", vec![0.8, 0.2]),
make_vector("v4", vec![0.7, 0.3]),
make_vector("v5", vec![0.6, 0.4]),
];
let result = brute_force_search(
&query,
&vectors,
3,
DistanceMetric::Cosine,
true,
false,
None,
);
assert_eq!(result.results.len(), 3);
assert_eq!(result.results[0].id, "v1");
assert_eq!(result.has_more, Some(true)); assert!(result.next_cursor.is_some());
}
#[test]
fn test_brute_force_search_include_metadata() {
let query = vec![1.0, 0.0];
let vectors = vec![make_vector_with_metadata(
"v1",
vec![1.0, 0.0],
json!({"key": "value"}),
)];
let result = brute_force_search(
&query,
&vectors,
1,
DistanceMetric::Cosine,
true,
false,
None,
);
assert!(result.results[0].metadata.is_some());
let result = brute_force_search(
&query,
&vectors,
1,
DistanceMetric::Cosine,
false,
false,
None,
);
assert!(result.results[0].metadata.is_none());
}
#[test]
fn test_brute_force_search_include_vectors() {
let query = vec![1.0, 0.0];
let vectors = vec![make_vector("v1", vec![1.0, 0.0])];
let result = brute_force_search(
&query,
&vectors,
1,
DistanceMetric::Cosine,
false,
true,
None,
);
assert!(result.results[0].vector.is_some());
assert_eq!(result.results[0].vector.as_ref().unwrap(), &vec![1.0, 0.0]);
let result = brute_force_search(
&query,
&vectors,
1,
DistanceMetric::Cosine,
false,
false,
None,
);
assert!(result.results[0].vector.is_none());
}
#[test]
fn test_brute_force_search_euclidean() {
let query = vec![0.0, 0.0];
let vectors = vec![
make_vector("v1", vec![1.0, 0.0]), make_vector("v2", vec![3.0, 4.0]), make_vector("v3", vec![0.5, 0.0]), ];
let result = brute_force_search(
&query,
&vectors,
3,
DistanceMetric::Euclidean,
false,
false,
None,
);
assert_eq!(result.results[0].id, "v3");
assert_eq!(result.results[1].id, "v1");
assert_eq!(result.results[2].id, "v2");
}
#[test]
fn test_pagination_basic() {
let query = vec![1.0, 0.0];
let vectors = vec![
make_vector("v1", vec![1.0, 0.0]),
make_vector("v2", vec![0.9, 0.1]),
make_vector("v3", vec![0.8, 0.2]),
make_vector("v4", vec![0.7, 0.3]),
make_vector("v5", vec![0.6, 0.4]),
];
let result1 = brute_force_search(
&query,
&vectors,
2,
DistanceMetric::Cosine,
false,
false,
None,
);
assert_eq!(result1.results.len(), 2);
assert_eq!(result1.results[0].id, "v1");
assert_eq!(result1.results[1].id, "v2");
assert_eq!(result1.has_more, Some(true));
assert!(result1.next_cursor.is_some());
let cursor = PaginationCursor::decode(result1.next_cursor.as_ref().unwrap()).unwrap();
let result2 = brute_force_search(
&query,
&vectors,
2,
DistanceMetric::Cosine,
false,
false,
Some(&cursor),
);
assert_eq!(result2.results.len(), 2);
assert_eq!(result2.results[0].id, "v3");
assert_eq!(result2.results[1].id, "v4");
assert_eq!(result2.has_more, Some(true));
let cursor2 = PaginationCursor::decode(result2.next_cursor.as_ref().unwrap()).unwrap();
let result3 = brute_force_search(
&query,
&vectors,
2,
DistanceMetric::Cosine,
false,
false,
Some(&cursor2),
);
assert_eq!(result3.results.len(), 1);
assert_eq!(result3.results[0].id, "v5");
assert_eq!(result3.has_more, Some(false));
assert!(result3.next_cursor.is_none());
}
#[test]
fn test_pagination_cursor_encode_decode() {
let cursor = PaginationCursor::new(0.95, "test_id".to_string());
let encoded = cursor.encode();
let decoded = PaginationCursor::decode(&encoded).unwrap();
assert!((decoded.last_score - 0.95).abs() < 1e-6);
assert_eq!(decoded.last_id, "test_id");
}
#[test]
fn test_pagination_with_tie_scores() {
let query = vec![1.0, 0.0];
let vectors = vec![
make_vector("a", vec![1.0, 0.0]),
make_vector("b", vec![1.0, 0.0]),
make_vector("c", vec![1.0, 0.0]),
make_vector("d", vec![1.0, 0.0]),
];
let result1 = brute_force_search(
&query,
&vectors,
2,
DistanceMetric::Cosine,
false,
false,
None,
);
assert_eq!(result1.results.len(), 2);
assert_eq!(result1.results[0].id, "a");
assert_eq!(result1.results[1].id, "b");
assert_eq!(result1.has_more, Some(true));
let cursor = PaginationCursor::decode(result1.next_cursor.as_ref().unwrap()).unwrap();
let result2 = brute_force_search(
&query,
&vectors,
2,
DistanceMetric::Cosine,
false,
false,
Some(&cursor),
);
assert_eq!(result2.results.len(), 2);
assert_eq!(result2.results[0].id, "c");
assert_eq!(result2.results[1].id, "d");
assert_eq!(result2.has_more, Some(false));
}
#[test]
fn test_pagination_no_more_results() {
let query = vec![1.0, 0.0];
let vectors = vec![
make_vector("v1", vec![1.0, 0.0]),
make_vector("v2", vec![0.9, 0.1]),
];
let result = brute_force_search(
&query,
&vectors,
5,
DistanceMetric::Cosine,
false,
false,
None,
);
assert_eq!(result.results.len(), 2);
assert_eq!(result.has_more, Some(false));
assert!(result.next_cursor.is_none());
}
}