1use super::CodeTool;
12use super::ToolError;
13use dashmap::DashMap;
14use std::path::Path;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18use std::time::Instant;
19use tantivy::Index;
20use tantivy::IndexReader;
21use tantivy::IndexWriter;
22use tantivy::collector::TopDocs;
23use tantivy::doc;
24use tantivy::query::Query;
25use tantivy::query::QueryParser;
26use tantivy::schema::Field;
27use tantivy::schema::INDEXED;
28use tantivy::schema::STORED;
29use tantivy::schema::STRING;
30use tantivy::schema::Schema;
31use tantivy::schema::TEXT;
32use tantivy::schema::Value;
33use tokio::process::Command;
34
35#[derive(Debug)]
37pub struct MultiLayerSearchEngine {
38 symbol_index: Arc<DashMap<String, Vec<Symbol>>>,
40 tantivy_index: Option<TantivySearchEngine>,
42 ast_cache: Arc<DashMap<PathBuf, CachedAst>>,
44 query_cache: Arc<DashMap<String, CachedResult>>,
46 config: SearchConfig,
48}
49
50pub struct TantivySearchEngine {
52 index: Index,
53 reader: IndexReader,
54 writer: Arc<tokio::sync::Mutex<IndexWriter>>,
55 schema: TantivySchema,
56}
57
58#[derive(Debug, Clone)]
60pub struct TantivySchema {
61 pub path: Field,
62 pub content: Field,
63 pub symbols: Field,
64 pub ast: Field,
65 pub language: Field,
66 pub line_number: Field,
67 pub function_name: Field,
68 pub class_name: Field,
69}
70
71#[derive(Debug, Clone)]
73pub struct SearchConfig {
74 pub max_cache_size: usize,
76 pub cache_ttl: Duration,
78 pub enable_symbol_index: bool,
80 pub enable_tantivy: bool,
82 pub enable_ast_cache: bool,
84 pub enable_ripgrep_fallback: bool,
86 pub max_results: usize,
88 pub timeout: Duration,
90}
91
92#[derive(Debug, Clone)]
94pub struct SearchQuery {
95 pub pattern: String,
97 pub query_type: QueryType,
99 pub file_filters: Vec<String>,
101 pub language_filters: Vec<String>,
103 pub context_lines: usize,
105 pub limit: Option<usize>,
107 pub fuzzy: bool,
109 pub case_sensitive: bool,
111 pub scope: SearchScope,
113}
114
115#[derive(Debug, Clone, PartialEq)]
117pub enum QueryType {
118 Symbol,
120 FullText,
122 Definition,
124 References,
126 Semantic,
128 General,
130}
131
132#[derive(Debug, Clone)]
134pub enum SearchScope {
135 Workspace,
137 Directory(PathBuf),
139 Files(Vec<PathBuf>),
141 GitRepository,
143}
144
145#[derive(Debug, Clone)]
147pub struct ToolOutput<T> {
148 pub result: T,
149 pub context: Context,
150 pub changes: Vec<Change>,
151 pub metadata: Metadata,
152 pub summary: String,
153}
154
155#[derive(Debug, Clone)]
157pub struct Context {
158 pub before: String,
160 pub after: String,
162 pub surrounding: Vec<Line>,
164 pub location: Location,
166 pub scope: Scope,
168}
169
170#[derive(Debug, Clone)]
172pub struct Line {
173 pub number: usize,
174 pub content: String,
175 pub is_match: bool,
176}
177
178#[derive(Debug, Clone)]
180pub struct Location {
181 pub file: PathBuf,
182 pub line: usize,
183 pub column: usize,
184 pub byte_offset: usize,
185}
186
187#[derive(Debug, Clone)]
189pub struct Scope {
190 pub function: Option<String>,
191 pub class: Option<String>,
192 pub module: Option<String>,
193 pub namespace: Option<String>,
194}
195
196#[derive(Debug, Clone)]
198pub struct Change {
199 pub change_type: ChangeType,
200 pub description: String,
201 pub location: Location,
202}
203
204#[derive(Debug, Clone)]
205pub enum ChangeType {
206 Addition,
207 Modification,
208 Deletion,
209}
210
211#[derive(Debug, Clone)]
213pub struct Metadata {
214 pub search_layer: SearchLayer,
216 pub duration: Duration,
218 pub total_results: usize,
220 pub strategy: SearchStrategy,
222 pub language: Option<String>,
224}
225
226#[derive(Debug, Clone, PartialEq)]
228pub enum SearchLayer {
229 SymbolIndex,
230 Tantivy,
231 AstCache,
232 RipgrepFallback,
233 Combined,
234}
235
236#[derive(Debug, Clone, PartialEq)]
238pub enum SearchStrategy {
239 FastSymbolLookup,
240 FullTextIndex,
241 SemanticAnalysis,
242 PatternMatching,
243 Hybrid,
244}
245
246pub type SearchResult = ToolOutput<Vec<Match>>;
248
249#[derive(Debug, Clone)]
251pub struct Match {
252 pub file: PathBuf,
253 pub line: usize,
254 pub column: usize,
255 pub content: String,
256 pub score: f32,
257}
258
259#[derive(Debug, Clone)]
261pub struct Symbol {
262 pub name: String,
263 pub kind: SymbolKind,
264 pub location: Location,
265 pub scope: Scope,
266 pub visibility: Visibility,
267}
268
269#[derive(Debug, Clone)]
270pub enum SymbolKind {
271 Function,
272 Method,
273 Class,
274 Interface,
275 Struct,
276 Enum,
277 Variable,
278 Constant,
279 Module,
280 Namespace,
281}
282
283#[derive(Debug, Clone)]
284pub enum Visibility {
285 Public,
286 Private,
287 Protected,
288 Internal,
289}
290
291#[derive(Debug, Clone)]
293pub struct CachedAst {
294 pub file_path: PathBuf,
295 pub language: String,
296 pub symbols: Vec<Symbol>,
297 pub dependencies: Vec<String>,
298 pub last_modified: std::time::SystemTime,
299 pub parse_duration: Duration,
300}
301
302#[derive(Debug, Clone)]
304pub struct CachedResult {
305 pub result: SearchResult,
306 pub timestamp: Instant,
307 pub ttl: Duration,
308}
309
310impl Default for SearchConfig {
311 fn default() -> Self {
312 Self {
313 max_cache_size: 1000,
314 cache_ttl: Duration::from_secs(300), enable_symbol_index: true,
316 enable_tantivy: true,
317 enable_ast_cache: true,
318 enable_ripgrep_fallback: true,
319 max_results: 100,
320 timeout: Duration::from_secs(10),
321 }
322 }
323}
324
325impl Default for SearchQuery {
326 fn default() -> Self {
327 Self {
328 pattern: String::new(),
329 query_type: QueryType::General,
330 file_filters: Vec::new(),
331 language_filters: Vec::new(),
332 context_lines: 3,
333 limit: Some(50),
334 fuzzy: false,
335 case_sensitive: true,
336 scope: SearchScope::Workspace,
337 }
338 }
339}
340
341impl MultiLayerSearchEngine {
342 pub fn new(config: SearchConfig) -> Result<Self, ToolError> {
344 let symbol_index = Arc::new(DashMap::new());
345 let ast_cache = Arc::new(DashMap::new());
346 let query_cache = Arc::new(DashMap::new());
347
348 let tantivy_index = if config.enable_tantivy {
349 Some(TantivySearchEngine::new()?)
350 } else {
351 None
352 };
353
354 Ok(Self {
355 symbol_index,
356 tantivy_index,
357 ast_cache,
358 query_cache,
359 config,
360 })
361 }
362
363 pub async fn search(&self, query: SearchQuery) -> Result<SearchResult, ToolError> {
365 let start_time = Instant::now();
366
367 if let Some(cached) = self.get_cached_result(&query) {
369 return Ok(cached);
370 }
371
372 let strategy = self.select_strategy(&query);
374 let layer = match strategy {
375 SearchStrategy::FastSymbolLookup if self.config.enable_symbol_index => {
376 SearchLayer::SymbolIndex
377 }
378 SearchStrategy::FullTextIndex if self.config.enable_tantivy => SearchLayer::Tantivy,
379 SearchStrategy::SemanticAnalysis if self.config.enable_ast_cache => {
380 SearchLayer::AstCache
381 }
382 SearchStrategy::PatternMatching if self.config.enable_ripgrep_fallback => {
383 SearchLayer::RipgrepFallback
384 }
385 SearchStrategy::Hybrid => SearchLayer::Combined,
386 _ => {
387 if self.config.enable_symbol_index {
389 SearchLayer::SymbolIndex
390 } else if self.config.enable_ast_cache {
391 SearchLayer::AstCache
392 } else if self.config.enable_tantivy {
393 SearchLayer::Tantivy
394 } else if self.config.enable_ripgrep_fallback {
395 SearchLayer::RipgrepFallback
396 } else {
397 SearchLayer::SymbolIndex
399 }
400 }
401 };
402
403 let matches = match layer {
405 SearchLayer::SymbolIndex => self.search_symbol_index(&query).await?,
406 SearchLayer::Tantivy => self.search_tantivy(&query).await?,
407 SearchLayer::AstCache => self.search_ast_cache(&query).await?,
408 SearchLayer::RipgrepFallback => self.search_ripgrep(&query).await?,
409 SearchLayer::Combined => self.search_combined(&query).await?,
410 };
411
412 let enhanced_matches = self.enhance_matches_with_context(matches, &query).await?;
414
415 let duration = start_time.elapsed();
416 let result = ToolOutput {
417 result: enhanced_matches.clone(),
418 context: self
419 .build_overall_context(&enhanced_matches, &query)
420 .await?,
421 changes: Vec::new(), metadata: Metadata {
423 search_layer: layer,
424 duration,
425 total_results: enhanced_matches.len(),
426 strategy,
427 language: self.detect_language(&query),
428 },
429 summary: self.generate_summary(&enhanced_matches, &query),
430 };
431
432 self.cache_result(&query, &result);
434
435 Ok(result)
436 }
437
438 async fn search_symbol_index(&self, query: &SearchQuery) -> Result<Vec<Match>, ToolError> {
440 let _start = Instant::now();
441 let mut matches = Vec::new();
442
443 if let Some(symbols) = self.symbol_index.get(&query.pattern) {
445 for symbol in symbols.iter() {
446 if self.matches_filters(symbol, query) {
447 matches.push(Match {
448 file: symbol.location.file.clone(),
449 line: symbol.location.line,
450 column: symbol.location.column,
451 content: symbol.name.clone(),
452 score: 1.0, });
454 }
455 }
456 }
457
458 if query.fuzzy && matches.is_empty() {
460 for entry in self.symbol_index.iter() {
461 let similarity = self.calculate_similarity(&query.pattern, entry.key());
462 if similarity > 0.7 {
463 for symbol in entry.value().iter() {
465 if self.matches_filters(symbol, query) {
466 matches.push(Match {
467 file: symbol.location.file.clone(),
468 line: symbol.location.line,
469 column: symbol.location.column,
470 content: symbol.name.clone(),
471 score: similarity,
472 });
473 }
474 }
475 }
476 }
477 }
478
479 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
481
482 if let Some(limit) = query.limit {
484 matches.truncate(limit);
485 }
486
487 Ok(matches)
488 }
489
490 async fn search_tantivy(&self, query: &SearchQuery) -> Result<Vec<Match>, ToolError> {
492 if let Some(ref tantivy) = self.tantivy_index {
493 tantivy.search(query).await
494 } else {
495 Err(ToolError::NotImplemented("Tantivy search not enabled"))
496 }
497 }
498
499 async fn search_ast_cache(&self, query: &SearchQuery) -> Result<Vec<Match>, ToolError> {
501 let mut matches = Vec::new();
502
503 for entry in self.ast_cache.iter() {
504 let ast = entry.value();
505 for symbol in &ast.symbols {
506 if self.matches_semantic_query(symbol, query) {
507 matches.push(Match {
508 file: symbol.location.file.clone(),
509 line: symbol.location.line,
510 column: symbol.location.column,
511 content: format!("{} {}", symbol.kind.as_str(), symbol.name),
512 score: self.calculate_semantic_score(symbol, query),
513 });
514 }
515 }
516 }
517
518 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
520
521 if let Some(limit) = query.limit {
522 matches.truncate(limit);
523 }
524
525 Ok(matches)
526 }
527
528 async fn search_ripgrep(&self, query: &SearchQuery) -> Result<Vec<Match>, ToolError> {
530 let mut cmd = Command::new("rg");
531 cmd.arg("--json")
532 .arg("--with-filename")
533 .arg("--line-number")
534 .arg("--column");
535
536 if !query.case_sensitive {
537 cmd.arg("--ignore-case");
538 }
539
540 if let Some(limit) = query.limit {
541 cmd.arg("--max-count").arg(limit.to_string());
542 }
543
544 for filter in &query.file_filters {
546 cmd.arg("--glob").arg(filter);
547 }
548
549 cmd.arg(&query.pattern);
551
552 match &query.scope {
554 SearchScope::Workspace => {
555 cmd.arg(".");
556 }
557 SearchScope::Directory(path) => {
558 cmd.arg(path);
559 }
560 SearchScope::Files(files) => {
561 for file in files {
562 cmd.arg(file);
563 }
564 }
565 SearchScope::GitRepository => {
566 cmd.arg("--type-add")
567 .arg("git:*.{rs,py,js,ts,go,java,c,cpp,h}")
568 .arg("--type")
569 .arg("git")
570 .arg(".");
571 }
572 }
573
574 let output = tokio::time::timeout(self.config.timeout, cmd.output())
575 .await
576 .map_err(|_| ToolError::InvalidQuery("Search timeout".to_string()))?
577 .map_err(ToolError::Io)?;
578
579 if !output.status.success() {
580 return Err(ToolError::InvalidQuery(format!(
581 "ripgrep failed: {}",
582 String::from_utf8_lossy(&output.stderr)
583 )));
584 }
585
586 self.parse_ripgrep_output(&output.stdout)
587 }
588
589 async fn search_combined(&self, query: &SearchQuery) -> Result<Vec<Match>, ToolError> {
591 let mut all_matches = Vec::new();
592 let mut any_layer_enabled = false;
593
594 if self.config.enable_symbol_index {
596 any_layer_enabled = true;
597 if let Ok(matches) = self.search_symbol_index(query).await {
598 all_matches.extend(matches);
599 }
600 }
601
602 if self.config.enable_tantivy && all_matches.len() < query.limit.unwrap_or(50) {
603 any_layer_enabled = true;
604 if let Ok(matches) = self.search_tantivy(query).await {
605 all_matches.extend(matches);
606 }
607 }
608
609 if self.config.enable_ast_cache && all_matches.len() < query.limit.unwrap_or(50) {
610 any_layer_enabled = true;
611 if let Ok(matches) = self.search_ast_cache(query).await {
612 all_matches.extend(matches);
613 }
614 }
615
616 if self.config.enable_ripgrep_fallback && all_matches.len() < query.limit.unwrap_or(50) {
617 any_layer_enabled = true;
618 if let Ok(matches) = self.search_ripgrep(query).await {
619 all_matches.extend(matches);
620 }
621 }
622
623 if !any_layer_enabled {
626 return Ok(Vec::new());
627 }
628
629 if !all_matches.is_empty() {
631 all_matches.sort_by(|a, b| {
632 match b.score.partial_cmp(&a.score).unwrap() {
634 std::cmp::Ordering::Equal => match a.file.cmp(&b.file) {
635 std::cmp::Ordering::Equal => a.line.cmp(&b.line),
636 other => other,
637 },
638 other => other,
639 }
640 });
641
642 all_matches
644 .dedup_by(|a, b| a.file == b.file && a.line == b.line && a.column == b.column);
645 }
646
647 if let Some(limit) = query.limit {
648 all_matches.truncate(limit);
649 }
650
651 Ok(all_matches)
652 }
653
654 fn select_strategy(&self, query: &SearchQuery) -> SearchStrategy {
656 match query.query_type {
657 QueryType::Symbol => {
658 if self.symbol_index.contains_key(&query.pattern) {
659 SearchStrategy::FastSymbolLookup
660 } else {
661 SearchStrategy::FullTextIndex
662 }
663 }
664 QueryType::Definition | QueryType::References => SearchStrategy::SemanticAnalysis,
665 QueryType::FullText => SearchStrategy::FullTextIndex,
666 QueryType::Semantic => SearchStrategy::SemanticAnalysis,
667 QueryType::General => {
668 if self.is_likely_symbol(&query.pattern) {
670 SearchStrategy::FastSymbolLookup
671 } else if query.pattern.len() < 50 && !query.pattern.contains(' ') {
672 SearchStrategy::FullTextIndex
673 } else {
674 SearchStrategy::Hybrid
675 }
676 }
677 }
678 }
679
680 fn is_likely_symbol(&self, pattern: &str) -> bool {
682 pattern
684 .chars()
685 .all(|c| c.is_alphanumeric() || c == '_' || c == ':')
686 && !pattern.contains(' ')
687 && pattern.len() < 100
688 }
689
690 fn calculate_similarity(&self, a: &str, b: &str) -> f32 {
692 let distance = self.levenshtein_distance(a, b);
694 let max_len = a.len().max(b.len()) as f32;
695 if max_len == 0.0 {
696 1.0
697 } else {
698 1.0 - (distance as f32 / max_len)
699 }
700 }
701
702 fn levenshtein_distance(&self, a: &str, b: &str) -> usize {
704 let a_chars: Vec<char> = a.chars().collect();
705 let b_chars: Vec<char> = b.chars().collect();
706 let a_len = a_chars.len();
707 let b_len = b_chars.len();
708
709 let mut matrix = vec![vec![0; b_len + 1]; a_len + 1];
710
711 for i in 0..=a_len {
712 matrix[i][0] = i;
713 }
714 for j in 0..=b_len {
715 matrix[0][j] = j;
716 }
717
718 for i in 1..=a_len {
719 for j in 1..=b_len {
720 let cost = if a_chars[i - 1] == b_chars[j - 1] {
721 0
722 } else {
723 1
724 };
725 matrix[i][j] = (matrix[i - 1][j] + 1)
726 .min(matrix[i][j - 1] + 1)
727 .min(matrix[i - 1][j - 1] + cost);
728 }
729 }
730
731 matrix[a_len][b_len]
732 }
733
734 fn matches_filters(&self, symbol: &Symbol, query: &SearchQuery) -> bool {
736 if !query.file_filters.is_empty() {
738 let file_str = symbol.location.file.to_string_lossy();
739 if !query
740 .file_filters
741 .iter()
742 .any(|filter| file_str.contains(filter))
743 {
744 return false;
745 }
746 }
747
748 if !query.language_filters.is_empty()
751 && let Some(ext) = symbol.location.file.extension()
752 {
753 let ext_str = ext.to_string_lossy().to_lowercase();
754 if !query.language_filters.contains(&ext_str) {
755 return false;
756 }
757 }
758
759 true
760 }
761
762 fn matches_semantic_query(&self, symbol: &Symbol, query: &SearchQuery) -> bool {
764 match query.query_type {
765 QueryType::Definition => {
766 matches!(
767 symbol.kind,
768 SymbolKind::Function | SymbolKind::Class | SymbolKind::Struct
769 )
770 }
771 QueryType::References => {
772 symbol.name.contains(&query.pattern)
774 }
775 _ => symbol.name.contains(&query.pattern),
776 }
777 }
778
779 fn calculate_semantic_score(&self, symbol: &Symbol, query: &SearchQuery) -> f32 {
781 let mut score = 0.0;
782
783 if symbol.name == query.pattern {
785 score += 1.0;
786 } else if symbol.name.contains(&query.pattern) {
787 score += 0.8;
788 } else {
789 score += self.calculate_similarity(&symbol.name, &query.pattern) * 0.6;
790 }
791
792 if query.query_type == QueryType::Definition {
794 match symbol.kind {
795 SymbolKind::Function | SymbolKind::Method => score += 0.2,
796 SymbolKind::Class | SymbolKind::Struct => score += 0.3,
797 _ => {}
798 }
799 }
800
801 if matches!(symbol.visibility, Visibility::Public) {
803 score += 0.1;
804 }
805
806 score.min(1.0)
807 }
808
809 fn parse_ripgrep_output(&self, output: &[u8]) -> Result<Vec<Match>, ToolError> {
811 let output_str = String::from_utf8_lossy(output);
812 let mut matches = Vec::new();
813
814 for line in output_str.lines() {
815 if let Ok(json) = serde_json::from_str::<serde_json::Value>(line)
816 && json["type"] == "match"
817 && let Some(data) = json["data"].as_object()
818 {
819 let file = PathBuf::from(data["path"]["text"].as_str().unwrap_or("").to_string());
820 let line_num = data["line_number"].as_u64().unwrap_or(0) as usize;
821 let column = data["submatches"][0]["start"].as_u64().unwrap_or(0) as usize + 1; let content = data["lines"]["text"].as_str().unwrap_or("").to_string();
823
824 matches.push(Match {
825 file,
826 line: line_num,
827 column,
828 content,
829 score: 0.8, });
831 }
832 }
833
834 Ok(matches)
835 }
836
837 async fn enhance_matches_with_context(
839 &self,
840 matches: Vec<Match>,
841 _query: &SearchQuery,
842 ) -> Result<Vec<Match>, ToolError> {
843 Ok(matches)
846 }
847
848 async fn build_overall_context(
850 &self,
851 _matches: &[Match],
852 _query: &SearchQuery,
853 ) -> Result<Context, ToolError> {
854 Ok(Context {
856 before: String::new(),
857 after: String::new(),
858 surrounding: Vec::new(),
859 location: Location {
860 file: PathBuf::new(),
861 line: 0,
862 column: 0,
863 byte_offset: 0,
864 },
865 scope: Scope {
866 function: None,
867 class: None,
868 module: None,
869 namespace: None,
870 },
871 })
872 }
873
874 fn detect_language(&self, query: &SearchQuery) -> Option<String> {
876 for filter in &query.file_filters {
878 if filter.ends_with(".rs") {
879 return Some("rust".to_string());
880 } else if filter.ends_with(".py") {
881 return Some("python".to_string());
882 } else if filter.ends_with(".js") || filter.ends_with(".ts") {
883 return Some("javascript".to_string());
884 }
885 }
886 None
887 }
888
889 fn generate_summary(&self, matches: &[Match], query: &SearchQuery) -> String {
891 format!(
892 "Found {} matches for '{}' using {} search",
893 matches.len(),
894 query.pattern,
895 match query.query_type {
896 QueryType::Symbol => "symbol",
897 QueryType::FullText => "full-text",
898 QueryType::Definition => "definition",
899 QueryType::References => "reference",
900 QueryType::Semantic => "semantic",
901 QueryType::General => "general",
902 }
903 )
904 }
905
906 fn get_cached_result(&self, query: &SearchQuery) -> Option<SearchResult> {
908 let cache_key = self.generate_cache_key(query);
909 if let Some(cached) = self.query_cache.get(&cache_key) {
910 if cached.timestamp.elapsed() < cached.ttl {
911 return Some(cached.result.clone());
912 }
913 self.query_cache.remove(&cache_key);
915 }
916 None
917 }
918
919 fn cache_result(&self, query: &SearchQuery, result: &SearchResult) {
921 if self.query_cache.len() >= self.config.max_cache_size {
922 if let Some(entry) = self.query_cache.iter().next() {
924 let key = entry.key().clone();
925 drop(entry);
926 self.query_cache.remove(&key);
927 }
928 }
929
930 let cache_key = self.generate_cache_key(query);
931 self.query_cache.insert(
932 cache_key,
933 CachedResult {
934 result: result.clone(),
935 timestamp: Instant::now(),
936 ttl: self.config.cache_ttl,
937 },
938 );
939 }
940
941 fn generate_cache_key(&self, query: &SearchQuery) -> String {
943 format!(
944 "{}:{}:{}:{}",
945 query.pattern,
946 query.query_type.as_str(),
947 query.file_filters.join(","),
948 query.language_filters.join(",")
949 )
950 }
951
952 pub fn add_symbol(&self, symbol: Symbol) {
954 self.symbol_index
955 .entry(symbol.name.clone())
956 .or_default()
957 .push(symbol);
958 }
959
960 pub fn add_ast_cache(&self, ast: CachedAst) {
962 self.ast_cache.insert(ast.file_path.clone(), ast);
963 }
964
965 pub async fn find_references(&self, symbol_name: &str) -> Result<SearchResult, ToolError> {
967 let query = SearchQuery {
968 pattern: symbol_name.to_string(),
969 query_type: QueryType::References,
970 ..Default::default()
971 };
972 self.search(query).await
973 }
974
975 pub async fn find_definition(&self, symbol_name: &str) -> Result<SearchResult, ToolError> {
977 let query = SearchQuery {
978 pattern: symbol_name.to_string(),
979 query_type: QueryType::Definition,
980 ..Default::default()
981 };
982 self.search(query).await
983 }
984}
985
986impl TantivySearchEngine {
987 pub fn new() -> Result<Self, ToolError> {
989 let mut schema_builder = Schema::builder();
990
991 let path = schema_builder.add_text_field("path", TEXT | STORED);
992 let content = schema_builder.add_text_field("content", TEXT);
993 let symbols = schema_builder.add_text_field("symbols", TEXT | STORED);
994 let ast = schema_builder.add_bytes_field("ast", STORED);
995 let language = schema_builder.add_text_field("language", STRING | STORED);
996 let line_number = schema_builder.add_u64_field("line_number", INDEXED | STORED);
997 let function_name = schema_builder.add_text_field("function_name", TEXT | STORED);
998 let class_name = schema_builder.add_text_field("class_name", TEXT | STORED);
999
1000 let schema = schema_builder.build();
1001
1002 let index = Index::create_in_ram(schema.clone());
1003 let reader = index
1004 .reader()
1005 .map_err(|e| ToolError::InvalidQuery(format!("Failed to create reader: {}", e)))?;
1006
1007 let writer = index
1008 .writer(50_000_000) .map_err(|e| ToolError::InvalidQuery(format!("Failed to create writer: {}", e)))?;
1010
1011 Ok(Self {
1012 index,
1013 reader,
1014 writer: Arc::new(tokio::sync::Mutex::new(writer)),
1015 schema: TantivySchema {
1016 path,
1017 content,
1018 symbols,
1019 ast,
1020 language,
1021 line_number,
1022 function_name,
1023 class_name,
1024 },
1025 })
1026 }
1027
1028 pub async fn search(&self, query: &SearchQuery) -> Result<Vec<Match>, ToolError> {
1030 let searcher = self.reader.searcher();
1031 let schema = &self.schema;
1032
1033 let tantivy_query: Box<dyn Query> = match query.query_type {
1035 QueryType::Symbol => {
1036 let query_parser = QueryParser::for_index(&self.index, vec![schema.symbols]);
1037 query_parser
1038 .parse_query(&query.pattern)
1039 .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?
1040 }
1041 QueryType::FullText => {
1042 let query_parser = QueryParser::for_index(&self.index, vec![schema.content]);
1043 query_parser
1044 .parse_query(&query.pattern)
1045 .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?
1046 }
1047 _ => {
1048 let query_parser =
1049 QueryParser::for_index(&self.index, vec![schema.content, schema.symbols]);
1050 query_parser
1051 .parse_query(&query.pattern)
1052 .map_err(|e| ToolError::InvalidQuery(format!("Parse error: {}", e)))?
1053 }
1054 };
1055
1056 let top_docs = searcher
1057 .search(
1058 &tantivy_query,
1059 &TopDocs::with_limit(query.limit.unwrap_or(50)),
1060 )
1061 .map_err(|e| ToolError::InvalidQuery(format!("Search error: {}", e)))?;
1062
1063 let mut matches = Vec::new();
1064 for (_score, doc_address) in top_docs {
1065 let retrieved_doc: tantivy::TantivyDocument = searcher
1066 .doc(doc_address)
1067 .map_err(|e| ToolError::InvalidQuery(format!("Doc retrieval error: {}", e)))?;
1068
1069 let path = retrieved_doc
1070 .get_first(schema.path)
1071 .and_then(|f| f.as_str())
1072 .unwrap_or("")
1073 .to_string();
1074
1075 let line_num = retrieved_doc
1076 .get_first(schema.line_number)
1077 .and_then(|f| f.as_u64())
1078 .unwrap_or(1) as usize;
1079
1080 let content = retrieved_doc
1081 .get_first(schema.content)
1082 .and_then(|f| f.as_str())
1083 .unwrap_or("")
1084 .to_string();
1085
1086 matches.push(Match {
1087 file: PathBuf::from(path),
1088 line: line_num,
1089 column: 1,
1090 content,
1091 score: 0.9, });
1093 }
1094
1095 Ok(matches)
1096 }
1097
1098 pub async fn add_document(
1100 &self,
1101 path: &Path,
1102 content: &str,
1103 symbols: &[Symbol],
1104 language: &str,
1105 ) -> Result<(), ToolError> {
1106 let mut writer = self.writer.lock().await;
1107 let schema = &self.schema;
1108
1109 let symbols_text = symbols
1110 .iter()
1111 .map(|s| s.name.clone())
1112 .collect::<Vec<_>>()
1113 .join(" ");
1114
1115 let doc = doc!(
1116 schema.path => path.to_string_lossy().to_string(),
1117 schema.content => content,
1118 schema.symbols => symbols_text,
1119 schema.language => language,
1120 schema.line_number => 1u64,
1121 );
1122
1123 writer
1124 .add_document(doc)
1125 .map_err(|e| ToolError::InvalidQuery(format!("Add document error: {}", e)))?;
1126
1127 writer
1128 .commit()
1129 .map_err(|e| ToolError::InvalidQuery(format!("Commit error: {}", e)))?;
1130
1131 Ok(())
1132 }
1133}
1134
1135impl std::fmt::Debug for TantivySearchEngine {
1136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1137 f.debug_struct("TantivySearchEngine")
1138 .field("schema", &self.schema)
1139 .finish_non_exhaustive()
1140 }
1141}
1142
1143impl CodeTool for MultiLayerSearchEngine {
1146 type Query = SearchQuery;
1147 type Output = SearchResult;
1148
1149 fn search(&self, query: Self::Query) -> Result<Self::Output, ToolError> {
1150 tokio::task::block_in_place(|| {
1152 tokio::runtime::Handle::current().block_on(self.search(query))
1153 })
1154 }
1155}
1156
1157impl QueryType {
1160 const fn as_str(&self) -> &'static str {
1161 match self {
1162 QueryType::Symbol => "symbol",
1163 QueryType::FullText => "fulltext",
1164 QueryType::Definition => "definition",
1165 QueryType::References => "references",
1166 QueryType::Semantic => "semantic",
1167 QueryType::General => "general",
1168 }
1169 }
1170}
1171
1172impl SymbolKind {
1173 const fn as_str(&self) -> &'static str {
1174 match self {
1175 SymbolKind::Function => "function",
1176 SymbolKind::Method => "method",
1177 SymbolKind::Class => "class",
1178 SymbolKind::Interface => "interface",
1179 SymbolKind::Struct => "struct",
1180 SymbolKind::Enum => "enum",
1181 SymbolKind::Variable => "variable",
1182 SymbolKind::Constant => "constant",
1183 SymbolKind::Module => "module",
1184 SymbolKind::Namespace => "namespace",
1185 }
1186 }
1187}
1188
1189impl SearchQuery {
1192 pub fn new(pattern: impl Into<String>) -> Self {
1193 Self {
1194 pattern: pattern.into(),
1195 ..Default::default()
1196 }
1197 }
1198
1199 pub fn symbol(pattern: impl Into<String>) -> Self {
1200 Self {
1201 pattern: pattern.into(),
1202 query_type: QueryType::Symbol,
1203 ..Default::default()
1204 }
1205 }
1206
1207 pub fn full_text(pattern: impl Into<String>) -> Self {
1208 Self {
1209 pattern: pattern.into(),
1210 query_type: QueryType::FullText,
1211 ..Default::default()
1212 }
1213 }
1214
1215 pub fn definition(symbol: impl Into<String>) -> Self {
1216 Self {
1217 pattern: symbol.into(),
1218 query_type: QueryType::Definition,
1219 ..Default::default()
1220 }
1221 }
1222
1223 pub fn references(symbol: impl Into<String>) -> Self {
1224 Self {
1225 pattern: symbol.into(),
1226 query_type: QueryType::References,
1227 ..Default::default()
1228 }
1229 }
1230
1231 pub fn with_file_filters(mut self, filters: Vec<String>) -> Self {
1232 self.file_filters = filters;
1233 self
1234 }
1235
1236 pub fn with_language_filters(mut self, filters: Vec<String>) -> Self {
1237 self.language_filters = filters;
1238 self
1239 }
1240
1241 pub const fn with_context_lines(mut self, lines: usize) -> Self {
1242 self.context_lines = lines;
1243 self
1244 }
1245
1246 pub const fn with_limit(mut self, limit: usize) -> Self {
1247 self.limit = Some(limit);
1248 self
1249 }
1250
1251 pub const fn fuzzy(mut self) -> Self {
1252 self.fuzzy = true;
1253 self
1254 }
1255
1256 pub const fn case_insensitive(mut self) -> Self {
1257 self.case_sensitive = false;
1258 self
1259 }
1260
1261 pub fn in_directory(mut self, path: impl Into<PathBuf>) -> Self {
1262 self.scope = SearchScope::Directory(path.into());
1263 self
1264 }
1265
1266 pub fn in_files(mut self, files: Vec<PathBuf>) -> Self {
1267 self.scope = SearchScope::Files(files);
1268 self
1269 }
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274 use super::*;
1275
1276 #[tokio::test]
1277 async fn test_symbol_search() {
1278 let config = SearchConfig {
1279 enable_tantivy: false, enable_ripgrep_fallback: false, ..Default::default()
1282 };
1283 let engine = MultiLayerSearchEngine::new(config).unwrap();
1284
1285 let symbol = Symbol {
1287 name: "test_function".to_string(),
1288 kind: SymbolKind::Function,
1289 location: Location {
1290 file: PathBuf::from("test.rs"),
1291 line: 10,
1292 column: 5,
1293 byte_offset: 100,
1294 },
1295 scope: Scope {
1296 function: None,
1297 class: None,
1298 module: Some("test_module".to_string()),
1299 namespace: None,
1300 },
1301 visibility: Visibility::Public,
1302 };
1303
1304 engine.add_symbol(symbol);
1305
1306 let query = SearchQuery::symbol("test_function");
1308 let result = engine.search(query).await.unwrap();
1309
1310 assert_eq!(result.result.len(), 1);
1311 assert_eq!(result.result[0].file, PathBuf::from("test.rs"));
1312 assert_eq!(result.result[0].line, 10);
1313 assert_eq!(result.metadata.search_layer, SearchLayer::SymbolIndex);
1314 }
1315
1316 #[tokio::test]
1317 async fn test_fuzzy_search() {
1318 let config = SearchConfig {
1319 enable_tantivy: false, enable_ripgrep_fallback: false, enable_ast_cache: false, ..Default::default()
1323 };
1324 let engine = MultiLayerSearchEngine::new(config).unwrap();
1325
1326 let symbol = Symbol {
1327 name: "calculateSum".to_string(),
1328 kind: SymbolKind::Function,
1329 location: Location {
1330 file: PathBuf::from("math.js"),
1331 line: 5,
1332 column: 1,
1333 byte_offset: 50,
1334 },
1335 scope: Scope {
1336 function: None,
1337 class: None,
1338 module: None,
1339 namespace: None,
1340 },
1341 visibility: Visibility::Public,
1342 };
1343
1344 engine.add_symbol(symbol);
1345
1346 let query = SearchQuery::symbol("calculaeSum").fuzzy();
1348 let result = engine.search(query).await.unwrap();
1349
1350 assert!(
1351 !result.result.is_empty(),
1352 "Fuzzy search should find similar symbols"
1353 );
1354 assert!(
1355 result.result[0].score > 0.7,
1356 "Score should be above 0.7 for similar match"
1357 );
1358 }
1359
1360 #[tokio::test]
1361 async fn test_search_strategy_selection() {
1362 let config = SearchConfig {
1363 enable_tantivy: false, enable_ripgrep_fallback: false, ..Default::default()
1366 };
1367 let engine = MultiLayerSearchEngine::new(config).unwrap();
1368
1369 let query = SearchQuery::new("function_name");
1371 let strategy = engine.select_strategy(&query);
1372 assert_eq!(strategy, SearchStrategy::FastSymbolLookup); let query = SearchQuery::new("this is a long text search query");
1376 let strategy = engine.select_strategy(&query);
1377 assert_eq!(strategy, SearchStrategy::Hybrid);
1378 }
1379
1380 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1381 #[ignore = "Temporarily disabled due to async deadlock issue"]
1382 async fn test_cache_functionality() {
1383 let config = SearchConfig {
1385 max_cache_size: 2,
1386 cache_ttl: Duration::from_millis(100),
1387 enable_symbol_index: true, enable_tantivy: false, enable_ripgrep_fallback: false, enable_ast_cache: false, max_results: 10,
1392 timeout: Duration::from_secs(2), };
1394 let engine = MultiLayerSearchEngine::new(config).unwrap();
1395
1396 let symbol = Symbol {
1398 name: "test_cache_function".to_string(),
1399 kind: SymbolKind::Function,
1400 location: Location {
1401 file: PathBuf::from("test.rs"),
1402 line: 1,
1403 column: 1,
1404 byte_offset: 0,
1405 },
1406 scope: Scope {
1407 function: None,
1408 class: None,
1409 module: None,
1410 namespace: None,
1411 },
1412 visibility: Visibility::Public,
1413 };
1414 engine.add_symbol(symbol);
1415
1416 let query = SearchQuery::symbol("test_cache_function");
1417 let timeout_duration = Duration::from_secs(3);
1418
1419 let result1 = tokio::time::timeout(timeout_duration, engine.search(query.clone()))
1421 .await
1422 .expect("First search timed out")
1423 .unwrap();
1424 assert!(!result1.result.is_empty(), "Should find the test symbol");
1425
1426 let result2 = tokio::time::timeout(timeout_duration, engine.search(query.clone()))
1428 .await
1429 .expect("Second search timed out")
1430 .unwrap();
1431 assert_eq!(result1.result.len(), result2.result.len());
1432
1433 tokio::time::sleep(Duration::from_millis(150)).await;
1435
1436 let result3 = tokio::time::timeout(timeout_duration, engine.search(query))
1438 .await
1439 .expect("Third search timed out")
1440 .unwrap();
1441 assert_eq!(result1.result.len(), result3.result.len());
1442 }
1443
1444 #[tokio::test]
1445 async fn test_similarity_calculation() {
1446 let config = SearchConfig {
1447 enable_tantivy: false, enable_ripgrep_fallback: false, ..Default::default()
1450 };
1451 let engine = MultiLayerSearchEngine::new(config).unwrap();
1452
1453 assert_eq!(engine.calculate_similarity("hello", "hello"), 1.0);
1454 assert_eq!(engine.calculate_similarity("", ""), 1.0);
1455
1456 let similarity = engine.calculate_similarity("hello", "helo");
1457 assert!(similarity >= 0.8); let similarity = engine.calculate_similarity("test", "completely_different");
1460 assert!(similarity < 0.3);
1461 }
1462
1463 #[tokio::test]
1464 async fn test_query_builder() {
1465 let query = SearchQuery::symbol("test_function")
1466 .with_file_filters(vec!["*.rs".to_string()])
1467 .with_language_filters(vec!["rust".to_string()])
1468 .with_context_lines(5)
1469 .with_limit(10)
1470 .fuzzy()
1471 .case_insensitive()
1472 .in_directory("/path/to/project");
1473
1474 assert_eq!(query.pattern, "test_function");
1475 assert_eq!(query.query_type, QueryType::Symbol);
1476 assert_eq!(query.file_filters, vec!["*.rs"]);
1477 assert_eq!(query.language_filters, vec!["rust"]);
1478 assert_eq!(query.context_lines, 5);
1479 assert_eq!(query.limit, Some(10));
1480 assert!(query.fuzzy);
1481 assert!(!query.case_sensitive);
1482 match query.scope {
1483 SearchScope::Directory(path) => assert_eq!(path, PathBuf::from("/path/to/project")),
1484 _ => panic!("Expected Directory scope"),
1485 }
1486 }
1487}