Skip to main content

reflex/parsers/
go.rs

1//! Go language parser using Tree-sitter
2//!
3//! Extracts symbols from Go source code:
4//! - Functions (func)
5//! - Types (struct, interface)
6//! - Methods (with receiver type)
7//! - Constants (const declarations and blocks)
8//! - Variables (var declarations and short declarations with :=)
9//! - Packages/Imports
10
11use crate::models::{Language, SearchResult, Span, SymbolKind};
12use anyhow::{Context, Result};
13use streaming_iterator::StreamingIterator;
14use tree_sitter::{Parser, Query, QueryCursor};
15
16/// Parse Go source code and extract symbols
17pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
18    let mut parser = Parser::new();
19    let language = tree_sitter_go::LANGUAGE;
20
21    parser
22        .set_language(&language.into())
23        .context("Failed to set Go language")?;
24
25    let tree = parser
26        .parse(source, None)
27        .context("Failed to parse Go source")?;
28
29    let root_node = tree.root_node();
30
31    let mut symbols = Vec::new();
32
33    // Extract different types of symbols using Tree-sitter queries
34    symbols.extend(extract_functions(source, &root_node, &language.into())?);
35    symbols.extend(extract_types(source, &root_node, &language.into())?);
36    symbols.extend(extract_interfaces(source, &root_node, &language.into())?);
37    symbols.extend(extract_methods(source, &root_node, &language.into())?);
38    symbols.extend(extract_constants(source, &root_node, &language.into())?);
39    symbols.extend(extract_variables(source, &root_node, &language.into())?);
40
41    // Add file path to all symbols
42    for symbol in &mut symbols {
43        symbol.path = path.to_string();
44        symbol.lang = Language::Go;
45    }
46
47    Ok(symbols)
48}
49
50/// Extract function declarations
51fn extract_functions(
52    source: &str,
53    root: &tree_sitter::Node,
54    language: &tree_sitter::Language,
55) -> Result<Vec<SearchResult>> {
56    let query_str = r#"
57        (function_declaration
58            name: (identifier) @name) @function
59    "#;
60
61    let query = Query::new(language, query_str).context("Failed to create function query")?;
62
63    extract_symbols(source, root, &query, SymbolKind::Function, None)
64}
65
66/// Extract type declarations (structs)
67fn extract_types(
68    source: &str,
69    root: &tree_sitter::Node,
70    language: &tree_sitter::Language,
71) -> Result<Vec<SearchResult>> {
72    let query_str = r#"
73        (type_declaration
74            (type_spec
75                name: (type_identifier) @name
76                type: (struct_type))) @struct
77    "#;
78
79    let query = Query::new(language, query_str).context("Failed to create struct query")?;
80
81    extract_symbols(source, root, &query, SymbolKind::Struct, None)
82}
83
84/// Extract interface declarations
85fn extract_interfaces(
86    source: &str,
87    root: &tree_sitter::Node,
88    language: &tree_sitter::Language,
89) -> Result<Vec<SearchResult>> {
90    let query_str = r#"
91        (type_declaration
92            (type_spec
93                name: (type_identifier) @name
94                type: (interface_type))) @interface
95    "#;
96
97    let query = Query::new(language, query_str).context("Failed to create interface query")?;
98
99    extract_symbols(source, root, &query, SymbolKind::Interface, None)
100}
101
102/// Extract method declarations (functions with receivers)
103fn extract_methods(
104    source: &str,
105    root: &tree_sitter::Node,
106    language: &tree_sitter::Language,
107) -> Result<Vec<SearchResult>> {
108    let query_str = r#"
109        (method_declaration
110            receiver: (parameter_list
111                (parameter_declaration
112                    type: [(type_identifier) (pointer_type (type_identifier))] @receiver_type))
113            name: (field_identifier) @method_name) @method
114    "#;
115
116    let query = Query::new(language, query_str).context("Failed to create method query")?;
117
118    let mut cursor = QueryCursor::new();
119    let mut matches = cursor.matches(&query, *root, source.as_bytes());
120
121    let mut symbols = Vec::new();
122
123    while let Some(match_) = matches.next() {
124        let mut receiver_type = None;
125        let mut method_name = None;
126        let mut method_node = None;
127
128        for capture in match_.captures {
129            let capture_name: &str = &query.capture_names()[capture.index as usize];
130            match capture_name {
131                "receiver_type" => {
132                    receiver_type = Some(
133                        capture
134                            .node
135                            .utf8_text(source.as_bytes())
136                            .unwrap_or("")
137                            .to_string(),
138                    );
139                }
140                "method_name" => {
141                    method_name = Some(
142                        capture
143                            .node
144                            .utf8_text(source.as_bytes())
145                            .unwrap_or("")
146                            .to_string(),
147                    );
148                }
149                "method" => {
150                    method_node = Some(capture.node);
151                }
152                _ => {}
153            }
154        }
155
156        if let (Some(receiver_type), Some(method_name), Some(node)) =
157            (receiver_type, method_name, method_node)
158        {
159            // Clean up receiver type (remove * if pointer)
160            let clean_receiver = receiver_type.trim_start_matches('*');
161            let scope = format!("type {}", clean_receiver);
162            let span = node_to_span(&node);
163            let preview = extract_preview(source, &span);
164
165            symbols.push(SearchResult::new(
166                String::new(),
167                Language::Go,
168                SymbolKind::Method,
169                Some(method_name),
170                span,
171                Some(scope),
172                preview,
173            ));
174        }
175    }
176
177    Ok(symbols)
178}
179
180/// Extract constant declarations
181fn extract_constants(
182    source: &str,
183    root: &tree_sitter::Node,
184    language: &tree_sitter::Language,
185) -> Result<Vec<SearchResult>> {
186    let query_str = r#"
187        (const_declaration
188            (const_spec
189                name: (identifier) @name)) @const
190    "#;
191
192    let query = Query::new(language, query_str).context("Failed to create const query")?;
193
194    extract_symbols(source, root, &query, SymbolKind::Constant, None)
195}
196
197/// Extract variable declarations (var and := short declarations)
198fn extract_variables(
199    source: &str,
200    root: &tree_sitter::Node,
201    language: &tree_sitter::Language,
202) -> Result<Vec<SearchResult>> {
203    // Match both var_spec and short_var_declaration to capture all variable declarations
204    let query_str = r#"
205        (var_spec
206            name: (identifier) @name) @var
207
208        (short_var_declaration
209            left: (expression_list (identifier) @name)) @short_var
210    "#;
211
212    let query = Query::new(language, query_str).context("Failed to create var query")?;
213
214    let mut cursor = QueryCursor::new();
215    let mut matches = cursor.matches(&query, *root, source.as_bytes());
216
217    let mut symbols = Vec::new();
218
219    while let Some(match_) = matches.next() {
220        let mut name = None;
221        let mut decl_node = None;
222
223        for capture in match_.captures {
224            let capture_name: &str = &query.capture_names()[capture.index as usize];
225            match capture_name {
226                "name" => {
227                    name = Some(
228                        capture
229                            .node
230                            .utf8_text(source.as_bytes())
231                            .unwrap_or("")
232                            .to_string(),
233                    );
234                }
235                "var" | "short_var" => {
236                    decl_node = Some(capture.node);
237                }
238                _ => {}
239            }
240        }
241
242        if let (Some(name), Some(node)) = (name, decl_node) {
243            let span = node_to_span(&node);
244            let preview = extract_preview(source, &span);
245
246            symbols.push(SearchResult::new(
247                String::new(),
248                Language::Go,
249                SymbolKind::Variable,
250                Some(name),
251                span,
252                None,
253                preview,
254            ));
255        }
256    }
257
258    Ok(symbols)
259}
260
261/// Generic symbol extraction helper
262fn extract_symbols(
263    source: &str,
264    root: &tree_sitter::Node,
265    query: &Query,
266    kind: SymbolKind,
267    scope: Option<String>,
268) -> Result<Vec<SearchResult>> {
269    let mut cursor = QueryCursor::new();
270    let mut matches = cursor.matches(query, *root, source.as_bytes());
271
272    let mut symbols = Vec::new();
273
274    while let Some(match_) = matches.next() {
275        // Find the name capture and the full node
276        let mut name = None;
277        let mut full_node = None;
278
279        for capture in match_.captures {
280            let capture_name: &str = &query.capture_names()[capture.index as usize];
281            if capture_name == "name" {
282                name = Some(
283                    capture
284                        .node
285                        .utf8_text(source.as_bytes())
286                        .unwrap_or("")
287                        .to_string(),
288                );
289            } else {
290                // Assume any other capture is the full node
291                full_node = Some(capture.node);
292            }
293        }
294
295        if let (Some(name), Some(node)) = (name, full_node) {
296            let span = node_to_span(&node);
297            let preview = extract_preview(source, &span);
298
299            symbols.push(SearchResult::new(
300                String::new(),
301                Language::Go,
302                kind.clone(),
303                Some(name),
304                span,
305                scope.clone(),
306                preview,
307            ));
308        }
309    }
310
311    Ok(symbols)
312}
313
314/// Convert a Tree-sitter node to a Span
315fn node_to_span(node: &tree_sitter::Node) -> Span {
316    let start = node.start_position();
317    let end = node.end_position();
318
319    Span::new(
320        start.row + 1, // Convert 0-indexed to 1-indexed
321        start.column,
322        end.row + 1,
323        end.column,
324    )
325}
326
327/// Extract a preview (7 lines) around the symbol
328fn extract_preview(source: &str, span: &Span) -> String {
329    let lines: Vec<&str> = source.lines().collect();
330
331    // Extract 7 lines: the start line and 6 following lines
332    let start_idx = (span.start_line - 1) as usize; // Convert back to 0-indexed
333    let end_idx = (start_idx + 7).min(lines.len());
334
335    lines[start_idx..end_idx].join("\n")
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_parse_function() {
344        let source = r#"
345package main
346
347func helloWorld() string {
348    return "Hello, world!"
349}
350        "#;
351
352        let symbols = parse("test.go", source).unwrap();
353        assert_eq!(symbols.len(), 1);
354        assert_eq!(symbols[0].symbol.as_deref(), Some("helloWorld"));
355        assert!(matches!(symbols[0].kind, SymbolKind::Function));
356    }
357
358    #[test]
359    fn test_parse_struct() {
360        let source = r#"
361package main
362
363type User struct {
364    Name string
365    Age  int
366}
367        "#;
368
369        let symbols = parse("test.go", source).unwrap();
370        assert_eq!(symbols.len(), 1);
371        assert_eq!(symbols[0].symbol.as_deref(), Some("User"));
372        assert!(matches!(symbols[0].kind, SymbolKind::Struct));
373    }
374
375    #[test]
376    fn test_parse_interface() {
377        let source = r#"
378package main
379
380type Reader interface {
381    Read(p []byte) (n int, err error)
382}
383        "#;
384
385        let symbols = parse("test.go", source).unwrap();
386        assert_eq!(symbols.len(), 1);
387        assert_eq!(symbols[0].symbol.as_deref(), Some("Reader"));
388        assert!(matches!(symbols[0].kind, SymbolKind::Interface));
389    }
390
391    #[test]
392    fn test_parse_method() {
393        let source = r#"
394package main
395
396type User struct {
397    Name string
398}
399
400func (u *User) GetName() string {
401    return u.Name
402}
403
404func (u User) SetName(name string) {
405    u.Name = name
406}
407        "#;
408
409        let symbols = parse("test.go", source).unwrap();
410
411        let method_symbols: Vec<_> = symbols
412            .iter()
413            .filter(|s| matches!(s.kind, SymbolKind::Method))
414            .collect();
415
416        assert_eq!(method_symbols.len(), 2);
417        assert!(
418            method_symbols
419                .iter()
420                .any(|s| s.symbol.as_deref() == Some("GetName"))
421        );
422        assert!(
423            method_symbols
424                .iter()
425                .any(|s| s.symbol.as_deref() == Some("SetName"))
426        );
427
428        // Check scope
429        for method in method_symbols {
430            // Removed: scope field no longer exists: assert_eq!(method.scope.as_ref().unwrap(), "type User");
431        }
432    }
433
434    #[test]
435    fn test_parse_constants() {
436        let source = r#"
437package main
438
439const MaxSize = 100
440const DefaultTimeout = 30
441
442const (
443    StatusActive   = 1
444    StatusInactive = 2
445)
446        "#;
447
448        let symbols = parse("test.go", source).unwrap();
449
450        let const_symbols: Vec<_> = symbols
451            .iter()
452            .filter(|s| matches!(s.kind, SymbolKind::Constant))
453            .collect();
454
455        assert_eq!(const_symbols.len(), 4);
456        assert!(
457            const_symbols
458                .iter()
459                .any(|s| s.symbol.as_deref() == Some("MaxSize"))
460        );
461        assert!(
462            const_symbols
463                .iter()
464                .any(|s| s.symbol.as_deref() == Some("DefaultTimeout"))
465        );
466        assert!(
467            const_symbols
468                .iter()
469                .any(|s| s.symbol.as_deref() == Some("StatusActive"))
470        );
471        assert!(
472            const_symbols
473                .iter()
474                .any(|s| s.symbol.as_deref() == Some("StatusInactive"))
475        );
476    }
477
478    #[test]
479    fn test_parse_variables() {
480        let source = r#"
481package main
482
483var GlobalConfig Config
484var (
485    Logger  *log.Logger
486    Version = "1.0.0"
487)
488        "#;
489
490        let symbols = parse("test.go", source).unwrap();
491
492        let var_symbols: Vec<_> = symbols
493            .iter()
494            .filter(|s| matches!(s.kind, SymbolKind::Variable))
495            .collect();
496
497        assert_eq!(var_symbols.len(), 3);
498        assert!(
499            var_symbols
500                .iter()
501                .any(|s| s.symbol.as_deref() == Some("GlobalConfig"))
502        );
503        assert!(
504            var_symbols
505                .iter()
506                .any(|s| s.symbol.as_deref() == Some("Logger"))
507        );
508        assert!(
509            var_symbols
510                .iter()
511                .any(|s| s.symbol.as_deref() == Some("Version"))
512        );
513    }
514
515    #[test]
516    fn test_parse_mixed_symbols() {
517        let source = r#"
518package main
519
520const DefaultPort = 8080
521
522type Server struct {
523    Port int
524}
525
526type Handler interface {
527    Handle(req *Request) error
528}
529
530func (s *Server) Start() error {
531    return nil
532}
533
534func NewServer(port int) *Server {
535    return &Server{Port: port}
536}
537
538var globalServer *Server
539        "#;
540
541        let symbols = parse("test.go", source).unwrap();
542
543        // Should find: const, struct, interface, method, function, var
544        assert!(symbols.len() >= 6);
545
546        let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
547        assert!(kinds.contains(&&SymbolKind::Constant));
548        assert!(kinds.contains(&&SymbolKind::Struct));
549        assert!(kinds.contains(&&SymbolKind::Interface));
550        assert!(kinds.contains(&&SymbolKind::Method));
551        assert!(kinds.contains(&&SymbolKind::Function));
552        assert!(kinds.contains(&&SymbolKind::Variable));
553    }
554
555    #[test]
556    fn test_parse_multiple_methods() {
557        let source = r#"
558package main
559
560type Calculator struct{}
561
562func (c *Calculator) Add(a, b int) int {
563    return a + b
564}
565
566func (c *Calculator) Subtract(a, b int) int {
567    return a - b
568}
569
570func (c *Calculator) Multiply(a, b int) int {
571    return a * b
572}
573        "#;
574
575        let symbols = parse("test.go", source).unwrap();
576
577        let method_symbols: Vec<_> = symbols
578            .iter()
579            .filter(|s| matches!(s.kind, SymbolKind::Method))
580            .collect();
581
582        assert_eq!(method_symbols.len(), 3);
583        assert!(
584            method_symbols
585                .iter()
586                .any(|s| s.symbol.as_deref() == Some("Add"))
587        );
588        assert!(
589            method_symbols
590                .iter()
591                .any(|s| s.symbol.as_deref() == Some("Subtract"))
592        );
593        assert!(
594            method_symbols
595                .iter()
596                .any(|s| s.symbol.as_deref() == Some("Multiply"))
597        );
598    }
599
600    #[test]
601    fn test_parse_type_alias() {
602        let source = r#"
603package main
604
605type UserID string
606type Age int
607
608type Config struct {
609    Host string
610    Port int
611}
612        "#;
613
614        let symbols = parse("test.go", source).unwrap();
615
616        // Should find the Config struct (type aliases UserID and Age are type_spec but not struct_type)
617        let struct_symbols: Vec<_> = symbols
618            .iter()
619            .filter(|s| matches!(s.kind, SymbolKind::Struct))
620            .collect();
621
622        assert_eq!(struct_symbols.len(), 1);
623        assert_eq!(struct_symbols[0].symbol.as_deref(), Some("Config"));
624    }
625
626    #[test]
627    fn test_parse_embedded_interface() {
628        let source = r#"
629package main
630
631type Reader interface {
632    Read(p []byte) (n int, err error)
633}
634
635type Writer interface {
636    Write(p []byte) (n int, err error)
637}
638
639type ReadWriter interface {
640    Reader
641    Writer
642}
643        "#;
644
645        let symbols = parse("test.go", source).unwrap();
646
647        let interface_symbols: Vec<_> = symbols
648            .iter()
649            .filter(|s| matches!(s.kind, SymbolKind::Interface))
650            .collect();
651
652        assert_eq!(interface_symbols.len(), 3);
653        assert!(
654            interface_symbols
655                .iter()
656                .any(|s| s.symbol.as_deref() == Some("Reader"))
657        );
658        assert!(
659            interface_symbols
660                .iter()
661                .any(|s| s.symbol.as_deref() == Some("Writer"))
662        );
663        assert!(
664            interface_symbols
665                .iter()
666                .any(|s| s.symbol.as_deref() == Some("ReadWriter"))
667        );
668    }
669
670    #[test]
671    fn test_local_variables_included() {
672        let source = r#"
673package main
674
675var globalCount int = 10
676
677func calculate(x int) int {
678    localVar := x * 2
679    var anotherLocal int = 5
680    return localVar + anotherLocal
681}
682        "#;
683
684        let symbols = parse("test.go", source).unwrap();
685
686        let var_symbols: Vec<_> = symbols
687            .iter()
688            .filter(|s| matches!(s.kind, SymbolKind::Variable))
689            .collect();
690
691        // Should find globalCount, localVar (short declaration), and anotherLocal (var declaration)
692        assert_eq!(var_symbols.len(), 3);
693        assert!(
694            var_symbols
695                .iter()
696                .any(|s| s.symbol.as_deref() == Some("globalCount"))
697        );
698        assert!(
699            var_symbols
700                .iter()
701                .any(|s| s.symbol.as_deref() == Some("localVar"))
702        );
703        assert!(
704            var_symbols
705                .iter()
706                .any(|s| s.symbol.as_deref() == Some("anotherLocal"))
707        );
708    }
709
710    #[test]
711    fn test_extract_go_imports() {
712        let source = r#"package main
713
714import (
715	"fmt"
716	"encoding/json"
717	"github.com/gin-gonic/gin"
718	"myproject/internal/models"
719)
720
721func main() {
722	fmt.Println("Hello")
723}
724"#;
725
726        let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
727
728        assert_eq!(deps.len(), 4, "Should extract 4 import statements");
729        assert!(deps.iter().any(|d| d.imported_path == "fmt"));
730        assert!(deps.iter().any(|d| d.imported_path == "encoding/json"));
731        assert!(
732            deps.iter()
733                .any(|d| d.imported_path == "github.com/gin-gonic/gin")
734        );
735        assert!(
736            deps.iter()
737                .any(|d| d.imported_path == "myproject/internal/models")
738        );
739
740        // Check stdlib classification
741        let fmt_dep = deps.iter().find(|d| d.imported_path == "fmt").unwrap();
742        assert!(
743            matches!(fmt_dep.import_type, ImportType::Stdlib),
744            "fmt should be classified as Stdlib"
745        );
746
747        let json_dep = deps
748            .iter()
749            .find(|d| d.imported_path == "encoding/json")
750            .unwrap();
751        assert!(
752            matches!(json_dep.import_type, ImportType::Stdlib),
753            "encoding/json should be classified as Stdlib"
754        );
755
756        // Check external classification
757        let gin_dep = deps
758            .iter()
759            .find(|d| d.imported_path == "github.com/gin-gonic/gin")
760            .unwrap();
761        assert!(
762            matches!(gin_dep.import_type, ImportType::External),
763            "github.com/gin-gonic/gin should be classified as External"
764        );
765
766        // Check myproject classification (ambiguous but should be External)
767        let models_dep = deps
768            .iter()
769            .find(|d| d.imported_path == "myproject/internal/models")
770            .unwrap();
771        assert!(
772            matches!(models_dep.import_type, ImportType::External),
773            "myproject/internal/models should be classified as External"
774        );
775    }
776
777    #[test]
778    fn test_extract_go_imports_with_comments() {
779        // Real-world Go code from Kubernetes with inline comments
780        let source = r#"package main
781
782import (
783	"os"
784	_ "time/tzdata" // for timeZone support in CronJob
785
786	"k8s.io/component-base/cli"
787	_ "k8s.io/component-base/logs/json/register"          // for JSON log format registration
788	_ "k8s.io/component-base/metrics/prometheus/clientgo" // load all the prometheus client-go plugins
789)
790
791func main() {
792	os.Exit(0)
793}
794"#;
795
796        let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
797
798        println!("Extracted {} dependencies:", deps.len());
799        for dep in &deps {
800            println!("  - {} (line {})", dep.imported_path, dep.line_number);
801        }
802
803        // Should extract all imports, even those with _ alias and comments
804        assert!(
805            deps.len() >= 4,
806            "Should extract at least 4 imports, got {}",
807            deps.len()
808        );
809        assert!(deps.iter().any(|d| d.imported_path == "os"));
810        assert!(deps.iter().any(|d| d.imported_path == "time/tzdata"));
811        assert!(
812            deps.iter()
813                .any(|d| d.imported_path == "k8s.io/component-base/cli")
814        );
815    }
816
817    #[test]
818    fn test_find_all_go_mods() {
819        use std::fs;
820        use tempfile::TempDir;
821
822        let temp = TempDir::new().unwrap();
823        let root = temp.path();
824
825        // Create multiple Go modules
826        let service1 = root.join("services/auth");
827        fs::create_dir_all(&service1).unwrap();
828        fs::write(
829            service1.join("go.mod"),
830            "module github.com/myorg/auth\n\ngo 1.21\n",
831        )
832        .unwrap();
833
834        let service2 = root.join("services/api");
835        fs::create_dir_all(&service2).unwrap();
836        fs::write(
837            service2.join("go.mod"),
838            "module github.com/myorg/api\n\ngo 1.21\n",
839        )
840        .unwrap();
841
842        // Create vendor directory that should be skipped
843        let vendor = root.join("vendor");
844        fs::create_dir_all(&vendor).unwrap();
845        fs::write(vendor.join("go.mod"), "module github.com/external/lib\n").unwrap();
846
847        let mods = find_all_go_mods(root).unwrap();
848
849        // Should find 2 modules (skipping vendor)
850        assert_eq!(mods.len(), 2);
851        assert!(mods.iter().any(|p| p.ends_with("services/auth/go.mod")));
852        assert!(mods.iter().any(|p| p.ends_with("services/api/go.mod")));
853    }
854
855    #[test]
856    fn test_parse_all_go_modules() {
857        use std::fs;
858        use tempfile::TempDir;
859
860        let temp = TempDir::new().unwrap();
861        let root = temp.path();
862
863        // Create multiple Go modules
864        let service1 = root.join("services/auth");
865        fs::create_dir_all(&service1).unwrap();
866        fs::write(
867            service1.join("go.mod"),
868            "module github.com/myorg/auth\n\ngo 1.21\n",
869        )
870        .unwrap();
871
872        let service2 = root.join("cmd/api");
873        fs::create_dir_all(&service2).unwrap();
874        fs::write(
875            service2.join("go.mod"),
876            "module github.com/myorg/api\n\ngo 1.21\n",
877        )
878        .unwrap();
879
880        let modules = parse_all_go_modules(root).unwrap();
881
882        // Should find 2 modules
883        assert_eq!(modules.len(), 2);
884
885        // Check module names
886        let names: Vec<_> = modules.iter().map(|m| m.name.as_str()).collect();
887        assert!(names.contains(&"github.com/myorg/auth"));
888        assert!(names.contains(&"github.com/myorg/api"));
889
890        // Check project roots
891        for module in &modules {
892            assert!(
893                module.project_root.starts_with("services/")
894                    || module.project_root.starts_with("cmd/")
895            );
896            assert!(module.abs_project_root.ends_with(&module.project_root));
897        }
898    }
899
900    #[test]
901    fn test_resolve_go_import() {
902        use std::fs;
903        use tempfile::TempDir;
904
905        let temp = TempDir::new().unwrap();
906        let root = temp.path();
907
908        // Create a Go module structure
909        let myapp = root.join("myapp");
910        fs::create_dir_all(myapp.join("pkg/models")).unwrap();
911        fs::write(
912            myapp.join("go.mod"),
913            "module github.com/myorg/myapp\n\ngo 1.21\n",
914        )
915        .unwrap();
916
917        let modules = parse_all_go_modules(root).unwrap();
918        assert_eq!(modules.len(), 1);
919
920        // Test sub-package import resolution
921        // "github.com/myorg/myapp/pkg/models" → "myapp/pkg/models.go" or "myapp/pkg/models/models.go"
922        let resolved =
923            resolve_go_import_to_path("github.com/myorg/myapp/pkg/models", &modules, None);
924
925        assert!(resolved.is_some());
926        let path = resolved.unwrap();
927        assert!(path.contains("myapp/pkg/models"));
928        assert!(path.ends_with(".go"));
929    }
930
931    #[test]
932    fn test_resolve_go_import_module_root() {
933        use std::fs;
934        use tempfile::TempDir;
935
936        let temp = TempDir::new().unwrap();
937        let root = temp.path();
938
939        let myapp = root.join("cmd/server");
940        fs::create_dir_all(&myapp).unwrap();
941        fs::write(
942            myapp.join("go.mod"),
943            "module github.com/myorg/server\n\ngo 1.21\n",
944        )
945        .unwrap();
946
947        let modules = parse_all_go_modules(root).unwrap();
948
949        // Test module root import (no sub-package)
950        let resolved = resolve_go_import_to_path("github.com/myorg/server", &modules, None);
951
952        assert!(resolved.is_some());
953        let path = resolved.unwrap();
954        // Should try main.go or server.go
955        assert!(path.contains("cmd/server"));
956        assert!(path.ends_with(".go"));
957    }
958
959    #[test]
960    fn test_resolve_go_import_not_found() {
961        use std::fs;
962        use tempfile::TempDir;
963
964        let temp = TempDir::new().unwrap();
965        let root = temp.path();
966
967        let myapp = root.join("myapp");
968        fs::create_dir_all(&myapp).unwrap();
969        fs::write(
970            myapp.join("go.mod"),
971            "module github.com/myorg/myapp\n\ngo 1.21\n",
972        )
973        .unwrap();
974
975        let modules = parse_all_go_modules(root).unwrap();
976
977        // Try to resolve an import for a different module
978        let resolved = resolve_go_import_to_path("github.com/other/package", &modules, None);
979
980        // Should return None for modules not in the monorepo
981        assert!(resolved.is_none());
982    }
983
984    #[test]
985    fn test_resolve_go_import_relative() {
986        let modules = vec![];
987
988        // Relative imports are not supported yet
989        let resolved =
990            resolve_go_import_to_path("./utils", &modules, Some("myapp/pkg/api/handler.go"));
991
992        assert!(resolved.is_none());
993    }
994
995    #[test]
996    fn test_resolve_go_import_root_module_no_leading_slash() {
997        // When go.mod is at the repo root, project_root is "" and paths must not
998        // start with "/" (which would cause ambiguous fuzzy matches in the DB).
999        use std::fs;
1000        use tempfile::TempDir;
1001
1002        let temp = TempDir::new().unwrap();
1003        let root = temp.path();
1004
1005        // go.mod at repo root → project_root = ""
1006        fs::write(root.join("go.mod"), "module k8s.io/kubernetes\n\ngo 1.21\n").unwrap();
1007
1008        let modules = parse_all_go_modules(root).unwrap();
1009        assert_eq!(modules.len(), 1);
1010        assert_eq!(modules[0].project_root, "");
1011
1012        // Sub-package import
1013        let resolved =
1014            resolve_go_import_to_path("k8s.io/kubernetes/test/internal/metric", &modules, None);
1015        assert!(resolved.is_some());
1016        let path = resolved.unwrap();
1017        assert!(
1018            !path.starts_with('/'),
1019            "path must not start with '/': {}",
1020            path
1021        );
1022        assert!(path.ends_with(".go"));
1023        assert!(path.contains("test/internal/metric"));
1024
1025        // Module-root import
1026        let resolved = resolve_go_import_to_path("k8s.io/kubernetes", &modules, None);
1027        assert!(resolved.is_some());
1028        let path = resolved.unwrap();
1029        assert!(
1030            !path.starts_with('/'),
1031            "path must not start with '/': {}",
1032            path
1033        );
1034        assert!(path.ends_with(".go"));
1035    }
1036}
1037
1038// ============================================================================
1039// Dependency Extraction
1040// ============================================================================
1041
1042use crate::models::ImportType;
1043use crate::parsers::{DependencyExtractor, ImportInfo};
1044
1045/// Go dependency extractor
1046pub struct GoDependencyExtractor;
1047
1048impl DependencyExtractor for GoDependencyExtractor {
1049    fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
1050        let mut parser = Parser::new();
1051        let language = tree_sitter_go::LANGUAGE;
1052
1053        parser
1054            .set_language(&language.into())
1055            .context("Failed to set Go language")?;
1056
1057        let tree = parser
1058            .parse(source, None)
1059            .context("Failed to parse Go source")?;
1060
1061        let root_node = tree.root_node();
1062
1063        let mut imports = Vec::new();
1064
1065        // Extract import statements
1066        imports.extend(extract_go_imports(source, &root_node)?);
1067
1068        Ok(imports)
1069    }
1070}
1071
1072/// Extract Go import statements
1073fn extract_go_imports(source: &str, root: &tree_sitter::Node) -> Result<Vec<ImportInfo>> {
1074    let language = tree_sitter_go::LANGUAGE;
1075
1076    // Go imports can be single or in groups
1077    let query_str = r#"
1078        (import_declaration
1079            (import_spec
1080                path: (interpreted_string_literal) @import_path)) @import
1081
1082        (import_declaration
1083            (import_spec_list
1084                (import_spec
1085                    path: (interpreted_string_literal) @import_path))) @import
1086    "#;
1087
1088    let query =
1089        Query::new(&language.into(), query_str).context("Failed to create Go import query")?;
1090
1091    let mut cursor = QueryCursor::new();
1092    let mut matches = cursor.matches(&query, *root, source.as_bytes());
1093
1094    let mut imports = Vec::new();
1095
1096    while let Some(match_) = matches.next() {
1097        let mut import_path = None;
1098        let mut import_node = None;
1099
1100        for capture in match_.captures {
1101            let capture_name: &str = &query.capture_names()[capture.index as usize];
1102            match capture_name {
1103                "import_path" => {
1104                    // Remove quotes from string literal
1105                    let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
1106                    import_path = Some(raw_path.trim_matches('"').to_string());
1107                }
1108                "import" => {
1109                    import_node = Some(capture.node);
1110                }
1111                _ => {}
1112            }
1113        }
1114
1115        if let (Some(path), Some(node)) = (import_path, import_node) {
1116            let import_type = classify_go_import(&path);
1117            let line_number = node.start_position().row + 1;
1118
1119            imports.push(ImportInfo {
1120                imported_path: path,
1121                import_type,
1122                line_number,
1123                imported_symbols: None, // Go imports entire packages, not selective symbols
1124            });
1125        }
1126    }
1127
1128    Ok(imports)
1129}
1130
1131/// Find and parse go.mod to extract module name
1132/// Returns None if go.mod not found or module name can't be parsed
1133pub fn find_go_module_name(root: &std::path::Path) -> Option<String> {
1134    // Look for go.mod in root directory
1135    let go_mod_path = root.join("go.mod");
1136    if !go_mod_path.exists() {
1137        return None;
1138    }
1139
1140    // Read go.mod and extract module name
1141    let content = std::fs::read_to_string(&go_mod_path).ok()?;
1142    for line in content.lines() {
1143        let trimmed = line.trim();
1144        if trimmed.starts_with("module ") {
1145            // Extract module name: "module k8s.io/kubernetes" -> "k8s.io/kubernetes"
1146            let module_name = trimmed["module ".len()..].trim();
1147            return Some(module_name.to_string());
1148        }
1149    }
1150
1151    None
1152}
1153
1154/// Reclassify a Go import based on the module prefix
1155/// This should be called by the indexer after extraction to correctly identify internal imports
1156pub fn reclassify_go_import(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1157    classify_go_import_impl(import_path, module_prefix)
1158}
1159
1160/// Classify a Go import as internal, external, or stdlib
1161fn classify_go_import(import_path: &str) -> ImportType {
1162    classify_go_import_impl(import_path, None)
1163}
1164
1165/// Internal implementation of Go import classification
1166fn classify_go_import_impl(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1167    // If we have a module prefix, check if import starts with it → Internal
1168    if let Some(prefix) = module_prefix {
1169        if import_path.starts_with(prefix) {
1170            return ImportType::Internal;
1171        }
1172        // Also check for multi-module repos - imports starting with k8s.io/* for Kubernetes
1173        // Extract the domain portion and check if it matches
1174        if let Some(import_domain) = import_path.split('/').next() {
1175            if let Some(module_domain) = prefix.split('/').next() {
1176                // If domains match (e.g., both start with k8s.io), consider it internal
1177                if import_domain == module_domain && module_domain.contains('.') {
1178                    return ImportType::Internal;
1179                }
1180            }
1181        }
1182    }
1183    // Relative imports (./ or ../) - rare in Go but possible
1184    if import_path.starts_with("./") || import_path.starts_with("../") {
1185        return ImportType::Internal;
1186    }
1187
1188    // Internal imports often start with company domain or project path
1189    // Check for common patterns like github.com/your-org/project
1190    // For now, we'll consider anything that looks like a full URL path as external
1191    // and short stdlib-like paths as stdlib
1192
1193    // Go standard library modules (common ones)
1194    const STDLIB_MODULES: &[&str] = &[
1195        "fmt",
1196        "io",
1197        "os",
1198        "path",
1199        "strings",
1200        "bytes",
1201        "bufio",
1202        "errors",
1203        "context",
1204        "sync",
1205        "time",
1206        "encoding/json",
1207        "encoding/xml",
1208        "encoding/csv",
1209        "net/http",
1210        "net/url",
1211        "net",
1212        "crypto",
1213        "crypto/tls",
1214        "crypto/sha256",
1215        "database/sql",
1216        "log",
1217        "math",
1218        "regexp",
1219        "strconv",
1220        "sort",
1221        "reflect",
1222        "runtime",
1223        "testing",
1224        "flag",
1225        "filepath",
1226        "unicode",
1227        "html",
1228        "text/template",
1229    ];
1230
1231    // Check if it's a stdlib module
1232    if STDLIB_MODULES.contains(&import_path) {
1233        return ImportType::Stdlib;
1234    }
1235
1236    // If it contains a domain (has dots and slashes), it's external
1237    if import_path.contains('/') && import_path.split('/').next().unwrap_or("").contains('.') {
1238        return ImportType::External;
1239    }
1240
1241    // Short paths without domains are likely stdlib
1242    if !import_path.contains('/') || import_path.split('/').count() <= 2 {
1243        return ImportType::Stdlib;
1244    }
1245
1246    // Everything else is external
1247    ImportType::External
1248}
1249
1250// ============================================================================
1251// Monorepo Support & Path Resolution
1252// ============================================================================
1253
1254/// Represents a Go module with its location
1255#[derive(Debug, Clone)]
1256pub struct GoModule {
1257    /// Module name (e.g., "k8s.io/kubernetes", "github.com/myorg/myproject")
1258    pub name: String,
1259    /// Project root relative to index root (e.g., "services/api")
1260    pub project_root: String,
1261    /// Absolute path to project root
1262    pub abs_project_root: std::path::PathBuf,
1263}
1264
1265/// Recursively find all go.mod files in the repository, respecting .gitignore
1266pub fn find_all_go_mods(index_root: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
1267    use ignore::WalkBuilder;
1268
1269    let mut go_mod_files = Vec::new();
1270
1271    let walker = WalkBuilder::new(index_root)
1272        .follow_links(false)
1273        .git_ignore(true)
1274        .build();
1275
1276    for entry in walker {
1277        let entry = entry?;
1278        let path = entry.path();
1279
1280        if !path.is_file() {
1281            continue;
1282        }
1283
1284        let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
1285
1286        // Look for go.mod files
1287        if filename == "go.mod" {
1288            // Skip vendor directories
1289            let path_str = path.to_string_lossy();
1290            if path_str.contains("/vendor/") {
1291                log::trace!("Skipping go.mod in vendor directory: {:?}", path);
1292                continue;
1293            }
1294
1295            go_mod_files.push(path.to_path_buf());
1296        }
1297    }
1298
1299    log::debug!("Found {} go.mod files", go_mod_files.len());
1300    Ok(go_mod_files)
1301}
1302
1303/// Parse all Go modules in a monorepo and track their project roots
1304pub fn parse_all_go_modules(index_root: &std::path::Path) -> Result<Vec<GoModule>> {
1305    let go_mod_files = find_all_go_mods(index_root)?;
1306
1307    if go_mod_files.is_empty() {
1308        log::debug!("No go.mod files found in {:?}", index_root);
1309        return Ok(Vec::new());
1310    }
1311
1312    let mut modules = Vec::new();
1313    let mod_count = go_mod_files.len();
1314
1315    for go_mod_path in &go_mod_files {
1316        let project_root = go_mod_path
1317            .parent()
1318            .ok_or_else(|| anyhow::anyhow!("go.mod has no parent directory"))?;
1319
1320        // Read and parse go.mod to extract module name
1321        if let Ok(content) = std::fs::read_to_string(go_mod_path) {
1322            for line in content.lines() {
1323                let trimmed = line.trim();
1324                if trimmed.starts_with("module ") {
1325                    let module_name = trimmed["module ".len()..].trim().to_string();
1326
1327                    let relative_project_root = project_root
1328                        .strip_prefix(index_root)
1329                        .unwrap_or(project_root)
1330                        .to_string_lossy()
1331                        .to_string();
1332
1333                    log::debug!(
1334                        "Found Go module '{}' at {:?}",
1335                        module_name,
1336                        relative_project_root
1337                    );
1338
1339                    modules.push(GoModule {
1340                        name: module_name,
1341                        project_root: relative_project_root,
1342                        abs_project_root: project_root.to_path_buf(),
1343                    });
1344                    break;
1345                }
1346            }
1347        }
1348    }
1349
1350    log::info!(
1351        "Loaded {} Go modules from {} go.mod files",
1352        modules.len(),
1353        mod_count
1354    );
1355
1356    Ok(modules)
1357}
1358
1359/// Resolve a Go import to a file path
1360///
1361/// Handles:
1362/// - Internal imports: `mymodule/pkg/utils` → `pkg/utils.go` or `pkg/utils/utils.go`
1363/// - Sub-packages: `mymodule/internal/models` → `internal/models/models.go`
1364/// - Relative imports: `./utils` (rare in Go but possible)
1365pub fn resolve_go_import_to_path(
1366    import_path: &str,
1367    modules: &[GoModule],
1368    _current_file_path: Option<&str>,
1369) -> Option<String> {
1370    // Handle relative imports (rare in Go)
1371    if import_path.starts_with("./") || import_path.starts_with("../") {
1372        // Go relative imports are rare and complex - skip for now
1373        return None;
1374    }
1375
1376    // Find matching module
1377    for module in modules {
1378        if import_path.starts_with(&module.name) {
1379            // Strip module name to get sub-package path
1380            // "k8s.io/kubernetes/pkg/api" with module "k8s.io/kubernetes" → "pkg/api"
1381            let sub_path = import_path
1382                .strip_prefix(&module.name)
1383                .unwrap_or(import_path)
1384                .trim_start_matches('/');
1385
1386            if sub_path.is_empty() {
1387                // Importing the module root - could be multiple files
1388                // Try common patterns
1389                let basename = module.name.split('/').last().unwrap_or("main");
1390                let candidates = if module.project_root.is_empty() {
1391                    vec!["main.go".to_string(), format!("{}.go", basename)]
1392                } else {
1393                    vec![
1394                        format!("{}/main.go", module.project_root),
1395                        format!("{}/{}.go", module.project_root, basename),
1396                    ]
1397                };
1398
1399                for candidate in candidates {
1400                    log::trace!("Checking Go module root: {}", candidate);
1401                    return Some(candidate);
1402                }
1403            } else {
1404                // Sub-package import
1405                // Try both single file and package directory patterns
1406                let package_name = sub_path.split('/').last().unwrap_or(sub_path);
1407                let candidates = if module.project_root.is_empty() {
1408                    vec![
1409                        format!("{}.go", sub_path),
1410                        format!("{}/{}.go", sub_path, package_name),
1411                    ]
1412                } else {
1413                    vec![
1414                        format!("{}/{}.go", module.project_root, sub_path),
1415                        format!("{}/{}/{}.go", module.project_root, sub_path, package_name),
1416                    ]
1417                };
1418
1419                for candidate in candidates {
1420                    log::trace!("Checking Go package path: {}", candidate);
1421                    return Some(candidate);
1422                }
1423            }
1424        }
1425    }
1426
1427    None
1428}