Skip to main content

oven_cli/pipeline/
runner.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4    time::Duration,
5};
6
7use anyhow::Result;
8use tokio::{
9    sync::{Mutex, Semaphore},
10    task::JoinSet,
11};
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info};
14
15use super::executor::PipelineExecutor;
16use crate::{agents::Complexity, issues::PipelineIssue, process::CommandRunner};
17
18/// Run the pipeline for a batch of issues, limiting parallelism with a semaphore.
19///
20/// Used for the explicit-IDs path (`oven on 42,43`). For the polling path, see
21/// [`polling_loop`] which handles continuous issue discovery.
22pub async fn run_batch<R: CommandRunner + 'static>(
23    executor: &Arc<PipelineExecutor<R>>,
24    issues: Vec<PipelineIssue>,
25    max_parallel: usize,
26    auto_merge: bool,
27) -> Result<()> {
28    let semaphore = Arc::new(Semaphore::new(max_parallel));
29    let mut tasks = JoinSet::new();
30
31    for issue in issues {
32        let permit = semaphore
33            .clone()
34            .acquire_owned()
35            .await
36            .map_err(|e| anyhow::anyhow!("semaphore closed: {e}"))?;
37        let exec = Arc::clone(executor);
38        tasks.spawn(async move {
39            let number = issue.number;
40            let result = exec.run_issue(&issue, auto_merge).await;
41            drop(permit);
42            (number, result)
43        });
44    }
45
46    let mut had_errors = false;
47    while let Some(join_result) = tasks.join_next().await {
48        match join_result {
49            Ok((number, Ok(()))) => {
50                info!(issue = number, "pipeline completed successfully");
51            }
52            Ok((number, Err(e))) => {
53                error!(issue = number, error = %e, "pipeline failed for issue");
54                had_errors = true;
55            }
56            Err(e) => {
57                error!(error = %e, "pipeline task panicked");
58                had_errors = true;
59            }
60        }
61    }
62
63    if had_errors {
64        anyhow::bail!("one or more pipelines failed");
65    }
66
67    Ok(())
68}
69
70/// Extract per-issue complexity from the planner, if available.
71///
72/// Returns an empty map if the planner fails or returns unparseable output.
73async fn get_complexity_map<R: CommandRunner + 'static>(
74    executor: &Arc<PipelineExecutor<R>>,
75    issues: &[PipelineIssue],
76) -> HashMap<u32, Complexity> {
77    let mut map = HashMap::new();
78    if let Some(plan) = executor.plan_issues(issues).await {
79        info!(batches = plan.batches.len(), total = plan.total_issues, "planner produced a plan");
80        for batch in &plan.batches {
81            for pi in &batch.issues {
82                map.insert(pi.number, pi.complexity.clone());
83            }
84        }
85    }
86    map
87}
88
89fn handle_task_result(result: Result<(u32, Result<()>), tokio::task::JoinError>) {
90    match result {
91        Ok((number, Ok(()))) => {
92            info!(issue = number, "pipeline completed successfully");
93        }
94        Ok((number, Err(e))) => {
95            error!(issue = number, error = %e, "pipeline failed for issue");
96        }
97        Err(e) => {
98            error!(error = %e, "pipeline task panicked");
99        }
100    }
101}
102
103/// Poll for new issues and run them through the pipeline.
104///
105/// Unlike `run_batch`, this function continuously polls for new issues even while
106/// existing pipelines are running. Uses a shared semaphore and `JoinSet` that persist
107/// across poll cycles, with in-flight tracking to prevent double-spawning.
108pub async fn polling_loop<R: CommandRunner + 'static>(
109    executor: Arc<PipelineExecutor<R>>,
110    auto_merge: bool,
111    cancel_token: CancellationToken,
112) -> Result<()> {
113    let poll_interval = Duration::from_secs(executor.config.pipeline.poll_interval);
114    let max_parallel = executor.config.pipeline.max_parallel as usize;
115    let ready_label = executor.config.labels.ready.clone();
116    let semaphore = Arc::new(Semaphore::new(max_parallel));
117    let mut tasks = JoinSet::new();
118    let in_flight: Arc<Mutex<HashSet<u32>>> = Arc::new(Mutex::new(HashSet::new()));
119
120    info!(poll_interval_secs = poll_interval.as_secs(), max_parallel, "continuous polling started");
121
122    loop {
123        tokio::select! {
124            () = cancel_token.cancelled() => {
125                info!("shutdown signal received, waiting for in-flight pipelines");
126                while let Some(result) = tasks.join_next().await {
127                    handle_task_result(result);
128                }
129                break;
130            }
131            () = tokio::time::sleep(poll_interval) => {
132                match executor.issues.get_ready_issues(&ready_label).await {
133                    Ok(issues) => {
134                        let in_flight_guard = in_flight.lock().await;
135                        let new_issues: Vec<_> = issues
136                            .into_iter()
137                            .filter(|i| !in_flight_guard.contains(&i.number))
138                            .collect();
139                        drop(in_flight_guard);
140
141                        if new_issues.is_empty() {
142                            info!("no new issues found, waiting");
143                            continue;
144                        }
145
146                        info!(count = new_issues.len(), "found new issues to process");
147
148                        let complexity_map =
149                            get_complexity_map(&executor, &new_issues).await;
150
151                        for issue in new_issues {
152                            let sem = Arc::clone(&semaphore);
153                            let exec = Arc::clone(&executor);
154                            let in_fl = Arc::clone(&in_flight);
155                            let number = issue.number;
156                            let complexity = complexity_map.get(&number).cloned();
157
158                            in_fl.lock().await.insert(number);
159
160                            tasks.spawn(async move {
161                                let permit = match sem.acquire_owned().await {
162                                    Ok(p) => p,
163                                    Err(e) => {
164                                        in_fl.lock().await.remove(&number);
165                                        return (
166                                            number,
167                                            Err(anyhow::anyhow!(
168                                                "semaphore closed: {e}"
169                                            )),
170                                        );
171                                    }
172                                };
173                                let result = exec
174                                    .run_issue_with_complexity(
175                                        &issue,
176                                        auto_merge,
177                                        complexity,
178                                    )
179                                    .await;
180                                in_fl.lock().await.remove(&number);
181                                drop(permit);
182                                (number, result)
183                            });
184                        }
185                    }
186                    Err(e) => {
187                        error!(error = %e, "failed to fetch issues");
188                    }
189                }
190            }
191            Some(result) = tasks.join_next(), if !tasks.is_empty() => {
192                handle_task_result(result);
193            }
194        }
195    }
196
197    Ok(())
198}
199
200#[cfg(test)]
201mod tests {
202    use std::path::PathBuf;
203
204    use tokio::sync::Mutex;
205
206    use super::*;
207    use crate::{
208        config::Config,
209        github::GhClient,
210        issues::{IssueOrigin, IssueProvider, github::GithubIssueProvider},
211        process::{AgentResult, CommandOutput, MockCommandRunner},
212    };
213
214    fn mock_runner_for_batch() -> MockCommandRunner {
215        let mut mock = MockCommandRunner::new();
216        mock.expect_run_gh().returning(|_, _| {
217            Box::pin(async {
218                Ok(CommandOutput {
219                    stdout: "https://github.com/user/repo/pull/1\n".to_string(),
220                    stderr: String::new(),
221                    success: true,
222                })
223            })
224        });
225        mock.expect_run_claude().returning(|_, _, _, _| {
226            Box::pin(async {
227                Ok(AgentResult {
228                    cost_usd: 1.0,
229                    duration: Duration::from_secs(5),
230                    turns: 3,
231                    output: r#"{"findings":[],"summary":"clean"}"#.to_string(),
232                    session_id: "sess-1".to_string(),
233                    success: true,
234                })
235            })
236        });
237        mock
238    }
239
240    fn make_github_provider(gh: &Arc<GhClient<MockCommandRunner>>) -> Arc<dyn IssueProvider> {
241        Arc::new(GithubIssueProvider::new(Arc::clone(gh), "target_repo"))
242    }
243
244    #[tokio::test]
245    async fn cancellation_stops_polling() {
246        let cancel = CancellationToken::new();
247        let runner = Arc::new(mock_runner_for_batch());
248        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
249        let issues = make_github_provider(&github);
250        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
251
252        let mut config = Config::default();
253        config.pipeline.poll_interval = 3600; // very long so we don't actually poll
254
255        let executor = Arc::new(PipelineExecutor {
256            runner,
257            github,
258            issues,
259            db,
260            config,
261            cancel_token: cancel.clone(),
262            repo_dir: PathBuf::from("/tmp"),
263        });
264
265        let cancel_clone = cancel.clone();
266        let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
267
268        // Cancel immediately
269        cancel.cancel();
270
271        let result = handle.await.unwrap();
272        assert!(result.is_ok());
273    }
274
275    #[tokio::test]
276    async fn cancellation_exits_within_timeout() {
277        let cancel = CancellationToken::new();
278        let runner = Arc::new(mock_runner_for_batch());
279        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
280        let issues = make_github_provider(&github);
281        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
282
283        let mut config = Config::default();
284        config.pipeline.poll_interval = 3600;
285
286        let executor = Arc::new(PipelineExecutor {
287            runner,
288            github,
289            issues,
290            db,
291            config,
292            cancel_token: cancel.clone(),
293            repo_dir: PathBuf::from("/tmp"),
294        });
295
296        let cancel_clone = cancel.clone();
297        let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
298
299        cancel.cancel();
300
301        let result = tokio::time::timeout(Duration::from_secs(5), handle)
302            .await
303            .expect("polling loop should exit within timeout")
304            .unwrap();
305        assert!(result.is_ok());
306    }
307
308    #[tokio::test]
309    async fn in_flight_set_filters_duplicate_issues() {
310        let in_flight: Arc<Mutex<HashSet<u32>>> = Arc::new(Mutex::new(HashSet::new()));
311
312        // Simulate issue 1 already in flight
313        in_flight.lock().await.insert(1);
314
315        let issues = vec![
316            PipelineIssue {
317                number: 1,
318                title: "Already running".to_string(),
319                body: String::new(),
320                source: IssueOrigin::Github,
321                target_repo: None,
322            },
323            PipelineIssue {
324                number: 2,
325                title: "New issue".to_string(),
326                body: String::new(),
327                source: IssueOrigin::Github,
328                target_repo: None,
329            },
330            PipelineIssue {
331                number: 3,
332                title: "Another new".to_string(),
333                body: String::new(),
334                source: IssueOrigin::Github,
335                target_repo: None,
336            },
337        ];
338
339        let guard = in_flight.lock().await;
340        let new_issues: Vec<_> =
341            issues.into_iter().filter(|i| !guard.contains(&i.number)).collect();
342        drop(guard);
343
344        assert_eq!(new_issues.len(), 2);
345        assert_eq!(new_issues[0].number, 2);
346        assert_eq!(new_issues[1].number, 3);
347    }
348
349    #[test]
350    fn handle_task_result_does_not_panic_on_success() {
351        handle_task_result(Ok((1, Ok(()))));
352    }
353
354    #[test]
355    fn handle_task_result_does_not_panic_on_error() {
356        handle_task_result(Ok((1, Err(anyhow::anyhow!("test error")))));
357    }
358
359    #[tokio::test]
360    async fn get_complexity_map_returns_empty_on_planner_failure() {
361        let mut mock = MockCommandRunner::new();
362        mock.expect_run_gh().returning(|_, _| {
363            Box::pin(async {
364                Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
365            })
366        });
367        mock.expect_run_claude().returning(|_, _, _, _| {
368            Box::pin(async {
369                Ok(AgentResult {
370                    cost_usd: 0.5,
371                    duration: Duration::from_secs(2),
372                    turns: 1,
373                    output: "I don't know how to plan".to_string(),
374                    session_id: "sess-plan".to_string(),
375                    success: true,
376                })
377            })
378        });
379
380        let runner = Arc::new(mock);
381        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
382        let issues_provider = make_github_provider(&github);
383        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
384
385        let executor = Arc::new(PipelineExecutor {
386            runner,
387            github,
388            issues: issues_provider,
389            db,
390            config: Config::default(),
391            cancel_token: CancellationToken::new(),
392            repo_dir: PathBuf::from("/tmp"),
393        });
394
395        let issues = vec![PipelineIssue {
396            number: 1,
397            title: "Test".to_string(),
398            body: "body".to_string(),
399            source: IssueOrigin::Github,
400            target_repo: None,
401        }];
402
403        let map = get_complexity_map(&executor, &issues).await;
404        assert!(map.is_empty());
405    }
406
407    #[tokio::test]
408    async fn get_complexity_map_extracts_complexity() {
409        let mut mock = MockCommandRunner::new();
410        mock.expect_run_gh().returning(|_, _| {
411            Box::pin(async {
412                Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
413            })
414        });
415        mock.expect_run_claude().returning(|_, _, _, _| {
416            Box::pin(async {
417                Ok(AgentResult {
418                    cost_usd: 0.5,
419                    duration: Duration::from_secs(2),
420                    turns: 1,
421                    output: r#"{"batches":[{"batch":1,"issues":[{"number":1,"complexity":"simple"},{"number":2,"complexity":"full"}],"reasoning":"ok"}],"total_issues":2,"parallel_capacity":2}"#.to_string(),
422                    session_id: "sess-plan".to_string(),
423                    success: true,
424                })
425            })
426        });
427
428        let runner = Arc::new(mock);
429        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
430        let issues_provider = make_github_provider(&github);
431        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
432
433        let executor = Arc::new(PipelineExecutor {
434            runner,
435            github,
436            issues: issues_provider,
437            db,
438            config: Config::default(),
439            cancel_token: CancellationToken::new(),
440            repo_dir: PathBuf::from("/tmp"),
441        });
442
443        let issues = vec![
444            PipelineIssue {
445                number: 1,
446                title: "Simple".to_string(),
447                body: "simple change".to_string(),
448                source: IssueOrigin::Github,
449                target_repo: None,
450            },
451            PipelineIssue {
452                number: 2,
453                title: "Complex".to_string(),
454                body: "big feature".to_string(),
455                source: IssueOrigin::Github,
456                target_repo: None,
457            },
458        ];
459
460        let map = get_complexity_map(&executor, &issues).await;
461        assert_eq!(map.get(&1), Some(&Complexity::Simple));
462        assert_eq!(map.get(&2), Some(&Complexity::Full));
463    }
464}