1use std::collections::HashSet;
4use std::path::PathBuf;
5use std::process::{Command, Stdio};
6
7use rayon::prelude::*;
8
9use std::sync::Arc;
10
11use crate::command_validator::CommandValidator;
12use crate::error::{Error, Result};
13use crate::graph::DependencyGraph;
14use crate::package::Package;
15use crate::remote_cache::{Artifact, ArtifactVerifier, RemoteCache};
16use crate::simd_utils;
17use crate::streaming::StreamingTask;
18use crate::task_cache::TaskCache;
19
20pub struct TaskRunner {
22 packages_dir: PathBuf,
23 graph: DependencyGraph,
24 max_parallel: Option<usize>,
25 command_validator: CommandValidator,
26 task_cache: Option<TaskCache>,
27 remote_cache: Option<Arc<RemoteCache>>,
28}
29
30impl TaskRunner {
31 pub fn new(packages_dir: impl Into<PathBuf>, graph: DependencyGraph) -> Self {
32 Self {
33 packages_dir: packages_dir.into(),
34 graph,
35 max_parallel: None,
36 command_validator: CommandValidator::new(),
37 task_cache: None,
38 remote_cache: None,
39 }
40 }
41
42 pub fn with_command_validator(mut self, validator: CommandValidator) -> Self {
43 self.command_validator = validator;
44 self
45 }
46
47 pub fn with_task_cache(mut self, cache: TaskCache) -> Self {
48 self.task_cache = Some(cache);
49 self
50 }
51
52 pub fn with_max_parallel(mut self, max_parallel: Option<usize>) -> Self {
53 self.max_parallel = max_parallel;
54 self
55 }
56
57 pub fn with_remote_cache(mut self, remote_cache: Arc<RemoteCache>) -> Self {
58 self.remote_cache = Some(remote_cache);
59 self
60 }
61
62 pub fn run_task(
63 &self,
64 task_name: &str,
65 package_names: Option<&[String]>,
66 ) -> Result<Vec<TaskResult>> {
67 if let Some(names) = package_names {
68 if names.is_empty() {
69 return Ok(Vec::new());
70 }
71 if names.len() == 1 {
72 if let Some(package) = self.graph.get_package(&names[0]) {
73 let result = self.execute_task(package, task_name)?;
74 return Ok(vec![result]);
75 }
76 }
77 }
78
79 let packages_to_run = if let Some(names) = package_names {
80 names
81 .iter()
82 .filter_map(|name| self.graph.get_package(name))
83 .collect::<Vec<_>>()
84 } else {
85 self.graph.all_packages()
86 };
87
88 if packages_to_run.is_empty() {
89 return Ok(Vec::new());
90 }
91
92 let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
93
94 let levels = self.graph.dependency_levels();
95 let mut results = Vec::with_capacity(packages_to_run.len());
96
97 let thread_count = self.max_parallel.unwrap_or_else(rayon::current_num_threads);
98 let pool = rayon::ThreadPoolBuilder::new()
99 .num_threads(thread_count)
100 .build()
101 .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
102
103 for level in levels {
104 let level_packages: Vec<&Package> = level
105 .iter()
106 .filter(|name| packages_set.contains(name.as_str()))
107 .filter_map(|name| self.graph.get_package(name))
108 .collect();
109
110 if level_packages.is_empty() {
111 continue;
112 }
113
114 let level_results: Result<Vec<TaskResult>> = pool.install(|| {
115 level_packages
116 .into_par_iter()
117 .map(|package| self.execute_task(package, task_name))
118 .collect()
119 });
120
121 let mut level_results = level_results?;
122 results.append(&mut level_results);
123 }
124
125 Ok(results)
126 }
127
128 pub async fn run_task_streaming<F>(
129 &self,
130 task_name: &str,
131 package_names: Option<&[String]>,
132 on_output: F,
133 ) -> Result<Vec<TaskResult>>
134 where
135 F: Fn(&str, &str, bool) + Send + Sync + 'static,
136 {
137 let packages_to_run: Vec<Package> = if let Some(names) = package_names {
138 names
139 .iter()
140 .filter_map(|name| self.graph.get_package(name))
141 .cloned()
142 .collect()
143 } else {
144 self.graph.all_packages().into_iter().cloned().collect()
145 };
146
147 if packages_to_run.is_empty() {
148 return Ok(Vec::new());
149 }
150
151 let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
152
153 let levels = self.graph.dependency_levels();
154 let mut results = Vec::new();
155 use std::sync::{Arc, Mutex};
156 use tokio::sync::mpsc;
157
158 let output_handler = Arc::new(Mutex::new(on_output));
159
160 for level in levels {
161 let level_packages: Vec<Package> = level
162 .iter()
163 .filter(|name| packages_set.contains(name.as_str()))
164 .filter_map(|name| self.graph.get_package(name))
165 .cloned()
166 .collect();
167
168 if level_packages.is_empty() {
169 continue;
170 }
171
172 let (tx, mut rx) = mpsc::unbounded_channel::<(String, String, bool)>();
173 let output_handler_clone = Arc::clone(&output_handler);
174
175 let output_task = tokio::spawn(async move {
176 while let Some((package_name, line, is_stderr)) = rx.recv().await {
177 if let Ok(handler) = output_handler_clone.lock() {
178 handler(&package_name, &line, is_stderr);
179 }
180 }
181 });
182
183 let mut handles = Vec::new();
184 let packages_dir = self.packages_dir.clone();
185 for package in level_packages {
186 let package_name = package.name.clone();
187 let package_path = packages_dir.join(&package.path);
188 let task_name = task_name.to_string();
189 let tx_clone = tx.clone();
190
191 let handle = tokio::spawn(async move {
192 let streaming_task =
193 match StreamingTask::spawn(&package, &task_name, &package_path).await {
194 Ok(task) => task,
195 Err(e) => return Err(e),
196 };
197
198 let stdout = Arc::new(Mutex::new(String::new()));
199 let stderr = Arc::new(Mutex::new(String::new()));
200 let stdout_clone = Arc::clone(&stdout);
201 let stderr_clone = Arc::clone(&stderr);
202 let package_name_clone = package_name.clone();
203
204 let success = streaming_task
205 .stream_output(move |line, is_stderr| {
206 if is_stderr {
207 if let Ok(mut stderr_guard) = stderr_clone.lock() {
208 stderr_guard.push_str(line);
209 stderr_guard.push('\n');
210 }
211 } else if let Ok(mut stdout_guard) = stdout_clone.lock() {
212 stdout_guard.push_str(line);
213 stdout_guard.push('\n');
214 }
215 let _ = tx_clone.send((
216 package_name_clone.clone(),
217 line.to_string(),
218 is_stderr,
219 ));
220 })
221 .await?;
222
223 let stdout_result = Arc::try_unwrap(stdout)
224 .map_err(|_| Error::MutexLock("Failed to unwrap stdout Arc".to_string()))?
225 .into_inner()
226 .map_err(|e| {
227 Error::MutexLock(format!("Failed to unwrap stdout Mutex: {}", e))
228 })?;
229
230 let stderr_result = Arc::try_unwrap(stderr)
231 .map_err(|_| Error::MutexLock("Failed to unwrap stderr Arc".to_string()))?
232 .into_inner()
233 .map_err(|e| {
234 Error::MutexLock(format!("Failed to unwrap stderr Mutex: {}", e))
235 })?;
236
237 Ok(TaskResult {
238 package_name,
239 task_name,
240 success,
241 stdout: stdout_result,
242 stderr: stderr_result,
243 })
244 });
245 handles.push(handle);
246 }
247
248 drop(tx);
249
250 for handle in handles {
251 match handle.await {
252 Ok(Ok(result)) => results.push(result),
253 Ok(Err(e)) => return Err(e),
254 Err(e) => {
255 return Err(Error::TaskExecution {
256 package: "unknown".to_string(),
257 task: task_name.to_string(),
258 message: format!("Task execution failed: {}", e),
259 });
260 }
261 }
262 }
263
264 output_task.abort();
265 }
266
267 Ok(results)
268 }
269
270 fn build_task_dependency_order(
271 &self,
272 package: &Package,
273 task_name: &str,
274 ) -> Result<Vec<String>> {
275 let _task = package
276 .get_task(task_name)
277 .ok_or_else(|| Error::TaskExecution {
278 package: package.name.clone(),
279 task: task_name.to_string(),
280 message: format!("Task '{}' not found", task_name),
281 })?;
282
283 let mut order = Vec::new();
284 let mut visited = HashSet::new();
285 let mut visiting = HashSet::new();
286
287 fn visit_task(
288 package: &Package,
289 task_name: &str,
290 order: &mut Vec<String>,
291 visited: &mut HashSet<String>,
292 visiting: &mut HashSet<String>,
293 ) -> Result<()> {
294 if visiting.contains(task_name) {
295 return Err(Error::TaskExecution {
296 package: package.name.clone(),
297 task: task_name.to_string(),
298 message: format!(
299 "Circular task dependency detected involving '{}'",
300 task_name
301 ),
302 });
303 }
304
305 if visited.contains(task_name) {
306 return Ok(());
307 }
308
309 visiting.insert(task_name.to_string());
310 let task = package
311 .get_task(task_name)
312 .ok_or_else(|| Error::TaskExecution {
313 package: package.name.clone(),
314 task: task_name.to_string(),
315 message: format!("Task '{}' not found", task_name),
316 })?;
317
318 for dep in &task.depends_on {
319 visit_task(package, dep, order, visited, visiting)?;
320 }
321
322 visiting.remove(task_name);
323 visited.insert(task_name.to_string());
324 order.push(task_name.to_string());
325
326 Ok(())
327 }
328
329 visit_task(package, task_name, &mut order, &mut visited, &mut visiting)?;
330
331 Ok(order)
332 }
333
334 fn execute_task_with_deps(
335 &self,
336 package: &Package,
337 task_name: &str,
338 ) -> Result<Vec<TaskResult>> {
339 let task_order = self.build_task_dependency_order(package, task_name)?;
340 let mut results = Vec::with_capacity(task_order.len());
341
342 for task in &task_order {
343 let result = self.execute_task_internal(package, task)?;
344 let success = result.success;
345 results.push(result);
346 if !success && task == task_name {
347 return Ok(results);
348 }
349 }
350
351 Ok(results)
352 }
353
354 fn execute_task_internal(&self, package: &Package, task_name: &str) -> Result<TaskResult> {
355 let task = package.get_task(task_name).ok_or_else(|| {
356 let available_tasks: Vec<&str> =
357 package.tasks.iter().map(|t| t.name.as_str()).collect();
358 Error::TaskExecution {
359 package: package.name.clone(),
360 task: task_name.to_string(),
361 message: format!(
362 "Task '{}' not found. Available tasks: {}",
363 task_name,
364 available_tasks.join(", ")
365 ),
366 }
367 })?;
368
369 let package_path = self.packages_dir.join(&package.path);
370
371 if let Some(ref remote_cache) = self.remote_cache {
372 match self.check_remote_cache(
373 remote_cache,
374 package,
375 task_name,
376 &task.command,
377 &package_path,
378 ) {
379 Ok(Some(cached_result)) => return Ok(cached_result),
380 Ok(None) => {}
381 Err(_) => {}
382 }
383 }
384
385 if let Some(ref cache) = self.task_cache {
386 if let Some(cached_result) = cache.get(&package.name, task_name, &task.command)? {
387 return Ok(cached_result);
388 }
389 }
390
391 self.command_validator.validate(&task.command)?;
392
393 let output = Command::new("sh")
394 .arg("-c")
395 .arg(&task.command)
396 .current_dir(&package_path)
397 .stdout(Stdio::piped())
398 .stderr(Stdio::piped())
399 .output()
400 .map_err(|e| Error::TaskExecution {
401 package: package.name.clone(),
402 task: task_name.to_string(),
403 message: format!("Failed to execute task: {}", e),
404 })?;
405
406 let stdout = if simd_utils::is_ascii_fast(&output.stdout) {
407 unsafe { String::from_utf8_unchecked(output.stdout) }
408 } else {
409 String::from_utf8_lossy(&output.stdout).to_string()
410 };
411
412 let stderr = if simd_utils::is_ascii_fast(&output.stderr) {
413 unsafe { String::from_utf8_unchecked(output.stderr) }
414 } else {
415 String::from_utf8_lossy(&output.stderr).to_string()
416 };
417
418 let result = TaskResult {
419 package_name: package.name.clone(),
420 task_name: task_name.to_string(),
421 success: output.status.success(),
422 stdout,
423 stderr,
424 };
425
426 if let Some(ref cache) = self.task_cache {
428 let _ = cache.put(&package.name, task_name, &task.command, &result);
429 }
430
431 if result.success {
432 if let Some(ref remote_cache) = self.remote_cache {
433 let _ = self.upload_to_remote_cache(
434 remote_cache,
435 package,
436 task_name,
437 &task.command,
438 &package_path,
439 &result,
440 );
441 }
442 }
443
444 Ok(result)
445 }
446
447 fn check_remote_cache(
451 &self,
452 remote_cache: &RemoteCache,
453 package: &Package,
454 task_name: &str,
455 command: &str,
456 package_path: &std::path::Path,
457 ) -> Result<Option<TaskResult>> {
458 let rt = match tokio::runtime::Handle::try_current() {
459 Ok(handle) => handle,
460 Err(_) => {
461 tokio::runtime::Runtime::new()
463 .map_err(|e| Error::Adapter {
464 package: "remote-cache".to_string(),
465 message: format!("Failed to create tokio runtime: {}", e),
466 })?
467 .handle()
468 .clone()
469 }
470 };
471
472 let cache_key = rt.block_on(remote_cache.build_cache_key(
474 package,
475 task_name,
476 command,
477 &self.graph,
478 package_path,
479 ))?;
480
481 let artifact_opt = rt.block_on(remote_cache.fetch_artifact(&cache_key))?;
483
484 if let Some(artifact) = artifact_opt {
485 if ArtifactVerifier::verify(&artifact, None).is_err() {
486 return Err(Error::Adapter {
487 package: "remote-cache".to_string(),
488 message: "Artifact integrity verification failed".to_string(),
489 });
490 }
491
492 artifact.extract_outputs(package_path)?;
494
495 Ok(Some(TaskResult {
497 package_name: package.name.clone(),
498 task_name: task_name.to_string(),
499 success: true,
500 stdout: String::new(), stderr: String::new(),
502 }))
503 } else {
504 Ok(None)
505 }
506 }
507
508 fn upload_to_remote_cache(
510 &self,
511 remote_cache: &RemoteCache,
512 package: &Package,
513 task_name: &str,
514 command: &str,
515 package_path: &std::path::Path,
516 result: &TaskResult,
517 ) -> Result<()> {
518 use std::collections::BTreeMap;
519
520 let rt = match tokio::runtime::Handle::try_current() {
521 Ok(handle) => handle,
522 Err(_) => {
523 tokio::runtime::Runtime::new()
524 .map_err(|e| Error::Adapter {
525 package: "remote-cache".to_string(),
526 message: format!("Failed to create tokio runtime: {}", e),
527 })?
528 .handle()
529 .clone()
530 }
531 };
532
533 let cache_key = rt.block_on(remote_cache.build_cache_key(
535 package,
536 task_name,
537 command,
538 &self.graph,
539 package_path,
540 ))?;
541
542 let mut output_files = BTreeMap::new();
545 output_files.insert(
546 PathBuf::from("stdout.txt"),
547 result.stdout.as_bytes().to_vec(),
548 );
549 output_files.insert(
550 PathBuf::from("stderr.txt"),
551 result.stderr.as_bytes().to_vec(),
552 );
553
554 let artifact = Artifact::new(
556 package.name.clone(),
557 task_name.to_string(),
558 command.to_string(),
559 cache_key.as_string(),
560 output_files,
561 )?;
562
563 let _ = rt.block_on(remote_cache.upload_artifact(&cache_key, &artifact));
564
565 Ok(())
566 }
567
568 fn execute_task(&self, package: &Package, task_name: &str) -> Result<TaskResult> {
569 let results = self.execute_task_with_deps(package, task_name)?;
570 results
571 .into_iter()
572 .find(|r| r.task_name == task_name)
573 .ok_or_else(|| Error::TaskExecution {
574 package: package.name.clone(),
575 task: task_name.to_string(),
576 message: "Task execution failed".to_string(),
577 })
578 }
579}
580
581#[derive(Debug, Clone)]
583pub struct TaskResult {
584 pub package_name: String,
586 pub task_name: String,
588 pub success: bool,
590 pub stdout: String,
592 pub stderr: String,
594}