Skip to main content

reflex/parsers/
go.rs

1//! Go language parser using Tree-sitter
2//!
3//! Extracts symbols from Go source code:
4//! - Functions (func)
5//! - Types (struct, interface)
6//! - Methods (with receiver type)
7//! - Constants (const declarations and blocks)
8//! - Variables (var declarations and short declarations with :=)
9//! - Packages/Imports
10
11use anyhow::{Context, Result};
12use streaming_iterator::StreamingIterator;
13use tree_sitter::{Parser, Query, QueryCursor};
14use crate::models::{Language, SearchResult, Span, SymbolKind};
15
16/// Parse Go source code and extract symbols
17pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
18    let mut parser = Parser::new();
19    let language = tree_sitter_go::LANGUAGE;
20
21    parser
22        .set_language(&language.into())
23        .context("Failed to set Go language")?;
24
25    let tree = parser
26        .parse(source, None)
27        .context("Failed to parse Go source")?;
28
29    let root_node = tree.root_node();
30
31    let mut symbols = Vec::new();
32
33    // Extract different types of symbols using Tree-sitter queries
34    symbols.extend(extract_functions(source, &root_node, &language.into())?);
35    symbols.extend(extract_types(source, &root_node, &language.into())?);
36    symbols.extend(extract_interfaces(source, &root_node, &language.into())?);
37    symbols.extend(extract_methods(source, &root_node, &language.into())?);
38    symbols.extend(extract_constants(source, &root_node, &language.into())?);
39    symbols.extend(extract_variables(source, &root_node, &language.into())?);
40
41    // Add file path to all symbols
42    for symbol in &mut symbols {
43        symbol.path = path.to_string();
44        symbol.lang = Language::Go;
45    }
46
47    Ok(symbols)
48}
49
50/// Extract function declarations
51fn extract_functions(
52    source: &str,
53    root: &tree_sitter::Node,
54    language: &tree_sitter::Language,
55) -> Result<Vec<SearchResult>> {
56    let query_str = r#"
57        (function_declaration
58            name: (identifier) @name) @function
59    "#;
60
61    let query = Query::new(language, query_str)
62        .context("Failed to create function query")?;
63
64    extract_symbols(source, root, &query, SymbolKind::Function, None)
65}
66
67/// Extract type declarations (structs)
68fn extract_types(
69    source: &str,
70    root: &tree_sitter::Node,
71    language: &tree_sitter::Language,
72) -> Result<Vec<SearchResult>> {
73    let query_str = r#"
74        (type_declaration
75            (type_spec
76                name: (type_identifier) @name
77                type: (struct_type))) @struct
78    "#;
79
80    let query = Query::new(language, query_str)
81        .context("Failed to create struct query")?;
82
83    extract_symbols(source, root, &query, SymbolKind::Struct, None)
84}
85
86/// Extract interface declarations
87fn extract_interfaces(
88    source: &str,
89    root: &tree_sitter::Node,
90    language: &tree_sitter::Language,
91) -> Result<Vec<SearchResult>> {
92    let query_str = r#"
93        (type_declaration
94            (type_spec
95                name: (type_identifier) @name
96                type: (interface_type))) @interface
97    "#;
98
99    let query = Query::new(language, query_str)
100        .context("Failed to create interface query")?;
101
102    extract_symbols(source, root, &query, SymbolKind::Interface, None)
103}
104
105/// Extract method declarations (functions with receivers)
106fn extract_methods(
107    source: &str,
108    root: &tree_sitter::Node,
109    language: &tree_sitter::Language,
110) -> Result<Vec<SearchResult>> {
111    let query_str = r#"
112        (method_declaration
113            receiver: (parameter_list
114                (parameter_declaration
115                    type: [(type_identifier) (pointer_type (type_identifier))] @receiver_type))
116            name: (field_identifier) @method_name) @method
117    "#;
118
119    let query = Query::new(language, query_str)
120        .context("Failed to create method query")?;
121
122    let mut cursor = QueryCursor::new();
123    let mut matches = cursor.matches(&query, *root, source.as_bytes());
124
125    let mut symbols = Vec::new();
126
127    while let Some(match_) = matches.next() {
128        let mut receiver_type = None;
129        let mut method_name = None;
130        let mut method_node = None;
131
132        for capture in match_.captures {
133            let capture_name: &str = &query.capture_names()[capture.index as usize];
134            match capture_name {
135                "receiver_type" => {
136                    receiver_type = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
137                }
138                "method_name" => {
139                    method_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
140                }
141                "method" => {
142                    method_node = Some(capture.node);
143                }
144                _ => {}
145            }
146        }
147
148        if let (Some(receiver_type), Some(method_name), Some(node)) = (receiver_type, method_name, method_node) {
149            // Clean up receiver type (remove * if pointer)
150            let clean_receiver = receiver_type.trim_start_matches('*');
151            let scope = format!("type {}", clean_receiver);
152            let span = node_to_span(&node);
153            let preview = extract_preview(source, &span);
154
155            symbols.push(SearchResult::new(
156                String::new(),
157                Language::Go,
158                SymbolKind::Method,
159                Some(method_name),
160                span,
161                Some(scope),
162                preview,
163            ));
164        }
165    }
166
167    Ok(symbols)
168}
169
170/// Extract constant declarations
171fn extract_constants(
172    source: &str,
173    root: &tree_sitter::Node,
174    language: &tree_sitter::Language,
175) -> Result<Vec<SearchResult>> {
176    let query_str = r#"
177        (const_declaration
178            (const_spec
179                name: (identifier) @name)) @const
180    "#;
181
182    let query = Query::new(language, query_str)
183        .context("Failed to create const query")?;
184
185    extract_symbols(source, root, &query, SymbolKind::Constant, None)
186}
187
188/// Extract variable declarations (var and := short declarations)
189fn extract_variables(
190    source: &str,
191    root: &tree_sitter::Node,
192    language: &tree_sitter::Language,
193) -> Result<Vec<SearchResult>> {
194    // Match both var_spec and short_var_declaration to capture all variable declarations
195    let query_str = r#"
196        (var_spec
197            name: (identifier) @name) @var
198
199        (short_var_declaration
200            left: (expression_list (identifier) @name)) @short_var
201    "#;
202
203    let query = Query::new(language, query_str)
204        .context("Failed to create var query")?;
205
206    let mut cursor = QueryCursor::new();
207    let mut matches = cursor.matches(&query, *root, source.as_bytes());
208
209    let mut symbols = Vec::new();
210
211    while let Some(match_) = matches.next() {
212        let mut name = None;
213        let mut decl_node = None;
214
215        for capture in match_.captures {
216            let capture_name: &str = &query.capture_names()[capture.index as usize];
217            match capture_name {
218                "name" => {
219                    name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
220                }
221                "var" | "short_var" => {
222                    decl_node = Some(capture.node);
223                }
224                _ => {}
225            }
226        }
227
228        if let (Some(name), Some(node)) = (name, decl_node) {
229            let span = node_to_span(&node);
230            let preview = extract_preview(source, &span);
231
232            symbols.push(SearchResult::new(
233                String::new(),
234                Language::Go,
235                SymbolKind::Variable,
236                Some(name),
237                span,
238                None,
239                preview,
240            ));
241        }
242    }
243
244    Ok(symbols)
245}
246
247/// Generic symbol extraction helper
248fn extract_symbols(
249    source: &str,
250    root: &tree_sitter::Node,
251    query: &Query,
252    kind: SymbolKind,
253    scope: Option<String>,
254) -> Result<Vec<SearchResult>> {
255    let mut cursor = QueryCursor::new();
256    let mut matches = cursor.matches(query, *root, source.as_bytes());
257
258    let mut symbols = Vec::new();
259
260    while let Some(match_) = matches.next() {
261        // Find the name capture and the full node
262        let mut name = None;
263        let mut full_node = None;
264
265        for capture in match_.captures {
266            let capture_name: &str = &query.capture_names()[capture.index as usize];
267            if capture_name == "name" {
268                name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
269            } else {
270                // Assume any other capture is the full node
271                full_node = Some(capture.node);
272            }
273        }
274
275        if let (Some(name), Some(node)) = (name, full_node) {
276            let span = node_to_span(&node);
277            let preview = extract_preview(source, &span);
278
279            symbols.push(SearchResult::new(
280                String::new(),
281                Language::Go,
282                kind.clone(),
283                Some(name),
284                span,
285                scope.clone(),
286                preview,
287            ));
288        }
289    }
290
291    Ok(symbols)
292}
293
294/// Convert a Tree-sitter node to a Span
295fn node_to_span(node: &tree_sitter::Node) -> Span {
296    let start = node.start_position();
297    let end = node.end_position();
298
299    Span::new(
300        start.row + 1,  // Convert 0-indexed to 1-indexed
301        start.column,
302        end.row + 1,
303        end.column,
304    )
305}
306
307/// Extract a preview (7 lines) around the symbol
308fn extract_preview(source: &str, span: &Span) -> String {
309    let lines: Vec<&str> = source.lines().collect();
310
311    // Extract 7 lines: the start line and 6 following lines
312    let start_idx = (span.start_line - 1) as usize; // Convert back to 0-indexed
313    let end_idx = (start_idx + 7).min(lines.len());
314
315    lines[start_idx..end_idx].join("\n")
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_parse_function() {
324        let source = r#"
325package main
326
327func helloWorld() string {
328    return "Hello, world!"
329}
330        "#;
331
332        let symbols = parse("test.go", source).unwrap();
333        assert_eq!(symbols.len(), 1);
334        assert_eq!(symbols[0].symbol.as_deref(), Some("helloWorld"));
335        assert!(matches!(symbols[0].kind, SymbolKind::Function));
336    }
337
338    #[test]
339    fn test_parse_struct() {
340        let source = r#"
341package main
342
343type User struct {
344    Name string
345    Age  int
346}
347        "#;
348
349        let symbols = parse("test.go", source).unwrap();
350        assert_eq!(symbols.len(), 1);
351        assert_eq!(symbols[0].symbol.as_deref(), Some("User"));
352        assert!(matches!(symbols[0].kind, SymbolKind::Struct));
353    }
354
355    #[test]
356    fn test_parse_interface() {
357        let source = r#"
358package main
359
360type Reader interface {
361    Read(p []byte) (n int, err error)
362}
363        "#;
364
365        let symbols = parse("test.go", source).unwrap();
366        assert_eq!(symbols.len(), 1);
367        assert_eq!(symbols[0].symbol.as_deref(), Some("Reader"));
368        assert!(matches!(symbols[0].kind, SymbolKind::Interface));
369    }
370
371    #[test]
372    fn test_parse_method() {
373        let source = r#"
374package main
375
376type User struct {
377    Name string
378}
379
380func (u *User) GetName() string {
381    return u.Name
382}
383
384func (u User) SetName(name string) {
385    u.Name = name
386}
387        "#;
388
389        let symbols = parse("test.go", source).unwrap();
390
391        let method_symbols: Vec<_> = symbols.iter()
392            .filter(|s| matches!(s.kind, SymbolKind::Method))
393            .collect();
394
395        assert_eq!(method_symbols.len(), 2);
396        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("GetName")));
397        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("SetName")));
398
399        // Check scope
400        for method in method_symbols {
401            // Removed: scope field no longer exists: assert_eq!(method.scope.as_ref().unwrap(), "type User");
402        }
403    }
404
405    #[test]
406    fn test_parse_constants() {
407        let source = r#"
408package main
409
410const MaxSize = 100
411const DefaultTimeout = 30
412
413const (
414    StatusActive   = 1
415    StatusInactive = 2
416)
417        "#;
418
419        let symbols = parse("test.go", source).unwrap();
420
421        let const_symbols: Vec<_> = symbols.iter()
422            .filter(|s| matches!(s.kind, SymbolKind::Constant))
423            .collect();
424
425        assert_eq!(const_symbols.len(), 4);
426        assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("MaxSize")));
427        assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("DefaultTimeout")));
428        assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("StatusActive")));
429        assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("StatusInactive")));
430    }
431
432    #[test]
433    fn test_parse_variables() {
434        let source = r#"
435package main
436
437var GlobalConfig Config
438var (
439    Logger  *log.Logger
440    Version = "1.0.0"
441)
442        "#;
443
444        let symbols = parse("test.go", source).unwrap();
445
446        let var_symbols: Vec<_> = symbols.iter()
447            .filter(|s| matches!(s.kind, SymbolKind::Variable))
448            .collect();
449
450        assert_eq!(var_symbols.len(), 3);
451        assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("GlobalConfig")));
452        assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("Logger")));
453        assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("Version")));
454    }
455
456    #[test]
457    fn test_parse_mixed_symbols() {
458        let source = r#"
459package main
460
461const DefaultPort = 8080
462
463type Server struct {
464    Port int
465}
466
467type Handler interface {
468    Handle(req *Request) error
469}
470
471func (s *Server) Start() error {
472    return nil
473}
474
475func NewServer(port int) *Server {
476    return &Server{Port: port}
477}
478
479var globalServer *Server
480        "#;
481
482        let symbols = parse("test.go", source).unwrap();
483
484        // Should find: const, struct, interface, method, function, var
485        assert!(symbols.len() >= 6);
486
487        let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
488        assert!(kinds.contains(&&SymbolKind::Constant));
489        assert!(kinds.contains(&&SymbolKind::Struct));
490        assert!(kinds.contains(&&SymbolKind::Interface));
491        assert!(kinds.contains(&&SymbolKind::Method));
492        assert!(kinds.contains(&&SymbolKind::Function));
493        assert!(kinds.contains(&&SymbolKind::Variable));
494    }
495
496    #[test]
497    fn test_parse_multiple_methods() {
498        let source = r#"
499package main
500
501type Calculator struct{}
502
503func (c *Calculator) Add(a, b int) int {
504    return a + b
505}
506
507func (c *Calculator) Subtract(a, b int) int {
508    return a - b
509}
510
511func (c *Calculator) Multiply(a, b int) int {
512    return a * b
513}
514        "#;
515
516        let symbols = parse("test.go", source).unwrap();
517
518        let method_symbols: Vec<_> = symbols.iter()
519            .filter(|s| matches!(s.kind, SymbolKind::Method))
520            .collect();
521
522        assert_eq!(method_symbols.len(), 3);
523        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("Add")));
524        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("Subtract")));
525        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("Multiply")));
526    }
527
528    #[test]
529    fn test_parse_type_alias() {
530        let source = r#"
531package main
532
533type UserID string
534type Age int
535
536type Config struct {
537    Host string
538    Port int
539}
540        "#;
541
542        let symbols = parse("test.go", source).unwrap();
543
544        // Should find the Config struct (type aliases UserID and Age are type_spec but not struct_type)
545        let struct_symbols: Vec<_> = symbols.iter()
546            .filter(|s| matches!(s.kind, SymbolKind::Struct))
547            .collect();
548
549        assert_eq!(struct_symbols.len(), 1);
550        assert_eq!(struct_symbols[0].symbol.as_deref(), Some("Config"));
551    }
552
553    #[test]
554    fn test_parse_embedded_interface() {
555        let source = r#"
556package main
557
558type Reader interface {
559    Read(p []byte) (n int, err error)
560}
561
562type Writer interface {
563    Write(p []byte) (n int, err error)
564}
565
566type ReadWriter interface {
567    Reader
568    Writer
569}
570        "#;
571
572        let symbols = parse("test.go", source).unwrap();
573
574        let interface_symbols: Vec<_> = symbols.iter()
575            .filter(|s| matches!(s.kind, SymbolKind::Interface))
576            .collect();
577
578        assert_eq!(interface_symbols.len(), 3);
579        assert!(interface_symbols.iter().any(|s| s.symbol.as_deref() == Some("Reader")));
580        assert!(interface_symbols.iter().any(|s| s.symbol.as_deref() == Some("Writer")));
581        assert!(interface_symbols.iter().any(|s| s.symbol.as_deref() == Some("ReadWriter")));
582    }
583
584    #[test]
585    fn test_local_variables_included() {
586        let source = r#"
587package main
588
589var globalCount int = 10
590
591func calculate(x int) int {
592    localVar := x * 2
593    var anotherLocal int = 5
594    return localVar + anotherLocal
595}
596        "#;
597
598        let symbols = parse("test.go", source).unwrap();
599
600        let var_symbols: Vec<_> = symbols.iter()
601            .filter(|s| matches!(s.kind, SymbolKind::Variable))
602            .collect();
603
604        // Should find globalCount, localVar (short declaration), and anotherLocal (var declaration)
605        assert_eq!(var_symbols.len(), 3);
606        assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("globalCount")));
607        assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("localVar")));
608        assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("anotherLocal")));
609    }
610
611    #[test]
612    fn test_extract_go_imports() {
613        let source = r#"package main
614
615import (
616	"fmt"
617	"encoding/json"
618	"github.com/gin-gonic/gin"
619	"myproject/internal/models"
620)
621
622func main() {
623	fmt.Println("Hello")
624}
625"#;
626
627        let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
628
629        assert_eq!(deps.len(), 4, "Should extract 4 import statements");
630        assert!(deps.iter().any(|d| d.imported_path == "fmt"));
631        assert!(deps.iter().any(|d| d.imported_path == "encoding/json"));
632        assert!(deps.iter().any(|d| d.imported_path == "github.com/gin-gonic/gin"));
633        assert!(deps.iter().any(|d| d.imported_path == "myproject/internal/models"));
634
635        // Check stdlib classification
636        let fmt_dep = deps.iter().find(|d| d.imported_path == "fmt").unwrap();
637        assert!(matches!(fmt_dep.import_type, ImportType::Stdlib),
638                "fmt should be classified as Stdlib");
639
640        let json_dep = deps.iter().find(|d| d.imported_path == "encoding/json").unwrap();
641        assert!(matches!(json_dep.import_type, ImportType::Stdlib),
642                "encoding/json should be classified as Stdlib");
643
644        // Check external classification
645        let gin_dep = deps.iter().find(|d| d.imported_path == "github.com/gin-gonic/gin").unwrap();
646        assert!(matches!(gin_dep.import_type, ImportType::External),
647                "github.com/gin-gonic/gin should be classified as External");
648
649        // Check myproject classification (ambiguous but should be External)
650        let models_dep = deps.iter().find(|d| d.imported_path == "myproject/internal/models").unwrap();
651        assert!(matches!(models_dep.import_type, ImportType::External),
652                "myproject/internal/models should be classified as External");
653    }
654
655    #[test]
656    fn test_extract_go_imports_with_comments() {
657        // Real-world Go code from Kubernetes with inline comments
658        let source = r#"package main
659
660import (
661	"os"
662	_ "time/tzdata" // for timeZone support in CronJob
663
664	"k8s.io/component-base/cli"
665	_ "k8s.io/component-base/logs/json/register"          // for JSON log format registration
666	_ "k8s.io/component-base/metrics/prometheus/clientgo" // load all the prometheus client-go plugins
667)
668
669func main() {
670	os.Exit(0)
671}
672"#;
673
674        let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
675
676        println!("Extracted {} dependencies:", deps.len());
677        for dep in &deps {
678            println!("  - {} (line {})", dep.imported_path, dep.line_number);
679        }
680
681        // Should extract all imports, even those with _ alias and comments
682        assert!(deps.len() >= 4, "Should extract at least 4 imports, got {}", deps.len());
683        assert!(deps.iter().any(|d| d.imported_path == "os"));
684        assert!(deps.iter().any(|d| d.imported_path == "time/tzdata"));
685        assert!(deps.iter().any(|d| d.imported_path == "k8s.io/component-base/cli"));
686    }
687
688    #[test]
689    fn test_find_all_go_mods() {
690        use tempfile::TempDir;
691        use std::fs;
692
693        let temp = TempDir::new().unwrap();
694        let root = temp.path();
695
696        // Create multiple Go modules
697        let service1 = root.join("services/auth");
698        fs::create_dir_all(&service1).unwrap();
699        fs::write(service1.join("go.mod"), "module github.com/myorg/auth\n\ngo 1.21\n").unwrap();
700
701        let service2 = root.join("services/api");
702        fs::create_dir_all(&service2).unwrap();
703        fs::write(service2.join("go.mod"), "module github.com/myorg/api\n\ngo 1.21\n").unwrap();
704
705        // Create vendor directory that should be skipped
706        let vendor = root.join("vendor");
707        fs::create_dir_all(&vendor).unwrap();
708        fs::write(vendor.join("go.mod"), "module github.com/external/lib\n").unwrap();
709
710        let mods = find_all_go_mods(root).unwrap();
711
712        // Should find 2 modules (skipping vendor)
713        assert_eq!(mods.len(), 2);
714        assert!(mods.iter().any(|p| p.ends_with("services/auth/go.mod")));
715        assert!(mods.iter().any(|p| p.ends_with("services/api/go.mod")));
716    }
717
718    #[test]
719    fn test_parse_all_go_modules() {
720        use tempfile::TempDir;
721        use std::fs;
722
723        let temp = TempDir::new().unwrap();
724        let root = temp.path();
725
726        // Create multiple Go modules
727        let service1 = root.join("services/auth");
728        fs::create_dir_all(&service1).unwrap();
729        fs::write(
730            service1.join("go.mod"),
731            "module github.com/myorg/auth\n\ngo 1.21\n"
732        ).unwrap();
733
734        let service2 = root.join("cmd/api");
735        fs::create_dir_all(&service2).unwrap();
736        fs::write(
737            service2.join("go.mod"),
738            "module github.com/myorg/api\n\ngo 1.21\n"
739        ).unwrap();
740
741        let modules = parse_all_go_modules(root).unwrap();
742
743        // Should find 2 modules
744        assert_eq!(modules.len(), 2);
745
746        // Check module names
747        let names: Vec<_> = modules.iter().map(|m| m.name.as_str()).collect();
748        assert!(names.contains(&"github.com/myorg/auth"));
749        assert!(names.contains(&"github.com/myorg/api"));
750
751        // Check project roots
752        for module in &modules {
753            assert!(module.project_root.starts_with("services/") || module.project_root.starts_with("cmd/"));
754            assert!(module.abs_project_root.ends_with(&module.project_root));
755        }
756    }
757
758    #[test]
759    fn test_resolve_go_import() {
760        use tempfile::TempDir;
761        use std::fs;
762
763        let temp = TempDir::new().unwrap();
764        let root = temp.path();
765
766        // Create a Go module structure
767        let myapp = root.join("myapp");
768        fs::create_dir_all(myapp.join("pkg/models")).unwrap();
769        fs::write(
770            myapp.join("go.mod"),
771            "module github.com/myorg/myapp\n\ngo 1.21\n"
772        ).unwrap();
773
774        let modules = parse_all_go_modules(root).unwrap();
775        assert_eq!(modules.len(), 1);
776
777        // Test sub-package import resolution
778        // "github.com/myorg/myapp/pkg/models" → "myapp/pkg/models.go" or "myapp/pkg/models/models.go"
779        let resolved = resolve_go_import_to_path(
780            "github.com/myorg/myapp/pkg/models",
781            &modules,
782            None
783        );
784
785        assert!(resolved.is_some());
786        let path = resolved.unwrap();
787        assert!(path.contains("myapp/pkg/models"));
788        assert!(path.ends_with(".go"));
789    }
790
791    #[test]
792    fn test_resolve_go_import_module_root() {
793        use tempfile::TempDir;
794        use std::fs;
795
796        let temp = TempDir::new().unwrap();
797        let root = temp.path();
798
799        let myapp = root.join("cmd/server");
800        fs::create_dir_all(&myapp).unwrap();
801        fs::write(
802            myapp.join("go.mod"),
803            "module github.com/myorg/server\n\ngo 1.21\n"
804        ).unwrap();
805
806        let modules = parse_all_go_modules(root).unwrap();
807
808        // Test module root import (no sub-package)
809        let resolved = resolve_go_import_to_path(
810            "github.com/myorg/server",
811            &modules,
812            None
813        );
814
815        assert!(resolved.is_some());
816        let path = resolved.unwrap();
817        // Should try main.go or server.go
818        assert!(path.contains("cmd/server"));
819        assert!(path.ends_with(".go"));
820    }
821
822    #[test]
823    fn test_resolve_go_import_not_found() {
824        use tempfile::TempDir;
825        use std::fs;
826
827        let temp = TempDir::new().unwrap();
828        let root = temp.path();
829
830        let myapp = root.join("myapp");
831        fs::create_dir_all(&myapp).unwrap();
832        fs::write(
833            myapp.join("go.mod"),
834            "module github.com/myorg/myapp\n\ngo 1.21\n"
835        ).unwrap();
836
837        let modules = parse_all_go_modules(root).unwrap();
838
839        // Try to resolve an import for a different module
840        let resolved = resolve_go_import_to_path(
841            "github.com/other/package",
842            &modules,
843            None
844        );
845
846        // Should return None for modules not in the monorepo
847        assert!(resolved.is_none());
848    }
849
850    #[test]
851    fn test_resolve_go_import_relative() {
852        let modules = vec![];
853
854        // Relative imports are not supported yet
855        let resolved = resolve_go_import_to_path(
856            "./utils",
857            &modules,
858            Some("myapp/pkg/api/handler.go"),
859        );
860
861        assert!(resolved.is_none());
862    }
863
864    #[test]
865    fn test_resolve_go_import_root_module_no_leading_slash() {
866        // When go.mod is at the repo root, project_root is "" and paths must not
867        // start with "/" (which would cause ambiguous fuzzy matches in the DB).
868        use tempfile::TempDir;
869        use std::fs;
870
871        let temp = TempDir::new().unwrap();
872        let root = temp.path();
873
874        // go.mod at repo root → project_root = ""
875        fs::write(
876            root.join("go.mod"),
877            "module k8s.io/kubernetes\n\ngo 1.21\n",
878        ).unwrap();
879
880        let modules = parse_all_go_modules(root).unwrap();
881        assert_eq!(modules.len(), 1);
882        assert_eq!(modules[0].project_root, "");
883
884        // Sub-package import
885        let resolved = resolve_go_import_to_path(
886            "k8s.io/kubernetes/test/internal/metric",
887            &modules,
888            None,
889        );
890        assert!(resolved.is_some());
891        let path = resolved.unwrap();
892        assert!(!path.starts_with('/'), "path must not start with '/': {}", path);
893        assert!(path.ends_with(".go"));
894        assert!(path.contains("test/internal/metric"));
895
896        // Module-root import
897        let resolved = resolve_go_import_to_path(
898            "k8s.io/kubernetes",
899            &modules,
900            None,
901        );
902        assert!(resolved.is_some());
903        let path = resolved.unwrap();
904        assert!(!path.starts_with('/'), "path must not start with '/': {}", path);
905        assert!(path.ends_with(".go"));
906    }
907}
908
909// ============================================================================
910// Dependency Extraction
911// ============================================================================
912
913use crate::models::ImportType;
914use crate::parsers::{DependencyExtractor, ImportInfo};
915
916/// Go dependency extractor
917pub struct GoDependencyExtractor;
918
919impl DependencyExtractor for GoDependencyExtractor {
920    fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
921        let mut parser = Parser::new();
922        let language = tree_sitter_go::LANGUAGE;
923
924        parser
925            .set_language(&language.into())
926            .context("Failed to set Go language")?;
927
928        let tree = parser
929            .parse(source, None)
930            .context("Failed to parse Go source")?;
931
932        let root_node = tree.root_node();
933
934        let mut imports = Vec::new();
935
936        // Extract import statements
937        imports.extend(extract_go_imports(source, &root_node)?);
938
939        Ok(imports)
940    }
941}
942
943/// Extract Go import statements
944fn extract_go_imports(
945    source: &str,
946    root: &tree_sitter::Node,
947) -> Result<Vec<ImportInfo>> {
948    let language = tree_sitter_go::LANGUAGE;
949
950    // Go imports can be single or in groups
951    let query_str = r#"
952        (import_declaration
953            (import_spec
954                path: (interpreted_string_literal) @import_path)) @import
955
956        (import_declaration
957            (import_spec_list
958                (import_spec
959                    path: (interpreted_string_literal) @import_path))) @import
960    "#;
961
962    let query = Query::new(&language.into(), query_str)
963        .context("Failed to create Go import query")?;
964
965    let mut cursor = QueryCursor::new();
966    let mut matches = cursor.matches(&query, *root, source.as_bytes());
967
968    let mut imports = Vec::new();
969
970    while let Some(match_) = matches.next() {
971        let mut import_path = None;
972        let mut import_node = None;
973
974        for capture in match_.captures {
975            let capture_name: &str = &query.capture_names()[capture.index as usize];
976            match capture_name {
977                "import_path" => {
978                    // Remove quotes from string literal
979                    let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
980                    import_path = Some(raw_path.trim_matches('"').to_string());
981                }
982                "import" => {
983                    import_node = Some(capture.node);
984                }
985                _ => {}
986            }
987        }
988
989        if let (Some(path), Some(node)) = (import_path, import_node) {
990            let import_type = classify_go_import(&path);
991            let line_number = node.start_position().row + 1;
992
993            imports.push(ImportInfo {
994                imported_path: path,
995                import_type,
996                line_number,
997                imported_symbols: None, // Go imports entire packages, not selective symbols
998            });
999        }
1000    }
1001
1002    Ok(imports)
1003}
1004
1005/// Find and parse go.mod to extract module name
1006/// Returns None if go.mod not found or module name can't be parsed
1007pub fn find_go_module_name(root: &std::path::Path) -> Option<String> {
1008    // Look for go.mod in root directory
1009    let go_mod_path = root.join("go.mod");
1010    if !go_mod_path.exists() {
1011        return None;
1012    }
1013
1014    // Read go.mod and extract module name
1015    let content = std::fs::read_to_string(&go_mod_path).ok()?;
1016    for line in content.lines() {
1017        let trimmed = line.trim();
1018        if trimmed.starts_with("module ") {
1019            // Extract module name: "module k8s.io/kubernetes" -> "k8s.io/kubernetes"
1020            let module_name = trimmed["module ".len()..].trim();
1021            return Some(module_name.to_string());
1022        }
1023    }
1024
1025    None
1026}
1027
1028/// Reclassify a Go import based on the module prefix
1029/// This should be called by the indexer after extraction to correctly identify internal imports
1030pub fn reclassify_go_import(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1031    classify_go_import_impl(import_path, module_prefix)
1032}
1033
1034/// Classify a Go import as internal, external, or stdlib
1035fn classify_go_import(import_path: &str) -> ImportType {
1036    classify_go_import_impl(import_path, None)
1037}
1038
1039/// Internal implementation of Go import classification
1040fn classify_go_import_impl(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1041    // If we have a module prefix, check if import starts with it → Internal
1042    if let Some(prefix) = module_prefix {
1043        if import_path.starts_with(prefix) {
1044            return ImportType::Internal;
1045        }
1046        // Also check for multi-module repos - imports starting with k8s.io/* for Kubernetes
1047        // Extract the domain portion and check if it matches
1048        if let Some(import_domain) = import_path.split('/').next() {
1049            if let Some(module_domain) = prefix.split('/').next() {
1050                // If domains match (e.g., both start with k8s.io), consider it internal
1051                if import_domain == module_domain && module_domain.contains('.') {
1052                    return ImportType::Internal;
1053                }
1054            }
1055        }
1056    }
1057    // Relative imports (./ or ../) - rare in Go but possible
1058    if import_path.starts_with("./") || import_path.starts_with("../") {
1059        return ImportType::Internal;
1060    }
1061
1062    // Internal imports often start with company domain or project path
1063    // Check for common patterns like github.com/your-org/project
1064    // For now, we'll consider anything that looks like a full URL path as external
1065    // and short stdlib-like paths as stdlib
1066
1067    // Go standard library modules (common ones)
1068    const STDLIB_MODULES: &[&str] = &[
1069        "fmt", "io", "os", "path", "strings", "bytes", "bufio", "errors",
1070        "context", "sync", "time", "encoding/json", "encoding/xml", "encoding/csv",
1071        "net/http", "net/url", "net", "crypto", "crypto/tls", "crypto/sha256",
1072        "database/sql", "log", "math", "regexp", "strconv", "sort", "reflect",
1073        "runtime", "testing", "flag", "filepath", "unicode", "html", "text/template",
1074    ];
1075
1076    // Check if it's a stdlib module
1077    if STDLIB_MODULES.contains(&import_path) {
1078        return ImportType::Stdlib;
1079    }
1080
1081    // If it contains a domain (has dots and slashes), it's external
1082    if import_path.contains('/') && import_path.split('/').next().unwrap_or("").contains('.') {
1083        return ImportType::External;
1084    }
1085
1086    // Short paths without domains are likely stdlib
1087    if !import_path.contains('/') || import_path.split('/').count() <= 2 {
1088        return ImportType::Stdlib;
1089    }
1090
1091    // Everything else is external
1092    ImportType::External
1093}
1094
1095// ============================================================================
1096// Monorepo Support & Path Resolution
1097// ============================================================================
1098
1099/// Represents a Go module with its location
1100#[derive(Debug, Clone)]
1101pub struct GoModule {
1102    /// Module name (e.g., "k8s.io/kubernetes", "github.com/myorg/myproject")
1103    pub name: String,
1104    /// Project root relative to index root (e.g., "services/api")
1105    pub project_root: String,
1106    /// Absolute path to project root
1107    pub abs_project_root: std::path::PathBuf,
1108}
1109
1110/// Recursively find all go.mod files in the repository, respecting .gitignore
1111pub fn find_all_go_mods(index_root: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
1112    use ignore::WalkBuilder;
1113
1114    let mut go_mod_files = Vec::new();
1115
1116    let walker = WalkBuilder::new(index_root)
1117        .follow_links(false)
1118        .git_ignore(true)
1119        .build();
1120
1121    for entry in walker {
1122        let entry = entry?;
1123        let path = entry.path();
1124
1125        if !path.is_file() {
1126            continue;
1127        }
1128
1129        let filename = path.file_name()
1130            .and_then(|n| n.to_str())
1131            .unwrap_or("");
1132
1133        // Look for go.mod files
1134        if filename == "go.mod" {
1135            // Skip vendor directories
1136            let path_str = path.to_string_lossy();
1137            if path_str.contains("/vendor/") {
1138                log::trace!("Skipping go.mod in vendor directory: {:?}", path);
1139                continue;
1140            }
1141
1142            go_mod_files.push(path.to_path_buf());
1143        }
1144    }
1145
1146    log::debug!("Found {} go.mod files", go_mod_files.len());
1147    Ok(go_mod_files)
1148}
1149
1150/// Parse all Go modules in a monorepo and track their project roots
1151pub fn parse_all_go_modules(index_root: &std::path::Path) -> Result<Vec<GoModule>> {
1152    let go_mod_files = find_all_go_mods(index_root)?;
1153
1154    if go_mod_files.is_empty() {
1155        log::debug!("No go.mod files found in {:?}", index_root);
1156        return Ok(Vec::new());
1157    }
1158
1159    let mut modules = Vec::new();
1160    let mod_count = go_mod_files.len();
1161
1162    for go_mod_path in &go_mod_files {
1163        let project_root = go_mod_path
1164            .parent()
1165            .ok_or_else(|| anyhow::anyhow!("go.mod has no parent directory"))?;
1166
1167        // Read and parse go.mod to extract module name
1168        if let Ok(content) = std::fs::read_to_string(go_mod_path) {
1169            for line in content.lines() {
1170                let trimmed = line.trim();
1171                if trimmed.starts_with("module ") {
1172                    let module_name = trimmed["module ".len()..].trim().to_string();
1173
1174                    let relative_project_root = project_root
1175                        .strip_prefix(index_root)
1176                        .unwrap_or(project_root)
1177                        .to_string_lossy()
1178                        .to_string();
1179
1180                    log::debug!(
1181                        "Found Go module '{}' at {:?}",
1182                        module_name,
1183                        relative_project_root
1184                    );
1185
1186                    modules.push(GoModule {
1187                        name: module_name,
1188                        project_root: relative_project_root,
1189                        abs_project_root: project_root.to_path_buf(),
1190                    });
1191                    break;
1192                }
1193            }
1194        }
1195    }
1196
1197    log::info!(
1198        "Loaded {} Go modules from {} go.mod files",
1199        modules.len(),
1200        mod_count
1201    );
1202
1203    Ok(modules)
1204}
1205
1206/// Resolve a Go import to a file path
1207///
1208/// Handles:
1209/// - Internal imports: `mymodule/pkg/utils` → `pkg/utils.go` or `pkg/utils/utils.go`
1210/// - Sub-packages: `mymodule/internal/models` → `internal/models/models.go`
1211/// - Relative imports: `./utils` (rare in Go but possible)
1212pub fn resolve_go_import_to_path(
1213    import_path: &str,
1214    modules: &[GoModule],
1215    _current_file_path: Option<&str>,
1216) -> Option<String> {
1217    // Handle relative imports (rare in Go)
1218    if import_path.starts_with("./") || import_path.starts_with("../") {
1219        // Go relative imports are rare and complex - skip for now
1220        return None;
1221    }
1222
1223    // Find matching module
1224    for module in modules {
1225        if import_path.starts_with(&module.name) {
1226            // Strip module name to get sub-package path
1227            // "k8s.io/kubernetes/pkg/api" with module "k8s.io/kubernetes" → "pkg/api"
1228            let sub_path = import_path.strip_prefix(&module.name)
1229                .unwrap_or(import_path)
1230                .trim_start_matches('/');
1231
1232            if sub_path.is_empty() {
1233                // Importing the module root - could be multiple files
1234                // Try common patterns
1235                let basename = module.name.split('/').last().unwrap_or("main");
1236                let candidates = if module.project_root.is_empty() {
1237                    vec![
1238                        "main.go".to_string(),
1239                        format!("{}.go", basename),
1240                    ]
1241                } else {
1242                    vec![
1243                        format!("{}/main.go", module.project_root),
1244                        format!("{}/{}.go", module.project_root, basename),
1245                    ]
1246                };
1247
1248                for candidate in candidates {
1249                    log::trace!("Checking Go module root: {}", candidate);
1250                    return Some(candidate);
1251                }
1252            } else {
1253                // Sub-package import
1254                // Try both single file and package directory patterns
1255                let package_name = sub_path.split('/').last().unwrap_or(sub_path);
1256                let candidates = if module.project_root.is_empty() {
1257                    vec![
1258                        format!("{}.go", sub_path),
1259                        format!("{}/{}.go", sub_path, package_name),
1260                    ]
1261                } else {
1262                    vec![
1263                        format!("{}/{}.go", module.project_root, sub_path),
1264                        format!("{}/{}/{}.go", module.project_root, sub_path, package_name),
1265                    ]
1266                };
1267
1268                for candidate in candidates {
1269                    log::trace!("Checking Go package path: {}", candidate);
1270                    return Some(candidate);
1271                }
1272            }
1273        }
1274    }
1275
1276    None
1277}