Skip to main content

mlua_swarm/middleware/
lua_layer.rs

1//! LuaSpawnerLayer — inserts middleware written in Lua into the `SpawnerStack`.
2//!
3//! Shape:
4//!   - `before_src` = a Lua source that evaluates to a function. Called
5//!     immediately before the spawn as `function(ctx_table) ... end`. If it
6//!     raises, the spawn is rejected with
7//!     `SpawnError::RejectedByMiddleware` (a gate).
8//!   - `after_src` = a Lua source that evaluates to
9//!     `function(ctx_table, result_table) return result' end`, called after
10//!     the worker finishes; the return value becomes the new result flowing
11//!     downstream (a transform).
12//!
13//! # Implementation axis
14//!
15//! `mlua::Lua` is `!Send`, so the initial form built a fresh `Lua::new()`
16//! per call and dropped it (per-call VM). Under high-frequency spawn the
17//! overhead becomes visible. This version switches to
18//! `mlua-isle::AsyncIslePool` (thread-isolated Lua VM + async + pool) and
19//! reuses VMs from the pool. The fully-async chain stays the same shape
20//! (no `block_on` / `spawn_blocking`); the `!Send` constraint is resolved
21//! inside the isle so it rides on the caller's tokio runtime directly.
22
23use crate::core::ctx::Ctx;
24use crate::core::engine::Engine;
25use crate::middleware::SpawnerLayer;
26use crate::types::{CapToken, TaskId};
27use crate::worker::adapter::{SpawnError, SpawnerAdapter, WorkerError};
28use crate::worker::{wrap_join, Worker};
29use async_trait::async_trait;
30use mlua::LuaSerdeExt;
31use mlua_isle::{AsyncIslePool, IsleError, PoolConfig, PoolStrategy};
32use serde_json::Value;
33use std::sync::{Arc, OnceLock};
34
35/// Default pool config (4 warm VMs reused). Callers that want their own
36/// pool can swap it in via `LuaMiddleware::with_pool`.
37fn default_pool_config() -> PoolConfig {
38    PoolConfig {
39        max_size: 4,
40        strategy: PoolStrategy::Warm,
41    }
42}
43
44fn build_default_pool() -> Arc<AsyncIslePool> {
45    Arc::new(
46        AsyncIslePool::new(|_lua| Ok(()), default_pool_config())
47            .expect("AsyncIslePool::new (no-op factory) must succeed"),
48    )
49}
50
51/// `SpawnerLayer` that runs Lua source as a before-gate and/or an
52/// after-transform around a spawn, executed on a pooled `AsyncIslePool`
53/// VM. See the module doc for the exact function shapes expected of
54/// `before_src` / `after_src`.
55#[derive(Clone, Default)]
56pub struct LuaMiddleware {
57    before_src: Option<String>,
58    after_src: Option<String>,
59    pool: Option<Arc<AsyncIslePool>>,
60}
61
62impl LuaMiddleware {
63    /// Empty layer — no before/after hooks, default pool.
64    pub fn new() -> Self {
65        Self::default()
66    }
67    /// Sets the before-hook source (`function(ctx_table) ... end`).
68    /// Raising from this function rejects the spawn.
69    pub fn before(mut self, src: impl Into<String>) -> Self {
70        self.before_src = Some(src.into());
71        self
72    }
73    /// Sets the after-hook source
74    /// (`function(ctx_table, result_table) return result end`). Its
75    /// return value replaces the worker's result.
76    pub fn after(mut self, src: impl Into<String>) -> Self {
77        self.after_src = Some(src.into());
78        self
79    }
80    /// Inject an externally-built AsyncIslePool. Useful when the caller
81    /// wants to share a pool with another Lua layer or carry VM
82    /// initialisation across calls.
83    pub fn with_pool(mut self, pool: Arc<AsyncIslePool>) -> Self {
84        self.pool = Some(pool);
85        self
86    }
87}
88
89impl SpawnerLayer for LuaMiddleware {
90    fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
91        // Pool resolution: caller injection wins over the process-wide default (built lazily, once).
92        static DEFAULT_POOL: OnceLock<Arc<AsyncIslePool>> = OnceLock::new();
93        let pool = self
94            .pool
95            .clone()
96            .unwrap_or_else(|| DEFAULT_POOL.get_or_init(build_default_pool).clone());
97        Arc::new(LuaWrapped {
98            inner,
99            before_src: self.before_src.clone(),
100            after_src: self.after_src.clone(),
101            pool,
102        })
103    }
104}
105
106struct LuaWrapped {
107    inner: Arc<dyn SpawnerAdapter>,
108    before_src: Option<String>,
109    after_src: Option<String>,
110    pool: Arc<AsyncIslePool>,
111}
112
113/// Serializable view of `Ctx` handed to Lua functions (`Arc<dyn>` fields are excluded since they cannot serde).
114fn ctx_to_serializable(ctx: &Ctx) -> Value {
115    serde_json::json!({
116        "task_id": ctx.task_id.0,
117        "attempt": ctx.attempt,
118        "agent": ctx.agent,
119        "operator": {
120            "kind": format!("{:?}", ctx.operator.kind),
121            "id": ctx.operator.id,
122        },
123    })
124}
125
126/// Helper that returns a closure running the "before hook" inside AsyncIsle.exec.
127/// Return value is fixed to the literal string "ok" — success / failure is
128/// expressed through IsleError.
129fn make_before_exec(
130    src: String,
131    ctx_json: String,
132) -> impl FnOnce(&mlua::Lua) -> Result<String, IsleError> + Send + 'static {
133    move |lua| {
134        let ctx_val: Value =
135            serde_json::from_str(&ctx_json).map_err(|e| IsleError::Lua(e.to_string()))?;
136        let ctx_lua: mlua::Value = lua
137            .to_value(&ctx_val)
138            .map_err(|e| IsleError::Lua(e.to_string()))?;
139        let f: mlua::Function = lua
140            .load(&src)
141            .eval()
142            .map_err(|e| IsleError::Lua(e.to_string()))?;
143        let _: mlua::Value = f.call(ctx_lua).map_err(|e| IsleError::Lua(e.to_string()))?;
144        Ok("ok".to_string())
145    }
146}
147
148/// Returns a closure that runs the "after hook" inside AsyncIsle.exec.
149/// Return value is the new result as a JSON string (`{"value": ..., "ok": bool}`).
150fn make_after_exec(
151    src: String,
152    ctx_json: String,
153    result_json: String,
154) -> impl FnOnce(&mlua::Lua) -> Result<String, IsleError> + Send + 'static {
155    move |lua| {
156        let ctx_val: Value =
157            serde_json::from_str(&ctx_json).map_err(|e| IsleError::Lua(e.to_string()))?;
158        let result_val: Value =
159            serde_json::from_str(&result_json).map_err(|e| IsleError::Lua(e.to_string()))?;
160        let ctx_lua: mlua::Value = lua
161            .to_value(&ctx_val)
162            .map_err(|e| IsleError::Lua(e.to_string()))?;
163        let result_lua: mlua::Value = lua
164            .to_value(&result_val)
165            .map_err(|e| IsleError::Lua(e.to_string()))?;
166        let f: mlua::Function = lua
167            .load(&src)
168            .eval()
169            .map_err(|e| IsleError::Lua(e.to_string()))?;
170        let returned: mlua::Value = f
171            .call((ctx_lua, result_lua))
172            .map_err(|e| IsleError::Lua(e.to_string()))?;
173        let new_result: Value = lua
174            .from_value(returned)
175            .map_err(|e| IsleError::Lua(e.to_string()))?;
176        serde_json::to_string(&new_result).map_err(|e| IsleError::Lua(e.to_string()))
177    }
178}
179
180#[async_trait]
181impl SpawnerAdapter for LuaWrapped {
182    async fn spawn(
183        &self,
184        engine: &Engine,
185        ctx: &Ctx,
186        task_id: TaskId,
187        attempt: u32,
188        token: CapToken,
189    ) -> Result<Box<dyn Worker>, SpawnError> {
190        // ─── before hook (= pool checkout → exec → return) ────────────────
191        if let Some(src) = &self.before_src {
192            let ctx_json = serde_json::to_string(&ctx_to_serializable(ctx))
193                .map_err(|e| SpawnError::Internal(format!("ctx serialize: {e}")))?;
194            let isle = self
195                .pool
196                .checkout()
197                .await
198                .map_err(|e| SpawnError::Internal(format!("isle pool checkout: {e}")))?;
199            let f = make_before_exec(src.clone(), ctx_json);
200            isle.exec(f)
201                .await
202                .map_err(|e| SpawnError::RejectedByMiddleware(format!("lua before: {e}")))?;
203            // The isle is dropped here — either returned to the pool (Warm) or shut down (Cold).
204        }
205
206        let engine_clone = engine.clone();
207        let token_clone = token.clone();
208        let task_id_clone = task_id.clone();
209        let handle = self
210            .inner
211            .spawn(engine, ctx, task_id, attempt, token)
212            .await?;
213
214        // ─── after hook ───────────────────────────────────────────────────
215        let Some(after_src) = self.after_src.clone() else {
216            return Ok(handle);
217        };
218        let ctx_val = ctx_to_serializable(ctx);
219        Ok(wrap_completion_with_lua_pool(
220            handle,
221            after_src,
222            ctx_val,
223            self.pool.clone(),
224            engine_clone,
225            token_clone,
226            task_id_clone,
227            attempt,
228        ))
229    }
230}
231
232/// Helper that wraps the completion signal and drives the Lua after-hook
233/// through the pool. Follows the signal-only design: pulls the value from
234/// `engine.output_tail`, and pushes the post-Lua `{value, ok}` (Lua-wire
235/// JSON) as an override Final via `engine.submit_output`.
236#[allow(clippy::too_many_arguments)]
237fn wrap_completion_with_lua_pool(
238    handle: Box<dyn Worker>,
239    after_src: String,
240    ctx_val: Value,
241    pool: Arc<AsyncIslePool>,
242    engine: Engine,
243    token: crate::types::CapToken,
244    task_id: TaskId,
245    attempt: u32,
246) -> Box<dyn Worker> {
247    wrap_join(handle, move |signal| async move {
248        match signal {
249            Ok(()) => {
250                let _ = apply_lua_after_pool(
251                    &after_src, &ctx_val, &pool, &engine, &token, &task_id, attempt,
252                )
253                .await;
254                Ok(())
255            }
256            Err(e) => Err(e),
257        }
258    })
259}
260
261async fn apply_lua_after_pool(
262    after_src: &str,
263    ctx_val: &Value,
264    pool: &AsyncIslePool,
265    engine: &Engine,
266    token: &crate::types::CapToken,
267    task_id: &TaskId,
268    attempt: u32,
269) -> Result<(), WorkerError> {
270    // Pull the existing Final from the tail (only Inline is fed through Lua; FileRef passes through as-is).
271    let tail = engine.output_tail(task_id, attempt).await;
272    let (value, ok) = match tail.iter().rev().find_map(|ev| match ev {
273        crate::worker::output::OutputEvent::Final {
274            content: crate::worker::output::ContentRef::Inline { value },
275            ok,
276        } => Some((value.clone(), *ok)),
277        _ => None,
278    }) {
279        Some(v) => v,
280        None => return Ok(()), // No Inline Final: do nothing.
281    };
282
283    let ctx_json = serde_json::to_string(ctx_val)
284        .map_err(|e| WorkerError::Failed(format!("ctx serialize: {e}")))?;
285    let result_json = serde_json::to_string(&serde_json::json!({"value": value, "ok": ok}))
286        .map_err(|e| WorkerError::Failed(format!("result serialize: {e}")))?;
287
288    let isle = pool
289        .checkout()
290        .await
291        .map_err(|e| WorkerError::Failed(format!("isle pool checkout: {e}")))?;
292    let f = make_after_exec(after_src.to_string(), ctx_json, result_json);
293    let new_json = isle
294        .exec(f)
295        .await
296        .map_err(|e| WorkerError::Failed(format!("lua after: {e}")))?;
297    let new_result: Value = serde_json::from_str(&new_json)
298        .map_err(|e| WorkerError::Failed(format!("lua after decode: {e}")))?;
299    let new_value = new_result.get("value").cloned().unwrap_or(Value::Null);
300    let new_ok = new_result.get("ok").and_then(|v| v.as_bool()).unwrap_or(ok);
301
302    // Push the override Final to the engine so the downstream dispatch's rev().find picks up the latest.
303    let _ = engine
304        .submit_output(
305            token,
306            task_id,
307            attempt,
308            crate::worker::output::OutputEvent::Final {
309                content: crate::worker::output::ContentRef::Inline { value: new_value },
310                ok: new_ok,
311            },
312        )
313        .await;
314    Ok(())
315}