Skip to main content

pecto_python/extractors/
common.rs

1use tree_sitter::Node;
2
3/// Parsed decorator info for Python `@decorator(args)`.
4#[derive(Debug)]
5pub struct DecoratorInfo {
6    pub name: String,
7    pub full_name: String, // e.g. "app.get" or "router.post"
8    pub args: Vec<String>,
9}
10
11/// Collect decorators from a decorated_definition or directly from a node.
12pub fn collect_decorators(node: &Node, source: &[u8]) -> Vec<DecoratorInfo> {
13    let mut decorators = Vec::new();
14
15    // If this is a decorated_definition, decorators are children
16    if node.kind() == "decorated_definition" {
17        for i in 0..node.named_child_count() {
18            let child = node.named_child(i).unwrap();
19            if child.kind() == "decorator"
20                && let Some(info) = parse_decorator(&child, source)
21            {
22                decorators.push(info);
23            }
24        }
25    }
26
27    decorators
28}
29
30/// Parse a single decorator node.
31fn parse_decorator(node: &Node, source: &[u8]) -> Option<DecoratorInfo> {
32    // Decorator structure: @ expression
33    // The expression can be: identifier, attribute (a.b), or call (a.b(...))
34    let mut full_name = String::new();
35    let mut args = Vec::new();
36
37    for i in 0..node.named_child_count() {
38        let child = node.named_child(i).unwrap();
39        match child.kind() {
40            "identifier" => {
41                full_name = node_text(&child, source);
42            }
43            "attribute" => {
44                full_name = node_text(&child, source);
45            }
46            "call" => {
47                // call has function + arguments
48                if let Some(func) = child.child_by_field_name("function") {
49                    full_name = node_text(&func, source);
50                }
51                if let Some(arg_list) = child.child_by_field_name("arguments") {
52                    for j in 0..arg_list.named_child_count() {
53                        let arg = arg_list.named_child(j).unwrap();
54                        args.push(node_text(&arg, source));
55                    }
56                }
57            }
58            _ => {}
59        }
60    }
61
62    if full_name.is_empty() {
63        return None;
64    }
65
66    // Extract simple name: "app.get" → "get", "router.post" → "post"
67    let name = full_name
68        .rsplit('.')
69        .next()
70        .unwrap_or(&full_name)
71        .to_string();
72
73    Some(DecoratorInfo {
74        name,
75        full_name,
76        args,
77    })
78}
79
80/// Extract text content from a tree-sitter node.
81pub fn node_text(node: &Node, source: &[u8]) -> String {
82    node.utf8_text(source).unwrap_or("").to_string()
83}
84
85/// Remove surrounding quotes from a string literal.
86pub fn clean_string_literal(s: &str) -> String {
87    s.trim_matches('"')
88        .trim_matches('\'')
89        .trim_start_matches("f\"")
90        .trim_start_matches("f'")
91        .to_string()
92}
93
94/// Convert PascalCase/snake_case to kebab-case.
95pub fn to_kebab_case(s: &str) -> String {
96    let mut result = String::new();
97    for (i, c) in s.chars().enumerate() {
98        if c == '_' {
99            result.push('-');
100        } else if c.is_uppercase() && i > 0 {
101            result.push('-');
102            result.push(c.to_ascii_lowercase());
103        } else {
104            result.push(c.to_ascii_lowercase());
105        }
106    }
107    result
108}
109
110/// Get the function/class name from a definition node.
111pub fn get_def_name(node: &Node, source: &[u8]) -> String {
112    node.child_by_field_name("name")
113        .map(|n| node_text(&n, source))
114        .unwrap_or_else(|| "unknown".to_string())
115}
116
117/// Get the inner definition from a decorated_definition node.
118pub fn get_inner_definition<'a>(node: &'a Node<'a>) -> Option<Node<'a>> {
119    if node.kind() == "decorated_definition" {
120        node.named_children(&mut node.walk())
121            .find(|c| c.kind() == "function_definition" || c.kind() == "class_definition")
122    } else {
123        Some(*node)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::context::ParsedFile;
131
132    fn parse_file(source: &str, path: &str) -> ParsedFile {
133        ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
134    }
135
136    #[test]
137    fn test_parse_python_decorators() {
138        let source = r#"
139from fastapi import APIRouter
140
141router = APIRouter()
142
143@router.get("/users/{user_id}")
144async def get_user(user_id: int):
145    return {"user_id": user_id}
146
147@router.post("/users")
148async def create_user(user: UserCreate):
149    return user
150"#;
151
152        let file = parse_file(source, "routes.py");
153        let root = file.tree.root_node();
154        let src = file.source.as_bytes();
155
156        let mut found = Vec::new();
157        for i in 0..root.named_child_count() {
158            let node = root.named_child(i).unwrap();
159            if node.kind() == "decorated_definition" {
160                let decorators = collect_decorators(&node, src);
161                for d in &decorators {
162                    found.push((d.name.clone(), d.full_name.clone(), d.args.clone()));
163                }
164            }
165        }
166
167        assert_eq!(found.len(), 2);
168        assert_eq!(found[0].0, "get"); // name
169        assert_eq!(found[0].1, "router.get"); // full_name
170        assert!(found[0].2[0].contains("/users/")); // first arg is path
171
172        assert_eq!(found[1].0, "post");
173        assert_eq!(found[1].1, "router.post");
174    }
175
176    #[test]
177    fn test_to_kebab_case() {
178        assert_eq!(to_kebab_case("UserService"), "user-service");
179        assert_eq!(to_kebab_case("user_service"), "user-service");
180        assert_eq!(to_kebab_case("get_users"), "get-users");
181    }
182
183    #[test]
184    fn test_clean_string_literal() {
185        assert_eq!(clean_string_literal("\"hello\""), "hello");
186        assert_eq!(clean_string_literal("'/api/users'"), "/api/users");
187    }
188}