manx_cli/rag/
search_engine.rs

1//! Smart search engine orchestrator for intelligent RAG search
2//!
3//! This module coordinates query enhancement, embedding selection, multi-stage search,
4//! and result verification to provide the best possible search experience.
5
6use anyhow::Result;
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use crate::rag::{
11    embeddings::EmbeddingModel,
12    indexer::Indexer,
13    llm::LlmClient,
14    query_enhancer::{EnhancedQuery, QueryEnhancer, SearchStrategy},
15    result_verifier::{ResultVerifier, VerifiedResult},
16    EmbeddingProvider, RagConfig, RagSearchResult,
17};
18
19#[cfg(test)]
20use crate::rag::SmartSearchConfig;
21
22/// Smart search engine that orchestrates intelligent search strategies
23pub struct SmartSearchEngine {
24    config: RagConfig,
25    query_enhancer: QueryEnhancer,
26    result_verifier: ResultVerifier,
27    embedding_model: Option<EmbeddingModel>,
28    #[allow(dead_code)] // Used in public API methods
29    llm_client: Option<LlmClient>,
30}
31
32impl SmartSearchEngine {
33    /// Create a new smart search engine
34    pub async fn new(config: RagConfig, llm_client: Option<LlmClient>) -> Result<Self> {
35        log::info!(
36            "Initializing smart search engine with config: {:?}",
37            config.smart_search
38        );
39
40        // Initialize embedding model based on smart search preferences
41        let embedding_model = Self::initialize_embedding_model(&config).await?;
42
43        // Create query enhancer
44        let query_enhancer = QueryEnhancer::new(llm_client.clone(), config.smart_search.clone());
45
46        // Create result verifier
47        let result_verifier = ResultVerifier::new(llm_client.clone(), config.smart_search.clone());
48
49        Ok(Self {
50            config,
51            query_enhancer,
52            result_verifier,
53            embedding_model,
54            llm_client,
55        })
56    }
57
58    /// Initialize the best available embedding model
59    async fn initialize_embedding_model(config: &RagConfig) -> Result<Option<EmbeddingModel>> {
60        if !config.smart_search.prefer_semantic {
61            log::info!("Semantic embeddings disabled by config");
62            return Ok(None);
63        }
64
65        // Try smart auto-selection first if using default hash provider
66        if matches!(config.embedding.provider, EmbeddingProvider::Hash) {
67            log::info!("Default hash provider detected, attempting auto-selection of better model");
68            match EmbeddingModel::new_auto_select().await {
69                Ok(model) => {
70                    log::info!(
71                        "Successfully auto-selected embedding model: {:?}",
72                        model.get_config().provider
73                    );
74                    return Ok(Some(model));
75                }
76                Err(e) => {
77                    log::warn!("Auto-selection failed, trying configured provider: {}", e);
78                }
79            }
80        }
81
82        // Try to initialize with configured embedding provider
83        match EmbeddingModel::new_with_config(config.embedding.clone()).await {
84            Ok(model) => {
85                log::info!(
86                    "Successfully initialized embedding model: {:?}",
87                    config.embedding.provider
88                );
89                Ok(Some(model))
90            }
91            Err(e) => {
92                log::warn!(
93                    "Failed to initialize embedding model, will use fallback: {}",
94                    e
95                );
96                Ok(None)
97            }
98        }
99    }
100
101    /// Perform intelligent search with multi-stage strategy
102    pub async fn search(
103        &self,
104        query: &str,
105        max_results: Option<usize>,
106    ) -> Result<Vec<VerifiedResult>> {
107        log::info!("Starting smart search for: '{}'", query);
108
109        // Stage 1: Query Enhancement
110        let enhanced_query = self.query_enhancer.enhance_query(query).await?;
111        log::debug!(
112            "Enhanced query with {} variations",
113            enhanced_query.variations.len()
114        );
115
116        // Stage 2: Multi-strategy search execution
117        let mut all_results = if self.config.smart_search.enable_multi_stage {
118            self.execute_multi_stage_search(&enhanced_query).await?
119        } else {
120            self.execute_single_stage_search(&enhanced_query).await?
121        };
122
123        log::debug!(
124            "Collected {} raw results from search stages",
125            all_results.len()
126        );
127
128        // Stage 3: Deduplication and initial filtering
129        all_results = self.deduplicate_results(all_results);
130
131        // Stage 4: Result verification and scoring
132        let verified_results = self
133            .result_verifier
134            .verify_results(&enhanced_query, all_results)
135            .await?;
136
137        // Stage 5: Final ranking and limiting
138        let final_results = self.finalize_results(verified_results, max_results);
139
140        log::info!(
141            "Smart search completed: {} verified results for '{}'",
142            final_results.len(),
143            query
144        );
145
146        Ok(final_results)
147    }
148
149    /// Execute multi-stage search with different strategies
150    async fn execute_multi_stage_search(
151        &self,
152        query: &EnhancedQuery,
153    ) -> Result<Vec<RagSearchResult>> {
154        let mut all_results = Vec::new();
155
156        // Stage 1: Direct semantic search with original query
157        if let Some(ref embedding_model) = self.embedding_model {
158            log::debug!("Stage 1: Semantic search with original query");
159            match self.semantic_search(&query.original, embedding_model).await {
160                Ok(mut results) => {
161                    log::debug!("Semantic search found {} results", results.len());
162                    all_results.append(&mut results);
163                }
164                Err(e) => log::warn!("Semantic search failed: {}", e),
165            }
166        }
167
168        // Stage 2: Enhanced query variations
169        log::debug!("Stage 2: Enhanced query variations");
170        for (i, variation) in query.variations.iter().enumerate().take(3) {
171            // Limit to top 3 variations
172            log::debug!("Searching with variation {}: '{}'", i + 1, variation.query);
173
174            let mut variation_results = match variation.strategy {
175                SearchStrategy::Semantic => {
176                    if let Some(ref embedding_model) = self.embedding_model {
177                        self.semantic_search(&variation.query, embedding_model)
178                            .await
179                            .unwrap_or_default()
180                    } else {
181                        Vec::new()
182                    }
183                }
184                SearchStrategy::Keyword => self
185                    .keyword_search(&variation.query)
186                    .await
187                    .unwrap_or_default(),
188                SearchStrategy::Code => {
189                    self.code_search(&variation.query).await.unwrap_or_default()
190                }
191                SearchStrategy::Mixed => {
192                    let mut mixed_results = Vec::new();
193                    if let Some(ref embedding_model) = self.embedding_model {
194                        if let Ok(mut semantic_results) = self
195                            .semantic_search(&variation.query, embedding_model)
196                            .await
197                        {
198                            mixed_results.append(&mut semantic_results);
199                        }
200                    }
201                    if let Ok(mut keyword_results) = self.keyword_search(&variation.query).await {
202                        mixed_results.append(&mut keyword_results);
203                    }
204                    mixed_results
205                }
206                _ => Vec::new(),
207            };
208
209            // Apply variation weight to scores
210            for result in &mut variation_results {
211                result.score *= variation.weight;
212            }
213
214            all_results.append(&mut variation_results);
215        }
216
217        // Stage 3: Keyword fallback for exact matches
218        log::debug!("Stage 3: Keyword fallback");
219        let mut keyword_results = self
220            .keyword_search(&query.original)
221            .await
222            .unwrap_or_default();
223        // Boost keyword results slightly since they're exact matches
224        for result in &mut keyword_results {
225            result.score *= 1.1;
226        }
227        all_results.append(&mut keyword_results);
228
229        Ok(all_results)
230    }
231
232    /// Execute single-stage search (simpler approach)
233    async fn execute_single_stage_search(
234        &self,
235        query: &EnhancedQuery,
236    ) -> Result<Vec<RagSearchResult>> {
237        if let Some(ref embedding_model) = self.embedding_model {
238            self.semantic_search(&query.original, embedding_model).await
239        } else {
240            self.keyword_search(&query.original).await
241        }
242    }
243
244    /// Perform semantic search using embeddings
245    async fn semantic_search(
246        &self,
247        query: &str,
248        embedding_model: &EmbeddingModel,
249    ) -> Result<Vec<RagSearchResult>> {
250        log::debug!("Performing semantic search for: '{}'", query);
251
252        // Generate query embedding
253        let query_embedding = embedding_model.embed_text(query).await?;
254
255        // Search through stored embeddings
256        let indexer = Indexer::new(&self.config)?;
257        let index_path = indexer.get_index_path();
258        let embedding_dir = index_path.join("embeddings");
259
260        if !embedding_dir.exists() {
261            log::debug!("No embeddings directory found");
262            return Ok(vec![]);
263        }
264
265        let mut results = Vec::new();
266        let entries = std::fs::read_dir(embedding_dir)?;
267
268        for entry in entries.flatten() {
269            if let Some(file_name) = entry.file_name().to_str() {
270                if file_name.ends_with(".json") {
271                    match self
272                        .load_and_score_embedding(&entry.path(), &query_embedding, embedding_model)
273                        .await
274                    {
275                        Ok(Some(result)) => {
276                            if result.score >= self.config.similarity_threshold {
277                                results.push(result);
278                            }
279                        }
280                        Ok(None) => continue,
281                        Err(e) => {
282                            log::warn!(
283                                "Failed to process embedding file {:?}: {}",
284                                entry.path(),
285                                e
286                            );
287                        }
288                    }
289                }
290            }
291        }
292
293        // Sort by similarity score
294        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
295
296        log::debug!("Semantic search found {} results", results.len());
297        Ok(results)
298    }
299
300    /// Perform keyword-based search
301    async fn keyword_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
302        log::debug!("Performing keyword search for: '{}'", query);
303
304        let indexer = Indexer::new(&self.config)?;
305        let index_path = indexer.get_index_path();
306        let embedding_dir = index_path.join("embeddings");
307
308        if !embedding_dir.exists() {
309            return Ok(vec![]);
310        }
311
312        let query_words: Vec<String> = query
313            .to_lowercase()
314            .split_whitespace()
315            .filter(|w| w.len() > 2)
316            .map(|w| w.to_string())
317            .collect();
318
319        let mut results = Vec::new();
320        let entries = std::fs::read_dir(embedding_dir)?;
321
322        for entry in entries.flatten() {
323            if let Some(file_name) = entry.file_name().to_str() {
324                if file_name.ends_with(".json") {
325                    if let Ok(content) = std::fs::read_to_string(entry.path()) {
326                        if let Ok(stored_chunk) =
327                            serde_json::from_str::<crate::rag::StoredChunk>(&content)
328                        {
329                            let content_lower = stored_chunk.content.to_lowercase();
330
331                            let matches = query_words
332                                .iter()
333                                .filter(|word| content_lower.contains(*word))
334                                .count();
335
336                            if matches > 0 {
337                                let score = matches as f32 / query_words.len() as f32;
338
339                                results.push(RagSearchResult {
340                                    id: stored_chunk.id,
341                                    content: stored_chunk.content,
342                                    source_path: stored_chunk.source_path,
343                                    source_type: stored_chunk.source_type,
344                                    title: stored_chunk.title,
345                                    section: stored_chunk.section,
346                                    score,
347                                    chunk_index: stored_chunk.chunk_index,
348                                    metadata: stored_chunk.metadata,
349                                });
350                            }
351                        }
352                    }
353                }
354            }
355        }
356
357        // Sort by keyword match score
358        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
359
360        log::debug!("Keyword search found {} results", results.len());
361        Ok(results)
362    }
363
364    /// Perform code-specific search
365    async fn code_search(&self, query: &str) -> Result<Vec<RagSearchResult>> {
366        log::debug!("Performing code search for: '{}'", query);
367
368        // For now, use keyword search with code-specific boosting
369        let mut results = self.keyword_search(query).await?;
370
371        // Boost results that appear to be from code files
372        for result in &mut results {
373            if self.is_code_file(&result.source_path) {
374                result.score *= 1.3;
375            }
376        }
377
378        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
379
380        Ok(results)
381    }
382
383    /// Check if a file appears to be a code file
384    fn is_code_file(&self, path: &Path) -> bool {
385        if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
386            matches!(
387                extension,
388                "rs" | "js" | "ts" | "py" | "java" | "cpp" | "c" | "go" | "php" | "rb"
389            )
390        } else {
391            false
392        }
393    }
394
395    /// Load and score a single embedding file (adapted from mod.rs)
396    async fn load_and_score_embedding(
397        &self,
398        file_path: &PathBuf,
399        query_embedding: &[f32],
400        _embedding_model: &EmbeddingModel,
401    ) -> Result<Option<RagSearchResult>> {
402        let content = std::fs::read_to_string(file_path)?;
403        let chunk_data: crate::rag::StoredChunk = serde_json::from_str(&content)?;
404
405        // Calculate similarity score
406        let score = EmbeddingModel::cosine_similarity(query_embedding, &chunk_data.embedding);
407
408        Ok(Some(RagSearchResult {
409            id: chunk_data.id,
410            content: chunk_data.content,
411            source_path: chunk_data.source_path,
412            source_type: chunk_data.source_type,
413            title: chunk_data.title,
414            section: chunk_data.section,
415            score,
416            chunk_index: chunk_data.chunk_index,
417            metadata: chunk_data.metadata,
418        }))
419    }
420
421    /// Remove duplicate results based on content similarity
422    fn deduplicate_results(&self, results: Vec<RagSearchResult>) -> Vec<RagSearchResult> {
423        let mut unique_results = Vec::new();
424        let mut seen_content = HashSet::new();
425        let original_count = results.len();
426
427        for result in results {
428            // Create a simple hash of the content for deduplication
429            let content_hash = format!(
430                "{}_{}",
431                result.source_path.to_string_lossy(),
432                result.chunk_index
433            );
434
435            if !seen_content.contains(&content_hash) {
436                seen_content.insert(content_hash);
437                unique_results.push(result);
438            }
439        }
440
441        log::debug!(
442            "Deduplicated {} results to {}",
443            original_count,
444            unique_results.len()
445        );
446        unique_results
447    }
448
449    /// Finalize results with ranking and limiting
450    fn finalize_results(
451        &self,
452        mut results: Vec<VerifiedResult>,
453        max_results: Option<usize>,
454    ) -> Vec<VerifiedResult> {
455        // Sort by confidence score (already done in verifier, but ensuring)
456        results.sort_by(|a, b| b.confidence_score.partial_cmp(&a.confidence_score).unwrap());
457
458        // Apply limit
459        let limit = max_results.unwrap_or(self.config.max_results);
460        if results.len() > limit {
461            results.truncate(limit);
462        }
463
464        results
465    }
466
467    /// Check if the search engine is ready to perform intelligent search
468    /// This is a public API method for external consumers
469    #[allow(dead_code)] // Public API method - may be used by external code
470    pub fn is_intelligent_mode_available(&self) -> bool {
471        self.embedding_model.is_some() || self.llm_client.is_some()
472    }
473
474    /// Get search engine capabilities for debugging
475    /// This is a public API method for external consumers
476    #[allow(dead_code)] // Public API method - may be used by external code
477    pub fn get_capabilities(&self) -> SearchCapabilities {
478        SearchCapabilities {
479            has_semantic_embeddings: self.embedding_model.is_some(),
480            has_llm_client: self.llm_client.is_some(),
481            has_query_enhancement: self.config.smart_search.enable_query_enhancement,
482            has_result_verification: self.config.smart_search.enable_result_verification,
483            multi_stage_enabled: self.config.smart_search.enable_multi_stage,
484        }
485    }
486}
487
488/// Search engine capabilities information
489/// This is a public API struct for external consumers
490#[derive(Debug)]
491#[allow(dead_code)] // Public API struct - may be used by external code
492pub struct SearchCapabilities {
493    pub has_semantic_embeddings: bool,
494    pub has_llm_client: bool,
495    pub has_query_enhancement: bool,
496    pub has_result_verification: bool,
497    pub multi_stage_enabled: bool,
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::rag::{CodeSecurityLevel, EmbeddingConfig, EmbeddingProvider};
504
505    fn create_test_config() -> RagConfig {
506        RagConfig {
507            enabled: true,
508            index_path: PathBuf::from("/tmp/test_index"),
509            max_results: 10,
510            similarity_threshold: 0.6,
511            allow_pdf_processing: false,
512            allow_code_processing: true,
513            code_security_level: CodeSecurityLevel::Moderate,
514            mask_secrets: true,
515            max_file_size_mb: 100,
516            embedding: EmbeddingConfig {
517                provider: EmbeddingProvider::Hash,
518                dimension: 384,
519                model_path: None,
520                api_key: None,
521                endpoint: None,
522                timeout_seconds: 30,
523                batch_size: 32,
524            },
525            smart_search: SmartSearchConfig::default(),
526        }
527    }
528
529    #[tokio::test]
530    async fn test_search_engine_initialization() {
531        let config = create_test_config();
532        let engine = SmartSearchEngine::new(config, None).await;
533        assert!(engine.is_ok());
534    }
535
536    #[test]
537    fn test_code_file_detection() {
538        let _engine_config = create_test_config();
539        // We can't easily test the full engine without setting it up, but we can test the logic
540        let path = PathBuf::from("test.rs");
541        // This would be tested with the actual engine instance
542        assert!(path.extension().and_then(|ext| ext.to_str()) == Some("rs"));
543    }
544}