Skip to main content

smg_wasm/
runtime.rs

1//! WASM Runtime
2//!
3//! Manages WASM component execution using wasmtime with async support.
4//! Provides a thread pool for concurrent WASM execution and metrics tracking.
5
6use std::{
7    num::NonZeroUsize,
8    sync::{
9        atomic::{AtomicU64, Ordering},
10        Arc,
11    },
12    time::Duration,
13};
14
15use lru::LruCache;
16use tokio::sync::oneshot;
17use tracing::{debug, error};
18use wasmtime::{
19    component::{Component, Linker, ResourceTable},
20    Config, Engine, InstanceAllocationStrategy, PoolingAllocationConfig, Store, StoreLimitsBuilder,
21};
22use wasmtime_wasi::WasiCtx;
23
24/// Epoch increment interval in milliseconds.
25/// Epochs are used for cooperative timeout enforcement in WASM execution.
26/// A smaller interval gives finer-grained timeout control but slightly more overhead.
27const EPOCH_INTERVAL_MS: u64 = 100;
28
29use crate::{
30    config::WasmRuntimeConfig,
31    errors::{Result, WasmError, WasmRuntimeError},
32    module::{MiddlewareAttachPoint, WasmModuleAttachPoint},
33    spec::Smg,
34    types::{WasiState, WasmComponentInput, WasmComponentOutput},
35};
36
37pub struct WasmRuntime {
38    config: WasmRuntimeConfig,
39    thread_pool: Arc<WasmThreadPool>,
40    // Metrics
41    total_executions: AtomicU64,
42    successful_executions: AtomicU64,
43    failed_executions: AtomicU64,
44    total_execution_time_ms: AtomicU64,
45    max_execution_time_ms: AtomicU64,
46}
47
48pub struct WasmThreadPool {
49    sender: async_channel::Sender<WasmTask>,
50    receiver: async_channel::Receiver<WasmTask>,
51    workers: Vec<std::thread::JoinHandle<()>>,
52    // Metrics
53    total_tasks: AtomicU64,
54    completed_tasks: AtomicU64,
55    failed_tasks: AtomicU64,
56}
57
58pub enum WasmTask {
59    ExecuteComponent {
60        /// SHA256 hash of the WASM bytes, used as the cache key for compiled
61        /// components. This avoids hashing the full `Vec<u8>` on every LRU lookup.
62        sha256_hash: [u8; 32],
63        /// WASM component bytes wrapped in Arc to avoid cloning the full bytes
64        /// on every request. Only read on cache miss (first compilation).
65        wasm_bytes: Arc<Vec<u8>>,
66        attach_point: WasmModuleAttachPoint,
67        input: WasmComponentInput,
68        response: oneshot::Sender<Result<WasmComponentOutput>>,
69    },
70}
71
72impl WasmRuntime {
73    pub fn new(config: WasmRuntimeConfig) -> Self {
74        let thread_pool = Arc::new(WasmThreadPool::new(config.clone()));
75
76        Self {
77            config,
78            thread_pool,
79            total_executions: AtomicU64::new(0),
80            successful_executions: AtomicU64::new(0),
81            failed_executions: AtomicU64::new(0),
82            total_execution_time_ms: AtomicU64::new(0),
83            max_execution_time_ms: AtomicU64::new(0),
84        }
85    }
86
87    pub fn with_default_config() -> Self {
88        Self::new(WasmRuntimeConfig::default())
89    }
90
91    pub fn get_config(&self) -> &WasmRuntimeConfig {
92        &self.config
93    }
94
95    /// get available cpu count and max recommended cpu count
96    pub fn get_cpu_info() -> (usize, usize) {
97        let cpu_count = std::thread::available_parallelism()
98            .map(|n| n.get())
99            .unwrap_or(4);
100        let max_recommended = cpu_count.max(1);
101        (cpu_count, max_recommended)
102    }
103
104    /// get current thread pool status
105    pub fn get_thread_pool_info(&self) -> (usize, usize) {
106        let (_cpu_count, max_recommended) = Self::get_cpu_info();
107        let current_workers = self.thread_pool.workers.len();
108        (current_workers, max_recommended)
109    }
110
111    /// Execute WASM component using WASM interface based on attach_point
112    pub async fn execute_component_async(
113        &self,
114        sha256_hash: [u8; 32],
115        wasm_bytes: Arc<Vec<u8>>,
116        attach_point: WasmModuleAttachPoint,
117        input: WasmComponentInput,
118    ) -> Result<WasmComponentOutput> {
119        let start_time = std::time::Instant::now();
120        let (response_tx, response_rx) = oneshot::channel();
121
122        let task = WasmTask::ExecuteComponent {
123            sha256_hash,
124            wasm_bytes,
125            attach_point,
126            input,
127            response: response_tx,
128        };
129
130        self.thread_pool.sender.send(task).await.map_err(|e| {
131            WasmRuntimeError::CallFailed(format!("Failed to send task to thread pool: {e}"))
132        })?;
133
134        let result = response_rx.await.map_err(|e| {
135            WasmRuntimeError::CallFailed(format!(
136                "Failed to receive response from thread pool: {e}"
137            ))
138        })?;
139
140        let execution_time_ms = start_time.elapsed().as_millis() as u64;
141        self.total_executions.fetch_add(1, Ordering::Relaxed);
142        self.total_execution_time_ms
143            .fetch_add(execution_time_ms, Ordering::Relaxed);
144        // Update max execution time
145        self.max_execution_time_ms
146            .fetch_max(execution_time_ms, Ordering::Relaxed);
147
148        if result.is_ok() {
149            self.successful_executions.fetch_add(1, Ordering::Relaxed);
150        } else {
151            self.failed_executions.fetch_add(1, Ordering::Relaxed);
152        }
153
154        result
155    }
156
157    /// Get current metrics
158    pub fn get_metrics(&self) -> (u64, u64, u64, u64, u64) {
159        (
160            self.total_executions.load(Ordering::Relaxed),
161            self.successful_executions.load(Ordering::Relaxed),
162            self.failed_executions.load(Ordering::Relaxed),
163            self.total_execution_time_ms.load(Ordering::Relaxed),
164            self.max_execution_time_ms.load(Ordering::Relaxed),
165        )
166    }
167}
168
169/// Maps a wasmtime error to a WasmError, detecting epoch interruption (timeout) traps.
170fn map_wasm_error(e: wasmtime::Error, timeout_ms: u64) -> WasmError {
171    // Use proper trap code detection instead of brittle string matching.
172    // Wasmtime uses Trap::Interrupt for epoch-based interruptions.
173    if e.downcast_ref::<wasmtime::Trap>() == Some(&wasmtime::Trap::Interrupt) {
174        WasmError::from(WasmRuntimeError::Timeout(timeout_ms))
175    } else {
176        WasmError::from(WasmRuntimeError::CallFailed(e.to_string()))
177    }
178}
179
180impl WasmThreadPool {
181    pub fn new(config: WasmRuntimeConfig) -> Self {
182        let (sender, receiver) = async_channel::unbounded();
183
184        let mut workers = Vec::new();
185        // set thread pool size based on cpu count
186        let max_workers = std::thread::available_parallelism()
187            .map(|n| n.get())
188            .unwrap_or(4)
189            .max(1);
190        let num_workers = config.thread_pool_size.clamp(1, max_workers);
191
192        debug!(
193            target: "smg::wasm::runtime",
194            "Initializing WASM runtime with {} workers",
195            num_workers
196        );
197
198        for worker_id in 0..num_workers {
199            let receiver = receiver.clone();
200            let config = config.clone();
201
202            let worker = std::thread::spawn(move || {
203                // create independent tokio runtime for this thread
204                let rt = match tokio::runtime::Runtime::new() {
205                    Ok(rt) => rt,
206                    Err(e) => {
207                        error!(
208                            target: "smg::wasm::runtime",
209                            worker_id = worker_id,
210                            "Failed to create tokio runtime: {}",
211                            e
212                        );
213                        return;
214                    }
215                };
216
217                rt.block_on(async {
218                    Self::worker_loop(worker_id, receiver, config).await;
219                });
220            });
221
222            workers.push(worker);
223        }
224
225        Self {
226            sender,
227            receiver,
228            workers,
229            total_tasks: AtomicU64::new(0),
230            completed_tasks: AtomicU64::new(0),
231            failed_tasks: AtomicU64::new(0),
232        }
233    }
234
235    /// Get current thread pool metrics
236    pub fn get_metrics(&self) -> (u64, u64, u64) {
237        (
238            self.total_tasks.load(Ordering::Relaxed),
239            self.completed_tasks.load(Ordering::Relaxed),
240            self.failed_tasks.load(Ordering::Relaxed),
241        )
242    }
243
244    async fn worker_loop(
245        worker_id: usize,
246        receiver: async_channel::Receiver<WasmTask>,
247        config: WasmRuntimeConfig,
248    ) {
249        debug!(
250            target: "smg::wasm::runtime",
251            worker_id = worker_id,
252            thread_id = ?std::thread::current().id(),
253            "Worker started"
254        );
255
256        let mut pool_config = PoolingAllocationConfig::default();
257        let max_memory_bytes = (config.max_memory_pages as usize) * 65536;
258
259        // Since this thread handles tasks sequentially, we don't need a large pool per thread.
260        // A pool size of 20 allows for efficient reuse without hogging memory.
261        pool_config.total_core_instances(20);
262        pool_config.max_memory_size(max_memory_bytes);
263        pool_config.max_component_instance_size(max_memory_bytes);
264        pool_config.max_tables_per_component(5);
265
266        let mut wasmtime_config = Config::new();
267        wasmtime_config.allocation_strategy(InstanceAllocationStrategy::Pooling(pool_config));
268
269        wasmtime_config.async_stack_size(config.max_stack_size);
270        wasmtime_config.async_support(true);
271        wasmtime_config.wasm_component_model(true); // Enable component model
272        wasmtime_config.epoch_interruption(true); // Enable epoch-based timeout interruption
273
274        let engine = match Engine::new(&wasmtime_config) {
275            Ok(engine) => engine,
276            Err(e) => {
277                error!(
278                    target: "smg::wasm::runtime",
279                    worker_id = worker_id,
280                    "Failed to create engine: {}",
281                    e
282                );
283                return;
284            }
285        };
286        let mut linker = Linker::<WasiState>::new(&engine);
287        if let Err(e) = wasmtime_wasi::p2::add_to_linker_async(&mut linker) {
288            error!(
289                target: "smg::wasm::runtime",
290                worker_id = worker_id,
291                "Failed to add WASI to linker: {}",
292                e
293            );
294            return;
295        }
296
297        // SAFETY: 10 is a non-zero literal, so NonZeroUsize::new(10) always returns Some.
298        let default_capacity = NonZeroUsize::new(10).unwrap_or(NonZeroUsize::MIN);
299        let cache_capacity =
300            NonZeroUsize::new(config.module_cache_size).unwrap_or(default_capacity);
301        let mut component_cache: LruCache<[u8; 32], Component> = LruCache::new(cache_capacity);
302
303        // Start epoch incrementer for timeout enforcement.
304        // The engine's epoch counter is incremented periodically, and each Store
305        // can set a deadline (number of epochs). When the deadline is reached,
306        // WASM execution is interrupted with a trap.
307        let engine_for_epoch = engine.clone();
308        #[expect(
309            clippy::disallowed_methods,
310            reason = "epoch interrupt handler must run as independent background task; abort on drop ensures cleanup"
311        )]
312        let epoch_handle = tokio::spawn(async move {
313            let mut interval = tokio::time::interval(Duration::from_millis(EPOCH_INTERVAL_MS));
314            loop {
315                interval.tick().await;
316                engine_for_epoch.increment_epoch();
317            }
318        });
319
320        debug!(
321            target: "smg::wasm::runtime",
322            worker_id = worker_id,
323            epoch_interval_ms = EPOCH_INTERVAL_MS,
324            "Epoch incrementer started for timeout enforcement"
325        );
326
327        loop {
328            let task = match receiver.recv().await {
329                Ok(task) => task,
330                Err(_) => {
331                    debug!(
332                        target: "smg::wasm::runtime",
333                        worker_id = worker_id,
334                        "Worker shutting down"
335                    );
336                    epoch_handle.abort(); // Stop the epoch incrementer
337                    break; // channel closed, exit loop
338                }
339            };
340
341            match task {
342                WasmTask::ExecuteComponent {
343                    sha256_hash,
344                    wasm_bytes,
345                    attach_point,
346                    input,
347                    response,
348                } => {
349                    let result = Self::execute_component_in_worker(
350                        &engine,
351                        &linker,
352                        &mut component_cache,
353                        sha256_hash,
354                        &wasm_bytes,
355                        attach_point,
356                        input,
357                        &config,
358                    )
359                    .await;
360
361                    let _ = response.send(result);
362                }
363            }
364        }
365    }
366
367    #[expect(clippy::too_many_arguments)]
368    async fn execute_component_in_worker(
369        engine: &Engine,
370        linker: &Linker<WasiState>,
371        cache: &mut LruCache<[u8; 32], Component>,
372        sha256_hash: [u8; 32],
373        wasm_bytes: &[u8],
374        attach_point: WasmModuleAttachPoint,
375        input: WasmComponentInput,
376        config: &WasmRuntimeConfig,
377    ) -> Result<WasmComponentOutput> {
378        // Compile component from bytes OR retrieve from cache.
379        // Cache is keyed by SHA256 hash (~20ns lookup) instead of raw Vec<u8>
380        // (~24µs for a 500KB module), a 1200× improvement.
381        let component = if let Some(comp) = cache.get(&sha256_hash) {
382            comp.clone() // Component is just a handle (cheap clone)
383        } else {
384            // Compile new component
385            let comp = Component::new(engine, wasm_bytes).map_err(|e| {
386                WasmRuntimeError::CompileFailed(format!(
387                    "failed to parse WebAssembly component: {e}. \
388                     Hint: The WASM file must be in component format. \
389                     If you're using wit-bindgen, use 'wasm-tools component new' to wrap the WASM module into a component."
390                ))
391            })?;
392
393            cache.push(sha256_hash, comp.clone());
394            comp
395        };
396
397        let mut builder = WasiCtx::builder();
398
399        // Create memory limits from config.
400        // Use the config helper to get total bytes, then safely convert to usize.
401        let memory_limit_bytes =
402            usize::try_from(config.get_total_memory_bytes()).map_err(|_| {
403                WasmError::from(WasmRuntimeError::CallFailed(
404                    "Configured WASM memory limit exceeds addressable space on this platform."
405                        .to_string(),
406                ))
407            })?;
408        let limits = StoreLimitsBuilder::new()
409            .memory_size(memory_limit_bytes)
410            .trap_on_grow_failure(true) // Trap instead of returning -1 for easier debugging
411            .build();
412
413        let mut store = Store::new(
414            engine,
415            WasiState {
416                ctx: builder.build(),
417                table: ResourceTable::new(),
418                limits,
419            },
420        );
421
422        // Apply resource limits to the store.
423        // This enforces max_memory_pages by preventing memory.grow beyond the limit.
424        store.limiter(|state| &mut state.limits);
425
426        // Set epoch deadline for timeout enforcement.
427        // The deadline is the number of epoch ticks before execution is interrupted.
428        // With EPOCH_INTERVAL_MS=100ms and max_execution_time_ms=1000ms, deadline=10 epochs.
429        let deadline_epochs = (config.max_execution_time_ms / EPOCH_INTERVAL_MS).max(1);
430        store.set_epoch_deadline(deadline_epochs);
431
432        // When the epoch deadline is reached, trap to enforce the execution timeout.
433        store.epoch_deadline_callback(|_store| {
434            Err(wasmtime::Error::msg("execution time limit exceeded"))
435        });
436
437        let output = match attach_point {
438            WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest) => {
439                let request = match input {
440                    WasmComponentInput::MiddlewareRequest(req) => req,
441                    WasmComponentInput::MiddlewareResponse(_) => {
442                        return Err(WasmError::from(WasmRuntimeError::CallFailed(
443                            "Expected MiddlewareRequest input for OnRequest attach point"
444                                .to_string(),
445                        )));
446                    }
447                };
448
449                // Instantiate component (must use async instantiation when async support is enabled)
450                let bindings = Smg::instantiate_async(&mut store, &component, linker)
451                    .await
452                    .map_err(|e| {
453                        WasmError::from(WasmRuntimeError::InstanceCreateFailed(e.to_string()))
454                    })?;
455
456                // Call on-request (async call when async support is enabled)
457                let action_result = bindings
458                    .smg_gateway_middleware_on_request()
459                    .call_on_request(&mut store, &request)
460                    .await
461                    .map_err(|e| map_wasm_error(e, config.max_execution_time_ms))?;
462
463                WasmComponentOutput::MiddlewareAction(action_result)
464            }
465            WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse) => {
466                // Extract Response input
467                let response = match input {
468                    WasmComponentInput::MiddlewareResponse(resp) => resp,
469                    WasmComponentInput::MiddlewareRequest(_) => {
470                        return Err(WasmError::from(WasmRuntimeError::CallFailed(
471                            "Expected MiddlewareResponse input for OnResponse attach point"
472                                .to_string(),
473                        )));
474                    }
475                };
476
477                // Instantiate component (must use async instantiation when async support is enabled)
478                let bindings = Smg::instantiate_async(&mut store, &component, linker)
479                    .await
480                    .map_err(|e| {
481                        WasmError::from(WasmRuntimeError::InstanceCreateFailed(e.to_string()))
482                    })?;
483
484                // Call on-response (async call when async support is enabled)
485                let action_result = bindings
486                    .smg_gateway_middleware_on_response()
487                    .call_on_response(&mut store, &response)
488                    .await
489                    .map_err(|e| map_wasm_error(e, config.max_execution_time_ms))?;
490
491                WasmComponentOutput::MiddlewareAction(action_result)
492            }
493            WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnError) => {
494                return Err(WasmError::from(WasmRuntimeError::CallFailed(
495                    "OnError attach point not yet implemented".to_string(),
496                )));
497            }
498        };
499
500        Ok(output)
501    }
502}
503
504impl Drop for WasmThreadPool {
505    fn drop(&mut self) {
506        // close sender and receiver
507        self.sender.close();
508        self.receiver.close();
509
510        // wait for all workers to complete
511        for worker in self.workers.drain(..) {
512            let _ = worker.join();
513        }
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use std::{num::NonZeroUsize, time::Instant};
520
521    use lru::LruCache;
522
523    use super::*;
524    use crate::config::WasmRuntimeConfig;
525
526    #[test]
527    fn test_get_cpu_info() {
528        let (cpu_count, max_recommended) = WasmRuntime::get_cpu_info();
529        assert!(cpu_count > 0);
530        assert!(max_recommended > 0);
531        assert!(max_recommended >= cpu_count);
532    }
533
534    #[test]
535    fn test_config_default_values() {
536        let config = WasmRuntimeConfig::default();
537
538        assert_eq!(config.max_memory_pages, 1024);
539        assert_eq!(config.max_execution_time_ms, 1000);
540        assert_eq!(config.max_stack_size, 1024 * 1024);
541        assert!(config.thread_pool_size > 0);
542        assert_eq!(config.module_cache_size, 10);
543    }
544
545    #[test]
546    fn test_config_clone() {
547        let config = WasmRuntimeConfig::default();
548        let cloned_config = config.clone();
549
550        assert_eq!(config.max_memory_pages, cloned_config.max_memory_pages);
551        assert_eq!(
552            config.max_execution_time_ms,
553            cloned_config.max_execution_time_ms
554        );
555        assert_eq!(config.max_stack_size, cloned_config.max_stack_size);
556        assert_eq!(config.thread_pool_size, cloned_config.thread_pool_size);
557        assert_eq!(config.module_cache_size, cloned_config.module_cache_size);
558    }
559    #[test]
560    fn test_wasm_instantiation_performance_threshold() {
561        // A simple WASM module forcing memory allocation
562        const WASM_WAT: &str = r#"
563            (module
564                (memory (export "memory") 1)
565                (func (export "run") (param i32 i32) (result i32)
566                    local.get 0
567                    local.get 1
568                    i32.add)
569            )
570        "#;
571
572        let iterations = 1000;
573
574        //  Scenario A: Baseline (No Pool, No Cache)
575        let engine_standard = Engine::default();
576        let start_standard = Instant::now();
577        for _ in 0..iterations {
578            // Simulate compilation + instantiation overhead
579            let module = wasmtime::Module::new(&engine_standard, WASM_WAT).unwrap();
580            let mut store = Store::new(&engine_standard, ());
581            let instance = wasmtime::Instance::new(&mut store, &module, &[]).unwrap();
582            let run_func = instance
583                .get_typed_func::<(i32, i32), i32>(&mut store, "run")
584                .unwrap();
585            let _ = run_func.call(&mut store, (10, 20)).unwrap();
586        }
587        let duration_standard = start_standard.elapsed();
588
589        // --- Scenario B: Optimized (Pool + Cache)
590        let mut pool_config = PoolingAllocationConfig::default();
591
592        pool_config.total_core_instances(100);
593
594        let mut config = Config::new();
595        config.allocation_strategy(InstanceAllocationStrategy::Pooling(pool_config));
596
597        let engine_pooled = Engine::new(&config).unwrap();
598
599        // Setup LRU Cache
600        let cache_capacity = NonZeroUsize::new(100).unwrap();
601        let mut cache: LruCache<Vec<u8>, wasmtime::Module> = LruCache::new(cache_capacity);
602
603        // Pre-warm cache (simulating the "cached" state)
604        let key = WASM_WAT.as_bytes().to_vec();
605        let module_compiled = wasmtime::Module::new(&engine_pooled, WASM_WAT).unwrap();
606        cache.push(key.clone(), module_compiled);
607
608        let start_pooled = Instant::now();
609        for _ in 0..iterations {
610            let module = cache.get(&key).unwrap().clone();
611            let mut store = Store::new(&engine_pooled, ());
612            let instance = wasmtime::Instance::new(&mut store, &module, &[]).unwrap();
613            let run_func = instance
614                .get_typed_func::<(i32, i32), i32>(&mut store, "run")
615                .unwrap();
616            let _ = run_func.call(&mut store, (10, 20)).unwrap();
617        }
618        let duration_pooled = start_pooled.elapsed();
619
620        // Verify Speedup
621        let standard_secs = duration_standard.as_secs_f64();
622        let pooled_secs = duration_pooled.as_secs_f64();
623
624        if pooled_secs > 0.0 {
625            let speedup = standard_secs / pooled_secs;
626
627            assert!(
628                speedup > 5.0,
629                "Optimization regression: Pooling+Caching was only {speedup:.2}x faster",
630            );
631        }
632    }
633}