1use std::{collections::HashSet, sync::Arc, time::Duration};
2
3use anyhow::Result;
4use tokio::{sync::Semaphore, task::JoinSet};
5use tokio_util::sync::CancellationToken;
6use tracing::{error, info, warn};
7
8use super::{
9 executor::{PipelineExecutor, PipelineOutcome},
10 graph::DependencyGraph,
11};
12use crate::{
13 agents::Complexity,
14 db::graph::NodeState,
15 issues::PipelineIssue,
16 pipeline::{executor::generate_run_id, graph::GraphNode},
17 process::CommandRunner,
18};
19
20struct SchedulerState {
25 graph: DependencyGraph,
26 semaphore: Arc<Semaphore>,
27 tasks: JoinSet<(u32, Result<PipelineOutcome>)>,
28}
29
30pub async fn run_batch<R: CommandRunner + 'static>(
37 executor: &Arc<PipelineExecutor<R>>,
38 issues: Vec<PipelineIssue>,
39 max_parallel: usize,
40 auto_merge: bool,
41) -> Result<()> {
42 let session_id = generate_run_id();
43 let mut graph = if let Some(plan) = executor.plan_issues(&issues, &[]).await {
44 info!(nodes = plan.nodes.len(), total = plan.total_issues, "planner produced a plan");
45 DependencyGraph::from_planner_output(&session_id, &plan, &issues)
46 } else {
47 warn!("planner failed, falling back to all-parallel execution");
48 let mut g = DependencyGraph::new(&session_id);
49 for issue in &issues {
50 g.add_node(standalone_node(issue));
51 }
52 g
53 };
54
55 save_graph(&graph, executor).await;
56
57 let semaphore = Arc::new(Semaphore::new(max_parallel));
58 let mut had_errors = false;
59
60 while !graph.all_terminal() {
61 let ready = graph.ready_issues();
62 if ready.is_empty() {
63 warn!("no ready issues but graph is not terminal, breaking to avoid infinite loop");
64 save_graph(&graph, executor).await;
65 break;
66 }
67
68 let mut tasks: JoinSet<(u32, Result<PipelineOutcome>)> = JoinSet::new();
69
70 for num in &ready {
71 graph.transition(*num, NodeState::InFlight);
72 }
73 save_graph(&graph, executor).await;
74
75 for num in ready {
76 let node = graph.node(num).expect("ready issue must exist in graph");
77 let issue = node.issue.clone().expect("batch issues have issue attached");
78 let complexity = node.complexity.parse::<Complexity>().ok();
79 let sem = Arc::clone(&semaphore);
80 let exec = Arc::clone(executor);
81
82 tasks.spawn(async move {
83 let permit = match sem.acquire_owned().await {
84 Ok(p) => p,
85 Err(e) => return (num, Err(anyhow::anyhow!("semaphore closed: {e}"))),
86 };
87 let result = exec.run_issue_pipeline(&issue, auto_merge, complexity).await;
88 let outcome = match result {
89 Ok(outcome) => {
90 if let Err(e) = exec.finalize_merge(&outcome, &issue).await {
91 warn!(issue = num, error = %e, "finalize_merge failed");
92 }
93 Ok(outcome)
94 }
95 Err(e) => Err(e),
96 };
97 drop(permit);
98 (num, outcome)
99 });
100 }
101
102 while let Some(join_result) = tasks.join_next().await {
103 match join_result {
104 Ok((number, Ok(ref outcome))) => {
105 info!(issue = number, "pipeline completed successfully");
106 graph.set_pr_number(number, outcome.pr_number);
107 graph.set_run_id(number, &outcome.run_id);
108 graph.transition(number, NodeState::Merged);
109 }
110 Ok((number, Err(ref e))) => {
111 error!(issue = number, error = %e, "pipeline failed for issue");
112 graph.transition(number, NodeState::Failed);
113 let blocked = graph.propagate_failure(number);
114 for b in &blocked {
115 warn!(issue = b, blocked_by = number, "transitively failed");
116 }
117 had_errors = true;
118 }
119 Err(e) => {
120 error!(error = %e, "pipeline task panicked");
121 had_errors = true;
122 }
123 }
124 }
125
126 save_graph(&graph, executor).await;
127 }
128
129 if had_errors {
130 anyhow::bail!("one or more pipelines failed in batch");
131 }
132 Ok(())
133}
134
135pub async fn polling_loop<R: CommandRunner + 'static>(
142 executor: Arc<PipelineExecutor<R>>,
143 auto_merge: bool,
144 cancel_token: CancellationToken,
145) -> Result<()> {
146 let poll_interval = Duration::from_secs(executor.config.pipeline.poll_interval);
147 let max_parallel = executor.config.pipeline.max_parallel as usize;
148 let ready_label = executor.config.labels.ready.clone();
149
150 let graph = load_or_create_graph(&executor).await;
152
153 let mut sched = SchedulerState {
154 graph,
155 semaphore: Arc::new(Semaphore::new(max_parallel)),
156 tasks: JoinSet::new(),
157 };
158
159 info!(poll_interval_secs = poll_interval.as_secs(), max_parallel, "continuous polling started");
160
161 loop {
162 tokio::select! {
163 () = cancel_token.cancelled() => {
164 info!("shutdown signal received, waiting for in-flight pipelines");
165 drain_tasks(&mut sched, &executor).await;
166 break;
167 }
168 () = tokio::time::sleep(poll_interval) => {
169 poll_and_spawn(&executor, &ready_label, &mut sched, auto_merge).await;
170 }
171 Some(result) = sched.tasks.join_next(), if !sched.tasks.is_empty() => {
172 handle_task_result(result, &mut sched.graph, &executor).await;
173 }
174 }
175 }
176
177 Ok(())
178}
179
180async fn load_or_create_graph<R: CommandRunner>(
182 executor: &Arc<PipelineExecutor<R>>,
183) -> DependencyGraph {
184 let conn = executor.db.lock().await;
185 match crate::db::graph::get_active_session(&conn) {
186 Ok(Some(session_id)) => match DependencyGraph::from_db(&conn, &session_id) {
187 Ok(graph) => {
188 info!(session_id = %session_id, nodes = graph.node_count(), "resumed existing graph session");
189 return graph;
190 }
191 Err(e) => {
192 warn!(error = %e, "failed to load graph session, starting fresh");
193 }
194 },
195 Ok(None) => {}
196 Err(e) => {
197 warn!(error = %e, "failed to check for active graph session");
198 }
199 }
200 let session_id = generate_run_id();
201 info!(session_id = %session_id, "starting new graph session");
202 DependencyGraph::new(&session_id)
203}
204
205async fn drain_tasks<R: CommandRunner>(
207 sched: &mut SchedulerState,
208 executor: &Arc<PipelineExecutor<R>>,
209) {
210 while let Some(result) = sched.tasks.join_next().await {
211 handle_task_result(result, &mut sched.graph, executor).await;
212 }
213}
214
215async fn handle_task_result<R: CommandRunner>(
217 result: Result<(u32, Result<PipelineOutcome>), tokio::task::JoinError>,
218 graph: &mut DependencyGraph,
219 executor: &Arc<PipelineExecutor<R>>,
220) {
221 match result {
222 Ok((number, Ok(ref outcome))) => {
223 info!(issue = number, "pipeline completed successfully");
224 graph.set_pr_number(number, outcome.pr_number);
225 graph.set_run_id(number, &outcome.run_id);
226 graph.transition(number, NodeState::AwaitingMerge);
227 }
228 Ok((number, Err(ref e))) => {
229 error!(issue = number, error = %e, "pipeline failed for issue");
230 graph.transition(number, NodeState::Failed);
231 let blocked = graph.propagate_failure(number);
232 for b in &blocked {
233 warn!(issue = b, blocked_by = number, "transitively failed");
234 }
235 }
236 Err(e) => {
237 error!(error = %e, "pipeline task panicked");
238 return;
239 }
240 }
241 save_graph(graph, executor).await;
242}
243
244async fn poll_awaiting_merges<R: CommandRunner + 'static>(
247 graph: &mut DependencyGraph,
248 executor: &Arc<PipelineExecutor<R>>,
249) {
250 let awaiting = graph.awaiting_merge();
251 if awaiting.is_empty() {
252 return;
253 }
254
255 for num in awaiting {
256 let Some(node) = graph.node(num) else { continue };
257 let Some(pr_number) = node.pr_number else {
258 warn!(issue = num, "AwaitingMerge node has no PR number, skipping");
259 continue;
260 };
261 let run_id = node.run_id.clone().unwrap_or_default();
262 let issue = node.issue.clone();
263 let target_repo = node.target_repo.clone();
264
265 let pr_repo_dir = match executor.resolve_target_dir(target_repo.as_ref()) {
268 Ok((dir, _)) => dir,
269 Err(e) => {
270 warn!(issue = num, error = %e, "failed to resolve target dir for PR state check");
271 continue;
272 }
273 };
274
275 let pr_state = match executor.github.get_pr_state_in(pr_number, &pr_repo_dir).await {
276 Ok(s) => s,
277 Err(e) => {
278 warn!(issue = num, pr = pr_number, error = %e, "failed to check PR state");
279 continue;
280 }
281 };
282
283 match pr_state {
284 crate::github::PrState::Merged => {
285 info!(issue = num, pr = pr_number, "PR merged, finalizing");
286 if let Some(ref issue) = issue {
287 match executor.reconstruct_outcome(issue, &run_id, pr_number) {
288 Ok(outcome) => {
289 if let Err(e) = executor.finalize_merge(&outcome, issue).await {
290 warn!(issue = num, error = %e, "finalize_merge after poll failed");
291 }
292 }
293 Err(e) => {
294 warn!(issue = num, error = %e, "failed to reconstruct outcome");
295 }
296 }
297 } else {
298 warn!(
299 issue = num,
300 pr = pr_number,
301 "node restored from DB has no PipelineIssue, \
302 skipping finalization (labels and worktree may need manual cleanup)"
303 );
304 }
305 graph.transition(num, NodeState::Merged);
306 }
307 crate::github::PrState::Closed => {
308 warn!(issue = num, pr = pr_number, "PR closed without merge, marking failed");
309 graph.transition(num, NodeState::Failed);
310 let blocked = graph.propagate_failure(num);
311 for b in &blocked {
312 warn!(issue = b, blocked_by = num, "transitively failed (PR closed)");
313 }
314 }
315 crate::github::PrState::Open => {
316 }
318 }
319 }
320
321 save_graph(graph, executor).await;
322}
323
324async fn poll_and_spawn<R: CommandRunner + 'static>(
326 executor: &Arc<PipelineExecutor<R>>,
327 ready_label: &str,
328 sched: &mut SchedulerState,
329 auto_merge: bool,
330) {
331 poll_awaiting_merges(&mut sched.graph, executor).await;
333
334 let ready_issues = match executor.issues.get_ready_issues(ready_label).await {
335 Ok(i) => i,
336 Err(e) => {
337 error!(error = %e, "failed to fetch issues");
338 return;
339 }
340 };
341
342 let ready_numbers: HashSet<u32> = ready_issues.iter().map(|i| i.number).collect();
343
344 clean_stale_nodes(&mut sched.graph, &ready_numbers);
346
347 let new_issues: Vec<_> =
349 ready_issues.into_iter().filter(|i| !sched.graph.contains(i.number)).collect();
350
351 if !new_issues.is_empty() {
353 info!(count = new_issues.len(), "found new issues to evaluate");
354 let graph_context = sched.graph.to_graph_context();
355
356 if let Some(plan) = executor.plan_issues(&new_issues, &graph_context).await {
357 info!(nodes = plan.nodes.len(), total = plan.total_issues, "planner produced a plan");
358 sched.graph.merge_planner_output(&plan, &new_issues);
359 } else {
360 warn!("planner failed, adding all new issues as independent nodes");
361 add_independent_nodes(&mut sched.graph, &new_issues);
362 }
363
364 save_graph(&sched.graph, executor).await;
365 }
366
367 let to_spawn = collect_ready_issues(&mut sched.graph);
369 if to_spawn.is_empty() {
370 if new_issues.is_empty() {
371 info!("no actionable issues, waiting");
372 }
373 return;
374 }
375
376 save_graph(&sched.graph, executor).await;
377 spawn_issues(to_spawn, executor, sched, auto_merge);
378}
379
380fn clean_stale_nodes(graph: &mut DependencyGraph, ready_numbers: &HashSet<u32>) {
382 let stale: Vec<u32> = graph
383 .all_issues()
384 .into_iter()
385 .filter(|num| {
386 !ready_numbers.contains(num)
387 && graph.node(*num).is_some_and(|n| n.state == NodeState::Pending)
388 })
389 .collect();
390 if !stale.is_empty() {
391 info!(count = stale.len(), "removing stale pending nodes");
392 for num in stale {
393 graph.remove_node(num);
394 }
395 }
396}
397
398fn add_independent_nodes(graph: &mut DependencyGraph, issues: &[PipelineIssue]) {
400 for issue in issues {
401 if !graph.contains(issue.number) {
402 graph.add_node(standalone_node(issue));
403 }
404 }
405}
406
407fn collect_ready_issues(graph: &mut DependencyGraph) -> Vec<(u32, PipelineIssue, Complexity)> {
409 let ready = graph.ready_issues();
410 let mut to_spawn = Vec::new();
411
412 for num in ready {
413 let Some(node) = graph.node(num) else { continue };
414 let Some(issue) = node.issue.clone() else {
415 warn!(issue = num, "ready node has no PipelineIssue attached, skipping");
416 continue;
417 };
418 let complexity = node.complexity.parse::<Complexity>().unwrap_or(Complexity::Full);
419 graph.transition(num, NodeState::InFlight);
420 to_spawn.push((num, issue, complexity));
421 }
422
423 to_spawn
424}
425
426fn spawn_issues<R: CommandRunner + 'static>(
428 to_spawn: Vec<(u32, PipelineIssue, Complexity)>,
429 executor: &Arc<PipelineExecutor<R>>,
430 sched: &mut SchedulerState,
431 auto_merge: bool,
432) {
433 for (number, issue, complexity) in to_spawn {
434 let sem = Arc::clone(&sched.semaphore);
435 let exec = Arc::clone(executor);
436
437 sched.tasks.spawn(async move {
438 let permit = match sem.acquire_owned().await {
439 Ok(p) => p,
440 Err(e) => return (number, Err(anyhow::anyhow!("semaphore closed: {e}"))),
441 };
442 let outcome = exec.run_issue_pipeline(&issue, auto_merge, Some(complexity)).await;
443 drop(permit);
444 (number, outcome)
445 });
446 }
447}
448
449fn standalone_node(issue: &PipelineIssue) -> GraphNode {
451 GraphNode {
452 issue_number: issue.number,
453 title: issue.title.clone(),
454 area: String::new(),
455 predicted_files: Vec::new(),
456 has_migration: false,
457 complexity: Complexity::Full.to_string(),
458 state: NodeState::Pending,
459 pr_number: None,
460 run_id: None,
461 target_repo: issue.target_repo.clone(),
462 issue: Some(issue.clone()),
463 }
464}
465
466async fn save_graph<R: CommandRunner>(
468 graph: &DependencyGraph,
469 executor: &Arc<PipelineExecutor<R>>,
470) {
471 let conn = executor.db.lock().await;
472 if let Err(e) = graph.save_to_db(&conn) {
473 warn!(error = %e, "failed to persist dependency graph");
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use std::path::PathBuf;
480
481 use tokio::sync::Mutex;
482
483 use super::*;
484 use crate::{
485 agents::PlannerGraphOutput,
486 config::Config,
487 github::GhClient,
488 issues::{IssueOrigin, IssueProvider, github::GithubIssueProvider},
489 process::{AgentResult, CommandOutput, MockCommandRunner},
490 };
491
492 fn mock_runner_for_batch() -> MockCommandRunner {
493 let mut mock = MockCommandRunner::new();
494 mock.expect_run_gh().returning(|_, _| {
495 Box::pin(async {
496 Ok(CommandOutput {
497 stdout: "https://github.com/user/repo/pull/1\n".to_string(),
498 stderr: String::new(),
499 success: true,
500 })
501 })
502 });
503 mock.expect_run_claude().returning(|_, _, _, _, _| {
504 Box::pin(async {
505 Ok(AgentResult {
506 cost_usd: 1.0,
507 duration: Duration::from_secs(5),
508 turns: 3,
509 output: r#"{"findings":[],"summary":"clean"}"#.to_string(),
510 session_id: "sess-1".to_string(),
511 success: true,
512 })
513 })
514 });
515 mock
516 }
517
518 fn make_github_provider(gh: &Arc<GhClient<MockCommandRunner>>) -> Arc<dyn IssueProvider> {
519 Arc::new(GithubIssueProvider::new(Arc::clone(gh), "target_repo"))
520 }
521
522 fn make_issue(number: u32) -> PipelineIssue {
523 PipelineIssue {
524 number,
525 title: format!("Issue #{number}"),
526 body: String::new(),
527 source: IssueOrigin::Github,
528 target_repo: None,
529 author: None,
530 }
531 }
532
533 #[tokio::test]
534 async fn cancellation_stops_polling() {
535 let cancel = CancellationToken::new();
536 let runner = Arc::new(mock_runner_for_batch());
537 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
538 let issues = make_github_provider(&github);
539 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
540
541 let mut config = Config::default();
542 config.pipeline.poll_interval = 3600; let executor = Arc::new(PipelineExecutor {
545 runner,
546 github,
547 issues,
548 db,
549 config,
550 cancel_token: cancel.clone(),
551 repo_dir: PathBuf::from("/tmp"),
552 });
553
554 let cancel_clone = cancel.clone();
555 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
556
557 cancel.cancel();
559
560 let result = handle.await.unwrap();
561 assert!(result.is_ok());
562 }
563
564 #[tokio::test]
565 async fn cancellation_exits_within_timeout() {
566 let cancel = CancellationToken::new();
567 let runner = Arc::new(mock_runner_for_batch());
568 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
569 let issues = make_github_provider(&github);
570 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
571
572 let mut config = Config::default();
573 config.pipeline.poll_interval = 3600;
574
575 let executor = Arc::new(PipelineExecutor {
576 runner,
577 github,
578 issues,
579 db,
580 config,
581 cancel_token: cancel.clone(),
582 repo_dir: PathBuf::from("/tmp"),
583 });
584
585 let cancel_clone = cancel.clone();
586 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
587
588 cancel.cancel();
589
590 let result = tokio::time::timeout(Duration::from_secs(5), handle)
591 .await
592 .expect("polling loop should exit within timeout")
593 .unwrap();
594 assert!(result.is_ok());
595 }
596
597 #[test]
598 fn handle_task_success_transitions_to_awaiting_merge() {
599 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
600 rt.block_on(async {
601 let executor = {
602 let runner = Arc::new(mock_runner_for_batch());
603 let github =
604 Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
605 let issues = make_github_provider(&github);
606 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
607 Arc::new(PipelineExecutor {
608 runner,
609 github,
610 issues,
611 db,
612 config: Config::default(),
613 cancel_token: CancellationToken::new(),
614 repo_dir: PathBuf::from("/tmp"),
615 })
616 };
617
618 let mut graph = DependencyGraph::new("test");
619 graph.add_node(standalone_node(&make_issue(1)));
620 graph.transition(1, NodeState::InFlight);
621
622 let outcome = PipelineOutcome {
623 run_id: "run-abc".to_string(),
624 pr_number: 42,
625 worktree_path: PathBuf::from("/tmp/wt"),
626 target_dir: PathBuf::from("/tmp"),
627 };
628
629 handle_task_result(Ok((1, Ok(outcome))), &mut graph, &executor).await;
630
631 assert_eq!(graph.node(1).unwrap().state, NodeState::AwaitingMerge);
632 assert_eq!(graph.node(1).unwrap().pr_number, Some(42));
633 assert_eq!(graph.node(1).unwrap().run_id.as_deref(), Some("run-abc"));
634 });
635 }
636
637 #[test]
638 fn handle_task_failure_propagates_to_dependents() {
639 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
640 rt.block_on(async {
641 let executor = {
642 let runner = Arc::new(mock_runner_for_batch());
643 let github =
644 Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
645 let issues = make_github_provider(&github);
646 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
647 Arc::new(PipelineExecutor {
648 runner,
649 github,
650 issues,
651 db,
652 config: Config::default(),
653 cancel_token: CancellationToken::new(),
654 repo_dir: PathBuf::from("/tmp"),
655 })
656 };
657
658 let plan = PlannerGraphOutput {
659 nodes: vec![
660 crate::agents::PlannedNode {
661 number: 1,
662 title: "Root".to_string(),
663 area: "a".to_string(),
664 predicted_files: vec![],
665 has_migration: false,
666 complexity: Complexity::Full,
667 depends_on: vec![],
668 reasoning: String::new(),
669 },
670 crate::agents::PlannedNode {
671 number: 2,
672 title: "Dep".to_string(),
673 area: "b".to_string(),
674 predicted_files: vec![],
675 has_migration: false,
676 complexity: Complexity::Full,
677 depends_on: vec![1],
678 reasoning: String::new(),
679 },
680 ],
681 total_issues: 2,
682 parallel_capacity: 1,
683 };
684 let issues = vec![make_issue(1), make_issue(2)];
685 let mut graph = DependencyGraph::from_planner_output("test", &plan, &issues);
686 graph.transition(1, NodeState::InFlight);
687
688 handle_task_result(
689 Ok((1, Err(anyhow::anyhow!("pipeline failed")))),
690 &mut graph,
691 &executor,
692 )
693 .await;
694
695 assert_eq!(graph.node(1).unwrap().state, NodeState::Failed);
696 assert_eq!(graph.node(2).unwrap().state, NodeState::Failed);
697 });
698 }
699
700 #[test]
701 fn stale_node_removed_when_issue_disappears() {
702 let mut graph = DependencyGraph::new("test");
703 graph.add_node(standalone_node(&make_issue(1)));
704 graph.add_node(standalone_node(&make_issue(2)));
705 graph.add_node(standalone_node(&make_issue(3)));
706 graph.transition(2, NodeState::InFlight);
707
708 let ready_numbers: HashSet<u32> = HashSet::from([1, 2]);
710 clean_stale_nodes(&mut graph, &ready_numbers);
711
712 assert!(graph.contains(1)); assert!(graph.contains(2)); assert!(!graph.contains(3)); }
716
717 #[test]
718 fn collect_ready_issues_transitions_to_in_flight() {
719 let mut graph = DependencyGraph::new("test");
720 graph.add_node(standalone_node(&make_issue(1)));
721 graph.add_node(standalone_node(&make_issue(2)));
722
723 let spawnable = collect_ready_issues(&mut graph);
724 assert_eq!(spawnable.len(), 2);
725
726 assert_eq!(graph.node(1).unwrap().state, NodeState::InFlight);
728 assert_eq!(graph.node(2).unwrap().state, NodeState::InFlight);
729
730 assert!(collect_ready_issues(&mut graph).is_empty());
732 }
733
734 #[tokio::test]
735 async fn planner_failure_falls_back_to_all_parallel() {
736 let mut mock = MockCommandRunner::new();
737 mock.expect_run_gh().returning(|_, _| {
738 Box::pin(async {
739 Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
740 })
741 });
742 mock.expect_run_claude().returning(|_, _, _, _, _| {
743 Box::pin(async {
744 Ok(AgentResult {
745 cost_usd: 0.5,
746 duration: Duration::from_secs(2),
747 turns: 1,
748 output: "I don't know how to plan".to_string(),
749 session_id: "sess-plan".to_string(),
750 success: true,
751 })
752 })
753 });
754
755 let runner = Arc::new(mock);
756 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
757 let issues_provider = make_github_provider(&github);
758 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
759
760 let executor = Arc::new(PipelineExecutor {
761 runner,
762 github,
763 issues: issues_provider,
764 db,
765 config: Config::default(),
766 cancel_token: CancellationToken::new(),
767 repo_dir: PathBuf::from("/tmp"),
768 });
769
770 let issues = vec![PipelineIssue {
771 number: 1,
772 title: "Test".to_string(),
773 body: "body".to_string(),
774 source: IssueOrigin::Github,
775 target_repo: None,
776 author: None,
777 }];
778
779 let plan = executor.plan_issues(&issues, &[]).await;
781 assert!(plan.is_none());
782 }
783
784 #[test]
785 fn graph_persisted_after_state_change() {
786 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
787 rt.block_on(async {
788 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
789 let runner = Arc::new(mock_runner_for_batch());
790 let github =
791 Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
792 let issues = make_github_provider(&github);
793 let executor = Arc::new(PipelineExecutor {
794 runner,
795 github,
796 issues,
797 db: Arc::clone(&db),
798 config: Config::default(),
799 cancel_token: CancellationToken::new(),
800 repo_dir: PathBuf::from("/tmp"),
801 });
802
803 let mut graph = DependencyGraph::new("persist-test");
804 graph.add_node(standalone_node(&make_issue(1)));
805 graph.transition(1, NodeState::InFlight);
806
807 let outcome = PipelineOutcome {
808 run_id: "run-1".to_string(),
809 pr_number: 10,
810 worktree_path: PathBuf::from("/tmp/wt"),
811 target_dir: PathBuf::from("/tmp"),
812 };
813 handle_task_result(Ok((1, Ok(outcome))), &mut graph, &executor).await;
814
815 let loaded = DependencyGraph::from_db(&*db.lock().await, "persist-test").unwrap();
817 assert_eq!(loaded.node(1).unwrap().state, NodeState::AwaitingMerge);
818 assert_eq!(loaded.node(1).unwrap().pr_number, Some(10));
819 });
820 }
821
822 fn mock_runner_with_pr_state(state: &'static str) -> MockCommandRunner {
823 let mut mock = MockCommandRunner::new();
824 mock.expect_run_gh().returning(move |args, _| {
825 let args = args.to_vec();
826 Box::pin(async move {
827 if args.iter().any(|a| a == "view") {
828 Ok(CommandOutput {
829 stdout: format!(r#"{{"state":"{state}"}}"#),
830 stderr: String::new(),
831 success: true,
832 })
833 } else {
834 Ok(CommandOutput {
835 stdout: String::new(),
836 stderr: String::new(),
837 success: true,
838 })
839 }
840 })
841 });
842 mock.expect_run_claude().returning(|_, _, _, _, _| {
843 Box::pin(async {
844 Ok(AgentResult {
845 cost_usd: 0.0,
846 duration: Duration::from_secs(0),
847 turns: 0,
848 output: String::new(),
849 session_id: String::new(),
850 success: true,
851 })
852 })
853 });
854 mock
855 }
856
857 fn make_merge_poll_executor(state: &'static str) -> Arc<PipelineExecutor<MockCommandRunner>> {
858 let gh_mock = mock_runner_with_pr_state(state);
859 let github = Arc::new(GhClient::new(gh_mock, std::path::Path::new("/tmp")));
860 let issues = make_github_provider(&github);
861 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
862 let runner = Arc::new(mock_runner_with_pr_state(state));
863 Arc::new(PipelineExecutor {
864 runner,
865 github,
866 issues,
867 db,
868 config: Config::default(),
869 cancel_token: CancellationToken::new(),
870 repo_dir: PathBuf::from("/tmp"),
871 })
872 }
873
874 #[test]
875 fn merge_polling_transitions_merged_pr() {
876 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
877 rt.block_on(async {
878 let executor = make_merge_poll_executor("MERGED");
879
880 let mut graph = DependencyGraph::new("merge-poll-test");
881 let mut node = standalone_node(&make_issue(1));
882 node.pr_number = Some(42);
883 node.run_id = Some("run-1".to_string());
884 graph.add_node(node);
885 graph.transition(1, NodeState::AwaitingMerge);
886
887 poll_awaiting_merges(&mut graph, &executor).await;
888
889 assert_eq!(graph.node(1).unwrap().state, NodeState::Merged);
890 });
891 }
892
893 #[test]
894 fn merge_polling_transitions_node_without_issue() {
895 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
896 rt.block_on(async {
897 let executor = make_merge_poll_executor("MERGED");
898
899 let mut graph = DependencyGraph::new("db-restore-test");
900 let mut node = GraphNode {
902 issue_number: 1,
903 title: "Issue #1".to_string(),
904 area: "test".to_string(),
905 predicted_files: vec![],
906 has_migration: false,
907 complexity: "full".to_string(),
908 state: NodeState::Pending,
909 pr_number: Some(42),
910 run_id: Some("run-1".to_string()),
911 issue: None,
912 target_repo: None,
913 };
914 node.state = NodeState::Pending;
915 graph.add_node(node);
916 graph.transition(1, NodeState::AwaitingMerge);
917
918 poll_awaiting_merges(&mut graph, &executor).await;
919
920 assert_eq!(graph.node(1).unwrap().state, NodeState::Merged);
922 });
923 }
924
925 #[test]
926 fn merge_polling_handles_closed_pr() {
927 let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
928 rt.block_on(async {
929 let executor = make_merge_poll_executor("CLOSED");
930
931 let plan = PlannerGraphOutput {
932 nodes: vec![
933 crate::agents::PlannedNode {
934 number: 1,
935 title: "Root".to_string(),
936 area: "a".to_string(),
937 predicted_files: vec![],
938 has_migration: false,
939 complexity: Complexity::Full,
940 depends_on: vec![],
941 reasoning: String::new(),
942 },
943 crate::agents::PlannedNode {
944 number: 2,
945 title: "Dep".to_string(),
946 area: "b".to_string(),
947 predicted_files: vec![],
948 has_migration: false,
949 complexity: Complexity::Full,
950 depends_on: vec![1],
951 reasoning: String::new(),
952 },
953 ],
954 total_issues: 2,
955 parallel_capacity: 1,
956 };
957 let test_issues = vec![make_issue(1), make_issue(2)];
958 let mut graph =
959 DependencyGraph::from_planner_output("merge-poll-close", &plan, &test_issues);
960 graph.transition(1, NodeState::AwaitingMerge);
961 graph.set_pr_number(1, 42);
962 graph.set_run_id(1, "run-1");
963
964 poll_awaiting_merges(&mut graph, &executor).await;
965
966 assert_eq!(graph.node(1).unwrap().state, NodeState::Failed);
967 assert_eq!(graph.node(2).unwrap().state, NodeState::Failed);
969 });
970 }
971
972 #[test]
973 fn merge_unlocks_dependent() {
974 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
975 rt.block_on(async {
976 let executor = make_merge_poll_executor("MERGED");
977
978 let plan = PlannerGraphOutput {
979 nodes: vec![
980 crate::agents::PlannedNode {
981 number: 1,
982 title: "Root".to_string(),
983 area: "a".to_string(),
984 predicted_files: vec![],
985 has_migration: false,
986 complexity: Complexity::Full,
987 depends_on: vec![],
988 reasoning: String::new(),
989 },
990 crate::agents::PlannedNode {
991 number: 2,
992 title: "Dep".to_string(),
993 area: "b".to_string(),
994 predicted_files: vec![],
995 has_migration: false,
996 complexity: Complexity::Full,
997 depends_on: vec![1],
998 reasoning: String::new(),
999 },
1000 ],
1001 total_issues: 2,
1002 parallel_capacity: 1,
1003 };
1004 let test_issues = vec![make_issue(1), make_issue(2)];
1005 let mut graph =
1006 DependencyGraph::from_planner_output("merge-unlock", &plan, &test_issues);
1007 graph.transition(1, NodeState::AwaitingMerge);
1008 graph.set_pr_number(1, 42);
1009 graph.set_run_id(1, "run-1");
1010
1011 assert!(graph.ready_issues().is_empty());
1013
1014 poll_awaiting_merges(&mut graph, &executor).await;
1015
1016 assert_eq!(graph.node(1).unwrap().state, NodeState::Merged);
1018 assert_eq!(graph.ready_issues(), vec![2]);
1019 });
1020 }
1021}