1use std::{collections::HashSet, sync::Arc, time::Duration};
2
3use anyhow::Result;
4use tokio::{sync::Semaphore, task::JoinSet};
5use tokio_util::sync::CancellationToken;
6use tracing::{error, info, warn};
7
8use super::{
9 executor::{PipelineExecutor, PipelineOutcome},
10 graph::DependencyGraph,
11};
12use crate::{
13 agents::Complexity,
14 db::graph::NodeState,
15 git,
16 issues::PipelineIssue,
17 pipeline::{executor::generate_run_id, graph::GraphNode},
18 process::CommandRunner,
19};
20
21struct SchedulerState {
26 graph: DependencyGraph,
27 semaphore: Arc<Semaphore>,
28 tasks: JoinSet<(u32, Result<PipelineOutcome>)>,
29}
30
31pub async fn run_batch<R: CommandRunner + 'static>(
38 executor: &Arc<PipelineExecutor<R>>,
39 issues: Vec<PipelineIssue>,
40 max_parallel: usize,
41 auto_merge: bool,
42) -> Result<()> {
43 let session_id = generate_run_id();
44 let mut graph = if let Some(plan) = executor.plan_issues(&issues, &[]).await {
45 info!(nodes = plan.nodes.len(), total = plan.total_issues, "planner produced a plan");
46 DependencyGraph::from_planner_output(&session_id, &plan, &issues)
47 } else {
48 warn!("planner failed, falling back to all-parallel execution");
49 let mut g = DependencyGraph::new(&session_id);
50 for issue in &issues {
51 g.add_node(standalone_node(issue));
52 }
53 g
54 };
55
56 save_graph(&graph, executor).await;
57
58 let semaphore = Arc::new(Semaphore::new(max_parallel));
59 let mut had_errors = false;
60
61 while !graph.all_terminal() {
62 let ready = graph.ready_issues();
63 if ready.is_empty() {
64 warn!("no ready issues but graph is not terminal, breaking to avoid infinite loop");
65 save_graph(&graph, executor).await;
66 break;
67 }
68
69 let mut tasks: JoinSet<(u32, Result<PipelineOutcome>)> = JoinSet::new();
70
71 for num in &ready {
72 graph.transition(*num, NodeState::InFlight);
73 }
74 save_graph(&graph, executor).await;
75
76 for num in ready {
77 let node = graph.node(num).expect("ready issue must exist in graph");
78 let issue = node.issue.clone().expect("batch issues have issue attached");
79 let complexity = node.complexity.parse::<Complexity>().ok();
80 let sem = Arc::clone(&semaphore);
81 let exec = Arc::clone(executor);
82
83 tasks.spawn(async move {
84 let permit = match sem.acquire_owned().await {
85 Ok(p) => p,
86 Err(e) => return (num, Err(anyhow::anyhow!("semaphore closed: {e}"))),
87 };
88 let result = exec.run_issue_pipeline(&issue, auto_merge, complexity).await;
89 let outcome = match result {
90 Ok(outcome) => {
91 if let Err(e) = exec.finalize_merge(&outcome, &issue).await {
92 warn!(issue = num, error = %e, "finalize_merge failed");
93 }
94 Ok(outcome)
95 }
96 Err(e) => Err(e),
97 };
98 drop(permit);
99 (num, outcome)
100 });
101 }
102
103 let mut layer_had_merges = false;
104 while let Some(join_result) = tasks.join_next().await {
105 match join_result {
106 Ok((number, Ok(ref outcome))) => {
107 info!(issue = number, "pipeline completed successfully");
108 graph.set_pr_number(number, outcome.pr_number);
109 graph.set_run_id(number, &outcome.run_id);
110 graph.transition(number, NodeState::Merged);
111 if auto_merge {
112 layer_had_merges = true;
113 }
114 }
115 Ok((number, Err(ref e))) => {
116 error!(issue = number, error = %e, "pipeline failed for issue");
117 graph.transition(number, NodeState::Failed);
118 let blocked = graph.propagate_failure(number);
119 for b in &blocked {
120 warn!(issue = b, blocked_by = number, "transitively failed");
121 }
122 had_errors = true;
123 }
124 Err(e) => {
125 error!(error = %e, "pipeline task panicked");
126 had_errors = true;
127 }
128 }
129 }
130
131 if layer_had_merges && !graph.all_terminal() {
134 fetch_base_branch(executor).await;
135 }
136
137 save_graph(&graph, executor).await;
138 }
139
140 if had_errors {
141 anyhow::bail!("one or more pipelines failed in batch");
142 }
143 Ok(())
144}
145
146pub async fn polling_loop<R: CommandRunner + 'static>(
153 executor: Arc<PipelineExecutor<R>>,
154 auto_merge: bool,
155 cancel_token: CancellationToken,
156) -> Result<()> {
157 let poll_interval = Duration::from_secs(executor.config.pipeline.poll_interval);
158 let max_parallel = executor.config.pipeline.max_parallel as usize;
159 let ready_label = executor.config.labels.ready.clone();
160
161 let graph = load_or_create_graph(&executor).await;
163
164 let mut sched = SchedulerState {
165 graph,
166 semaphore: Arc::new(Semaphore::new(max_parallel)),
167 tasks: JoinSet::new(),
168 };
169
170 info!(poll_interval_secs = poll_interval.as_secs(), max_parallel, "continuous polling started");
171
172 loop {
173 tokio::select! {
174 () = cancel_token.cancelled() => {
175 info!("shutdown signal received, waiting for in-flight pipelines");
176 drain_tasks(&mut sched, &executor).await;
177 break;
178 }
179 () = tokio::time::sleep(poll_interval) => {
180 poll_and_spawn(&executor, &ready_label, &mut sched, auto_merge).await;
181 }
182 Some(result) = sched.tasks.join_next(), if !sched.tasks.is_empty() => {
183 handle_task_result(result, &mut sched.graph, &executor).await;
184 }
185 }
186 }
187
188 Ok(())
189}
190
191async fn load_or_create_graph<R: CommandRunner>(
193 executor: &Arc<PipelineExecutor<R>>,
194) -> DependencyGraph {
195 let conn = executor.db.lock().await;
196 match crate::db::graph::get_active_session(&conn) {
197 Ok(Some(session_id)) => match DependencyGraph::from_db(&conn, &session_id) {
198 Ok(graph) => {
199 info!(session_id = %session_id, nodes = graph.node_count(), "resumed existing graph session");
200 return graph;
201 }
202 Err(e) => {
203 warn!(error = %e, "failed to load graph session, starting fresh");
204 }
205 },
206 Ok(None) => {}
207 Err(e) => {
208 warn!(error = %e, "failed to check for active graph session");
209 }
210 }
211 let session_id = generate_run_id();
212 info!(session_id = %session_id, "starting new graph session");
213 DependencyGraph::new(&session_id)
214}
215
216async fn drain_tasks<R: CommandRunner>(
218 sched: &mut SchedulerState,
219 executor: &Arc<PipelineExecutor<R>>,
220) {
221 while let Some(result) = sched.tasks.join_next().await {
222 handle_task_result(result, &mut sched.graph, executor).await;
223 }
224}
225
226async fn handle_task_result<R: CommandRunner>(
228 result: Result<(u32, Result<PipelineOutcome>), tokio::task::JoinError>,
229 graph: &mut DependencyGraph,
230 executor: &Arc<PipelineExecutor<R>>,
231) {
232 match result {
233 Ok((number, Ok(ref outcome))) => {
234 info!(issue = number, "pipeline completed successfully");
235 graph.set_pr_number(number, outcome.pr_number);
236 graph.set_run_id(number, &outcome.run_id);
237 graph.transition(number, NodeState::AwaitingMerge);
238 }
239 Ok((number, Err(ref e))) => {
240 error!(issue = number, error = %e, "pipeline failed for issue");
241 graph.transition(number, NodeState::Failed);
242 let blocked = graph.propagate_failure(number);
243 for b in &blocked {
244 warn!(issue = b, blocked_by = number, "transitively failed");
245 }
246 }
247 Err(e) => {
248 error!(error = %e, "pipeline task panicked");
249 return;
250 }
251 }
252 save_graph(graph, executor).await;
253}
254
255async fn poll_awaiting_merges<R: CommandRunner + 'static>(
258 graph: &mut DependencyGraph,
259 executor: &Arc<PipelineExecutor<R>>,
260) {
261 let awaiting = graph.awaiting_merge();
262 if awaiting.is_empty() {
263 return;
264 }
265
266 let mut any_merged = false;
267 for num in awaiting {
268 let Some(node) = graph.node(num) else { continue };
269 let Some(pr_number) = node.pr_number else {
270 warn!(issue = num, "AwaitingMerge node has no PR number, skipping");
271 continue;
272 };
273 let run_id = node.run_id.clone().unwrap_or_default();
274 let issue = node.issue.clone();
275 let target_repo = node.target_repo.clone();
276
277 let pr_repo_dir = match executor.resolve_target_dir(target_repo.as_ref()) {
280 Ok((dir, _)) => dir,
281 Err(e) => {
282 warn!(issue = num, error = %e, "failed to resolve target dir for PR state check");
283 continue;
284 }
285 };
286
287 let pr_state = match executor.github.get_pr_state_in(pr_number, &pr_repo_dir).await {
288 Ok(s) => s,
289 Err(e) => {
290 warn!(issue = num, pr = pr_number, error = %e, "failed to check PR state");
291 continue;
292 }
293 };
294
295 match pr_state {
296 crate::github::PrState::Merged => {
297 info!(issue = num, pr = pr_number, "PR merged, finalizing");
298 if let Some(ref issue) = issue {
299 match executor.reconstruct_outcome(issue, &run_id, pr_number) {
300 Ok(outcome) => {
301 if let Err(e) = executor.finalize_merge(&outcome, issue).await {
302 warn!(issue = num, error = %e, "finalize_merge after poll failed");
303 }
304 }
305 Err(e) => {
306 warn!(issue = num, error = %e, "failed to reconstruct outcome");
307 }
308 }
309 } else {
310 warn!(
311 issue = num,
312 pr = pr_number,
313 "node restored from DB has no PipelineIssue, \
314 skipping finalization (labels and worktree may need manual cleanup)"
315 );
316 }
317 graph.transition(num, NodeState::Merged);
318 any_merged = true;
319 }
320 crate::github::PrState::Closed => {
321 warn!(issue = num, pr = pr_number, "PR closed without merge, marking failed");
322 graph.transition(num, NodeState::Failed);
323 let blocked = graph.propagate_failure(num);
324 for b in &blocked {
325 warn!(issue = b, blocked_by = num, "transitively failed (PR closed)");
326 }
327 }
328 crate::github::PrState::Open => {
329 }
331 }
332 }
333
334 if any_merged {
337 fetch_base_branch(executor).await;
338 }
339
340 save_graph(graph, executor).await;
341}
342
343async fn poll_and_spawn<R: CommandRunner + 'static>(
345 executor: &Arc<PipelineExecutor<R>>,
346 ready_label: &str,
347 sched: &mut SchedulerState,
348 auto_merge: bool,
349) {
350 poll_awaiting_merges(&mut sched.graph, executor).await;
352
353 let ready_issues = match executor.issues.get_ready_issues(ready_label).await {
354 Ok(i) => i,
355 Err(e) => {
356 error!(error = %e, "failed to fetch issues");
357 return;
358 }
359 };
360
361 let ready_numbers: HashSet<u32> = ready_issues.iter().map(|i| i.number).collect();
362
363 clean_stale_nodes(&mut sched.graph, &ready_numbers);
365
366 let new_issues: Vec<_> =
368 ready_issues.into_iter().filter(|i| !sched.graph.contains(i.number)).collect();
369
370 if !new_issues.is_empty() {
372 info!(count = new_issues.len(), "found new issues to evaluate");
373 let graph_context = sched.graph.to_graph_context();
374
375 if let Some(plan) = executor.plan_issues(&new_issues, &graph_context).await {
376 info!(nodes = plan.nodes.len(), total = plan.total_issues, "planner produced a plan");
377 sched.graph.merge_planner_output(&plan, &new_issues);
378 } else {
379 warn!("planner failed, adding all new issues as independent nodes");
380 add_independent_nodes(&mut sched.graph, &new_issues);
381 }
382
383 save_graph(&sched.graph, executor).await;
384 }
385
386 let to_spawn = collect_ready_issues(&mut sched.graph);
388 if to_spawn.is_empty() {
389 if new_issues.is_empty() {
390 info!("no actionable issues, waiting");
391 }
392 return;
393 }
394
395 save_graph(&sched.graph, executor).await;
396 spawn_issues(to_spawn, executor, sched, auto_merge);
397}
398
399fn clean_stale_nodes(graph: &mut DependencyGraph, ready_numbers: &HashSet<u32>) {
401 let stale: Vec<u32> = graph
402 .all_issues()
403 .into_iter()
404 .filter(|num| {
405 !ready_numbers.contains(num)
406 && graph.node(*num).is_some_and(|n| n.state == NodeState::Pending)
407 })
408 .collect();
409 if !stale.is_empty() {
410 info!(count = stale.len(), "removing stale pending nodes");
411 for num in stale {
412 graph.remove_node(num);
413 }
414 }
415}
416
417fn add_independent_nodes(graph: &mut DependencyGraph, issues: &[PipelineIssue]) {
419 for issue in issues {
420 if !graph.contains(issue.number) {
421 graph.add_node(standalone_node(issue));
422 }
423 }
424}
425
426fn collect_ready_issues(graph: &mut DependencyGraph) -> Vec<(u32, PipelineIssue, Complexity)> {
428 let ready = graph.ready_issues();
429 let mut to_spawn = Vec::new();
430
431 for num in ready {
432 let Some(node) = graph.node(num) else { continue };
433 let Some(issue) = node.issue.clone() else {
434 warn!(issue = num, "ready node has no PipelineIssue attached, skipping");
435 continue;
436 };
437 let complexity = node.complexity.parse::<Complexity>().unwrap_or(Complexity::Full);
438 graph.transition(num, NodeState::InFlight);
439 to_spawn.push((num, issue, complexity));
440 }
441
442 to_spawn
443}
444
445fn spawn_issues<R: CommandRunner + 'static>(
447 to_spawn: Vec<(u32, PipelineIssue, Complexity)>,
448 executor: &Arc<PipelineExecutor<R>>,
449 sched: &mut SchedulerState,
450 auto_merge: bool,
451) {
452 for (number, issue, complexity) in to_spawn {
453 let sem = Arc::clone(&sched.semaphore);
454 let exec = Arc::clone(executor);
455
456 sched.tasks.spawn(async move {
457 let permit = match sem.acquire_owned().await {
458 Ok(p) => p,
459 Err(e) => return (number, Err(anyhow::anyhow!("semaphore closed: {e}"))),
460 };
461 let outcome = exec.run_issue_pipeline(&issue, auto_merge, Some(complexity)).await;
462 drop(permit);
463 (number, outcome)
464 });
465 }
466}
467
468fn standalone_node(issue: &PipelineIssue) -> GraphNode {
470 GraphNode {
471 issue_number: issue.number,
472 title: issue.title.clone(),
473 area: String::new(),
474 predicted_files: Vec::new(),
475 has_migration: false,
476 complexity: Complexity::Full.to_string(),
477 state: NodeState::Pending,
478 pr_number: None,
479 run_id: None,
480 target_repo: issue.target_repo.clone(),
481 issue: Some(issue.clone()),
482 }
483}
484
485async fn fetch_base_branch<R: CommandRunner>(executor: &Arc<PipelineExecutor<R>>) {
490 match git::default_branch(&executor.repo_dir).await {
491 Ok(branch) => {
492 if let Err(e) = git::fetch_branch(&executor.repo_dir, &branch).await {
493 warn!(error = %e, "failed to fetch base branch after merge");
494 } else {
495 info!(branch = %branch, "updated local base branch after merge");
496 }
497 }
498 Err(e) => {
499 warn!(error = %e, "failed to detect base branch for post-merge fetch");
500 }
501 }
502}
503
504async fn save_graph<R: CommandRunner>(
506 graph: &DependencyGraph,
507 executor: &Arc<PipelineExecutor<R>>,
508) {
509 let conn = executor.db.lock().await;
510 if let Err(e) = graph.save_to_db(&conn) {
511 warn!(error = %e, "failed to persist dependency graph");
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use std::path::PathBuf;
518
519 use tokio::sync::Mutex;
520
521 use super::*;
522 use crate::{
523 agents::PlannerGraphOutput,
524 config::Config,
525 github::GhClient,
526 issues::{IssueOrigin, IssueProvider, github::GithubIssueProvider},
527 process::{AgentResult, CommandOutput, MockCommandRunner},
528 };
529
530 fn mock_runner_for_batch() -> MockCommandRunner {
531 let mut mock = MockCommandRunner::new();
532 mock.expect_run_gh().returning(|_, _| {
533 Box::pin(async {
534 Ok(CommandOutput {
535 stdout: "https://github.com/user/repo/pull/1\n".to_string(),
536 stderr: String::new(),
537 success: true,
538 })
539 })
540 });
541 mock.expect_run_claude().returning(|_, _, _, _, _| {
542 Box::pin(async {
543 Ok(AgentResult {
544 cost_usd: 1.0,
545 duration: Duration::from_secs(5),
546 turns: 3,
547 output: r#"{"findings":[],"summary":"clean"}"#.to_string(),
548 session_id: "sess-1".to_string(),
549 success: true,
550 })
551 })
552 });
553 mock
554 }
555
556 fn make_github_provider(gh: &Arc<GhClient<MockCommandRunner>>) -> Arc<dyn IssueProvider> {
557 Arc::new(GithubIssueProvider::new(Arc::clone(gh), "target_repo"))
558 }
559
560 fn make_issue(number: u32) -> PipelineIssue {
561 PipelineIssue {
562 number,
563 title: format!("Issue #{number}"),
564 body: String::new(),
565 source: IssueOrigin::Github,
566 target_repo: None,
567 author: None,
568 }
569 }
570
571 #[tokio::test]
572 async fn cancellation_stops_polling() {
573 let cancel = CancellationToken::new();
574 let runner = Arc::new(mock_runner_for_batch());
575 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
576 let issues = make_github_provider(&github);
577 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
578
579 let mut config = Config::default();
580 config.pipeline.poll_interval = 3600; let executor = Arc::new(PipelineExecutor {
583 runner,
584 github,
585 issues,
586 db,
587 config,
588 cancel_token: cancel.clone(),
589 repo_dir: PathBuf::from("/tmp"),
590 });
591
592 let cancel_clone = cancel.clone();
593 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
594
595 cancel.cancel();
597
598 let result = handle.await.unwrap();
599 assert!(result.is_ok());
600 }
601
602 #[tokio::test]
603 async fn cancellation_exits_within_timeout() {
604 let cancel = CancellationToken::new();
605 let runner = Arc::new(mock_runner_for_batch());
606 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
607 let issues = make_github_provider(&github);
608 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
609
610 let mut config = Config::default();
611 config.pipeline.poll_interval = 3600;
612
613 let executor = Arc::new(PipelineExecutor {
614 runner,
615 github,
616 issues,
617 db,
618 config,
619 cancel_token: cancel.clone(),
620 repo_dir: PathBuf::from("/tmp"),
621 });
622
623 let cancel_clone = cancel.clone();
624 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
625
626 cancel.cancel();
627
628 let result = tokio::time::timeout(Duration::from_secs(5), handle)
629 .await
630 .expect("polling loop should exit within timeout")
631 .unwrap();
632 assert!(result.is_ok());
633 }
634
635 #[test]
636 fn handle_task_success_transitions_to_awaiting_merge() {
637 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
638 rt.block_on(async {
639 let executor = {
640 let runner = Arc::new(mock_runner_for_batch());
641 let github =
642 Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
643 let issues = make_github_provider(&github);
644 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
645 Arc::new(PipelineExecutor {
646 runner,
647 github,
648 issues,
649 db,
650 config: Config::default(),
651 cancel_token: CancellationToken::new(),
652 repo_dir: PathBuf::from("/tmp"),
653 })
654 };
655
656 let mut graph = DependencyGraph::new("test");
657 graph.add_node(standalone_node(&make_issue(1)));
658 graph.transition(1, NodeState::InFlight);
659
660 let outcome = PipelineOutcome {
661 run_id: "run-abc".to_string(),
662 pr_number: 42,
663 worktree_path: PathBuf::from("/tmp/wt"),
664 target_dir: PathBuf::from("/tmp"),
665 };
666
667 handle_task_result(Ok((1, Ok(outcome))), &mut graph, &executor).await;
668
669 assert_eq!(graph.node(1).unwrap().state, NodeState::AwaitingMerge);
670 assert_eq!(graph.node(1).unwrap().pr_number, Some(42));
671 assert_eq!(graph.node(1).unwrap().run_id.as_deref(), Some("run-abc"));
672 });
673 }
674
675 #[test]
676 fn handle_task_failure_propagates_to_dependents() {
677 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
678 rt.block_on(async {
679 let executor = {
680 let runner = Arc::new(mock_runner_for_batch());
681 let github =
682 Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
683 let issues = make_github_provider(&github);
684 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
685 Arc::new(PipelineExecutor {
686 runner,
687 github,
688 issues,
689 db,
690 config: Config::default(),
691 cancel_token: CancellationToken::new(),
692 repo_dir: PathBuf::from("/tmp"),
693 })
694 };
695
696 let plan = PlannerGraphOutput {
697 nodes: vec![
698 crate::agents::PlannedNode {
699 number: 1,
700 title: "Root".to_string(),
701 area: "a".to_string(),
702 predicted_files: vec![],
703 has_migration: false,
704 complexity: Complexity::Full,
705 depends_on: vec![],
706 reasoning: String::new(),
707 },
708 crate::agents::PlannedNode {
709 number: 2,
710 title: "Dep".to_string(),
711 area: "b".to_string(),
712 predicted_files: vec![],
713 has_migration: false,
714 complexity: Complexity::Full,
715 depends_on: vec![1],
716 reasoning: String::new(),
717 },
718 ],
719 total_issues: 2,
720 parallel_capacity: 1,
721 };
722 let issues = vec![make_issue(1), make_issue(2)];
723 let mut graph = DependencyGraph::from_planner_output("test", &plan, &issues);
724 graph.transition(1, NodeState::InFlight);
725
726 handle_task_result(
727 Ok((1, Err(anyhow::anyhow!("pipeline failed")))),
728 &mut graph,
729 &executor,
730 )
731 .await;
732
733 assert_eq!(graph.node(1).unwrap().state, NodeState::Failed);
734 assert_eq!(graph.node(2).unwrap().state, NodeState::Failed);
735 });
736 }
737
738 #[test]
739 fn stale_node_removed_when_issue_disappears() {
740 let mut graph = DependencyGraph::new("test");
741 graph.add_node(standalone_node(&make_issue(1)));
742 graph.add_node(standalone_node(&make_issue(2)));
743 graph.add_node(standalone_node(&make_issue(3)));
744 graph.transition(2, NodeState::InFlight);
745
746 let ready_numbers: HashSet<u32> = HashSet::from([1, 2]);
748 clean_stale_nodes(&mut graph, &ready_numbers);
749
750 assert!(graph.contains(1)); assert!(graph.contains(2)); assert!(!graph.contains(3)); }
754
755 #[test]
756 fn collect_ready_issues_transitions_to_in_flight() {
757 let mut graph = DependencyGraph::new("test");
758 graph.add_node(standalone_node(&make_issue(1)));
759 graph.add_node(standalone_node(&make_issue(2)));
760
761 let spawnable = collect_ready_issues(&mut graph);
762 assert_eq!(spawnable.len(), 2);
763
764 assert_eq!(graph.node(1).unwrap().state, NodeState::InFlight);
766 assert_eq!(graph.node(2).unwrap().state, NodeState::InFlight);
767
768 assert!(collect_ready_issues(&mut graph).is_empty());
770 }
771
772 #[tokio::test]
773 async fn planner_failure_falls_back_to_all_parallel() {
774 let mut mock = MockCommandRunner::new();
775 mock.expect_run_gh().returning(|_, _| {
776 Box::pin(async {
777 Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
778 })
779 });
780 mock.expect_run_claude().returning(|_, _, _, _, _| {
781 Box::pin(async {
782 Ok(AgentResult {
783 cost_usd: 0.5,
784 duration: Duration::from_secs(2),
785 turns: 1,
786 output: "I don't know how to plan".to_string(),
787 session_id: "sess-plan".to_string(),
788 success: true,
789 })
790 })
791 });
792
793 let runner = Arc::new(mock);
794 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
795 let issues_provider = make_github_provider(&github);
796 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
797
798 let executor = Arc::new(PipelineExecutor {
799 runner,
800 github,
801 issues: issues_provider,
802 db,
803 config: Config::default(),
804 cancel_token: CancellationToken::new(),
805 repo_dir: PathBuf::from("/tmp"),
806 });
807
808 let issues = vec![PipelineIssue {
809 number: 1,
810 title: "Test".to_string(),
811 body: "body".to_string(),
812 source: IssueOrigin::Github,
813 target_repo: None,
814 author: None,
815 }];
816
817 let plan = executor.plan_issues(&issues, &[]).await;
819 assert!(plan.is_none());
820 }
821
822 #[test]
823 fn graph_persisted_after_state_change() {
824 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
825 rt.block_on(async {
826 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
827 let runner = Arc::new(mock_runner_for_batch());
828 let github =
829 Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
830 let issues = make_github_provider(&github);
831 let executor = Arc::new(PipelineExecutor {
832 runner,
833 github,
834 issues,
835 db: Arc::clone(&db),
836 config: Config::default(),
837 cancel_token: CancellationToken::new(),
838 repo_dir: PathBuf::from("/tmp"),
839 });
840
841 let mut graph = DependencyGraph::new("persist-test");
842 graph.add_node(standalone_node(&make_issue(1)));
843 graph.transition(1, NodeState::InFlight);
844
845 let outcome = PipelineOutcome {
846 run_id: "run-1".to_string(),
847 pr_number: 10,
848 worktree_path: PathBuf::from("/tmp/wt"),
849 target_dir: PathBuf::from("/tmp"),
850 };
851 handle_task_result(Ok((1, Ok(outcome))), &mut graph, &executor).await;
852
853 let loaded = DependencyGraph::from_db(&*db.lock().await, "persist-test").unwrap();
855 assert_eq!(loaded.node(1).unwrap().state, NodeState::AwaitingMerge);
856 assert_eq!(loaded.node(1).unwrap().pr_number, Some(10));
857 });
858 }
859
860 fn mock_runner_with_pr_state(state: &'static str) -> MockCommandRunner {
861 let mut mock = MockCommandRunner::new();
862 mock.expect_run_gh().returning(move |args, _| {
863 let args = args.to_vec();
864 Box::pin(async move {
865 if args.iter().any(|a| a == "view") {
866 Ok(CommandOutput {
867 stdout: format!(r#"{{"state":"{state}"}}"#),
868 stderr: String::new(),
869 success: true,
870 })
871 } else {
872 Ok(CommandOutput {
873 stdout: String::new(),
874 stderr: String::new(),
875 success: true,
876 })
877 }
878 })
879 });
880 mock.expect_run_claude().returning(|_, _, _, _, _| {
881 Box::pin(async {
882 Ok(AgentResult {
883 cost_usd: 0.0,
884 duration: Duration::from_secs(0),
885 turns: 0,
886 output: String::new(),
887 session_id: String::new(),
888 success: true,
889 })
890 })
891 });
892 mock
893 }
894
895 fn make_merge_poll_executor(state: &'static str) -> Arc<PipelineExecutor<MockCommandRunner>> {
896 let gh_mock = mock_runner_with_pr_state(state);
897 let github = Arc::new(GhClient::new(gh_mock, std::path::Path::new("/tmp")));
898 let issues = make_github_provider(&github);
899 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
900 let runner = Arc::new(mock_runner_with_pr_state(state));
901 Arc::new(PipelineExecutor {
902 runner,
903 github,
904 issues,
905 db,
906 config: Config::default(),
907 cancel_token: CancellationToken::new(),
908 repo_dir: PathBuf::from("/tmp"),
909 })
910 }
911
912 #[test]
913 fn merge_polling_transitions_merged_pr() {
914 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
915 rt.block_on(async {
916 let executor = make_merge_poll_executor("MERGED");
917
918 let mut graph = DependencyGraph::new("merge-poll-test");
919 let mut node = standalone_node(&make_issue(1));
920 node.pr_number = Some(42);
921 node.run_id = Some("run-1".to_string());
922 graph.add_node(node);
923 graph.transition(1, NodeState::AwaitingMerge);
924
925 poll_awaiting_merges(&mut graph, &executor).await;
926
927 assert_eq!(graph.node(1).unwrap().state, NodeState::Merged);
928 });
929 }
930
931 #[test]
932 fn merge_polling_transitions_node_without_issue() {
933 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
934 rt.block_on(async {
935 let executor = make_merge_poll_executor("MERGED");
936
937 let mut graph = DependencyGraph::new("db-restore-test");
938 let mut node = GraphNode {
940 issue_number: 1,
941 title: "Issue #1".to_string(),
942 area: "test".to_string(),
943 predicted_files: vec![],
944 has_migration: false,
945 complexity: "full".to_string(),
946 state: NodeState::Pending,
947 pr_number: Some(42),
948 run_id: Some("run-1".to_string()),
949 issue: None,
950 target_repo: None,
951 };
952 node.state = NodeState::Pending;
953 graph.add_node(node);
954 graph.transition(1, NodeState::AwaitingMerge);
955
956 poll_awaiting_merges(&mut graph, &executor).await;
957
958 assert_eq!(graph.node(1).unwrap().state, NodeState::Merged);
960 });
961 }
962
963 #[test]
964 fn merge_polling_handles_closed_pr() {
965 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
966 rt.block_on(async {
967 let executor = make_merge_poll_executor("CLOSED");
968
969 let plan = PlannerGraphOutput {
970 nodes: vec![
971 crate::agents::PlannedNode {
972 number: 1,
973 title: "Root".to_string(),
974 area: "a".to_string(),
975 predicted_files: vec![],
976 has_migration: false,
977 complexity: Complexity::Full,
978 depends_on: vec![],
979 reasoning: String::new(),
980 },
981 crate::agents::PlannedNode {
982 number: 2,
983 title: "Dep".to_string(),
984 area: "b".to_string(),
985 predicted_files: vec![],
986 has_migration: false,
987 complexity: Complexity::Full,
988 depends_on: vec![1],
989 reasoning: String::new(),
990 },
991 ],
992 total_issues: 2,
993 parallel_capacity: 1,
994 };
995 let test_issues = vec![make_issue(1), make_issue(2)];
996 let mut graph =
997 DependencyGraph::from_planner_output("merge-poll-close", &plan, &test_issues);
998 graph.transition(1, NodeState::AwaitingMerge);
999 graph.set_pr_number(1, 42);
1000 graph.set_run_id(1, "run-1");
1001
1002 poll_awaiting_merges(&mut graph, &executor).await;
1003
1004 assert_eq!(graph.node(1).unwrap().state, NodeState::Failed);
1005 assert_eq!(graph.node(2).unwrap().state, NodeState::Failed);
1007 });
1008 }
1009
1010 #[test]
1011 fn merge_unlocks_dependent() {
1012 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
1013 rt.block_on(async {
1014 let executor = make_merge_poll_executor("MERGED");
1015
1016 let plan = PlannerGraphOutput {
1017 nodes: vec![
1018 crate::agents::PlannedNode {
1019 number: 1,
1020 title: "Root".to_string(),
1021 area: "a".to_string(),
1022 predicted_files: vec![],
1023 has_migration: false,
1024 complexity: Complexity::Full,
1025 depends_on: vec![],
1026 reasoning: String::new(),
1027 },
1028 crate::agents::PlannedNode {
1029 number: 2,
1030 title: "Dep".to_string(),
1031 area: "b".to_string(),
1032 predicted_files: vec![],
1033 has_migration: false,
1034 complexity: Complexity::Full,
1035 depends_on: vec![1],
1036 reasoning: String::new(),
1037 },
1038 ],
1039 total_issues: 2,
1040 parallel_capacity: 1,
1041 };
1042 let test_issues = vec![make_issue(1), make_issue(2)];
1043 let mut graph =
1044 DependencyGraph::from_planner_output("merge-unlock", &plan, &test_issues);
1045 graph.transition(1, NodeState::AwaitingMerge);
1046 graph.set_pr_number(1, 42);
1047 graph.set_run_id(1, "run-1");
1048
1049 assert!(graph.ready_issues().is_empty());
1051
1052 poll_awaiting_merges(&mut graph, &executor).await;
1053
1054 assert_eq!(graph.node(1).unwrap().state, NodeState::Merged);
1056 assert_eq!(graph.ready_issues(), vec![2]);
1057 });
1058 }
1059}