agentroot_core/index/ast_chunker/strategies/
rust.rs

1//! Rust-specific chunking strategy
2
3use super::{get_breadcrumb, line_numbers, ChunkingStrategy};
4use crate::error::Result;
5use crate::index::ast_chunker::types::{
6    compute_chunk_hash, ChunkMetadata, ChunkType, SemanticChunk,
7};
8use tree_sitter::Node;
9
10const RUST_SEMANTIC_NODES: &[&str] = &[
11    "function_item",
12    "impl_item",
13    "struct_item",
14    "enum_item",
15    "trait_item",
16    "mod_item",
17    "type_item",
18    "const_item",
19    "static_item",
20    "macro_definition",
21];
22
23pub struct RustStrategy;
24
25impl ChunkingStrategy for RustStrategy {
26    fn semantic_node_types(&self) -> &[&str] {
27        RUST_SEMANTIC_NODES
28    }
29
30    fn extract_chunks(&self, source: &str, root: Node) -> Result<Vec<SemanticChunk>> {
31        let mut chunks = Vec::new();
32        let mut cursor = root.walk();
33        extract_rust_chunks(source, &mut cursor, &mut chunks, self);
34
35        if chunks.is_empty() {
36            chunks.push(SemanticChunk::new(source.to_string(), ChunkType::Text, 0));
37        }
38
39        Ok(chunks)
40    }
41
42    fn chunk_type_for_node(&self, node: Node) -> ChunkType {
43        match node.kind() {
44            "function_item" => ChunkType::Function,
45            "impl_item" => {
46                if has_child_kind(node, "trait") {
47                    ChunkType::Trait
48                } else {
49                    ChunkType::Method
50                }
51            }
52            "struct_item" => ChunkType::Struct,
53            "enum_item" => ChunkType::Enum,
54            "trait_item" => ChunkType::Trait,
55            "mod_item" => ChunkType::Module,
56            _ => ChunkType::Function,
57        }
58    }
59}
60
61fn extract_rust_chunks(
62    source: &str,
63    cursor: &mut tree_sitter::TreeCursor,
64    chunks: &mut Vec<SemanticChunk>,
65    strategy: &RustStrategy,
66) {
67    loop {
68        let node = cursor.node();
69        let kind = node.kind();
70
71        if RUST_SEMANTIC_NODES.contains(&kind) {
72            let leading = strategy.extract_leading_trivia(source, node);
73            let trailing = strategy.extract_trailing_trivia(source, node);
74            let text = source[node.start_byte()..node.end_byte()].to_string();
75            let (start_line, end_line) = line_numbers(source, node.start_byte(), node.end_byte());
76            let breadcrumb = get_breadcrumb(source, node);
77            let chunk_hash = compute_chunk_hash(&text, &leading, &trailing);
78
79            let chunk = SemanticChunk {
80                text,
81                chunk_type: strategy.chunk_type_for_node(node),
82                chunk_hash,
83                position: node.start_byte(),
84                token_count: None,
85                metadata: ChunkMetadata {
86                    leading_trivia: leading,
87                    trailing_trivia: trailing,
88                    breadcrumb,
89                    language: Some("rust"),
90                    start_line,
91                    end_line,
92                },
93            };
94            chunks.push(chunk);
95
96            if kind == "impl_item" && cursor.goto_first_child() {
97                extract_rust_chunks(source, cursor, chunks, strategy);
98                cursor.goto_parent();
99            }
100        } else if cursor.goto_first_child() {
101            extract_rust_chunks(source, cursor, chunks, strategy);
102            cursor.goto_parent();
103        }
104
105        if !cursor.goto_next_sibling() {
106            break;
107        }
108    }
109}
110
111fn has_child_kind(node: Node, kind: &str) -> bool {
112    let mut cursor = node.walk();
113    if cursor.goto_first_child() {
114        loop {
115            if cursor.node().kind() == kind {
116                return true;
117            }
118            if !cursor.goto_next_sibling() {
119                break;
120            }
121        }
122    }
123    false
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::index::ast_chunker::language::Language;
130    use crate::index::ast_chunker::parser::parse;
131
132    #[test]
133    fn test_extract_function() {
134        let source = r#"
135/// A test function
136fn test_fn() {
137    println!("hello");
138}
139"#;
140        let tree = parse(source, Language::Rust).unwrap();
141        let strategy = RustStrategy;
142        let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
143
144        assert!(!chunks.is_empty());
145        assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Function));
146    }
147
148    #[test]
149    fn test_extract_struct() {
150        let source = r#"
151/// My struct
152struct MyStruct {
153    field: i32,
154}
155"#;
156        let tree = parse(source, Language::Rust).unwrap();
157        let strategy = RustStrategy;
158        let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
159
160        assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Struct));
161    }
162
163    #[test]
164    fn test_extract_impl() {
165        let source = r#"
166impl MyStruct {
167    fn new() -> Self {
168        Self { field: 0 }
169    }
170}
171"#;
172        let tree = parse(source, Language::Rust).unwrap();
173        let strategy = RustStrategy;
174        let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
175
176        assert!(!chunks.is_empty());
177    }
178}