1use std::collections::{HashMap, HashSet};
2use std::ops::Range;
3use tree_sitter::{Language, Node};
4
5#[derive(Debug, Clone, Default)]
6pub struct Bone {
7 pub metadata: HashMap<String, String>,
8}
9
10pub struct Parser {}
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum SymbolKind {
15 Function,
16 Method,
17 Class,
18 Struct,
19 Impl,
20 Interface,
21 }
23
24#[derive(Debug, Clone)]
26pub struct Symbol {
27 pub name: String,
29 pub qualified_name: String,
31 pub kind: SymbolKind,
33 pub full_range: Range<usize>,
35 pub body_range: Option<Range<usize>>,
37}
38
39#[derive(Debug, Clone)]
41pub struct ParsedDocument {
42 pub file_path: String,
43 pub symbols: Vec<Symbol>,
44}
45
46pub struct LanguageSpec {
48 pub language: Language,
50
51 pub symbol_node_types: HashMap<&'static str, SymbolKind>,
53
54 pub name_fields: HashMap<&'static str, &'static str>,
56
57 pub container_node_types: HashSet<&'static str>,
59
60 pub body_node_types: HashSet<&'static str>,
62}
63
64pub fn get_python_spec() -> LanguageSpec {
65 LanguageSpec {
66 language: tree_sitter_python::LANGUAGE.into(),
67 symbol_node_types: HashMap::from([
68 ("function_definition", SymbolKind::Function),
69 ("class_definition", SymbolKind::Class),
70 ]),
71 name_fields: HashMap::from([
72 ("function_definition", "name"),
73 ("class_definition", "name"),
74 ]),
75 container_node_types: HashSet::from(["class_definition"]),
76 body_node_types: HashSet::from(["block"]),
77 }
78}
79
80pub fn get_rust_spec() -> LanguageSpec {
81 LanguageSpec {
82 language: tree_sitter_rust::LANGUAGE.into(),
83 symbol_node_types: HashMap::from([
84 ("function_item", SymbolKind::Function),
85 ("struct_item", SymbolKind::Struct),
86 ("impl_item", SymbolKind::Impl),
87 ]),
88 name_fields: HashMap::from([
89 ("function_item", "name"),
90 ("struct_item", "name"),
91 ("impl_item", "type"),
92 ]),
93 container_node_types: HashSet::from(["impl_item"]),
94 body_node_types: HashSet::from(["block", "declaration_list"]),
95 }
96}
97
98pub fn get_java_spec() -> LanguageSpec {
99 LanguageSpec {
100 language: tree_sitter_java::LANGUAGE.into(),
101 symbol_node_types: std::collections::HashMap::from([
102 ("method_declaration", SymbolKind::Method),
103 ("class_declaration", SymbolKind::Class),
104 ("interface_declaration", SymbolKind::Interface),
105 ("enum_declaration", SymbolKind::Class),
106 ]),
107 name_fields: std::collections::HashMap::from([
108 ("method_declaration", "name"),
109 ("class_declaration", "name"),
110 ("interface_declaration", "name"),
111 ("enum_declaration", "name"),
112 ]),
113 container_node_types: std::collections::HashSet::from([
114 "class_declaration",
115 "interface_declaration",
116 ]),
117 body_node_types: std::collections::HashSet::from([
118 "block",
119 "class_body",
120 "interface_body",
121 "enum_body",
122 ]),
123 }
124}
125
126pub fn get_c_spec() -> LanguageSpec {
127 LanguageSpec {
128 language: tree_sitter_c::LANGUAGE.into(),
129 symbol_node_types: std::collections::HashMap::from([
130 ("function_definition", SymbolKind::Function),
131 ("struct_specifier", SymbolKind::Struct),
132 ("class_specifier", SymbolKind::Class),
133 ("namespace_definition", SymbolKind::Class),
134 ]),
135 name_fields: std::collections::HashMap::from([
136 ("function_definition", "declarator"),
137 ("struct_specifier", "name"),
138 ("class_specifier", "name"),
139 ("namespace_definition", "name"),
140 ]),
141 container_node_types: std::collections::HashSet::from([
142 "class_specifier",
143 "struct_specifier",
144 "namespace_definition",
145 ]),
146 body_node_types: std::collections::HashSet::from([
147 "compound_statement",
148 "field_declaration_list",
149 ]),
150 }
151}
152
153pub fn get_cpp_spec() -> LanguageSpec {
154 LanguageSpec {
155 language: tree_sitter_cpp::LANGUAGE.into(),
156 symbol_node_types: std::collections::HashMap::from([
157 ("function_definition", SymbolKind::Function),
158 ("struct_specifier", SymbolKind::Struct),
159 ("class_specifier", SymbolKind::Class),
160 ("namespace_definition", SymbolKind::Class),
161 ]),
162 name_fields: std::collections::HashMap::from([
163 ("function_definition", "declarator"),
164 ("struct_specifier", "name"),
165 ("class_specifier", "name"),
166 ("namespace_definition", "name"),
167 ]),
168 container_node_types: std::collections::HashSet::from([
169 "class_specifier",
170 "struct_specifier",
171 "namespace_definition",
172 ]),
173 body_node_types: std::collections::HashSet::from([
174 "compound_statement",
175 "field_declaration_list",
176 ]),
177 }
178}
179
180pub fn get_csharp_spec() -> LanguageSpec {
181 LanguageSpec {
182 language: tree_sitter_c_sharp::LANGUAGE.into(),
183 symbol_node_types: std::collections::HashMap::from([
184 ("method_declaration", SymbolKind::Method),
185 ("class_declaration", SymbolKind::Class),
186 ("interface_declaration", SymbolKind::Interface),
187 ("struct_declaration", SymbolKind::Struct),
188 ("namespace_declaration", SymbolKind::Class),
189 ]),
190 name_fields: std::collections::HashMap::from([
191 ("method_declaration", "name"),
192 ("class_declaration", "name"),
193 ("interface_declaration", "name"),
194 ("struct_declaration", "name"),
195 ("namespace_declaration", "name"),
196 ]),
197 container_node_types: std::collections::HashSet::from([
198 "class_declaration",
199 "interface_declaration",
200 "namespace_declaration",
201 "struct_declaration",
202 ]),
203 body_node_types: std::collections::HashSet::from(["block", "declaration_list"]),
204 }
205}
206
207pub fn get_ruby_spec() -> LanguageSpec {
208 LanguageSpec {
209 language: tree_sitter_ruby::LANGUAGE.into(),
210 symbol_node_types: std::collections::HashMap::from([
211 ("method", SymbolKind::Method),
212 ("singleton_method", SymbolKind::Method),
213 ("class", SymbolKind::Class),
214 ("module", SymbolKind::Class),
215 ]),
216 name_fields: std::collections::HashMap::from([
217 ("method", "name"),
218 ("singleton_method", "name"),
219 ("class", "name"),
220 ("module", "name"),
221 ]),
222 container_node_types: std::collections::HashSet::from(["class", "module"]),
223 body_node_types: std::collections::HashSet::from(["body", "do_block", "begin_block"]),
224 }
225}
226
227pub fn get_php_spec() -> LanguageSpec {
228 LanguageSpec {
229 language: tree_sitter_php::LANGUAGE_PHP.into(),
230 symbol_node_types: std::collections::HashMap::from([
231 ("function_definition", SymbolKind::Function),
232 ("method_declaration", SymbolKind::Method),
233 ("class_declaration", SymbolKind::Class),
234 ("interface_declaration", SymbolKind::Interface),
235 ("trait_declaration", SymbolKind::Class),
236 ]),
237 name_fields: std::collections::HashMap::from([
238 ("function_definition", "name"),
239 ("method_declaration", "name"),
240 ("class_declaration", "name"),
241 ("interface_declaration", "name"),
242 ("trait_declaration", "name"),
243 ]),
244 container_node_types: std::collections::HashSet::from([
245 "class_declaration",
246 "interface_declaration",
247 "trait_declaration",
248 ]),
249 body_node_types: std::collections::HashSet::from([
250 "compound_statement",
251 "declaration_list",
252 ]),
253 }
254}
255
256pub fn get_swift_spec() -> LanguageSpec {
257 LanguageSpec {
258 language: tree_sitter_swift::LANGUAGE.into(),
259 symbol_node_types: std::collections::HashMap::from([
260 ("function_declaration", SymbolKind::Function),
261 ("class_declaration", SymbolKind::Class),
262 ("struct_declaration", SymbolKind::Struct),
263 ("protocol_declaration", SymbolKind::Interface),
264 ("extension_declaration", SymbolKind::Impl),
265 ]),
266 name_fields: std::collections::HashMap::from([
267 ("function_declaration", "name"),
268 ("class_declaration", "name"),
269 ("struct_declaration", "name"),
270 ("protocol_declaration", "name"),
271 ("extension_declaration", "type"),
272 ]),
273 container_node_types: std::collections::HashSet::from([
274 "class_declaration",
275 "struct_declaration",
276 "protocol_declaration",
277 "extension_declaration",
278 ]),
279 body_node_types: std::collections::HashSet::from([
280 "class_body",
281 "function_body",
282 "code_block",
283 ]),
284 }
285}
286
287pub fn get_spec_for_extension(ext: &str) -> Option<LanguageSpec> {
288 match ext {
289 "rs" => Some(get_rust_spec()),
290 "py" => Some(get_python_spec()),
291 "go" => Some(get_go_spec()),
292 "ts" | "tsx" | "js" | "jsx" => Some(get_typescript_spec()),
293 "java" => Some(get_java_spec()),
294 "c" | "h" => Some(get_c_spec()),
295 "cpp" | "hpp" | "cc" | "cxx" => Some(get_cpp_spec()),
296 "cs" => Some(get_csharp_spec()),
297 "rb" => Some(get_ruby_spec()),
298 "php" => Some(get_php_spec()),
299 "swift" => Some(get_swift_spec()),
300 _ => None,
301 }
302}
303
304pub fn get_go_spec() -> LanguageSpec {
305 LanguageSpec {
306 language: tree_sitter_go::LANGUAGE.into(),
307 symbol_node_types: std::collections::HashMap::from([
308 ("function_declaration", SymbolKind::Function),
309 ("method_declaration", SymbolKind::Method),
310 ("type_declaration", SymbolKind::Struct),
311 ]),
312 name_fields: std::collections::HashMap::from([
313 ("function_declaration", "name"),
314 ("method_declaration", "name"),
315 ("type_declaration", "name"),
316 ]),
317 container_node_types: std::collections::HashSet::from(["type_declaration"]),
318 body_node_types: std::collections::HashSet::from(["block", "type_spec"]),
319 }
320}
321
322pub fn get_typescript_spec() -> LanguageSpec {
323 LanguageSpec {
324 language: tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
325 symbol_node_types: std::collections::HashMap::from([
326 ("function_declaration", SymbolKind::Function),
327 ("method_definition", SymbolKind::Method),
328 ("class_declaration", SymbolKind::Class),
329 ("interface_declaration", SymbolKind::Interface),
330 ]),
331 name_fields: std::collections::HashMap::from([
332 ("function_declaration", "name"),
333 ("method_definition", "name"),
334 ("class_declaration", "name"),
335 ("interface_declaration", "name"),
336 ]),
337 container_node_types: std::collections::HashSet::from([
338 "class_declaration",
339 "interface_declaration",
340 ]),
341 body_node_types: std::collections::HashSet::from([
342 "statement_block",
343 "class_body",
344 "object_type",
345 ]),
346 }
347}
348
349pub fn parse_file(source: &str, spec: &LanguageSpec) -> ParsedDocument {
350 let mut parser = tree_sitter::Parser::new();
351 parser
352 .set_language(&spec.language)
353 .expect("Error loading language");
354 let tree = parser.parse(source, None).expect("Error parsing source");
355 let root_node = tree.root_node();
356
357 let mut symbols = Vec::new();
358 walk_tree(root_node, source.as_bytes(), spec, None, &mut symbols);
359
360 ParsedDocument {
361 file_path: String::new(),
362 symbols,
363 }
364}
365
366fn walk_tree(
367 node: Node,
368 source: &[u8],
369 spec: &LanguageSpec,
370 parent_symbol: Option<&Symbol>,
371 symbols: &mut Vec<Symbol>,
372) {
373 let kind = node.kind();
374 let mut current_symbol = None;
375
376 if let Some(symbol_kind) = spec.symbol_node_types.get(kind) {
377 let mut name = None;
378
379 if let Some(name_field) = spec.name_fields.get(kind) {
380 if let Some(mut child) = node.child_by_field_name(name_field) {
381 while child.kind() == "function_declarator"
382 || child.kind() == "pointer_declarator"
383 || child.kind() == "reference_declarator"
384 {
385 if let Some(inner) = child.child_by_field_name("declarator") {
386 child = inner;
387 } else {
388 break;
389 }
390 }
391 if let Ok(text) = std::str::from_utf8(&source[child.start_byte()..child.end_byte()])
392 {
393 name = Some(text.to_string());
394 }
395 }
396 }
397
398 if name.is_none() {
399 let mut cursor = node.walk();
400 for child in node.children(&mut cursor) {
401 let child_kind = child.kind();
402 if child_kind == "identifier" || child_kind == "type_identifier" {
403 if let Ok(text) =
404 std::str::from_utf8(&source[child.start_byte()..child.end_byte()])
405 {
406 name = Some(text.to_string());
407 break;
408 }
409 }
410 }
411 }
412
413 if let Some(name) = name {
414 let qualified_name = if let Some(parent) = parent_symbol {
415 format!("{}.{}", parent.qualified_name, name)
416 } else {
417 name.clone()
418 };
419
420 let mut body_range = None;
421 let mut body_node_opt = node.child_by_field_name("body");
422
423 if body_node_opt.is_none() {
424 let mut cursor = node.walk();
425 for child in node.children(&mut cursor) {
426 if spec.body_node_types.contains(child.kind()) {
427 body_node_opt = Some(child);
428 break;
429 }
430 }
431 }
432
433 if let Some(body_node) = body_node_opt {
434 let mut start = body_node.start_byte();
435 if let Some(prev) = body_node.prev_sibling() {
436 if prev.kind() == ":" {
437 start = prev.end_byte();
438 } else {
439 let mut has_newline = false;
440 for i in prev.end_byte()..start {
441 if i < source.len() && (source[i] == b'\n' || source[i] == b'\r') {
442 has_newline = true;
443 break;
444 }
445 }
446 if has_newline {
447 start = prev.end_byte();
448 }
449 }
450 }
451 body_range = Some(start..body_node.end_byte());
452 }
453
454 let symbol = Symbol {
455 name,
456 qualified_name,
457 kind: symbol_kind.clone(),
458 full_range: node.start_byte()..node.end_byte(),
459 body_range,
460 };
461
462 symbols.push(symbol.clone());
463 current_symbol = Some(symbol);
464 }
465 }
466
467 let next_parent = current_symbol.as_ref().or(parent_symbol);
468
469 let mut cursor = node.walk();
470 for child in node.children(&mut cursor) {
471 walk_tree(child, source, spec, next_parent, symbols);
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 fn elide_document(source: &str, doc: &ParsedDocument) -> String {
480 let mut result = String::new();
481 let mut last_end = 0;
482
483 let mut sorted_symbols = doc.symbols.clone();
484 sorted_symbols.sort_by_key(|s| s.full_range.start);
485
486 for sym in sorted_symbols {
487 if let Some(body_range) = &sym.body_range {
488 if body_range.start >= last_end {
489 result.push_str(&source[last_end..body_range.start]);
490 result.push_str("...");
491 last_end = body_range.end;
492 }
493 }
494 }
495 result.push_str(&source[last_end..]);
496 result
497 }
498
499 #[test]
500 fn test_extract_python_class_signature_and_elide_body() {
501 let source = "class MyClass:\n def __init__(self):\n pass";
502 let spec = get_python_spec();
503 let doc = parse_file(source, &spec);
504
505 assert_eq!(doc.symbols.len(), 2);
506
507 let class_sym = doc.symbols.iter().find(|s| s.name == "MyClass").unwrap();
508 assert_eq!(class_sym.kind, SymbolKind::Class);
509
510 let elided = elide_document(source, &doc);
511 assert_eq!(elided, "class MyClass:...");
512 }
513
514 #[test]
515 fn test_extract_python_function_signature_and_elide_body() {
516 let source = "def calculate_total(a: int, b: int) -> int:\n return a + b";
517 let spec = get_python_spec();
518 let doc = parse_file(source, &spec);
519
520 assert_eq!(doc.symbols.len(), 1);
521 let sym = &doc.symbols[0];
522 assert_eq!(sym.name, "calculate_total");
523
524 let elided = elide_document(source, &doc);
525 assert_eq!(elided, "def calculate_total(a: int, b: int) -> int:...");
526 }
527
528 #[test]
529 fn test_extract_rust_struct_signature_and_elide_body() {
530 let source = "pub struct User {\n pub id: i32,\n pub name: String,\n}";
531 let spec = get_rust_spec();
532 let doc = parse_file(source, &spec);
533
534 assert_eq!(doc.symbols.len(), 1);
535 let sym = &doc.symbols[0];
536 assert_eq!(sym.name, "User");
537
538 let elided = elide_document(source, &doc);
539 assert_eq!(elided, "pub struct User ...");
540 }
541
542 #[test]
543 fn test_extract_rust_function_signature_and_elide_body() {
544 let source = "pub fn process_data(data: &[u8]) -> Result<(), Error> {\n // do work\n Ok(())\n}";
545 let spec = get_rust_spec();
546 let doc = parse_file(source, &spec);
547
548 assert_eq!(doc.symbols.len(), 1);
549 let sym = &doc.symbols[0];
550 assert_eq!(sym.name, "process_data");
551
552 let elided = elide_document(source, &doc);
553 assert_eq!(
554 elided,
555 "pub fn process_data(data: &[u8]) -> Result<(), Error> ..."
556 );
557 }
558
559 #[test]
560 fn test_handle_nested_functions_classes() {
561 let source =
562 "class MyClass:\n def my_method(self):\n def nested():\n pass";
563 let spec = get_python_spec();
564 let doc = parse_file(source, &spec);
565
566 assert_eq!(doc.symbols.len(), 3);
567
568 let method_sym = doc.symbols.iter().find(|s| s.name == "my_method").unwrap();
569 assert_eq!(method_sym.qualified_name, "MyClass.my_method");
570
571 let nested_sym = doc.symbols.iter().find(|s| s.name == "nested").unwrap();
572 assert_eq!(nested_sym.qualified_name, "MyClass.my_method.nested");
573 }
574
575 #[test]
576 fn test_use_fallback_name_extraction() {
577 let source = "def calculate_total(a: int, b: int) -> int:\n return a + b";
580 let mut spec = get_python_spec();
581 spec.name_fields.remove("function_definition");
583
584 let doc = parse_file(source, &spec);
585
586 assert_eq!(doc.symbols.len(), 1);
587 let sym = &doc.symbols[0];
588 assert_eq!(sym.name, "calculate_total");
589 }
590
591 #[test]
592 fn test_ignore_empty_files_or_no_symbols() {
593 let source = "# just a comment\n\n";
594 let spec = get_python_spec();
595 let doc = parse_file(source, &spec);
596
597 assert!(doc.symbols.is_empty());
598 }
599
600 #[test]
601 fn test_extract_java_class_and_method_elide_body() {
602 let source = "public class MyClass {\n public void doWork() {\n System.out.println(\"work\");\n }\n}";
603 let spec = get_java_spec();
604 let doc = parse_file(source, &spec);
605
606 assert_eq!(doc.symbols.len(), 2);
607
608 let class_sym = doc.symbols.iter().find(|s| s.name == "MyClass").unwrap();
609 assert_eq!(class_sym.kind, SymbolKind::Class);
610 assert!(class_sym.body_range.is_some());
611
612 let method_sym = doc.symbols.iter().find(|s| s.name == "doWork").unwrap();
613 assert_eq!(method_sym.kind, SymbolKind::Method);
614 assert_eq!(method_sym.qualified_name, "MyClass.doWork");
615 assert!(method_sym.body_range.is_some());
616
617 let elided = elide_document(source, &doc);
618 assert!(elided.starts_with("public class MyClass ..."));
619 }
620
621 #[test]
622 fn test_extract_c_cpp_function_elide_body() {
623 let source = "int calculate(int a, int b) {\n return a + b;\n}";
624 let spec = get_c_spec();
625 let doc = parse_file(source, &spec);
626
627 assert_eq!(doc.symbols.len(), 1);
628
629 let func_sym = doc.symbols.iter().find(|s| s.name == "calculate").unwrap();
630 assert_eq!(func_sym.kind, SymbolKind::Function);
631 assert!(func_sym.body_range.is_some());
632
633 let elided = elide_document(source, &doc);
634 assert!(elided.starts_with("int calculate(int a, int b) ..."));
635 }
636
637 #[test]
638 fn test_extract_csharp_class_and_method_elide_body() {
639 let source = "public class Server {\n public async Task StartAsync() {\n await Task.Delay(10);\n }\n}";
640 let spec = get_csharp_spec();
641 let doc = parse_file(source, &spec);
642
643 assert_eq!(doc.symbols.len(), 2);
644
645 let class_sym = doc.symbols.iter().find(|s| s.name == "Server").unwrap();
646 assert_eq!(class_sym.kind, SymbolKind::Class);
647
648 let method_sym = doc.symbols.iter().find(|s| s.name == "StartAsync").unwrap();
649 assert_eq!(method_sym.kind, SymbolKind::Method);
650 assert_eq!(method_sym.qualified_name, "Server.StartAsync");
651
652 let elided = elide_document(source, &doc);
653 assert!(elided.starts_with("public class Server ..."));
654 }
655
656 #[test]
657 fn test_extract_ruby_class_and_method_elide_body() {
658 let source = "class User\n def login(email)\n puts 'login'\n end\nend";
659 let spec = get_ruby_spec();
660 let doc = parse_file(source, &spec);
661
662 assert_eq!(doc.symbols.len(), 2);
663
664 let class_sym = doc.symbols.iter().find(|s| s.name == "User").unwrap();
665 assert_eq!(class_sym.kind, SymbolKind::Class);
666
667 let method_sym = doc.symbols.iter().find(|s| s.name == "login").unwrap();
668 assert_eq!(method_sym.kind, SymbolKind::Method);
669 assert_eq!(method_sym.qualified_name, "User.login");
670
671 let elided = elide_document(source, &doc);
672 assert!(elided.starts_with("class User..."));
673 }
674
675 #[test]
676 fn test_extract_php_class_and_method_elide_body() {
677 let source = "<?php\nclass Controller {\n public function handle($req) {\n return true;\n }\n}";
678 let spec = get_php_spec();
679 let doc = parse_file(source, &spec);
680
681 assert_eq!(doc.symbols.len(), 2);
682
683 let class_sym = doc.symbols.iter().find(|s| s.name == "Controller").unwrap();
684 assert_eq!(class_sym.kind, SymbolKind::Class);
685
686 let method_sym = doc.symbols.iter().find(|s| s.name == "handle").unwrap();
687 assert_eq!(method_sym.kind, SymbolKind::Method);
688 assert_eq!(method_sym.qualified_name, "Controller.handle");
689
690 let elided = elide_document(source, &doc);
691 assert!(elided.contains("class Controller ..."));
692 }
693
694 #[test]
695 fn test_extract_swift_class_and_function_elide_body() {
696 let source =
697 "class ViewModel {\n func loadData(with id: String) {\n print(id)\n }\n}";
698 let spec = get_swift_spec();
699 let doc = parse_file(source, &spec);
700
701 assert_eq!(doc.symbols.len(), 2);
702
703 let class_sym = doc.symbols.iter().find(|s| s.name == "ViewModel").unwrap();
704 assert_eq!(class_sym.kind, SymbolKind::Class);
705
706 let method_sym = doc.symbols.iter().find(|s| s.name == "loadData").unwrap();
707 assert_eq!(method_sym.kind, SymbolKind::Function);
708 assert_eq!(method_sym.qualified_name, "ViewModel.loadData");
709
710 let elided = elide_document(source, &doc);
711 assert!(elided.starts_with("class ViewModel ..."));
712 }
713}