polykit_core/
executor.rs

1//! Internal task execution logic.
2
3use std::collections::HashSet;
4use std::path::PathBuf;
5use std::process::{Command, Stdio};
6use std::sync::Arc;
7
8use crate::command_validator::CommandValidator;
9use crate::error::{Error, Result};
10use crate::graph::DependencyGraph;
11use crate::package::Package;
12use crate::remote_cache::{Artifact, ArtifactVerifier, RemoteCache};
13use crate::runner::TaskResult;
14use crate::simd_utils;
15use crate::task_cache::TaskCache;
16
17pub struct TaskExecutor {
18    packages_dir: PathBuf,
19    graph: Arc<DependencyGraph>,
20    command_validator: CommandValidator,
21    task_cache: Option<TaskCache>,
22    remote_cache: Option<Arc<RemoteCache>>,
23}
24
25impl TaskExecutor {
26    pub fn new(
27        packages_dir: PathBuf,
28        graph: DependencyGraph,
29        command_validator: CommandValidator,
30        task_cache: Option<TaskCache>,
31        remote_cache: Option<Arc<RemoteCache>>,
32    ) -> Self {
33        Self {
34            packages_dir,
35            graph: Arc::new(graph),
36            command_validator,
37            task_cache,
38            remote_cache,
39        }
40    }
41
42    pub fn build_task_dependency_order(
43        &self,
44        package: &Package,
45        task_name: &str,
46    ) -> Result<Vec<String>> {
47        let _task = package
48            .get_task(task_name)
49            .ok_or_else(|| Error::TaskExecution {
50                package: package.name.clone(),
51                task: task_name.to_string(),
52                message: format!("Task '{}' not found", task_name),
53            })?;
54
55        let mut order = Vec::new();
56        let mut visited = HashSet::new();
57        let mut visiting = HashSet::new();
58
59        fn visit_task(
60            package: &Package,
61            task_name: &str,
62            order: &mut Vec<String>,
63            visited: &mut HashSet<String>,
64            visiting: &mut HashSet<String>,
65        ) -> Result<()> {
66            if visiting.contains(task_name) {
67                return Err(Error::TaskExecution {
68                    package: package.name.clone(),
69                    task: task_name.to_string(),
70                    message: format!(
71                        "Circular task dependency detected involving '{}'",
72                        task_name
73                    ),
74                });
75            }
76
77            if visited.contains(task_name) {
78                return Ok(());
79            }
80
81            visiting.insert(task_name.to_string());
82            let task = package
83                .get_task(task_name)
84                .ok_or_else(|| Error::TaskExecution {
85                    package: package.name.clone(),
86                    task: task_name.to_string(),
87                    message: format!("Task '{}' not found", task_name),
88                })?;
89
90            for dep in &task.depends_on {
91                visit_task(package, dep, order, visited, visiting)?;
92            }
93
94            visiting.remove(task_name);
95            visited.insert(task_name.to_string());
96            order.push(task_name.to_string());
97
98            Ok(())
99        }
100
101        visit_task(package, task_name, &mut order, &mut visited, &mut visiting)?;
102
103        Ok(order)
104    }
105
106    pub fn execute_task_with_deps(
107        &self,
108        package: &Package,
109        task_name: &str,
110    ) -> Result<Vec<TaskResult>> {
111        let task_order = self.build_task_dependency_order(package, task_name)?;
112        let mut results = Vec::with_capacity(task_order.len());
113
114        for task in &task_order {
115            let result = self.execute_task_internal(package, task)?;
116            let success = result.success;
117            results.push(result);
118            if !success && task == task_name {
119                return Ok(results);
120            }
121        }
122
123        Ok(results)
124    }
125
126    pub fn execute_task_internal(&self, package: &Package, task_name: &str) -> Result<TaskResult> {
127        let task = package.get_task(task_name).ok_or_else(|| {
128            let available_tasks: Vec<&str> =
129                package.tasks.iter().map(|t| t.name.as_str()).collect();
130            Error::TaskExecution {
131                package: package.name.clone(),
132                task: task_name.to_string(),
133                message: format!(
134                    "Task '{}' not found. Available tasks: {}",
135                    task_name,
136                    available_tasks.join(", ")
137                ),
138            }
139        })?;
140
141        let package_path = self.packages_dir.join(&package.path);
142
143        if let Some(ref remote_cache) = self.remote_cache {
144            match self.check_remote_cache(
145                remote_cache,
146                package,
147                task_name,
148                &task.command,
149                &package_path,
150            ) {
151                Ok(Some(cached_result)) => return Ok(cached_result),
152                Ok(None) => {}
153                Err(_) => {}
154            }
155        }
156
157        if let Some(ref cache) = self.task_cache {
158            if let Some(cached_result) = cache.get(&package.name, task_name, &task.command)? {
159                return Ok(cached_result);
160            }
161        }
162
163        self.command_validator.validate(&task.command)?;
164
165        let output = Command::new("sh")
166            .arg("-c")
167            .arg(&task.command)
168            .current_dir(&package_path)
169            .stdout(Stdio::piped())
170            .stderr(Stdio::piped())
171            .output()
172            .map_err(|e| Error::TaskExecution {
173                package: package.name.clone(),
174                task: task_name.to_string(),
175                message: format!("Failed to execute task: {}", e),
176            })?;
177
178        let stdout = if simd_utils::is_ascii_fast(&output.stdout) {
179            unsafe { String::from_utf8_unchecked(output.stdout) }
180        } else {
181            String::from_utf8_lossy(&output.stdout).to_string()
182        };
183
184        let stderr = if simd_utils::is_ascii_fast(&output.stderr) {
185            unsafe { String::from_utf8_unchecked(output.stderr) }
186        } else {
187            String::from_utf8_lossy(&output.stderr).to_string()
188        };
189
190        let result = TaskResult {
191            package_name: package.name.clone(),
192            task_name: task_name.to_string(),
193            success: output.status.success(),
194            stdout,
195            stderr,
196        };
197
198        // Store in local cache
199        if let Some(ref cache) = self.task_cache {
200            let _ = cache.put(&package.name, task_name, &task.command, &result);
201        }
202
203        if result.success {
204            if let Some(ref remote_cache) = self.remote_cache {
205                let remote_cache = Arc::clone(remote_cache);
206                let package = package.clone();
207                let task_name = task_name.to_string();
208                let command = task.command.clone();
209                let package_path = package_path.to_path_buf();
210                let result = result.clone();
211
212                tokio::spawn(async move {
213                    let rt = tokio::runtime::Handle::try_current()
214                        .unwrap_or_else(|_| tokio::runtime::Runtime::new().unwrap().handle().clone());
215                    rt.block_on(async {
216                        use std::collections::BTreeMap;
217                        let mut output_files = BTreeMap::new();
218                        output_files.insert(
219                            PathBuf::from("stdout.txt"),
220                            result.stdout.as_bytes().to_vec(),
221                        );
222                        output_files.insert(
223                            PathBuf::from("stderr.txt"),
224                            result.stderr.as_bytes().to_vec(),
225                        );
226                        let temp_graph = DependencyGraph::new(vec![package.clone()]).ok();
227                        if let Some(ref graph) = temp_graph {
228                            if let Ok(cache_key) = remote_cache
229                                .build_cache_key(&package, &task_name, &command, graph, &package_path)
230                                .await
231                            {
232                                if let Ok(artifact) = Artifact::new(
233                                    package.name.clone(),
234                                    task_name.clone(),
235                                    command.clone(),
236                                    cache_key.as_string(),
237                                    output_files,
238                                ) {
239                                    let _ = remote_cache.upload_artifact(&cache_key, &artifact).await;
240                                }
241                            }
242                        }
243                    });
244                });
245            }
246        }
247
248        Ok(result)
249    }
250
251    /// Checks remote cache for a task result.
252    ///
253    /// Returns `Ok(Some(result))` if found, `Ok(None)` if not found, or `Err` on error.
254    pub fn check_remote_cache(
255        &self,
256        remote_cache: &RemoteCache,
257        package: &Package,
258        task_name: &str,
259        command: &str,
260        package_path: &std::path::Path,
261    ) -> Result<Option<TaskResult>> {
262        let rt = match tokio::runtime::Handle::try_current() {
263            Ok(handle) => handle,
264            Err(_) => {
265                // Create a new runtime if we're not in an async context
266                tokio::runtime::Runtime::new()
267                    .map_err(|e| Error::Adapter {
268                        package: "remote-cache".to_string(),
269                        message: format!("Failed to create tokio runtime: {}", e),
270                    })?
271                    .handle()
272                    .clone()
273            }
274        };
275
276        // Build cache key
277        let cache_key = rt.block_on(remote_cache.build_cache_key(
278            package,
279            task_name,
280            command,
281            self.graph.as_ref(),
282            package_path,
283        ))?;
284
285        // Fetch artifact
286        let artifact_opt = rt.block_on(remote_cache.fetch_artifact(&cache_key))?;
287
288        if let Some(artifact) = artifact_opt {
289            if ArtifactVerifier::verify(&artifact, None).is_err() {
290                return Err(Error::Adapter {
291                    package: "remote-cache".to_string(),
292                    message: "Artifact integrity verification failed".to_string(),
293                });
294            }
295
296            // Extract outputs
297            artifact.extract_outputs(package_path)?;
298
299            // Return cached result
300            Ok(Some(TaskResult {
301                package_name: package.name.clone(),
302                task_name: task_name.to_string(),
303                success: true,
304                stdout: String::new(), // Outputs are in files, not stdout
305                stderr: String::new(),
306            }))
307        } else {
308            Ok(None)
309        }
310    }
311
312    pub fn execute_task(&self, package: &Package, task_name: &str) -> Result<TaskResult> {
313        let results = self.execute_task_with_deps(package, task_name)?;
314        results
315            .into_iter()
316            .find(|r| r.task_name == task_name)
317            .ok_or_else(|| Error::TaskExecution {
318                package: package.name.clone(),
319                task: task_name.to_string(),
320                message: "Task execution failed".to_string(),
321            })
322    }
323}