1use crate::models::{Language, SearchResult, Span, SymbolKind};
12use anyhow::{Context, Result};
13use streaming_iterator::StreamingIterator;
14use tree_sitter::{Parser, Query, QueryCursor};
15
16pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
18 let mut parser = Parser::new();
19 let language = tree_sitter_go::LANGUAGE;
20
21 parser
22 .set_language(&language.into())
23 .context("Failed to set Go language")?;
24
25 let tree = parser
26 .parse(source, None)
27 .context("Failed to parse Go source")?;
28
29 let root_node = tree.root_node();
30
31 let mut symbols = Vec::new();
32
33 symbols.extend(extract_functions(source, &root_node, &language.into())?);
35 symbols.extend(extract_types(source, &root_node, &language.into())?);
36 symbols.extend(extract_interfaces(source, &root_node, &language.into())?);
37 symbols.extend(extract_methods(source, &root_node, &language.into())?);
38 symbols.extend(extract_constants(source, &root_node, &language.into())?);
39 symbols.extend(extract_variables(source, &root_node, &language.into())?);
40
41 for symbol in &mut symbols {
43 symbol.path = path.to_string();
44 symbol.lang = Language::Go;
45 }
46
47 Ok(symbols)
48}
49
50fn extract_functions(
52 source: &str,
53 root: &tree_sitter::Node,
54 language: &tree_sitter::Language,
55) -> Result<Vec<SearchResult>> {
56 let query_str = r#"
57 (function_declaration
58 name: (identifier) @name) @function
59 "#;
60
61 let query = Query::new(language, query_str).context("Failed to create function query")?;
62
63 extract_symbols(source, root, &query, SymbolKind::Function, None)
64}
65
66fn extract_types(
68 source: &str,
69 root: &tree_sitter::Node,
70 language: &tree_sitter::Language,
71) -> Result<Vec<SearchResult>> {
72 let query_str = r#"
73 (type_declaration
74 (type_spec
75 name: (type_identifier) @name
76 type: (struct_type))) @struct
77 "#;
78
79 let query = Query::new(language, query_str).context("Failed to create struct query")?;
80
81 extract_symbols(source, root, &query, SymbolKind::Struct, None)
82}
83
84fn extract_interfaces(
86 source: &str,
87 root: &tree_sitter::Node,
88 language: &tree_sitter::Language,
89) -> Result<Vec<SearchResult>> {
90 let query_str = r#"
91 (type_declaration
92 (type_spec
93 name: (type_identifier) @name
94 type: (interface_type))) @interface
95 "#;
96
97 let query = Query::new(language, query_str).context("Failed to create interface query")?;
98
99 extract_symbols(source, root, &query, SymbolKind::Interface, None)
100}
101
102fn extract_methods(
104 source: &str,
105 root: &tree_sitter::Node,
106 language: &tree_sitter::Language,
107) -> Result<Vec<SearchResult>> {
108 let query_str = r#"
109 (method_declaration
110 receiver: (parameter_list
111 (parameter_declaration
112 type: [(type_identifier) (pointer_type (type_identifier))] @receiver_type))
113 name: (field_identifier) @method_name) @method
114 "#;
115
116 let query = Query::new(language, query_str).context("Failed to create method query")?;
117
118 let mut cursor = QueryCursor::new();
119 let mut matches = cursor.matches(&query, *root, source.as_bytes());
120
121 let mut symbols = Vec::new();
122
123 while let Some(match_) = matches.next() {
124 let mut receiver_type = None;
125 let mut method_name = None;
126 let mut method_node = None;
127
128 for capture in match_.captures {
129 let capture_name: &str = &query.capture_names()[capture.index as usize];
130 match capture_name {
131 "receiver_type" => {
132 receiver_type = Some(
133 capture
134 .node
135 .utf8_text(source.as_bytes())
136 .unwrap_or("")
137 .to_string(),
138 );
139 }
140 "method_name" => {
141 method_name = Some(
142 capture
143 .node
144 .utf8_text(source.as_bytes())
145 .unwrap_or("")
146 .to_string(),
147 );
148 }
149 "method" => {
150 method_node = Some(capture.node);
151 }
152 _ => {}
153 }
154 }
155
156 if let (Some(receiver_type), Some(method_name), Some(node)) =
157 (receiver_type, method_name, method_node)
158 {
159 let clean_receiver = receiver_type.trim_start_matches('*');
161 let scope = format!("type {}", clean_receiver);
162 let span = node_to_span(&node);
163 let preview = extract_preview(source, &span);
164
165 symbols.push(SearchResult::new(
166 String::new(),
167 Language::Go,
168 SymbolKind::Method,
169 Some(method_name),
170 span,
171 Some(scope),
172 preview,
173 ));
174 }
175 }
176
177 Ok(symbols)
178}
179
180fn extract_constants(
182 source: &str,
183 root: &tree_sitter::Node,
184 language: &tree_sitter::Language,
185) -> Result<Vec<SearchResult>> {
186 let query_str = r#"
187 (const_declaration
188 (const_spec
189 name: (identifier) @name)) @const
190 "#;
191
192 let query = Query::new(language, query_str).context("Failed to create const query")?;
193
194 extract_symbols(source, root, &query, SymbolKind::Constant, None)
195}
196
197fn extract_variables(
199 source: &str,
200 root: &tree_sitter::Node,
201 language: &tree_sitter::Language,
202) -> Result<Vec<SearchResult>> {
203 let query_str = r#"
205 (var_spec
206 name: (identifier) @name) @var
207
208 (short_var_declaration
209 left: (expression_list (identifier) @name)) @short_var
210 "#;
211
212 let query = Query::new(language, query_str).context("Failed to create var query")?;
213
214 let mut cursor = QueryCursor::new();
215 let mut matches = cursor.matches(&query, *root, source.as_bytes());
216
217 let mut symbols = Vec::new();
218
219 while let Some(match_) = matches.next() {
220 let mut name = None;
221 let mut decl_node = None;
222
223 for capture in match_.captures {
224 let capture_name: &str = &query.capture_names()[capture.index as usize];
225 match capture_name {
226 "name" => {
227 name = Some(
228 capture
229 .node
230 .utf8_text(source.as_bytes())
231 .unwrap_or("")
232 .to_string(),
233 );
234 }
235 "var" | "short_var" => {
236 decl_node = Some(capture.node);
237 }
238 _ => {}
239 }
240 }
241
242 if let (Some(name), Some(node)) = (name, decl_node) {
243 let span = node_to_span(&node);
244 let preview = extract_preview(source, &span);
245
246 symbols.push(SearchResult::new(
247 String::new(),
248 Language::Go,
249 SymbolKind::Variable,
250 Some(name),
251 span,
252 None,
253 preview,
254 ));
255 }
256 }
257
258 Ok(symbols)
259}
260
261fn extract_symbols(
263 source: &str,
264 root: &tree_sitter::Node,
265 query: &Query,
266 kind: SymbolKind,
267 scope: Option<String>,
268) -> Result<Vec<SearchResult>> {
269 let mut cursor = QueryCursor::new();
270 let mut matches = cursor.matches(query, *root, source.as_bytes());
271
272 let mut symbols = Vec::new();
273
274 while let Some(match_) = matches.next() {
275 let mut name = None;
277 let mut full_node = None;
278
279 for capture in match_.captures {
280 let capture_name: &str = &query.capture_names()[capture.index as usize];
281 if capture_name == "name" {
282 name = Some(
283 capture
284 .node
285 .utf8_text(source.as_bytes())
286 .unwrap_or("")
287 .to_string(),
288 );
289 } else {
290 full_node = Some(capture.node);
292 }
293 }
294
295 if let (Some(name), Some(node)) = (name, full_node) {
296 let span = node_to_span(&node);
297 let preview = extract_preview(source, &span);
298
299 symbols.push(SearchResult::new(
300 String::new(),
301 Language::Go,
302 kind.clone(),
303 Some(name),
304 span,
305 scope.clone(),
306 preview,
307 ));
308 }
309 }
310
311 Ok(symbols)
312}
313
314fn node_to_span(node: &tree_sitter::Node) -> Span {
316 let start = node.start_position();
317 let end = node.end_position();
318
319 Span::new(
320 start.row + 1, start.column,
322 end.row + 1,
323 end.column,
324 )
325}
326
327fn extract_preview(source: &str, span: &Span) -> String {
329 let lines: Vec<&str> = source.lines().collect();
330
331 let start_idx = (span.start_line - 1) as usize; let end_idx = (start_idx + 7).min(lines.len());
334
335 lines[start_idx..end_idx].join("\n")
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_parse_function() {
344 let source = r#"
345package main
346
347func helloWorld() string {
348 return "Hello, world!"
349}
350 "#;
351
352 let symbols = parse("test.go", source).unwrap();
353 assert_eq!(symbols.len(), 1);
354 assert_eq!(symbols[0].symbol.as_deref(), Some("helloWorld"));
355 assert!(matches!(symbols[0].kind, SymbolKind::Function));
356 }
357
358 #[test]
359 fn test_parse_struct() {
360 let source = r#"
361package main
362
363type User struct {
364 Name string
365 Age int
366}
367 "#;
368
369 let symbols = parse("test.go", source).unwrap();
370 assert_eq!(symbols.len(), 1);
371 assert_eq!(symbols[0].symbol.as_deref(), Some("User"));
372 assert!(matches!(symbols[0].kind, SymbolKind::Struct));
373 }
374
375 #[test]
376 fn test_parse_interface() {
377 let source = r#"
378package main
379
380type Reader interface {
381 Read(p []byte) (n int, err error)
382}
383 "#;
384
385 let symbols = parse("test.go", source).unwrap();
386 assert_eq!(symbols.len(), 1);
387 assert_eq!(symbols[0].symbol.as_deref(), Some("Reader"));
388 assert!(matches!(symbols[0].kind, SymbolKind::Interface));
389 }
390
391 #[test]
392 fn test_parse_method() {
393 let source = r#"
394package main
395
396type User struct {
397 Name string
398}
399
400func (u *User) GetName() string {
401 return u.Name
402}
403
404func (u User) SetName(name string) {
405 u.Name = name
406}
407 "#;
408
409 let symbols = parse("test.go", source).unwrap();
410
411 let method_symbols: Vec<_> = symbols
412 .iter()
413 .filter(|s| matches!(s.kind, SymbolKind::Method))
414 .collect();
415
416 assert_eq!(method_symbols.len(), 2);
417 assert!(
418 method_symbols
419 .iter()
420 .any(|s| s.symbol.as_deref() == Some("GetName"))
421 );
422 assert!(
423 method_symbols
424 .iter()
425 .any(|s| s.symbol.as_deref() == Some("SetName"))
426 );
427
428 for method in method_symbols {
430 }
432 }
433
434 #[test]
435 fn test_parse_constants() {
436 let source = r#"
437package main
438
439const MaxSize = 100
440const DefaultTimeout = 30
441
442const (
443 StatusActive = 1
444 StatusInactive = 2
445)
446 "#;
447
448 let symbols = parse("test.go", source).unwrap();
449
450 let const_symbols: Vec<_> = symbols
451 .iter()
452 .filter(|s| matches!(s.kind, SymbolKind::Constant))
453 .collect();
454
455 assert_eq!(const_symbols.len(), 4);
456 assert!(
457 const_symbols
458 .iter()
459 .any(|s| s.symbol.as_deref() == Some("MaxSize"))
460 );
461 assert!(
462 const_symbols
463 .iter()
464 .any(|s| s.symbol.as_deref() == Some("DefaultTimeout"))
465 );
466 assert!(
467 const_symbols
468 .iter()
469 .any(|s| s.symbol.as_deref() == Some("StatusActive"))
470 );
471 assert!(
472 const_symbols
473 .iter()
474 .any(|s| s.symbol.as_deref() == Some("StatusInactive"))
475 );
476 }
477
478 #[test]
479 fn test_parse_variables() {
480 let source = r#"
481package main
482
483var GlobalConfig Config
484var (
485 Logger *log.Logger
486 Version = "1.0.0"
487)
488 "#;
489
490 let symbols = parse("test.go", source).unwrap();
491
492 let var_symbols: Vec<_> = symbols
493 .iter()
494 .filter(|s| matches!(s.kind, SymbolKind::Variable))
495 .collect();
496
497 assert_eq!(var_symbols.len(), 3);
498 assert!(
499 var_symbols
500 .iter()
501 .any(|s| s.symbol.as_deref() == Some("GlobalConfig"))
502 );
503 assert!(
504 var_symbols
505 .iter()
506 .any(|s| s.symbol.as_deref() == Some("Logger"))
507 );
508 assert!(
509 var_symbols
510 .iter()
511 .any(|s| s.symbol.as_deref() == Some("Version"))
512 );
513 }
514
515 #[test]
516 fn test_parse_mixed_symbols() {
517 let source = r#"
518package main
519
520const DefaultPort = 8080
521
522type Server struct {
523 Port int
524}
525
526type Handler interface {
527 Handle(req *Request) error
528}
529
530func (s *Server) Start() error {
531 return nil
532}
533
534func NewServer(port int) *Server {
535 return &Server{Port: port}
536}
537
538var globalServer *Server
539 "#;
540
541 let symbols = parse("test.go", source).unwrap();
542
543 assert!(symbols.len() >= 6);
545
546 let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
547 assert!(kinds.contains(&&SymbolKind::Constant));
548 assert!(kinds.contains(&&SymbolKind::Struct));
549 assert!(kinds.contains(&&SymbolKind::Interface));
550 assert!(kinds.contains(&&SymbolKind::Method));
551 assert!(kinds.contains(&&SymbolKind::Function));
552 assert!(kinds.contains(&&SymbolKind::Variable));
553 }
554
555 #[test]
556 fn test_parse_multiple_methods() {
557 let source = r#"
558package main
559
560type Calculator struct{}
561
562func (c *Calculator) Add(a, b int) int {
563 return a + b
564}
565
566func (c *Calculator) Subtract(a, b int) int {
567 return a - b
568}
569
570func (c *Calculator) Multiply(a, b int) int {
571 return a * b
572}
573 "#;
574
575 let symbols = parse("test.go", source).unwrap();
576
577 let method_symbols: Vec<_> = symbols
578 .iter()
579 .filter(|s| matches!(s.kind, SymbolKind::Method))
580 .collect();
581
582 assert_eq!(method_symbols.len(), 3);
583 assert!(
584 method_symbols
585 .iter()
586 .any(|s| s.symbol.as_deref() == Some("Add"))
587 );
588 assert!(
589 method_symbols
590 .iter()
591 .any(|s| s.symbol.as_deref() == Some("Subtract"))
592 );
593 assert!(
594 method_symbols
595 .iter()
596 .any(|s| s.symbol.as_deref() == Some("Multiply"))
597 );
598 }
599
600 #[test]
601 fn test_parse_type_alias() {
602 let source = r#"
603package main
604
605type UserID string
606type Age int
607
608type Config struct {
609 Host string
610 Port int
611}
612 "#;
613
614 let symbols = parse("test.go", source).unwrap();
615
616 let struct_symbols: Vec<_> = symbols
618 .iter()
619 .filter(|s| matches!(s.kind, SymbolKind::Struct))
620 .collect();
621
622 assert_eq!(struct_symbols.len(), 1);
623 assert_eq!(struct_symbols[0].symbol.as_deref(), Some("Config"));
624 }
625
626 #[test]
627 fn test_parse_embedded_interface() {
628 let source = r#"
629package main
630
631type Reader interface {
632 Read(p []byte) (n int, err error)
633}
634
635type Writer interface {
636 Write(p []byte) (n int, err error)
637}
638
639type ReadWriter interface {
640 Reader
641 Writer
642}
643 "#;
644
645 let symbols = parse("test.go", source).unwrap();
646
647 let interface_symbols: Vec<_> = symbols
648 .iter()
649 .filter(|s| matches!(s.kind, SymbolKind::Interface))
650 .collect();
651
652 assert_eq!(interface_symbols.len(), 3);
653 assert!(
654 interface_symbols
655 .iter()
656 .any(|s| s.symbol.as_deref() == Some("Reader"))
657 );
658 assert!(
659 interface_symbols
660 .iter()
661 .any(|s| s.symbol.as_deref() == Some("Writer"))
662 );
663 assert!(
664 interface_symbols
665 .iter()
666 .any(|s| s.symbol.as_deref() == Some("ReadWriter"))
667 );
668 }
669
670 #[test]
671 fn test_local_variables_included() {
672 let source = r#"
673package main
674
675var globalCount int = 10
676
677func calculate(x int) int {
678 localVar := x * 2
679 var anotherLocal int = 5
680 return localVar + anotherLocal
681}
682 "#;
683
684 let symbols = parse("test.go", source).unwrap();
685
686 let var_symbols: Vec<_> = symbols
687 .iter()
688 .filter(|s| matches!(s.kind, SymbolKind::Variable))
689 .collect();
690
691 assert_eq!(var_symbols.len(), 3);
693 assert!(
694 var_symbols
695 .iter()
696 .any(|s| s.symbol.as_deref() == Some("globalCount"))
697 );
698 assert!(
699 var_symbols
700 .iter()
701 .any(|s| s.symbol.as_deref() == Some("localVar"))
702 );
703 assert!(
704 var_symbols
705 .iter()
706 .any(|s| s.symbol.as_deref() == Some("anotherLocal"))
707 );
708 }
709
710 #[test]
711 fn test_extract_go_imports() {
712 let source = r#"package main
713
714import (
715 "fmt"
716 "encoding/json"
717 "github.com/gin-gonic/gin"
718 "myproject/internal/models"
719)
720
721func main() {
722 fmt.Println("Hello")
723}
724"#;
725
726 let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
727
728 assert_eq!(deps.len(), 4, "Should extract 4 import statements");
729 assert!(deps.iter().any(|d| d.imported_path == "fmt"));
730 assert!(deps.iter().any(|d| d.imported_path == "encoding/json"));
731 assert!(
732 deps.iter()
733 .any(|d| d.imported_path == "github.com/gin-gonic/gin")
734 );
735 assert!(
736 deps.iter()
737 .any(|d| d.imported_path == "myproject/internal/models")
738 );
739
740 let fmt_dep = deps.iter().find(|d| d.imported_path == "fmt").unwrap();
742 assert!(
743 matches!(fmt_dep.import_type, ImportType::Stdlib),
744 "fmt should be classified as Stdlib"
745 );
746
747 let json_dep = deps
748 .iter()
749 .find(|d| d.imported_path == "encoding/json")
750 .unwrap();
751 assert!(
752 matches!(json_dep.import_type, ImportType::Stdlib),
753 "encoding/json should be classified as Stdlib"
754 );
755
756 let gin_dep = deps
758 .iter()
759 .find(|d| d.imported_path == "github.com/gin-gonic/gin")
760 .unwrap();
761 assert!(
762 matches!(gin_dep.import_type, ImportType::External),
763 "github.com/gin-gonic/gin should be classified as External"
764 );
765
766 let models_dep = deps
768 .iter()
769 .find(|d| d.imported_path == "myproject/internal/models")
770 .unwrap();
771 assert!(
772 matches!(models_dep.import_type, ImportType::External),
773 "myproject/internal/models should be classified as External"
774 );
775 }
776
777 #[test]
778 fn test_extract_go_imports_with_comments() {
779 let source = r#"package main
781
782import (
783 "os"
784 _ "time/tzdata" // for timeZone support in CronJob
785
786 "k8s.io/component-base/cli"
787 _ "k8s.io/component-base/logs/json/register" // for JSON log format registration
788 _ "k8s.io/component-base/metrics/prometheus/clientgo" // load all the prometheus client-go plugins
789)
790
791func main() {
792 os.Exit(0)
793}
794"#;
795
796 let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
797
798 println!("Extracted {} dependencies:", deps.len());
799 for dep in &deps {
800 println!(" - {} (line {})", dep.imported_path, dep.line_number);
801 }
802
803 assert!(
805 deps.len() >= 4,
806 "Should extract at least 4 imports, got {}",
807 deps.len()
808 );
809 assert!(deps.iter().any(|d| d.imported_path == "os"));
810 assert!(deps.iter().any(|d| d.imported_path == "time/tzdata"));
811 assert!(
812 deps.iter()
813 .any(|d| d.imported_path == "k8s.io/component-base/cli")
814 );
815 }
816
817 #[test]
818 fn test_find_all_go_mods() {
819 use std::fs;
820 use tempfile::TempDir;
821
822 let temp = TempDir::new().unwrap();
823 let root = temp.path();
824
825 let service1 = root.join("services/auth");
827 fs::create_dir_all(&service1).unwrap();
828 fs::write(
829 service1.join("go.mod"),
830 "module github.com/myorg/auth\n\ngo 1.21\n",
831 )
832 .unwrap();
833
834 let service2 = root.join("services/api");
835 fs::create_dir_all(&service2).unwrap();
836 fs::write(
837 service2.join("go.mod"),
838 "module github.com/myorg/api\n\ngo 1.21\n",
839 )
840 .unwrap();
841
842 let vendor = root.join("vendor");
844 fs::create_dir_all(&vendor).unwrap();
845 fs::write(vendor.join("go.mod"), "module github.com/external/lib\n").unwrap();
846
847 let mods = find_all_go_mods(root).unwrap();
848
849 assert_eq!(mods.len(), 2);
851 assert!(mods.iter().any(|p| p.ends_with("services/auth/go.mod")));
852 assert!(mods.iter().any(|p| p.ends_with("services/api/go.mod")));
853 }
854
855 #[test]
856 fn test_parse_all_go_modules() {
857 use std::fs;
858 use tempfile::TempDir;
859
860 let temp = TempDir::new().unwrap();
861 let root = temp.path();
862
863 let service1 = root.join("services/auth");
865 fs::create_dir_all(&service1).unwrap();
866 fs::write(
867 service1.join("go.mod"),
868 "module github.com/myorg/auth\n\ngo 1.21\n",
869 )
870 .unwrap();
871
872 let service2 = root.join("cmd/api");
873 fs::create_dir_all(&service2).unwrap();
874 fs::write(
875 service2.join("go.mod"),
876 "module github.com/myorg/api\n\ngo 1.21\n",
877 )
878 .unwrap();
879
880 let modules = parse_all_go_modules(root).unwrap();
881
882 assert_eq!(modules.len(), 2);
884
885 let names: Vec<_> = modules.iter().map(|m| m.name.as_str()).collect();
887 assert!(names.contains(&"github.com/myorg/auth"));
888 assert!(names.contains(&"github.com/myorg/api"));
889
890 for module in &modules {
892 assert!(
893 module.project_root.starts_with("services/")
894 || module.project_root.starts_with("cmd/")
895 );
896 assert!(module.abs_project_root.ends_with(&module.project_root));
897 }
898 }
899
900 #[test]
901 fn test_resolve_go_import() {
902 use std::fs;
903 use tempfile::TempDir;
904
905 let temp = TempDir::new().unwrap();
906 let root = temp.path();
907
908 let myapp = root.join("myapp");
910 fs::create_dir_all(myapp.join("pkg/models")).unwrap();
911 fs::write(
912 myapp.join("go.mod"),
913 "module github.com/myorg/myapp\n\ngo 1.21\n",
914 )
915 .unwrap();
916
917 let modules = parse_all_go_modules(root).unwrap();
918 assert_eq!(modules.len(), 1);
919
920 let resolved =
923 resolve_go_import_to_path("github.com/myorg/myapp/pkg/models", &modules, None);
924
925 assert!(resolved.is_some());
926 let path = resolved.unwrap();
927 assert!(path.contains("myapp/pkg/models"));
928 assert!(path.ends_with(".go"));
929 }
930
931 #[test]
932 fn test_resolve_go_import_module_root() {
933 use std::fs;
934 use tempfile::TempDir;
935
936 let temp = TempDir::new().unwrap();
937 let root = temp.path();
938
939 let myapp = root.join("cmd/server");
940 fs::create_dir_all(&myapp).unwrap();
941 fs::write(
942 myapp.join("go.mod"),
943 "module github.com/myorg/server\n\ngo 1.21\n",
944 )
945 .unwrap();
946
947 let modules = parse_all_go_modules(root).unwrap();
948
949 let resolved = resolve_go_import_to_path("github.com/myorg/server", &modules, None);
951
952 assert!(resolved.is_some());
953 let path = resolved.unwrap();
954 assert!(path.contains("cmd/server"));
956 assert!(path.ends_with(".go"));
957 }
958
959 #[test]
960 fn test_resolve_go_import_not_found() {
961 use std::fs;
962 use tempfile::TempDir;
963
964 let temp = TempDir::new().unwrap();
965 let root = temp.path();
966
967 let myapp = root.join("myapp");
968 fs::create_dir_all(&myapp).unwrap();
969 fs::write(
970 myapp.join("go.mod"),
971 "module github.com/myorg/myapp\n\ngo 1.21\n",
972 )
973 .unwrap();
974
975 let modules = parse_all_go_modules(root).unwrap();
976
977 let resolved = resolve_go_import_to_path("github.com/other/package", &modules, None);
979
980 assert!(resolved.is_none());
982 }
983
984 #[test]
985 fn test_resolve_go_import_relative() {
986 let modules = vec![];
987
988 let resolved =
990 resolve_go_import_to_path("./utils", &modules, Some("myapp/pkg/api/handler.go"));
991
992 assert!(resolved.is_none());
993 }
994
995 #[test]
996 fn test_resolve_go_import_root_module_no_leading_slash() {
997 use std::fs;
1000 use tempfile::TempDir;
1001
1002 let temp = TempDir::new().unwrap();
1003 let root = temp.path();
1004
1005 fs::write(root.join("go.mod"), "module k8s.io/kubernetes\n\ngo 1.21\n").unwrap();
1007
1008 let modules = parse_all_go_modules(root).unwrap();
1009 assert_eq!(modules.len(), 1);
1010 assert_eq!(modules[0].project_root, "");
1011
1012 let resolved =
1014 resolve_go_import_to_path("k8s.io/kubernetes/test/internal/metric", &modules, None);
1015 assert!(resolved.is_some());
1016 let path = resolved.unwrap();
1017 assert!(
1018 !path.starts_with('/'),
1019 "path must not start with '/': {}",
1020 path
1021 );
1022 assert!(path.ends_with(".go"));
1023 assert!(path.contains("test/internal/metric"));
1024
1025 let resolved = resolve_go_import_to_path("k8s.io/kubernetes", &modules, None);
1027 assert!(resolved.is_some());
1028 let path = resolved.unwrap();
1029 assert!(
1030 !path.starts_with('/'),
1031 "path must not start with '/': {}",
1032 path
1033 );
1034 assert!(path.ends_with(".go"));
1035 }
1036}
1037
1038use crate::models::ImportType;
1043use crate::parsers::{DependencyExtractor, ImportInfo};
1044
1045pub struct GoDependencyExtractor;
1047
1048impl DependencyExtractor for GoDependencyExtractor {
1049 fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
1050 let mut parser = Parser::new();
1051 let language = tree_sitter_go::LANGUAGE;
1052
1053 parser
1054 .set_language(&language.into())
1055 .context("Failed to set Go language")?;
1056
1057 let tree = parser
1058 .parse(source, None)
1059 .context("Failed to parse Go source")?;
1060
1061 let root_node = tree.root_node();
1062
1063 let mut imports = Vec::new();
1064
1065 imports.extend(extract_go_imports(source, &root_node)?);
1067
1068 Ok(imports)
1069 }
1070}
1071
1072fn extract_go_imports(source: &str, root: &tree_sitter::Node) -> Result<Vec<ImportInfo>> {
1074 let language = tree_sitter_go::LANGUAGE;
1075
1076 let query_str = r#"
1078 (import_declaration
1079 (import_spec
1080 path: (interpreted_string_literal) @import_path)) @import
1081
1082 (import_declaration
1083 (import_spec_list
1084 (import_spec
1085 path: (interpreted_string_literal) @import_path))) @import
1086 "#;
1087
1088 let query =
1089 Query::new(&language.into(), query_str).context("Failed to create Go import query")?;
1090
1091 let mut cursor = QueryCursor::new();
1092 let mut matches = cursor.matches(&query, *root, source.as_bytes());
1093
1094 let mut imports = Vec::new();
1095
1096 while let Some(match_) = matches.next() {
1097 let mut import_path = None;
1098 let mut import_node = None;
1099
1100 for capture in match_.captures {
1101 let capture_name: &str = &query.capture_names()[capture.index as usize];
1102 match capture_name {
1103 "import_path" => {
1104 let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
1106 import_path = Some(raw_path.trim_matches('"').to_string());
1107 }
1108 "import" => {
1109 import_node = Some(capture.node);
1110 }
1111 _ => {}
1112 }
1113 }
1114
1115 if let (Some(path), Some(node)) = (import_path, import_node) {
1116 let import_type = classify_go_import(&path);
1117 let line_number = node.start_position().row + 1;
1118
1119 imports.push(ImportInfo {
1120 imported_path: path,
1121 import_type,
1122 line_number,
1123 imported_symbols: None, });
1125 }
1126 }
1127
1128 Ok(imports)
1129}
1130
1131pub fn find_go_module_name(root: &std::path::Path) -> Option<String> {
1134 let go_mod_path = root.join("go.mod");
1136 if !go_mod_path.exists() {
1137 return None;
1138 }
1139
1140 let content = std::fs::read_to_string(&go_mod_path).ok()?;
1142 for line in content.lines() {
1143 let trimmed = line.trim();
1144 if trimmed.starts_with("module ") {
1145 let module_name = trimmed["module ".len()..].trim();
1147 return Some(module_name.to_string());
1148 }
1149 }
1150
1151 None
1152}
1153
1154pub fn reclassify_go_import(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1157 classify_go_import_impl(import_path, module_prefix)
1158}
1159
1160fn classify_go_import(import_path: &str) -> ImportType {
1162 classify_go_import_impl(import_path, None)
1163}
1164
1165fn classify_go_import_impl(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1167 if let Some(prefix) = module_prefix {
1169 if import_path.starts_with(prefix) {
1170 return ImportType::Internal;
1171 }
1172 if let Some(import_domain) = import_path.split('/').next() {
1175 if let Some(module_domain) = prefix.split('/').next() {
1176 if import_domain == module_domain && module_domain.contains('.') {
1178 return ImportType::Internal;
1179 }
1180 }
1181 }
1182 }
1183 if import_path.starts_with("./") || import_path.starts_with("../") {
1185 return ImportType::Internal;
1186 }
1187
1188 const STDLIB_MODULES: &[&str] = &[
1195 "fmt",
1196 "io",
1197 "os",
1198 "path",
1199 "strings",
1200 "bytes",
1201 "bufio",
1202 "errors",
1203 "context",
1204 "sync",
1205 "time",
1206 "encoding/json",
1207 "encoding/xml",
1208 "encoding/csv",
1209 "net/http",
1210 "net/url",
1211 "net",
1212 "crypto",
1213 "crypto/tls",
1214 "crypto/sha256",
1215 "database/sql",
1216 "log",
1217 "math",
1218 "regexp",
1219 "strconv",
1220 "sort",
1221 "reflect",
1222 "runtime",
1223 "testing",
1224 "flag",
1225 "filepath",
1226 "unicode",
1227 "html",
1228 "text/template",
1229 ];
1230
1231 if STDLIB_MODULES.contains(&import_path) {
1233 return ImportType::Stdlib;
1234 }
1235
1236 if import_path.contains('/') && import_path.split('/').next().unwrap_or("").contains('.') {
1238 return ImportType::External;
1239 }
1240
1241 if !import_path.contains('/') || import_path.split('/').count() <= 2 {
1243 return ImportType::Stdlib;
1244 }
1245
1246 ImportType::External
1248}
1249
1250#[derive(Debug, Clone)]
1256pub struct GoModule {
1257 pub name: String,
1259 pub project_root: String,
1261 pub abs_project_root: std::path::PathBuf,
1263}
1264
1265pub fn find_all_go_mods(index_root: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
1267 use ignore::WalkBuilder;
1268
1269 let mut go_mod_files = Vec::new();
1270
1271 let walker = WalkBuilder::new(index_root)
1272 .follow_links(false)
1273 .git_ignore(true)
1274 .build();
1275
1276 for entry in walker {
1277 let entry = entry?;
1278 let path = entry.path();
1279
1280 if !path.is_file() {
1281 continue;
1282 }
1283
1284 let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
1285
1286 if filename == "go.mod" {
1288 let path_str = path.to_string_lossy();
1290 if path_str.contains("/vendor/") {
1291 log::trace!("Skipping go.mod in vendor directory: {:?}", path);
1292 continue;
1293 }
1294
1295 go_mod_files.push(path.to_path_buf());
1296 }
1297 }
1298
1299 log::debug!("Found {} go.mod files", go_mod_files.len());
1300 Ok(go_mod_files)
1301}
1302
1303pub fn parse_all_go_modules(index_root: &std::path::Path) -> Result<Vec<GoModule>> {
1305 let go_mod_files = find_all_go_mods(index_root)?;
1306
1307 if go_mod_files.is_empty() {
1308 log::debug!("No go.mod files found in {:?}", index_root);
1309 return Ok(Vec::new());
1310 }
1311
1312 let mut modules = Vec::new();
1313 let mod_count = go_mod_files.len();
1314
1315 for go_mod_path in &go_mod_files {
1316 let project_root = go_mod_path
1317 .parent()
1318 .ok_or_else(|| anyhow::anyhow!("go.mod has no parent directory"))?;
1319
1320 if let Ok(content) = std::fs::read_to_string(go_mod_path) {
1322 for line in content.lines() {
1323 let trimmed = line.trim();
1324 if trimmed.starts_with("module ") {
1325 let module_name = trimmed["module ".len()..].trim().to_string();
1326
1327 let relative_project_root = project_root
1328 .strip_prefix(index_root)
1329 .unwrap_or(project_root)
1330 .to_string_lossy()
1331 .to_string();
1332
1333 log::debug!(
1334 "Found Go module '{}' at {:?}",
1335 module_name,
1336 relative_project_root
1337 );
1338
1339 modules.push(GoModule {
1340 name: module_name,
1341 project_root: relative_project_root,
1342 abs_project_root: project_root.to_path_buf(),
1343 });
1344 break;
1345 }
1346 }
1347 }
1348 }
1349
1350 log::info!(
1351 "Loaded {} Go modules from {} go.mod files",
1352 modules.len(),
1353 mod_count
1354 );
1355
1356 Ok(modules)
1357}
1358
1359pub fn resolve_go_import_to_path(
1366 import_path: &str,
1367 modules: &[GoModule],
1368 _current_file_path: Option<&str>,
1369) -> Option<String> {
1370 if import_path.starts_with("./") || import_path.starts_with("../") {
1372 return None;
1374 }
1375
1376 for module in modules {
1378 if import_path.starts_with(&module.name) {
1379 let sub_path = import_path
1382 .strip_prefix(&module.name)
1383 .unwrap_or(import_path)
1384 .trim_start_matches('/');
1385
1386 if sub_path.is_empty() {
1387 let basename = module.name.split('/').last().unwrap_or("main");
1390 let candidates = if module.project_root.is_empty() {
1391 vec!["main.go".to_string(), format!("{}.go", basename)]
1392 } else {
1393 vec![
1394 format!("{}/main.go", module.project_root),
1395 format!("{}/{}.go", module.project_root, basename),
1396 ]
1397 };
1398
1399 for candidate in candidates {
1400 log::trace!("Checking Go module root: {}", candidate);
1401 return Some(candidate);
1402 }
1403 } else {
1404 let package_name = sub_path.split('/').last().unwrap_or(sub_path);
1407 let candidates = if module.project_root.is_empty() {
1408 vec![
1409 format!("{}.go", sub_path),
1410 format!("{}/{}.go", sub_path, package_name),
1411 ]
1412 } else {
1413 vec![
1414 format!("{}/{}.go", module.project_root, sub_path),
1415 format!("{}/{}/{}.go", module.project_root, sub_path, package_name),
1416 ]
1417 };
1418
1419 for candidate in candidates {
1420 log::trace!("Checking Go package path: {}", candidate);
1421 return Some(candidate);
1422 }
1423 }
1424 }
1425 }
1426
1427 None
1428}