Skip to main content

capsule_core/wasm/utilities/
task_config.rs

1use serde::{Deserialize, Serialize};
2
3use crate::config::manifest::CapsuleToml;
4use crate::wasm::execution_policy::{Compute, ExecutionPolicy};
5
6#[derive(Serialize, Deserialize)]
7pub struct TaskResult {
8    pub success: bool,
9    pub result: Option<serde_json::Value>,
10    pub error: Option<TaskError>,
11    pub execution: TaskExecution,
12}
13
14#[derive(Serialize, Deserialize)]
15pub struct TaskError {
16    pub error_type: String,
17    pub message: String,
18}
19
20#[derive(Serialize, Deserialize)]
21pub struct TaskExecution {
22    pub task_name: String,
23    pub duration_ms: u64,
24    pub retries: u64,
25    pub fuel_consumed: u64,
26}
27
28#[derive(Debug, Deserialize, Default)]
29pub struct TaskConfig {
30    name: Option<String>,
31    compute: Option<String>,
32    ram: Option<String>,
33    timeout: Option<String>,
34
35    #[serde(alias = "maxRetries")]
36    max_retries: Option<u64>,
37
38    #[serde(alias = "allowedFiles")]
39    allowed_files: Option<Vec<String>>,
40
41    #[serde(alias = "allowedHosts")]
42    allowed_hosts: Option<Vec<String>>,
43
44    #[serde(alias = "envVariables")]
45    env_variables: Option<Vec<String>>,
46}
47
48impl TaskConfig {
49    pub fn to_execution_policy(&self, capsule_toml: &CapsuleToml) -> ExecutionPolicy {
50        let default_policy = capsule_toml.tasks.as_ref();
51
52        let compute = self
53            .compute
54            .as_ref()
55            .map(|c| match c.to_uppercase().as_str() {
56                "LOW" => Compute::Low,
57                "MEDIUM" => Compute::Medium,
58                "HIGH" => Compute::High,
59                _ => c
60                    .parse::<u64>()
61                    .map(Compute::Custom)
62                    .unwrap_or(Compute::Medium),
63            })
64            .or_else(|| default_policy.and_then(|p| p.default_compute.clone()));
65
66        let ram = self
67            .ram
68            .as_ref()
69            .and_then(|r| Self::parse_ram_string(r))
70            .or_else(|| {
71                default_policy
72                    .and_then(|p| p.default_ram.as_ref())
73                    .and_then(|r| Self::parse_ram_string(r))
74            });
75
76        let timeout = self
77            .timeout
78            .clone()
79            .or_else(|| default_policy.and_then(|p| p.default_timeout.clone()));
80
81        let max_retries = self
82            .max_retries
83            .or_else(|| default_policy.and_then(|p| p.default_max_retries));
84
85        let allowed_files = self
86            .allowed_files
87            .clone()
88            .or_else(|| default_policy.and_then(|p| p.default_allowed_files.clone()))
89            .unwrap_or_default();
90
91        let allowed_hosts = self
92            .allowed_hosts
93            .clone()
94            .or_else(|| default_policy.and_then(|p| p.default_allowed_hosts.clone()))
95            .unwrap_or(vec!["*".to_string()]);
96
97        let env_variables = self
98            .env_variables
99            .clone()
100            .or_else(|| default_policy.and_then(|p| p.default_env_variables.clone()))
101            .unwrap_or_default();
102
103        ExecutionPolicy::new()
104            .name(self.name.clone())
105            .compute(compute)
106            .ram(ram)
107            .timeout(timeout)
108            .max_retries(max_retries)
109            .allowed_files(allowed_files)
110            .allowed_hosts(allowed_hosts)
111            .env_variables(env_variables)
112    }
113
114    pub fn parse_ram_string(s: &str) -> Option<u64> {
115        let s = s.trim().to_uppercase();
116        if s.ends_with("GB") {
117            s.trim_end_matches("GB")
118                .trim()
119                .parse::<u64>()
120                .ok()
121                .map(|v| v * 1024 * 1024 * 1024)
122        } else if s.ends_with("MB") {
123            s.trim_end_matches("MB")
124                .trim()
125                .parse::<u64>()
126                .ok()
127                .map(|v| v * 1024 * 1024)
128        } else if s.ends_with("KB") {
129            s.trim_end_matches("KB")
130                .trim()
131                .parse::<u64>()
132                .ok()
133                .map(|v| v * 1024)
134        } else {
135            s.parse::<u64>().ok()
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::config::manifest::DefaultPolicy;
144
145    #[test]
146    fn test_parse_ram_string() {
147        assert_eq!(
148            TaskConfig::parse_ram_string("2GB"),
149            Some(2 * 1024 * 1024 * 1024)
150        );
151        assert_eq!(
152            TaskConfig::parse_ram_string("1 GB"),
153            Some(1024 * 1024 * 1024)
154        );
155        assert_eq!(
156            TaskConfig::parse_ram_string("4gb"),
157            Some(4 * 1024 * 1024 * 1024)
158        );
159
160        assert_eq!(
161            TaskConfig::parse_ram_string("512MB"),
162            Some(512 * 1024 * 1024)
163        );
164        assert_eq!(
165            TaskConfig::parse_ram_string("256 MB"),
166            Some(256 * 1024 * 1024)
167        );
168        assert_eq!(
169            TaskConfig::parse_ram_string("128mb"),
170            Some(128 * 1024 * 1024)
171        );
172
173        assert_eq!(TaskConfig::parse_ram_string("1024KB"), Some(1024 * 1024));
174        assert_eq!(TaskConfig::parse_ram_string("512 KB"), Some(512 * 1024));
175        assert_eq!(TaskConfig::parse_ram_string("256kb"), Some(256 * 1024));
176
177        assert_eq!(TaskConfig::parse_ram_string("1024"), Some(1024));
178        assert_eq!(TaskConfig::parse_ram_string("512"), Some(512));
179    }
180
181    #[test]
182    fn test_parse_ram_string_invalid() {
183        assert_eq!(TaskConfig::parse_ram_string("invalid"), None);
184        assert_eq!(TaskConfig::parse_ram_string(""), None);
185        assert_eq!(TaskConfig::parse_ram_string("GB"), None);
186    }
187
188    #[test]
189    fn test_to_execution_policy_default() {
190        let config = TaskConfig::default();
191        let policy = config.to_execution_policy(&CapsuleToml::default());
192
193        assert_eq!(policy.name, "default");
194        assert_eq!(policy.compute, Compute::Medium);
195        assert_eq!(policy.ram, None);
196        assert_eq!(policy.timeout, None);
197        assert_eq!(policy.max_retries, 0);
198    }
199
200    #[test]
201    fn test_to_execution_policy_with_values() {
202        let config = TaskConfig {
203            name: Some("test_task".to_string()),
204            compute: Some("HIGH".to_string()),
205            ram: Some("2GB".to_string()),
206            timeout: Some("30s".to_string()),
207            max_retries: Some(3),
208            allowed_files: Some(vec!["./data".to_string()]),
209            allowed_hosts: Some(vec!["https://example.com".to_string()]),
210            env_variables: Some(vec!["FOO".to_string()]),
211        };
212
213        let policy = config.to_execution_policy(&CapsuleToml::default());
214
215        assert_eq!(policy.name, "test_task");
216        assert_eq!(policy.compute, Compute::High);
217        assert_eq!(policy.ram, Some(2 * 1024 * 1024 * 1024));
218        assert_eq!(policy.timeout, Some("30s".to_string()));
219        assert_eq!(policy.max_retries, 3);
220    }
221
222    #[test]
223    fn test_to_execution_policy_compute_variants() {
224        let low = TaskConfig {
225            compute: Some("LOW".to_string()),
226            ..Default::default()
227        };
228        assert_eq!(
229            low.to_execution_policy(&CapsuleToml::default()).compute,
230            Compute::Low
231        );
232
233        let medium = TaskConfig {
234            compute: Some("MEDIUM".to_string()),
235            ..Default::default()
236        };
237        assert_eq!(
238            medium.to_execution_policy(&CapsuleToml::default()).compute,
239            Compute::Medium
240        );
241
242        let high = TaskConfig {
243            compute: Some("HIGH".to_string()),
244            ..Default::default()
245        };
246        assert_eq!(
247            high.to_execution_policy(&CapsuleToml::default()).compute,
248            Compute::High
249        );
250
251        let invalid = TaskConfig {
252            compute: Some("INVALID".to_string()),
253            ..Default::default()
254        };
255        assert_eq!(
256            invalid.to_execution_policy(&CapsuleToml::default()).compute,
257            Compute::Medium
258        );
259    }
260
261    #[test]
262    fn test_to_execution_policy_uses_capsule_toml_defaults() {
263        let capsule_toml = CapsuleToml {
264            workflow: None,
265            tasks: Some(DefaultPolicy {
266                default_compute: Some(Compute::High),
267                default_ram: Some("1GB".to_string()),
268                default_timeout: Some("60s".to_string()),
269                default_max_retries: Some(5),
270                default_allowed_files: Some(vec!["./default".to_string()]),
271                default_allowed_hosts: Some(vec!["https://default.com".to_string()]),
272                default_env_variables: Some(vec!["FOO".to_string()]),
273            }),
274        };
275
276        let config = TaskConfig::default();
277        let policy = config.to_execution_policy(&capsule_toml);
278
279        assert_eq!(policy.compute, Compute::High);
280        assert_eq!(policy.ram, Some(1024 * 1024 * 1024));
281        assert_eq!(policy.timeout, Some("60s".to_string()));
282        assert_eq!(policy.max_retries, 5);
283        assert_eq!(policy.allowed_files, vec!["./default".to_string()]);
284        assert_eq!(
285            policy.allowed_hosts,
286            vec!["https://default.com".to_string()]
287        );
288        assert_eq!(policy.env_variables, vec!["FOO".to_string()]);
289    }
290
291    #[test]
292    fn test_task_config_overrides_capsule_toml_defaults() {
293        let capsule_toml = CapsuleToml {
294            workflow: None,
295            tasks: Some(DefaultPolicy {
296                default_compute: Some(Compute::Low),
297                default_ram: Some("512MB".to_string()),
298                default_timeout: Some("30s".to_string()),
299                default_max_retries: Some(2),
300                default_allowed_files: Some(vec!["./default.txt".to_string()]),
301                default_allowed_hosts: Some(vec!["*".to_string()]),
302                default_env_variables: Some(vec!["FOO".to_string()]),
303            }),
304        };
305
306        let config = TaskConfig {
307            name: Some("override_task".to_string()),
308            compute: Some("HIGH".to_string()),
309            ram: Some("4GB".to_string()),
310            timeout: Some("120s".to_string()),
311            max_retries: Some(10),
312            allowed_files: Some(vec!["./custom".to_string()]),
313            allowed_hosts: Some(vec!["https://custom.com".to_string()]),
314            env_variables: Some(vec!["BAR".to_string()]),
315        };
316
317        let policy = config.to_execution_policy(&capsule_toml);
318
319        assert_eq!(policy.name, "override_task");
320        assert_eq!(policy.compute, Compute::High);
321        assert_eq!(policy.ram, Some(4 * 1024 * 1024 * 1024));
322        assert_eq!(policy.timeout, Some("120s".to_string()));
323        assert_eq!(policy.max_retries, 10);
324        assert_eq!(policy.allowed_files, vec!["./custom".to_string()]);
325        assert_eq!(policy.allowed_hosts, vec!["https://custom.com".to_string()]);
326        assert_eq!(policy.env_variables, vec!["BAR".to_string()]);
327    }
328
329    #[test]
330    fn test_partial_task_config_with_capsule_toml_defaults() {
331        let capsule_toml = CapsuleToml {
332            workflow: None,
333            tasks: Some(DefaultPolicy {
334                default_compute: Some(Compute::Medium),
335                default_ram: Some("2GB".to_string()),
336                default_timeout: Some("45s".to_string()),
337                default_max_retries: Some(3),
338                default_allowed_files: Some(vec!["./default".to_string()]),
339                default_allowed_hosts: Some(vec!["*".to_string()]),
340                default_env_variables: Some(vec!["FOO".to_string()]),
341            }),
342        };
343
344        let config = TaskConfig {
345            name: Some("partial_task".to_string()),
346            compute: Some("LOW".to_string()),
347            ram: None,
348            timeout: None,
349            max_retries: Some(1),
350            allowed_files: None,
351            allowed_hosts: None,
352            env_variables: None,
353        };
354
355        let policy = config.to_execution_policy(&capsule_toml);
356
357        assert_eq!(policy.name, "partial_task");
358        assert_eq!(policy.compute, Compute::Low);
359        assert_eq!(policy.ram, Some(2 * 1024 * 1024 * 1024));
360        assert_eq!(policy.timeout, Some("45s".to_string()));
361        assert_eq!(policy.max_retries, 1);
362        assert_eq!(policy.allowed_files, vec!["./default".to_string()]);
363        assert_eq!(policy.allowed_hosts, vec!["*".to_string()]);
364        assert_eq!(policy.env_variables, vec!["FOO".to_string()]);
365    }
366}