1use std::collections::{HashMap, HashSet};
8use std::path::Path;
9
10use graphify_core::confidence::Confidence;
11use graphify_core::id::make_id;
12use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
13use tracing::trace;
14use tree_sitter::{Language, Node, Parser};
15
16pub struct TsConfig {
23 pub class_types: HashSet<&'static str>,
24 pub function_types: HashSet<&'static str>,
25 pub import_types: HashSet<&'static str>,
26 pub call_types: HashSet<&'static str>,
27 pub name_field: &'static str,
29 pub class_name_field: Option<&'static str>,
31 pub body_field: &'static str,
33 pub call_function_field: &'static str,
35}
36
37fn python_config() -> TsConfig {
38 TsConfig {
39 class_types: ["class_definition"].into_iter().collect(),
40 function_types: ["function_definition"].into_iter().collect(),
41 import_types: ["import_statement", "import_from_statement"]
42 .into_iter()
43 .collect(),
44 call_types: ["call"].into_iter().collect(),
45 name_field: "name",
46 class_name_field: None,
47 body_field: "body",
48 call_function_field: "function",
49 }
50}
51
52fn js_config() -> TsConfig {
53 TsConfig {
54 class_types: ["class_declaration", "class"].into_iter().collect(),
55 function_types: [
56 "function_declaration",
57 "method_definition",
58 "arrow_function",
59 "generator_function_declaration",
60 ]
61 .into_iter()
62 .collect(),
63 import_types: ["import_statement"].into_iter().collect(),
64 call_types: ["call_expression"].into_iter().collect(),
65 name_field: "name",
66 class_name_field: None,
67 body_field: "body",
68 call_function_field: "function",
69 }
70}
71
72fn rust_config() -> TsConfig {
73 TsConfig {
74 class_types: ["struct_item", "enum_item", "trait_item", "impl_item"]
75 .into_iter()
76 .collect(),
77 function_types: ["function_item"].into_iter().collect(),
78 import_types: ["use_declaration"].into_iter().collect(),
79 call_types: ["call_expression"].into_iter().collect(),
80 name_field: "name",
81 class_name_field: None,
82 body_field: "body",
83 call_function_field: "function",
84 }
85}
86
87fn go_config() -> TsConfig {
88 TsConfig {
89 class_types: ["type_declaration"].into_iter().collect(),
90 function_types: ["function_declaration", "method_declaration"]
91 .into_iter()
92 .collect(),
93 import_types: ["import_declaration"].into_iter().collect(),
94 call_types: ["call_expression"].into_iter().collect(),
95 name_field: "name",
96 class_name_field: None,
97 body_field: "body",
98 call_function_field: "function",
99 }
100}
101
102fn java_config() -> TsConfig {
103 TsConfig {
104 class_types: ["class_declaration", "interface_declaration"]
105 .into_iter()
106 .collect(),
107 function_types: ["method_declaration", "constructor_declaration"]
108 .into_iter()
109 .collect(),
110 import_types: ["import_declaration"].into_iter().collect(),
111 call_types: ["method_invocation"].into_iter().collect(),
112 name_field: "name",
113 class_name_field: None,
114 body_field: "body",
115 call_function_field: "name",
116 }
117}
118
119fn c_config() -> TsConfig {
120 TsConfig {
121 class_types: HashSet::new(),
122 function_types: ["function_definition"].into_iter().collect(),
123 import_types: ["preproc_include"].into_iter().collect(),
124 call_types: ["call_expression"].into_iter().collect(),
125 name_field: "declarator",
126 class_name_field: None,
127 body_field: "body",
128 call_function_field: "function",
129 }
130}
131
132fn cpp_config() -> TsConfig {
133 TsConfig {
134 class_types: ["class_specifier"].into_iter().collect(),
135 function_types: ["function_definition"].into_iter().collect(),
136 import_types: ["preproc_include"].into_iter().collect(),
137 call_types: ["call_expression"].into_iter().collect(),
138 name_field: "declarator",
139 class_name_field: Some("name"),
140 body_field: "body",
141 call_function_field: "function",
142 }
143}
144
145fn ruby_config() -> TsConfig {
146 TsConfig {
147 class_types: ["class"].into_iter().collect(),
148 function_types: ["method", "singleton_method"].into_iter().collect(),
149 import_types: HashSet::new(),
150 call_types: ["call"].into_iter().collect(),
151 name_field: "name",
152 class_name_field: None,
153 body_field: "body",
154 call_function_field: "method",
155 }
156}
157
158fn csharp_config() -> TsConfig {
159 TsConfig {
160 class_types: ["class_declaration", "interface_declaration"]
161 .into_iter()
162 .collect(),
163 function_types: ["method_declaration"].into_iter().collect(),
164 import_types: ["using_directive"].into_iter().collect(),
165 call_types: ["invocation_expression"].into_iter().collect(),
166 name_field: "name",
167 class_name_field: None,
168 body_field: "body",
169 call_function_field: "function",
170 }
171}
172
173pub fn try_extract(path: &Path, source: &[u8], lang: &str) -> Option<ExtractionResult> {
180 let (language, config) = match lang {
181 "python" => (tree_sitter_python::LANGUAGE.into(), python_config()),
182 "javascript" => (tree_sitter_javascript::LANGUAGE.into(), js_config()),
183 "typescript" => (
184 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
185 js_config(),
186 ),
187 "rust" => (tree_sitter_rust::LANGUAGE.into(), rust_config()),
188 "go" => (tree_sitter_go::LANGUAGE.into(), go_config()),
189 "java" => (tree_sitter_java::LANGUAGE.into(), java_config()),
190 "c" => (tree_sitter_c::LANGUAGE.into(), c_config()),
191 "cpp" => (tree_sitter_cpp::LANGUAGE.into(), cpp_config()),
192 "ruby" => (tree_sitter_ruby::LANGUAGE.into(), ruby_config()),
193 "csharp" => (tree_sitter_c_sharp::LANGUAGE.into(), csharp_config()),
194 _ => return None,
195 };
196 extract_with_treesitter(path, source, language, &config, lang)
197}
198
199fn extract_with_treesitter(
205 path: &Path,
206 source: &[u8],
207 language: Language,
208 config: &TsConfig,
209 lang: &str,
210) -> Option<ExtractionResult> {
211 let mut parser = Parser::new();
212 parser.set_language(&language).ok()?;
213 let tree = parser.parse(source, None)?;
214 let root = tree.root_node();
215
216 let stem = path.file_stem()?.to_str()?;
217 let str_path = path.to_string_lossy();
218
219 let mut nodes = Vec::new();
220 let mut edges = Vec::new();
221 let mut seen_ids = HashSet::new();
222 let mut function_bodies: Vec<(String, usize, usize)> = Vec::new();
224
225 let file_nid = make_id(&[&str_path]);
227 seen_ids.insert(file_nid.clone());
228 nodes.push(GraphNode {
229 id: file_nid.clone(),
230 label: stem.to_string(),
231 source_file: str_path.to_string(),
232 source_location: None,
233 node_type: NodeType::File,
234 community: None,
235 extra: HashMap::new(),
236 });
237
238 walk_node(
240 root,
241 source,
242 config,
243 lang,
244 &file_nid,
245 stem,
246 &str_path,
247 &mut nodes,
248 &mut edges,
249 &mut seen_ids,
250 &mut function_bodies,
251 None,
252 );
253
254 let label_to_nid: HashMap<String, String> = nodes
257 .iter()
258 .filter(|n| matches!(n.node_type, NodeType::Function | NodeType::Method))
259 .map(|n| {
260 let normalized = n
261 .label
262 .trim_end_matches("()")
263 .trim_start_matches('.')
264 .to_lowercase();
265 (normalized, n.id.clone())
266 })
267 .collect();
268
269 let mut seen_calls: HashSet<(String, String)> = HashSet::new();
270 for (caller_nid, body_start, body_end) in &function_bodies {
271 let body_text = &source[*body_start..*body_end];
272 let body_str = String::from_utf8_lossy(body_text);
273 for (func_label, callee_nid) in &label_to_nid {
274 if callee_nid == caller_nid {
275 continue;
276 }
277 if body_str.to_lowercase().contains(&format!("{func_label}(")) {
279 let key = (caller_nid.clone(), callee_nid.clone());
280 if seen_calls.insert(key) {
281 edges.push(GraphEdge {
282 source: caller_nid.clone(),
283 target: callee_nid.clone(),
284 relation: "calls".to_string(),
285 confidence: Confidence::Inferred,
286 confidence_score: Confidence::Inferred.default_score(),
287 source_file: str_path.to_string(),
288 source_location: None,
289 weight: 1.0,
290 extra: HashMap::new(),
291 });
292 }
293 }
294 }
295 }
296
297 trace!(
298 "treesitter({}): {} nodes, {} edges from {}",
299 lang,
300 nodes.len(),
301 edges.len(),
302 str_path
303 );
304
305 Some(ExtractionResult {
306 nodes,
307 edges,
308 hyperedges: vec![],
309 })
310}
311
312#[allow(clippy::too_many_arguments)]
317fn walk_node(
318 node: Node,
319 source: &[u8],
320 config: &TsConfig,
321 lang: &str,
322 file_nid: &str,
323 stem: &str,
324 str_path: &str,
325 nodes: &mut Vec<GraphNode>,
326 edges: &mut Vec<GraphEdge>,
327 seen_ids: &mut HashSet<String>,
328 function_bodies: &mut Vec<(String, usize, usize)>,
329 parent_class_nid: Option<&str>,
330) {
331 let kind = node.kind();
332
333 if config.import_types.contains(kind) {
335 extract_import(node, source, file_nid, str_path, lang, edges, nodes);
336 return; }
338
339 if config.class_types.contains(kind) {
341 handle_class_like(
342 node,
343 source,
344 config,
345 lang,
346 file_nid,
347 stem,
348 str_path,
349 nodes,
350 edges,
351 seen_ids,
352 function_bodies,
353 );
354 return;
355 }
356
357 if config.function_types.contains(kind) {
359 handle_function(
360 node,
361 source,
362 config,
363 lang,
364 file_nid,
365 stem,
366 str_path,
367 nodes,
368 edges,
369 seen_ids,
370 function_bodies,
371 parent_class_nid,
372 );
373 return;
374 }
375
376 let mut cursor = node.walk();
378 for child in node.children(&mut cursor) {
379 walk_node(
380 child,
381 source,
382 config,
383 lang,
384 file_nid,
385 stem,
386 str_path,
387 nodes,
388 edges,
389 seen_ids,
390 function_bodies,
391 parent_class_nid,
392 );
393 }
394}
395
396#[allow(clippy::too_many_arguments)]
401fn handle_class_like(
402 node: Node,
403 source: &[u8],
404 config: &TsConfig,
405 lang: &str,
406 file_nid: &str,
407 stem: &str,
408 str_path: &str,
409 nodes: &mut Vec<GraphNode>,
410 edges: &mut Vec<GraphEdge>,
411 seen_ids: &mut HashSet<String>,
412 function_bodies: &mut Vec<(String, usize, usize)>,
413) {
414 let kind = node.kind();
415
416 if lang == "go" && kind == "type_declaration" {
418 let mut cursor = node.walk();
419 for child in node.children(&mut cursor) {
420 if child.kind() == "type_spec" {
421 handle_go_type_spec(
422 child,
423 source,
424 config,
425 lang,
426 file_nid,
427 stem,
428 str_path,
429 nodes,
430 edges,
431 seen_ids,
432 function_bodies,
433 );
434 }
435 }
436 return;
437 }
438
439 if lang == "rust" && kind == "impl_item" {
441 handle_rust_impl(
442 node,
443 source,
444 config,
445 lang,
446 file_nid,
447 stem,
448 str_path,
449 nodes,
450 edges,
451 seen_ids,
452 function_bodies,
453 );
454 return;
455 }
456
457 let class_field = config.class_name_field.unwrap_or(config.name_field);
459 let name = match get_name(node, source, class_field) {
460 Some(n) => n,
461 None => return,
462 };
463 let line = node.start_position().row + 1;
464 let class_nid = make_id(&[str_path, &name]);
465
466 let node_type = classify_class_kind(kind, lang);
467
468 if seen_ids.insert(class_nid.clone()) {
469 nodes.push(GraphNode {
470 id: class_nid.clone(),
471 label: name.clone(),
472 source_file: str_path.to_string(),
473 source_location: Some(format!("L{line}")),
474 node_type,
475 community: None,
476 extra: HashMap::new(),
477 });
478 edges.push(make_edge(file_nid, &class_nid, "defines", str_path, line));
479 }
480
481 if let Some(body) = node.child_by_field_name(config.body_field) {
483 let mut cursor = body.walk();
484 for child in body.children(&mut cursor) {
485 walk_node(
486 child,
487 source,
488 config,
489 lang,
490 file_nid,
491 stem,
492 str_path,
493 nodes,
494 edges,
495 seen_ids,
496 function_bodies,
497 Some(&class_nid),
498 );
499 }
500 }
501}
502
503fn classify_class_kind(kind: &str, lang: &str) -> NodeType {
504 match (kind, lang) {
505 ("struct_item", "rust") => NodeType::Struct,
506 ("enum_item", "rust") => NodeType::Enum,
507 ("trait_item", "rust") => NodeType::Trait,
508 _ => NodeType::Class,
509 }
510}
511
512#[allow(clippy::too_many_arguments)]
513fn handle_go_type_spec(
514 node: Node,
515 source: &[u8],
516 config: &TsConfig,
517 lang: &str,
518 file_nid: &str,
519 stem: &str,
520 str_path: &str,
521 nodes: &mut Vec<GraphNode>,
522 edges: &mut Vec<GraphEdge>,
523 seen_ids: &mut HashSet<String>,
524 function_bodies: &mut Vec<(String, usize, usize)>,
525) {
526 let name = match get_name(node, source, "name") {
527 Some(n) => n,
528 None => return,
529 };
530 let line = node.start_position().row + 1;
531 let nid = make_id(&[str_path, &name]);
532
533 let node_type = {
535 let mut nt = NodeType::Struct;
536 let mut cursor = node.walk();
537 for child in node.children(&mut cursor) {
538 match child.kind() {
539 "interface_type" => {
540 nt = NodeType::Interface;
541 break;
542 }
543 "struct_type" => {
544 nt = NodeType::Struct;
545 break;
546 }
547 _ => {}
548 }
549 }
550 nt
551 };
552
553 if seen_ids.insert(nid.clone()) {
554 nodes.push(GraphNode {
555 id: nid.clone(),
556 label: name.clone(),
557 source_file: str_path.to_string(),
558 source_location: Some(format!("L{line}")),
559 node_type,
560 community: None,
561 extra: HashMap::new(),
562 });
563 edges.push(make_edge(file_nid, &nid, "defines", str_path, line));
564 }
565
566 if let Some(body) = node.child_by_field_name(config.body_field) {
569 let mut cursor = body.walk();
570 for child in body.children(&mut cursor) {
571 walk_node(
572 child,
573 source,
574 config,
575 lang,
576 file_nid,
577 stem,
578 str_path,
579 nodes,
580 edges,
581 seen_ids,
582 function_bodies,
583 Some(&nid),
584 );
585 }
586 }
587}
588
589#[allow(clippy::too_many_arguments)]
590fn handle_rust_impl(
591 node: Node,
592 source: &[u8],
593 config: &TsConfig,
594 lang: &str,
595 file_nid: &str,
596 stem: &str,
597 str_path: &str,
598 nodes: &mut Vec<GraphNode>,
599 edges: &mut Vec<GraphEdge>,
600 seen_ids: &mut HashSet<String>,
601 function_bodies: &mut Vec<(String, usize, usize)>,
602) {
603 let type_name = node
606 .child_by_field_name("type")
607 .map(|n| node_text(n, source));
608 let trait_name = node
609 .child_by_field_name("trait")
610 .map(|n| node_text(n, source));
611
612 let impl_target_nid = type_name.as_ref().map(|tn| make_id(&[str_path, tn]));
613
614 if let (Some(trait_n), Some(target_nid)) = (&trait_name, &impl_target_nid) {
616 let line = node.start_position().row + 1;
617 let trait_nid = make_id(&[str_path, trait_n]);
618 edges.push(GraphEdge {
619 source: target_nid.clone(),
620 target: trait_nid,
621 relation: "implements".to_string(),
622 confidence: Confidence::Extracted,
623 confidence_score: Confidence::Extracted.default_score(),
624 source_file: str_path.to_string(),
625 source_location: Some(format!("L{line}")),
626 weight: 1.0,
627 extra: HashMap::new(),
628 });
629 }
630
631 if let Some(body) = node.child_by_field_name(config.body_field) {
633 let class_nid = impl_target_nid.as_deref();
634 let mut cursor = body.walk();
635 for child in body.children(&mut cursor) {
636 walk_node(
637 child,
638 source,
639 config,
640 lang,
641 file_nid,
642 stem,
643 str_path,
644 nodes,
645 edges,
646 seen_ids,
647 function_bodies,
648 class_nid,
649 );
650 }
651 }
652}
653
654#[allow(clippy::too_many_arguments)]
659fn handle_function(
660 node: Node,
661 source: &[u8],
662 config: &TsConfig,
663 _lang: &str,
664 file_nid: &str,
665 _stem: &str,
666 str_path: &str,
667 nodes: &mut Vec<GraphNode>,
668 edges: &mut Vec<GraphEdge>,
669 seen_ids: &mut HashSet<String>,
670 function_bodies: &mut Vec<(String, usize, usize)>,
671 parent_class_nid: Option<&str>,
672) {
673 let func_name = match get_name(node, source, config.name_field) {
677 Some(n) => n,
678 None => {
679 if node.kind() == "arrow_function" {
681 if let Some(parent) = node.parent() {
682 if parent.kind() == "variable_declarator" {
683 match get_name(parent, source, "name") {
684 Some(n) => n,
685 None => return,
686 }
687 } else {
688 return;
689 }
690 } else {
691 return;
692 }
693 } else {
694 return;
695 }
696 }
697 };
698
699 let line = node.start_position().row + 1;
700
701 let (func_nid, label, node_type, relation) = if let Some(class_nid) = parent_class_nid {
702 let nid = make_id(&[class_nid, &func_name]);
703 (
704 nid,
705 format!(".{}()", func_name),
706 NodeType::Method,
707 "defines",
708 )
709 } else {
710 let nid = make_id(&[str_path, &func_name]);
711 (
712 nid,
713 format!("{}()", func_name),
714 NodeType::Function,
715 "defines",
716 )
717 };
718
719 if seen_ids.insert(func_nid.clone()) {
720 nodes.push(GraphNode {
721 id: func_nid.clone(),
722 label,
723 source_file: str_path.to_string(),
724 source_location: Some(format!("L{line}")),
725 node_type,
726 community: None,
727 extra: HashMap::new(),
728 });
729
730 let parent_nid = parent_class_nid.unwrap_or(file_nid);
731 edges.push(make_edge(parent_nid, &func_nid, relation, str_path, line));
732 }
733
734 if let Some(body) = node.child_by_field_name(config.body_field) {
736 function_bodies.push((func_nid, body.start_byte(), body.end_byte()));
737 } else {
738 function_bodies.push((func_nid, node.start_byte(), node.end_byte()));
740 }
741}
742
743fn extract_import(
748 node: Node,
749 source: &[u8],
750 file_nid: &str,
751 str_path: &str,
752 lang: &str,
753 edges: &mut Vec<GraphEdge>,
754 nodes: &mut Vec<GraphNode>,
755) {
756 let line = node.start_position().row + 1;
757 let import_text = node_text(node, source);
758
759 match lang {
760 "python" => extract_python_import(node, source, file_nid, str_path, line, edges, nodes),
761 "javascript" | "typescript" => {
762 extract_js_import(node, source, file_nid, str_path, line, edges, nodes)
763 }
764 "rust" => {
765 let module = import_text
767 .strip_prefix("use ")
768 .unwrap_or(&import_text)
769 .trim_end_matches(';')
770 .trim();
771 add_import_node(
772 nodes,
773 edges,
774 file_nid,
775 str_path,
776 line,
777 module,
778 NodeType::Module,
779 );
780 }
781 "go" => {
782 extract_go_import(node, source, file_nid, str_path, line, edges, nodes);
783 }
784 "java" => {
785 let text = node_text(node, source);
787 let module = text
788 .trim()
789 .strip_prefix("import ")
790 .unwrap_or(&text)
791 .strip_prefix("static ")
792 .unwrap_or_else(|| text.trim().strip_prefix("import ").unwrap_or(&text))
793 .trim_end_matches(';')
794 .trim();
795 add_import_node(
796 nodes,
797 edges,
798 file_nid,
799 str_path,
800 line,
801 module,
802 NodeType::Module,
803 );
804 }
805 "c" | "cpp" => {
806 let text = node_text(node, source);
808 let module = text
809 .trim()
810 .strip_prefix("#include")
811 .unwrap_or(&text)
812 .trim()
813 .trim_matches(&['<', '>', '"'][..])
814 .trim();
815 add_import_node(
816 nodes,
817 edges,
818 file_nid,
819 str_path,
820 line,
821 module,
822 NodeType::Module,
823 );
824 }
825 "csharp" => {
826 let text = node_text(node, source);
828 let module = text
829 .trim()
830 .strip_prefix("using ")
831 .unwrap_or(&text)
832 .trim_end_matches(';')
833 .trim();
834 add_import_node(
835 nodes,
836 edges,
837 file_nid,
838 str_path,
839 line,
840 module,
841 NodeType::Module,
842 );
843 }
844 _ => {
845 add_import_node(
846 nodes,
847 edges,
848 file_nid,
849 str_path,
850 line,
851 &import_text,
852 NodeType::Module,
853 );
854 }
855 }
856}
857
858fn extract_python_import(
859 node: Node,
860 source: &[u8],
861 file_nid: &str,
862 str_path: &str,
863 line: usize,
864 edges: &mut Vec<GraphEdge>,
865 nodes: &mut Vec<GraphNode>,
866) {
867 let kind = node.kind();
870
871 if kind == "import_from_statement" {
872 let module = node
873 .child_by_field_name("module_name")
874 .map(|n| node_text(n, source))
875 .unwrap_or_default();
876 let mut cursor = node.walk();
878 for child in node.children(&mut cursor) {
879 if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
880 let name_node = if child.kind() == "aliased_import" {
881 child.child_by_field_name("name")
882 } else {
883 Some(child)
884 };
885 if let Some(nn) = name_node {
886 let name = node_text(nn, source);
887 if name != module {
888 let full = if module.is_empty() {
889 name
890 } else {
891 format!("{module}.{name}")
892 };
893 add_import_node(
894 nodes,
895 edges,
896 file_nid,
897 str_path,
898 line,
899 &full,
900 NodeType::Module,
901 );
902 }
903 }
904 }
905 }
906 let import_count = edges.iter().filter(|e| e.relation == "imports").count();
908 if import_count == 0 && !module.is_empty() {
909 add_import_node(
910 nodes,
911 edges,
912 file_nid,
913 str_path,
914 line,
915 &module,
916 NodeType::Module,
917 );
918 }
919 } else {
920 let mut cursor = node.walk();
922 for child in node.children(&mut cursor) {
923 if child.kind() == "dotted_name" || child.kind() == "aliased_import" {
924 let name_node = if child.kind() == "aliased_import" {
925 child.child_by_field_name("name")
926 } else {
927 Some(child)
928 };
929 if let Some(nn) = name_node {
930 let name = node_text(nn, source);
931 add_import_node(
932 nodes,
933 edges,
934 file_nid,
935 str_path,
936 line,
937 &name,
938 NodeType::Module,
939 );
940 }
941 }
942 }
943 }
944}
945
946fn extract_js_import(
947 node: Node,
948 source: &[u8],
949 file_nid: &str,
950 str_path: &str,
951 line: usize,
952 edges: &mut Vec<GraphEdge>,
953 nodes: &mut Vec<GraphNode>,
954) {
955 let module = node
958 .child_by_field_name("source")
959 .map(|n| {
960 let t = node_text(n, source);
961 t.trim_matches(&['"', '\''][..]).to_string()
962 })
963 .unwrap_or_default();
964
965 let mut found_names = false;
967 let mut cursor = node.walk();
968 for child in node.children(&mut cursor) {
969 if child.kind() == "import_clause" {
970 let mut inner_cursor = child.walk();
971 for inner in child.children(&mut inner_cursor) {
972 match inner.kind() {
973 "identifier" => {
974 let name = node_text(inner, source);
975 let full = format!("{module}/{name}");
976 add_import_node(
977 nodes,
978 edges,
979 file_nid,
980 str_path,
981 line,
982 &full,
983 NodeType::Module,
984 );
985 found_names = true;
986 }
987 "named_imports" => {
988 let mut spec_cursor = inner.walk();
989 for spec in inner.children(&mut spec_cursor) {
990 if spec.kind() == "import_specifier" {
991 let name = spec
992 .child_by_field_name("name")
993 .map(|n| node_text(n, source))
994 .unwrap_or_else(|| node_text(spec, source));
995 let full = format!("{module}/{name}");
996 add_import_node(
997 nodes,
998 edges,
999 file_nid,
1000 str_path,
1001 line,
1002 &full,
1003 NodeType::Module,
1004 );
1005 found_names = true;
1006 }
1007 }
1008 }
1009 _ => {}
1010 }
1011 }
1012 }
1013 }
1014
1015 if !found_names && !module.is_empty() {
1016 add_import_node(
1017 nodes,
1018 edges,
1019 file_nid,
1020 str_path,
1021 line,
1022 &module,
1023 NodeType::Module,
1024 );
1025 }
1026}
1027
1028fn extract_go_import(
1029 node: Node,
1030 source: &[u8],
1031 file_nid: &str,
1032 str_path: &str,
1033 line: usize,
1034 edges: &mut Vec<GraphEdge>,
1035 nodes: &mut Vec<GraphNode>,
1036) {
1037 let mut cursor = node.walk();
1039 for child in node.children(&mut cursor) {
1040 match child.kind() {
1041 "import_spec" => {
1042 if let Some(path_node) = child.child_by_field_name("path") {
1043 let module = node_text(path_node, source).trim_matches('"').to_string();
1044 let spec_line = child.start_position().row + 1;
1045 add_import_node(
1046 nodes,
1047 edges,
1048 file_nid,
1049 str_path,
1050 spec_line,
1051 &module,
1052 NodeType::Package,
1053 );
1054 }
1055 }
1056 "import_spec_list" => {
1057 let mut inner = child.walk();
1058 for spec in child.children(&mut inner) {
1059 if spec.kind() == "import_spec"
1060 && let Some(path_node) = spec.child_by_field_name("path")
1061 {
1062 let module = node_text(path_node, source).trim_matches('"').to_string();
1063 let spec_line = spec.start_position().row + 1;
1064 add_import_node(
1065 nodes,
1066 edges,
1067 file_nid,
1068 str_path,
1069 spec_line,
1070 &module,
1071 NodeType::Package,
1072 );
1073 }
1074 }
1075 }
1076 "interpreted_string_literal" => {
1077 let module = node_text(child, source).trim_matches('"').to_string();
1079 add_import_node(
1080 nodes,
1081 edges,
1082 file_nid,
1083 str_path,
1084 line,
1085 &module,
1086 NodeType::Package,
1087 );
1088 }
1089 _ => {}
1090 }
1091 }
1092}
1093
1094fn node_text(node: Node, source: &[u8]) -> String {
1100 node.utf8_text(source).unwrap_or("").to_string()
1101}
1102
1103fn get_name(node: Node, source: &[u8], field: &str) -> Option<String> {
1105 let name_node = node.child_by_field_name(field)?;
1106 let text = unwrap_declarator_name(name_node, source);
1108 if text.is_empty() { None } else { Some(text) }
1109}
1110
1111fn unwrap_declarator_name(node: Node, source: &[u8]) -> String {
1114 match node.kind() {
1115 "function_declarator"
1116 | "pointer_declarator"
1117 | "reference_declarator"
1118 | "parenthesized_declarator" => {
1119 if let Some(inner) = node.child_by_field_name("declarator") {
1121 return unwrap_declarator_name(inner, source);
1122 }
1123 let mut cursor = node.walk();
1125 for child in node.children(&mut cursor) {
1126 if child.kind() == "identifier" || child.kind() == "field_identifier" {
1127 return node_text(child, source);
1128 }
1129 }
1130 node_text(node, source)
1131 }
1132 "qualified_identifier" | "scoped_identifier" => {
1133 if let Some(name) = node.child_by_field_name("name") {
1135 return node_text(name, source);
1136 }
1137 node_text(node, source)
1138 }
1139 _ => node_text(node, source),
1140 }
1141}
1142
1143fn add_import_node(
1144 nodes: &mut Vec<GraphNode>,
1145 edges: &mut Vec<GraphEdge>,
1146 file_nid: &str,
1147 str_path: &str,
1148 line: usize,
1149 module: &str,
1150 node_type: NodeType,
1151) {
1152 let import_id = make_id(&[str_path, "import", module]);
1153 nodes.push(GraphNode {
1154 id: import_id.clone(),
1155 label: module.to_string(),
1156 source_file: str_path.to_string(),
1157 source_location: Some(format!("L{line}")),
1158 node_type,
1159 community: None,
1160 extra: HashMap::new(),
1161 });
1162 edges.push(GraphEdge {
1163 source: file_nid.to_string(),
1164 target: import_id,
1165 relation: "imports".to_string(),
1166 confidence: Confidence::Extracted,
1167 confidence_score: Confidence::Extracted.default_score(),
1168 source_file: str_path.to_string(),
1169 source_location: Some(format!("L{line}")),
1170 weight: 1.0,
1171 extra: HashMap::new(),
1172 });
1173}
1174
1175fn make_edge(
1176 source_id: &str,
1177 target_id: &str,
1178 relation: &str,
1179 source_file: &str,
1180 line: usize,
1181) -> GraphEdge {
1182 GraphEdge {
1183 source: source_id.to_string(),
1184 target: target_id.to_string(),
1185 relation: relation.to_string(),
1186 confidence: Confidence::Extracted,
1187 confidence_score: Confidence::Extracted.default_score(),
1188 source_file: source_file.to_string(),
1189 source_location: Some(format!("L{line}")),
1190 weight: 1.0,
1191 extra: HashMap::new(),
1192 }
1193}
1194
1195#[cfg(test)]
1200mod tests {
1201 use super::*;
1202 use std::path::Path;
1203
1204 #[test]
1207 fn ts_python_extracts_class_and_methods() {
1208 let source = br#"
1209class MyClass:
1210 def __init__(self):
1211 pass
1212
1213 def greet(self, name):
1214 return f"Hello {name}"
1215
1216def standalone():
1217 pass
1218"#;
1219 let result = try_extract(Path::new("test.py"), source, "python").unwrap();
1220
1221 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1222 assert!(
1223 labels.iter().any(|l| l.contains("MyClass")),
1224 "missing MyClass: {labels:?}"
1225 );
1226 assert!(
1227 labels.iter().any(|l| l.contains("__init__")),
1228 "missing __init__: {labels:?}"
1229 );
1230 assert!(
1231 labels.iter().any(|l| l.contains("greet")),
1232 "missing greet: {labels:?}"
1233 );
1234 assert!(
1235 labels.iter().any(|l| l.contains("standalone")),
1236 "missing standalone: {labels:?}"
1237 );
1238 assert!(result.nodes.iter().any(|n| n.node_type == NodeType::File));
1239 assert!(result.nodes.iter().any(|n| n.node_type == NodeType::Class));
1240 }
1241
1242 #[test]
1243 fn ts_python_extracts_imports() {
1244 let source = br#"
1245import os
1246from pathlib import Path
1247from collections import defaultdict, OrderedDict
1248"#;
1249 let result = try_extract(Path::new("test.py"), source, "python").unwrap();
1250 let import_edges: Vec<_> = result
1251 .edges
1252 .iter()
1253 .filter(|e| e.relation == "imports")
1254 .collect();
1255 assert!(
1256 import_edges.len() >= 2,
1257 "expected >= 2 import edges, got {}",
1258 import_edges.len()
1259 );
1260 }
1261
1262 #[test]
1263 fn ts_python_infers_calls() {
1264 let source = br#"
1265def foo():
1266 bar()
1267
1268def bar():
1269 pass
1270"#;
1271 let result = try_extract(Path::new("test.py"), source, "python").unwrap();
1272 let call_edges: Vec<_> = result
1273 .edges
1274 .iter()
1275 .filter(|e| e.relation == "calls")
1276 .collect();
1277 assert!(!call_edges.is_empty(), "expected call edges");
1278 }
1279
1280 #[test]
1283 fn ts_rust_extracts_structs_and_functions() {
1284 let source = br#"
1285use std::collections::HashMap;
1286
1287pub struct Config {
1288 name: String,
1289}
1290
1291pub enum Status {
1292 Active,
1293 Inactive,
1294}
1295
1296pub trait Runnable {
1297 fn run(&self);
1298}
1299
1300impl Runnable for Config {
1301 fn run(&self) {
1302 println!("{}", self.name);
1303 }
1304}
1305
1306pub fn main() {
1307 let c = Config { name: "test".into() };
1308 c.run();
1309}
1310"#;
1311 let result = try_extract(Path::new("lib.rs"), source, "rust").unwrap();
1312 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1313 assert!(
1314 labels.iter().any(|l| l.contains("Config")),
1315 "missing Config: {labels:?}"
1316 );
1317 assert!(
1318 labels.iter().any(|l| l.contains("Status")),
1319 "missing Status: {labels:?}"
1320 );
1321 assert!(
1322 labels.iter().any(|l| l.contains("Runnable")),
1323 "missing Runnable: {labels:?}"
1324 );
1325 assert!(
1326 labels.iter().any(|l| l.contains("main")),
1327 "missing main: {labels:?}"
1328 );
1329 assert!(result.nodes.iter().any(|n| n.node_type == NodeType::Struct));
1330 assert!(result.nodes.iter().any(|n| n.node_type == NodeType::Enum));
1331 assert!(result.nodes.iter().any(|n| n.node_type == NodeType::Trait));
1332 assert!(
1333 result.edges.iter().any(|e| e.relation == "implements"),
1334 "missing implements edge"
1335 );
1336 }
1337
1338 #[test]
1341 fn ts_js_extracts_functions_and_classes() {
1342 let source = br#"
1343import { useState } from 'react';
1344import axios from 'axios';
1345
1346export class ApiClient {
1347 constructor(baseUrl) {
1348 this.baseUrl = baseUrl;
1349 }
1350}
1351
1352export function fetchData(url) {
1353 return axios.get(url);
1354}
1355"#;
1356 let result = try_extract(Path::new("api.js"), source, "javascript").unwrap();
1357 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1358 assert!(
1359 labels.iter().any(|l| l.contains("ApiClient")),
1360 "missing ApiClient: {labels:?}"
1361 );
1362 assert!(
1363 labels.iter().any(|l| l.contains("fetchData")),
1364 "missing fetchData: {labels:?}"
1365 );
1366
1367 let import_count = result
1368 .edges
1369 .iter()
1370 .filter(|e| e.relation == "imports")
1371 .count();
1372 assert!(
1373 import_count >= 2,
1374 "expected >=2 imports, got {import_count}"
1375 );
1376 }
1377
1378 #[test]
1381 fn ts_go_extracts_types_and_functions() {
1382 let source = br#"
1383package main
1384
1385import (
1386 "fmt"
1387 "os"
1388)
1389
1390type Server struct {
1391 host string
1392 port int
1393}
1394
1395type Handler interface {
1396 Handle()
1397}
1398
1399func (s *Server) Start() {
1400 fmt.Println("starting")
1401}
1402
1403func main() {
1404 s := Server{host: "localhost", port: 8080}
1405 s.Start()
1406}
1407"#;
1408 let result = try_extract(Path::new("main.go"), source, "go").unwrap();
1409 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1410 assert!(
1411 labels.iter().any(|l| l.contains("Server")),
1412 "missing Server: {labels:?}"
1413 );
1414 assert!(
1415 labels.iter().any(|l| l.contains("Handler")),
1416 "missing Handler: {labels:?}"
1417 );
1418 assert!(
1419 labels.iter().any(|l| l.contains("Start")),
1420 "missing Start: {labels:?}"
1421 );
1422 assert!(
1423 labels.iter().any(|l| l.contains("main")),
1424 "missing main: {labels:?}"
1425 );
1426 assert!(result.nodes.iter().any(|n| n.node_type == NodeType::Struct));
1427 assert!(
1428 result
1429 .nodes
1430 .iter()
1431 .any(|n| n.node_type == NodeType::Interface)
1432 );
1433 }
1434
1435 #[test]
1438 fn ts_unsupported_returns_none() {
1439 assert!(try_extract(Path::new("test.pl"), b"sub foo { 1 }", "perl").is_none());
1440 }
1441
1442 #[test]
1445 fn ts_python_at_least_as_many_nodes_as_regex() {
1446 let source_str = r#"
1447class MyClass:
1448 def __init__(self):
1449 pass
1450
1451 def greet(self, name):
1452 return f"Hello {name}"
1453
1454def standalone():
1455 pass
1456"#;
1457 let regex_result =
1458 crate::ast_extract::extract_file(Path::new("test.py"), source_str, "python");
1459 let ts_result = try_extract(Path::new("test.py"), source_str.as_bytes(), "python").unwrap();
1460
1461 assert!(
1462 ts_result.nodes.len() >= regex_result.nodes.len(),
1463 "tree-sitter ({}) should produce >= nodes than regex ({})",
1464 ts_result.nodes.len(),
1465 regex_result.nodes.len()
1466 );
1467 }
1468
1469 #[test]
1470 fn all_edges_have_source_file() {
1471 let source = b"def foo():\n bar()\ndef bar():\n pass\n";
1472 let result = try_extract(Path::new("x.py"), source, "python").unwrap();
1473 for edge in &result.edges {
1474 assert!(!edge.source_file.is_empty());
1475 }
1476 }
1477
1478 #[test]
1479 fn node_ids_are_deterministic() {
1480 let source = b"def foo():\n pass\n";
1481 let r1 = try_extract(Path::new("test.py"), source, "python").unwrap();
1482 let r2 = try_extract(Path::new("test.py"), source, "python").unwrap();
1483 assert_eq!(r1.nodes.len(), r2.nodes.len());
1484 for (a, b) in r1.nodes.iter().zip(r2.nodes.iter()) {
1485 assert_eq!(a.id, b.id);
1486 }
1487 }
1488
1489 #[test]
1492 fn ts_java_extracts_class_and_methods() {
1493 let source = br#"
1494import java.util.List;
1495
1496public class Foo {
1497 public void bar() {}
1498 public int baz(String s) { return 0; }
1499}
1500"#;
1501 let result = try_extract(Path::new("Foo.java"), source, "java").unwrap();
1502 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1503 assert!(
1504 labels.iter().any(|l| l.contains("Foo")),
1505 "missing Foo: {labels:?}"
1506 );
1507 assert!(
1508 labels.iter().any(|l| l.contains("bar")),
1509 "missing bar: {labels:?}"
1510 );
1511 assert!(
1512 labels.iter().any(|l| l.contains("baz")),
1513 "missing baz: {labels:?}"
1514 );
1515 let import_count = result
1516 .edges
1517 .iter()
1518 .filter(|e| e.relation == "imports")
1519 .count();
1520 assert!(
1521 import_count >= 1,
1522 "expected >=1 imports, got {import_count}"
1523 );
1524 }
1525
1526 #[test]
1527 fn ts_java_extracts_interface() {
1528 let source = br#"
1529public interface Runnable {
1530 void run();
1531}
1532"#;
1533 let result = try_extract(Path::new("Runnable.java"), source, "java").unwrap();
1534 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1535 assert!(
1536 labels.iter().any(|l| l.contains("Runnable")),
1537 "missing Runnable: {labels:?}"
1538 );
1539 }
1540
1541 #[test]
1544 fn ts_c_extracts_functions() {
1545 let source = br#"
1546#include <stdio.h>
1547
1548int main(int argc, char **argv) {
1549 printf("hello\n");
1550 return 0;
1551}
1552
1553void helper(void) {}
1554"#;
1555 let result = try_extract(Path::new("main.c"), source, "c").unwrap();
1556 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1557 assert!(
1558 labels.iter().any(|l| l.contains("main")),
1559 "missing main: {labels:?}"
1560 );
1561 assert!(
1562 labels.iter().any(|l| l.contains("helper")),
1563 "missing helper: {labels:?}"
1564 );
1565 let import_count = result
1566 .edges
1567 .iter()
1568 .filter(|e| e.relation == "imports")
1569 .count();
1570 assert!(
1571 import_count >= 1,
1572 "expected >=1 imports, got {import_count}"
1573 );
1574 }
1575
1576 #[test]
1579 fn ts_cpp_extracts_class_and_functions() {
1580 let source = br#"
1581#include <iostream>
1582
1583class Greeter {
1584public:
1585 void greet() {
1586 std::cout << "hello" << std::endl;
1587 }
1588};
1589
1590int main() {
1591 Greeter g;
1592 g.greet();
1593 return 0;
1594}
1595"#;
1596 let result = try_extract(Path::new("main.cpp"), source, "cpp").unwrap();
1597 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1598 assert!(
1599 labels.iter().any(|l| l.contains("Greeter")),
1600 "missing Greeter: {labels:?}"
1601 );
1602 assert!(
1603 labels.iter().any(|l| l.contains("main")),
1604 "missing main: {labels:?}"
1605 );
1606 }
1607
1608 #[test]
1611 fn ts_ruby_extracts_class_and_methods() {
1612 let source = br#"
1613class Dog
1614 def initialize(name)
1615 @name = name
1616 end
1617
1618 def bark
1619 puts "Woof!"
1620 end
1621end
1622"#;
1623 let result = try_extract(Path::new("dog.rb"), source, "ruby").unwrap();
1624 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1625 assert!(
1626 labels.iter().any(|l| l.contains("Dog")),
1627 "missing Dog: {labels:?}"
1628 );
1629 assert!(
1630 labels.iter().any(|l| l.contains("initialize")),
1631 "missing initialize: {labels:?}"
1632 );
1633 assert!(
1634 labels.iter().any(|l| l.contains("bark")),
1635 "missing bark: {labels:?}"
1636 );
1637 }
1638
1639 #[test]
1642 fn ts_csharp_extracts_class_and_methods() {
1643 let source = br#"
1644using System;
1645using System.Collections.Generic;
1646
1647public class Calculator {
1648 public int Add(int a, int b) {
1649 return a + b;
1650 }
1651
1652 public int Subtract(int a, int b) {
1653 return a - b;
1654 }
1655}
1656"#;
1657 let result = try_extract(Path::new("Calculator.cs"), source, "csharp").unwrap();
1658 let labels: Vec<&str> = result.nodes.iter().map(|n| n.label.as_str()).collect();
1659 assert!(
1660 labels.iter().any(|l| l.contains("Calculator")),
1661 "missing Calculator: {labels:?}"
1662 );
1663 assert!(
1664 labels.iter().any(|l| l.contains("Add")),
1665 "missing Add: {labels:?}"
1666 );
1667 assert!(
1668 labels.iter().any(|l| l.contains("Subtract")),
1669 "missing Subtract: {labels:?}"
1670 );
1671 let import_count = result
1672 .edges
1673 .iter()
1674 .filter(|e| e.relation == "imports")
1675 .count();
1676 assert!(
1677 import_count >= 2,
1678 "expected >=2 imports, got {import_count}"
1679 );
1680 }
1681}