1use anyhow::Result;
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use crate::rag::{
12 embeddings::EmbeddingModel,
13 indexer::Indexer,
14 llm::LlmClient,
15 query_enhancer::{EnhancedQuery, QueryEnhancer, SearchStrategy},
16 result_verifier::{ResultVerifier, VerifiedResult},
17 EmbeddingProvider, RagConfig, RagSearchResult,
18};
19
20#[cfg(test)]
21use crate::rag::SmartSearchConfig;
22
23pub struct SmartSearchEngine {
25 config: RagConfig,
26 query_enhancer: QueryEnhancer,
27 result_verifier: ResultVerifier,
28 embedding_model: Option<Arc<EmbeddingModel>>,
29 #[allow(dead_code)] llm_client: Option<Arc<LlmClient>>,
31}
32
33impl SmartSearchEngine {
34 pub async fn new(config: RagConfig, llm_client: Option<LlmClient>) -> Result<Self> {
36 log::info!(
37 "Initializing smart search engine with config: {:?}",
38 config.smart_search
39 );
40
41 let llm_client_arc = llm_client.map(Arc::new);
43
44 let embedding_model = Self::initialize_embedding_model(&config).await?;
46
47 let query_enhancer =
49 QueryEnhancer::new(llm_client_arc.clone(), config.smart_search.clone());
50
51 let result_verifier =
53 ResultVerifier::new(llm_client_arc.clone(), config.smart_search.clone());
54
55 Ok(Self {
56 config,
57 query_enhancer,
58 result_verifier,
59 embedding_model,
60 llm_client: llm_client_arc,
61 })
62 }
63
64 async fn initialize_embedding_model(config: &RagConfig) -> Result<Option<Arc<EmbeddingModel>>> {
66 if !config.smart_search.prefer_semantic {
67 log::info!("Semantic embeddings disabled by config");
68 return Ok(None);
69 }
70
71 if matches!(config.embedding.provider, EmbeddingProvider::Hash) {
73 log::info!("Default hash provider detected, attempting auto-selection of better model");
74 match EmbeddingModel::new_auto_select().await {
75 Ok(model) => {
76 log::info!(
77 "Successfully auto-selected embedding model (pooled): {:?}",
78 model.get_config().provider
79 );
80 return Ok(Some(Arc::new(model)));
81 }
82 Err(e) => {
83 log::warn!("Auto-selection failed, trying configured provider: {}", e);
84 }
85 }
86 }
87
88 match EmbeddingModel::new_with_config(config.embedding.clone()).await {
90 Ok(model) => {
91 log::info!(
92 "Successfully initialized embedding model (pooled): {:?}",
93 config.embedding.provider
94 );
95 Ok(Some(Arc::new(model)))
96 }
97 Err(e) => {
98 log::warn!(
99 "Failed to initialize embedding model, will use fallback: {}",
100 e
101 );
102 Ok(None)
103 }
104 }
105 }
106
107 pub async fn search(
109 &self,
110 query: &str,
111 max_results: Option<usize>,
112 ) -> Result<Vec<VerifiedResult>> {
113 log::info!("Starting smart search for: '{}'", query);
114
115 let enhanced_query = self.query_enhancer.enhance_query(query).await?;
117 log::debug!(
118 "Enhanced query with {} variations",
119 enhanced_query.variations.len()
120 );
121
122 let mut all_results = if self.config.smart_search.enable_multi_stage {
124 self.execute_multi_stage_search(&enhanced_query).await?
125 } else {
126 self.execute_single_stage_search(&enhanced_query).await?
127 };
128
129 log::debug!(
130 "Collected {} raw results from search stages",
131 all_results.len()
132 );
133
134 all_results = self.deduplicate_results(all_results);
136
137 let verified_results = self
139 .result_verifier
140 .verify_results(&enhanced_query, all_results)
141 .await?;
142
143 let final_results = self.finalize_results(verified_results, max_results);
145
146 log::info!(
147 "Smart search completed: {} verified results for '{}'",
148 final_results.len(),
149 query
150 );
151
152 Ok(final_results)
153 }
154
155 async fn execute_multi_stage_search(
157 &self,
158 query: &EnhancedQuery,
159 ) -> Result<Vec<RagSearchResult>> {
160 let mut all_results = Vec::new();
161
162 if let Some(ref embedding_model) = self.embedding_model {
164 log::debug!("Stage 1: Semantic search with original query");
165 match self.semantic_search(&query.original, embedding_model).await {
166 Ok(mut results) => {
167 log::debug!("Semantic search found {} results", results.len());
168 all_results.append(&mut results);
169 }
170 Err(e) => log::warn!("Semantic search failed: {}", e),
171 }
172 }
173
174 log::debug!("Stage 2: Enhanced query variations");
176 for (i, variation) in query.variations.iter().enumerate().take(3) {
177 log::debug!("Searching with variation {}: '{}'", i + 1, variation.query);
179
180 let mut variation_results = match variation.strategy {
181 SearchStrategy::Semantic => {
182 if let Some(ref embedding_model) = self.embedding_model {
183 self.semantic_search(&variation.query, embedding_model)
184 .await
185 .unwrap_or_default()
186 } else {
187 Vec::new()
188 }
189 }
190 SearchStrategy::Keyword => self
191 .keyword_search(&variation.query)
192 .await
193 .unwrap_or_default(),
194 SearchStrategy::Code => {
195 self.code_search(&variation.query).await.unwrap_or_default()
196 }
197 SearchStrategy::Mixed => {
198 let mut mixed_results = Vec::new();
199 if let Some(ref embedding_model) = self.embedding_model {
200 if let Ok(mut semantic_results) = self
201 .semantic_search(&variation.query, embedding_model)
202 .await
203 {
204 mixed_results.append(&mut semantic_results);
205 }
206 }
207 if let Ok(mut keyword_results) = self.keyword_search(&variation.query).await {
208 mixed_results.append(&mut keyword_results);
209 }
210 mixed_results
211 }
212 _ => Vec::new(),
213 };
214
215 for result in &mut variation_results {
217 result.score *= variation.weight;
218 }
219
220 all_results.append(&mut variation_results);
221 }
222
223 log::debug!("Stage 3: Keyword fallback");
225 let mut keyword_results = self
226 .keyword_search(&query.original)
227 .await
228 .unwrap_or_default();
229 for result in &mut keyword_results {
231 result.score *= 1.1;
232 }
233 all_results.append(&mut keyword_results);
234
235 Ok(all_results)
236 }
237
238 async fn execute_single_stage_search(
240 &self,
241 query: &EnhancedQuery,
242 ) -> Result<Vec<RagSearchResult>> {
243 if let Some(ref embedding_model) = self.embedding_model {
244 self.semantic_search(&query.original, embedding_model).await
245 } else {
246 self.keyword_search(&query.original).await
247 }
248 }
249
250 async fn semantic_search(
252 &self,
253 query: &str,
254 embedding_model: &EmbeddingModel,
255 ) -> Result<Vec<RagSearchResult>> {
256 log::debug!("Performing semantic search for: '{}'", query);
257
258 let query_embedding = embedding_model.embed_text(query).await?;
260
261 let indexer = Indexer::new(&self.config)?;
263 let index_path = indexer.get_index_path();
264 let embedding_dir = index_path.join("embeddings");
265
266 if !embedding_dir.exists() {
267 log::debug!("No embeddings directory found");
268 return Ok(vec![]);
269 }
270
271 let mut results = Vec::new();
272 let entries = std::fs::read_dir(embedding_dir)?;
273
274 for entry in entries.flatten() {
275 if let Some(file_name) = entry.file_name().to_str() {
276 if file_name.ends_with(".json") {
277 match self
278 .load_and_score_embedding(&entry.path(), &query_embedding, embedding_model)
279 .await
280 {
281 Ok(Some(result)) => {
282 if result.score >= self.config.similarity_threshold {
283 results.push(result);
284 }
285 }
286 Ok(None) => continue,
287 Err(e) => {
288 log::warn!(
289 "Failed to process embedding file {:?}: {}",
290 entry.path(),
291 e
292 );
293 }
294 }
295 }
296 }
297 }
298
299 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
301
302 log::debug!("Semantic search found {} results", results.len());
303 Ok(results)
304 }
305
306 async fn keyword_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
308 log::debug!("Performing keyword search for: '{}'", query);
309
310 let indexer = Indexer::new(&self.config)?;
311 let index_path = indexer.get_index_path();
312 let embedding_dir = index_path.join("embeddings");
313
314 if !embedding_dir.exists() {
315 return Ok(vec![]);
316 }
317
318 let query_words: Vec<String> = query
319 .to_lowercase()
320 .split_whitespace()
321 .filter(|w| w.len() > 2)
322 .map(|w| w.to_string())
323 .collect();
324
325 let mut results = Vec::new();
326 let entries = std::fs::read_dir(embedding_dir)?;
327
328 for entry in entries.flatten() {
329 if let Some(file_name) = entry.file_name().to_str() {
330 if file_name.ends_with(".json") {
331 if let Ok(content) = std::fs::read_to_string(entry.path()) {
332 if let Ok(stored_chunk) =
333 serde_json::from_str::<crate::rag::StoredChunk>(&content)
334 {
335 let content_lower = stored_chunk.content.to_lowercase();
336
337 let matches = query_words
338 .iter()
339 .filter(|word| content_lower.contains(*word))
340 .count();
341
342 if matches > 0 {
343 let score = matches as f32 / query_words.len() as f32;
344
345 results.push(RagSearchResult {
346 id: stored_chunk.id,
347 content: stored_chunk.content,
348 source_path: stored_chunk.source_path,
349 source_type: stored_chunk.source_type,
350 title: stored_chunk.title,
351 section: stored_chunk.section,
352 score,
353 chunk_index: stored_chunk.chunk_index,
354 metadata: stored_chunk.metadata,
355 });
356 }
357 }
358 }
359 }
360 }
361 }
362
363 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
365
366 log::debug!("Keyword search found {} results", results.len());
367 Ok(results)
368 }
369
370 async fn code_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
372 log::debug!("Performing code search for: '{}'", query);
373
374 let mut results = self.keyword_search(query).await?;
376
377 for result in &mut results {
379 if self.is_code_file(&result.source_path) {
380 result.score *= 1.3;
381 }
382 }
383
384 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
385
386 Ok(results)
387 }
388
389 fn is_code_file(&self, path: &Path) -> bool {
391 if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
392 matches!(
393 extension,
394 "rs" | "js" | "ts" | "py" | "java" | "cpp" | "c" | "go" | "php" | "rb"
395 )
396 } else {
397 false
398 }
399 }
400
401 async fn load_and_score_embedding(
403 &self,
404 file_path: &PathBuf,
405 query_embedding: &[f32],
406 _embedding_model: &EmbeddingModel,
407 ) -> Result<Option<RagSearchResult>> {
408 let content = std::fs::read_to_string(file_path)?;
409 let chunk_data: crate::rag::StoredChunk = serde_json::from_str(&content)?;
410
411 let score = EmbeddingModel::cosine_similarity(query_embedding, &chunk_data.embedding);
413
414 Ok(Some(RagSearchResult {
415 id: chunk_data.id,
416 content: chunk_data.content,
417 source_path: chunk_data.source_path,
418 source_type: chunk_data.source_type,
419 title: chunk_data.title,
420 section: chunk_data.section,
421 score,
422 chunk_index: chunk_data.chunk_index,
423 metadata: chunk_data.metadata,
424 }))
425 }
426
427 fn deduplicate_results(&self, results: Vec<RagSearchResult>) -> Vec<RagSearchResult> {
429 let mut unique_results = Vec::new();
430 let mut seen_content = HashSet::new();
431 let original_count = results.len();
432
433 for result in results {
434 let content_hash = format!(
436 "{}_{}",
437 result.source_path.to_string_lossy(),
438 result.chunk_index
439 );
440
441 if !seen_content.contains(&content_hash) {
442 seen_content.insert(content_hash);
443 unique_results.push(result);
444 }
445 }
446
447 log::debug!(
448 "Deduplicated {} results to {}",
449 original_count,
450 unique_results.len()
451 );
452 unique_results
453 }
454
455 fn finalize_results(
457 &self,
458 mut results: Vec<VerifiedResult>,
459 max_results: Option<usize>,
460 ) -> Vec<VerifiedResult> {
461 results.sort_by(|a, b| b.confidence_score.partial_cmp(&a.confidence_score).unwrap());
463
464 let limit = max_results.unwrap_or(self.config.max_results);
466 if results.len() > limit {
467 results.truncate(limit);
468 }
469
470 results
471 }
472
473 #[allow(dead_code)] pub fn is_intelligent_mode_available(&self) -> bool {
477 self.embedding_model.is_some() || self.llm_client.is_some()
478 }
479
480 #[allow(dead_code)] pub fn get_capabilities(&self) -> SearchCapabilities {
484 SearchCapabilities {
485 has_semantic_embeddings: self.embedding_model.is_some(),
486 has_llm_client: self.llm_client.is_some(),
487 has_query_enhancement: self.config.smart_search.enable_query_enhancement,
488 has_result_verification: self.config.smart_search.enable_result_verification,
489 multi_stage_enabled: self.config.smart_search.enable_multi_stage,
490 }
491 }
492}
493
494#[derive(Debug)]
497#[allow(dead_code)] pub struct SearchCapabilities {
499 pub has_semantic_embeddings: bool,
500 pub has_llm_client: bool,
501 pub has_query_enhancement: bool,
502 pub has_result_verification: bool,
503 pub multi_stage_enabled: bool,
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::rag::{CodeSecurityLevel, EmbeddingConfig, EmbeddingProvider};
510
511 fn create_test_config() -> RagConfig {
512 RagConfig {
513 enabled: true,
514 index_path: PathBuf::from("/tmp/test_index"),
515 max_results: 10,
516 similarity_threshold: 0.6,
517 allow_pdf_processing: false,
518 allow_code_processing: true,
519 code_security_level: CodeSecurityLevel::Moderate,
520 mask_secrets: true,
521 max_file_size_mb: 100,
522 embedding: EmbeddingConfig {
523 provider: EmbeddingProvider::Hash,
524 dimension: 384,
525 model_path: None,
526 api_key: None,
527 endpoint: None,
528 timeout_seconds: 30,
529 batch_size: 32,
530 },
531 smart_search: SmartSearchConfig::default(),
532 }
533 }
534
535 #[tokio::test]
536 async fn test_search_engine_initialization() {
537 let config = create_test_config();
538 let engine = SmartSearchEngine::new(config, None).await;
539 assert!(engine.is_ok());
540 }
541
542 #[test]
543 fn test_code_file_detection() {
544 let _engine_config = create_test_config();
545 let path = PathBuf::from("test.rs");
547 assert!(path.extension().and_then(|ext| ext.to_str()) == Some("rs"));
549 }
550}