Skip to main content

manx_cli/rag/
search_engine.rs

1//! Smart search engine orchestrator for intelligent RAG search
2//!
3//! This module coordinates query enhancement, embedding selection, multi-stage search,
4//! and result verification to provide the best possible search experience.
5
6use anyhow::Result;
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use crate::rag::{
12    embeddings::EmbeddingModel,
13    indexer::Indexer,
14    llm::LlmClient,
15    query_enhancer::{EnhancedQuery, QueryEnhancer, SearchStrategy},
16    result_verifier::{ResultVerifier, VerifiedResult},
17    EmbeddingProvider, RagConfig, RagSearchResult,
18};
19
20#[cfg(test)]
21use crate::rag::SmartSearchConfig;
22
23/// Smart search engine that orchestrates intelligent search strategies
24pub struct SmartSearchEngine {
25    config: RagConfig,
26    query_enhancer: QueryEnhancer,
27    result_verifier: ResultVerifier,
28    embedding_model: Option<Arc<EmbeddingModel>>,
29    #[allow(dead_code)] // Used in public API methods
30    llm_client: Option<Arc<LlmClient>>,
31}
32
33impl SmartSearchEngine {
34    /// Create a new smart search engine
35    pub async fn new(config: RagConfig, llm_client: Option<LlmClient>) -> Result<Self> {
36        log::info!(
37            "Initializing smart search engine with config: {:?}",
38            config.smart_search
39        );
40
41        // Wrap LLM client in Arc for sharing
42        let llm_client_arc = llm_client.map(Arc::new);
43
44        // Initialize embedding model based on smart search preferences
45        let embedding_model = Self::initialize_embedding_model(&config).await?;
46
47        // Create query enhancer
48        let query_enhancer =
49            QueryEnhancer::new(llm_client_arc.clone(), config.smart_search.clone());
50
51        // Create result verifier
52        let result_verifier =
53            ResultVerifier::new(llm_client_arc.clone(), config.smart_search.clone());
54
55        Ok(Self {
56            config,
57            query_enhancer,
58            result_verifier,
59            embedding_model,
60            llm_client: llm_client_arc,
61        })
62    }
63
64    /// Initialize the best available embedding model (wrapped in Arc for sharing)
65    async fn initialize_embedding_model(config: &RagConfig) -> Result<Option<Arc<EmbeddingModel>>> {
66        if !config.smart_search.prefer_semantic {
67            log::info!("Semantic embeddings disabled by config");
68            return Ok(None);
69        }
70
71        // Try smart auto-selection first if using default hash provider
72        if matches!(config.embedding.provider, EmbeddingProvider::Hash) {
73            log::info!("Default hash provider detected, attempting auto-selection of better model");
74            match EmbeddingModel::new_auto_select().await {
75                Ok(model) => {
76                    log::info!(
77                        "Successfully auto-selected embedding model (pooled): {:?}",
78                        model.get_config().provider
79                    );
80                    return Ok(Some(Arc::new(model)));
81                }
82                Err(e) => {
83                    log::warn!("Auto-selection failed, trying configured provider: {}", e);
84                }
85            }
86        }
87
88        // Try to initialize with configured embedding provider
89        match EmbeddingModel::new_with_config(config.embedding.clone()).await {
90            Ok(model) => {
91                log::info!(
92                    "Successfully initialized embedding model (pooled): {:?}",
93                    config.embedding.provider
94                );
95                Ok(Some(Arc::new(model)))
96            }
97            Err(e) => {
98                log::warn!(
99                    "Failed to initialize embedding model, will use fallback: {}",
100                    e
101                );
102                Ok(None)
103            }
104        }
105    }
106
107    /// Perform intelligent search with multi-stage strategy
108    pub async fn search(
109        &self,
110        query: &str,
111        max_results: Option<usize>,
112    ) -> Result<Vec<VerifiedResult>> {
113        log::info!("Starting smart search for: '{}'", query);
114
115        // Stage 1: Query Enhancement
116        let enhanced_query = self.query_enhancer.enhance_query(query).await?;
117        log::debug!(
118            "Enhanced query with {} variations",
119            enhanced_query.variations.len()
120        );
121
122        // Stage 2: Multi-strategy search execution
123        let mut all_results = if self.config.smart_search.enable_multi_stage {
124            self.execute_multi_stage_search(&enhanced_query).await?
125        } else {
126            self.execute_single_stage_search(&enhanced_query).await?
127        };
128
129        log::debug!(
130            "Collected {} raw results from search stages",
131            all_results.len()
132        );
133
134        // Stage 3: Deduplication and initial filtering
135        all_results = self.deduplicate_results(all_results);
136
137        // Stage 4: Result verification and scoring
138        let verified_results = self
139            .result_verifier
140            .verify_results(&enhanced_query, all_results)
141            .await?;
142
143        // Stage 5: Final ranking and limiting
144        let final_results = self.finalize_results(verified_results, max_results);
145
146        log::info!(
147            "Smart search completed: {} verified results for '{}'",
148            final_results.len(),
149            query
150        );
151
152        Ok(final_results)
153    }
154
155    /// Execute multi-stage search with different strategies
156    async fn execute_multi_stage_search(
157        &self,
158        query: &EnhancedQuery,
159    ) -> Result<Vec<RagSearchResult>> {
160        let mut all_results = Vec::new();
161
162        // Stage 1: Direct semantic search with original query
163        if let Some(ref embedding_model) = self.embedding_model {
164            log::debug!("Stage 1: Semantic search with original query");
165            match self.semantic_search(&query.original, embedding_model).await {
166                Ok(mut results) => {
167                    log::debug!("Semantic search found {} results", results.len());
168                    all_results.append(&mut results);
169                }
170                Err(e) => log::warn!("Semantic search failed: {}", e),
171            }
172        }
173
174        // Stage 2: Enhanced query variations
175        log::debug!("Stage 2: Enhanced query variations");
176        for (i, variation) in query.variations.iter().enumerate().take(3) {
177            // Limit to top 3 variations
178            log::debug!("Searching with variation {}: '{}'", i + 1, variation.query);
179
180            let mut variation_results = match variation.strategy {
181                SearchStrategy::Semantic => {
182                    if let Some(ref embedding_model) = self.embedding_model {
183                        self.semantic_search(&variation.query, embedding_model)
184                            .await
185                            .unwrap_or_default()
186                    } else {
187                        Vec::new()
188                    }
189                }
190                SearchStrategy::Keyword => self
191                    .keyword_search(&variation.query)
192                    .await
193                    .unwrap_or_default(),
194                SearchStrategy::Code => {
195                    self.code_search(&variation.query).await.unwrap_or_default()
196                }
197                SearchStrategy::Mixed => {
198                    let mut mixed_results = Vec::new();
199                    if let Some(ref embedding_model) = self.embedding_model {
200                        if let Ok(mut semantic_results) = self
201                            .semantic_search(&variation.query, embedding_model)
202                            .await
203                        {
204                            mixed_results.append(&mut semantic_results);
205                        }
206                    }
207                    if let Ok(mut keyword_results) = self.keyword_search(&variation.query).await {
208                        mixed_results.append(&mut keyword_results);
209                    }
210                    mixed_results
211                }
212                _ => Vec::new(),
213            };
214
215            // Apply variation weight to scores
216            for result in &mut variation_results {
217                result.score *= variation.weight;
218            }
219
220            all_results.append(&mut variation_results);
221        }
222
223        // Stage 3: Keyword fallback for exact matches
224        log::debug!("Stage 3: Keyword fallback");
225        let mut keyword_results = self
226            .keyword_search(&query.original)
227            .await
228            .unwrap_or_default();
229        // Boost keyword results slightly since they're exact matches
230        for result in &mut keyword_results {
231            result.score *= 1.1;
232        }
233        all_results.append(&mut keyword_results);
234
235        Ok(all_results)
236    }
237
238    /// Execute single-stage search (simpler approach)
239    async fn execute_single_stage_search(
240        &self,
241        query: &EnhancedQuery,
242    ) -> Result<Vec<RagSearchResult>> {
243        if let Some(ref embedding_model) = self.embedding_model {
244            self.semantic_search(&query.original, embedding_model).await
245        } else {
246            self.keyword_search(&query.original).await
247        }
248    }
249
250    /// Perform semantic search using embeddings
251    async fn semantic_search(
252        &self,
253        query: &str,
254        embedding_model: &EmbeddingModel,
255    ) -> Result<Vec<RagSearchResult>> {
256        log::debug!("Performing semantic search for: '{}'", query);
257
258        // Generate query embedding
259        let query_embedding = embedding_model.embed_text(query).await?;
260
261        // Search through stored embeddings
262        let indexer = Indexer::new(&self.config)?;
263        let index_path = indexer.get_index_path();
264        let embedding_dir = index_path.join("embeddings");
265
266        if !embedding_dir.exists() {
267            log::debug!("No embeddings directory found");
268            return Ok(vec![]);
269        }
270
271        let mut results = Vec::new();
272        let entries = std::fs::read_dir(embedding_dir)?;
273
274        for entry in entries.flatten() {
275            if let Some(file_name) = entry.file_name().to_str() {
276                if file_name.ends_with(".json") {
277                    match self
278                        .load_and_score_embedding(&entry.path(), &query_embedding, embedding_model)
279                        .await
280                    {
281                        Ok(Some(result)) => {
282                            if result.score >= self.config.similarity_threshold {
283                                results.push(result);
284                            }
285                        }
286                        Ok(None) => continue,
287                        Err(e) => {
288                            log::warn!(
289                                "Failed to process embedding file {:?}: {}",
290                                entry.path(),
291                                e
292                            );
293                        }
294                    }
295                }
296            }
297        }
298
299        // Sort by similarity score
300        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
301
302        log::debug!("Semantic search found {} results", results.len());
303        Ok(results)
304    }
305
306    /// Perform keyword-based search
307    async fn keyword_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
308        log::debug!("Performing keyword search for: '{}'", query);
309
310        let indexer = Indexer::new(&self.config)?;
311        let index_path = indexer.get_index_path();
312        let embedding_dir = index_path.join("embeddings");
313
314        if !embedding_dir.exists() {
315            return Ok(vec![]);
316        }
317
318        let query_words: Vec<String> = query
319            .to_lowercase()
320            .split_whitespace()
321            .filter(|w| w.len() > 2)
322            .map(|w| w.to_string())
323            .collect();
324
325        let mut results = Vec::new();
326        let entries = std::fs::read_dir(embedding_dir)?;
327
328        for entry in entries.flatten() {
329            if let Some(file_name) = entry.file_name().to_str() {
330                if file_name.ends_with(".json") {
331                    if let Ok(content) = std::fs::read_to_string(entry.path()) {
332                        if let Ok(stored_chunk) =
333                            serde_json::from_str::<crate::rag::StoredChunk>(&content)
334                        {
335                            let content_lower = stored_chunk.content.to_lowercase();
336
337                            let matches = query_words
338                                .iter()
339                                .filter(|word| content_lower.contains(*word))
340                                .count();
341
342                            if matches > 0 {
343                                let score = matches as f32 / query_words.len() as f32;
344
345                                results.push(RagSearchResult {
346                                    id: stored_chunk.id,
347                                    content: stored_chunk.content,
348                                    source_path: stored_chunk.source_path,
349                                    source_type: stored_chunk.source_type,
350                                    title: stored_chunk.title,
351                                    section: stored_chunk.section,
352                                    score,
353                                    chunk_index: stored_chunk.chunk_index,
354                                    metadata: stored_chunk.metadata,
355                                });
356                            }
357                        }
358                    }
359                }
360            }
361        }
362
363        // Sort by keyword match score
364        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
365
366        log::debug!("Keyword search found {} results", results.len());
367        Ok(results)
368    }
369
370    /// Perform code-specific search
371    async fn code_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
372        log::debug!("Performing code search for: '{}'", query);
373
374        // For now, use keyword search with code-specific boosting
375        let mut results = self.keyword_search(query).await?;
376
377        // Boost results that appear to be from code files
378        for result in &mut results {
379            if self.is_code_file(&result.source_path) {
380                result.score *= 1.3;
381            }
382        }
383
384        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
385
386        Ok(results)
387    }
388
389    /// Check if a file appears to be a code file
390    fn is_code_file(&self, path: &Path) -> bool {
391        if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
392            matches!(
393                extension,
394                "rs" | "js" | "ts" | "py" | "java" | "cpp" | "c" | "go" | "php" | "rb"
395            )
396        } else {
397            false
398        }
399    }
400
401    /// Load and score a single embedding file (adapted from mod.rs)
402    async fn load_and_score_embedding(
403        &self,
404        file_path: &PathBuf,
405        query_embedding: &[f32],
406        _embedding_model: &EmbeddingModel,
407    ) -> Result<Option<RagSearchResult>> {
408        let content = std::fs::read_to_string(file_path)?;
409        let chunk_data: crate::rag::StoredChunk = serde_json::from_str(&content)?;
410
411        // Calculate similarity score
412        let score = EmbeddingModel::cosine_similarity(query_embedding, &chunk_data.embedding);
413
414        Ok(Some(RagSearchResult {
415            id: chunk_data.id,
416            content: chunk_data.content,
417            source_path: chunk_data.source_path,
418            source_type: chunk_data.source_type,
419            title: chunk_data.title,
420            section: chunk_data.section,
421            score,
422            chunk_index: chunk_data.chunk_index,
423            metadata: chunk_data.metadata,
424        }))
425    }
426
427    /// Remove duplicate results based on content similarity
428    fn deduplicate_results(&self, results: Vec<RagSearchResult>) -> Vec<RagSearchResult> {
429        let mut unique_results = Vec::new();
430        let mut seen_content = HashSet::new();
431        let original_count = results.len();
432
433        for result in results {
434            // Create a simple hash of the content for deduplication
435            let content_hash = format!(
436                "{}_{}",
437                result.source_path.to_string_lossy(),
438                result.chunk_index
439            );
440
441            if !seen_content.contains(&content_hash) {
442                seen_content.insert(content_hash);
443                unique_results.push(result);
444            }
445        }
446
447        log::debug!(
448            "Deduplicated {} results to {}",
449            original_count,
450            unique_results.len()
451        );
452        unique_results
453    }
454
455    /// Finalize results with ranking and limiting
456    fn finalize_results(
457        &self,
458        mut results: Vec<VerifiedResult>,
459        max_results: Option<usize>,
460    ) -> Vec<VerifiedResult> {
461        // Sort by confidence score (already done in verifier, but ensuring)
462        results.sort_by(|a, b| b.confidence_score.partial_cmp(&a.confidence_score).unwrap());
463
464        // Apply limit
465        let limit = max_results.unwrap_or(self.config.max_results);
466        if results.len() > limit {
467            results.truncate(limit);
468        }
469
470        results
471    }
472
473    /// Check if the search engine is ready to perform intelligent search
474    /// This is a public API method for external consumers
475    #[allow(dead_code)] // Public API method - may be used by external code
476    pub fn is_intelligent_mode_available(&self) -> bool {
477        self.embedding_model.is_some() || self.llm_client.is_some()
478    }
479
480    /// Get search engine capabilities for debugging
481    /// This is a public API method for external consumers
482    #[allow(dead_code)] // Public API method - may be used by external code
483    pub fn get_capabilities(&self) -> SearchCapabilities {
484        SearchCapabilities {
485            has_semantic_embeddings: self.embedding_model.is_some(),
486            has_llm_client: self.llm_client.is_some(),
487            has_query_enhancement: self.config.smart_search.enable_query_enhancement,
488            has_result_verification: self.config.smart_search.enable_result_verification,
489            multi_stage_enabled: self.config.smart_search.enable_multi_stage,
490        }
491    }
492}
493
494/// Search engine capabilities information
495/// This is a public API struct for external consumers
496#[derive(Debug)]
497#[allow(dead_code)] // Public API struct - may be used by external code
498pub struct SearchCapabilities {
499    pub has_semantic_embeddings: bool,
500    pub has_llm_client: bool,
501    pub has_query_enhancement: bool,
502    pub has_result_verification: bool,
503    pub multi_stage_enabled: bool,
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::rag::{CodeSecurityLevel, EmbeddingConfig, EmbeddingProvider};
510
511    fn create_test_config() -> RagConfig {
512        RagConfig {
513            enabled: true,
514            index_path: PathBuf::from("/tmp/test_index"),
515            max_results: 10,
516            similarity_threshold: 0.6,
517            allow_pdf_processing: false,
518            allow_code_processing: true,
519            code_security_level: CodeSecurityLevel::Moderate,
520            mask_secrets: true,
521            max_file_size_mb: 100,
522            embedding: EmbeddingConfig {
523                provider: EmbeddingProvider::Hash,
524                dimension: 384,
525                model_path: None,
526                api_key: None,
527                endpoint: None,
528                timeout_seconds: 30,
529                batch_size: 32,
530            },
531            smart_search: SmartSearchConfig::default(),
532        }
533    }
534
535    #[tokio::test]
536    async fn test_search_engine_initialization() {
537        let config = create_test_config();
538        let engine = SmartSearchEngine::new(config, None).await;
539        assert!(engine.is_ok());
540    }
541
542    #[test]
543    fn test_code_file_detection() {
544        let _engine_config = create_test_config();
545        // We can't easily test the full engine without setting it up, but we can test the logic
546        let path = PathBuf::from("test.rs");
547        // This would be tested with the actual engine instance
548        assert!(path.extension().and_then(|ext| ext.to_str()) == Some("rs"));
549    }
550}