1use std::collections::HashSet;
4use std::path::PathBuf;
5
6use rayon::prelude::*;
7use crossbeam::channel;
8
9use std::sync::Arc;
10
11use crate::command_validator::CommandValidator;
12use crate::error::{Error, Result};
13use crate::executor::TaskExecutor;
14use crate::graph::DependencyGraph;
15use crate::package::Package;
16use crate::remote_cache::RemoteCache;
17use crate::streaming::StreamingTask;
18use crate::task_cache::TaskCache;
19
20pub struct TaskRunner {
22 packages_dir: PathBuf,
23 graph: DependencyGraph,
24 max_parallel: Option<usize>,
25 command_validator: CommandValidator,
26 task_cache: Option<TaskCache>,
27 remote_cache: Option<Arc<RemoteCache>>,
28 thread_pool: Arc<rayon::ThreadPool>,
29 executor: TaskExecutor,
30}
31
32impl TaskRunner {
33 pub fn new(packages_dir: impl Into<PathBuf>, graph: DependencyGraph) -> Self {
34 let pool = rayon::ThreadPoolBuilder::new()
35 .num_threads(rayon::current_num_threads())
36 .thread_name(|i| format!("polykit-worker-{}", i))
37 .build()
38 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
39
40 let packages_dir_path = packages_dir.into();
41 let executor = TaskExecutor::new(
42 packages_dir_path.clone(),
43 graph.clone(),
44 CommandValidator::new(),
45 None,
46 None,
47 );
48
49 Self {
50 packages_dir: packages_dir_path,
51 graph,
52 max_parallel: None,
53 command_validator: CommandValidator::new(),
54 task_cache: None,
55 remote_cache: None,
56 thread_pool: Arc::new(pool),
57 executor,
58 }
59 }
60
61 pub fn with_command_validator(mut self, validator: CommandValidator) -> Self {
62 self.command_validator = validator.clone();
63 self.executor = TaskExecutor::new(
64 self.packages_dir.clone(),
65 self.graph.clone(),
66 validator,
67 self.task_cache.clone(),
68 self.remote_cache.clone(),
69 );
70 self
71 }
72
73 pub fn with_task_cache(mut self, cache: TaskCache) -> Self {
74 self.task_cache = Some(cache.clone());
75 self.executor = TaskExecutor::new(
76 self.packages_dir.clone(),
77 self.graph.clone(),
78 self.command_validator.clone(),
79 Some(cache),
80 self.remote_cache.clone(),
81 );
82 self
83 }
84
85 pub fn with_max_parallel(mut self, max_parallel: Option<usize>) -> Self {
86 self.max_parallel = max_parallel;
87 self
88 }
89
90 pub fn with_remote_cache(mut self, remote_cache: Arc<RemoteCache>) -> Self {
91 self.remote_cache = Some(remote_cache.clone());
92 self.executor = TaskExecutor::new(
93 self.packages_dir.clone(),
94 self.graph.clone(),
95 self.command_validator.clone(),
96 self.task_cache.clone(),
97 Some(remote_cache),
98 );
99 self
100 }
101
102 pub fn run_task(
103 &self,
104 task_name: &str,
105 package_names: Option<&[String]>,
106 ) -> Result<Vec<TaskResult>> {
107 if let Some(names) = package_names {
108 if names.is_empty() {
109 return Ok(Vec::new());
110 }
111 if names.len() == 1 {
112 if let Some(package) = self.graph.get_package(&names[0]) {
113 let result = self.executor.execute_task(package, task_name)?;
114 return Ok(vec![result]);
115 }
116 }
117 }
118
119 let packages_to_run = if let Some(names) = package_names {
120 names
121 .iter()
122 .filter_map(|name| self.graph.get_package(name))
123 .collect::<Vec<_>>()
124 } else {
125 self.graph.all_packages()
126 };
127
128 if packages_to_run.is_empty() {
129 return Ok(Vec::new());
130 }
131
132 let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
133
134 let levels = self.graph.dependency_levels();
135 let mut results = Vec::with_capacity(packages_to_run.len());
136
137 for level in levels {
138 let level_packages: Vec<&Package> = level
139 .iter()
140 .filter(|name| packages_set.contains(name.as_str()))
141 .filter_map(|name| self.graph.get_package(name))
142 .collect();
143
144 if level_packages.is_empty() {
145 continue;
146 }
147
148 let (tx, rx) = channel::unbounded();
149 let executor = &self.executor;
150 self.thread_pool.install(|| {
151 level_packages
152 .into_par_iter()
153 .for_each(|package| {
154 let result = executor.execute_task(package, task_name);
155 let _ = tx.send(result);
156 });
157 });
158 drop(tx);
159
160 let level_results: Result<Vec<TaskResult>> = rx
161 .iter()
162 .collect::<std::result::Result<Vec<_>, _>>()
163 .map_err(|e| Error::TaskExecution {
164 package: "unknown".to_string(),
165 task: task_name.to_string(),
166 message: format!("Task execution failed: {}", e),
167 });
168
169 let mut level_results = level_results?;
170 results.append(&mut level_results);
171 }
172
173 Ok(results)
174 }
175
176 pub async fn run_task_streaming<F>(
177 &self,
178 task_name: &str,
179 package_names: Option<&[String]>,
180 on_output: F,
181 ) -> Result<Vec<TaskResult>>
182 where
183 F: Fn(&str, &str, bool) + Send + Sync + 'static,
184 {
185 let packages_to_run: Vec<Package> = if let Some(names) = package_names {
186 names
187 .iter()
188 .filter_map(|name| self.graph.get_package(name))
189 .cloned()
190 .collect()
191 } else {
192 self.graph.all_packages().into_iter().cloned().collect()
193 };
194
195 if packages_to_run.is_empty() {
196 return Ok(Vec::new());
197 }
198
199 let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
200
201 let levels = self.graph.dependency_levels();
202 let mut results = Vec::new();
203 use std::sync::{Arc, Mutex};
204 use tokio::sync::mpsc;
205
206 let output_handler = Arc::new(Mutex::new(on_output));
207
208 for level in levels {
209 let level_packages: Vec<Package> = level
210 .iter()
211 .filter(|name| packages_set.contains(name.as_str()))
212 .filter_map(|name| self.graph.get_package(name))
213 .cloned()
214 .collect();
215
216 if level_packages.is_empty() {
217 continue;
218 }
219
220 let (tx, mut rx) = mpsc::unbounded_channel::<(String, String, bool)>();
221 let output_handler_clone = Arc::clone(&output_handler);
222
223 let output_task = tokio::spawn(async move {
224 while let Some((package_name, line, is_stderr)) = rx.recv().await {
225 if let Ok(handler) = output_handler_clone.lock() {
226 handler(&package_name, &line, is_stderr);
227 }
228 }
229 });
230
231 let mut handles = Vec::new();
232 let packages_dir = self.packages_dir.clone();
233 for package in level_packages {
234 let package_name = package.name.clone();
235 let package_path = packages_dir.join(&package.path);
236 let task_name = task_name.to_string();
237 let tx_clone = tx.clone();
238
239 let handle = tokio::spawn(async move {
240 let streaming_task =
241 match StreamingTask::spawn(&package, &task_name, &package_path).await {
242 Ok(task) => task,
243 Err(e) => return Err(e),
244 };
245
246 let stdout = Arc::new(Mutex::new(String::new()));
247 let stderr = Arc::new(Mutex::new(String::new()));
248 let stdout_clone = Arc::clone(&stdout);
249 let stderr_clone = Arc::clone(&stderr);
250 let package_name_clone = package_name.clone();
251
252 let success = streaming_task
253 .stream_output(move |line, is_stderr| {
254 if is_stderr {
255 if let Ok(mut stderr_guard) = stderr_clone.lock() {
256 stderr_guard.push_str(line);
257 stderr_guard.push('\n');
258 }
259 } else if let Ok(mut stdout_guard) = stdout_clone.lock() {
260 stdout_guard.push_str(line);
261 stdout_guard.push('\n');
262 }
263 let _ = tx_clone.send((
264 package_name_clone.clone(),
265 line.to_string(),
266 is_stderr,
267 ));
268 })
269 .await?;
270
271 let stdout_result = Arc::try_unwrap(stdout)
272 .map_err(|_| Error::MutexLock("Failed to unwrap stdout Arc".to_string()))?
273 .into_inner()
274 .map_err(|e| {
275 Error::MutexLock(format!("Failed to unwrap stdout Mutex: {}", e))
276 })?;
277
278 let stderr_result = Arc::try_unwrap(stderr)
279 .map_err(|_| Error::MutexLock("Failed to unwrap stderr Arc".to_string()))?
280 .into_inner()
281 .map_err(|e| {
282 Error::MutexLock(format!("Failed to unwrap stderr Mutex: {}", e))
283 })?;
284
285 Ok(TaskResult {
286 package_name,
287 task_name,
288 success,
289 stdout: stdout_result,
290 stderr: stderr_result,
291 })
292 });
293 handles.push(handle);
294 }
295
296 drop(tx);
297
298 for handle in handles {
299 match handle.await {
300 Ok(Ok(result)) => results.push(result),
301 Ok(Err(e)) => return Err(e),
302 Err(e) => {
303 return Err(Error::TaskExecution {
304 package: "unknown".to_string(),
305 task: task_name.to_string(),
306 message: format!("Task execution failed: {}", e),
307 });
308 }
309 }
310 }
311
312 output_task.abort();
313 }
314
315 Ok(results)
316 }
317
318}
319
320#[derive(Debug, Clone)]
322pub struct TaskResult {
323 pub package_name: String,
325 pub task_name: String,
327 pub success: bool,
329 pub stdout: String,
331 pub stderr: String,
333}