Skip to main content

palladium_plugin/
wasm_actor.rs

1//! [`WasmActor`] — an `Actor` implementation backed by a WASM module instance.
2//!
3//! # Lifecycle
4//!
5//! 1. A `WasmActor` is constructed by [`PluginRegistry::create_actor`] (or
6//!    directly in tests) with a compiled [`WasmModule`] and a [`WasmHost`].
7//! 2. `on_start` instantiates the module with real [`WasmImports`] wired to
8//!    the outbox, sets the initial fuel budget, and calls the module's
9//!    `pd_actor_on_start` export (if present).
10//! 3. `on_message` writes the incoming envelope + payload into WASM linear
11//!    memory, calls the module's `pd_actor_on_message` export, then drains the
12//!    outbox by forwarding each queued message through `ctx.send_raw()`.
13//! 4. `on_stop` calls the module's `pd_actor_on_stop` export (if present).
14//!
15//! # Outbox pattern
16//!
17//! Because `Actor` is not `Send`, the `pd_send` host import can use an
18//! `Arc<Mutex<…>>` outbox that is `Send + Sync + 'static` (required by the
19//! wasmtime linker) while remaining accessible to the single-threaded actor.
20//! After each WASM call the outbox is drained synchronously.
21//!
22//! # Memory layout
23//!
24//! The host reserves two slots in WASM linear memory for passing messages:
25//!
26//! | Offset | Size | Contents |
27//! |--------|------|----------|
28//! | 0      | 80   | Serialised `Envelope` (`Envelope::to_bytes()`) |
29//! | 80     | ≤ payload | Raw payload bytes |
30//!
31//! These offsets are passed as `i32` arguments to the WASM exports.
32
33use std::sync::{Arc, Mutex};
34
35use palladium_actor::{Actor, ActorContext, ActorError, Envelope, MessagePayload, StopReason};
36
37use crate::wasm::{WasmError, WasmHost, WasmImports, WasmInstance, WasmModule, WasmVal};
38
39// ── Memory layout constants ───────────────────────────────────────────────────
40
41/// Offset in WASM linear memory where the host writes the incoming envelope.
42const ENVELOPE_OFFSET: u32 = 0;
43
44/// Offset in WASM linear memory where the host writes the incoming payload.
45const PAYLOAD_OFFSET: u32 = 80; // immediately after the 80-byte envelope
46
47// ── WasmActor ────────────────────────────────────────────────────────────────
48
49/// An actor whose lifecycle hooks are implemented in a WASM module.
50///
51/// Construct via [`WasmActor::new`] and spawn with the normal `ChildSpec`
52/// factory API.  The actor creates its WASM instance lazily in `on_start`.
53pub struct WasmActor<R: palladium_actor::Reactor> {
54    /// Shared compiled module (may be reused across multiple actor instances).
55    module: Arc<Box<dyn WasmModule>>,
56    /// WASM runtime host — used to instantiate the module in `on_start`.
57    host: Arc<dyn WasmHost>,
58    /// Live WASM instance; `None` until `on_start` has run.
59    instance: Option<Box<dyn WasmInstance>>,
60    /// Opaque state handle passed to every WASM lifecycle call.
61    /// Returned by `pd_actor_create` if that export exists; otherwise 0.
62    state_handle: i32,
63    /// Buffer populated by the `pd_send` import during a WASM call.
64    /// Drained by `on_message` after the call returns.
65    outbox: Arc<Mutex<Vec<(Envelope, MessagePayload)>>>,
66    /// Fuel units granted per `on_message` call.  0 = use the instance default.
67    pub fuel_per_message: u64,
68    _phantom: std::marker::PhantomData<R>,
69}
70
71// SAFETY: WasmActor is managed by the pd-runtime which ensures it only ever
72// runs on a single core's local executor. Box<dyn Actor> requires Send so
73// we can move it into the spawn_local task.
74unsafe impl<R: palladium_actor::Reactor> Send for WasmActor<R> {}
75
76impl<R: palladium_actor::Reactor> WasmActor<R> {
77    /// Create a new actor backed by `module`.
78    ///
79    /// * `module` — a compiled (but not yet instantiated) WASM module.
80    /// * `host` — the WASM runtime that will instantiate the module in
81    ///   `on_start`.
82    /// * `fuel_per_message` — fuel budget granted before each `on_message`
83    ///   call. Use `0` to rely on the instance's built-in default.
84    pub fn new(
85        module: Arc<Box<dyn WasmModule>>,
86        host: Arc<dyn WasmHost>,
87        fuel_per_message: u64,
88    ) -> Self {
89        Self {
90            module,
91            host,
92            instance: None,
93            state_handle: 0,
94            outbox: Arc::new(Mutex::new(Vec::new())),
95            fuel_per_message,
96            _phantom: std::marker::PhantomData,
97        }
98    }
99
100    /// Build a [`WasmImports`] struct with the `pd_send` closure wired to this
101    /// actor's outbox.
102    fn make_imports(&self) -> WasmImports {
103        let outbox = Arc::clone(&self.outbox);
104        WasmImports {
105            pd_send: Box::new(move |env_ptr, env_len, payload_ptr, payload_len| {
106                if env_ptr.is_null() || env_len as usize != Envelope::SIZE {
107                    return -1;
108                }
109                // Safety: env_ptr is valid for `env_len` bytes — guaranteed by
110                // the WasmtimeInstance func_wrap wrapper, which copies the bytes
111                // from WASM linear memory immediately before this call.
112                let env_bytes: [u8; Envelope::SIZE] =
113                    unsafe { *(env_ptr as *const [u8; Envelope::SIZE]) };
114                let envelope = Envelope::from_bytes(&env_bytes);
115
116                let payload = if payload_len > 0 && !payload_ptr.is_null() {
117                    // Safety: same guarantee — wasmtime validated bounds.
118                    let slice =
119                        unsafe { std::slice::from_raw_parts(payload_ptr, payload_len as usize) };
120                    MessagePayload::serialized(slice.to_vec())
121                } else {
122                    MessagePayload::serialized(Vec::new())
123                };
124
125                outbox.lock().unwrap().push((envelope, payload));
126                0
127            }),
128            pd_now_micros: Box::new(move || {
129                // In simulation, this should ideally be virtual wall-clock time.
130                // For now, we return a value that is deterministic if the reactor is deterministic.
131                // TODO: Add wall-clock support to Reactor trait.
132                0
133            }),
134            pd_log: Box::new(|_level, _msg_ptr, _msg_len| {}),
135        }
136    }
137
138    /// Return the exports of the compiled module.
139    fn exports(&self) -> Vec<String> {
140        self.module.exports()
141    }
142
143    /// Drain the outbox and forward each message through `ctx.send_raw()`.
144    /// Returns Ok(()) if all messages were sent, or Err(ActorError::Handler)
145    /// if any send failed (e.g. mailbox full).
146    fn drain_outbox(&self, ctx: &mut ActorContext<R>) -> Result<(), ActorError> {
147        let messages = {
148            let mut lock = self.outbox.lock().unwrap();
149            std::mem::take(&mut *lock)
150        };
151        for (env, payload) in messages.into_iter() {
152            ctx.send_raw(env, payload)
153                .map_err(|_| ActorError::Handler)?;
154        }
155        Ok(())
156    }
157
158    /// Map a [`WasmError`] from a lifecycle call to an [`ActorError`].
159    fn map_wasm_err(e: WasmError) -> ActorError {
160        match e {
161            WasmError::FuelExhausted => ActorError::ResourceExhausted,
162            WasmError::Trap(msg) => ActorError::WasmTrap(msg),
163            _ => ActorError::Handler,
164        }
165    }
166}
167
168impl<R: palladium_actor::Reactor> Actor<R> for WasmActor<R> {
169    fn on_start(&mut self, _ctx: &mut ActorContext<R>) -> Result<(), ActorError> {
170        // Pre-check exports before borrowing self.instance mutably.
171        let has_on_start = self.exports().contains(&"pd_actor_on_start".to_string());
172
173        // Instantiate with wired-up imports.
174        let imports = self.make_imports();
175        let instance = self
176            .host
177            .instantiate(self.module.as_ref().as_ref(), imports)
178            .map_err(|_| ActorError::Init)?;
179        self.instance = Some(instance);
180
181        let inst = self.instance.as_mut().unwrap();
182
183        // Set initial fuel budget if requested.
184        if self.fuel_per_message > 0 {
185            inst.set_fuel(self.fuel_per_message)
186                .map_err(|_| ActorError::Init)?;
187        }
188
189        // Call pd_actor_on_start if the module exports it.
190        if has_on_start {
191            let rets = inst
192                .call(
193                    "pd_actor_on_start",
194                    &[WasmVal::I32(self.state_handle), WasmVal::I32(0)],
195                )
196                .map_err(Self::map_wasm_err)?;
197            let code = first_i32(&rets);
198            if code != 0 {
199                return Err(ActorError::Init);
200            }
201        }
202
203        Ok(())
204    }
205
206    fn on_message(
207        &mut self,
208        ctx: &mut ActorContext<R>,
209        envelope: &Envelope,
210        payload: MessagePayload,
211    ) -> Result<(), ActorError> {
212        let inst = match self.instance.as_mut() {
213            Some(i) => i,
214            None => return Err(ActorError::Init),
215        };
216
217        // Refresh fuel for this message.
218        if self.fuel_per_message > 0 {
219            inst.set_fuel(self.fuel_per_message)
220                .map_err(|_| ActorError::ResourceExhausted)?;
221        }
222
223        // Write envelope to WASM memory at ENVELOPE_OFFSET.
224        let env_bytes = envelope.to_bytes();
225        inst.memory_write(ENVELOPE_OFFSET, &env_bytes)
226            .map_err(|_| ActorError::Handler)?;
227
228        // Write payload bytes to WASM memory at PAYLOAD_OFFSET.
229        let payload_slice: &[u8] = match &payload {
230            MessagePayload::Serialized(b) => b.as_ref(),
231            MessagePayload::Local(_) => &[],
232        };
233        let payload_len = payload_slice.len() as i32;
234        if payload_len > 0 {
235            inst.memory_write(PAYLOAD_OFFSET, payload_slice)
236                .map_err(|_| ActorError::Handler)?;
237        }
238
239        // Call pd_actor_on_message.
240        let payload_ptr = if payload_len > 0 {
241            PAYLOAD_OFFSET as i32
242        } else {
243            0
244        };
245        let rets = inst
246            .call(
247                "pd_actor_on_message",
248                &[
249                    WasmVal::I32(self.state_handle),
250                    WasmVal::I32(0), // ctx: unused in WASM
251                    WasmVal::I32(ENVELOPE_OFFSET as i32),
252                    WasmVal::I32(payload_ptr),
253                    WasmVal::I32(payload_len),
254                ],
255            )
256            .map_err(Self::map_wasm_err)?;
257
258        if first_i32(&rets) != 0 {
259            return Err(ActorError::Handler);
260        }
261
262        // Drain outbox: deliver queued sends through the runtime bridge.
263        self.drain_outbox(ctx)
264    }
265
266    fn on_stop(&mut self, _ctx: &mut ActorContext<R>, reason: StopReason) {
267        // Pre-check before borrowing self.instance mutably.
268        let has_on_stop = self.exports().contains(&"pd_actor_on_stop".to_string());
269
270        let inst = match self.instance.as_mut() {
271            Some(i) => i,
272            None => return,
273        };
274
275        if !has_on_stop {
276            return;
277        }
278
279        let reason_code = match reason {
280            StopReason::Normal => 0i32,
281            StopReason::Requested => 1,
282            StopReason::Supervisor => 2,
283            StopReason::Hierarchical => 2,
284            StopReason::Shutdown => 3,
285            StopReason::Killed => 4,
286            StopReason::Error(_) => -1,
287        };
288
289        // Best-effort: ignore errors in on_stop.
290        let _ = inst.call(
291            "pd_actor_on_stop",
292            &[
293                WasmVal::I32(self.state_handle),
294                WasmVal::I32(0),
295                WasmVal::I32(reason_code),
296            ],
297        );
298    }
299}
300
301// ── Helpers ───────────────────────────────────────────────────────────────────
302
303/// Extract the first `i32` return value, or 0 if absent.
304fn first_i32(rets: &[WasmVal]) -> i32 {
305    match rets.first() {
306        Some(WasmVal::I32(n)) => *n,
307        _ => 0,
308    }
309}