1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4 time::Duration,
5};
6
7use anyhow::Result;
8use tokio::{
9 sync::{Mutex, Semaphore},
10 task::JoinSet,
11};
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info};
14
15use super::executor::PipelineExecutor;
16use crate::{agents::Complexity, issues::PipelineIssue, process::CommandRunner};
17
18pub async fn run_batch<R: CommandRunner + 'static>(
23 executor: &Arc<PipelineExecutor<R>>,
24 issues: Vec<PipelineIssue>,
25 max_parallel: usize,
26 auto_merge: bool,
27) -> Result<()> {
28 let semaphore = Arc::new(Semaphore::new(max_parallel));
29 let mut tasks = JoinSet::new();
30
31 for issue in issues {
32 let permit = semaphore
33 .clone()
34 .acquire_owned()
35 .await
36 .map_err(|e| anyhow::anyhow!("semaphore closed: {e}"))?;
37 let exec = Arc::clone(executor);
38 tasks.spawn(async move {
39 let number = issue.number;
40 let result = exec.run_issue(&issue, auto_merge).await;
41 drop(permit);
42 (number, result)
43 });
44 }
45
46 let mut had_errors = false;
47 while let Some(join_result) = tasks.join_next().await {
48 match join_result {
49 Ok((number, Ok(()))) => {
50 info!(issue = number, "pipeline completed successfully");
51 }
52 Ok((number, Err(e))) => {
53 error!(issue = number, error = %e, "pipeline failed for issue");
54 had_errors = true;
55 }
56 Err(e) => {
57 error!(error = %e, "pipeline task panicked");
58 had_errors = true;
59 }
60 }
61 }
62
63 if had_errors {
64 anyhow::bail!("one or more pipelines failed");
65 }
66
67 Ok(())
68}
69
70async fn get_complexity_map<R: CommandRunner + 'static>(
74 executor: &Arc<PipelineExecutor<R>>,
75 issues: &[PipelineIssue],
76) -> HashMap<u32, Complexity> {
77 let mut map = HashMap::new();
78 if let Some(plan) = executor.plan_issues(issues).await {
79 info!(batches = plan.batches.len(), total = plan.total_issues, "planner produced a plan");
80 for batch in &plan.batches {
81 for pi in &batch.issues {
82 map.insert(pi.number, pi.complexity.clone());
83 }
84 }
85 }
86 map
87}
88
89fn handle_task_result(result: Result<(u32, Result<()>), tokio::task::JoinError>) {
90 match result {
91 Ok((number, Ok(()))) => {
92 info!(issue = number, "pipeline completed successfully");
93 }
94 Ok((number, Err(e))) => {
95 error!(issue = number, error = %e, "pipeline failed for issue");
96 }
97 Err(e) => {
98 error!(error = %e, "pipeline task panicked");
99 }
100 }
101}
102
103pub async fn polling_loop<R: CommandRunner + 'static>(
109 executor: Arc<PipelineExecutor<R>>,
110 auto_merge: bool,
111 cancel_token: CancellationToken,
112) -> Result<()> {
113 let poll_interval = Duration::from_secs(executor.config.pipeline.poll_interval);
114 let max_parallel = executor.config.pipeline.max_parallel as usize;
115 let ready_label = executor.config.labels.ready.clone();
116 let semaphore = Arc::new(Semaphore::new(max_parallel));
117 let mut tasks = JoinSet::new();
118 let in_flight: Arc<Mutex<HashSet<u32>>> = Arc::new(Mutex::new(HashSet::new()));
119
120 info!(poll_interval_secs = poll_interval.as_secs(), max_parallel, "continuous polling started");
121
122 loop {
123 tokio::select! {
124 () = cancel_token.cancelled() => {
125 info!("shutdown signal received, waiting for in-flight pipelines");
126 while let Some(result) = tasks.join_next().await {
127 handle_task_result(result);
128 }
129 break;
130 }
131 () = tokio::time::sleep(poll_interval) => {
132 match executor.issues.get_ready_issues(&ready_label).await {
133 Ok(issues) => {
134 let in_flight_guard = in_flight.lock().await;
135 let new_issues: Vec<_> = issues
136 .into_iter()
137 .filter(|i| !in_flight_guard.contains(&i.number))
138 .collect();
139 drop(in_flight_guard);
140
141 if new_issues.is_empty() {
142 info!("no new issues found, waiting");
143 continue;
144 }
145
146 info!(count = new_issues.len(), "found new issues to process");
147
148 let complexity_map =
149 get_complexity_map(&executor, &new_issues).await;
150
151 for issue in new_issues {
152 let sem = Arc::clone(&semaphore);
153 let exec = Arc::clone(&executor);
154 let in_fl = Arc::clone(&in_flight);
155 let number = issue.number;
156 let complexity = complexity_map.get(&number).cloned();
157
158 in_fl.lock().await.insert(number);
159
160 tasks.spawn(async move {
161 let permit = match sem.acquire_owned().await {
162 Ok(p) => p,
163 Err(e) => {
164 in_fl.lock().await.remove(&number);
165 return (
166 number,
167 Err(anyhow::anyhow!(
168 "semaphore closed: {e}"
169 )),
170 );
171 }
172 };
173 let result = exec
174 .run_issue_with_complexity(
175 &issue,
176 auto_merge,
177 complexity,
178 )
179 .await;
180 in_fl.lock().await.remove(&number);
181 drop(permit);
182 (number, result)
183 });
184 }
185 }
186 Err(e) => {
187 error!(error = %e, "failed to fetch issues");
188 }
189 }
190 }
191 Some(result) = tasks.join_next(), if !tasks.is_empty() => {
192 handle_task_result(result);
193 }
194 }
195 }
196
197 Ok(())
198}
199
200#[cfg(test)]
201mod tests {
202 use std::path::PathBuf;
203
204 use tokio::sync::Mutex;
205
206 use super::*;
207 use crate::{
208 config::Config,
209 github::GhClient,
210 issues::{IssueOrigin, IssueProvider, github::GithubIssueProvider},
211 process::{AgentResult, CommandOutput, MockCommandRunner},
212 };
213
214 fn mock_runner_for_batch() -> MockCommandRunner {
215 let mut mock = MockCommandRunner::new();
216 mock.expect_run_gh().returning(|_, _| {
217 Box::pin(async {
218 Ok(CommandOutput {
219 stdout: "https://github.com/user/repo/pull/1\n".to_string(),
220 stderr: String::new(),
221 success: true,
222 })
223 })
224 });
225 mock.expect_run_claude().returning(|_, _, _, _| {
226 Box::pin(async {
227 Ok(AgentResult {
228 cost_usd: 1.0,
229 duration: Duration::from_secs(5),
230 turns: 3,
231 output: r#"{"findings":[],"summary":"clean"}"#.to_string(),
232 session_id: "sess-1".to_string(),
233 success: true,
234 })
235 })
236 });
237 mock
238 }
239
240 fn make_github_provider(gh: &Arc<GhClient<MockCommandRunner>>) -> Arc<dyn IssueProvider> {
241 Arc::new(GithubIssueProvider::new(Arc::clone(gh), "target_repo"))
242 }
243
244 #[tokio::test]
245 async fn cancellation_stops_polling() {
246 let cancel = CancellationToken::new();
247 let runner = Arc::new(mock_runner_for_batch());
248 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
249 let issues = make_github_provider(&github);
250 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
251
252 let mut config = Config::default();
253 config.pipeline.poll_interval = 3600; let executor = Arc::new(PipelineExecutor {
256 runner,
257 github,
258 issues,
259 db,
260 config,
261 cancel_token: cancel.clone(),
262 repo_dir: PathBuf::from("/tmp"),
263 });
264
265 let cancel_clone = cancel.clone();
266 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
267
268 cancel.cancel();
270
271 let result = handle.await.unwrap();
272 assert!(result.is_ok());
273 }
274
275 #[tokio::test]
276 async fn cancellation_exits_within_timeout() {
277 let cancel = CancellationToken::new();
278 let runner = Arc::new(mock_runner_for_batch());
279 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
280 let issues = make_github_provider(&github);
281 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
282
283 let mut config = Config::default();
284 config.pipeline.poll_interval = 3600;
285
286 let executor = Arc::new(PipelineExecutor {
287 runner,
288 github,
289 issues,
290 db,
291 config,
292 cancel_token: cancel.clone(),
293 repo_dir: PathBuf::from("/tmp"),
294 });
295
296 let cancel_clone = cancel.clone();
297 let handle = tokio::spawn(async move { polling_loop(executor, false, cancel_clone).await });
298
299 cancel.cancel();
300
301 let result = tokio::time::timeout(Duration::from_secs(5), handle)
302 .await
303 .expect("polling loop should exit within timeout")
304 .unwrap();
305 assert!(result.is_ok());
306 }
307
308 #[tokio::test]
309 async fn in_flight_set_filters_duplicate_issues() {
310 let in_flight: Arc<Mutex<HashSet<u32>>> = Arc::new(Mutex::new(HashSet::new()));
311
312 in_flight.lock().await.insert(1);
314
315 let issues = vec![
316 PipelineIssue {
317 number: 1,
318 title: "Already running".to_string(),
319 body: String::new(),
320 source: IssueOrigin::Github,
321 target_repo: None,
322 },
323 PipelineIssue {
324 number: 2,
325 title: "New issue".to_string(),
326 body: String::new(),
327 source: IssueOrigin::Github,
328 target_repo: None,
329 },
330 PipelineIssue {
331 number: 3,
332 title: "Another new".to_string(),
333 body: String::new(),
334 source: IssueOrigin::Github,
335 target_repo: None,
336 },
337 ];
338
339 let guard = in_flight.lock().await;
340 let new_issues: Vec<_> =
341 issues.into_iter().filter(|i| !guard.contains(&i.number)).collect();
342 drop(guard);
343
344 assert_eq!(new_issues.len(), 2);
345 assert_eq!(new_issues[0].number, 2);
346 assert_eq!(new_issues[1].number, 3);
347 }
348
349 #[test]
350 fn handle_task_result_does_not_panic_on_success() {
351 handle_task_result(Ok((1, Ok(()))));
352 }
353
354 #[test]
355 fn handle_task_result_does_not_panic_on_error() {
356 handle_task_result(Ok((1, Err(anyhow::anyhow!("test error")))));
357 }
358
359 #[tokio::test]
360 async fn get_complexity_map_returns_empty_on_planner_failure() {
361 let mut mock = MockCommandRunner::new();
362 mock.expect_run_gh().returning(|_, _| {
363 Box::pin(async {
364 Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
365 })
366 });
367 mock.expect_run_claude().returning(|_, _, _, _| {
368 Box::pin(async {
369 Ok(AgentResult {
370 cost_usd: 0.5,
371 duration: Duration::from_secs(2),
372 turns: 1,
373 output: "I don't know how to plan".to_string(),
374 session_id: "sess-plan".to_string(),
375 success: true,
376 })
377 })
378 });
379
380 let runner = Arc::new(mock);
381 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
382 let issues_provider = make_github_provider(&github);
383 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
384
385 let executor = Arc::new(PipelineExecutor {
386 runner,
387 github,
388 issues: issues_provider,
389 db,
390 config: Config::default(),
391 cancel_token: CancellationToken::new(),
392 repo_dir: PathBuf::from("/tmp"),
393 });
394
395 let issues = vec![PipelineIssue {
396 number: 1,
397 title: "Test".to_string(),
398 body: "body".to_string(),
399 source: IssueOrigin::Github,
400 target_repo: None,
401 }];
402
403 let map = get_complexity_map(&executor, &issues).await;
404 assert!(map.is_empty());
405 }
406
407 #[tokio::test]
408 async fn get_complexity_map_extracts_complexity() {
409 let mut mock = MockCommandRunner::new();
410 mock.expect_run_gh().returning(|_, _| {
411 Box::pin(async {
412 Ok(CommandOutput { stdout: String::new(), stderr: String::new(), success: true })
413 })
414 });
415 mock.expect_run_claude().returning(|_, _, _, _| {
416 Box::pin(async {
417 Ok(AgentResult {
418 cost_usd: 0.5,
419 duration: Duration::from_secs(2),
420 turns: 1,
421 output: r#"{"batches":[{"batch":1,"issues":[{"number":1,"complexity":"simple"},{"number":2,"complexity":"full"}],"reasoning":"ok"}],"total_issues":2,"parallel_capacity":2}"#.to_string(),
422 session_id: "sess-plan".to_string(),
423 success: true,
424 })
425 })
426 });
427
428 let runner = Arc::new(mock);
429 let github = Arc::new(GhClient::new(mock_runner_for_batch(), std::path::Path::new("/tmp")));
430 let issues_provider = make_github_provider(&github);
431 let db = Arc::new(Mutex::new(crate::db::open_in_memory().unwrap()));
432
433 let executor = Arc::new(PipelineExecutor {
434 runner,
435 github,
436 issues: issues_provider,
437 db,
438 config: Config::default(),
439 cancel_token: CancellationToken::new(),
440 repo_dir: PathBuf::from("/tmp"),
441 });
442
443 let issues = vec![
444 PipelineIssue {
445 number: 1,
446 title: "Simple".to_string(),
447 body: "simple change".to_string(),
448 source: IssueOrigin::Github,
449 target_repo: None,
450 },
451 PipelineIssue {
452 number: 2,
453 title: "Complex".to_string(),
454 body: "big feature".to_string(),
455 source: IssueOrigin::Github,
456 target_repo: None,
457 },
458 ];
459
460 let map = get_complexity_map(&executor, &issues).await;
461 assert_eq!(map.get(&1), Some(&Complexity::Simple));
462 assert_eq!(map.get(&2), Some(&Complexity::Full));
463 }
464}