1use anyhow::{Context, Result};
12use streaming_iterator::StreamingIterator;
13use tree_sitter::{Parser, Query, QueryCursor};
14use crate::models::{Language, SearchResult, Span, SymbolKind};
15
16pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
18 let mut parser = Parser::new();
19 let language = tree_sitter_go::LANGUAGE;
20
21 parser
22 .set_language(&language.into())
23 .context("Failed to set Go language")?;
24
25 let tree = parser
26 .parse(source, None)
27 .context("Failed to parse Go source")?;
28
29 let root_node = tree.root_node();
30
31 let mut symbols = Vec::new();
32
33 symbols.extend(extract_functions(source, &root_node, &language.into())?);
35 symbols.extend(extract_types(source, &root_node, &language.into())?);
36 symbols.extend(extract_interfaces(source, &root_node, &language.into())?);
37 symbols.extend(extract_methods(source, &root_node, &language.into())?);
38 symbols.extend(extract_constants(source, &root_node, &language.into())?);
39 symbols.extend(extract_variables(source, &root_node, &language.into())?);
40
41 for symbol in &mut symbols {
43 symbol.path = path.to_string();
44 symbol.lang = Language::Go;
45 }
46
47 Ok(symbols)
48}
49
50fn extract_functions(
52 source: &str,
53 root: &tree_sitter::Node,
54 language: &tree_sitter::Language,
55) -> Result<Vec<SearchResult>> {
56 let query_str = r#"
57 (function_declaration
58 name: (identifier) @name) @function
59 "#;
60
61 let query = Query::new(language, query_str)
62 .context("Failed to create function query")?;
63
64 extract_symbols(source, root, &query, SymbolKind::Function, None)
65}
66
67fn extract_types(
69 source: &str,
70 root: &tree_sitter::Node,
71 language: &tree_sitter::Language,
72) -> Result<Vec<SearchResult>> {
73 let query_str = r#"
74 (type_declaration
75 (type_spec
76 name: (type_identifier) @name
77 type: (struct_type))) @struct
78 "#;
79
80 let query = Query::new(language, query_str)
81 .context("Failed to create struct query")?;
82
83 extract_symbols(source, root, &query, SymbolKind::Struct, None)
84}
85
86fn extract_interfaces(
88 source: &str,
89 root: &tree_sitter::Node,
90 language: &tree_sitter::Language,
91) -> Result<Vec<SearchResult>> {
92 let query_str = r#"
93 (type_declaration
94 (type_spec
95 name: (type_identifier) @name
96 type: (interface_type))) @interface
97 "#;
98
99 let query = Query::new(language, query_str)
100 .context("Failed to create interface query")?;
101
102 extract_symbols(source, root, &query, SymbolKind::Interface, None)
103}
104
105fn extract_methods(
107 source: &str,
108 root: &tree_sitter::Node,
109 language: &tree_sitter::Language,
110) -> Result<Vec<SearchResult>> {
111 let query_str = r#"
112 (method_declaration
113 receiver: (parameter_list
114 (parameter_declaration
115 type: [(type_identifier) (pointer_type (type_identifier))] @receiver_type))
116 name: (field_identifier) @method_name) @method
117 "#;
118
119 let query = Query::new(language, query_str)
120 .context("Failed to create method query")?;
121
122 let mut cursor = QueryCursor::new();
123 let mut matches = cursor.matches(&query, *root, source.as_bytes());
124
125 let mut symbols = Vec::new();
126
127 while let Some(match_) = matches.next() {
128 let mut receiver_type = None;
129 let mut method_name = None;
130 let mut method_node = None;
131
132 for capture in match_.captures {
133 let capture_name: &str = &query.capture_names()[capture.index as usize];
134 match capture_name {
135 "receiver_type" => {
136 receiver_type = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
137 }
138 "method_name" => {
139 method_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
140 }
141 "method" => {
142 method_node = Some(capture.node);
143 }
144 _ => {}
145 }
146 }
147
148 if let (Some(receiver_type), Some(method_name), Some(node)) = (receiver_type, method_name, method_node) {
149 let clean_receiver = receiver_type.trim_start_matches('*');
151 let scope = format!("type {}", clean_receiver);
152 let span = node_to_span(&node);
153 let preview = extract_preview(source, &span);
154
155 symbols.push(SearchResult::new(
156 String::new(),
157 Language::Go,
158 SymbolKind::Method,
159 Some(method_name),
160 span,
161 Some(scope),
162 preview,
163 ));
164 }
165 }
166
167 Ok(symbols)
168}
169
170fn extract_constants(
172 source: &str,
173 root: &tree_sitter::Node,
174 language: &tree_sitter::Language,
175) -> Result<Vec<SearchResult>> {
176 let query_str = r#"
177 (const_declaration
178 (const_spec
179 name: (identifier) @name)) @const
180 "#;
181
182 let query = Query::new(language, query_str)
183 .context("Failed to create const query")?;
184
185 extract_symbols(source, root, &query, SymbolKind::Constant, None)
186}
187
188fn extract_variables(
190 source: &str,
191 root: &tree_sitter::Node,
192 language: &tree_sitter::Language,
193) -> Result<Vec<SearchResult>> {
194 let query_str = r#"
196 (var_spec
197 name: (identifier) @name) @var
198
199 (short_var_declaration
200 left: (expression_list (identifier) @name)) @short_var
201 "#;
202
203 let query = Query::new(language, query_str)
204 .context("Failed to create var query")?;
205
206 let mut cursor = QueryCursor::new();
207 let mut matches = cursor.matches(&query, *root, source.as_bytes());
208
209 let mut symbols = Vec::new();
210
211 while let Some(match_) = matches.next() {
212 let mut name = None;
213 let mut decl_node = None;
214
215 for capture in match_.captures {
216 let capture_name: &str = &query.capture_names()[capture.index as usize];
217 match capture_name {
218 "name" => {
219 name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
220 }
221 "var" | "short_var" => {
222 decl_node = Some(capture.node);
223 }
224 _ => {}
225 }
226 }
227
228 if let (Some(name), Some(node)) = (name, decl_node) {
229 let span = node_to_span(&node);
230 let preview = extract_preview(source, &span);
231
232 symbols.push(SearchResult::new(
233 String::new(),
234 Language::Go,
235 SymbolKind::Variable,
236 Some(name),
237 span,
238 None,
239 preview,
240 ));
241 }
242 }
243
244 Ok(symbols)
245}
246
247fn extract_symbols(
249 source: &str,
250 root: &tree_sitter::Node,
251 query: &Query,
252 kind: SymbolKind,
253 scope: Option<String>,
254) -> Result<Vec<SearchResult>> {
255 let mut cursor = QueryCursor::new();
256 let mut matches = cursor.matches(query, *root, source.as_bytes());
257
258 let mut symbols = Vec::new();
259
260 while let Some(match_) = matches.next() {
261 let mut name = None;
263 let mut full_node = None;
264
265 for capture in match_.captures {
266 let capture_name: &str = &query.capture_names()[capture.index as usize];
267 if capture_name == "name" {
268 name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
269 } else {
270 full_node = Some(capture.node);
272 }
273 }
274
275 if let (Some(name), Some(node)) = (name, full_node) {
276 let span = node_to_span(&node);
277 let preview = extract_preview(source, &span);
278
279 symbols.push(SearchResult::new(
280 String::new(),
281 Language::Go,
282 kind.clone(),
283 Some(name),
284 span,
285 scope.clone(),
286 preview,
287 ));
288 }
289 }
290
291 Ok(symbols)
292}
293
294fn node_to_span(node: &tree_sitter::Node) -> Span {
296 let start = node.start_position();
297 let end = node.end_position();
298
299 Span::new(
300 start.row + 1, start.column,
302 end.row + 1,
303 end.column,
304 )
305}
306
307fn extract_preview(source: &str, span: &Span) -> String {
309 let lines: Vec<&str> = source.lines().collect();
310
311 let start_idx = (span.start_line - 1) as usize; let end_idx = (start_idx + 7).min(lines.len());
314
315 lines[start_idx..end_idx].join("\n")
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_parse_function() {
324 let source = r#"
325package main
326
327func helloWorld() string {
328 return "Hello, world!"
329}
330 "#;
331
332 let symbols = parse("test.go", source).unwrap();
333 assert_eq!(symbols.len(), 1);
334 assert_eq!(symbols[0].symbol.as_deref(), Some("helloWorld"));
335 assert!(matches!(symbols[0].kind, SymbolKind::Function));
336 }
337
338 #[test]
339 fn test_parse_struct() {
340 let source = r#"
341package main
342
343type User struct {
344 Name string
345 Age int
346}
347 "#;
348
349 let symbols = parse("test.go", source).unwrap();
350 assert_eq!(symbols.len(), 1);
351 assert_eq!(symbols[0].symbol.as_deref(), Some("User"));
352 assert!(matches!(symbols[0].kind, SymbolKind::Struct));
353 }
354
355 #[test]
356 fn test_parse_interface() {
357 let source = r#"
358package main
359
360type Reader interface {
361 Read(p []byte) (n int, err error)
362}
363 "#;
364
365 let symbols = parse("test.go", source).unwrap();
366 assert_eq!(symbols.len(), 1);
367 assert_eq!(symbols[0].symbol.as_deref(), Some("Reader"));
368 assert!(matches!(symbols[0].kind, SymbolKind::Interface));
369 }
370
371 #[test]
372 fn test_parse_method() {
373 let source = r#"
374package main
375
376type User struct {
377 Name string
378}
379
380func (u *User) GetName() string {
381 return u.Name
382}
383
384func (u User) SetName(name string) {
385 u.Name = name
386}
387 "#;
388
389 let symbols = parse("test.go", source).unwrap();
390
391 let method_symbols: Vec<_> = symbols.iter()
392 .filter(|s| matches!(s.kind, SymbolKind::Method))
393 .collect();
394
395 assert_eq!(method_symbols.len(), 2);
396 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("GetName")));
397 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("SetName")));
398
399 for method in method_symbols {
401 }
403 }
404
405 #[test]
406 fn test_parse_constants() {
407 let source = r#"
408package main
409
410const MaxSize = 100
411const DefaultTimeout = 30
412
413const (
414 StatusActive = 1
415 StatusInactive = 2
416)
417 "#;
418
419 let symbols = parse("test.go", source).unwrap();
420
421 let const_symbols: Vec<_> = symbols.iter()
422 .filter(|s| matches!(s.kind, SymbolKind::Constant))
423 .collect();
424
425 assert_eq!(const_symbols.len(), 4);
426 assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("MaxSize")));
427 assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("DefaultTimeout")));
428 assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("StatusActive")));
429 assert!(const_symbols.iter().any(|s| s.symbol.as_deref() == Some("StatusInactive")));
430 }
431
432 #[test]
433 fn test_parse_variables() {
434 let source = r#"
435package main
436
437var GlobalConfig Config
438var (
439 Logger *log.Logger
440 Version = "1.0.0"
441)
442 "#;
443
444 let symbols = parse("test.go", source).unwrap();
445
446 let var_symbols: Vec<_> = symbols.iter()
447 .filter(|s| matches!(s.kind, SymbolKind::Variable))
448 .collect();
449
450 assert_eq!(var_symbols.len(), 3);
451 assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("GlobalConfig")));
452 assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("Logger")));
453 assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("Version")));
454 }
455
456 #[test]
457 fn test_parse_mixed_symbols() {
458 let source = r#"
459package main
460
461const DefaultPort = 8080
462
463type Server struct {
464 Port int
465}
466
467type Handler interface {
468 Handle(req *Request) error
469}
470
471func (s *Server) Start() error {
472 return nil
473}
474
475func NewServer(port int) *Server {
476 return &Server{Port: port}
477}
478
479var globalServer *Server
480 "#;
481
482 let symbols = parse("test.go", source).unwrap();
483
484 assert!(symbols.len() >= 6);
486
487 let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
488 assert!(kinds.contains(&&SymbolKind::Constant));
489 assert!(kinds.contains(&&SymbolKind::Struct));
490 assert!(kinds.contains(&&SymbolKind::Interface));
491 assert!(kinds.contains(&&SymbolKind::Method));
492 assert!(kinds.contains(&&SymbolKind::Function));
493 assert!(kinds.contains(&&SymbolKind::Variable));
494 }
495
496 #[test]
497 fn test_parse_multiple_methods() {
498 let source = r#"
499package main
500
501type Calculator struct{}
502
503func (c *Calculator) Add(a, b int) int {
504 return a + b
505}
506
507func (c *Calculator) Subtract(a, b int) int {
508 return a - b
509}
510
511func (c *Calculator) Multiply(a, b int) int {
512 return a * b
513}
514 "#;
515
516 let symbols = parse("test.go", source).unwrap();
517
518 let method_symbols: Vec<_> = symbols.iter()
519 .filter(|s| matches!(s.kind, SymbolKind::Method))
520 .collect();
521
522 assert_eq!(method_symbols.len(), 3);
523 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("Add")));
524 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("Subtract")));
525 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("Multiply")));
526 }
527
528 #[test]
529 fn test_parse_type_alias() {
530 let source = r#"
531package main
532
533type UserID string
534type Age int
535
536type Config struct {
537 Host string
538 Port int
539}
540 "#;
541
542 let symbols = parse("test.go", source).unwrap();
543
544 let struct_symbols: Vec<_> = symbols.iter()
546 .filter(|s| matches!(s.kind, SymbolKind::Struct))
547 .collect();
548
549 assert_eq!(struct_symbols.len(), 1);
550 assert_eq!(struct_symbols[0].symbol.as_deref(), Some("Config"));
551 }
552
553 #[test]
554 fn test_parse_embedded_interface() {
555 let source = r#"
556package main
557
558type Reader interface {
559 Read(p []byte) (n int, err error)
560}
561
562type Writer interface {
563 Write(p []byte) (n int, err error)
564}
565
566type ReadWriter interface {
567 Reader
568 Writer
569}
570 "#;
571
572 let symbols = parse("test.go", source).unwrap();
573
574 let interface_symbols: Vec<_> = symbols.iter()
575 .filter(|s| matches!(s.kind, SymbolKind::Interface))
576 .collect();
577
578 assert_eq!(interface_symbols.len(), 3);
579 assert!(interface_symbols.iter().any(|s| s.symbol.as_deref() == Some("Reader")));
580 assert!(interface_symbols.iter().any(|s| s.symbol.as_deref() == Some("Writer")));
581 assert!(interface_symbols.iter().any(|s| s.symbol.as_deref() == Some("ReadWriter")));
582 }
583
584 #[test]
585 fn test_local_variables_included() {
586 let source = r#"
587package main
588
589var globalCount int = 10
590
591func calculate(x int) int {
592 localVar := x * 2
593 var anotherLocal int = 5
594 return localVar + anotherLocal
595}
596 "#;
597
598 let symbols = parse("test.go", source).unwrap();
599
600 let var_symbols: Vec<_> = symbols.iter()
601 .filter(|s| matches!(s.kind, SymbolKind::Variable))
602 .collect();
603
604 assert_eq!(var_symbols.len(), 3);
606 assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("globalCount")));
607 assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("localVar")));
608 assert!(var_symbols.iter().any(|s| s.symbol.as_deref() == Some("anotherLocal")));
609 }
610
611 #[test]
612 fn test_extract_go_imports() {
613 let source = r#"package main
614
615import (
616 "fmt"
617 "encoding/json"
618 "github.com/gin-gonic/gin"
619 "myproject/internal/models"
620)
621
622func main() {
623 fmt.Println("Hello")
624}
625"#;
626
627 let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
628
629 assert_eq!(deps.len(), 4, "Should extract 4 import statements");
630 assert!(deps.iter().any(|d| d.imported_path == "fmt"));
631 assert!(deps.iter().any(|d| d.imported_path == "encoding/json"));
632 assert!(deps.iter().any(|d| d.imported_path == "github.com/gin-gonic/gin"));
633 assert!(deps.iter().any(|d| d.imported_path == "myproject/internal/models"));
634
635 let fmt_dep = deps.iter().find(|d| d.imported_path == "fmt").unwrap();
637 assert!(matches!(fmt_dep.import_type, ImportType::Stdlib),
638 "fmt should be classified as Stdlib");
639
640 let json_dep = deps.iter().find(|d| d.imported_path == "encoding/json").unwrap();
641 assert!(matches!(json_dep.import_type, ImportType::Stdlib),
642 "encoding/json should be classified as Stdlib");
643
644 let gin_dep = deps.iter().find(|d| d.imported_path == "github.com/gin-gonic/gin").unwrap();
646 assert!(matches!(gin_dep.import_type, ImportType::External),
647 "github.com/gin-gonic/gin should be classified as External");
648
649 let models_dep = deps.iter().find(|d| d.imported_path == "myproject/internal/models").unwrap();
651 assert!(matches!(models_dep.import_type, ImportType::External),
652 "myproject/internal/models should be classified as External");
653 }
654
655 #[test]
656 fn test_extract_go_imports_with_comments() {
657 let source = r#"package main
659
660import (
661 "os"
662 _ "time/tzdata" // for timeZone support in CronJob
663
664 "k8s.io/component-base/cli"
665 _ "k8s.io/component-base/logs/json/register" // for JSON log format registration
666 _ "k8s.io/component-base/metrics/prometheus/clientgo" // load all the prometheus client-go plugins
667)
668
669func main() {
670 os.Exit(0)
671}
672"#;
673
674 let deps = GoDependencyExtractor::extract_dependencies(source).unwrap();
675
676 println!("Extracted {} dependencies:", deps.len());
677 for dep in &deps {
678 println!(" - {} (line {})", dep.imported_path, dep.line_number);
679 }
680
681 assert!(deps.len() >= 4, "Should extract at least 4 imports, got {}", deps.len());
683 assert!(deps.iter().any(|d| d.imported_path == "os"));
684 assert!(deps.iter().any(|d| d.imported_path == "time/tzdata"));
685 assert!(deps.iter().any(|d| d.imported_path == "k8s.io/component-base/cli"));
686 }
687
688 #[test]
689 fn test_find_all_go_mods() {
690 use tempfile::TempDir;
691 use std::fs;
692
693 let temp = TempDir::new().unwrap();
694 let root = temp.path();
695
696 let service1 = root.join("services/auth");
698 fs::create_dir_all(&service1).unwrap();
699 fs::write(service1.join("go.mod"), "module github.com/myorg/auth\n\ngo 1.21\n").unwrap();
700
701 let service2 = root.join("services/api");
702 fs::create_dir_all(&service2).unwrap();
703 fs::write(service2.join("go.mod"), "module github.com/myorg/api\n\ngo 1.21\n").unwrap();
704
705 let vendor = root.join("vendor");
707 fs::create_dir_all(&vendor).unwrap();
708 fs::write(vendor.join("go.mod"), "module github.com/external/lib\n").unwrap();
709
710 let mods = find_all_go_mods(root).unwrap();
711
712 assert_eq!(mods.len(), 2);
714 assert!(mods.iter().any(|p| p.ends_with("services/auth/go.mod")));
715 assert!(mods.iter().any(|p| p.ends_with("services/api/go.mod")));
716 }
717
718 #[test]
719 fn test_parse_all_go_modules() {
720 use tempfile::TempDir;
721 use std::fs;
722
723 let temp = TempDir::new().unwrap();
724 let root = temp.path();
725
726 let service1 = root.join("services/auth");
728 fs::create_dir_all(&service1).unwrap();
729 fs::write(
730 service1.join("go.mod"),
731 "module github.com/myorg/auth\n\ngo 1.21\n"
732 ).unwrap();
733
734 let service2 = root.join("cmd/api");
735 fs::create_dir_all(&service2).unwrap();
736 fs::write(
737 service2.join("go.mod"),
738 "module github.com/myorg/api\n\ngo 1.21\n"
739 ).unwrap();
740
741 let modules = parse_all_go_modules(root).unwrap();
742
743 assert_eq!(modules.len(), 2);
745
746 let names: Vec<_> = modules.iter().map(|m| m.name.as_str()).collect();
748 assert!(names.contains(&"github.com/myorg/auth"));
749 assert!(names.contains(&"github.com/myorg/api"));
750
751 for module in &modules {
753 assert!(module.project_root.starts_with("services/") || module.project_root.starts_with("cmd/"));
754 assert!(module.abs_project_root.ends_with(&module.project_root));
755 }
756 }
757
758 #[test]
759 fn test_resolve_go_import() {
760 use tempfile::TempDir;
761 use std::fs;
762
763 let temp = TempDir::new().unwrap();
764 let root = temp.path();
765
766 let myapp = root.join("myapp");
768 fs::create_dir_all(myapp.join("pkg/models")).unwrap();
769 fs::write(
770 myapp.join("go.mod"),
771 "module github.com/myorg/myapp\n\ngo 1.21\n"
772 ).unwrap();
773
774 let modules = parse_all_go_modules(root).unwrap();
775 assert_eq!(modules.len(), 1);
776
777 let resolved = resolve_go_import_to_path(
780 "github.com/myorg/myapp/pkg/models",
781 &modules,
782 None
783 );
784
785 assert!(resolved.is_some());
786 let path = resolved.unwrap();
787 assert!(path.contains("myapp/pkg/models"));
788 assert!(path.ends_with(".go"));
789 }
790
791 #[test]
792 fn test_resolve_go_import_module_root() {
793 use tempfile::TempDir;
794 use std::fs;
795
796 let temp = TempDir::new().unwrap();
797 let root = temp.path();
798
799 let myapp = root.join("cmd/server");
800 fs::create_dir_all(&myapp).unwrap();
801 fs::write(
802 myapp.join("go.mod"),
803 "module github.com/myorg/server\n\ngo 1.21\n"
804 ).unwrap();
805
806 let modules = parse_all_go_modules(root).unwrap();
807
808 let resolved = resolve_go_import_to_path(
810 "github.com/myorg/server",
811 &modules,
812 None
813 );
814
815 assert!(resolved.is_some());
816 let path = resolved.unwrap();
817 assert!(path.contains("cmd/server"));
819 assert!(path.ends_with(".go"));
820 }
821
822 #[test]
823 fn test_resolve_go_import_not_found() {
824 use tempfile::TempDir;
825 use std::fs;
826
827 let temp = TempDir::new().unwrap();
828 let root = temp.path();
829
830 let myapp = root.join("myapp");
831 fs::create_dir_all(&myapp).unwrap();
832 fs::write(
833 myapp.join("go.mod"),
834 "module github.com/myorg/myapp\n\ngo 1.21\n"
835 ).unwrap();
836
837 let modules = parse_all_go_modules(root).unwrap();
838
839 let resolved = resolve_go_import_to_path(
841 "github.com/other/package",
842 &modules,
843 None
844 );
845
846 assert!(resolved.is_none());
848 }
849
850 #[test]
851 fn test_resolve_go_import_relative() {
852 let modules = vec![];
853
854 let resolved = resolve_go_import_to_path(
856 "./utils",
857 &modules,
858 Some("myapp/pkg/api/handler.go"),
859 );
860
861 assert!(resolved.is_none());
862 }
863
864 #[test]
865 fn test_resolve_go_import_root_module_no_leading_slash() {
866 use tempfile::TempDir;
869 use std::fs;
870
871 let temp = TempDir::new().unwrap();
872 let root = temp.path();
873
874 fs::write(
876 root.join("go.mod"),
877 "module k8s.io/kubernetes\n\ngo 1.21\n",
878 ).unwrap();
879
880 let modules = parse_all_go_modules(root).unwrap();
881 assert_eq!(modules.len(), 1);
882 assert_eq!(modules[0].project_root, "");
883
884 let resolved = resolve_go_import_to_path(
886 "k8s.io/kubernetes/test/internal/metric",
887 &modules,
888 None,
889 );
890 assert!(resolved.is_some());
891 let path = resolved.unwrap();
892 assert!(!path.starts_with('/'), "path must not start with '/': {}", path);
893 assert!(path.ends_with(".go"));
894 assert!(path.contains("test/internal/metric"));
895
896 let resolved = resolve_go_import_to_path(
898 "k8s.io/kubernetes",
899 &modules,
900 None,
901 );
902 assert!(resolved.is_some());
903 let path = resolved.unwrap();
904 assert!(!path.starts_with('/'), "path must not start with '/': {}", path);
905 assert!(path.ends_with(".go"));
906 }
907}
908
909use crate::models::ImportType;
914use crate::parsers::{DependencyExtractor, ImportInfo};
915
916pub struct GoDependencyExtractor;
918
919impl DependencyExtractor for GoDependencyExtractor {
920 fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
921 let mut parser = Parser::new();
922 let language = tree_sitter_go::LANGUAGE;
923
924 parser
925 .set_language(&language.into())
926 .context("Failed to set Go language")?;
927
928 let tree = parser
929 .parse(source, None)
930 .context("Failed to parse Go source")?;
931
932 let root_node = tree.root_node();
933
934 let mut imports = Vec::new();
935
936 imports.extend(extract_go_imports(source, &root_node)?);
938
939 Ok(imports)
940 }
941}
942
943fn extract_go_imports(
945 source: &str,
946 root: &tree_sitter::Node,
947) -> Result<Vec<ImportInfo>> {
948 let language = tree_sitter_go::LANGUAGE;
949
950 let query_str = r#"
952 (import_declaration
953 (import_spec
954 path: (interpreted_string_literal) @import_path)) @import
955
956 (import_declaration
957 (import_spec_list
958 (import_spec
959 path: (interpreted_string_literal) @import_path))) @import
960 "#;
961
962 let query = Query::new(&language.into(), query_str)
963 .context("Failed to create Go import query")?;
964
965 let mut cursor = QueryCursor::new();
966 let mut matches = cursor.matches(&query, *root, source.as_bytes());
967
968 let mut imports = Vec::new();
969
970 while let Some(match_) = matches.next() {
971 let mut import_path = None;
972 let mut import_node = None;
973
974 for capture in match_.captures {
975 let capture_name: &str = &query.capture_names()[capture.index as usize];
976 match capture_name {
977 "import_path" => {
978 let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
980 import_path = Some(raw_path.trim_matches('"').to_string());
981 }
982 "import" => {
983 import_node = Some(capture.node);
984 }
985 _ => {}
986 }
987 }
988
989 if let (Some(path), Some(node)) = (import_path, import_node) {
990 let import_type = classify_go_import(&path);
991 let line_number = node.start_position().row + 1;
992
993 imports.push(ImportInfo {
994 imported_path: path,
995 import_type,
996 line_number,
997 imported_symbols: None, });
999 }
1000 }
1001
1002 Ok(imports)
1003}
1004
1005pub fn find_go_module_name(root: &std::path::Path) -> Option<String> {
1008 let go_mod_path = root.join("go.mod");
1010 if !go_mod_path.exists() {
1011 return None;
1012 }
1013
1014 let content = std::fs::read_to_string(&go_mod_path).ok()?;
1016 for line in content.lines() {
1017 let trimmed = line.trim();
1018 if trimmed.starts_with("module ") {
1019 let module_name = trimmed["module ".len()..].trim();
1021 return Some(module_name.to_string());
1022 }
1023 }
1024
1025 None
1026}
1027
1028pub fn reclassify_go_import(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1031 classify_go_import_impl(import_path, module_prefix)
1032}
1033
1034fn classify_go_import(import_path: &str) -> ImportType {
1036 classify_go_import_impl(import_path, None)
1037}
1038
1039fn classify_go_import_impl(import_path: &str, module_prefix: Option<&str>) -> ImportType {
1041 if let Some(prefix) = module_prefix {
1043 if import_path.starts_with(prefix) {
1044 return ImportType::Internal;
1045 }
1046 if let Some(import_domain) = import_path.split('/').next() {
1049 if let Some(module_domain) = prefix.split('/').next() {
1050 if import_domain == module_domain && module_domain.contains('.') {
1052 return ImportType::Internal;
1053 }
1054 }
1055 }
1056 }
1057 if import_path.starts_with("./") || import_path.starts_with("../") {
1059 return ImportType::Internal;
1060 }
1061
1062 const STDLIB_MODULES: &[&str] = &[
1069 "fmt", "io", "os", "path", "strings", "bytes", "bufio", "errors",
1070 "context", "sync", "time", "encoding/json", "encoding/xml", "encoding/csv",
1071 "net/http", "net/url", "net", "crypto", "crypto/tls", "crypto/sha256",
1072 "database/sql", "log", "math", "regexp", "strconv", "sort", "reflect",
1073 "runtime", "testing", "flag", "filepath", "unicode", "html", "text/template",
1074 ];
1075
1076 if STDLIB_MODULES.contains(&import_path) {
1078 return ImportType::Stdlib;
1079 }
1080
1081 if import_path.contains('/') && import_path.split('/').next().unwrap_or("").contains('.') {
1083 return ImportType::External;
1084 }
1085
1086 if !import_path.contains('/') || import_path.split('/').count() <= 2 {
1088 return ImportType::Stdlib;
1089 }
1090
1091 ImportType::External
1093}
1094
1095#[derive(Debug, Clone)]
1101pub struct GoModule {
1102 pub name: String,
1104 pub project_root: String,
1106 pub abs_project_root: std::path::PathBuf,
1108}
1109
1110pub fn find_all_go_mods(index_root: &std::path::Path) -> Result<Vec<std::path::PathBuf>> {
1112 use ignore::WalkBuilder;
1113
1114 let mut go_mod_files = Vec::new();
1115
1116 let walker = WalkBuilder::new(index_root)
1117 .follow_links(false)
1118 .git_ignore(true)
1119 .build();
1120
1121 for entry in walker {
1122 let entry = entry?;
1123 let path = entry.path();
1124
1125 if !path.is_file() {
1126 continue;
1127 }
1128
1129 let filename = path.file_name()
1130 .and_then(|n| n.to_str())
1131 .unwrap_or("");
1132
1133 if filename == "go.mod" {
1135 let path_str = path.to_string_lossy();
1137 if path_str.contains("/vendor/") {
1138 log::trace!("Skipping go.mod in vendor directory: {:?}", path);
1139 continue;
1140 }
1141
1142 go_mod_files.push(path.to_path_buf());
1143 }
1144 }
1145
1146 log::debug!("Found {} go.mod files", go_mod_files.len());
1147 Ok(go_mod_files)
1148}
1149
1150pub fn parse_all_go_modules(index_root: &std::path::Path) -> Result<Vec<GoModule>> {
1152 let go_mod_files = find_all_go_mods(index_root)?;
1153
1154 if go_mod_files.is_empty() {
1155 log::debug!("No go.mod files found in {:?}", index_root);
1156 return Ok(Vec::new());
1157 }
1158
1159 let mut modules = Vec::new();
1160 let mod_count = go_mod_files.len();
1161
1162 for go_mod_path in &go_mod_files {
1163 let project_root = go_mod_path
1164 .parent()
1165 .ok_or_else(|| anyhow::anyhow!("go.mod has no parent directory"))?;
1166
1167 if let Ok(content) = std::fs::read_to_string(go_mod_path) {
1169 for line in content.lines() {
1170 let trimmed = line.trim();
1171 if trimmed.starts_with("module ") {
1172 let module_name = trimmed["module ".len()..].trim().to_string();
1173
1174 let relative_project_root = project_root
1175 .strip_prefix(index_root)
1176 .unwrap_or(project_root)
1177 .to_string_lossy()
1178 .to_string();
1179
1180 log::debug!(
1181 "Found Go module '{}' at {:?}",
1182 module_name,
1183 relative_project_root
1184 );
1185
1186 modules.push(GoModule {
1187 name: module_name,
1188 project_root: relative_project_root,
1189 abs_project_root: project_root.to_path_buf(),
1190 });
1191 break;
1192 }
1193 }
1194 }
1195 }
1196
1197 log::info!(
1198 "Loaded {} Go modules from {} go.mod files",
1199 modules.len(),
1200 mod_count
1201 );
1202
1203 Ok(modules)
1204}
1205
1206pub fn resolve_go_import_to_path(
1213 import_path: &str,
1214 modules: &[GoModule],
1215 _current_file_path: Option<&str>,
1216) -> Option<String> {
1217 if import_path.starts_with("./") || import_path.starts_with("../") {
1219 return None;
1221 }
1222
1223 for module in modules {
1225 if import_path.starts_with(&module.name) {
1226 let sub_path = import_path.strip_prefix(&module.name)
1229 .unwrap_or(import_path)
1230 .trim_start_matches('/');
1231
1232 if sub_path.is_empty() {
1233 let basename = module.name.split('/').last().unwrap_or("main");
1236 let candidates = if module.project_root.is_empty() {
1237 vec![
1238 "main.go".to_string(),
1239 format!("{}.go", basename),
1240 ]
1241 } else {
1242 vec![
1243 format!("{}/main.go", module.project_root),
1244 format!("{}/{}.go", module.project_root, basename),
1245 ]
1246 };
1247
1248 for candidate in candidates {
1249 log::trace!("Checking Go module root: {}", candidate);
1250 return Some(candidate);
1251 }
1252 } else {
1253 let package_name = sub_path.split('/').last().unwrap_or(sub_path);
1256 let candidates = if module.project_root.is_empty() {
1257 vec![
1258 format!("{}.go", sub_path),
1259 format!("{}/{}.go", sub_path, package_name),
1260 ]
1261 } else {
1262 vec![
1263 format!("{}/{}.go", module.project_root, sub_path),
1264 format!("{}/{}/{}.go", module.project_root, sub_path, package_name),
1265 ]
1266 };
1267
1268 for candidate in candidates {
1269 log::trace!("Checking Go package path: {}", candidate);
1270 return Some(candidate);
1271 }
1272 }
1273 }
1274 }
1275
1276 None
1277}