1use crate::manager::{WorktreeInfo, WorktreeManager, WorktreeStatus};
6use crate::paths::normalize_path;
7use miyabi_types::error::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::env;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14use tracing::{error, info, warn};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PoolConfig {
19 pub max_concurrency: usize,
21 pub timeout_seconds: u64,
23 pub fail_fast: bool,
25 pub auto_cleanup: bool,
27}
28
29impl Default for PoolConfig {
30 fn default() -> Self {
31 Self {
32 max_concurrency: 3,
33 timeout_seconds: 1800, fail_fast: false,
35 auto_cleanup: true,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct WorktreeTask {
43 pub issue_number: u64,
45 pub description: String,
47 pub agent_type: Option<String>,
49 pub metadata: Option<serde_json::Value>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TaskResult {
56 pub issue_number: u64,
58 pub worktree_id: String,
60 pub status: TaskStatus,
62 pub duration_ms: u64,
64 pub error: Option<String>,
66 pub output: Option<serde_json::Value>,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum TaskStatus {
72 Success,
73 Failed,
74 Timeout,
75 Cancelled,
76}
77
78pub struct WorktreePool {
80 manager: Arc<WorktreeManager>,
81 config: PoolConfig,
82 active_tasks: Arc<Mutex<HashMap<String, WorktreeTask>>>,
83}
84
85impl WorktreePool {
86 pub fn new(config: PoolConfig, worktree_base: Option<PathBuf>) -> Result<Self> {
92 let repo_path = miyabi_core::find_git_root(None)?;
93 let base = worktree_base.unwrap_or_else(default_worktree_base);
94 let resolved_base = if base.is_absolute() {
95 base
96 } else {
97 repo_path.join(base)
98 };
99 let resolved_base = normalize_path(resolved_base);
100
101 let manager =
102 Arc::new(WorktreeManager::new(&repo_path, &resolved_base, config.max_concurrency)?);
103
104 Ok(Self {
105 manager,
106 config,
107 active_tasks: Arc::new(Mutex::new(HashMap::new())),
108 })
109 }
110
111 pub fn new_with_path(
113 repo_path: impl AsRef<std::path::Path>,
114 worktree_base: impl AsRef<std::path::Path>,
115 config: PoolConfig,
116 ) -> Result<Self> {
117 let manager =
118 Arc::new(WorktreeManager::new(repo_path, worktree_base, config.max_concurrency)?);
119
120 Ok(Self {
121 manager,
122 config,
123 active_tasks: Arc::new(Mutex::new(HashMap::new())),
124 })
125 }
126
127 pub async fn execute_parallel<F, Fut>(
140 &self,
141 tasks: Vec<WorktreeTask>,
142 executor: F,
143 ) -> PoolExecutionResult
144 where
145 F: Fn(WorktreeInfo, WorktreeTask) -> Fut + Send + Sync + Clone + 'static,
146 Fut: std::future::Future<Output = Result<serde_json::Value>> + Send + 'static,
147 {
148 let start_time = std::time::Instant::now();
149 let task_count = tasks.len();
150
151 info!(
152 "Starting parallel execution of {} tasks with max concurrency: {}, fail_fast: {}",
153 task_count, self.config.max_concurrency, self.config.fail_fast
154 );
155
156 use futures::stream::{self, StreamExt};
158 use tokio::sync::watch;
159
160 let manager = self.manager.clone();
161 let active_tasks = self.active_tasks.clone();
162 let timeout_seconds = self.config.timeout_seconds;
163 let max_concurrency = self.config.max_concurrency;
164 let fail_fast = self.config.fail_fast;
165
166 let (cancel_tx, cancel_rx) = watch::channel(false);
168
169 let results: Vec<TaskResult> = stream::iter(tasks)
170 .map(|task| {
171 let manager = manager.clone();
172 let active_tasks = active_tasks.clone();
173 let executor = executor.clone();
174 let cancel_tx = cancel_tx.clone();
175 let cancel_rx = cancel_rx.clone();
176
177 async move {
178 if *cancel_rx.borrow() {
180 warn!("Task for issue #{} cancelled due to fail-fast", task.issue_number);
181 return TaskResult {
182 issue_number: task.issue_number,
183 worktree_id: String::new(),
184 status: TaskStatus::Cancelled,
185 duration_ms: 0,
186 error: Some("Cancelled due to fail-fast".to_string()),
187 output: None,
188 };
189 }
190
191 let task_start = std::time::Instant::now();
192
193 let worktree_info = match manager.create_worktree(task.issue_number).await {
195 Ok(info) => {
196 {
198 let mut tasks = active_tasks.lock().await;
199 tasks.insert(info.id.clone(), task.clone());
200 }
201 info
202 },
203 Err(e) => {
204 error!(
205 "Failed to create worktree for issue #{}: {}",
206 task.issue_number, e
207 );
208 return TaskResult {
209 issue_number: task.issue_number,
210 worktree_id: String::new(),
211 status: TaskStatus::Failed,
212 duration_ms: task_start.elapsed().as_millis() as u64,
213 error: Some(e.to_string()),
214 output: None,
215 };
216 },
217 };
218
219 let execution_result = tokio::time::timeout(
221 std::time::Duration::from_secs(timeout_seconds),
222 executor(worktree_info.clone(), task.clone()),
223 )
224 .await;
225
226 let task_result = match execution_result {
228 Ok(Ok(output)) => {
229 info!("Task for issue #{} completed successfully", task.issue_number);
230 let _ = manager
232 .update_status(&worktree_info.id, WorktreeStatus::Completed)
233 .await;
234 TaskResult {
235 issue_number: task.issue_number,
236 worktree_id: worktree_info.id.clone(),
237 status: TaskStatus::Success,
238 duration_ms: task_start.elapsed().as_millis() as u64,
239 error: None,
240 output: Some(output),
241 }
242 },
243 Ok(Err(e)) => {
244 error!("Task for issue #{} failed: {}", task.issue_number, e);
245 let _ = manager
246 .update_status(&worktree_info.id, WorktreeStatus::Failed)
247 .await;
248
249 if fail_fast {
251 warn!("Triggering fail-fast cancellation due to task failure");
252 let _ = cancel_tx.send(true);
253 }
254
255 TaskResult {
256 issue_number: task.issue_number,
257 worktree_id: worktree_info.id.clone(),
258 status: TaskStatus::Failed,
259 duration_ms: task_start.elapsed().as_millis() as u64,
260 error: Some(e.to_string()),
261 output: None,
262 }
263 },
264 Err(_) => {
265 warn!(
266 "Task for issue #{} timed out after {} seconds",
267 task.issue_number, timeout_seconds
268 );
269 let _ = manager
270 .update_status(&worktree_info.id, WorktreeStatus::Failed)
271 .await;
272
273 if fail_fast {
275 warn!("Triggering fail-fast cancellation due to task timeout");
276 let _ = cancel_tx.send(true);
277 }
278
279 TaskResult {
280 issue_number: task.issue_number,
281 worktree_id: worktree_info.id.clone(),
282 status: TaskStatus::Timeout,
283 duration_ms: task_start.elapsed().as_millis() as u64,
284 error: Some(format!("Timeout after {} seconds", timeout_seconds)),
285 output: None,
286 }
287 },
288 };
289
290 {
292 let mut tasks = active_tasks.lock().await;
293 tasks.remove(&worktree_info.id);
294 }
295
296 task_result
297 }
298 })
299 .buffer_unordered(max_concurrency)
300 .collect()
301 .await;
302
303 let total_duration = start_time.elapsed().as_millis() as u64;
304
305 let success_count = results.iter().filter(|r| r.status == TaskStatus::Success).count();
307 let failed_count = results.iter().filter(|r| r.status == TaskStatus::Failed).count();
308 let timeout_count = results.iter().filter(|r| r.status == TaskStatus::Timeout).count();
309 let cancelled_count = results.iter().filter(|r| r.status == TaskStatus::Cancelled).count();
310
311 info!(
312 "Parallel execution completed: {} successful, {} failed, {} timed out, {} cancelled, {}ms total",
313 success_count, failed_count, timeout_count, cancelled_count, total_duration
314 );
315
316 if self.config.auto_cleanup {
318 info!("Auto-cleanup enabled, removing worktrees");
319 if let Err(e) = self.manager.cleanup_all().await {
320 warn!("Cleanup failed: {}", e);
321 }
322 }
323
324 PoolExecutionResult {
325 total_tasks: task_count,
326 results,
327 total_duration_ms: total_duration,
328 success_count,
329 failed_count,
330 timeout_count,
331 cancelled_count,
332 }
333 }
334
335 pub async fn execute_simple<F, Fut>(
339 &self,
340 issue_numbers: Vec<u64>,
341 executor: F,
342 ) -> PoolExecutionResult
343 where
344 F: Fn(PathBuf, u64) -> Fut + Send + Sync + Clone + 'static,
345 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
346 {
347 let tasks: Vec<WorktreeTask> = issue_numbers
348 .into_iter()
349 .map(|issue_number| WorktreeTask {
350 issue_number,
351 description: format!("Task for issue #{}", issue_number),
352 agent_type: None,
353 metadata: None,
354 })
355 .collect();
356
357 self.execute_parallel(tasks, move |worktree_info, _task| {
358 let executor = executor.clone();
359 let worktree_path = worktree_info.path.clone();
360 let issue_number = worktree_info.issue_number;
361
362 async move {
363 executor(worktree_path, issue_number).await?;
364 Ok(serde_json::json!({"status": "completed"}))
365 }
366 })
367 .await
368 }
369
370 pub async fn stats(&self) -> PoolStats {
372 let worktree_stats = self.manager.stats().await;
373 let active_tasks = self.active_tasks.lock().await;
374
375 PoolStats {
376 max_concurrency: self.config.max_concurrency,
377 active_worktrees: worktree_stats.active,
378 idle_worktrees: worktree_stats.idle,
379 completed_worktrees: worktree_stats.completed,
380 failed_worktrees: worktree_stats.failed,
381 active_tasks: active_tasks.len(),
382 available_slots: worktree_stats.available_slots,
383 }
384 }
385
386 pub fn manager(&self) -> &Arc<WorktreeManager> {
388 &self.manager
389 }
390}
391
392fn default_worktree_base() -> PathBuf {
393 if cfg!(windows) {
394 match env::var("LOCALAPPDATA") {
395 Ok(dir) => PathBuf::from(dir).join("Miyabi").join("wt"),
396 Err(_) => PathBuf::from(".worktrees"),
397 }
398 } else {
399 PathBuf::from(".worktrees")
400 }
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct PoolExecutionResult {
406 pub total_tasks: usize,
408 pub results: Vec<TaskResult>,
410 pub total_duration_ms: u64,
412 pub success_count: usize,
414 pub failed_count: usize,
416 pub timeout_count: usize,
418 pub cancelled_count: usize,
420}
421
422impl PoolExecutionResult {
423 pub fn all_successful(&self) -> bool {
425 self.success_count == self.total_tasks
426 }
427
428 pub fn has_failures(&self) -> bool {
430 self.failed_count > 0 || self.timeout_count > 0
431 }
432
433 pub fn has_cancellations(&self) -> bool {
435 self.cancelled_count > 0
436 }
437
438 pub fn success_rate(&self) -> f64 {
440 if self.total_tasks == 0 {
441 0.0
442 } else {
443 (self.success_count as f64 / self.total_tasks as f64) * 100.0
444 }
445 }
446
447 pub fn failure_rate(&self) -> f64 {
449 if self.total_tasks == 0 {
450 0.0
451 } else {
452 ((self.failed_count + self.timeout_count) as f64 / self.total_tasks as f64) * 100.0
453 }
454 }
455
456 pub fn average_duration_ms(&self) -> f64 {
458 if self.results.is_empty() {
459 0.0
460 } else {
461 let total: u64 = self.results.iter().map(|r| r.duration_ms).sum();
462 total as f64 / self.results.len() as f64
463 }
464 }
465
466 pub fn min_duration_ms(&self) -> u64 {
468 self.results.iter().map(|r| r.duration_ms).min().unwrap_or(0)
469 }
470
471 pub fn max_duration_ms(&self) -> u64 {
473 self.results.iter().map(|r| r.duration_ms).max().unwrap_or(0)
474 }
475
476 pub fn throughput(&self) -> f64 {
478 if self.total_duration_ms == 0 {
479 0.0
480 } else {
481 (self.total_tasks as f64) / (self.total_duration_ms as f64 / 1000.0)
482 }
483 }
484
485 pub fn effective_concurrency(&self) -> f64 {
487 if self.total_duration_ms == 0 {
488 0.0
489 } else {
490 let total_work: u64 = self.results.iter().map(|r| r.duration_ms).sum();
491 (total_work as f64) / (self.total_duration_ms as f64)
492 }
493 }
494
495 pub fn failed_tasks(&self) -> Vec<&TaskResult> {
497 self.results.iter().filter(|r| r.status == TaskStatus::Failed).collect()
498 }
499
500 pub fn timed_out_tasks(&self) -> Vec<&TaskResult> {
502 self.results.iter().filter(|r| r.status == TaskStatus::Timeout).collect()
503 }
504
505 pub fn cancelled_tasks(&self) -> Vec<&TaskResult> {
507 self.results.iter().filter(|r| r.status == TaskStatus::Cancelled).collect()
508 }
509
510 pub fn successful_tasks(&self) -> Vec<&TaskResult> {
512 self.results.iter().filter(|r| r.status == TaskStatus::Success).collect()
513 }
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct PoolStats {
519 pub max_concurrency: usize,
521 pub active_worktrees: usize,
523 pub idle_worktrees: usize,
525 pub completed_worktrees: usize,
527 pub failed_worktrees: usize,
529 pub active_tasks: usize,
531 pub available_slots: usize,
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn test_pool_config_default() {
541 let config = PoolConfig::default();
542 assert_eq!(config.max_concurrency, 3);
543 assert_eq!(config.timeout_seconds, 1800);
544 assert!(!config.fail_fast);
545 assert!(config.auto_cleanup);
546 }
547
548 #[test]
549 fn test_worktree_task_creation() {
550 let task = WorktreeTask {
551 issue_number: 123,
552 description: "Test task".to_string(),
553 agent_type: Some("CodeGenAgent".to_string()),
554 metadata: None,
555 };
556
557 assert_eq!(task.issue_number, 123);
558 assert_eq!(task.description, "Test task");
559 assert_eq!(task.agent_type, Some("CodeGenAgent".to_string()));
560 }
561
562 #[test]
563 fn test_task_result_serialization() {
564 let result = TaskResult {
565 issue_number: 123,
566 worktree_id: "test-id".to_string(),
567 status: TaskStatus::Success,
568 duration_ms: 5000,
569 error: None,
570 output: Some(serde_json::json!({"test": true})),
571 };
572
573 let json = serde_json::to_string(&result).unwrap();
574 let deserialized: TaskResult = serde_json::from_str(&json).unwrap();
575
576 assert_eq!(result.issue_number, deserialized.issue_number);
577 assert_eq!(result.status, deserialized.status);
578 assert_eq!(result.duration_ms, deserialized.duration_ms);
579 }
580
581 #[test]
582 fn test_pool_execution_result_methods() {
583 let result = PoolExecutionResult {
584 total_tasks: 5,
585 results: vec![
586 TaskResult {
587 issue_number: 1,
588 worktree_id: "id1".to_string(),
589 status: TaskStatus::Success,
590 duration_ms: 1000,
591 error: None,
592 output: None,
593 },
594 TaskResult {
595 issue_number: 2,
596 worktree_id: "id2".to_string(),
597 status: TaskStatus::Success,
598 duration_ms: 2000,
599 error: None,
600 output: None,
601 },
602 TaskResult {
603 issue_number: 3,
604 worktree_id: "id3".to_string(),
605 status: TaskStatus::Failed,
606 duration_ms: 3000,
607 error: Some("Error".to_string()),
608 output: None,
609 },
610 ],
611 total_duration_ms: 10000,
612 success_count: 2,
613 failed_count: 1,
614 timeout_count: 0,
615 cancelled_count: 2,
616 };
617
618 assert!(!result.all_successful());
619 assert_eq!(result.success_rate(), 40.0);
620 assert_eq!(result.average_duration_ms(), 2000.0);
621 }
622
623 #[test]
624 fn test_task_status_equality() {
625 assert_eq!(TaskStatus::Success, TaskStatus::Success);
626 assert_ne!(TaskStatus::Success, TaskStatus::Failed);
627 assert_ne!(TaskStatus::Failed, TaskStatus::Timeout);
628 }
629}