use crate::{graph::GraphTree, VectorEntry, EMBEDDING_DIMENSION};
use serde::{Deserialize, Serialize};
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchOptions {
pub limit: usize,
pub threshold: f32,
pub use_graph_boost: bool,
pub category: Option<String>,
pub namespace_id: Option<i64>,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
limit: 10,
threshold: 0.5,
use_graph_boost: true,
category: None,
namespace_id: None,
}
}
}
impl SearchOptions {
pub fn with_limit(limit: usize) -> Self {
Self {
limit,
..Default::default()
}
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
pub fn with_namespace(mut self, namespace_id: i64) -> Self {
self.namespace_id = Some(namespace_id);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: i64,
pub score: f32,
pub boosted_score: Option<f32>,
pub entry: VectorEntry,
}
impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for SearchResult {}
impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let self_score = self.boosted_score.unwrap_or(self.score);
let other_score = other.boosted_score.unwrap_or(other.score);
other_score
.partial_cmp(&self_score)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
pub struct SemanticSearch {
graph_tree: GraphTree,
}
impl SemanticSearch {
pub fn new() -> Self {
Self {
graph_tree: GraphTree::new(),
}
}
pub fn with_graph_tree(graph_tree: GraphTree) -> Self {
Self { graph_tree }
}
pub fn search(
&self,
query: &[f32],
vectors: &[VectorEntry],
options: &SearchOptions,
) -> crate::Result<(Vec<SearchResult>, crate::SearchLatency)> {
let start = Instant::now();
let vector_start = Instant::now();
if query.len() != EMBEDDING_DIMENSION {
return Err(nexus_core::NexusError::InvalidInput(format!(
"Query dimension mismatch: expected {}, got {}",
EMBEDDING_DIMENSION,
query.len()
)));
}
let mut results: Vec<SearchResult> = vectors
.iter()
.filter(|v| {
if let Some(ref cat) = options.category {
if v.category != *cat {
return false;
}
}
if let Some(ns) = options.namespace_id {
if v.namespace_id != ns {
return false;
}
}
true
})
.filter_map(|entry| {
let score = cosine_similarity(query, &entry.embedding);
if score >= options.threshold {
Some((entry, score))
} else {
None
}
})
.map(|(entry, score)| {
let boosted_score = if options.use_graph_boost {
Some(self.graph_tree.calculate_boosted_score(entry.id, score))
} else {
None
};
SearchResult {
id: entry.id,
score,
boosted_score,
entry: entry.clone(),
}
})
.collect();
let vector_time = vector_start.elapsed().as_millis() as u64;
results.sort();
results.truncate(options.limit);
let total_time = start.elapsed().as_millis() as u64;
let latency = crate::SearchLatency {
total_ms: total_time,
vector_comparison_ms: vector_time,
graph_traversal_ms: if options.use_graph_boost {
Some(total_time.saturating_sub(vector_time))
} else {
None
},
};
Ok((results, latency))
}
pub fn graph_tree(&self) -> &GraphTree {
&self.graph_tree
}
pub fn graph_tree_mut(&mut self) -> &mut GraphTree {
&mut self.graph_tree
}
}
impl Default for SemanticSearch {
fn default() -> Self {
Self::new()
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_entry(id: i64, embedding: Vec<f32>) -> VectorEntry {
VectorEntry::new(id, embedding, "general".to_string(), 1)
}
#[test]
fn test_search_options_default() {
let opts = SearchOptions::default();
assert_eq!(opts.limit, 10);
assert!((opts.threshold - 0.5).abs() < 0.01);
}
#[test]
fn test_search_options_builder() {
let opts = SearchOptions::with_limit(5)
.with_threshold(0.8)
.with_category("facts");
assert_eq!(opts.limit, 5);
assert!((opts.threshold - 0.8).abs() < 0.01);
assert_eq!(opts.category, Some("facts".to_string()));
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
}
#[test]
fn test_semantic_search_basic() {
let search = SemanticSearch::new();
let query = vec![1.0; EMBEDDING_DIMENSION];
let vectors = vec![
create_test_entry(1, vec![0.9; EMBEDDING_DIMENSION]),
create_test_entry(2, vec![0.1; EMBEDDING_DIMENSION]),
];
let opts = SearchOptions::with_limit(10).with_threshold(0.0);
let (results, latency) = search.search(&query, &vectors, &opts).unwrap();
assert_eq!(results.len(), 2);
assert!(latency.total_ms < 5_000);
assert!(results[0].score >= results[1].score);
}
#[test]
fn test_semantic_search_threshold() {
let search = SemanticSearch::new();
let query = vec![1.0; EMBEDDING_DIMENSION];
let vectors = vec![
create_test_entry(1, vec![1.0; EMBEDDING_DIMENSION]), create_test_entry(2, vec![0.0; EMBEDDING_DIMENSION]), ];
let opts = SearchOptions::with_limit(10).with_threshold(0.9);
let (results, _) = search.search(&query, &vectors, &opts).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_semantic_search_category_filter() {
let search = SemanticSearch::new();
let query = vec![1.0; EMBEDDING_DIMENSION];
let mut entry1 = create_test_entry(1, vec![1.0; EMBEDDING_DIMENSION]);
entry1.category = "facts".to_string();
let mut entry2 = create_test_entry(2, vec![1.0; EMBEDDING_DIMENSION]);
entry2.category = "general".to_string();
let vectors = vec![entry1, entry2];
let opts = SearchOptions::with_limit(10)
.with_threshold(0.0)
.with_category("facts");
let (results, _) = search.search(&query, &vectors, &opts).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].entry.category, "facts");
}
#[test]
fn test_search_result_ordering() {
let r1 = SearchResult {
id: 1,
score: 0.9,
boosted_score: None,
entry: create_test_entry(1, vec![0.1; EMBEDDING_DIMENSION]),
};
let r2 = SearchResult {
id: 2,
score: 0.8,
boosted_score: None,
entry: create_test_entry(2, vec![0.1; EMBEDDING_DIMENSION]),
};
assert!(r1 < r2); }
}