1use anyhow::Result;
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use crate::rag::{
11 embeddings::EmbeddingModel,
12 indexer::Indexer,
13 llm::LlmClient,
14 query_enhancer::{EnhancedQuery, QueryEnhancer, SearchStrategy},
15 result_verifier::{ResultVerifier, VerifiedResult},
16 EmbeddingProvider, RagConfig, RagSearchResult,
17};
18
19#[cfg(test)]
20use crate::rag::SmartSearchConfig;
21
22pub struct SmartSearchEngine {
24 config: RagConfig,
25 query_enhancer: QueryEnhancer,
26 result_verifier: ResultVerifier,
27 embedding_model: Option<EmbeddingModel>,
28 #[allow(dead_code)] llm_client: Option<LlmClient>,
30}
31
32impl SmartSearchEngine {
33 pub async fn new(config: RagConfig, llm_client: Option<LlmClient>) -> Result<Self> {
35 log::info!(
36 "Initializing smart search engine with config: {:?}",
37 config.smart_search
38 );
39
40 let embedding_model = Self::initialize_embedding_model(&config).await?;
42
43 let query_enhancer = QueryEnhancer::new(llm_client.clone(), config.smart_search.clone());
45
46 let result_verifier = ResultVerifier::new(llm_client.clone(), config.smart_search.clone());
48
49 Ok(Self {
50 config,
51 query_enhancer,
52 result_verifier,
53 embedding_model,
54 llm_client,
55 })
56 }
57
58 async fn initialize_embedding_model(config: &RagConfig) -> Result<Option<EmbeddingModel>> {
60 if !config.smart_search.prefer_semantic {
61 log::info!("Semantic embeddings disabled by config");
62 return Ok(None);
63 }
64
65 if matches!(config.embedding.provider, EmbeddingProvider::Hash) {
67 log::info!("Default hash provider detected, attempting auto-selection of better model");
68 match EmbeddingModel::new_auto_select().await {
69 Ok(model) => {
70 log::info!(
71 "Successfully auto-selected embedding model: {:?}",
72 model.get_config().provider
73 );
74 return Ok(Some(model));
75 }
76 Err(e) => {
77 log::warn!("Auto-selection failed, trying configured provider: {}", e);
78 }
79 }
80 }
81
82 match EmbeddingModel::new_with_config(config.embedding.clone()).await {
84 Ok(model) => {
85 log::info!(
86 "Successfully initialized embedding model: {:?}",
87 config.embedding.provider
88 );
89 Ok(Some(model))
90 }
91 Err(e) => {
92 log::warn!(
93 "Failed to initialize embedding model, will use fallback: {}",
94 e
95 );
96 Ok(None)
97 }
98 }
99 }
100
101 pub async fn search(
103 &self,
104 query: &str,
105 max_results: Option<usize>,
106 ) -> Result<Vec<VerifiedResult>> {
107 log::info!("Starting smart search for: '{}'", query);
108
109 let enhanced_query = self.query_enhancer.enhance_query(query).await?;
111 log::debug!(
112 "Enhanced query with {} variations",
113 enhanced_query.variations.len()
114 );
115
116 let mut all_results = if self.config.smart_search.enable_multi_stage {
118 self.execute_multi_stage_search(&enhanced_query).await?
119 } else {
120 self.execute_single_stage_search(&enhanced_query).await?
121 };
122
123 log::debug!(
124 "Collected {} raw results from search stages",
125 all_results.len()
126 );
127
128 all_results = self.deduplicate_results(all_results);
130
131 let verified_results = self
133 .result_verifier
134 .verify_results(&enhanced_query, all_results)
135 .await?;
136
137 let final_results = self.finalize_results(verified_results, max_results);
139
140 log::info!(
141 "Smart search completed: {} verified results for '{}'",
142 final_results.len(),
143 query
144 );
145
146 Ok(final_results)
147 }
148
149 async fn execute_multi_stage_search(
151 &self,
152 query: &EnhancedQuery,
153 ) -> Result<Vec<RagSearchResult>> {
154 let mut all_results = Vec::new();
155
156 if let Some(ref embedding_model) = self.embedding_model {
158 log::debug!("Stage 1: Semantic search with original query");
159 match self.semantic_search(&query.original, embedding_model).await {
160 Ok(mut results) => {
161 log::debug!("Semantic search found {} results", results.len());
162 all_results.append(&mut results);
163 }
164 Err(e) => log::warn!("Semantic search failed: {}", e),
165 }
166 }
167
168 log::debug!("Stage 2: Enhanced query variations");
170 for (i, variation) in query.variations.iter().enumerate().take(3) {
171 log::debug!("Searching with variation {}: '{}'", i + 1, variation.query);
173
174 let mut variation_results = match variation.strategy {
175 SearchStrategy::Semantic => {
176 if let Some(ref embedding_model) = self.embedding_model {
177 self.semantic_search(&variation.query, embedding_model)
178 .await
179 .unwrap_or_default()
180 } else {
181 Vec::new()
182 }
183 }
184 SearchStrategy::Keyword => self
185 .keyword_search(&variation.query)
186 .await
187 .unwrap_or_default(),
188 SearchStrategy::Code => {
189 self.code_search(&variation.query).await.unwrap_or_default()
190 }
191 SearchStrategy::Mixed => {
192 let mut mixed_results = Vec::new();
193 if let Some(ref embedding_model) = self.embedding_model {
194 if let Ok(mut semantic_results) = self
195 .semantic_search(&variation.query, embedding_model)
196 .await
197 {
198 mixed_results.append(&mut semantic_results);
199 }
200 }
201 if let Ok(mut keyword_results) = self.keyword_search(&variation.query).await {
202 mixed_results.append(&mut keyword_results);
203 }
204 mixed_results
205 }
206 _ => Vec::new(),
207 };
208
209 for result in &mut variation_results {
211 result.score *= variation.weight;
212 }
213
214 all_results.append(&mut variation_results);
215 }
216
217 log::debug!("Stage 3: Keyword fallback");
219 let mut keyword_results = self
220 .keyword_search(&query.original)
221 .await
222 .unwrap_or_default();
223 for result in &mut keyword_results {
225 result.score *= 1.1;
226 }
227 all_results.append(&mut keyword_results);
228
229 Ok(all_results)
230 }
231
232 async fn execute_single_stage_search(
234 &self,
235 query: &EnhancedQuery,
236 ) -> Result<Vec<RagSearchResult>> {
237 if let Some(ref embedding_model) = self.embedding_model {
238 self.semantic_search(&query.original, embedding_model).await
239 } else {
240 self.keyword_search(&query.original).await
241 }
242 }
243
244 async fn semantic_search(
246 &self,
247 query: &str,
248 embedding_model: &EmbeddingModel,
249 ) -> Result<Vec<RagSearchResult>> {
250 log::debug!("Performing semantic search for: '{}'", query);
251
252 let query_embedding = embedding_model.embed_text(query).await?;
254
255 let indexer = Indexer::new(&self.config)?;
257 let index_path = indexer.get_index_path();
258 let embedding_dir = index_path.join("embeddings");
259
260 if !embedding_dir.exists() {
261 log::debug!("No embeddings directory found");
262 return Ok(vec![]);
263 }
264
265 let mut results = Vec::new();
266 let entries = std::fs::read_dir(embedding_dir)?;
267
268 for entry in entries.flatten() {
269 if let Some(file_name) = entry.file_name().to_str() {
270 if file_name.ends_with(".json") {
271 match self
272 .load_and_score_embedding(&entry.path(), &query_embedding, embedding_model)
273 .await
274 {
275 Ok(Some(result)) => {
276 if result.score >= self.config.similarity_threshold {
277 results.push(result);
278 }
279 }
280 Ok(None) => continue,
281 Err(e) => {
282 log::warn!(
283 "Failed to process embedding file {:?}: {}",
284 entry.path(),
285 e
286 );
287 }
288 }
289 }
290 }
291 }
292
293 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
295
296 log::debug!("Semantic search found {} results", results.len());
297 Ok(results)
298 }
299
300 async fn keyword_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
302 log::debug!("Performing keyword search for: '{}'", query);
303
304 let indexer = Indexer::new(&self.config)?;
305 let index_path = indexer.get_index_path();
306 let embedding_dir = index_path.join("embeddings");
307
308 if !embedding_dir.exists() {
309 return Ok(vec![]);
310 }
311
312 let query_words: Vec<String> = query
313 .to_lowercase()
314 .split_whitespace()
315 .filter(|w| w.len() > 2)
316 .map(|w| w.to_string())
317 .collect();
318
319 let mut results = Vec::new();
320 let entries = std::fs::read_dir(embedding_dir)?;
321
322 for entry in entries.flatten() {
323 if let Some(file_name) = entry.file_name().to_str() {
324 if file_name.ends_with(".json") {
325 if let Ok(content) = std::fs::read_to_string(entry.path()) {
326 if let Ok(stored_chunk) =
327 serde_json::from_str::<crate::rag::StoredChunk>(&content)
328 {
329 let content_lower = stored_chunk.content.to_lowercase();
330
331 let matches = query_words
332 .iter()
333 .filter(|word| content_lower.contains(*word))
334 .count();
335
336 if matches > 0 {
337 let score = matches as f32 / query_words.len() as f32;
338
339 results.push(RagSearchResult {
340 id: stored_chunk.id,
341 content: stored_chunk.content,
342 source_path: stored_chunk.source_path,
343 source_type: stored_chunk.source_type,
344 title: stored_chunk.title,
345 section: stored_chunk.section,
346 score,
347 chunk_index: stored_chunk.chunk_index,
348 metadata: stored_chunk.metadata,
349 });
350 }
351 }
352 }
353 }
354 }
355 }
356
357 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
359
360 log::debug!("Keyword search found {} results", results.len());
361 Ok(results)
362 }
363
364 async fn code_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
366 log::debug!("Performing code search for: '{}'", query);
367
368 let mut results = self.keyword_search(query).await?;
370
371 for result in &mut results {
373 if self.is_code_file(&result.source_path) {
374 result.score *= 1.3;
375 }
376 }
377
378 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
379
380 Ok(results)
381 }
382
383 fn is_code_file(&self, path: &Path) -> bool {
385 if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
386 matches!(
387 extension,
388 "rs" | "js" | "ts" | "py" | "java" | "cpp" | "c" | "go" | "php" | "rb"
389 )
390 } else {
391 false
392 }
393 }
394
395 async fn load_and_score_embedding(
397 &self,
398 file_path: &PathBuf,
399 query_embedding: &[f32],
400 _embedding_model: &EmbeddingModel,
401 ) -> Result<Option<RagSearchResult>> {
402 let content = std::fs::read_to_string(file_path)?;
403 let chunk_data: crate::rag::StoredChunk = serde_json::from_str(&content)?;
404
405 let score = EmbeddingModel::cosine_similarity(query_embedding, &chunk_data.embedding);
407
408 Ok(Some(RagSearchResult {
409 id: chunk_data.id,
410 content: chunk_data.content,
411 source_path: chunk_data.source_path,
412 source_type: chunk_data.source_type,
413 title: chunk_data.title,
414 section: chunk_data.section,
415 score,
416 chunk_index: chunk_data.chunk_index,
417 metadata: chunk_data.metadata,
418 }))
419 }
420
421 fn deduplicate_results(&self, results: Vec<RagSearchResult>) -> Vec<RagSearchResult> {
423 let mut unique_results = Vec::new();
424 let mut seen_content = HashSet::new();
425 let original_count = results.len();
426
427 for result in results {
428 let content_hash = format!(
430 "{}_{}",
431 result.source_path.to_string_lossy(),
432 result.chunk_index
433 );
434
435 if !seen_content.contains(&content_hash) {
436 seen_content.insert(content_hash);
437 unique_results.push(result);
438 }
439 }
440
441 log::debug!(
442 "Deduplicated {} results to {}",
443 original_count,
444 unique_results.len()
445 );
446 unique_results
447 }
448
449 fn finalize_results(
451 &self,
452 mut results: Vec<VerifiedResult>,
453 max_results: Option<usize>,
454 ) -> Vec<VerifiedResult> {
455 results.sort_by(|a, b| b.confidence_score.partial_cmp(&a.confidence_score).unwrap());
457
458 let limit = max_results.unwrap_or(self.config.max_results);
460 if results.len() > limit {
461 results.truncate(limit);
462 }
463
464 results
465 }
466
467 #[allow(dead_code)] pub fn is_intelligent_mode_available(&self) -> bool {
471 self.embedding_model.is_some() || self.llm_client.is_some()
472 }
473
474 #[allow(dead_code)] pub fn get_capabilities(&self) -> SearchCapabilities {
478 SearchCapabilities {
479 has_semantic_embeddings: self.embedding_model.is_some(),
480 has_llm_client: self.llm_client.is_some(),
481 has_query_enhancement: self.config.smart_search.enable_query_enhancement,
482 has_result_verification: self.config.smart_search.enable_result_verification,
483 multi_stage_enabled: self.config.smart_search.enable_multi_stage,
484 }
485 }
486}
487
488#[derive(Debug)]
491#[allow(dead_code)] pub struct SearchCapabilities {
493 pub has_semantic_embeddings: bool,
494 pub has_llm_client: bool,
495 pub has_query_enhancement: bool,
496 pub has_result_verification: bool,
497 pub multi_stage_enabled: bool,
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::rag::{CodeSecurityLevel, EmbeddingConfig, EmbeddingProvider};
504
505 fn create_test_config() -> RagConfig {
506 RagConfig {
507 enabled: true,
508 index_path: PathBuf::from("/tmp/test_index"),
509 max_results: 10,
510 similarity_threshold: 0.6,
511 allow_pdf_processing: false,
512 allow_code_processing: true,
513 code_security_level: CodeSecurityLevel::Moderate,
514 mask_secrets: true,
515 max_file_size_mb: 100,
516 embedding: EmbeddingConfig {
517 provider: EmbeddingProvider::Hash,
518 dimension: 384,
519 model_path: None,
520 api_key: None,
521 endpoint: None,
522 timeout_seconds: 30,
523 batch_size: 32,
524 },
525 smart_search: SmartSearchConfig::default(),
526 }
527 }
528
529 #[tokio::test]
530 async fn test_search_engine_initialization() {
531 let config = create_test_config();
532 let engine = SmartSearchEngine::new(config, None).await;
533 assert!(engine.is_ok());
534 }
535
536 #[test]
537 fn test_code_file_detection() {
538 let _engine_config = create_test_config();
539 let path = PathBuf::from("test.rs");
541 assert!(path.extension().and_then(|ext| ext.to_str()) == Some("rs"));
543 }
544}