Skip to main content

capsule_core/wasm/
state.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use hyper::Request;
5use wasmtime::component::{ResourceTable, bindgen};
6use wasmtime::{ResourceLimiter, StoreLimits};
7use wasmtime_wasi::{WasiCtx, WasiView};
8use wasmtime_wasi_http::bindings::http::types::ErrorCode;
9use wasmtime_wasi_http::body::HyperOutgoingBody;
10use wasmtime_wasi_http::types::{
11    HostFutureIncomingResponse, OutgoingRequestConfig, default_send_request,
12};
13use wasmtime_wasi_http::{HttpResult, WasiHttpCtx, WasiHttpView};
14
15use crate::wasm::commands::create::CreateInstance;
16use crate::wasm::commands::run::RunInstance;
17use crate::wasm::execution_policy::ExecutionPolicy;
18use crate::wasm::runtime::Runtime;
19use crate::wasm::utilities::host_validator::is_host_allowed;
20use crate::wasm::utilities::task_config::{TaskConfig, TaskResult};
21
22use capsule::host::api::{Host, HttpError, HttpResponse, TaskError};
23
24bindgen!({
25    path: "./capsule-wit",
26    world: "capsule-agent",
27    async: true,
28});
29
30pub use capsule::host::api as host_api;
31
32pub struct State {
33    pub ctx: WasiCtx,
34    pub http_ctx: WasiHttpCtx,
35    pub table: ResourceTable,
36    pub limits: StoreLimits,
37    pub runtime: Option<Arc<Runtime>>,
38    pub policy: ExecutionPolicy,
39}
40
41impl WasiView for State {
42    fn ctx(&mut self) -> &mut WasiCtx {
43        &mut self.ctx
44    }
45    fn table(&mut self) -> &mut ResourceTable {
46        &mut self.table
47    }
48}
49
50impl WasiHttpView for State {
51    fn ctx(&mut self) -> &mut WasiHttpCtx {
52        &mut self.http_ctx
53    }
54
55    fn table(&mut self) -> &mut ResourceTable {
56        &mut self.table
57    }
58
59    fn send_request(
60        &mut self,
61        request: Request<HyperOutgoingBody>,
62        config: OutgoingRequestConfig,
63    ) -> HttpResult<HostFutureIncomingResponse> {
64        let host = request.uri().host().unwrap_or("unknown");
65        let allowed = is_host_allowed(host, &self.policy.allowed_hosts);
66
67        if !allowed {
68            return Err(ErrorCode::HttpRequestDenied.into());
69        }
70
71        Ok(default_send_request(request, config))
72    }
73}
74
75impl Host for State {
76    async fn schedule_task(
77        &mut self,
78        name: String,
79        args: String,
80        config: String,
81    ) -> Result<String, TaskError> {
82        let runtime = match &self.runtime {
83            Some(r) => Arc::clone(r),
84            None => {
85                return Err(TaskError::InternalError(
86                    "No runtime available for recursive task execution".to_string(),
87                ));
88            }
89        };
90
91        let task_config: TaskConfig = serde_json::from_str(&config).unwrap_or_default();
92        let policy = task_config.to_execution_policy(&runtime.capsule_toml);
93        let max_retries = policy.max_retries;
94
95        let mut last_error: Option<String> = None;
96
97        for attempt in 0..=max_retries {
98            let create_cmd = CreateInstance::new(policy.clone(), vec![]).task_name(&name);
99
100            let (store, instance, task_id) = match runtime.execute(create_cmd).await {
101                Ok(result) => result,
102                Err(e) => {
103                    runtime
104                        .task_reporter
105                        .lock()
106                        .await
107                        .task_failed(&name, &e.to_string());
108                    last_error = Some(format!("Failed to create instance: {}", e));
109                    continue;
110                }
111            };
112
113            let args_json = format!(
114                r#"{{"task_name": "{}", "args": {}, "kwargs": {{}}}}"#,
115                name, args
116            );
117
118            runtime
119                .task_reporter
120                .lock()
121                .await
122                .task_running(&name, &task_id);
123
124            let start_time = std::time::Instant::now();
125
126            let run_cmd = RunInstance::new(task_id, policy.clone(), store, instance, args_json);
127
128            match runtime.execute(run_cmd).await {
129                Ok(result) => {
130                    if result.is_empty() {
131                        last_error = Some("Task failed".to_string());
132                        if attempt < max_retries {
133                            continue;
134                        }
135                    } else {
136                        match serde_json::from_str::<TaskResult>(&result) {
137                            Ok(task_result) if task_result.success => {
138                                let elapsed = start_time.elapsed();
139                                runtime
140                                    .task_reporter
141                                    .lock()
142                                    .await
143                                    .task_completed_with_time(&name, elapsed);
144
145                                return Ok(result);
146                            }
147                            Ok(_) => {
148                                if attempt < max_retries {
149                                    continue;
150                                }
151
152                                return Ok(result);
153                            }
154                            Err(_) => {
155                                if attempt < max_retries {
156                                    continue;
157                                }
158                            }
159                        }
160                    }
161                }
162                Err(_) => {
163                    if attempt < max_retries {
164                        continue;
165                    }
166                }
167            }
168        }
169
170        Ok(last_error.unwrap_or_else(|| "Unknown error".to_string()))
171    }
172
173    async fn http_request(
174        &mut self,
175        method: String,
176        url: String,
177        headers: Vec<(String, String)>,
178        body: Option<String>,
179    ) -> Result<HttpResponse, HttpError> {
180        let allowed = is_host_allowed(&url, &self.policy.allowed_hosts);
181
182        if !allowed {
183            return Err(HttpError::InvalidUrl("Host not allowed".to_string()));
184        }
185
186        let client = reqwest::Client::new();
187
188        let mut request_builder = match method.to_uppercase().as_str() {
189            "GET" => client.get(&url),
190            "POST" => client.post(&url),
191            "PUT" => client.put(&url),
192            "DELETE" => client.delete(&url),
193            "PATCH" => client.patch(&url),
194            "HEAD" => client.head(&url),
195            _ => {
196                return Err(HttpError::InvalidUrl(format!(
197                    "Unsupported method: {}",
198                    method
199                )));
200            }
201        };
202
203        for (key, value) in headers {
204            request_builder = request_builder.header(key, value);
205        }
206
207        if let Some(body_content) = body {
208            request_builder = request_builder.body(body_content);
209        }
210
211        let response = request_builder
212            .send()
213            .await
214            .map_err(|e| HttpError::NetworkError(e.to_string()))?;
215
216        let status = response.status().as_u16();
217        let response_headers: Vec<(String, String)> = response
218            .headers()
219            .iter()
220            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
221            .collect();
222
223        let body_text = response
224            .text()
225            .await
226            .map_err(|e| HttpError::NetworkError(e.to_string()))?;
227
228        Ok(HttpResponse {
229            status,
230            headers: response_headers,
231            body: body_text,
232        })
233    }
234}
235
236impl ResourceLimiter for State {
237    fn memory_growing(
238        &mut self,
239        current: usize,
240        desired: usize,
241        maximum: Option<usize>,
242    ) -> Result<bool> {
243        self.limits.memory_growing(current, desired, maximum)
244    }
245
246    fn table_growing(
247        &mut self,
248        current: usize,
249        desired: usize,
250        maximum: Option<usize>,
251    ) -> Result<bool> {
252        self.limits.table_growing(current, desired, maximum)
253    }
254}