1use std::collections::{HashMap, HashSet};
2use std::ops::Range;
3use tree_sitter::{Language, Node};
4
5#[derive(Debug, Clone, Default)]
6pub struct Bone {
7 pub metadata: HashMap<String, String>,
8}
9
10pub struct Parser {}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum SymbolKind {
17 Function,
18 Method,
19 Class,
20 Struct,
21 Impl,
22 Interface,
23 }
25
26#[derive(Debug, Clone)]
28pub struct Symbol {
29 pub name: String,
31 pub qualified_name: String,
33 pub kind: SymbolKind,
35 pub full_range: Range<usize>,
37 pub body_range: Option<Range<usize>>,
39}
40
41#[derive(Debug, Clone)]
43pub struct ParsedDocument {
44 pub file_path: String,
45 pub symbols: Vec<Symbol>,
46}
47
48pub struct LanguageSpec {
50 pub language: Language,
52
53 pub symbol_node_types: HashMap<&'static str, SymbolKind>,
55
56 pub name_fields: HashMap<&'static str, &'static str>,
58
59 pub container_node_types: HashSet<&'static str>,
61
62 pub body_node_types: HashSet<&'static str>,
64}
65
66pub fn get_python_spec() -> LanguageSpec {
67 LanguageSpec {
68 language: tree_sitter_python::LANGUAGE.into(),
69 symbol_node_types: HashMap::from([
70 ("function_definition", SymbolKind::Function),
71 ("class_definition", SymbolKind::Class),
72 ]),
73 name_fields: HashMap::from([
74 ("function_definition", "name"),
75 ("class_definition", "name"),
76 ]),
77 container_node_types: HashSet::from(["class_definition"]),
78 body_node_types: HashSet::from(["block"]),
79 }
80}
81
82pub fn get_rust_spec() -> LanguageSpec {
83 LanguageSpec {
84 language: tree_sitter_rust::LANGUAGE.into(),
85 symbol_node_types: HashMap::from([
86 ("function_item", SymbolKind::Function),
87 ("struct_item", SymbolKind::Struct),
88 ("impl_item", SymbolKind::Impl),
89 ]),
90 name_fields: HashMap::from([
91 ("function_item", "name"),
92 ("struct_item", "name"),
93 ("impl_item", "type"),
94 ]),
95 container_node_types: HashSet::from(["impl_item"]),
96 body_node_types: HashSet::from(["block", "declaration_list"]),
97 }
98}
99
100pub fn get_java_spec() -> LanguageSpec {
101 LanguageSpec {
102 language: tree_sitter_java::LANGUAGE.into(),
103 symbol_node_types: std::collections::HashMap::from([
104 ("method_declaration", SymbolKind::Method),
105 ("class_declaration", SymbolKind::Class),
106 ("interface_declaration", SymbolKind::Interface),
107 ("enum_declaration", SymbolKind::Class),
108 ]),
109 name_fields: std::collections::HashMap::from([
110 ("method_declaration", "name"),
111 ("class_declaration", "name"),
112 ("interface_declaration", "name"),
113 ("enum_declaration", "name"),
114 ]),
115 container_node_types: std::collections::HashSet::from([
116 "class_declaration",
117 "interface_declaration",
118 ]),
119 body_node_types: std::collections::HashSet::from([
120 "block",
121 "class_body",
122 "interface_body",
123 "enum_body",
124 ]),
125 }
126}
127
128pub fn get_c_spec() -> LanguageSpec {
129 LanguageSpec {
130 language: tree_sitter_c::LANGUAGE.into(),
131 symbol_node_types: std::collections::HashMap::from([
132 ("function_definition", SymbolKind::Function),
133 ("struct_specifier", SymbolKind::Struct),
134 ("class_specifier", SymbolKind::Class),
135 ("namespace_definition", SymbolKind::Class),
136 ]),
137 name_fields: std::collections::HashMap::from([
138 ("function_definition", "declarator"),
139 ("struct_specifier", "name"),
140 ("class_specifier", "name"),
141 ("namespace_definition", "name"),
142 ]),
143 container_node_types: std::collections::HashSet::from([
144 "class_specifier",
145 "struct_specifier",
146 "namespace_definition",
147 ]),
148 body_node_types: std::collections::HashSet::from([
149 "compound_statement",
150 "field_declaration_list",
151 ]),
152 }
153}
154
155pub fn get_cpp_spec() -> LanguageSpec {
156 LanguageSpec {
157 language: tree_sitter_cpp::LANGUAGE.into(),
158 symbol_node_types: std::collections::HashMap::from([
159 ("function_definition", SymbolKind::Function),
160 ("struct_specifier", SymbolKind::Struct),
161 ("class_specifier", SymbolKind::Class),
162 ("namespace_definition", SymbolKind::Class),
163 ]),
164 name_fields: std::collections::HashMap::from([
165 ("function_definition", "declarator"),
166 ("struct_specifier", "name"),
167 ("class_specifier", "name"),
168 ("namespace_definition", "name"),
169 ]),
170 container_node_types: std::collections::HashSet::from([
171 "class_specifier",
172 "struct_specifier",
173 "namespace_definition",
174 ]),
175 body_node_types: std::collections::HashSet::from([
176 "compound_statement",
177 "field_declaration_list",
178 ]),
179 }
180}
181
182pub fn get_csharp_spec() -> LanguageSpec {
183 LanguageSpec {
184 language: tree_sitter_c_sharp::LANGUAGE.into(),
185 symbol_node_types: std::collections::HashMap::from([
186 ("method_declaration", SymbolKind::Method),
187 ("class_declaration", SymbolKind::Class),
188 ("interface_declaration", SymbolKind::Interface),
189 ("struct_declaration", SymbolKind::Struct),
190 ("namespace_declaration", SymbolKind::Class),
191 ]),
192 name_fields: std::collections::HashMap::from([
193 ("method_declaration", "name"),
194 ("class_declaration", "name"),
195 ("interface_declaration", "name"),
196 ("struct_declaration", "name"),
197 ("namespace_declaration", "name"),
198 ]),
199 container_node_types: std::collections::HashSet::from([
200 "class_declaration",
201 "interface_declaration",
202 "namespace_declaration",
203 "struct_declaration",
204 ]),
205 body_node_types: std::collections::HashSet::from(["block", "declaration_list"]),
206 }
207}
208
209pub fn get_ruby_spec() -> LanguageSpec {
210 LanguageSpec {
211 language: tree_sitter_ruby::LANGUAGE.into(),
212 symbol_node_types: std::collections::HashMap::from([
213 ("method", SymbolKind::Method),
214 ("singleton_method", SymbolKind::Method),
215 ("class", SymbolKind::Class),
216 ("module", SymbolKind::Class),
217 ]),
218 name_fields: std::collections::HashMap::from([
219 ("method", "name"),
220 ("singleton_method", "name"),
221 ("class", "name"),
222 ("module", "name"),
223 ]),
224 container_node_types: std::collections::HashSet::from(["class", "module"]),
225 body_node_types: std::collections::HashSet::from(["body", "do_block", "begin_block"]),
226 }
227}
228
229pub fn get_php_spec() -> LanguageSpec {
230 LanguageSpec {
231 language: tree_sitter_php::LANGUAGE_PHP.into(),
232 symbol_node_types: std::collections::HashMap::from([
233 ("function_definition", SymbolKind::Function),
234 ("method_declaration", SymbolKind::Method),
235 ("class_declaration", SymbolKind::Class),
236 ("interface_declaration", SymbolKind::Interface),
237 ("trait_declaration", SymbolKind::Class),
238 ]),
239 name_fields: std::collections::HashMap::from([
240 ("function_definition", "name"),
241 ("method_declaration", "name"),
242 ("class_declaration", "name"),
243 ("interface_declaration", "name"),
244 ("trait_declaration", "name"),
245 ]),
246 container_node_types: std::collections::HashSet::from([
247 "class_declaration",
248 "interface_declaration",
249 "trait_declaration",
250 ]),
251 body_node_types: std::collections::HashSet::from([
252 "compound_statement",
253 "declaration_list",
254 ]),
255 }
256}
257
258pub fn get_swift_spec() -> LanguageSpec {
259 LanguageSpec {
260 language: tree_sitter_swift::LANGUAGE.into(),
261 symbol_node_types: std::collections::HashMap::from([
262 ("function_declaration", SymbolKind::Function),
263 ("class_declaration", SymbolKind::Class),
264 ("struct_declaration", SymbolKind::Struct),
265 ("protocol_declaration", SymbolKind::Interface),
266 ("extension_declaration", SymbolKind::Impl),
267 ]),
268 name_fields: std::collections::HashMap::from([
269 ("function_declaration", "name"),
270 ("class_declaration", "name"),
271 ("struct_declaration", "name"),
272 ("protocol_declaration", "name"),
273 ("extension_declaration", "type"),
274 ]),
275 container_node_types: std::collections::HashSet::from([
276 "class_declaration",
277 "struct_declaration",
278 "protocol_declaration",
279 "extension_declaration",
280 ]),
281 body_node_types: std::collections::HashSet::from([
282 "class_body",
283 "function_body",
284 "code_block",
285 ]),
286 }
287}
288
289pub fn get_spec_for_extension(ext: &str) -> Option<LanguageSpec> {
290 match ext {
291 "rs" => Some(get_rust_spec()),
292 "py" => Some(get_python_spec()),
293 "go" => Some(get_go_spec()),
294 "ts" | "tsx" | "js" | "jsx" => Some(get_typescript_spec()),
295 "java" => Some(get_java_spec()),
296 "c" | "h" => Some(get_c_spec()),
297 "cpp" | "hpp" | "cc" | "cxx" => Some(get_cpp_spec()),
298 "cs" => Some(get_csharp_spec()),
299 "rb" => Some(get_ruby_spec()),
300 "php" => Some(get_php_spec()),
301 "swift" => Some(get_swift_spec()),
302 _ => None,
303 }
304}
305
306pub fn get_go_spec() -> LanguageSpec {
307 LanguageSpec {
308 language: tree_sitter_go::LANGUAGE.into(),
309 symbol_node_types: std::collections::HashMap::from([
310 ("function_declaration", SymbolKind::Function),
311 ("method_declaration", SymbolKind::Method),
312 ("type_declaration", SymbolKind::Struct),
313 ]),
314 name_fields: std::collections::HashMap::from([
315 ("function_declaration", "name"),
316 ("method_declaration", "name"),
317 ("type_declaration", "name"),
318 ]),
319 container_node_types: std::collections::HashSet::from(["type_declaration"]),
320 body_node_types: std::collections::HashSet::from(["block", "type_spec"]),
321 }
322}
323
324pub fn get_typescript_spec() -> LanguageSpec {
325 LanguageSpec {
326 language: tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
327 symbol_node_types: std::collections::HashMap::from([
328 ("function_declaration", SymbolKind::Function),
329 ("method_definition", SymbolKind::Method),
330 ("class_declaration", SymbolKind::Class),
331 ("interface_declaration", SymbolKind::Interface),
332 ]),
333 name_fields: std::collections::HashMap::from([
334 ("function_declaration", "name"),
335 ("method_definition", "name"),
336 ("class_declaration", "name"),
337 ("interface_declaration", "name"),
338 ]),
339 container_node_types: std::collections::HashSet::from([
340 "class_declaration",
341 "interface_declaration",
342 ]),
343 body_node_types: std::collections::HashSet::from([
344 "statement_block",
345 "class_body",
346 "object_type",
347 ]),
348 }
349}
350
351pub fn parse_file(source: &str, spec: &LanguageSpec) -> ParsedDocument {
352 let mut parser = tree_sitter::Parser::new();
353 parser
354 .set_language(&spec.language)
355 .expect("Error loading language");
356 let tree = parser.parse(source, None).expect("Error parsing source");
357 let root_node = tree.root_node();
358
359 let mut symbols = Vec::new();
360 walk_tree(root_node, source.as_bytes(), spec, None, &mut symbols);
361
362 ParsedDocument {
363 file_path: String::new(),
364 symbols,
365 }
366}
367
368fn walk_tree(
369 node: Node,
370 source: &[u8],
371 spec: &LanguageSpec,
372 parent_symbol: Option<&Symbol>,
373 symbols: &mut Vec<Symbol>,
374) {
375 let kind = node.kind();
376 let mut current_symbol = None;
377
378 if let Some(symbol_kind) = spec.symbol_node_types.get(kind) {
379 let mut name = None;
380
381 if let Some(name_field) = spec.name_fields.get(kind) {
382 if let Some(mut child) = node.child_by_field_name(name_field) {
383 while child.kind() == "function_declarator"
384 || child.kind() == "pointer_declarator"
385 || child.kind() == "reference_declarator"
386 {
387 if let Some(inner) = child.child_by_field_name("declarator") {
388 child = inner;
389 } else {
390 break;
391 }
392 }
393 if let Ok(text) = std::str::from_utf8(&source[child.start_byte()..child.end_byte()])
394 {
395 name = Some(text.to_string());
396 }
397 }
398 }
399
400 if name.is_none() {
401 let mut cursor = node.walk();
402 for child in node.children(&mut cursor) {
403 let child_kind = child.kind();
404 if child_kind == "identifier" || child_kind == "type_identifier" {
405 if let Ok(text) =
406 std::str::from_utf8(&source[child.start_byte()..child.end_byte()])
407 {
408 name = Some(text.to_string());
409 break;
410 }
411 }
412 }
413 }
414
415 if let Some(name) = name {
416 let qualified_name = if let Some(parent) = parent_symbol {
417 format!("{}.{}", parent.qualified_name, name)
418 } else {
419 name.clone()
420 };
421
422 let mut body_range = None;
423 let mut body_node_opt = node.child_by_field_name("body");
424
425 if body_node_opt.is_none() {
426 let mut cursor = node.walk();
427 for child in node.children(&mut cursor) {
428 if spec.body_node_types.contains(child.kind()) {
429 body_node_opt = Some(child);
430 break;
431 }
432 }
433 }
434
435 if let Some(body_node) = body_node_opt {
436 let mut start = body_node.start_byte();
437 if let Some(prev) = body_node.prev_sibling() {
438 if prev.kind() == ":" {
439 start = prev.end_byte();
440 } else {
441 let mut has_newline = false;
442 for i in prev.end_byte()..start {
443 if i < source.len() && (source[i] == b'\n' || source[i] == b'\r') {
444 has_newline = true;
445 break;
446 }
447 }
448 if has_newline {
449 start = prev.end_byte();
450 }
451 }
452 }
453 body_range = Some(start..body_node.end_byte());
454 }
455
456 let symbol = Symbol {
457 name,
458 qualified_name,
459 kind: symbol_kind.clone(),
460 full_range: node.start_byte()..node.end_byte(),
461 body_range,
462 };
463
464 symbols.push(symbol.clone());
465 current_symbol = Some(symbol);
466 }
467 }
468
469 let next_parent = current_symbol.as_ref().or(parent_symbol);
470
471 let mut cursor = node.walk();
472 for child in node.children(&mut cursor) {
473 walk_tree(child, source, spec, next_parent, symbols);
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 fn elide_document(source: &str, doc: &ParsedDocument) -> String {
482 let mut result = String::new();
483 let mut last_end = 0;
484
485 let mut sorted_symbols = doc.symbols.clone();
486 sorted_symbols.sort_by_key(|s| s.full_range.start);
487
488 for sym in sorted_symbols {
489 if let Some(body_range) = &sym.body_range {
490 if body_range.start >= last_end {
491 result.push_str(&source[last_end..body_range.start]);
492 result.push_str("...");
493 last_end = body_range.end;
494 }
495 }
496 }
497 result.push_str(&source[last_end..]);
498 result
499 }
500
501 #[test]
502 fn test_extract_python_class_signature_and_elide_body() {
503 let source = "class MyClass:\n def __init__(self):\n pass";
504 let spec = get_python_spec();
505 let doc = parse_file(source, &spec);
506
507 assert_eq!(doc.symbols.len(), 2);
508
509 let class_sym = doc.symbols.iter().find(|s| s.name == "MyClass").unwrap();
510 assert_eq!(class_sym.kind, SymbolKind::Class);
511
512 let elided = elide_document(source, &doc);
513 assert_eq!(elided, "class MyClass:...");
514 }
515
516 #[test]
517 fn test_extract_python_function_signature_and_elide_body() {
518 let source = "def calculate_total(a: int, b: int) -> int:\n return a + b";
519 let spec = get_python_spec();
520 let doc = parse_file(source, &spec);
521
522 assert_eq!(doc.symbols.len(), 1);
523 let sym = &doc.symbols[0];
524 assert_eq!(sym.name, "calculate_total");
525
526 let elided = elide_document(source, &doc);
527 assert_eq!(elided, "def calculate_total(a: int, b: int) -> int:...");
528 }
529
530 #[test]
531 fn test_extract_rust_struct_signature_and_elide_body() {
532 let source = "pub struct User {\n pub id: i32,\n pub name: String,\n}";
533 let spec = get_rust_spec();
534 let doc = parse_file(source, &spec);
535
536 assert_eq!(doc.symbols.len(), 1);
537 let sym = &doc.symbols[0];
538 assert_eq!(sym.name, "User");
539
540 let elided = elide_document(source, &doc);
541 assert_eq!(elided, "pub struct User ...");
542 }
543
544 #[test]
545 fn test_extract_rust_function_signature_and_elide_body() {
546 let source = "pub fn process_data(data: &[u8]) -> Result<(), Error> {\n // do work\n Ok(())\n}";
547 let spec = get_rust_spec();
548 let doc = parse_file(source, &spec);
549
550 assert_eq!(doc.symbols.len(), 1);
551 let sym = &doc.symbols[0];
552 assert_eq!(sym.name, "process_data");
553
554 let elided = elide_document(source, &doc);
555 assert_eq!(
556 elided,
557 "pub fn process_data(data: &[u8]) -> Result<(), Error> ..."
558 );
559 }
560
561 #[test]
562 fn test_handle_nested_functions_classes() {
563 let source =
564 "class MyClass:\n def my_method(self):\n def nested():\n pass";
565 let spec = get_python_spec();
566 let doc = parse_file(source, &spec);
567
568 assert_eq!(doc.symbols.len(), 3);
569
570 let method_sym = doc.symbols.iter().find(|s| s.name == "my_method").unwrap();
571 assert_eq!(method_sym.qualified_name, "MyClass.my_method");
572
573 let nested_sym = doc.symbols.iter().find(|s| s.name == "nested").unwrap();
574 assert_eq!(nested_sym.qualified_name, "MyClass.my_method.nested");
575 }
576
577 #[test]
578 fn test_use_fallback_name_extraction() {
579 let source = "def calculate_total(a: int, b: int) -> int:\n return a + b";
582 let mut spec = get_python_spec();
583 spec.name_fields.remove("function_definition");
585
586 let doc = parse_file(source, &spec);
587
588 assert_eq!(doc.symbols.len(), 1);
589 let sym = &doc.symbols[0];
590 assert_eq!(sym.name, "calculate_total");
591 }
592
593 #[test]
594 fn test_ignore_empty_files_or_no_symbols() {
595 let source = "# just a comment\n\n";
596 let spec = get_python_spec();
597 let doc = parse_file(source, &spec);
598
599 assert!(doc.symbols.is_empty());
600 }
601
602 #[test]
603 fn test_extract_java_class_and_method_elide_body() {
604 let source = "public class MyClass {\n public void doWork() {\n System.out.println(\"work\");\n }\n}";
605 let spec = get_java_spec();
606 let doc = parse_file(source, &spec);
607
608 assert_eq!(doc.symbols.len(), 2);
609
610 let class_sym = doc.symbols.iter().find(|s| s.name == "MyClass").unwrap();
611 assert_eq!(class_sym.kind, SymbolKind::Class);
612 assert!(class_sym.body_range.is_some());
613
614 let method_sym = doc.symbols.iter().find(|s| s.name == "doWork").unwrap();
615 assert_eq!(method_sym.kind, SymbolKind::Method);
616 assert_eq!(method_sym.qualified_name, "MyClass.doWork");
617 assert!(method_sym.body_range.is_some());
618
619 let elided = elide_document(source, &doc);
620 assert!(elided.starts_with("public class MyClass ..."));
621 }
622
623 #[test]
624 fn test_extract_c_cpp_function_elide_body() {
625 let source = "int calculate(int a, int b) {\n return a + b;\n}";
626 let spec = get_c_spec();
627 let doc = parse_file(source, &spec);
628
629 assert_eq!(doc.symbols.len(), 1);
630
631 let func_sym = doc.symbols.iter().find(|s| s.name == "calculate").unwrap();
632 assert_eq!(func_sym.kind, SymbolKind::Function);
633 assert!(func_sym.body_range.is_some());
634
635 let elided = elide_document(source, &doc);
636 assert!(elided.starts_with("int calculate(int a, int b) ..."));
637 }
638
639 #[test]
640 fn test_extract_csharp_class_and_method_elide_body() {
641 let source = "public class Server {\n public async Task StartAsync() {\n await Task.Delay(10);\n }\n}";
642 let spec = get_csharp_spec();
643 let doc = parse_file(source, &spec);
644
645 assert_eq!(doc.symbols.len(), 2);
646
647 let class_sym = doc.symbols.iter().find(|s| s.name == "Server").unwrap();
648 assert_eq!(class_sym.kind, SymbolKind::Class);
649
650 let method_sym = doc.symbols.iter().find(|s| s.name == "StartAsync").unwrap();
651 assert_eq!(method_sym.kind, SymbolKind::Method);
652 assert_eq!(method_sym.qualified_name, "Server.StartAsync");
653
654 let elided = elide_document(source, &doc);
655 assert!(elided.starts_with("public class Server ..."));
656 }
657
658 #[test]
659 fn test_extract_ruby_class_and_method_elide_body() {
660 let source = "class User\n def login(email)\n puts 'login'\n end\nend";
661 let spec = get_ruby_spec();
662 let doc = parse_file(source, &spec);
663
664 assert_eq!(doc.symbols.len(), 2);
665
666 let class_sym = doc.symbols.iter().find(|s| s.name == "User").unwrap();
667 assert_eq!(class_sym.kind, SymbolKind::Class);
668
669 let method_sym = doc.symbols.iter().find(|s| s.name == "login").unwrap();
670 assert_eq!(method_sym.kind, SymbolKind::Method);
671 assert_eq!(method_sym.qualified_name, "User.login");
672
673 let elided = elide_document(source, &doc);
674 assert!(elided.starts_with("class User..."));
675 }
676
677 #[test]
678 fn test_extract_php_class_and_method_elide_body() {
679 let source = "<?php\nclass Controller {\n public function handle($req) {\n return true;\n }\n}";
680 let spec = get_php_spec();
681 let doc = parse_file(source, &spec);
682
683 assert_eq!(doc.symbols.len(), 2);
684
685 let class_sym = doc.symbols.iter().find(|s| s.name == "Controller").unwrap();
686 assert_eq!(class_sym.kind, SymbolKind::Class);
687
688 let method_sym = doc.symbols.iter().find(|s| s.name == "handle").unwrap();
689 assert_eq!(method_sym.kind, SymbolKind::Method);
690 assert_eq!(method_sym.qualified_name, "Controller.handle");
691
692 let elided = elide_document(source, &doc);
693 assert!(elided.contains("class Controller ..."));
694 }
695
696 #[test]
697 fn test_extract_swift_class_and_function_elide_body() {
698 let source =
699 "class ViewModel {\n func loadData(with id: String) {\n print(id)\n }\n}";
700 let spec = get_swift_spec();
701 let doc = parse_file(source, &spec);
702
703 assert_eq!(doc.symbols.len(), 2);
704
705 let class_sym = doc.symbols.iter().find(|s| s.name == "ViewModel").unwrap();
706 assert_eq!(class_sym.kind, SymbolKind::Class);
707
708 let method_sym = doc.symbols.iter().find(|s| s.name == "loadData").unwrap();
709 assert_eq!(method_sym.kind, SymbolKind::Function);
710 assert_eq!(method_sym.qualified_name, "ViewModel.loadData");
711
712 let elided = elide_document(source, &doc);
713 assert!(elided.starts_with("class ViewModel ..."));
714 }
715}