Skip to main content

directory_indexer/search/
engine.rs

1use log::info;
2use std::path::PathBuf;
3
4use crate::{
5    embedding::EmbeddingProvider,
6    error::{IndexerError, Result},
7    storage::{QdrantStore, SqliteStore},
8};
9
10pub struct SearchEngine {
11    #[allow(dead_code)]
12    sqlite_store: SqliteStore,
13    #[allow(dead_code)]
14    vector_store: QdrantStore,
15    #[allow(dead_code)]
16    embedding_provider: Box<dyn EmbeddingProvider>,
17}
18
19#[derive(Debug, Clone)]
20pub struct SearchQuery {
21    pub text: String,
22    pub directory_filter: Option<PathBuf>,
23    pub limit: usize,
24    pub similarity_threshold: Option<f32>,
25}
26
27#[derive(Debug, Clone)]
28pub struct SearchResult {
29    pub file_path: String,
30    pub chunk_id: usize,
31    pub score: f32,
32    pub parent_directories: Vec<String>,
33}
34
35impl SearchEngine {
36    pub fn new(
37        sqlite_store: SqliteStore,
38        vector_store: QdrantStore,
39        embedding_provider: Box<dyn EmbeddingProvider>,
40    ) -> Self {
41        Self {
42            sqlite_store,
43            vector_store,
44            embedding_provider,
45        }
46    }
47
48    pub fn validate_query(&self, query: &SearchQuery) -> Result<()> {
49        Self::validate_query_static(query)
50    }
51
52    pub fn filter_results_by_directory(
53        &self,
54        results: Vec<SearchResult>,
55        directory_filter: &Option<PathBuf>,
56    ) -> Vec<SearchResult> {
57        Self::filter_results_by_directory_static(results, directory_filter)
58    }
59
60    pub fn apply_similarity_threshold(
61        &self,
62        results: Vec<SearchResult>,
63        threshold: Option<f32>,
64    ) -> Vec<SearchResult> {
65        Self::apply_similarity_threshold_static(results, threshold)
66    }
67
68    pub fn rank_results(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
69        Self::rank_results_static(results)
70    }
71
72    pub fn limit_results(&self, results: Vec<SearchResult>, limit: usize) -> Vec<SearchResult> {
73        Self::limit_results_static(results, limit)
74    }
75
76    // Static versions for easier unit testing
77    pub fn validate_query_static(query: &SearchQuery) -> Result<()> {
78        if query.text.trim().is_empty() {
79            return Err(IndexerError::invalid_input("Search query cannot be empty"));
80        }
81
82        if query.limit == 0 {
83            return Err(IndexerError::invalid_input(
84                "Search limit must be greater than 0",
85            ));
86        }
87
88        if let Some(threshold) = query.similarity_threshold {
89            if !(0.0..=1.0).contains(&threshold) {
90                return Err(IndexerError::invalid_input(
91                    "Similarity threshold must be between 0.0 and 1.0",
92                ));
93            }
94        }
95
96        if let Some(ref dir_filter) = query.directory_filter {
97            if !dir_filter.is_dir() && !dir_filter.exists() {
98                return Err(IndexerError::invalid_input(
99                    "Directory filter must be a valid directory path",
100                ));
101            }
102        }
103
104        Ok(())
105    }
106
107    pub fn filter_results_by_directory_static(
108        results: Vec<SearchResult>,
109        directory_filter: &Option<PathBuf>,
110    ) -> Vec<SearchResult> {
111        if let Some(filter_dir) = directory_filter {
112            let filter_str = filter_dir.to_string_lossy();
113            results
114                .into_iter()
115                .filter(|result| result.file_path.starts_with(filter_str.as_ref()))
116                .collect()
117        } else {
118            results
119        }
120    }
121
122    pub fn apply_similarity_threshold_static(
123        results: Vec<SearchResult>,
124        threshold: Option<f32>,
125    ) -> Vec<SearchResult> {
126        if let Some(min_score) = threshold {
127            results
128                .into_iter()
129                .filter(|result| result.score >= min_score)
130                .collect()
131        } else {
132            results
133        }
134    }
135
136    pub fn rank_results_static(mut results: Vec<SearchResult>) -> Vec<SearchResult> {
137        results.sort_by(|a, b| {
138            b.score
139                .partial_cmp(&a.score)
140                .unwrap_or(std::cmp::Ordering::Equal)
141        });
142        results
143    }
144
145    pub fn limit_results_static(results: Vec<SearchResult>, limit: usize) -> Vec<SearchResult> {
146        results.into_iter().take(limit).collect()
147    }
148
149    pub async fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
150        let text = &query.text;
151        let limit = query.limit;
152        info!("Searching for: '{text}' with limit: {limit}");
153
154        // Validate query
155        self.validate_query(&query)?;
156
157        // Generate embedding for the query
158        let query_embedding = self
159            .embedding_provider
160            .generate_embedding(query.text.clone())
161            .await?;
162
163        // Perform vector search
164        let search_results = self.vector_store.search(query_embedding, limit).await?;
165
166        // Apply directory filtering if specified
167        let filtered_results =
168            self.filter_results_by_directory(search_results, &query.directory_filter);
169
170        // Apply similarity threshold if specified
171        let threshold_results =
172            self.apply_similarity_threshold(filtered_results, query.similarity_threshold);
173
174        // Rank results
175        let ranked_results = self.rank_results(threshold_results);
176
177        // Limit results
178        let final_results = self.limit_results(ranked_results, limit);
179
180        Ok(final_results)
181    }
182
183    pub async fn find_similar_files(
184        &self,
185        file_path: PathBuf,
186        limit: usize,
187    ) -> Result<Vec<SearchResult>> {
188        info!("Finding files similar to: {file_path:?} with limit: {limit}");
189
190        if !file_path.exists() {
191            return Err(IndexerError::not_found(format!(
192                "File not found: {}",
193                file_path.display()
194            )));
195        }
196        if !file_path.is_file() {
197            return Err(IndexerError::invalid_input(format!(
198                "Path is not a file: {}",
199                file_path.display()
200            )));
201        }
202
203        // Try to get file from database to retrieve chunks
204        let normalized_path = crate::utils::normalize_path(&file_path)?;
205        let file_record = self.sqlite_store.get_file_by_path(&normalized_path)?;
206
207        // Generate embedding for the file
208        let file_embedding = if let Some(file_record) = file_record {
209            // Parse chunks JSON to get file chunks
210            let chunks = match file_record.chunks_json {
211                Some(chunks_json) => {
212                    serde_json::from_value::<Vec<String>>(chunks_json).map_err(|e| {
213                        IndexerError::file_processing(format!("Failed to parse chunks: {e}"))
214                    })?
215                }
216                None => {
217                    return Err(IndexerError::not_found(format!(
218                        "No chunks found for file: {}",
219                        file_path.display()
220                    )));
221                }
222            };
223
224            if chunks.is_empty() {
225                return Err(IndexerError::not_found(format!(
226                    "No chunks found for file: {}",
227                    file_path.display()
228                )));
229            }
230
231            // Use the first chunk as representative of the file
232            let representative_chunk = &chunks[0];
233            self.embedding_provider
234                .generate_embedding(representative_chunk.clone())
235                .await?
236        } else {
237            // File not indexed, read from filesystem and generate embedding
238            let content = std::fs::read_to_string(&file_path)
239                .map_err(|e| IndexerError::file_processing(format!("Failed to read file: {e}")))?;
240
241            // Use first 512 chars as representative content
242            let representative_content = if content.len() > 512 {
243                &content[..512]
244            } else {
245                &content
246            };
247
248            self.embedding_provider
249                .generate_embedding(representative_content.to_string())
250                .await?
251        };
252
253        // Search for similar chunks
254        let search_results = self.vector_store.search(file_embedding, limit + 5).await?;
255
256        // Filter out results from the same file and group by file path
257        let mut file_scores: std::collections::HashMap<String, (f32, usize)> =
258            std::collections::HashMap::new();
259        let file_path_str = file_path.to_string_lossy().to_string();
260
261        for result in search_results {
262            // Skip if it's the same file
263            if result.file_path == file_path_str {
264                continue;
265            }
266
267            // Keep track of the best score for each file
268            let entry = file_scores
269                .entry(result.file_path.clone())
270                .or_insert((0.0, 0));
271            if result.score > entry.0 {
272                entry.0 = result.score;
273                entry.1 = result.chunk_id;
274            }
275        }
276
277        // Sort by score and take top results, convert to SearchResult
278        let mut similar_files: Vec<_> = file_scores.into_iter().collect();
279        similar_files.sort_by(|a, b| {
280            b.1 .0
281                .partial_cmp(&a.1 .0)
282                .unwrap_or(std::cmp::Ordering::Equal)
283        });
284        similar_files.truncate(limit);
285
286        let results: Vec<SearchResult> = similar_files
287            .into_iter()
288            .map(|(file_path, (score, chunk_id))| SearchResult {
289                file_path,
290                chunk_id,
291                score,
292                parent_directories: vec![], // Could be populated if needed
293            })
294            .collect();
295
296        Ok(results)
297    }
298
299    pub async fn get_file_content(
300        &self,
301        file_path: PathBuf,
302        chunk_range: Option<(usize, usize)>,
303    ) -> Result<String> {
304        info!("Getting content for: {file_path:?} with chunks: {chunk_range:?}");
305
306        if !file_path.exists() {
307            return Err(IndexerError::not_found(format!(
308                "File not found: {}",
309                file_path.display()
310            )));
311        }
312        if !file_path.is_file() {
313            return Err(IndexerError::invalid_input(format!(
314                "Path is not a file: {}",
315                file_path.display()
316            )));
317        }
318
319        // Try to get file from database
320        let normalized_path = crate::utils::normalize_path(&file_path)?;
321        let file_record = self.sqlite_store.get_file_by_path(&normalized_path)?;
322
323        // If chunks are stored in database, use those; otherwise read from file system
324        let content = if let Some(file_record) = file_record {
325            if let Some(chunks_json) = file_record.chunks_json {
326                let chunks = serde_json::from_value::<Vec<String>>(chunks_json).map_err(|e| {
327                    IndexerError::file_processing(format!("Failed to parse chunks: {e}"))
328                })?;
329
330                if let Some((start, end)) = chunk_range {
331                    // Return specific chunk range (1-indexed to 0-indexed)
332                    let start_idx = start.saturating_sub(1);
333                    let end_idx = end.min(chunks.len());
334
335                    if start_idx >= chunks.len() {
336                        return Err(IndexerError::invalid_input(format!(
337                            "Chunk range {start}-{end} exceeds available chunks ({})",
338                            chunks.len()
339                        )));
340                    }
341
342                    chunks[start_idx..end_idx].join("\n")
343                } else {
344                    // Return all chunks
345                    chunks.join("\n")
346                }
347            } else {
348                // File indexed but no chunks stored, read from filesystem
349                let content = std::fs::read_to_string(&file_path).map_err(|e| {
350                    IndexerError::file_processing(format!("Failed to read file: {e}"))
351                })?;
352
353                if let Some((start, end)) = chunk_range {
354                    // Split content into chunks on-the-fly for files without stored chunks
355                    let lines: Vec<&str> = content.lines().collect();
356                    let lines_per_chunk = lines.len().div_ceil(10); // Approximate 10 chunks
357                    let total_chunks = lines.len().div_ceil(lines_per_chunk);
358
359                    if start > total_chunks || start == 0 {
360                        return Err(IndexerError::invalid_input(format!(
361                            "Chunk {start} is out of range. File has {total_chunks} estimated chunks"
362                        )));
363                    }
364
365                    let start_line = (start - 1) * lines_per_chunk;
366                    let end_line = (end * lines_per_chunk).min(lines.len());
367
368                    lines[start_line..end_line].join("\n")
369                } else {
370                    content
371                }
372            }
373        } else {
374            // File not indexed, read directly from file system
375            let content = std::fs::read_to_string(&file_path)
376                .map_err(|e| IndexerError::file_processing(format!("Failed to read file: {e}")))?;
377
378            if let Some((start, end)) = chunk_range {
379                // Split content into chunks on-the-fly for unindexed files
380                let lines: Vec<&str> = content.lines().collect();
381                let lines_per_chunk = lines.len().div_ceil(10); // Approximate 10 chunks
382                let total_chunks = lines.len().div_ceil(lines_per_chunk);
383
384                if start > total_chunks || start == 0 {
385                    return Err(IndexerError::invalid_input(format!(
386                        "Chunk {start} is out of range. File has {total_chunks} estimated chunks"
387                    )));
388                }
389
390                let start_line = (start - 1) * lines_per_chunk;
391                let end_line = (end * lines_per_chunk).min(lines.len());
392
393                lines[start_line..end_line].join("\n")
394            } else {
395                content
396            }
397        };
398
399        Ok(content)
400    }
401}
402
403pub async fn create_search_engine() -> Result<SearchEngine> {
404    let config = crate::Config::load()?;
405    crate::environment::validate_environment(&config).await?;
406
407    let sqlite_store = crate::storage::SqliteStore::new(&config.storage.sqlite_path)?;
408    let vector_store = crate::storage::QdrantStore::new(
409        &config.storage.qdrant.endpoint,
410        config.storage.qdrant.collection.clone(),
411    )
412    .await?;
413    let embedding_provider = crate::embedding::create_embedding_provider(&config.embedding)?;
414
415    Ok(SearchEngine::new(
416        sqlite_store,
417        vector_store,
418        embedding_provider,
419    ))
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use std::path::PathBuf;
426
427    fn create_sample_search_results() -> Vec<SearchResult> {
428        vec![
429            SearchResult {
430                file_path: "/home/user/docs/readme.md".to_string(),
431                chunk_id: 0,
432                score: 0.9,
433                parent_directories: vec!["docs".to_string()],
434            },
435            SearchResult {
436                file_path: "/home/user/code/main.rs".to_string(),
437                chunk_id: 1,
438                score: 0.8,
439                parent_directories: vec!["code".to_string()],
440            },
441            SearchResult {
442                file_path: "/home/user/docs/api.md".to_string(),
443                chunk_id: 0,
444                score: 0.7,
445                parent_directories: vec!["docs".to_string()],
446            },
447            SearchResult {
448                file_path: "/home/user/other/test.txt".to_string(),
449                chunk_id: 0,
450                score: 0.5,
451                parent_directories: vec!["other".to_string()],
452            },
453        ]
454    }
455
456    #[test]
457    fn test_validate_query_success() {
458        let valid_query = SearchQuery {
459            text: "test search".to_string(),
460            directory_filter: None,
461            limit: 10,
462            similarity_threshold: Some(0.5),
463        };
464
465        assert!(SearchEngine::validate_query_static(&valid_query).is_ok());
466    }
467
468    #[test]
469    fn test_validate_query_empty_text() {
470        let invalid_query = SearchQuery {
471            text: "".to_string(),
472            directory_filter: None,
473            limit: 10,
474            similarity_threshold: None,
475        };
476
477        let result = SearchEngine::validate_query_static(&invalid_query);
478        assert!(result.is_err());
479        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
480    }
481
482    #[test]
483    fn test_validate_query_whitespace_only_text() {
484        let invalid_query = SearchQuery {
485            text: "   \t\n  ".to_string(),
486            directory_filter: None,
487            limit: 10,
488            similarity_threshold: None,
489        };
490
491        let result = SearchEngine::validate_query_static(&invalid_query);
492        assert!(result.is_err());
493        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
494    }
495
496    #[test]
497    fn test_validate_query_zero_limit() {
498        let invalid_query = SearchQuery {
499            text: "test".to_string(),
500            directory_filter: None,
501            limit: 0,
502            similarity_threshold: None,
503        };
504
505        let result = SearchEngine::validate_query_static(&invalid_query);
506        assert!(result.is_err());
507        assert!(result
508            .unwrap_err()
509            .to_string()
510            .contains("must be greater than 0"));
511    }
512
513    #[test]
514    fn test_validate_query_invalid_similarity_threshold() {
515        let invalid_queries = vec![
516            SearchQuery {
517                text: "test".to_string(),
518                directory_filter: None,
519                limit: 10,
520                similarity_threshold: Some(-0.1),
521            },
522            SearchQuery {
523                text: "test".to_string(),
524                directory_filter: None,
525                limit: 10,
526                similarity_threshold: Some(1.1),
527            },
528        ];
529
530        for query in invalid_queries {
531            let result = SearchEngine::validate_query_static(&query);
532            assert!(result.is_err());
533            assert!(result
534                .unwrap_err()
535                .to_string()
536                .contains("between 0.0 and 1.0"));
537        }
538    }
539
540    #[test]
541    fn test_validate_query_valid_similarity_threshold() {
542        let valid_thresholds = vec![0.0, 0.5, 1.0];
543
544        for threshold in valid_thresholds {
545            let query = SearchQuery {
546                text: "test".to_string(),
547                directory_filter: None,
548                limit: 10,
549                similarity_threshold: Some(threshold),
550            };
551
552            assert!(SearchEngine::validate_query_static(&query).is_ok());
553        }
554    }
555
556    #[test]
557    fn test_filter_results_by_directory_with_filter() {
558        let results = create_sample_search_results();
559
560        let filter_dir = Some(PathBuf::from("/home/user/docs"));
561        let filtered = SearchEngine::filter_results_by_directory_static(results, &filter_dir);
562
563        assert_eq!(filtered.len(), 2);
564        assert!(filtered
565            .iter()
566            .all(|r| r.file_path.starts_with("/home/user/docs")));
567    }
568
569    #[test]
570    fn test_filter_results_by_directory_no_filter() {
571        let results = create_sample_search_results();
572        let original_count = results.len();
573
574        let filtered = SearchEngine::filter_results_by_directory_static(results, &None);
575
576        assert_eq!(filtered.len(), original_count);
577    }
578
579    #[test]
580    fn test_filter_results_by_directory_no_matches() {
581        let results = create_sample_search_results();
582
583        let filter_dir = Some(PathBuf::from("/nonexistent/path"));
584        let filtered = SearchEngine::filter_results_by_directory_static(results, &filter_dir);
585
586        assert_eq!(filtered.len(), 0);
587    }
588
589    #[test]
590    fn test_apply_similarity_threshold_with_threshold() {
591        let results = create_sample_search_results();
592
593        let threshold = Some(0.75);
594        let filtered = SearchEngine::apply_similarity_threshold_static(results, threshold);
595
596        assert_eq!(filtered.len(), 2);
597        assert!(filtered.iter().all(|r| r.score >= 0.75));
598    }
599
600    #[test]
601    fn test_apply_similarity_threshold_no_threshold() {
602        let results = create_sample_search_results();
603        let original_count = results.len();
604
605        let filtered = SearchEngine::apply_similarity_threshold_static(results, None);
606
607        assert_eq!(filtered.len(), original_count);
608    }
609
610    #[test]
611    fn test_apply_similarity_threshold_no_matches() {
612        let results = create_sample_search_results();
613
614        let threshold = Some(0.95);
615        let filtered = SearchEngine::apply_similarity_threshold_static(results, threshold);
616
617        assert_eq!(filtered.len(), 0);
618    }
619
620    #[test]
621    fn test_rank_results() {
622        let results = create_sample_search_results();
623
624        let ranked = SearchEngine::rank_results_static(results);
625
626        assert_eq!(ranked.len(), 4);
627        assert_eq!(ranked[0].score, 0.9);
628        assert_eq!(ranked[1].score, 0.8);
629        assert_eq!(ranked[2].score, 0.7);
630        assert_eq!(ranked[3].score, 0.5);
631
632        // Verify it's sorted in descending order
633        for i in 1..ranked.len() {
634            assert!(ranked[i - 1].score >= ranked[i].score);
635        }
636    }
637
638    #[test]
639    fn test_rank_results_empty() {
640        let ranked = SearchEngine::rank_results_static(vec![]);
641
642        assert_eq!(ranked.len(), 0);
643    }
644
645    #[test]
646    fn test_limit_results() {
647        let results = create_sample_search_results();
648
649        let limited = SearchEngine::limit_results_static(results, 2);
650
651        assert_eq!(limited.len(), 2);
652    }
653
654    #[test]
655    fn test_limit_results_larger_than_available() {
656        let results = create_sample_search_results();
657        let original_count = results.len();
658
659        let limited = SearchEngine::limit_results_static(results, 10);
660
661        assert_eq!(limited.len(), original_count);
662    }
663
664    #[test]
665    fn test_limit_results_zero() {
666        let results = create_sample_search_results();
667
668        let limited = SearchEngine::limit_results_static(results, 0);
669
670        assert_eq!(limited.len(), 0);
671    }
672
673    // Integration tests for full search functionality will be in tests/search_integration_tests.rs
674
675    #[test]
676    fn test_search_query_creation() {
677        let query = SearchQuery {
678            text: "test query".to_string(),
679            directory_filter: Some(PathBuf::from("/test/dir")),
680            limit: 5,
681            similarity_threshold: Some(0.8),
682        };
683
684        assert_eq!(query.text, "test query");
685        assert_eq!(query.directory_filter, Some(PathBuf::from("/test/dir")));
686        assert_eq!(query.limit, 5);
687        assert_eq!(query.similarity_threshold, Some(0.8));
688    }
689
690    #[test]
691    fn test_search_result_creation() {
692        let result = SearchResult {
693            file_path: "/test/file.txt".to_string(),
694            chunk_id: 1,
695            score: 0.85,
696            parent_directories: vec!["test".to_string()],
697        };
698
699        assert_eq!(result.file_path, "/test/file.txt");
700        assert_eq!(result.chunk_id, 1);
701        assert_eq!(result.score, 0.85);
702        assert_eq!(result.parent_directories, vec!["test".to_string()]);
703    }
704}