agentroot_core/index/ast_chunker/strategies/
rust.rs1use super::{get_breadcrumb, line_numbers, ChunkingStrategy};
4use crate::error::Result;
5use crate::index::ast_chunker::types::{
6 compute_chunk_hash, ChunkMetadata, ChunkType, SemanticChunk,
7};
8use tree_sitter::Node;
9
10const RUST_SEMANTIC_NODES: &[&str] = &[
11 "function_item",
12 "impl_item",
13 "struct_item",
14 "enum_item",
15 "trait_item",
16 "mod_item",
17 "type_item",
18 "const_item",
19 "static_item",
20 "macro_definition",
21];
22
23pub struct RustStrategy;
24
25impl ChunkingStrategy for RustStrategy {
26 fn semantic_node_types(&self) -> &[&str] {
27 RUST_SEMANTIC_NODES
28 }
29
30 fn extract_chunks(&self, source: &str, root: Node) -> Result<Vec<SemanticChunk>> {
31 let mut chunks = Vec::new();
32 let mut cursor = root.walk();
33 extract_rust_chunks(source, &mut cursor, &mut chunks, self);
34
35 if chunks.is_empty() {
36 chunks.push(SemanticChunk::new(source.to_string(), ChunkType::Text, 0));
37 }
38
39 Ok(chunks)
40 }
41
42 fn chunk_type_for_node(&self, node: Node) -> ChunkType {
43 match node.kind() {
44 "function_item" => ChunkType::Function,
45 "impl_item" => {
46 if has_child_kind(node, "trait") {
47 ChunkType::Trait
48 } else {
49 ChunkType::Method
50 }
51 }
52 "struct_item" => ChunkType::Struct,
53 "enum_item" => ChunkType::Enum,
54 "trait_item" => ChunkType::Trait,
55 "mod_item" => ChunkType::Module,
56 _ => ChunkType::Function,
57 }
58 }
59}
60
61fn extract_rust_chunks(
62 source: &str,
63 cursor: &mut tree_sitter::TreeCursor,
64 chunks: &mut Vec<SemanticChunk>,
65 strategy: &RustStrategy,
66) {
67 loop {
68 let node = cursor.node();
69 let kind = node.kind();
70
71 if RUST_SEMANTIC_NODES.contains(&kind) {
72 let leading = strategy.extract_leading_trivia(source, node);
73 let trailing = strategy.extract_trailing_trivia(source, node);
74 let text = source[node.start_byte()..node.end_byte()].to_string();
75 let (start_line, end_line) = line_numbers(source, node.start_byte(), node.end_byte());
76 let breadcrumb = get_breadcrumb(source, node);
77 let chunk_hash = compute_chunk_hash(&text, &leading, &trailing);
78
79 let chunk = SemanticChunk {
80 text,
81 chunk_type: strategy.chunk_type_for_node(node),
82 chunk_hash,
83 position: node.start_byte(),
84 token_count: None,
85 metadata: ChunkMetadata {
86 leading_trivia: leading,
87 trailing_trivia: trailing,
88 breadcrumb,
89 language: Some("rust"),
90 start_line,
91 end_line,
92 },
93 };
94 chunks.push(chunk);
95
96 if kind == "impl_item" && cursor.goto_first_child() {
97 extract_rust_chunks(source, cursor, chunks, strategy);
98 cursor.goto_parent();
99 }
100 } else if cursor.goto_first_child() {
101 extract_rust_chunks(source, cursor, chunks, strategy);
102 cursor.goto_parent();
103 }
104
105 if !cursor.goto_next_sibling() {
106 break;
107 }
108 }
109}
110
111fn has_child_kind(node: Node, kind: &str) -> bool {
112 let mut cursor = node.walk();
113 if cursor.goto_first_child() {
114 loop {
115 if cursor.node().kind() == kind {
116 return true;
117 }
118 if !cursor.goto_next_sibling() {
119 break;
120 }
121 }
122 }
123 false
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use crate::index::ast_chunker::language::Language;
130 use crate::index::ast_chunker::parser::parse;
131
132 #[test]
133 fn test_extract_function() {
134 let source = r#"
135/// A test function
136fn test_fn() {
137 println!("hello");
138}
139"#;
140 let tree = parse(source, Language::Rust).unwrap();
141 let strategy = RustStrategy;
142 let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
143
144 assert!(!chunks.is_empty());
145 assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Function));
146 }
147
148 #[test]
149 fn test_extract_struct() {
150 let source = r#"
151/// My struct
152struct MyStruct {
153 field: i32,
154}
155"#;
156 let tree = parse(source, Language::Rust).unwrap();
157 let strategy = RustStrategy;
158 let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
159
160 assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Struct));
161 }
162
163 #[test]
164 fn test_extract_impl() {
165 let source = r#"
166impl MyStruct {
167 fn new() -> Self {
168 Self { field: 0 }
169 }
170}
171"#;
172 let tree = parse(source, Language::Rust).unwrap();
173 let strategy = RustStrategy;
174 let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
175
176 assert!(!chunks.is_empty());
177 }
178}