polykit_core/
runner.rs

1//! Task execution engine and orchestration.
2
3use std::collections::HashSet;
4use std::path::PathBuf;
5
6use rayon::prelude::*;
7use crossbeam::channel;
8
9use std::sync::Arc;
10
11use crate::command_validator::CommandValidator;
12use crate::error::{Error, Result};
13use crate::executor::TaskExecutor;
14use crate::graph::DependencyGraph;
15use crate::package::Package;
16use crate::remote_cache::RemoteCache;
17use crate::streaming::StreamingTask;
18use crate::task_cache::TaskCache;
19
20/// Executes tasks across packages respecting dependency order.
21pub struct TaskRunner {
22    packages_dir: PathBuf,
23    graph: DependencyGraph,
24    max_parallel: Option<usize>,
25    command_validator: CommandValidator,
26    task_cache: Option<TaskCache>,
27    remote_cache: Option<Arc<RemoteCache>>,
28    thread_pool: Arc<rayon::ThreadPool>,
29    executor: TaskExecutor,
30}
31
32impl TaskRunner {
33    pub fn new(packages_dir: impl Into<PathBuf>, graph: DependencyGraph) -> Self {
34        let pool = rayon::ThreadPoolBuilder::new()
35            .num_threads(rayon::current_num_threads())
36            .thread_name(|i| format!("polykit-worker-{}", i))
37            .build()
38            .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
39
40        let packages_dir_path = packages_dir.into();
41        let executor = TaskExecutor::new(
42            packages_dir_path.clone(),
43            graph.clone(),
44            CommandValidator::new(),
45            None,
46            None,
47        );
48
49        Self {
50            packages_dir: packages_dir_path,
51            graph,
52            max_parallel: None,
53            command_validator: CommandValidator::new(),
54            task_cache: None,
55            remote_cache: None,
56            thread_pool: Arc::new(pool),
57            executor,
58        }
59    }
60
61    pub fn with_command_validator(mut self, validator: CommandValidator) -> Self {
62        self.command_validator = validator.clone();
63        self.executor = TaskExecutor::new(
64            self.packages_dir.clone(),
65            self.graph.clone(),
66            validator,
67            self.task_cache.clone(),
68            self.remote_cache.clone(),
69        );
70        self
71    }
72
73    pub fn with_task_cache(mut self, cache: TaskCache) -> Self {
74        self.task_cache = Some(cache.clone());
75        self.executor = TaskExecutor::new(
76            self.packages_dir.clone(),
77            self.graph.clone(),
78            self.command_validator.clone(),
79            Some(cache),
80            self.remote_cache.clone(),
81        );
82        self
83    }
84
85    pub fn with_max_parallel(mut self, max_parallel: Option<usize>) -> Self {
86        self.max_parallel = max_parallel;
87        self
88    }
89
90    pub fn with_remote_cache(mut self, remote_cache: Arc<RemoteCache>) -> Self {
91        self.remote_cache = Some(remote_cache.clone());
92        self.executor = TaskExecutor::new(
93            self.packages_dir.clone(),
94            self.graph.clone(),
95            self.command_validator.clone(),
96            self.task_cache.clone(),
97            Some(remote_cache),
98        );
99        self
100    }
101
102    pub fn run_task(
103        &self,
104        task_name: &str,
105        package_names: Option<&[String]>,
106    ) -> Result<Vec<TaskResult>> {
107        if let Some(names) = package_names {
108            if names.is_empty() {
109                return Ok(Vec::new());
110            }
111            if names.len() == 1 {
112                if let Some(package) = self.graph.get_package(&names[0]) {
113                    let result = self.executor.execute_task(package, task_name)?;
114                    return Ok(vec![result]);
115                }
116            }
117        }
118
119        let packages_to_run = if let Some(names) = package_names {
120            names
121                .iter()
122                .filter_map(|name| self.graph.get_package(name))
123                .collect::<Vec<_>>()
124        } else {
125            self.graph.all_packages()
126        };
127
128        if packages_to_run.is_empty() {
129            return Ok(Vec::new());
130        }
131
132        let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
133
134        let levels = self.graph.dependency_levels();
135        let mut results = Vec::with_capacity(packages_to_run.len());
136
137        for level in levels {
138            let level_packages: Vec<&Package> = level
139                .iter()
140                .filter(|name| packages_set.contains(name.as_str()))
141                .filter_map(|name| self.graph.get_package(name))
142                .collect();
143
144            if level_packages.is_empty() {
145                continue;
146            }
147
148            let (tx, rx) = channel::unbounded();
149            let executor = &self.executor;
150            self.thread_pool.install(|| {
151                level_packages
152                    .into_par_iter()
153                    .for_each(|package| {
154                        let result = executor.execute_task(package, task_name);
155                        let _ = tx.send(result);
156                    });
157            });
158            drop(tx);
159
160            let level_results: Result<Vec<TaskResult>> = rx
161                .iter()
162                .collect::<std::result::Result<Vec<_>, _>>()
163                .map_err(|e| Error::TaskExecution {
164                    package: "unknown".to_string(),
165                    task: task_name.to_string(),
166                    message: format!("Task execution failed: {}", e),
167                });
168
169            let mut level_results = level_results?;
170            results.append(&mut level_results);
171        }
172
173        Ok(results)
174    }
175
176    pub async fn run_task_streaming<F>(
177        &self,
178        task_name: &str,
179        package_names: Option<&[String]>,
180        on_output: F,
181    ) -> Result<Vec<TaskResult>>
182    where
183        F: Fn(&str, &str, bool) + Send + Sync + 'static,
184    {
185        let packages_to_run: Vec<Package> = if let Some(names) = package_names {
186            names
187                .iter()
188                .filter_map(|name| self.graph.get_package(name))
189                .cloned()
190                .collect()
191        } else {
192            self.graph.all_packages().into_iter().cloned().collect()
193        };
194
195        if packages_to_run.is_empty() {
196            return Ok(Vec::new());
197        }
198
199        let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
200
201        let levels = self.graph.dependency_levels();
202        let mut results = Vec::new();
203        use std::sync::{Arc, Mutex};
204        use tokio::sync::mpsc;
205
206        let output_handler = Arc::new(Mutex::new(on_output));
207
208        for level in levels {
209            let level_packages: Vec<Package> = level
210                .iter()
211                .filter(|name| packages_set.contains(name.as_str()))
212                .filter_map(|name| self.graph.get_package(name))
213                .cloned()
214                .collect();
215
216            if level_packages.is_empty() {
217                continue;
218            }
219
220            let (tx, mut rx) = mpsc::unbounded_channel::<(String, String, bool)>();
221            let output_handler_clone = Arc::clone(&output_handler);
222
223            let output_task = tokio::spawn(async move {
224                while let Some((package_name, line, is_stderr)) = rx.recv().await {
225                    if let Ok(handler) = output_handler_clone.lock() {
226                        handler(&package_name, &line, is_stderr);
227                    }
228                }
229            });
230
231            let mut handles = Vec::new();
232            let packages_dir = self.packages_dir.clone();
233            for package in level_packages {
234                let package_name = package.name.clone();
235                let package_path = packages_dir.join(&package.path);
236                let task_name = task_name.to_string();
237                let tx_clone = tx.clone();
238
239                let handle = tokio::spawn(async move {
240                    let streaming_task =
241                        match StreamingTask::spawn(&package, &task_name, &package_path).await {
242                            Ok(task) => task,
243                            Err(e) => return Err(e),
244                        };
245
246                    let stdout = Arc::new(Mutex::new(String::new()));
247                    let stderr = Arc::new(Mutex::new(String::new()));
248                    let stdout_clone = Arc::clone(&stdout);
249                    let stderr_clone = Arc::clone(&stderr);
250                    let package_name_clone = package_name.clone();
251
252                    let success = streaming_task
253                        .stream_output(move |line, is_stderr| {
254                            if is_stderr {
255                                if let Ok(mut stderr_guard) = stderr_clone.lock() {
256                                    stderr_guard.push_str(line);
257                                    stderr_guard.push('\n');
258                                }
259                            } else if let Ok(mut stdout_guard) = stdout_clone.lock() {
260                                stdout_guard.push_str(line);
261                                stdout_guard.push('\n');
262                            }
263                            let _ = tx_clone.send((
264                                package_name_clone.clone(),
265                                line.to_string(),
266                                is_stderr,
267                            ));
268                        })
269                        .await?;
270
271                    let stdout_result = Arc::try_unwrap(stdout)
272                        .map_err(|_| Error::MutexLock("Failed to unwrap stdout Arc".to_string()))?
273                        .into_inner()
274                        .map_err(|e| {
275                            Error::MutexLock(format!("Failed to unwrap stdout Mutex: {}", e))
276                        })?;
277
278                    let stderr_result = Arc::try_unwrap(stderr)
279                        .map_err(|_| Error::MutexLock("Failed to unwrap stderr Arc".to_string()))?
280                        .into_inner()
281                        .map_err(|e| {
282                            Error::MutexLock(format!("Failed to unwrap stderr Mutex: {}", e))
283                        })?;
284
285                    Ok(TaskResult {
286                        package_name,
287                        task_name,
288                        success,
289                        stdout: stdout_result,
290                        stderr: stderr_result,
291                    })
292                });
293                handles.push(handle);
294            }
295
296            drop(tx);
297
298            for handle in handles {
299                match handle.await {
300                    Ok(Ok(result)) => results.push(result),
301                    Ok(Err(e)) => return Err(e),
302                    Err(e) => {
303                        return Err(Error::TaskExecution {
304                            package: "unknown".to_string(),
305                            task: task_name.to_string(),
306                            message: format!("Task execution failed: {}", e),
307                        });
308                    }
309                }
310            }
311
312            output_task.abort();
313        }
314
315        Ok(results)
316    }
317
318}
319
320/// Result of executing a task for a package.
321#[derive(Debug, Clone)]
322pub struct TaskResult {
323    /// Name of the package that was executed.
324    pub package_name: String,
325    /// Name of the task that was executed.
326    pub task_name: String,
327    /// Whether the task succeeded.
328    pub success: bool,
329    /// Standard output from the task.
330    pub stdout: String,
331    /// Standard error from the task.
332    pub stderr: String,
333}