1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4
5pub fn extract(file: &ParsedFile) -> Option<Capability> {
8 let root = file.tree.root_node();
9 let source = file.source.as_bytes();
10
11 if let Some(cap) = extract_class_service(root, source, file) {
13 return Some(cap);
14 }
15
16 extract_module_functions(root, source, file)
18}
19
20fn extract_class_service(
21 root: tree_sitter::Node,
22 source: &[u8],
23 file: &ParsedFile,
24) -> Option<Capability> {
25 let mut operations = Vec::new();
26 let mut class_name = String::new();
27
28 for i in 0..root.named_child_count() {
29 let node = root.named_child(i).unwrap();
30
31 let (decorators, class_node) = if node.kind() == "class_definition" {
32 (Vec::new(), node)
33 } else if node.kind() == "decorated_definition" {
34 let decs = collect_decorators(&node, source);
35 match get_inner_definition(&node) {
36 Some(n) if n.kind() == "class_definition" => (decs, n),
37 _ => continue,
38 }
39 } else {
40 continue;
41 };
42
43 let name = get_def_name(&class_node, source);
44 let is_service = name.ends_with("Service")
45 || name.ends_with("Repository")
46 || name.ends_with("UseCase")
47 || decorators
48 .iter()
49 .any(|d| d.name == "inject" || d.name == "injectable" || d.name == "service");
50
51 if !is_service {
52 continue;
53 }
54
55 class_name = name.clone();
56
57 if let Some(body) = class_node.child_by_field_name("body") {
59 for j in 0..body.named_child_count() {
60 let member = body.named_child(j).unwrap();
61
62 let func = if member.kind() == "function_definition" {
63 member
64 } else if member.kind() == "decorated_definition" {
65 match get_inner_definition(&member) {
66 Some(n) if n.kind() == "function_definition" => n,
67 _ => continue,
68 }
69 } else {
70 continue;
71 };
72
73 let method_name = get_def_name(&func, source);
74
75 if method_name.starts_with('_') {
77 continue;
78 }
79
80 operations.push(make_operation(&func, source, &name, &method_name));
81 }
82 }
83
84 break; }
86
87 if operations.is_empty() {
88 return None;
89 }
90
91 let capability_name = to_kebab_case(
92 &class_name
93 .replace("Service", "")
94 .replace("Repository", "")
95 .replace("UseCase", ""),
96 );
97 let mut capability = Capability::new(format!("{}-service", capability_name), file.path.clone());
98 capability.operations = operations;
99 Some(capability)
100}
101
102fn extract_module_functions(
105 root: tree_sitter::Node,
106 source: &[u8],
107 file: &ParsedFile,
108) -> Option<Capability> {
109 let file_stem = file
110 .path
111 .rsplit('/')
112 .next()
113 .unwrap_or(&file.path)
114 .trim_end_matches(".py")
115 .to_lowercase();
116
117 let is_service_file = file_stem == "crud"
118 || file_stem == "services"
119 || file_stem == "service"
120 || file_stem == "repository"
121 || file_stem == "queries"
122 || file_stem == "actions"
123 || file_stem.ends_with("_service")
124 || file_stem.ends_with("_crud")
125 || file_stem.ends_with("_repository");
126
127 if !is_service_file {
128 return None;
129 }
130
131 let mut operations = Vec::new();
132 let module_name = to_kebab_case(&file_stem);
133
134 for i in 0..root.named_child_count() {
135 let node = root.named_child(i).unwrap();
136
137 let func = if node.kind() == "function_definition" {
138 node
139 } else if node.kind() == "decorated_definition" {
140 match get_inner_definition(&node) {
141 Some(n) if n.kind() == "function_definition" => n,
142 _ => continue,
143 }
144 } else {
145 continue;
146 };
147
148 let func_name = get_def_name(&func, source);
149
150 if func_name.starts_with('_') {
152 continue;
153 }
154
155 operations.push(make_operation(&func, source, &file_stem, &func_name));
156 }
157
158 if operations.is_empty() {
159 return None;
160 }
161
162 let mut capability = Capability::new(format!("{}-service", module_name), file.path.clone());
163 capability.operations = operations;
164 Some(capability)
165}
166
167fn make_operation(
168 func: &tree_sitter::Node,
169 source: &[u8],
170 owner_name: &str,
171 method_name: &str,
172) -> Operation {
173 let input = func
174 .child_by_field_name("parameters")
175 .and_then(|p| extract_first_non_self_param(&p, source));
176
177 let return_type = func
178 .child_by_field_name("return_type")
179 .map(|t| node_text(&t, source))
180 .filter(|t| t != "None" && !t.is_empty());
181
182 Operation {
183 name: method_name.to_string(),
184 source_method: format!("{}#{}", owner_name, method_name),
185 input: input.map(|t| TypeRef {
186 name: t,
187 fields: std::collections::BTreeMap::new(),
188 }),
189 behaviors: vec![Behavior {
190 name: "success".to_string(),
191 condition: None,
192 returns: ResponseSpec {
193 status: 200,
194 body: return_type.map(|t| TypeRef {
195 name: t,
196 fields: std::collections::BTreeMap::new(),
197 }),
198 },
199 side_effects: Vec::new(),
200 }],
201 transaction: None,
202 }
203}
204
205fn extract_first_non_self_param(params: &tree_sitter::Node, source: &[u8]) -> Option<String> {
206 for i in 0..params.named_child_count() {
207 let param = params.named_child(i).unwrap();
208 let name = match param.kind() {
209 "typed_parameter" | "typed_default_parameter" => param
210 .child_by_field_name("name")
211 .map(|n| node_text(&n, source))
212 .unwrap_or_default(),
213 "identifier" => node_text(¶m, source),
214 _ => continue,
215 };
216
217 if name == "self" || name == "cls" {
218 continue;
219 }
220
221 if let Some(type_node) = param.child_by_field_name("type") {
223 return Some(node_text(&type_node, source));
224 }
225 return Some(name);
226 }
227 None
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::context::ParsedFile;
234
235 fn parse_file(source: &str, path: &str) -> ParsedFile {
236 ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
237 }
238
239 #[test]
240 fn test_service_extraction() {
241 let source = r#"
242class UserService:
243 def __init__(self, db: Database):
244 self.db = db
245
246 def find_by_id(self, user_id: int) -> User:
247 return self.db.get(user_id)
248
249 def create(self, data: UserCreate) -> User:
250 return self.db.create(data)
251
252 def _private_helper(self):
253 pass
254"#;
255
256 let file = parse_file(source, "services/user_service.py");
257 let capability = extract(&file).unwrap();
258
259 assert_eq!(capability.name, "user-service");
260 assert_eq!(capability.operations.len(), 2);
262 assert_eq!(capability.operations[0].name, "find_by_id");
263 assert_eq!(capability.operations[1].name, "create");
264 }
265
266 #[test]
267 fn test_non_service() {
268 let source = r#"
269class Helper:
270 def do_thing(self):
271 pass
272"#;
273 let file = parse_file(source, "utils.py");
274 assert!(extract(&file).is_none());
275 }
276
277 #[test]
278 fn test_module_level_crud() {
279 let source = r#"
280from sqlmodel import Session
281from app.models import User, UserCreate
282
283def create_user(session: Session, user_create: UserCreate) -> User:
284 db_obj = User.model_validate(user_create)
285 session.add(db_obj)
286 session.commit()
287 return db_obj
288
289def get_user_by_email(session: Session, email: str) -> User:
290 return session.exec(select(User).where(User.email == email)).first()
291
292def _private_helper():
293 pass
294"#;
295
296 let file = parse_file(source, "app/crud.py");
297 let capability = extract(&file).unwrap();
298
299 assert_eq!(capability.name, "crud-service");
300 assert_eq!(capability.operations.len(), 2);
301 assert_eq!(capability.operations[0].name, "create_user");
302 assert_eq!(capability.operations[1].name, "get_user_by_email");
303 }
304
305 #[test]
306 fn test_module_level_non_crud_file() {
307 let source = r#"
308def helper_func():
309 pass
310"#;
311 let file = parse_file(source, "utils.py");
312 assert!(extract(&file).is_none());
313 }
314}