1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use anyhow::Result;
4use tokio::{
5 sync::{Mutex, Semaphore},
6 task::JoinSet,
7};
8use tokio_util::sync::CancellationToken;
9use tracing::{error, info, warn};
10
11use super::executor::PipelineExecutor;
12use crate::{
13 agents::{InFlightIssue, PlannerOutput},
14 issues::PipelineIssue,
15 process::CommandRunner,
16};
17
18pub async fn run_batch<R: CommandRunner + 'static>(
24 executor: &Arc<PipelineExecutor<R>>,
25 issues: Vec<PipelineIssue>,
26 max_parallel: usize,
27 auto_merge: bool,
28) -> Result<()> {
29 if let Some(plan) = executor.plan_issues(&issues, &[]).await {
30 info!(
31 batches = plan.batches.len(),
32 total = plan.total_issues,
33 "planner produced a plan, running batches sequentially"
34 );
35 run_batches_sequentially(executor, &issues, &plan, max_parallel, auto_merge).await
36 } else {
37 warn!("planner failed, falling back to all-parallel execution");
38 run_all_parallel(executor, issues, max_parallel, auto_merge).await
39 }
40}
41
42async fn run_batches_sequentially<R: CommandRunner + 'static>(
45 executor: &Arc<PipelineExecutor<R>>,
46 issues: &[PipelineIssue],
47 plan: &PlannerOutput,
48 max_parallel: usize,
49 auto_merge: bool,
50) -> Result<()> {
51 let issue_map: HashMap<u32, &PipelineIssue> = issues.iter().map(|i| (i.number, i)).collect();
52
53 for batch in &plan.batches {
54 let batch_issues: Vec<PipelineIssue> = batch
55 .issues
56 .iter()
57 .filter_map(|pi| issue_map.get(&pi.number).map(|i| (*i).clone()))
58 .collect();
59
60 if batch_issues.is_empty() {
61 continue;
62 }
63
64 info!(
65 batch = batch.batch,
66 count = batch_issues.len(),
67 reasoning = %batch.reasoning,
68 "starting batch"
69 );
70
71 run_single_batch(executor, batch_issues, &batch.issues, max_parallel, auto_merge).await?;
72 }
73
74 Ok(())
75}
76
77async fn run_single_batch<R: CommandRunner + 'static>(
79 executor: &Arc<PipelineExecutor<R>>,
80 issues: Vec<PipelineIssue>,
81 planned: &[crate::agents::PlannedIssue],
82 max_parallel: usize,
83 auto_merge: bool,
84) -> Result<()> {
85 let complexity_map: HashMap<u32, crate::agents::Complexity> =
86 planned.iter().map(|pi| (pi.number, pi.complexity.clone())).collect();
87 let semaphore = Arc::new(Semaphore::new(max_parallel));
88 let mut tasks = JoinSet::new();
89
90 for issue in issues {
91 let permit = semaphore
92 .clone()
93 .acquire_owned()
94 .await
95 .map_err(|e| anyhow::anyhow!("semaphore closed: {e}"))?;
96 let exec = Arc::clone(executor);
97 let complexity = complexity_map.get(&issue.number).cloned();
98 tasks.spawn(async move {
99 let number = issue.number;
100 let result = exec.run_issue_with_complexity(&issue, auto_merge, complexity).await;
101 drop(permit);
102 (number, result)
103 });
104 }
105
106 let mut had_errors = false;
107 while let Some(join_result) = tasks.join_next().await {
108 match join_result {
109 Ok((number, Err(e))) => {
110 error!(issue = number, error = %e, "pipeline failed for issue");
111 had_errors = true;
112 }
113 Err(e) => {
114 error!(error = %e, "pipeline task panicked");
115 had_errors = true;
116 }
117 Ok((number, Ok(()))) => {
118 info!(issue = number, "pipeline completed successfully");
119 }
120 }
121 }
122
123 if had_errors { Err(anyhow::anyhow!("one or more pipelines failed in batch")) } else { Ok(()) }
124}
125
126async fn run_all_parallel<R: CommandRunner + 'static>(
128 executor: &Arc<PipelineExecutor<R>>,
129 issues: Vec<PipelineIssue>,
130 max_parallel: usize,
131 auto_merge: bool,
132) -> Result<()> {
133 let semaphore = Arc::new(Semaphore::new(max_parallel));
134 let mut tasks = JoinSet::new();
135
136 for issue in issues {
137 let permit = semaphore
138 .clone()
139 .acquire_owned()
140 .await
141 .map_err(|e| anyhow::anyhow!("semaphore closed: {e}"))?;
142 let exec = Arc::clone(executor);
143 tasks.spawn(async move {
144 let number = issue.number;
145 let result = exec.run_issue(&issue, auto_merge).await;
146 drop(permit);
147 (number, result)
148 });
149 }
150
151 let mut had_errors = false;
152 while let Some(join_result) = tasks.join_next().await {
153 match join_result {
154 Ok((number, Ok(()))) => {
155 info!(issue = number, "pipeline completed successfully");
156 }
157 Ok((number, Err(e))) => {
158 error!(issue = number, error = %e, "pipeline failed for issue");
159 had_errors = true;
160 }
161 Err(e) => {
162 error!(error = %e, "pipeline task panicked");
163 had_errors = true;
164 }
165 }
166 }
167
168 if had_errors {
169 anyhow::bail!("one or more pipelines failed");
170 }
171 Ok(())
172}
173
174fn handle_task_result(result: Result<(u32, Result<()>), tokio::task::JoinError>) {
175 match result {
176 Ok((number, Ok(()))) => {
177 info!(issue = number, "pipeline completed successfully");
178 }
179 Ok((number, Err(e))) => {
180 error!(issue = number, error = %e, "pipeline failed for issue");
181 }
182 Err(e) => {
183 error!(error = %e, "pipeline task panicked");
184 }
185 }
186}
187
188pub async fn polling_loop<R: CommandRunner + 'static>(
198 executor: Arc<PipelineExecutor<R>>,
199 auto_merge: bool,
200 cancel_token: CancellationToken,
201) -> Result<()> {
202 let poll_interval = Duration::from_secs(executor.config.pipeline.poll_interval);
203 let max_parallel = executor.config.pipeline.max_parallel as usize;
204 let ready_label = executor.config.labels.ready.clone();
205 let semaphore = Arc::new(Semaphore::new(max_parallel));
206 let mut tasks = JoinSet::new();
207 let in_flight: Arc<Mutex<HashMap<u32, InFlightIssue>>> = Arc::new(Mutex::new(HashMap::new()));
208
209 info!(poll_interval_secs = poll_interval.as_secs(), max_parallel, "continuous polling started");
210
211 loop {
212 tokio::select! {
213 () = cancel_token.cancelled() => {
214 info!("shutdown signal received, waiting for in-flight pipelines");
215 while let Some(result) = tasks.join_next().await {
216 handle_task_result(result);
217 }
218 break;
219 }
220 () = tokio::time::sleep(poll_interval) => {
221 poll_and_spawn(
222 &executor, &ready_label, &semaphore, &in_flight,
223 &mut tasks, auto_merge,
224 ).await;
225 }
226 Some(result) = tasks.join_next(), if !tasks.is_empty() => {
227 handle_task_result(result);
228 }
229 }
230 }
231
232 Ok(())
233}
234
235async fn poll_and_spawn<R: CommandRunner + 'static>(
237 executor: &Arc<PipelineExecutor<R>>,
238 ready_label: &str,
239 semaphore: &Arc<Semaphore>,
240 in_flight: &Arc<Mutex<HashMap<u32, InFlightIssue>>>,
241 tasks: &mut JoinSet<(u32, Result<()>)>,
242 auto_merge: bool,
243) {
244 let issues = match executor.issues.get_ready_issues(ready_label).await {
245 Ok(i) => i,
246 Err(e) => {
247 error!(error = %e, "failed to fetch issues");
248 return;
249 }
250 };
251
252 let in_flight_guard = in_flight.lock().await;
253 let new_issues: Vec<_> =
254 issues.into_iter().filter(|i| !in_flight_guard.contains_key(&i.number)).collect();
255 let in_flight_snapshot: Vec<InFlightIssue> = in_flight_guard.values().cloned().collect();
256 drop(in_flight_guard);
257
258 if new_issues.is_empty() {
259 info!("no new issues found, waiting");
260 return;
261 }
262
263 info!(count = new_issues.len(), "found new issues to process");
264
265 let (batch1_issues, metadata_map) =
266 if let Some(plan) = executor.plan_issues(&new_issues, &in_flight_snapshot).await {
267 info!(
268 batches = plan.batches.len(),
269 total = plan.total_issues,
270 "planner produced a plan, spawning batch 1 only"
271 );
272 extract_batch1(&plan)
273 } else {
274 warn!("planner failed, falling back to spawning all issues");
275 let all: HashMap<u32, InFlightIssue> =
276 new_issues.iter().map(|i| (i.number, InFlightIssue::from_issue(i))).collect();
277 let numbers: Vec<u32> = all.keys().copied().collect();
278 (numbers, all)
279 };
280
281 for issue in new_issues {
282 if !batch1_issues.contains(&issue.number) {
283 info!(issue = issue.number, "deferring issue to next poll cycle (not in batch 1)");
284 continue;
285 }
286
287 let sem = Arc::clone(semaphore);
288 let exec = Arc::clone(executor);
289 let in_fl = Arc::clone(in_flight);
290 let number = issue.number;
291 let complexity = metadata_map.get(&number).map(|m| m.complexity.clone());
292
293 let metadata =
294 metadata_map.get(&number).cloned().unwrap_or_else(|| InFlightIssue::from_issue(&issue));
295 in_fl.lock().await.insert(number, metadata);
296
297 tasks.spawn(async move {
298 let permit = match sem.acquire_owned().await {
299 Ok(p) => p,
300 Err(e) => {
301 in_fl.lock().await.remove(&number);
302 return (number, Err(anyhow::anyhow!("semaphore closed: {e}")));
303 }
304 };
305 let result = exec.run_issue_with_complexity(&issue, auto_merge, complexity).await;
306 in_fl.lock().await.remove(&number);
307 drop(permit);
308 (number, result)
309 });
310 }
311}
312
313fn extract_batch1(plan: &PlannerOutput) -> (Vec<u32>, HashMap<u32, InFlightIssue>) {
315 let mut batch1_numbers = Vec::new();
316 let mut metadata_map = HashMap::new();
317
318 if let Some(batch) = plan.batches.first() {
319 for pi in &batch.issues {
320 batch1_numbers.push(pi.number);
321 metadata_map.insert(pi.number, InFlightIssue::from(pi));
322 }
323 }
324
325 (batch1_numbers, metadata_map)
326}
327
328#[cfg(test)]
329mod tests {
330 use std::path::PathBuf;
331
332 use tokio::sync::Mutex;
333
334 use super::*;
335 use crate::{
336 agents::{Complexity, InFlightIssue},
337 config::Config,
338 github::GhClient,
339 issues::{IssueOrigin, IssueProvider, github::GithubIssueProvider},
340 process::{AgentResult, CommandOutput, MockCommandRunner},
341 };
342
343 fn mock_runner_for_batch() -> MockCommandRunner {
344 let mut mock = MockCommandRunner::new();
345 mock.expect_run_gh().returning(|_, _| {
346 Box::pin(async {
347 Ok(CommandOutput {
348 stdout: "https://github.com/user/repo/pull/1\n".to_string(),
349 stderr: String::new(),
350 success: true,
351 })
352 })
353 });
354 mock.expect_run_claude().returning(|_, _, _, _| {
355 Box::pin(async {
356 Ok(AgentResult {
357 cost_usd: 1.0,
358 duration: Duration::from_secs(5),
359 turns: 3,
360 output: r#"{"findings":[],"summary":"clean"}"#.to_string(),
361 session_id: "sess-1".to_string(),
362 success: true,
363 })
364 })
365 });
366 mock
367 }
368
369 fn make_github_provider(gh: &Arc<GhClient<MockCommandRunner>>) -> Arc<dyn IssueProvider> {
370 Arc::new(GithubIssueProvider::new(Arc::clone(gh), "target_repo"))
371 }
372
373 #[tokio::test]
374 async fn cancellation_stops_polling() {
375 let cancel = CancellationToken::new();
376 let runner = Arc::new(mock_runner_for_batch());
377 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
378 let issues = make_github_provider(&github);
379 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
380
381 let mut config = Config::default();
382 config.pipeline.poll_interval = 3600; let executor = Arc::new(PipelineExecutor {
385 runner,
386 github,
387 issues,
388 db,
389 config,
390 cancel_token: cancel.clone(),
391 repo_dir: PathBuf::from("/tmp"),
392 });
393
394 let cancel_clone = cancel.clone();
395 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
396
397 cancel.cancel();
399
400 let result = handle.await.unwrap();
401 assert!(result.is_ok());
402 }
403
404 #[tokio::test]
405 async fn cancellation_exits_within_timeout() {
406 let cancel = CancellationToken::new();
407 let runner = Arc::new(mock_runner_for_batch());
408 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
409 let issues = make_github_provider(&github);
410 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
411
412 let mut config = Config::default();
413 config.pipeline.poll_interval = 3600;
414
415 let executor = Arc::new(PipelineExecutor {
416 runner,
417 github,
418 issues,
419 db,
420 config,
421 cancel_token: cancel.clone(),
422 repo_dir: PathBuf::from("/tmp"),
423 });
424
425 let cancel_clone = cancel.clone();
426 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
427
428 cancel.cancel();
429
430 let result = tokio::time::timeout(Duration::from_secs(5), handle)
431 .await
432 .expect("polling loop should exit within timeout")
433 .unwrap();
434 assert!(result.is_ok());
435 }
436
437 #[tokio::test]
438 async fn in_flight_map_filters_duplicate_issues() {
439 let in_flight: Arc<Mutex<HashMap<u32, InFlightIssue>>> =
440 Arc::new(Mutex::new(HashMap::new()));
441
442 in_flight.lock().await.insert(
444 1,
445 InFlightIssue {
446 number: 1,
447 title: "Already running".to_string(),
448 area: "auth".to_string(),
449 predicted_files: vec!["src/auth.rs".to_string()],
450 has_migration: false,
451 complexity: Complexity::Full,
452 },
453 );
454
455 let issues = vec![
456 PipelineIssue {
457 number: 1,
458 title: "Already running".to_string(),
459 body: String::new(),
460 source: IssueOrigin::Github,
461 target_repo: None,
462 },
463 PipelineIssue {
464 number: 2,
465 title: "New issue".to_string(),
466 body: String::new(),
467 source: IssueOrigin::Github,
468 target_repo: None,
469 },
470 PipelineIssue {
471 number: 3,
472 title: "Another new".to_string(),
473 body: String::new(),
474 source: IssueOrigin::Github,
475 target_repo: None,
476 },
477 ];
478
479 let guard = in_flight.lock().await;
480 let new_issues: Vec<_> =
481 issues.into_iter().filter(|i| !guard.contains_key(&i.number)).collect();
482 drop(guard);
483
484 assert_eq!(new_issues.len(), 2);
485 assert_eq!(new_issues[0].number, 2);
486 assert_eq!(new_issues[1].number, 3);
487 }
488
489 #[test]
490 fn handle_task_result_does_not_panic_on_success() {
491 handle_task_result(Ok((1, Ok(()))));
492 }
493
494 #[test]
495 fn handle_task_result_does_not_panic_on_error() {
496 handle_task_result(Ok((1, Err(anyhow::anyhow!("test error")))));
497 }
498
499 #[test]
500 fn extract_batch1_returns_first_batch_only() {
501 let plan = crate::agents::PlannerOutput {
502 batches: vec![
503 crate::agents::Batch {
504 batch: 1,
505 issues: vec![
506 crate::agents::PlannedIssue {
507 number: 1,
508 title: "First".to_string(),
509 area: "cli".to_string(),
510 predicted_files: vec!["src/cli.rs".to_string()],
511 has_migration: false,
512 complexity: Complexity::Simple,
513 },
514 crate::agents::PlannedIssue {
515 number: 2,
516 title: "Second".to_string(),
517 area: "config".to_string(),
518 predicted_files: vec!["src/config.rs".to_string()],
519 has_migration: false,
520 complexity: Complexity::Full,
521 },
522 ],
523 reasoning: "independent".to_string(),
524 },
525 crate::agents::Batch {
526 batch: 2,
527 issues: vec![crate::agents::PlannedIssue {
528 number: 3,
529 title: "Third".to_string(),
530 area: "db".to_string(),
531 predicted_files: vec!["src/db.rs".to_string()],
532 has_migration: true,
533 complexity: Complexity::Full,
534 }],
535 reasoning: "depends on batch 1".to_string(),
536 },
537 ],
538 total_issues: 3,
539 parallel_capacity: 2,
540 };
541
542 let (batch1_numbers, metadata_map) = extract_batch1(&plan);
543 assert_eq!(batch1_numbers, vec![1, 2]);
544 assert!(!batch1_numbers.contains(&3));
545 assert_eq!(metadata_map.get(&1).unwrap().complexity, Complexity::Simple);
546 assert_eq!(metadata_map.get(&1).unwrap().area, "cli");
547 assert_eq!(metadata_map.get(&2).unwrap().complexity, Complexity::Full);
548 assert!(!metadata_map.contains_key(&3));
549 }
550
551 #[test]
552 fn extract_batch1_empty_plan() {
553 let plan =
554 crate::agents::PlannerOutput { batches: vec![], total_issues: 0, parallel_capacity: 0 };
555 let (batch1, metadata) = extract_batch1(&plan);
556 assert!(batch1.is_empty());
557 assert!(metadata.is_empty());
558 }
559
560 #[tokio::test]
561 async fn planner_failure_falls_back_to_all_parallel() {
562 let mut mock = MockCommandRunner::new();
563 mock.expect_run_gh().returning(|_, _| {
564 Box::pin(async {
565 Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
566 })
567 });
568 mock.expect_run_claude().returning(|_, _, _, _| {
569 Box::pin(async {
570 Ok(AgentResult {
571 cost_usd: 0.5,
572 duration: Duration::from_secs(2),
573 turns: 1,
574 output: "I don't know how to plan".to_string(),
575 session_id: "sess-plan".to_string(),
576 success: true,
577 })
578 })
579 });
580
581 let runner = Arc::new(mock);
582 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
583 let issues_provider = make_github_provider(&github);
584 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
585
586 let executor = Arc::new(PipelineExecutor {
587 runner,
588 github,
589 issues: issues_provider,
590 db,
591 config: Config::default(),
592 cancel_token: CancellationToken::new(),
593 repo_dir: PathBuf::from("/tmp"),
594 });
595
596 let issues = vec![PipelineIssue {
597 number: 1,
598 title: "Test".to_string(),
599 body: "body".to_string(),
600 source: IssueOrigin::Github,
601 target_repo: None,
602 }];
603
604 let plan = executor.plan_issues(&issues, &[]).await;
606 assert!(plan.is_none());
607 }
608}