Skip to main content

polykit_core/
runner.rs

1//! Task execution engine.
2
3use std::collections::HashSet;
4use std::path::PathBuf;
5use std::process::{Command, Stdio};
6
7use rayon::prelude::*;
8
9use std::sync::Arc;
10
11use crate::command_validator::CommandValidator;
12use crate::error::{Error, Result};
13use crate::graph::DependencyGraph;
14use crate::package::Package;
15use crate::remote_cache::{Artifact, ArtifactVerifier, RemoteCache};
16use crate::simd_utils;
17use crate::streaming::StreamingTask;
18use crate::task_cache::TaskCache;
19
20/// Executes tasks across packages respecting dependency order.
21pub struct TaskRunner {
22    packages_dir: PathBuf,
23    graph: DependencyGraph,
24    max_parallel: Option<usize>,
25    command_validator: CommandValidator,
26    task_cache: Option<TaskCache>,
27    remote_cache: Option<Arc<RemoteCache>>,
28}
29
30impl TaskRunner {
31    pub fn new(packages_dir: impl Into<PathBuf>, graph: DependencyGraph) -> Self {
32        Self {
33            packages_dir: packages_dir.into(),
34            graph,
35            max_parallel: None,
36            command_validator: CommandValidator::new(),
37            task_cache: None,
38            remote_cache: None,
39        }
40    }
41
42    pub fn with_command_validator(mut self, validator: CommandValidator) -> Self {
43        self.command_validator = validator;
44        self
45    }
46
47    pub fn with_task_cache(mut self, cache: TaskCache) -> Self {
48        self.task_cache = Some(cache);
49        self
50    }
51
52    pub fn with_max_parallel(mut self, max_parallel: Option<usize>) -> Self {
53        self.max_parallel = max_parallel;
54        self
55    }
56
57    pub fn with_remote_cache(mut self, remote_cache: Arc<RemoteCache>) -> Self {
58        self.remote_cache = Some(remote_cache);
59        self
60    }
61
62    pub fn run_task(
63        &self,
64        task_name: &str,
65        package_names: Option<&[String]>,
66    ) -> Result<Vec<TaskResult>> {
67        if let Some(names) = package_names {
68            if names.is_empty() {
69                return Ok(Vec::new());
70            }
71            if names.len() == 1 {
72                if let Some(package) = self.graph.get_package(&names[0]) {
73                    let result = self.execute_task(package, task_name)?;
74                    return Ok(vec![result]);
75                }
76            }
77        }
78
79        let packages_to_run = if let Some(names) = package_names {
80            names
81                .iter()
82                .filter_map(|name| self.graph.get_package(name))
83                .collect::<Vec<_>>()
84        } else {
85            self.graph.all_packages()
86        };
87
88        if packages_to_run.is_empty() {
89            return Ok(Vec::new());
90        }
91
92        let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
93
94        let levels = self.graph.dependency_levels();
95        let mut results = Vec::with_capacity(packages_to_run.len());
96
97        let thread_count = self.max_parallel.unwrap_or_else(rayon::current_num_threads);
98        let pool = rayon::ThreadPoolBuilder::new()
99            .num_threads(thread_count)
100            .build()
101            .unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
102
103        for level in levels {
104            let level_packages: Vec<&Package> = level
105                .iter()
106                .filter(|name| packages_set.contains(name.as_str()))
107                .filter_map(|name| self.graph.get_package(name))
108                .collect();
109
110            if level_packages.is_empty() {
111                continue;
112            }
113
114            let level_results: Result<Vec<TaskResult>> = pool.install(|| {
115                level_packages
116                    .into_par_iter()
117                    .map(|package| self.execute_task(package, task_name))
118                    .collect()
119            });
120
121            let mut level_results = level_results?;
122            results.append(&mut level_results);
123        }
124
125        Ok(results)
126    }
127
128    pub async fn run_task_streaming<F>(
129        &self,
130        task_name: &str,
131        package_names: Option<&[String]>,
132        on_output: F,
133    ) -> Result<Vec<TaskResult>>
134    where
135        F: Fn(&str, &str, bool) + Send + Sync + 'static,
136    {
137        let packages_to_run: Vec<Package> = if let Some(names) = package_names {
138            names
139                .iter()
140                .filter_map(|name| self.graph.get_package(name))
141                .cloned()
142                .collect()
143        } else {
144            self.graph.all_packages().into_iter().cloned().collect()
145        };
146
147        if packages_to_run.is_empty() {
148            return Ok(Vec::new());
149        }
150
151        let packages_set: HashSet<&str> = packages_to_run.iter().map(|p| p.name.as_str()).collect();
152
153        let levels = self.graph.dependency_levels();
154        let mut results = Vec::new();
155        use std::sync::{Arc, Mutex};
156        use tokio::sync::mpsc;
157
158        let output_handler = Arc::new(Mutex::new(on_output));
159
160        for level in levels {
161            let level_packages: Vec<Package> = level
162                .iter()
163                .filter(|name| packages_set.contains(name.as_str()))
164                .filter_map(|name| self.graph.get_package(name))
165                .cloned()
166                .collect();
167
168            if level_packages.is_empty() {
169                continue;
170            }
171
172            let (tx, mut rx) = mpsc::unbounded_channel::<(String, String, bool)>();
173            let output_handler_clone = Arc::clone(&output_handler);
174
175            let output_task = tokio::spawn(async move {
176                while let Some((package_name, line, is_stderr)) = rx.recv().await {
177                    if let Ok(handler) = output_handler_clone.lock() {
178                        handler(&package_name, &line, is_stderr);
179                    }
180                }
181            });
182
183            let mut handles = Vec::new();
184            let packages_dir = self.packages_dir.clone();
185            for package in level_packages {
186                let package_name = package.name.clone();
187                let package_path = packages_dir.join(&package.path);
188                let task_name = task_name.to_string();
189                let tx_clone = tx.clone();
190
191                let handle = tokio::spawn(async move {
192                    let streaming_task =
193                        match StreamingTask::spawn(&package, &task_name, &package_path).await {
194                            Ok(task) => task,
195                            Err(e) => return Err(e),
196                        };
197
198                    let stdout = Arc::new(Mutex::new(String::new()));
199                    let stderr = Arc::new(Mutex::new(String::new()));
200                    let stdout_clone = Arc::clone(&stdout);
201                    let stderr_clone = Arc::clone(&stderr);
202                    let package_name_clone = package_name.clone();
203
204                    let success = streaming_task
205                        .stream_output(move |line, is_stderr| {
206                            if is_stderr {
207                                if let Ok(mut stderr_guard) = stderr_clone.lock() {
208                                    stderr_guard.push_str(line);
209                                    stderr_guard.push('\n');
210                                }
211                            } else if let Ok(mut stdout_guard) = stdout_clone.lock() {
212                                stdout_guard.push_str(line);
213                                stdout_guard.push('\n');
214                            }
215                            let _ = tx_clone.send((
216                                package_name_clone.clone(),
217                                line.to_string(),
218                                is_stderr,
219                            ));
220                        })
221                        .await?;
222
223                    let stdout_result = Arc::try_unwrap(stdout)
224                        .map_err(|_| Error::MutexLock("Failed to unwrap stdout Arc".to_string()))?
225                        .into_inner()
226                        .map_err(|e| {
227                            Error::MutexLock(format!("Failed to unwrap stdout Mutex: {}", e))
228                        })?;
229
230                    let stderr_result = Arc::try_unwrap(stderr)
231                        .map_err(|_| Error::MutexLock("Failed to unwrap stderr Arc".to_string()))?
232                        .into_inner()
233                        .map_err(|e| {
234                            Error::MutexLock(format!("Failed to unwrap stderr Mutex: {}", e))
235                        })?;
236
237                    Ok(TaskResult {
238                        package_name,
239                        task_name,
240                        success,
241                        stdout: stdout_result,
242                        stderr: stderr_result,
243                    })
244                });
245                handles.push(handle);
246            }
247
248            drop(tx);
249
250            for handle in handles {
251                match handle.await {
252                    Ok(Ok(result)) => results.push(result),
253                    Ok(Err(e)) => return Err(e),
254                    Err(e) => {
255                        return Err(Error::TaskExecution {
256                            package: "unknown".to_string(),
257                            task: task_name.to_string(),
258                            message: format!("Task execution failed: {}", e),
259                        });
260                    }
261                }
262            }
263
264            output_task.abort();
265        }
266
267        Ok(results)
268    }
269
270    fn build_task_dependency_order(
271        &self,
272        package: &Package,
273        task_name: &str,
274    ) -> Result<Vec<String>> {
275        let _task = package
276            .get_task(task_name)
277            .ok_or_else(|| Error::TaskExecution {
278                package: package.name.clone(),
279                task: task_name.to_string(),
280                message: format!("Task '{}' not found", task_name),
281            })?;
282
283        let mut order = Vec::new();
284        let mut visited = HashSet::new();
285        let mut visiting = HashSet::new();
286
287        fn visit_task(
288            package: &Package,
289            task_name: &str,
290            order: &mut Vec<String>,
291            visited: &mut HashSet<String>,
292            visiting: &mut HashSet<String>,
293        ) -> Result<()> {
294            if visiting.contains(task_name) {
295                return Err(Error::TaskExecution {
296                    package: package.name.clone(),
297                    task: task_name.to_string(),
298                    message: format!(
299                        "Circular task dependency detected involving '{}'",
300                        task_name
301                    ),
302                });
303            }
304
305            if visited.contains(task_name) {
306                return Ok(());
307            }
308
309            visiting.insert(task_name.to_string());
310            let task = package
311                .get_task(task_name)
312                .ok_or_else(|| Error::TaskExecution {
313                    package: package.name.clone(),
314                    task: task_name.to_string(),
315                    message: format!("Task '{}' not found", task_name),
316                })?;
317
318            for dep in &task.depends_on {
319                visit_task(package, dep, order, visited, visiting)?;
320            }
321
322            visiting.remove(task_name);
323            visited.insert(task_name.to_string());
324            order.push(task_name.to_string());
325
326            Ok(())
327        }
328
329        visit_task(package, task_name, &mut order, &mut visited, &mut visiting)?;
330
331        Ok(order)
332    }
333
334    fn execute_task_with_deps(
335        &self,
336        package: &Package,
337        task_name: &str,
338    ) -> Result<Vec<TaskResult>> {
339        let task_order = self.build_task_dependency_order(package, task_name)?;
340        let mut results = Vec::with_capacity(task_order.len());
341
342        for task in &task_order {
343            let result = self.execute_task_internal(package, task)?;
344            let success = result.success;
345            results.push(result);
346            if !success && task == task_name {
347                return Ok(results);
348            }
349        }
350
351        Ok(results)
352    }
353
354    fn execute_task_internal(&self, package: &Package, task_name: &str) -> Result<TaskResult> {
355        let task = package.get_task(task_name).ok_or_else(|| {
356            let available_tasks: Vec<&str> =
357                package.tasks.iter().map(|t| t.name.as_str()).collect();
358            Error::TaskExecution {
359                package: package.name.clone(),
360                task: task_name.to_string(),
361                message: format!(
362                    "Task '{}' not found. Available tasks: {}",
363                    task_name,
364                    available_tasks.join(", ")
365                ),
366            }
367        })?;
368
369        let package_path = self.packages_dir.join(&package.path);
370
371        if let Some(ref remote_cache) = self.remote_cache {
372            match self.check_remote_cache(
373                remote_cache,
374                package,
375                task_name,
376                &task.command,
377                &package_path,
378            ) {
379                Ok(Some(cached_result)) => return Ok(cached_result),
380                Ok(None) => {}
381                Err(_) => {}
382            }
383        }
384
385        if let Some(ref cache) = self.task_cache {
386            if let Some(cached_result) = cache.get(&package.name, task_name, &task.command)? {
387                return Ok(cached_result);
388            }
389        }
390
391        self.command_validator.validate(&task.command)?;
392
393        let output = Command::new("sh")
394            .arg("-c")
395            .arg(&task.command)
396            .current_dir(&package_path)
397            .stdout(Stdio::piped())
398            .stderr(Stdio::piped())
399            .output()
400            .map_err(|e| Error::TaskExecution {
401                package: package.name.clone(),
402                task: task_name.to_string(),
403                message: format!("Failed to execute task: {}", e),
404            })?;
405
406        let stdout = if simd_utils::is_ascii_fast(&output.stdout) {
407            unsafe { String::from_utf8_unchecked(output.stdout) }
408        } else {
409            String::from_utf8_lossy(&output.stdout).to_string()
410        };
411
412        let stderr = if simd_utils::is_ascii_fast(&output.stderr) {
413            unsafe { String::from_utf8_unchecked(output.stderr) }
414        } else {
415            String::from_utf8_lossy(&output.stderr).to_string()
416        };
417
418        let result = TaskResult {
419            package_name: package.name.clone(),
420            task_name: task_name.to_string(),
421            success: output.status.success(),
422            stdout,
423            stderr,
424        };
425
426        // Store in local cache
427        if let Some(ref cache) = self.task_cache {
428            let _ = cache.put(&package.name, task_name, &task.command, &result);
429        }
430
431        if result.success {
432            if let Some(ref remote_cache) = self.remote_cache {
433                let _ = self.upload_to_remote_cache(
434                    remote_cache,
435                    package,
436                    task_name,
437                    &task.command,
438                    &package_path,
439                    &result,
440                );
441            }
442        }
443
444        Ok(result)
445    }
446
447    /// Checks remote cache for a task result.
448    ///
449    /// Returns `Ok(Some(result))` if found, `Ok(None)` if not found, or `Err` on error.
450    fn check_remote_cache(
451        &self,
452        remote_cache: &RemoteCache,
453        package: &Package,
454        task_name: &str,
455        command: &str,
456        package_path: &std::path::Path,
457    ) -> Result<Option<TaskResult>> {
458        let rt = match tokio::runtime::Handle::try_current() {
459            Ok(handle) => handle,
460            Err(_) => {
461                // Create a new runtime if we're not in an async context
462                tokio::runtime::Runtime::new()
463                    .map_err(|e| Error::Adapter {
464                        package: "remote-cache".to_string(),
465                        message: format!("Failed to create tokio runtime: {}", e),
466                    })?
467                    .handle()
468                    .clone()
469            }
470        };
471
472        // Build cache key
473        let cache_key = rt.block_on(remote_cache.build_cache_key(
474            package,
475            task_name,
476            command,
477            &self.graph,
478            package_path,
479        ))?;
480
481        // Fetch artifact
482        let artifact_opt = rt.block_on(remote_cache.fetch_artifact(&cache_key))?;
483
484        if let Some(artifact) = artifact_opt {
485            if ArtifactVerifier::verify(&artifact, None).is_err() {
486                return Err(Error::Adapter {
487                    package: "remote-cache".to_string(),
488                    message: "Artifact integrity verification failed".to_string(),
489                });
490            }
491
492            // Extract outputs
493            artifact.extract_outputs(package_path)?;
494
495            // Return cached result
496            Ok(Some(TaskResult {
497                package_name: package.name.clone(),
498                task_name: task_name.to_string(),
499                success: true,
500                stdout: String::new(), // Outputs are in files, not stdout
501                stderr: String::new(),
502            }))
503        } else {
504            Ok(None)
505        }
506    }
507
508    /// Uploads task result to remote cache.
509    fn upload_to_remote_cache(
510        &self,
511        remote_cache: &RemoteCache,
512        package: &Package,
513        task_name: &str,
514        command: &str,
515        package_path: &std::path::Path,
516        result: &TaskResult,
517    ) -> Result<()> {
518        use std::collections::BTreeMap;
519
520        let rt = match tokio::runtime::Handle::try_current() {
521            Ok(handle) => handle,
522            Err(_) => {
523                tokio::runtime::Runtime::new()
524                    .map_err(|e| Error::Adapter {
525                        package: "remote-cache".to_string(),
526                        message: format!("Failed to create tokio runtime: {}", e),
527                    })?
528                    .handle()
529                    .clone()
530            }
531        };
532
533        // Build cache key
534        let cache_key = rt.block_on(remote_cache.build_cache_key(
535            package,
536            task_name,
537            command,
538            &self.graph,
539            package_path,
540        ))?;
541
542        // Collect output files (simplified - in practice, we'd track what files were created)
543        // For now, we'll create a minimal artifact with stdout/stderr
544        let mut output_files = BTreeMap::new();
545        output_files.insert(
546            PathBuf::from("stdout.txt"),
547            result.stdout.as_bytes().to_vec(),
548        );
549        output_files.insert(
550            PathBuf::from("stderr.txt"),
551            result.stderr.as_bytes().to_vec(),
552        );
553
554        // Create artifact
555        let artifact = Artifact::new(
556            package.name.clone(),
557            task_name.to_string(),
558            command.to_string(),
559            cache_key.as_string(),
560            output_files,
561        )?;
562
563        let _ = rt.block_on(remote_cache.upload_artifact(&cache_key, &artifact));
564
565        Ok(())
566    }
567
568    fn execute_task(&self, package: &Package, task_name: &str) -> Result<TaskResult> {
569        let results = self.execute_task_with_deps(package, task_name)?;
570        results
571            .into_iter()
572            .find(|r| r.task_name == task_name)
573            .ok_or_else(|| Error::TaskExecution {
574                package: package.name.clone(),
575                task: task_name.to_string(),
576                message: "Task execution failed".to_string(),
577            })
578    }
579}
580
581/// Result of executing a task for a package.
582#[derive(Debug, Clone)]
583pub struct TaskResult {
584    /// Name of the package that was executed.
585    pub package_name: String,
586    /// Name of the task that was executed.
587    pub task_name: String,
588    /// Whether the task succeeded.
589    pub success: bool,
590    /// Standard output from the task.
591    pub stdout: String,
592    /// Standard error from the task.
593    pub stderr: String,
594}