miyabi_worktree/
pool.rs

1//! Worktree execution pool for parallel task processing
2//!
3//! Provides high-level abstractions for executing multiple tasks in parallel worktrees
4
5use crate::manager::{WorktreeInfo, WorktreeManager, WorktreeStatus};
6use crate::paths::normalize_path;
7use miyabi_types::error::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::env;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14use tracing::{error, info, warn};
15
16/// Configuration for worktree pool execution
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PoolConfig {
19    /// Maximum number of concurrent worktrees
20    pub max_concurrency: usize,
21    /// Timeout for individual task execution (seconds)
22    pub timeout_seconds: u64,
23    /// Whether to fail fast on first error
24    pub fail_fast: bool,
25    /// Whether to cleanup worktrees after execution
26    pub auto_cleanup: bool,
27}
28
29impl Default for PoolConfig {
30    fn default() -> Self {
31        Self {
32            max_concurrency: 3,
33            timeout_seconds: 1800, // 30 minutes
34            fail_fast: false,
35            auto_cleanup: true,
36        }
37    }
38}
39
40/// Task to be executed in a worktree
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct WorktreeTask {
43    /// Issue number for the task
44    pub issue_number: u64,
45    /// Task description
46    pub description: String,
47    /// Optional agent type to execute
48    pub agent_type: Option<String>,
49    /// Additional metadata
50    pub metadata: Option<serde_json::Value>,
51}
52
53/// Result of a worktree task execution
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TaskResult {
56    /// Issue number
57    pub issue_number: u64,
58    /// Worktree ID
59    pub worktree_id: String,
60    /// Execution status
61    pub status: TaskStatus,
62    /// Execution duration in milliseconds
63    pub duration_ms: u64,
64    /// Error message if failed
65    pub error: Option<String>,
66    /// Output data
67    pub output: Option<serde_json::Value>,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum TaskStatus {
72    Success,
73    Failed,
74    Timeout,
75    Cancelled,
76}
77
78/// Worktree execution pool
79pub struct WorktreePool {
80    manager: Arc<WorktreeManager>,
81    config: PoolConfig,
82    active_tasks: Arc<Mutex<HashMap<String, WorktreeTask>>>,
83}
84
85impl WorktreePool {
86    /// Create a new worktree pool with automatic repository discovery
87    ///
88    /// # Arguments
89    /// * `config` - Pool configuration
90    /// * `worktree_base` - Optional override for worktree base directory
91    pub fn new(config: PoolConfig, worktree_base: Option<PathBuf>) -> Result<Self> {
92        let repo_path = miyabi_core::find_git_root(None)?;
93        let base = worktree_base.unwrap_or_else(default_worktree_base);
94        let resolved_base = if base.is_absolute() {
95            base
96        } else {
97            repo_path.join(base)
98        };
99        let resolved_base = normalize_path(resolved_base);
100
101        let manager =
102            Arc::new(WorktreeManager::new(&repo_path, &resolved_base, config.max_concurrency)?);
103
104        Ok(Self {
105            manager,
106            config,
107            active_tasks: Arc::new(Mutex::new(HashMap::new())),
108        })
109    }
110
111    /// Create a worktree pool with explicit repository path
112    pub fn new_with_path(
113        repo_path: impl AsRef<std::path::Path>,
114        worktree_base: impl AsRef<std::path::Path>,
115        config: PoolConfig,
116    ) -> Result<Self> {
117        let manager =
118            Arc::new(WorktreeManager::new(repo_path, worktree_base, config.max_concurrency)?);
119
120        Ok(Self {
121            manager,
122            config,
123            active_tasks: Arc::new(Mutex::new(HashMap::new())),
124        })
125    }
126
127    /// Execute multiple tasks in parallel worktrees
128    ///
129    /// # Arguments
130    /// * `tasks` - Vector of tasks to execute
131    /// * `executor` - Async function to execute in each worktree
132    ///
133    /// # Returns
134    /// Pool execution result with individual task results
135    ///
136    /// # Fail-Fast Behavior
137    /// If `fail_fast` is enabled in config, execution stops after the first failure
138    /// and all remaining tasks are cancelled.
139    pub async fn execute_parallel<F, Fut>(
140        &self,
141        tasks: Vec<WorktreeTask>,
142        executor: F,
143    ) -> PoolExecutionResult
144    where
145        F: Fn(WorktreeInfo, WorktreeTask) -> Fut + Send + Sync + Clone + 'static,
146        Fut: std::future::Future<Output = Result<serde_json::Value>> + Send + 'static,
147    {
148        let start_time = std::time::Instant::now();
149        let task_count = tasks.len();
150
151        info!(
152            "Starting parallel execution of {} tasks with max concurrency: {}, fail_fast: {}",
153            task_count, self.config.max_concurrency, self.config.fail_fast
154        );
155
156        // Use futures::stream with buffer_unordered to respect concurrency limit
157        use futures::stream::{self, StreamExt};
158        use tokio::sync::watch;
159
160        let manager = self.manager.clone();
161        let active_tasks = self.active_tasks.clone();
162        let timeout_seconds = self.config.timeout_seconds;
163        let max_concurrency = self.config.max_concurrency;
164        let fail_fast = self.config.fail_fast;
165
166        // Create cancellation channel for fail-fast
167        let (cancel_tx, cancel_rx) = watch::channel(false);
168
169        let results: Vec<TaskResult> = stream::iter(tasks)
170            .map(|task| {
171                let manager = manager.clone();
172                let active_tasks = active_tasks.clone();
173                let executor = executor.clone();
174                let cancel_tx = cancel_tx.clone();
175                let cancel_rx = cancel_rx.clone();
176
177                async move {
178                    // Check if we should cancel (fail-fast mode)
179                    if *cancel_rx.borrow() {
180                        warn!("Task for issue #{} cancelled due to fail-fast", task.issue_number);
181                        return TaskResult {
182                            issue_number: task.issue_number,
183                            worktree_id: String::new(),
184                            status: TaskStatus::Cancelled,
185                            duration_ms: 0,
186                            error: Some("Cancelled due to fail-fast".to_string()),
187                            output: None,
188                        };
189                    }
190
191                    let task_start = std::time::Instant::now();
192
193                    // Create worktree (semaphore is acquired inside)
194                    let worktree_info = match manager.create_worktree(task.issue_number).await {
195                        Ok(info) => {
196                            // Track active task
197                            {
198                                let mut tasks = active_tasks.lock().await;
199                                tasks.insert(info.id.clone(), task.clone());
200                            }
201                            info
202                        },
203                        Err(e) => {
204                            error!(
205                                "Failed to create worktree for issue #{}: {}",
206                                task.issue_number, e
207                            );
208                            return TaskResult {
209                                issue_number: task.issue_number,
210                                worktree_id: String::new(),
211                                status: TaskStatus::Failed,
212                                duration_ms: task_start.elapsed().as_millis() as u64,
213                                error: Some(e.to_string()),
214                                output: None,
215                            };
216                        },
217                    };
218
219                    // Execute task with timeout
220                    let execution_result = tokio::time::timeout(
221                        std::time::Duration::from_secs(timeout_seconds),
222                        executor(worktree_info.clone(), task.clone()),
223                    )
224                    .await;
225
226                    // Process result
227                    let task_result = match execution_result {
228                        Ok(Ok(output)) => {
229                            info!("Task for issue #{} completed successfully", task.issue_number);
230                            // Update worktree status
231                            let _ = manager
232                                .update_status(&worktree_info.id, WorktreeStatus::Completed)
233                                .await;
234                            TaskResult {
235                                issue_number: task.issue_number,
236                                worktree_id: worktree_info.id.clone(),
237                                status: TaskStatus::Success,
238                                duration_ms: task_start.elapsed().as_millis() as u64,
239                                error: None,
240                                output: Some(output),
241                            }
242                        },
243                        Ok(Err(e)) => {
244                            error!("Task for issue #{} failed: {}", task.issue_number, e);
245                            let _ = manager
246                                .update_status(&worktree_info.id, WorktreeStatus::Failed)
247                                .await;
248
249                            // Trigger cancellation if fail_fast is enabled
250                            if fail_fast {
251                                warn!("Triggering fail-fast cancellation due to task failure");
252                                let _ = cancel_tx.send(true);
253                            }
254
255                            TaskResult {
256                                issue_number: task.issue_number,
257                                worktree_id: worktree_info.id.clone(),
258                                status: TaskStatus::Failed,
259                                duration_ms: task_start.elapsed().as_millis() as u64,
260                                error: Some(e.to_string()),
261                                output: None,
262                            }
263                        },
264                        Err(_) => {
265                            warn!(
266                                "Task for issue #{} timed out after {} seconds",
267                                task.issue_number, timeout_seconds
268                            );
269                            let _ = manager
270                                .update_status(&worktree_info.id, WorktreeStatus::Failed)
271                                .await;
272
273                            // Trigger cancellation if fail_fast is enabled
274                            if fail_fast {
275                                warn!("Triggering fail-fast cancellation due to task timeout");
276                                let _ = cancel_tx.send(true);
277                            }
278
279                            TaskResult {
280                                issue_number: task.issue_number,
281                                worktree_id: worktree_info.id.clone(),
282                                status: TaskStatus::Timeout,
283                                duration_ms: task_start.elapsed().as_millis() as u64,
284                                error: Some(format!("Timeout after {} seconds", timeout_seconds)),
285                                output: None,
286                            }
287                        },
288                    };
289
290                    // Remove from active tasks
291                    {
292                        let mut tasks = active_tasks.lock().await;
293                        tasks.remove(&worktree_info.id);
294                    }
295
296                    task_result
297                }
298            })
299            .buffer_unordered(max_concurrency)
300            .collect()
301            .await;
302
303        let total_duration = start_time.elapsed().as_millis() as u64;
304
305        // Calculate statistics
306        let success_count = results.iter().filter(|r| r.status == TaskStatus::Success).count();
307        let failed_count = results.iter().filter(|r| r.status == TaskStatus::Failed).count();
308        let timeout_count = results.iter().filter(|r| r.status == TaskStatus::Timeout).count();
309        let cancelled_count = results.iter().filter(|r| r.status == TaskStatus::Cancelled).count();
310
311        info!(
312            "Parallel execution completed: {} successful, {} failed, {} timed out, {} cancelled, {}ms total",
313            success_count, failed_count, timeout_count, cancelled_count, total_duration
314        );
315
316        // Cleanup if configured
317        if self.config.auto_cleanup {
318            info!("Auto-cleanup enabled, removing worktrees");
319            if let Err(e) = self.manager.cleanup_all().await {
320                warn!("Cleanup failed: {}", e);
321            }
322        }
323
324        PoolExecutionResult {
325            total_tasks: task_count,
326            results,
327            total_duration_ms: total_duration,
328            success_count,
329            failed_count,
330            timeout_count,
331            cancelled_count,
332        }
333    }
334
335    /// Execute tasks with automatic worktree lifecycle management
336    ///
337    /// This is a simplified version that automatically creates, executes, and cleans up worktrees
338    pub async fn execute_simple<F, Fut>(
339        &self,
340        issue_numbers: Vec<u64>,
341        executor: F,
342    ) -> PoolExecutionResult
343    where
344        F: Fn(PathBuf, u64) -> Fut + Send + Sync + Clone + 'static,
345        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
346    {
347        let tasks: Vec<WorktreeTask> = issue_numbers
348            .into_iter()
349            .map(|issue_number| WorktreeTask {
350                issue_number,
351                description: format!("Task for issue #{}", issue_number),
352                agent_type: None,
353                metadata: None,
354            })
355            .collect();
356
357        self.execute_parallel(tasks, move |worktree_info, _task| {
358            let executor = executor.clone();
359            let worktree_path = worktree_info.path.clone();
360            let issue_number = worktree_info.issue_number;
361
362            async move {
363                executor(worktree_path, issue_number).await?;
364                Ok(serde_json::json!({"status": "completed"}))
365            }
366        })
367        .await
368    }
369
370    /// Get current pool statistics
371    pub async fn stats(&self) -> PoolStats {
372        let worktree_stats = self.manager.stats().await;
373        let active_tasks = self.active_tasks.lock().await;
374
375        PoolStats {
376            max_concurrency: self.config.max_concurrency,
377            active_worktrees: worktree_stats.active,
378            idle_worktrees: worktree_stats.idle,
379            completed_worktrees: worktree_stats.completed,
380            failed_worktrees: worktree_stats.failed,
381            active_tasks: active_tasks.len(),
382            available_slots: worktree_stats.available_slots,
383        }
384    }
385
386    /// Get reference to underlying manager
387    pub fn manager(&self) -> &Arc<WorktreeManager> {
388        &self.manager
389    }
390}
391
392fn default_worktree_base() -> PathBuf {
393    if cfg!(windows) {
394        match env::var("LOCALAPPDATA") {
395            Ok(dir) => PathBuf::from(dir).join("Miyabi").join("wt"),
396            Err(_) => PathBuf::from(".worktrees"),
397        }
398    } else {
399        PathBuf::from(".worktrees")
400    }
401}
402
403/// Result of pool execution
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct PoolExecutionResult {
406    /// Total number of tasks
407    pub total_tasks: usize,
408    /// Individual task results
409    pub results: Vec<TaskResult>,
410    /// Total execution time in milliseconds
411    pub total_duration_ms: u64,
412    /// Number of successful tasks
413    pub success_count: usize,
414    /// Number of failed tasks
415    pub failed_count: usize,
416    /// Number of timed out tasks
417    pub timeout_count: usize,
418    /// Number of cancelled tasks
419    pub cancelled_count: usize,
420}
421
422impl PoolExecutionResult {
423    /// Check if all tasks were successful
424    pub fn all_successful(&self) -> bool {
425        self.success_count == self.total_tasks
426    }
427
428    /// Check if any tasks failed
429    pub fn has_failures(&self) -> bool {
430        self.failed_count > 0 || self.timeout_count > 0
431    }
432
433    /// Check if any tasks were cancelled
434    pub fn has_cancellations(&self) -> bool {
435        self.cancelled_count > 0
436    }
437
438    /// Get success rate as percentage
439    pub fn success_rate(&self) -> f64 {
440        if self.total_tasks == 0 {
441            0.0
442        } else {
443            (self.success_count as f64 / self.total_tasks as f64) * 100.0
444        }
445    }
446
447    /// Get failure rate as percentage
448    pub fn failure_rate(&self) -> f64 {
449        if self.total_tasks == 0 {
450            0.0
451        } else {
452            ((self.failed_count + self.timeout_count) as f64 / self.total_tasks as f64) * 100.0
453        }
454    }
455
456    /// Get average task duration in milliseconds
457    pub fn average_duration_ms(&self) -> f64 {
458        if self.results.is_empty() {
459            0.0
460        } else {
461            let total: u64 = self.results.iter().map(|r| r.duration_ms).sum();
462            total as f64 / self.results.len() as f64
463        }
464    }
465
466    /// Get minimum task duration in milliseconds
467    pub fn min_duration_ms(&self) -> u64 {
468        self.results.iter().map(|r| r.duration_ms).min().unwrap_or(0)
469    }
470
471    /// Get maximum task duration in milliseconds
472    pub fn max_duration_ms(&self) -> u64 {
473        self.results.iter().map(|r| r.duration_ms).max().unwrap_or(0)
474    }
475
476    /// Get throughput (tasks per second)
477    pub fn throughput(&self) -> f64 {
478        if self.total_duration_ms == 0 {
479            0.0
480        } else {
481            (self.total_tasks as f64) / (self.total_duration_ms as f64 / 1000.0)
482        }
483    }
484
485    /// Get average concurrency (based on actual execution time vs total time)
486    pub fn effective_concurrency(&self) -> f64 {
487        if self.total_duration_ms == 0 {
488            0.0
489        } else {
490            let total_work: u64 = self.results.iter().map(|r| r.duration_ms).sum();
491            (total_work as f64) / (self.total_duration_ms as f64)
492        }
493    }
494
495    /// Get failed tasks
496    pub fn failed_tasks(&self) -> Vec<&TaskResult> {
497        self.results.iter().filter(|r| r.status == TaskStatus::Failed).collect()
498    }
499
500    /// Get timed out tasks
501    pub fn timed_out_tasks(&self) -> Vec<&TaskResult> {
502        self.results.iter().filter(|r| r.status == TaskStatus::Timeout).collect()
503    }
504
505    /// Get cancelled tasks
506    pub fn cancelled_tasks(&self) -> Vec<&TaskResult> {
507        self.results.iter().filter(|r| r.status == TaskStatus::Cancelled).collect()
508    }
509
510    /// Get successful tasks
511    pub fn successful_tasks(&self) -> Vec<&TaskResult> {
512        self.results.iter().filter(|r| r.status == TaskStatus::Success).collect()
513    }
514}
515
516/// Pool statistics
517#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct PoolStats {
519    /// Maximum concurrency setting
520    pub max_concurrency: usize,
521    /// Number of active worktrees
522    pub active_worktrees: usize,
523    /// Number of idle worktrees
524    pub idle_worktrees: usize,
525    /// Number of completed worktrees
526    pub completed_worktrees: usize,
527    /// Number of failed worktrees
528    pub failed_worktrees: usize,
529    /// Number of active tasks
530    pub active_tasks: usize,
531    /// Number of available slots
532    pub available_slots: usize,
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn test_pool_config_default() {
541        let config = PoolConfig::default();
542        assert_eq!(config.max_concurrency, 3);
543        assert_eq!(config.timeout_seconds, 1800);
544        assert!(!config.fail_fast);
545        assert!(config.auto_cleanup);
546    }
547
548    #[test]
549    fn test_worktree_task_creation() {
550        let task = WorktreeTask {
551            issue_number: 123,
552            description: "Test task".to_string(),
553            agent_type: Some("CodeGenAgent".to_string()),
554            metadata: None,
555        };
556
557        assert_eq!(task.issue_number, 123);
558        assert_eq!(task.description, "Test task");
559        assert_eq!(task.agent_type, Some("CodeGenAgent".to_string()));
560    }
561
562    #[test]
563    fn test_task_result_serialization() {
564        let result = TaskResult {
565            issue_number: 123,
566            worktree_id: "test-id".to_string(),
567            status: TaskStatus::Success,
568            duration_ms: 5000,
569            error: None,
570            output: Some(serde_json::json!({"test": true})),
571        };
572
573        let json = serde_json::to_string(&result).unwrap();
574        let deserialized: TaskResult = serde_json::from_str(&json).unwrap();
575
576        assert_eq!(result.issue_number, deserialized.issue_number);
577        assert_eq!(result.status, deserialized.status);
578        assert_eq!(result.duration_ms, deserialized.duration_ms);
579    }
580
581    #[test]
582    fn test_pool_execution_result_methods() {
583        let result = PoolExecutionResult {
584            total_tasks: 5,
585            results: vec![
586                TaskResult {
587                    issue_number: 1,
588                    worktree_id: "id1".to_string(),
589                    status: TaskStatus::Success,
590                    duration_ms: 1000,
591                    error: None,
592                    output: None,
593                },
594                TaskResult {
595                    issue_number: 2,
596                    worktree_id: "id2".to_string(),
597                    status: TaskStatus::Success,
598                    duration_ms: 2000,
599                    error: None,
600                    output: None,
601                },
602                TaskResult {
603                    issue_number: 3,
604                    worktree_id: "id3".to_string(),
605                    status: TaskStatus::Failed,
606                    duration_ms: 3000,
607                    error: Some("Error".to_string()),
608                    output: None,
609                },
610            ],
611            total_duration_ms: 10000,
612            success_count: 2,
613            failed_count: 1,
614            timeout_count: 0,
615            cancelled_count: 2,
616        };
617
618        assert!(!result.all_successful());
619        assert_eq!(result.success_rate(), 40.0);
620        assert_eq!(result.average_duration_ms(), 2000.0);
621    }
622
623    #[test]
624    fn test_task_status_equality() {
625        assert_eq!(TaskStatus::Success, TaskStatus::Success);
626        assert_ne!(TaskStatus::Success, TaskStatus::Failed);
627        assert_ne!(TaskStatus::Failed, TaskStatus::Timeout);
628    }
629}