Skip to main content

capsule_core/wasm/utilities/
task_config.rs

1use serde::{Deserialize, Serialize};
2
3use crate::config::manifest::CapsuleToml;
4use crate::wasm::execution_policy::{Compute, ExecutionPolicy};
5
6#[derive(Serialize, Deserialize)]
7pub struct TaskResult {
8    pub success: bool,
9    pub result: Option<serde_json::Value>,
10    pub error: Option<TaskError>,
11    pub execution: TaskExecution,
12}
13
14#[derive(Serialize, Deserialize)]
15pub struct TaskError {
16    pub error_type: String,
17    pub message: String,
18}
19
20#[derive(Serialize, Deserialize)]
21pub struct TaskExecution {
22    pub task_name: String,
23    pub duration_ms: u64,
24    pub retries: u64,
25    pub fuel_consumed: u64,
26}
27
28#[derive(Debug, Deserialize, Default)]
29pub struct TaskConfig {
30    name: Option<String>,
31    compute: Option<String>,
32    ram: Option<String>,
33    timeout: Option<String>,
34
35    #[serde(alias = "maxRetries")]
36    max_retries: Option<u64>,
37
38    #[serde(alias = "allowedFiles")]
39    allowed_files: Option<Vec<String>>,
40
41    #[serde(alias = "envVariables")]
42    env_variables: Option<Vec<String>>,
43}
44
45impl TaskConfig {
46    pub fn to_execution_policy(&self, capsule_toml: &CapsuleToml) -> ExecutionPolicy {
47        let default_policy = capsule_toml.tasks.as_ref();
48
49        let compute = self
50            .compute
51            .as_ref()
52            .map(|c| match c.to_uppercase().as_str() {
53                "LOW" => Compute::Low,
54                "MEDIUM" => Compute::Medium,
55                "HIGH" => Compute::High,
56                _ => c
57                    .parse::<u64>()
58                    .map(Compute::Custom)
59                    .unwrap_or(Compute::Medium),
60            })
61            .or_else(|| default_policy.and_then(|p| p.default_compute.clone()));
62
63        let ram = self
64            .ram
65            .as_ref()
66            .and_then(|r| Self::parse_ram_string(r))
67            .or_else(|| {
68                default_policy
69                    .and_then(|p| p.default_ram.as_ref())
70                    .and_then(|r| Self::parse_ram_string(r))
71            });
72
73        let timeout = self
74            .timeout
75            .clone()
76            .or_else(|| default_policy.and_then(|p| p.default_timeout.clone()));
77
78        let max_retries = self
79            .max_retries
80            .or_else(|| default_policy.and_then(|p| p.default_max_retries));
81
82        let allowed_files = self
83            .allowed_files
84            .clone()
85            .or_else(|| default_policy.and_then(|p| p.default_allowed_files.clone()))
86            .unwrap_or_default();
87
88        let env_variables = self
89            .env_variables
90            .clone()
91            .or_else(|| default_policy.and_then(|p| p.default_env_variables.clone()))
92            .unwrap_or_default();
93
94        ExecutionPolicy::new()
95            .name(self.name.clone())
96            .compute(compute)
97            .ram(ram)
98            .timeout(timeout)
99            .max_retries(max_retries)
100            .allowed_files(allowed_files)
101            .env_variables(env_variables)
102    }
103
104    pub fn parse_ram_string(s: &str) -> Option<u64> {
105        let s = s.trim().to_uppercase();
106        if s.ends_with("GB") {
107            s.trim_end_matches("GB")
108                .trim()
109                .parse::<u64>()
110                .ok()
111                .map(|v| v * 1024 * 1024 * 1024)
112        } else if s.ends_with("MB") {
113            s.trim_end_matches("MB")
114                .trim()
115                .parse::<u64>()
116                .ok()
117                .map(|v| v * 1024 * 1024)
118        } else if s.ends_with("KB") {
119            s.trim_end_matches("KB")
120                .trim()
121                .parse::<u64>()
122                .ok()
123                .map(|v| v * 1024)
124        } else {
125            s.parse::<u64>().ok()
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::config::manifest::DefaultPolicy;
134
135    #[test]
136    fn test_parse_ram_string() {
137        assert_eq!(
138            TaskConfig::parse_ram_string("2GB"),
139            Some(2 * 1024 * 1024 * 1024)
140        );
141        assert_eq!(
142            TaskConfig::parse_ram_string("1 GB"),
143            Some(1024 * 1024 * 1024)
144        );
145        assert_eq!(
146            TaskConfig::parse_ram_string("4gb"),
147            Some(4 * 1024 * 1024 * 1024)
148        );
149
150        assert_eq!(
151            TaskConfig::parse_ram_string("512MB"),
152            Some(512 * 1024 * 1024)
153        );
154        assert_eq!(
155            TaskConfig::parse_ram_string("256 MB"),
156            Some(256 * 1024 * 1024)
157        );
158        assert_eq!(
159            TaskConfig::parse_ram_string("128mb"),
160            Some(128 * 1024 * 1024)
161        );
162
163        assert_eq!(TaskConfig::parse_ram_string("1024KB"), Some(1024 * 1024));
164        assert_eq!(TaskConfig::parse_ram_string("512 KB"), Some(512 * 1024));
165        assert_eq!(TaskConfig::parse_ram_string("256kb"), Some(256 * 1024));
166
167        assert_eq!(TaskConfig::parse_ram_string("1024"), Some(1024));
168        assert_eq!(TaskConfig::parse_ram_string("512"), Some(512));
169    }
170
171    #[test]
172    fn test_parse_ram_string_invalid() {
173        assert_eq!(TaskConfig::parse_ram_string("invalid"), None);
174        assert_eq!(TaskConfig::parse_ram_string(""), None);
175        assert_eq!(TaskConfig::parse_ram_string("GB"), None);
176    }
177
178    #[test]
179    fn test_to_execution_policy_default() {
180        let config = TaskConfig::default();
181        let policy = config.to_execution_policy(&CapsuleToml::default());
182
183        assert_eq!(policy.name, "default");
184        assert_eq!(policy.compute, Compute::Medium);
185        assert_eq!(policy.ram, None);
186        assert_eq!(policy.timeout, None);
187        assert_eq!(policy.max_retries, 0);
188    }
189
190    #[test]
191    fn test_to_execution_policy_with_values() {
192        let config = TaskConfig {
193            name: Some("test_task".to_string()),
194            compute: Some("HIGH".to_string()),
195            ram: Some("2GB".to_string()),
196            timeout: Some("30s".to_string()),
197            max_retries: Some(3),
198            allowed_files: Some(vec!["./data".to_string()]),
199            env_variables: Some(vec!["FOO".to_string()]),
200        };
201
202        let policy = config.to_execution_policy(&CapsuleToml::default());
203
204        assert_eq!(policy.name, "test_task");
205        assert_eq!(policy.compute, Compute::High);
206        assert_eq!(policy.ram, Some(2 * 1024 * 1024 * 1024));
207        assert_eq!(policy.timeout, Some("30s".to_string()));
208        assert_eq!(policy.max_retries, 3);
209    }
210
211    #[test]
212    fn test_to_execution_policy_compute_variants() {
213        let low = TaskConfig {
214            compute: Some("LOW".to_string()),
215            ..Default::default()
216        };
217        assert_eq!(
218            low.to_execution_policy(&CapsuleToml::default()).compute,
219            Compute::Low
220        );
221
222        let medium = TaskConfig {
223            compute: Some("MEDIUM".to_string()),
224            ..Default::default()
225        };
226        assert_eq!(
227            medium.to_execution_policy(&CapsuleToml::default()).compute,
228            Compute::Medium
229        );
230
231        let high = TaskConfig {
232            compute: Some("HIGH".to_string()),
233            ..Default::default()
234        };
235        assert_eq!(
236            high.to_execution_policy(&CapsuleToml::default()).compute,
237            Compute::High
238        );
239
240        let invalid = TaskConfig {
241            compute: Some("INVALID".to_string()),
242            ..Default::default()
243        };
244        assert_eq!(
245            invalid.to_execution_policy(&CapsuleToml::default()).compute,
246            Compute::Medium
247        );
248    }
249
250    #[test]
251    fn test_to_execution_policy_uses_capsule_toml_defaults() {
252        let capsule_toml = CapsuleToml {
253            workflow: None,
254            tasks: Some(DefaultPolicy {
255                default_compute: Some(Compute::High),
256                default_ram: Some("1GB".to_string()),
257                default_timeout: Some("60s".to_string()),
258                default_max_retries: Some(5),
259                default_allowed_files: Some(vec!["./default".to_string()]),
260                default_env_variables: Some(vec!["FOO".to_string()]),
261            }),
262        };
263
264        let config = TaskConfig::default();
265        let policy = config.to_execution_policy(&capsule_toml);
266
267        assert_eq!(policy.compute, Compute::High);
268        assert_eq!(policy.ram, Some(1024 * 1024 * 1024));
269        assert_eq!(policy.timeout, Some("60s".to_string()));
270        assert_eq!(policy.max_retries, 5);
271        assert_eq!(policy.allowed_files, vec!["./default".to_string()]);
272    }
273
274    #[test]
275    fn test_task_config_overrides_capsule_toml_defaults() {
276        let capsule_toml = CapsuleToml {
277            workflow: None,
278            tasks: Some(DefaultPolicy {
279                default_compute: Some(Compute::Low),
280                default_ram: Some("512MB".to_string()),
281                default_timeout: Some("30s".to_string()),
282                default_max_retries: Some(2),
283                default_allowed_files: Some(vec!["./default.txt".to_string()]),
284                default_env_variables: Some(vec!["FOO".to_string()]),
285            }),
286        };
287
288        let config = TaskConfig {
289            name: Some("override_task".to_string()),
290            compute: Some("HIGH".to_string()),
291            ram: Some("4GB".to_string()),
292            timeout: Some("120s".to_string()),
293            max_retries: Some(10),
294            allowed_files: Some(vec!["./custom".to_string()]),
295            env_variables: Some(vec!["BAR".to_string()]),
296        };
297
298        let policy = config.to_execution_policy(&capsule_toml);
299
300        assert_eq!(policy.name, "override_task");
301        assert_eq!(policy.compute, Compute::High);
302        assert_eq!(policy.ram, Some(4 * 1024 * 1024 * 1024));
303        assert_eq!(policy.timeout, Some("120s".to_string()));
304        assert_eq!(policy.max_retries, 10);
305        assert_eq!(policy.allowed_files, vec!["./custom".to_string()]);
306        assert_eq!(policy.env_variables, vec!["BAR".to_string()]);
307    }
308
309    #[test]
310    fn test_partial_task_config_with_capsule_toml_defaults() {
311        let capsule_toml = CapsuleToml {
312            workflow: None,
313            tasks: Some(DefaultPolicy {
314                default_compute: Some(Compute::Medium),
315                default_ram: Some("2GB".to_string()),
316                default_timeout: Some("45s".to_string()),
317                default_max_retries: Some(3),
318                default_allowed_files: Some(vec!["./default".to_string()]),
319                default_env_variables: Some(vec!["FOO".to_string()]),
320            }),
321        };
322
323        let config = TaskConfig {
324            name: Some("partial_task".to_string()),
325            compute: Some("LOW".to_string()),
326            ram: None,
327            timeout: None,
328            max_retries: Some(1),
329            allowed_files: None,
330            env_variables: None,
331        };
332
333        let policy = config.to_execution_policy(&capsule_toml);
334
335        assert_eq!(policy.name, "partial_task");
336        assert_eq!(policy.compute, Compute::Low);
337        assert_eq!(policy.ram, Some(2 * 1024 * 1024 * 1024));
338        assert_eq!(policy.timeout, Some("45s".to_string()));
339        assert_eq!(policy.max_retries, 1);
340        assert_eq!(policy.allowed_files, vec!["./default".to_string()]);
341        assert_eq!(policy.env_variables, vec!["FOO".to_string()]);
342    }
343}