Skip to main content

git_semantic/search/
engine.rs

1use ndarray::Array1;
2use tracing::debug;
3
4use crate::cli::SearchFilters;
5use crate::embedding::ModelManager;
6use crate::index::SemanticIndex;
7
8use super::filter::FilterEngine;
9use super::{SearchError, SearchResult};
10
11pub struct SearchEngine {
12    model_manager: ModelManager,
13}
14
15impl SearchEngine {
16    pub fn new(mut model_manager: ModelManager) -> Result<Self, SearchError> {
17        model_manager.init()?;
18        Ok(Self { model_manager })
19    }
20
21    pub fn search(
22        &mut self,
23        index: &SemanticIndex,
24        query: &str,
25        num_results: usize,
26        filters: SearchFilters,
27    ) -> Result<Vec<SearchResult>, SearchError> {
28        debug!("Searching for: {}", query);
29
30        let query_embedding = self.model_manager.encode_text(query)?;
31
32        let mut results: Vec<SearchResult> = index
33            .entries
34            .iter()
35            .enumerate()
36            .map(|(idx, entry)| {
37                let embedding = Array1::from_vec(entry.embedding.clone());
38                let similarity = cosine_similarity(&query_embedding, &embedding);
39
40                SearchResult {
41                    commit: entry.commit.clone(),
42                    similarity,
43                    rank: idx + 1,
44                }
45            })
46            .collect();
47
48        let filter_engine = FilterEngine::new(filters);
49        results = filter_engine.apply(results)?;
50
51        results.sort_by(|a, b| {
52            b.similarity
53                .partial_cmp(&a.similarity)
54                .unwrap_or(std::cmp::Ordering::Equal)
55        });
56
57        results.truncate(num_results);
58
59        for (idx, result) in results.iter_mut().enumerate() {
60            result.rank = idx + 1;
61        }
62
63        Ok(results)
64    }
65}
66
67pub(crate) fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
68    let dot_product = a.dot(b);
69    let norm_a = a.dot(a).sqrt();
70    let norm_b = b.dot(b).sqrt();
71
72    if norm_a == 0.0 || norm_b == 0.0 {
73        return 0.0;
74    }
75
76    dot_product / (norm_a * norm_b)
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn test_cosine_similarity_identical_vectors() {
85        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
86        let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
87        let sim = cosine_similarity(&a, &b);
88        assert!(
89            (sim - 1.0).abs() < 1e-6,
90            "identical vectors should have similarity ~1.0, got {sim}"
91        );
92    }
93
94    #[test]
95    fn test_cosine_similarity_orthogonal_vectors() {
96        let a = Array1::from_vec(vec![1.0, 0.0, 0.0]);
97        let b = Array1::from_vec(vec![0.0, 1.0, 0.0]);
98        let sim = cosine_similarity(&a, &b);
99        assert!(
100            sim.abs() < 1e-6,
101            "orthogonal vectors should have similarity ~0.0, got {sim}"
102        );
103    }
104
105    #[test]
106    fn test_cosine_similarity_opposite_vectors() {
107        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
108        let b = Array1::from_vec(vec![-1.0, -2.0, -3.0]);
109        let sim = cosine_similarity(&a, &b);
110        assert!(
111            (sim - (-1.0)).abs() < 1e-6,
112            "opposite vectors should have similarity ~-1.0, got {sim}"
113        );
114    }
115
116    #[test]
117    fn test_cosine_similarity_zero_vector() {
118        let a = Array1::from_vec(vec![0.0, 0.0, 0.0]);
119        let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
120        let sim = cosine_similarity(&a, &b);
121        assert_eq!(sim, 0.0, "zero vector should give similarity 0.0");
122    }
123
124    #[test]
125    fn test_cosine_similarity_normalized_vectors() {
126        // Pre-normalized vectors (unit length)
127        let a = Array1::from_vec(vec![1.0, 0.0]);
128        let val = std::f32::consts::FRAC_1_SQRT_2;
129        let b = Array1::from_vec(vec![val, val]); // 45 degrees
130        let sim = cosine_similarity(&a, &b);
131        assert!(
132            (sim - val).abs() < 0.01,
133            "45-degree angle should give ~0.707, got {sim}"
134        );
135    }
136
137    #[test]
138    fn test_cosine_similarity_384_dimensions() {
139        // Simulate real embedding dimension (384 for bge-small-en-v1.5)
140        let a = Array1::from_vec((0..384).map(|i| (i as f32) / 384.0).collect());
141        let b = Array1::from_vec((0..384).map(|i| ((383 - i) as f32) / 384.0).collect());
142        let sim = cosine_similarity(&a, &b);
143        assert!(
144            sim > 0.0 && sim < 1.0,
145            "should be between 0 and 1, got {sim}"
146        );
147    }
148}