agentroot_core/index/ast_chunker/strategies/
python.rs

1//! Python-specific chunking strategy
2
3use super::{get_breadcrumb, line_numbers, ChunkingStrategy};
4use crate::error::Result;
5use crate::index::ast_chunker::types::{
6    compute_chunk_hash, ChunkMetadata, ChunkType, SemanticChunk,
7};
8use tree_sitter::Node;
9
10const PYTHON_SEMANTIC_NODES: &[&str] = &[
11    "function_definition",
12    "class_definition",
13    "decorated_definition",
14];
15
16pub struct PythonStrategy;
17
18impl ChunkingStrategy for PythonStrategy {
19    fn semantic_node_types(&self) -> &[&str] {
20        PYTHON_SEMANTIC_NODES
21    }
22
23    fn extract_chunks(&self, source: &str, root: Node) -> Result<Vec<SemanticChunk>> {
24        let mut chunks = Vec::new();
25        let mut cursor = root.walk();
26        extract_python_chunks(source, &mut cursor, &mut chunks, self, None);
27
28        if chunks.is_empty() {
29            chunks.push(SemanticChunk::new(source.to_string(), ChunkType::Text, 0));
30        }
31
32        Ok(chunks)
33    }
34
35    fn chunk_type_for_node(&self, node: Node) -> ChunkType {
36        match node.kind() {
37            "function_definition" => ChunkType::Function,
38            "class_definition" => ChunkType::Class,
39            "decorated_definition" => {
40                if let Some(inner) = get_decorated_inner(node) {
41                    match inner.kind() {
42                        "function_definition" => ChunkType::Function,
43                        "class_definition" => ChunkType::Class,
44                        _ => ChunkType::Function,
45                    }
46                } else {
47                    ChunkType::Function
48                }
49            }
50            _ => ChunkType::Text,
51        }
52    }
53
54    fn extract_leading_trivia(&self, source: &str, node: Node) -> String {
55        let mut trivia = super::extract_leading_comments(source, node);
56
57        if let Some(docstring) = extract_docstring(source, node) {
58            if !trivia.is_empty() {
59                trivia.push('\n');
60            }
61            trivia.push_str(&docstring);
62        }
63
64        trivia
65    }
66}
67
68fn extract_python_chunks(
69    source: &str,
70    cursor: &mut tree_sitter::TreeCursor,
71    chunks: &mut Vec<SemanticChunk>,
72    strategy: &PythonStrategy,
73    parent_class: Option<&str>,
74) {
75    loop {
76        let node = cursor.node();
77        let kind = node.kind();
78
79        if PYTHON_SEMANTIC_NODES.contains(&kind) {
80            let actual_node = if kind == "decorated_definition" {
81                get_decorated_inner(node).unwrap_or(node)
82            } else {
83                node
84            };
85
86            let leading = strategy.extract_leading_trivia(source, node);
87            let trailing = strategy.extract_trailing_trivia(source, node);
88            let text = source[node.start_byte()..node.end_byte()].to_string();
89            let (start_line, end_line) = line_numbers(source, node.start_byte(), node.end_byte());
90
91            let name = actual_node
92                .child_by_field_name("name")
93                .map(|n| source[n.start_byte()..n.end_byte()].to_string());
94
95            let breadcrumb = match (parent_class, &name) {
96                (Some(cls), Some(n)) => Some(format!("{}::{}", cls, n)),
97                (None, Some(n)) => Some(n.clone()),
98                _ => get_breadcrumb(source, node),
99            };
100
101            let chunk_type =
102                if parent_class.is_some() && actual_node.kind() == "function_definition" {
103                    ChunkType::Method
104                } else {
105                    strategy.chunk_type_for_node(node)
106                };
107
108            let chunk_hash = compute_chunk_hash(&text, &leading, &trailing);
109
110            let chunk = SemanticChunk {
111                text,
112                chunk_type,
113                chunk_hash,
114                position: node.start_byte(),
115                token_count: None,
116                metadata: ChunkMetadata {
117                    leading_trivia: leading,
118                    trailing_trivia: trailing,
119                    breadcrumb,
120                    language: Some("python"),
121                    start_line,
122                    end_line,
123                },
124            };
125            chunks.push(chunk);
126
127            if actual_node.kind() == "class_definition" {
128                let class_name = name.as_deref();
129                if cursor.goto_first_child() {
130                    extract_python_chunks(source, cursor, chunks, strategy, class_name);
131                    cursor.goto_parent();
132                }
133            }
134        } else if cursor.goto_first_child() {
135            extract_python_chunks(source, cursor, chunks, strategy, parent_class);
136            cursor.goto_parent();
137        }
138
139        if !cursor.goto_next_sibling() {
140            break;
141        }
142    }
143}
144
145fn get_decorated_inner(node: Node) -> Option<Node> {
146    node.child_by_field_name("definition")
147}
148
149fn extract_docstring(source: &str, node: Node) -> Option<String> {
150    let body = match node.kind() {
151        "function_definition" | "class_definition" => node.child_by_field_name("body"),
152        "decorated_definition" => {
153            get_decorated_inner(node).and_then(|n| n.child_by_field_name("body"))
154        }
155        _ => None,
156    }?;
157
158    let mut cursor = body.walk();
159    if cursor.goto_first_child() {
160        let first = cursor.node();
161        if first.kind() == "expression_statement" {
162            if let Some(string) = first.child(0) {
163                if string.kind() == "string" {
164                    return Some(source[string.start_byte()..string.end_byte()].to_string());
165                }
166            }
167        }
168    }
169    None
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::index::ast_chunker::language::Language;
176    use crate::index::ast_chunker::parser::parse;
177
178    #[test]
179    fn test_extract_function() {
180        let source = r#"
181def hello():
182    """Say hello."""
183    print("hello")
184"#;
185        let tree = parse(source, Language::Python).unwrap();
186        let strategy = PythonStrategy;
187        let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
188
189        assert!(!chunks.is_empty());
190        assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Function));
191    }
192
193    #[test]
194    fn test_extract_class() {
195        let source = r#"
196class MyClass:
197    """A class."""
198
199    def method(self):
200        pass
201"#;
202        let tree = parse(source, Language::Python).unwrap();
203        let strategy = PythonStrategy;
204        let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
205
206        assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Class));
207        assert!(chunks.iter().any(|c| c.chunk_type == ChunkType::Method));
208    }
209
210    #[test]
211    fn test_extract_decorated() {
212        let source = r#"
213@decorator
214def decorated_fn():
215    pass
216"#;
217        let tree = parse(source, Language::Python).unwrap();
218        let strategy = PythonStrategy;
219        let chunks = strategy.extract_chunks(source, tree.root_node()).unwrap();
220
221        assert!(!chunks.is_empty());
222    }
223}