Skip to main content

oven_cli/pipeline/
runner.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use anyhow::Result;
4use tokio::{
5    sync::{Mutex, Semaphore},
6    task::JoinSet,
7};
8use tokio_util::sync::CancellationToken;
9use tracing::{error, info, warn};
10
11use super::executor::PipelineExecutor;
12use crate::{
13    agents::{InFlightIssue, PlannerOutput},
14    issues::PipelineIssue,
15    process::CommandRunner,
16};
17
18/// Run the pipeline for a batch of issues using planner-driven sequencing.
19///
20/// Used for the explicit-IDs path (`oven on 42,43`). Calls the planner with no
21/// in-flight context, then runs batches sequentially (issues within each batch
22/// run in parallel). Falls back to all-parallel if the planner fails.
23pub async fn run_batch<R: CommandRunner + 'static>(
24    executor: &Arc<PipelineExecutor<R>>,
25    issues: Vec<PipelineIssue>,
26    max_parallel: usize,
27    auto_merge: bool,
28) -> Result<()> {
29    if let Some(plan) = executor.plan_issues(&issues, &[]).await {
30        info!(
31            batches = plan.batches.len(),
32            total = plan.total_issues,
33            "planner produced a plan, running batches sequentially"
34        );
35        run_batches_sequentially(executor, &issues, &plan, max_parallel, auto_merge).await
36    } else {
37        warn!("planner failed, falling back to all-parallel execution");
38        run_all_parallel(executor, issues, max_parallel, auto_merge).await
39    }
40}
41
42/// Run planner batches in sequence: wait for batch N to complete before starting batch N+1.
43/// Issues within each batch run in parallel.
44async fn run_batches_sequentially<R: CommandRunner + 'static>(
45    executor: &Arc<PipelineExecutor<R>>,
46    issues: &[PipelineIssue],
47    plan: &PlannerOutput,
48    max_parallel: usize,
49    auto_merge: bool,
50) -> Result<()> {
51    let issue_map: HashMap<u32, &PipelineIssue> = issues.iter().map(|i| (i.number, i)).collect();
52
53    for batch in &plan.batches {
54        let batch_issues: Vec<PipelineIssue> = batch
55            .issues
56            .iter()
57            .filter_map(|pi| issue_map.get(&pi.number).map(|i| (*i).clone()))
58            .collect();
59
60        if batch_issues.is_empty() {
61            continue;
62        }
63
64        info!(
65            batch = batch.batch,
66            count = batch_issues.len(),
67            reasoning = %batch.reasoning,
68            "starting batch"
69        );
70
71        run_single_batch(executor, batch_issues, &batch.issues, max_parallel, auto_merge).await?;
72    }
73
74    Ok(())
75}
76
77/// Run a single batch of issues in parallel with complexity from planner output.
78async fn run_single_batch<R: CommandRunner + 'static>(
79    executor: &Arc<PipelineExecutor<R>>,
80    issues: Vec<PipelineIssue>,
81    planned: &[crate::agents::PlannedIssue],
82    max_parallel: usize,
83    auto_merge: bool,
84) -> Result<()> {
85    let complexity_map: HashMap<u32, crate::agents::Complexity> =
86        planned.iter().map(|pi| (pi.number, pi.complexity.clone())).collect();
87    let semaphore = Arc::new(Semaphore::new(max_parallel));
88    let mut tasks = JoinSet::new();
89
90    for issue in issues {
91        let permit = semaphore
92            .clone()
93            .acquire_owned()
94            .await
95            .map_err(|e| anyhow::anyhow!("semaphore closed: {e}"))?;
96        let exec = Arc::clone(executor);
97        let complexity = complexity_map.get(&issue.number).cloned();
98        tasks.spawn(async move {
99            let number = issue.number;
100            let result = exec.run_issue_with_complexity(&issue, auto_merge, complexity).await;
101            drop(permit);
102            (number, result)
103        });
104    }
105
106    let mut had_errors = false;
107    while let Some(join_result) = tasks.join_next().await {
108        match join_result {
109            Ok((number, Err(e))) => {
110                error!(issue = number, error = %e, "pipeline failed for issue");
111                had_errors = true;
112            }
113            Err(e) => {
114                error!(error = %e, "pipeline task panicked");
115                had_errors = true;
116            }
117            Ok((number, Ok(()))) => {
118                info!(issue = number, "pipeline completed successfully");
119            }
120        }
121    }
122
123    if had_errors { Err(anyhow::anyhow!("one or more pipelines failed in batch")) } else { Ok(()) }
124}
125
126/// Fallback: run all issues in parallel behind a semaphore (no planner guidance).
127async fn run_all_parallel<R: CommandRunner + 'static>(
128    executor: &Arc<PipelineExecutor<R>>,
129    issues: Vec<PipelineIssue>,
130    max_parallel: usize,
131    auto_merge: bool,
132) -> Result<()> {
133    let semaphore = Arc::new(Semaphore::new(max_parallel));
134    let mut tasks = JoinSet::new();
135
136    for issue in issues {
137        let permit = semaphore
138            .clone()
139            .acquire_owned()
140            .await
141            .map_err(|e| anyhow::anyhow!("semaphore closed: {e}"))?;
142        let exec = Arc::clone(executor);
143        tasks.spawn(async move {
144            let number = issue.number;
145            let result = exec.run_issue(&issue, auto_merge).await;
146            drop(permit);
147            (number, result)
148        });
149    }
150
151    let mut had_errors = false;
152    while let Some(join_result) = tasks.join_next().await {
153        match join_result {
154            Ok((number, Ok(()))) => {
155                info!(issue = number, "pipeline completed successfully");
156            }
157            Ok((number, Err(e))) => {
158                error!(issue = number, error = %e, "pipeline failed for issue");
159                had_errors = true;
160            }
161            Err(e) => {
162                error!(error = %e, "pipeline task panicked");
163                had_errors = true;
164            }
165        }
166    }
167
168    if had_errors {
169        anyhow::bail!("one or more pipelines failed");
170    }
171    Ok(())
172}
173
174fn handle_task_result(result: Result<(u32, Result<()>), tokio::task::JoinError>) {
175    match result {
176        Ok((number, Ok(()))) => {
177            info!(issue = number, "pipeline completed successfully");
178        }
179        Ok((number, Err(e))) => {
180            error!(issue = number, error = %e, "pipeline failed for issue");
181        }
182        Err(e) => {
183            error!(error = %e, "pipeline task panicked");
184        }
185    }
186}
187
188/// Poll for new issues and run them through the pipeline.
189///
190/// Unlike `run_batch`, this function continuously polls for new issues even while
191/// existing pipelines are running. Uses a shared semaphore and `JoinSet` that persist
192/// across poll cycles, with in-flight tracking to prevent double-spawning.
193///
194/// The planner receives in-flight metadata so it can avoid scheduling conflicting work
195/// in batch 1. Only batch 1 issues are spawned each cycle; deferred issues keep `o-ready`
196/// and naturally reappear on the next poll.
197pub async fn polling_loop<R: CommandRunner + 'static>(
198    executor: Arc<PipelineExecutor<R>>,
199    auto_merge: bool,
200    cancel_token: CancellationToken,
201) -> Result<()> {
202    let poll_interval = Duration::from_secs(executor.config.pipeline.poll_interval);
203    let max_parallel = executor.config.pipeline.max_parallel as usize;
204    let ready_label = executor.config.labels.ready.clone();
205    let semaphore = Arc::new(Semaphore::new(max_parallel));
206    let mut tasks = JoinSet::new();
207    let in_flight: Arc<Mutex<HashMap<u32, InFlightIssue>>> = Arc::new(Mutex::new(HashMap::new()));
208
209    info!(poll_interval_secs = poll_interval.as_secs(), max_parallel, "continuous polling started");
210
211    loop {
212        tokio::select! {
213            () = cancel_token.cancelled() => {
214                info!("shutdown signal received, waiting for in-flight pipelines");
215                while let Some(result) = tasks.join_next().await {
216                    handle_task_result(result);
217                }
218                break;
219            }
220            () = tokio::time::sleep(poll_interval) => {
221                poll_and_spawn(
222                    &executor, &ready_label, &semaphore, &in_flight,
223                    &mut tasks, auto_merge,
224                ).await;
225            }
226            Some(result) = tasks.join_next(), if !tasks.is_empty() => {
227                handle_task_result(result);
228            }
229        }
230    }
231
232    Ok(())
233}
234
235/// Single poll cycle: fetch ready issues, plan, and spawn batch 1.
236async fn poll_and_spawn<R: CommandRunner + 'static>(
237    executor: &Arc<PipelineExecutor<R>>,
238    ready_label: &str,
239    semaphore: &Arc<Semaphore>,
240    in_flight: &Arc<Mutex<HashMap<u32, InFlightIssue>>>,
241    tasks: &mut JoinSet<(u32, Result<()>)>,
242    auto_merge: bool,
243) {
244    let issues = match executor.issues.get_ready_issues(ready_label).await {
245        Ok(i) => i,
246        Err(e) => {
247            error!(error = %e, "failed to fetch issues");
248            return;
249        }
250    };
251
252    let in_flight_guard = in_flight.lock().await;
253    let new_issues: Vec<_> =
254        issues.into_iter().filter(|i| !in_flight_guard.contains_key(&i.number)).collect();
255    let in_flight_snapshot: Vec<InFlightIssue> = in_flight_guard.values().cloned().collect();
256    drop(in_flight_guard);
257
258    if new_issues.is_empty() {
259        info!("no new issues found, waiting");
260        return;
261    }
262
263    info!(count = new_issues.len(), "found new issues to process");
264
265    let (batch1_issues, metadata_map) =
266        if let Some(plan) = executor.plan_issues(&new_issues, &in_flight_snapshot).await {
267            info!(
268                batches = plan.batches.len(),
269                total = plan.total_issues,
270                "planner produced a plan, spawning batch 1 only"
271            );
272            extract_batch1(&plan)
273        } else {
274            warn!("planner failed, falling back to spawning all issues");
275            let all: HashMap<u32, InFlightIssue> =
276                new_issues.iter().map(|i| (i.number, InFlightIssue::from_issue(i))).collect();
277            let numbers: Vec<u32> = all.keys().copied().collect();
278            (numbers, all)
279        };
280
281    for issue in new_issues {
282        if !batch1_issues.contains(&issue.number) {
283            info!(issue = issue.number, "deferring issue to next poll cycle (not in batch 1)");
284            continue;
285        }
286
287        let sem = Arc::clone(semaphore);
288        let exec = Arc::clone(executor);
289        let in_fl = Arc::clone(in_flight);
290        let number = issue.number;
291        let complexity = metadata_map.get(&number).map(|m| m.complexity.clone());
292
293        let metadata =
294            metadata_map.get(&number).cloned().unwrap_or_else(|| InFlightIssue::from_issue(&issue));
295        in_fl.lock().await.insert(number, metadata);
296
297        tasks.spawn(async move {
298            let permit = match sem.acquire_owned().await {
299                Ok(p) => p,
300                Err(e) => {
301                    in_fl.lock().await.remove(&number);
302                    return (number, Err(anyhow::anyhow!("semaphore closed: {e}")));
303                }
304            };
305            let result = exec.run_issue_with_complexity(&issue, auto_merge, complexity).await;
306            in_fl.lock().await.remove(&number);
307            drop(permit);
308            (number, result)
309        });
310    }
311}
312
313/// Extract batch 1 issue numbers and their planner metadata from a planner output.
314fn extract_batch1(plan: &PlannerOutput) -> (Vec<u32>, HashMap<u32, InFlightIssue>) {
315    let mut batch1_numbers = Vec::new();
316    let mut metadata_map = HashMap::new();
317
318    if let Some(batch) = plan.batches.first() {
319        for pi in &batch.issues {
320            batch1_numbers.push(pi.number);
321            metadata_map.insert(pi.number, InFlightIssue::from(pi));
322        }
323    }
324
325    (batch1_numbers, metadata_map)
326}
327
328#[cfg(test)]
329mod tests {
330    use std::path::PathBuf;
331
332    use tokio::sync::Mutex;
333
334    use super::*;
335    use crate::{
336        agents::{Complexity, InFlightIssue},
337        config::Config,
338        github::GhClient,
339        issues::{IssueOrigin, IssueProvider, github::GithubIssueProvider},
340        process::{AgentResult, CommandOutput, MockCommandRunner},
341    };
342
343    fn mock_runner_for_batch() -> MockCommandRunner {
344        let mut mock = MockCommandRunner::new();
345        mock.expect_run_gh().returning(|_, _| {
346            Box::pin(async {
347                Ok(CommandOutput {
348                    stdout: "https://github.com/user/repo/pull/1\n".to_string(),
349                    stderr: String::new(),
350                    success: true,
351                })
352            })
353        });
354        mock.expect_run_claude().returning(|_, _, _, _| {
355            Box::pin(async {
356                Ok(AgentResult {
357                    cost_usd: 1.0,
358                    duration: Duration::from_secs(5),
359                    turns: 3,
360                    output: r#"{"findings":[],"summary":"clean"}"#.to_string(),
361                    session_id: "sess-1".to_string(),
362                    success: true,
363                })
364            })
365        });
366        mock
367    }
368
369    fn make_github_provider(gh: &Arc<GhClient<MockCommandRunner>>) -> Arc<dyn IssueProvider> {
370        Arc::new(GithubIssueProvider::new(Arc::clone(gh), "target_repo"))
371    }
372
373    #[tokio::test]
374    async fn cancellation_stops_polling() {
375        let cancel = CancellationToken::new();
376        let runner = Arc::new(mock_runner_for_batch());
377        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
378        let issues = make_github_provider(&github);
379        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
380
381        let mut config = Config::default();
382        config.pipeline.poll_interval = 3600; // very long so we don't actually poll
383
384        let executor = Arc::new(PipelineExecutor {
385            runner,
386            github,
387            issues,
388            db,
389            config,
390            cancel_token: cancel.clone(),
391            repo_dir: PathBuf::from("/tmp"),
392        });
393
394        let cancel_clone = cancel.clone();
395        let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
396
397        // Cancel immediately
398        cancel.cancel();
399
400        let result = handle.await.unwrap();
401        assert!(result.is_ok());
402    }
403
404    #[tokio::test]
405    async fn cancellation_exits_within_timeout() {
406        let cancel = CancellationToken::new();
407        let runner = Arc::new(mock_runner_for_batch());
408        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
409        let issues = make_github_provider(&github);
410        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
411
412        let mut config = Config::default();
413        config.pipeline.poll_interval = 3600;
414
415        let executor = Arc::new(PipelineExecutor {
416            runner,
417            github,
418            issues,
419            db,
420            config,
421            cancel_token: cancel.clone(),
422            repo_dir: PathBuf::from("/tmp"),
423        });
424
425        let cancel_clone = cancel.clone();
426        let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
427
428        cancel.cancel();
429
430        let result = tokio::time::timeout(Duration::from_secs(5), handle)
431            .await
432            .expect("polling loop should exit within timeout")
433            .unwrap();
434        assert!(result.is_ok());
435    }
436
437    #[tokio::test]
438    async fn in_flight_map_filters_duplicate_issues() {
439        let in_flight: Arc<Mutex<HashMap<u32, InFlightIssue>>> =
440            Arc::new(Mutex::new(HashMap::new()));
441
442        // Simulate issue 1 already in flight
443        in_flight.lock().await.insert(
444            1,
445            InFlightIssue {
446                number: 1,
447                title: "Already running".to_string(),
448                area: "auth".to_string(),
449                predicted_files: vec!["src/auth.rs".to_string()],
450                has_migration: false,
451                complexity: Complexity::Full,
452            },
453        );
454
455        let issues = vec![
456            PipelineIssue {
457                number: 1,
458                title: "Already running".to_string(),
459                body: String::new(),
460                source: IssueOrigin::Github,
461                target_repo: None,
462            },
463            PipelineIssue {
464                number: 2,
465                title: "New issue".to_string(),
466                body: String::new(),
467                source: IssueOrigin::Github,
468                target_repo: None,
469            },
470            PipelineIssue {
471                number: 3,
472                title: "Another new".to_string(),
473                body: String::new(),
474                source: IssueOrigin::Github,
475                target_repo: None,
476            },
477        ];
478
479        let guard = in_flight.lock().await;
480        let new_issues: Vec<_> =
481            issues.into_iter().filter(|i| !guard.contains_key(&i.number)).collect();
482        drop(guard);
483
484        assert_eq!(new_issues.len(), 2);
485        assert_eq!(new_issues[0].number, 2);
486        assert_eq!(new_issues[1].number, 3);
487    }
488
489    #[test]
490    fn handle_task_result_does_not_panic_on_success() {
491        handle_task_result(Ok((1, Ok(()))));
492    }
493
494    #[test]
495    fn handle_task_result_does_not_panic_on_error() {
496        handle_task_result(Ok((1, Err(anyhow::anyhow!("test error")))));
497    }
498
499    #[test]
500    fn extract_batch1_returns_first_batch_only() {
501        let plan = crate::agents::PlannerOutput {
502            batches: vec![
503                crate::agents::Batch {
504                    batch: 1,
505                    issues: vec![
506                        crate::agents::PlannedIssue {
507                            number: 1,
508                            title: "First".to_string(),
509                            area: "cli".to_string(),
510                            predicted_files: vec!["src/cli.rs".to_string()],
511                            has_migration: false,
512                            complexity: Complexity::Simple,
513                        },
514                        crate::agents::PlannedIssue {
515                            number: 2,
516                            title: "Second".to_string(),
517                            area: "config".to_string(),
518                            predicted_files: vec!["src/config.rs".to_string()],
519                            has_migration: false,
520                            complexity: Complexity::Full,
521                        },
522                    ],
523                    reasoning: "independent".to_string(),
524                },
525                crate::agents::Batch {
526                    batch: 2,
527                    issues: vec![crate::agents::PlannedIssue {
528                        number: 3,
529                        title: "Third".to_string(),
530                        area: "db".to_string(),
531                        predicted_files: vec!["src/db.rs".to_string()],
532                        has_migration: true,
533                        complexity: Complexity::Full,
534                    }],
535                    reasoning: "depends on batch 1".to_string(),
536                },
537            ],
538            total_issues: 3,
539            parallel_capacity: 2,
540        };
541
542        let (batch1_numbers, metadata_map) = extract_batch1(&plan);
543        assert_eq!(batch1_numbers, vec![1, 2]);
544        assert!(!batch1_numbers.contains(&3));
545        assert_eq!(metadata_map.get(&1).unwrap().complexity, Complexity::Simple);
546        assert_eq!(metadata_map.get(&1).unwrap().area, "cli");
547        assert_eq!(metadata_map.get(&2).unwrap().complexity, Complexity::Full);
548        assert!(!metadata_map.contains_key(&3));
549    }
550
551    #[test]
552    fn extract_batch1_empty_plan() {
553        let plan =
554            crate::agents::PlannerOutput { batches: vec![], total_issues: 0, parallel_capacity: 0 };
555        let (batch1, metadata) = extract_batch1(&plan);
556        assert!(batch1.is_empty());
557        assert!(metadata.is_empty());
558    }
559
560    #[tokio::test]
561    async fn planner_failure_falls_back_to_all_parallel() {
562        let mut mock = MockCommandRunner::new();
563        mock.expect_run_gh().returning(|_, _| {
564            Box::pin(async {
565                Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
566            })
567        });
568        mock.expect_run_claude().returning(|_, _, _, _| {
569            Box::pin(async {
570                Ok(AgentResult {
571                    cost_usd: 0.5,
572                    duration: Duration::from_secs(2),
573                    turns: 1,
574                    output: "I don't know how to plan".to_string(),
575                    session_id: "sess-plan".to_string(),
576                    success: true,
577                })
578            })
579        });
580
581        let runner = Arc::new(mock);
582        let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
583        let issues_provider = make_github_provider(&github);
584        let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
585
586        let executor = Arc::new(PipelineExecutor {
587            runner,
588            github,
589            issues: issues_provider,
590            db,
591            config: Config::default(),
592            cancel_token: CancellationToken::new(),
593            repo_dir: PathBuf::from("/tmp"),
594        });
595
596        let issues = vec![PipelineIssue {
597            number: 1,
598            title: "Test".to_string(),
599            body: "body".to_string(),
600            source: IssueOrigin::Github,
601            target_repo: None,
602        }];
603
604        // plan_issues returns None for unparseable output
605        let plan = executor.plan_issues(&issues, &[]).await;
606        assert!(plan.is_none());
607    }
608}