agentroot_core/index/ast_chunker/strategies/
python.rs1use super::{get_breadcrumb, line_numbers, ChunkingStrategy};
4use crate::error::Result;
5use crate::index::ast_chunker::types::{
6 compute_chunk_hash, ChunkMetadata, ChunkType, SemanticChunk,
7};
8use tree_sitter::Node;
9
10const PYTHON_SEMANTIC_NODES: &[&str] = &[
11 "function_definition",
12 "class_definition",
13 "decorated_definition",
14];
15
16pub struct PythonStrategy;
17
18impl ChunkingStrategy for PythonStrategy {
19 fn semantic_node_types(&self) -> &[&str] {
20 PYTHON_SEMANTIC_NODES
21 }
22
23 fn extract_chunks(&self, source: &str, root: Node) -> Result<Vec<SemanticChunk>> {
24 let mut chunks = Vec::new();
25 let mut cursor = root.walk();
26 extract_python_chunks(source, &mut cursor, &mut chunks, self, None);
27
28 if chunks.is_empty() {
29 chunks.push(SemanticChunk::new(source.to_string(), ChunkType::Text, 0));
30 }
31
32 Ok(chunks)
33 }
34
35 fn chunk_type_for_node(&self, node: Node) -> ChunkType {
36 match node.kind() {
37 "function_definition" => ChunkType::Function,
38 "class_definition" => ChunkType::Class,
39 "decorated_definition" => {
40 if let Some(inner) = get_decorated_inner(node) {
41 match inner.kind() {
42 "function_definition" => ChunkType::Function,
43 "class_definition" => ChunkType::Class,
44 _ => ChunkType::Function,
45 }
46 } else {
47 ChunkType::Function
48 }
49 }
50 _ => ChunkType::Text,
51 }
52 }
53
54 fn extract_leading_trivia(&self, source: &str, node: Node) -> String {
55 let mut trivia = super::extract_leading_comments(source, node);
56
57 if let Some(docstring) = extract_docstring(source, node) {
58 if !trivia.is_empty() {
59 trivia.push('\n');
60 }
61 trivia.push_str(&docstring);
62 }
63
64 trivia
65 }
66}
67
68fn extract_python_chunks(
69 source: &str,
70 cursor: &mut tree_sitter::TreeCursor,
71 chunks: &mut Vec<SemanticChunk>,
72 strategy: &PythonStrategy,
73 parent_class: Option<&str>,
74) {
75 loop {
76 let node = cursor.node();
77 let kind = node.kind();
78
79 if PYTHON_SEMANTIC_NODES.contains(&kind) {
80 let actual_node = if kind == "decorated_definition" {
81 get_decorated_inner(node).unwrap_or(node)
82 } else {
83 node
84 };
85
86 let leading = strategy.extract_leading_trivia(source, node);
87 let trailing = strategy.extract_trailing_trivia(source, node);
88 let text = source[node.start_byte()..node.end_byte()].to_string();
89 let (start_line, end_line) = line_numbers(source, node.start_byte(), node.end_byte());
90
91 let name = actual_node
92 .child_by_field_name("name")
93 .map(|n| source[n.start_byte()..n.end_byte()].to_string());
94
95 let breadcrumb = match (parent_class, &name) {
96 (Some(cls), Some(n)) => Some(format!("{}::{}", cls, n)),
97 (None, Some(n)) => Some(n.clone()),
98 _ => get_breadcrumb(source, node),
99 };
100
101 let chunk_type =
102 if parent_class.is_some() && actual_node.kind() == "function_definition" {
103 ChunkType::Method
104 } else {
105 strategy.chunk_type_for_node(node)
106 };
107
108 let chunk_hash = compute_chunk_hash(&text, &leading, &trailing);
109
110 let chunk = SemanticChunk {
111 text,
112 chunk_type,
113 chunk_hash,
114 position: node.start_byte(),
115 token_count: None,
116 metadata: ChunkMetadata {
117 leading_trivia: leading,
118 trailing_trivia: trailing,
119 breadcrumb,
120 language: Some("python"),
121 start_line,
122 end_line,
123 },
124 };
125 chunks.push(chunk);
126
127 if actual_node.kind() == "class_definition" {
128 let class_name = name.as_deref();
129 if cursor.goto_first_child() {
130 extract_python_chunks(source, cursor, chunks, strategy, class_name);
131 cursor.goto_parent();
132 }
133 }
134 } else if cursor.goto_first_child() {
135 extract_python_chunks(source, cursor, chunks, strategy, parent_class);
136 cursor.goto_parent();
137 }
138
139 if !cursor.goto_next_sibling() {
140 break;
141 }
142 }
143}
144
145fn get_decorated_inner(node: Node) -> Option<Node> {
146 node.child_by_field_name("definition")
147}
148
149fn extract_docstring(source: &str, node: Node) -> Option<String> {
150 let body = match node.kind() {
151 "function_definition" | "class_definition" => node.child_by_field_name("body"),
152 "decorated_definition" => {
153 get_decorated_inner(node).and_then(|n| n.child_by_field_name("body"))
154 }
155 _ => None,
156 }?;
157
158 let mut cursor = body.walk();
159 if cursor.goto_first_child() {
160 let first = cursor.node();
161 if first.kind() == "expression_statement" {
162 if let Some(string) = first.child(0) {
163 if string.kind() == "string" {
164 return Some(source[string.start_byte()..string.end_byte()].to_string());
165 }
166 }
167 }
168 }
169 None
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::index::ast_chunker::language::Language;
176 use crate::index::ast_chunker::parser::parse;
177
178 #[test]
179 fn test_extract_function() {
180 let source = r#"
181def hello():
182 """Say hello."""
183 print("hello")
184"#;
185 let tree = parse(source, Language::Python).unwrap();
186 let strategy = PythonStrategy;
187 let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
188
189 assert!(!chunks.is_empty());
190 assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Function));
191 }
192
193 #[test]
194 fn test_extract_class() {
195 let source = r#"
196class MyClass:
197 """A class."""
198
199 def method(self):
200 pass
201"#;
202 let tree = parse(source, Language::Python).unwrap();
203 let strategy = PythonStrategy;
204 let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
205
206 assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Class));
207 assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Method));
208 }
209
210 #[test]
211 fn test_extract_decorated() {
212 let source = r#"
213@decorator
214def decorated_fn():
215 pass
216"#;
217 let tree = parse(source, Language::Python).unwrap();
218 let strategy = PythonStrategy;
219 let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
220
221 assert!(!chunks.is_empty());
222 }
223}