git_semantic/search/
engine.rs1use ndarray::Array1;
2use tracing::debug;
3
4use crate::cli::SearchFilters;
5use crate::embedding::ModelManager;
6use crate::index::SemanticIndex;
7
8use super::filter::FilterEngine;
9use super::{SearchError, SearchResult};
10
11pub struct SearchEngine {
12 model_manager: ModelManager,
13}
14
15impl SearchEngine {
16 pub fn new(mut model_manager: ModelManager) -> Result<Self, SearchError> {
17 model_manager.init()?;
18 Ok(Self { model_manager })
19 }
20
21 pub fn search(
22 &mut self,
23 index: &SemanticIndex,
24 query: &str,
25 num_results: usize,
26 filters: SearchFilters,
27 ) -> Result<Vec<SearchResult>, SearchError> {
28 debug!("Searching for: {}", query);
29
30 let query_embedding = self.model_manager.encode_text(query)?;
31
32 let mut results: Vec<SearchResult> = index
33 .entries
34 .iter()
35 .enumerate()
36 .map(|(idx, entry)| {
37 let embedding = Array1::from_vec(entry.embedding.clone());
38 let similarity = cosine_similarity(&query_embedding, &embedding);
39
40 SearchResult {
41 commit: entry.commit.clone(),
42 similarity,
43 rank: idx + 1,
44 }
45 })
46 .collect();
47
48 let filter_engine = FilterEngine::new(filters);
49 results = filter_engine.apply(results)?;
50
51 results.sort_by(|a, b| {
52 b.similarity
53 .partial_cmp(&a.similarity)
54 .unwrap_or(std::cmp::Ordering::Equal)
55 });
56
57 results.truncate(num_results);
58
59 for (idx, result) in results.iter_mut().enumerate() {
60 result.rank = idx + 1;
61 }
62
63 Ok(results)
64 }
65}
66
67pub(crate) fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
68 let dot_product = a.dot(b);
69 let norm_a = a.dot(a).sqrt();
70 let norm_b = b.dot(b).sqrt();
71
72 if norm_a == 0.0 || norm_b == 0.0 {
73 return 0.0;
74 }
75
76 dot_product / (norm_a * norm_b)
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn test_cosine_similarity_identical_vectors() {
85 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
86 let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
87 let sim = cosine_similarity(&a, &b);
88 assert!(
89 (sim - 1.0).abs() < 1e-6,
90 "identical vectors should have similarity ~1.0, got {sim}"
91 );
92 }
93
94 #[test]
95 fn test_cosine_similarity_orthogonal_vectors() {
96 let a = Array1::from_vec(vec![1.0, 0.0, 0.0]);
97 let b = Array1::from_vec(vec![0.0, 1.0, 0.0]);
98 let sim = cosine_similarity(&a, &b);
99 assert!(
100 sim.abs() < 1e-6,
101 "orthogonal vectors should have similarity ~0.0, got {sim}"
102 );
103 }
104
105 #[test]
106 fn test_cosine_similarity_opposite_vectors() {
107 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
108 let b = Array1::from_vec(vec![-1.0, -2.0, -3.0]);
109 let sim = cosine_similarity(&a, &b);
110 assert!(
111 (sim - (-1.0)).abs() < 1e-6,
112 "opposite vectors should have similarity ~-1.0, got {sim}"
113 );
114 }
115
116 #[test]
117 fn test_cosine_similarity_zero_vector() {
118 let a = Array1::from_vec(vec![0.0, 0.0, 0.0]);
119 let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
120 let sim = cosine_similarity(&a, &b);
121 assert_eq!(sim, 0.0, "zero vector should give similarity 0.0");
122 }
123
124 #[test]
125 fn test_cosine_similarity_normalized_vectors() {
126 let a = Array1::from_vec(vec![1.0, 0.0]);
128 let val = std::f32::consts::FRAC_1_SQRT_2;
129 let b = Array1::from_vec(vec![val, val]); let sim = cosine_similarity(&a, &b);
131 assert!(
132 (sim - val).abs() < 0.01,
133 "45-degree angle should give ~0.707, got {sim}"
134 );
135 }
136
137 #[test]
138 fn test_cosine_similarity_384_dimensions() {
139 let a = Array1::from_vec((0..384).map(|i| (i as f32) / 384.0).collect());
141 let b = Array1::from_vec((0..384).map(|i| ((383 - i) as f32) / 384.0).collect());
142 let sim = cosine_similarity(&a, &b);
143 assert!(
144 sim > 0.0 && sim < 1.0,
145 "should be between 0 and 1, got {sim}"
146 );
147 }
148}