Skip to main content

pecto_python/extractors/
scheduled.rs

1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4
5/// Extract Celery tasks, APScheduler jobs, and other scheduled patterns from Python.
6pub fn extract(file: &ParsedFile) -> Option<Capability> {
7    let root = file.tree.root_node();
8    let source = file.source.as_bytes();
9    let full_text = &file.source;
10
11    // Quick check
12    if !full_text.contains("@app.task")
13        && !full_text.contains("@shared_task")
14        && !full_text.contains("@periodic_task")
15        && !full_text.contains("@celery")
16        && !full_text.contains("crontab")
17        && !full_text.contains("add_job")
18        && !full_text.contains("APScheduler")
19        && !full_text.contains("BlockingScheduler")
20        && !full_text.contains("BackgroundScheduler")
21        && !full_text.contains("AsyncIOScheduler")
22    {
23        return None;
24    }
25
26    let mut tasks = Vec::new();
27
28    // Extract Celery decorator-based tasks
29    for i in 0..root.named_child_count() {
30        let node = root.named_child(i).unwrap();
31
32        if node.kind() == "decorated_definition" {
33            let decorators = collect_decorators(&node, source);
34            let inner = match get_inner_definition(&node) {
35                Some(n) if n.kind() == "function_definition" => n,
36                _ => continue,
37            };
38
39            let is_task = decorators.iter().any(|d| {
40                d.name == "task"
41                    || d.name == "shared_task"
42                    || d.name == "periodic_task"
43                    || d.full_name.contains("celery")
44                    || d.full_name.contains("app.task")
45            });
46
47            if !is_task {
48                continue;
49            }
50
51            let func_name = get_def_name(&inner, source);
52
53            let schedule = decorators
54                .iter()
55                .find(|d| d.name == "periodic_task")
56                .and_then(|d| {
57                    d.args
58                        .iter()
59                        .find(|a| a.contains("crontab") || a.contains("schedule"))
60                        .cloned()
61                })
62                .unwrap_or_else(|| "celery-task".to_string());
63
64            tasks.push(ScheduledTask {
65                name: func_name,
66                schedule: clean_string_literal(&schedule),
67                description: Some("Celery task".to_string()),
68            });
69        }
70    }
71
72    // Extract APScheduler add_job() calls
73    if full_text.contains("add_job") {
74        extract_apscheduler_jobs(full_text, &mut tasks);
75    }
76
77    if tasks.is_empty() {
78        return None;
79    }
80
81    let file_stem = file
82        .path
83        .rsplit('/')
84        .next()
85        .unwrap_or(&file.path)
86        .trim_end_matches(".py");
87    let capability_name = to_kebab_case(file_stem);
88
89    let mut capability = Capability::new(capability_name, file.path.clone());
90    capability.scheduled_tasks = tasks;
91    Some(capability)
92}
93
94/// Extract APScheduler `scheduler.add_job(func, trigger, ...)` patterns via text scanning.
95/// Handles multi-line add_job() calls by joining lines between `add_job(` and closing `)`.
96fn extract_apscheduler_jobs(source: &str, tasks: &mut Vec<ScheduledTask>) {
97    // Collect multi-line add_job() blocks
98    let mut blocks = Vec::new();
99    let mut current_block = String::new();
100    let mut paren_depth = 0i32;
101    let mut in_add_job = false;
102
103    for line in source.lines() {
104        let trimmed = line.trim();
105
106        if !in_add_job {
107            if trimmed.contains("add_job(") {
108                in_add_job = true;
109                current_block.clear();
110                // Count parens from this line onward
111                for ch in trimmed.chars() {
112                    if ch == '(' {
113                        paren_depth += 1;
114                    } else if ch == ')' {
115                        paren_depth -= 1;
116                    }
117                }
118                current_block.push_str(trimmed);
119                if paren_depth <= 0 {
120                    blocks.push(current_block.clone());
121                    in_add_job = false;
122                    paren_depth = 0;
123                }
124            }
125        } else {
126            for ch in trimmed.chars() {
127                if ch == '(' {
128                    paren_depth += 1;
129                } else if ch == ')' {
130                    paren_depth -= 1;
131                }
132            }
133            current_block.push(' ');
134            current_block.push_str(trimmed);
135            if paren_depth <= 0 {
136                blocks.push(current_block.clone());
137                in_add_job = false;
138                paren_depth = 0;
139            }
140        }
141    }
142
143    for block in &blocks {
144        let after_add_job = match block.split("add_job(").nth(1) {
145            Some(s) => s,
146            None => continue,
147        };
148
149        // Remove only the outermost closing paren (the one matching add_job's open paren)
150        // by finding the balanced closing paren from the end
151        let inner = strip_outer_closing_paren(after_add_job.trim_end());
152
153        let args: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
154        if args.is_empty() {
155            continue;
156        }
157
158        // First arg is the function reference (e.g. analyst.run_cycle)
159        let func_ref = args[0].trim();
160        let task_name = func_ref
161            .rsplit('.')
162            .next()
163            .unwrap_or(func_ref)
164            .trim_matches('"')
165            .trim_matches('\'')
166            .to_string();
167
168        // Try to extract the id= or name= kwarg for a better task name
169        let display_name = args
170            .iter()
171            .find_map(|a| extract_simple_kwarg(a, "id"))
172            .map(|v| v.trim_matches('"').trim_matches('\'').to_string())
173            .unwrap_or(task_name);
174
175        // Try to find schedule info from trigger or keyword args
176        let mut schedule = "apscheduler-job".to_string();
177        let full_block = block.to_lowercase();
178
179        if full_block.contains("intervaltrigger")
180            || full_block.contains("'interval'")
181            || full_block.contains("\"interval\"")
182        {
183            // Look for minutes=, hours=, seconds= anywhere in the block
184            for arg in &args {
185                if let Some(v) = extract_simple_kwarg(arg, "minutes") {
186                    // Check if it's a numeric literal
187                    if v.chars().all(|c| c.is_ascii_digit()) {
188                        schedule = format!("every {}min", v);
189                    } else {
190                        schedule = "interval".to_string();
191                    }
192                    break;
193                } else if let Some(v) = extract_simple_kwarg(arg, "hours") {
194                    if v.chars().all(|c| c.is_ascii_digit()) {
195                        schedule = format!("every {}h", v);
196                    } else {
197                        schedule = "interval".to_string();
198                    }
199                    break;
200                } else if let Some(v) = extract_simple_kwarg(arg, "seconds") {
201                    if v.chars().all(|c| c.is_ascii_digit()) {
202                        schedule = format!("every {}s", v);
203                    } else {
204                        schedule = "interval".to_string();
205                    }
206                    break;
207                }
208            }
209        } else if full_block.contains("crontrigger")
210            || full_block.contains("'cron'")
211            || full_block.contains("\"cron\"")
212        {
213            schedule = "cron".to_string();
214        }
215
216        tasks.push(ScheduledTask {
217            name: display_name,
218            schedule,
219            description: Some("APScheduler job".to_string()),
220        });
221    }
222}
223
224/// Strip only the outermost closing paren from an add_job argument list.
225/// For `func, IntervalTrigger(hours=1))` → `func, IntervalTrigger(hours=1)`
226/// For `func, trigger=IntervalTrigger(hours=1), )` → `func, trigger=IntervalTrigger(hours=1), `
227fn strip_outer_closing_paren(s: &str) -> &str {
228    // Find the position of the last ')' and remove only that one
229    if let Some(pos) = s.rfind(')') {
230        // Check if there's another ')' immediately before — if so, we need to be careful
231        // Count parens to find the matching one for add_job's opening '('
232        let mut depth = 1i32; // We start after add_job(
233        for (i, ch) in s.char_indices() {
234            match ch {
235                '(' => depth += 1,
236                ')' => {
237                    depth -= 1;
238                    if depth == 0 {
239                        return &s[..i];
240                    }
241                }
242                _ => {}
243            }
244        }
245        // Fallback: just strip the last ')'
246        &s[..pos]
247    } else {
248        s
249    }
250}
251
252fn extract_simple_kwarg<'a>(text: &'a str, key: &str) -> Option<&'a str> {
253    let pattern = format!("{}=", key);
254    let start = text.find(&pattern)? + pattern.len();
255    let remaining = &text[start..];
256    let end = remaining.find([',', ')', ' ']).unwrap_or(remaining.len());
257    let val = remaining[..end].trim();
258    if val.is_empty() { None } else { Some(val) }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::context::ParsedFile;
265
266    fn parse_file(source: &str, path: &str) -> ParsedFile {
267        ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
268    }
269
270    #[test]
271    fn test_celery_task() {
272        let source = r#"
273from celery import shared_task
274
275@shared_task
276def send_email(to: str, subject: str):
277    mail.send(to, subject)
278
279@shared_task
280def process_order(order_id: int):
281    Order.process(order_id)
282"#;
283
284        let file = parse_file(source, "tasks/email_tasks.py");
285        let capability = extract(&file).unwrap();
286
287        assert_eq!(capability.scheduled_tasks.len(), 2);
288        assert_eq!(capability.scheduled_tasks[0].name, "send_email");
289        assert_eq!(capability.scheduled_tasks[1].name, "process_order");
290    }
291
292    #[test]
293    fn test_no_tasks() {
294        let source = r#"
295def regular_function():
296    pass
297"#;
298        let file = parse_file(source, "utils.py");
299        assert!(extract(&file).is_none());
300    }
301
302    #[test]
303    fn test_apscheduler_multiline() {
304        let source = r#"
305from apscheduler.schedulers.blocking import BlockingScheduler
306from apscheduler.triggers.interval import IntervalTrigger
307
308scheduler = BlockingScheduler()
309
310scheduler.add_job(
311    analyst.run_cycle,
312    trigger=IntervalTrigger(minutes=30),
313    id="agent_cycle",
314    name="Agent Cycle",
315    max_instances=1,
316)
317
318scheduler.add_job(
319    portfolio.snapshot,
320    trigger=IntervalTrigger(hours=1),
321    id="portfolio_snapshot",
322    name="Portfolio Snapshot",
323)
324"#;
325
326        let file = parse_file(source, "scheduler.py");
327        let capability = extract(&file).unwrap();
328
329        assert_eq!(capability.scheduled_tasks.len(), 2);
330        assert_eq!(capability.scheduled_tasks[0].name, "agent_cycle");
331        assert_eq!(capability.scheduled_tasks[0].schedule, "every 30min");
332        assert_eq!(capability.scheduled_tasks[1].name, "portfolio_snapshot");
333        assert_eq!(capability.scheduled_tasks[1].schedule, "every 1h");
334    }
335
336    #[test]
337    fn test_apscheduler_compact_no_trailing_comma() {
338        // Edge case: compact add_job() without trailing comma — the )) at end
339        // must NOT strip the inner ) from IntervalTrigger(hours=1)
340        let source = r#"
341from apscheduler.schedulers.blocking import BlockingScheduler
342from apscheduler.triggers.interval import IntervalTrigger
343
344scheduler = BlockingScheduler()
345scheduler.add_job(my_func, IntervalTrigger(hours=2), id="cleanup")
346"#;
347
348        let file = parse_file(source, "scheduler.py");
349        let capability = extract(&file).unwrap();
350
351        assert_eq!(capability.scheduled_tasks.len(), 1);
352        assert_eq!(capability.scheduled_tasks[0].name, "cleanup");
353        assert_eq!(capability.scheduled_tasks[0].schedule, "every 2h");
354    }
355}