1use crate::error::WasmError;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashSet;
7use std::fs;
8use std::path::Path;
9use wasmtime::{Engine, Module};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum WasmCapability {
14 WasiFs,
16 WasiEnv,
18 WasiArgs,
20 WasiStdio,
22 WasiNet,
24 HostFunctions,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModuleMetadata {
31 pub hash: String,
33 pub size: usize,
35 pub capabilities: HashSet<WasmCapability>,
37 pub exports: Vec<String>,
39 pub imports: Vec<WasmImport>,
41 pub is_wasi: bool,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct WasmImport {
48 pub module: String,
50 pub name: String,
52}
53
54#[derive(Debug, Clone)]
56pub struct WasmModule {
57 pub bytes: Vec<u8>,
59 pub metadata: ModuleMetadata,
61 compiled: Option<Module>,
63}
64
65impl WasmModule {
66 pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, WasmError> {
68 Self::validate_basic_format(&bytes)?;
70
71 let metadata = Self::extract_metadata(&bytes)?;
72 Self::validate_module(&bytes, &metadata)?;
73
74 Ok(WasmModule {
75 bytes,
76 metadata,
77 compiled: None,
78 })
79 }
80
81 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, WasmError> {
83 let bytes = fs::read(path)?;
84 Self::from_bytes(bytes)
85 }
86
87 pub fn get_compiled(&mut self, engine: &Engine) -> Result<&Module, WasmError> {
89 if self.compiled.is_none() {
90 let module = Module::from_binary(engine, &self.bytes)?;
91 self.compiled = Some(module);
92 }
93 Ok(self.compiled.as_ref().unwrap())
94 }
95
96 pub fn hash(&self) -> &str {
98 &self.metadata.hash
99 }
100
101 pub fn requires_capability(&self, capability: &WasmCapability) -> bool {
103 self.metadata.capabilities.contains(capability)
104 }
105
106 pub fn is_wasi(&self) -> bool {
108 self.metadata.is_wasi
109 }
110
111 fn extract_metadata(bytes: &[u8]) -> Result<ModuleMetadata, WasmError> {
113 let mut hasher = Sha256::new();
115 hasher.update(bytes);
116 let hash = format!("{:x}", hasher.finalize());
117
118 let engine = Engine::default();
120 let module = Module::from_binary(&engine, bytes)
121 .map_err(|e| WasmError::ModuleLoad(e.to_string()))?;
122
123 let mut capabilities = HashSet::new();
124 let mut exports = Vec::new();
125 let mut imports = Vec::new();
126 let mut is_wasi = false;
127
128 for export in module.exports() {
130 exports.push(export.name().to_string());
131 }
132
133 for import in module.imports() {
135 let import_info = WasmImport {
136 module: import.module().to_string(),
137 name: import.name().to_string(),
138 };
139
140 if import.module().starts_with("wasi_") {
142 is_wasi = true;
143
144 match import.name() {
146 name if name.starts_with("fd_") => {
147 capabilities.insert(WasmCapability::WasiFs);
148 capabilities.insert(WasmCapability::WasiStdio);
149 }
150 name if name.starts_with("environ_") => {
151 capabilities.insert(WasmCapability::WasiEnv);
152 }
153 name if name.starts_with("args_") => {
154 capabilities.insert(WasmCapability::WasiArgs);
155 }
156 name if name.starts_with("sock_") => {
157 capabilities.insert(WasmCapability::WasiNet);
158 }
159 _ => {}
160 }
161 } else if import.module() != "env" {
162 capabilities.insert(WasmCapability::HostFunctions);
164 }
165
166 imports.push(import_info);
167 }
168
169 if is_wasi {
171 capabilities.insert(WasmCapability::WasiStdio);
172 }
173
174 Ok(ModuleMetadata {
175 hash,
176 size: bytes.len(),
177 capabilities,
178 exports,
179 imports,
180 is_wasi,
181 })
182 }
183
184 fn validate_basic_format(bytes: &[u8]) -> Result<(), WasmError> {
186 if bytes.len() < 8 {
188 return Err(WasmError::InvalidFormat(
189 "WASM module too small (minimum 8 bytes)".to_string()
190 ));
191 }
192
193 if &bytes[0..4] != b"\0asm" {
195 return Err(WasmError::InvalidFormat(
196 "Invalid WASM magic number".to_string()
197 ));
198 }
199
200 const MAX_MODULE_SIZE: usize = 64 * 1024 * 1024;
202 if bytes.len() > MAX_MODULE_SIZE {
203 return Err(WasmError::ModuleValidation(format!(
204 "Module too large: {} bytes (max: {} bytes)",
205 bytes.len(), MAX_MODULE_SIZE
206 )));
207 }
208
209 Ok(())
210 }
211
212 fn validate_module(_bytes: &[u8], metadata: &ModuleMetadata) -> Result<(), WasmError> {
214 if metadata.capabilities.contains(&WasmCapability::WasiNet) {
216 return Err(WasmError::UnsupportedCapability(
217 "WASI networking is not supported".to_string()
218 ));
219 }
220
221 if metadata.is_wasi && !metadata.exports.contains(&"_start".to_string()) {
223 return Err(WasmError::ModuleValidation(
224 "WASI module must export '_start' function".to_string()
225 ));
226 }
227
228 Ok(())
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::test_utils::test_modules::{minimal_wasm, simple_function_wasm, wasi_hello_wasm, INVALID_MAGIC_WASM};
236
237 #[test]
238 fn test_minimal_wasm_module() {
239 let module = WasmModule::from_bytes(minimal_wasm().to_vec()).unwrap();
240 assert_eq!(module.metadata.size, minimal_wasm().len());
241 assert!(!module.is_wasi());
242 assert!(module.metadata.exports.is_empty());
243 assert!(module.metadata.imports.is_empty());
244 }
245
246 #[test]
247 fn test_simple_function_wasm() {
248 let module = WasmModule::from_bytes(simple_function_wasm().to_vec()).unwrap();
249 assert!(!module.is_wasi());
250 assert!(module.metadata.exports.contains(&"add".to_string()));
251 assert!(!module.requires_capability(&WasmCapability::WasiStdio));
252 }
253
254 #[test]
255 fn test_wasi_module_detection() {
256 let module = WasmModule::from_bytes(wasi_hello_wasm().to_vec()).unwrap();
257 assert!(module.is_wasi());
258 assert!(module.metadata.exports.contains(&"_start".to_string()));
259 assert!(module.metadata.exports.contains(&"memory".to_string()));
260 assert!(module.requires_capability(&WasmCapability::WasiStdio));
261
262 let has_fd_write = module.metadata.imports.iter()
264 .any(|imp| imp.module == "wasi_snapshot_preview1" && imp.name == "fd_write");
265 assert!(has_fd_write);
266
267 let has_environ_get = module.metadata.imports.iter()
268 .any(|imp| imp.module == "wasi_snapshot_preview1" && imp.name == "environ_get");
269 assert!(has_environ_get);
270
271 assert!(module.requires_capability(&WasmCapability::WasiEnv));
273 }
274
275 #[test]
276 fn test_invalid_wasm_magic() {
277 let result = WasmModule::from_bytes(INVALID_MAGIC_WASM.to_vec());
278 assert!(result.is_err());
279 assert!(matches!(result.unwrap_err(), WasmError::InvalidFormat(_)));
280 }
281
282 #[test]
283 fn test_empty_bytes() {
284 let empty_bytes = vec![];
285 let result = WasmModule::from_bytes(empty_bytes);
286 assert!(result.is_err());
287 assert!(matches!(result.unwrap_err(), WasmError::InvalidFormat(_)));
288 }
289
290 #[test]
291 fn test_module_too_large() {
292 let mut large_bytes = vec![0x00, 0x61, 0x73, 0x6d]; large_bytes.extend(vec![0x01, 0x00, 0x00, 0x00]); large_bytes.extend(vec![0x00; 65 * 1024 * 1024]); let result = WasmModule::from_bytes(large_bytes);
298 assert!(result.is_err());
299 assert!(matches!(result.unwrap_err(), WasmError::ModuleValidation(_)));
300 }
301
302 #[test]
303 fn test_hash_calculation() {
304 let module1 = WasmModule::from_bytes(minimal_wasm().to_vec()).unwrap();
305 let module2 = WasmModule::from_bytes(simple_function_wasm().to_vec()).unwrap();
306
307 assert_ne!(module1.hash(), module2.hash());
309
310 let module1_copy = WasmModule::from_bytes(minimal_wasm().to_vec()).unwrap();
312 assert_eq!(module1.hash(), module1_copy.hash());
313 }
314
315 #[test]
316 fn test_capability_detection() {
317 let wasi_module = WasmModule::from_bytes(wasi_hello_wasm().to_vec()).unwrap();
318
319 assert!(wasi_module.requires_capability(&WasmCapability::WasiStdio));
321
322 assert!(wasi_module.requires_capability(&WasmCapability::WasiEnv));
324
325 assert!(!wasi_module.requires_capability(&WasmCapability::WasiNet));
327
328 let simple_module = WasmModule::from_bytes(simple_function_wasm().to_vec()).unwrap();
329
330 assert!(!simple_module.requires_capability(&WasmCapability::WasiStdio));
332 }
333
334 #[test]
335 fn test_compiled_module_caching() {
336 let mut module = WasmModule::from_bytes(minimal_wasm().to_vec()).unwrap();
337 let engine = wasmtime::Engine::default();
338
339 let _compiled1 = module.get_compiled(&engine).unwrap();
341
342 assert!(module.compiled.is_some());
344
345 let _compiled2 = module.get_compiled(&engine).unwrap();
347
348 assert!(module.compiled.is_some());
350 }
351
352 #[test]
353 fn test_from_file_nonexistent() {
354 let result = WasmModule::from_file("/nonexistent/path/module.wasm");
355 assert!(result.is_err());
356 assert!(matches!(result.unwrap_err(), WasmError::Io(_)));
357 }
358}