Skip to main content

capsule_core/wasm/
state.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use wasmtime::component::{ResourceTable, bindgen};
5use wasmtime::{ResourceLimiter, StoreLimits};
6use wasmtime_wasi::{WasiCtx, WasiView};
7use wasmtime_wasi_http::{WasiHttpCtx, WasiHttpView};
8
9use crate::wasm::commands::create::CreateInstance;
10use crate::wasm::commands::run::RunInstance;
11use crate::wasm::runtime::Runtime;
12use crate::wasm::utilities::task_config::{TaskConfig, TaskResult};
13
14use capsule::host::api::{Host, HttpError, HttpResponse, TaskError};
15
16bindgen!({
17    path: "./capsule-wit",
18    world: "capsule-agent",
19    async: true,
20});
21
22pub use capsule::host::api as host_api;
23
24pub struct State {
25    pub ctx: WasiCtx,
26    pub http_ctx: WasiHttpCtx,
27    pub table: ResourceTable,
28    pub limits: StoreLimits,
29    pub runtime: Option<Arc<Runtime>>,
30}
31
32impl WasiView for State {
33    fn ctx(&mut self) -> &mut WasiCtx {
34        &mut self.ctx
35    }
36    fn table(&mut self) -> &mut ResourceTable {
37        &mut self.table
38    }
39}
40
41impl WasiHttpView for State {
42    fn ctx(&mut self) -> &mut WasiHttpCtx {
43        &mut self.http_ctx
44    }
45    fn table(&mut self) -> &mut ResourceTable {
46        &mut self.table
47    }
48}
49
50impl Host for State {
51    async fn schedule_task(
52        &mut self,
53        name: String,
54        args: String,
55        config: String,
56    ) -> Result<String, TaskError> {
57        let runtime = match &self.runtime {
58            Some(r) => Arc::clone(r),
59            None => {
60                return Err(TaskError::InternalError(
61                    "No runtime available for recursive task execution".to_string(),
62                ));
63            }
64        };
65
66        let task_config: TaskConfig = serde_json::from_str(&config).unwrap_or_default();
67        let policy = task_config.to_execution_policy(&runtime.capsule_toml);
68        let max_retries = policy.max_retries;
69
70        let mut last_error: Option<String> = None;
71
72        for attempt in 0..=max_retries {
73            let create_cmd = CreateInstance::new(policy.clone(), vec![]).task_name(&name);
74
75            let (store, instance, task_id) = match runtime.execute(create_cmd).await {
76                Ok(result) => result,
77                Err(e) => {
78                    runtime
79                        .task_reporter
80                        .lock()
81                        .await
82                        .task_failed(&name, &e.to_string());
83                    last_error = Some(format!("Failed to create instance: {}", e));
84                    continue;
85                }
86            };
87
88            let args_json = format!(
89                r#"{{"task_name": "{}", "args": {}, "kwargs": {{}}}}"#,
90                name, args
91            );
92
93            runtime
94                .task_reporter
95                .lock()
96                .await
97                .task_running(&name, &task_id);
98
99            let start_time = std::time::Instant::now();
100
101            let run_cmd = RunInstance::new(task_id, policy.clone(), store, instance, args_json);
102
103            match runtime.execute(run_cmd).await {
104                Ok(result) => {
105                    if result.is_empty() {
106                        last_error = Some("Task failed".to_string());
107                        if attempt < max_retries {
108                            continue;
109                        }
110                    } else {
111                        match serde_json::from_str::<TaskResult>(&result) {
112                            Ok(task_result) if task_result.success => {
113                                let elapsed = start_time.elapsed();
114                                runtime
115                                    .task_reporter
116                                    .lock()
117                                    .await
118                                    .task_completed_with_time(&name, elapsed);
119
120                                return Ok(result);
121                            }
122                            Ok(_) => {
123                                if attempt < max_retries {
124                                    continue;
125                                }
126
127                                return Ok(result);
128                            }
129                            Err(_) => {
130                                if attempt < max_retries {
131                                    continue;
132                                }
133                            }
134                        }
135                    }
136                }
137                Err(_) => {
138                    if attempt < max_retries {
139                        continue;
140                    }
141                }
142            }
143        }
144
145        Ok(last_error.unwrap_or_else(|| "Unknown error".to_string()))
146    }
147
148    async fn http_request(
149        &mut self,
150        method: String,
151        url: String,
152        headers: Vec<(String, String)>,
153        body: Option<String>,
154    ) -> Result<HttpResponse, HttpError> {
155        let client = reqwest::Client::new();
156
157        let mut request_builder = match method.to_uppercase().as_str() {
158            "GET" => client.get(&url),
159            "POST" => client.post(&url),
160            "PUT" => client.put(&url),
161            "DELETE" => client.delete(&url),
162            "PATCH" => client.patch(&url),
163            "HEAD" => client.head(&url),
164            _ => {
165                return Err(HttpError::InvalidUrl(format!(
166                    "Unsupported method: {}",
167                    method
168                )));
169            }
170        };
171
172        for (key, value) in headers {
173            request_builder = request_builder.header(key, value);
174        }
175
176        if let Some(body_content) = body {
177            request_builder = request_builder.body(body_content);
178        }
179
180        let response = request_builder
181            .send()
182            .await
183            .map_err(|e| HttpError::NetworkError(e.to_string()))?;
184
185        let status = response.status().as_u16();
186        let response_headers: Vec<(String, String)> = response
187            .headers()
188            .iter()
189            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
190            .collect();
191
192        let body_text = response
193            .text()
194            .await
195            .map_err(|e| HttpError::NetworkError(e.to_string()))?;
196
197        Ok(HttpResponse {
198            status,
199            headers: response_headers,
200            body: body_text,
201        })
202    }
203}
204
205impl ResourceLimiter for State {
206    fn memory_growing(
207        &mut self,
208        current: usize,
209        desired: usize,
210        maximum: Option<usize>,
211    ) -> Result<bool> {
212        self.limits.memory_growing(current, desired, maximum)
213    }
214
215    fn table_growing(
216        &mut self,
217        current: usize,
218        desired: usize,
219        maximum: Option<usize>,
220    ) -> Result<bool> {
221        self.limits.table_growing(current, desired, maximum)
222    }
223}