1use crate::core::ctx::Ctx;
24use crate::core::engine::Engine;
25use crate::middleware::SpawnerLayer;
26use crate::types::{CapToken, TaskId};
27use crate::worker::adapter::{SpawnError, SpawnerAdapter, WorkerError};
28use crate::worker::{wrap_join, Worker};
29use async_trait::async_trait;
30use mlua::LuaSerdeExt;
31use mlua_isle::{AsyncIslePool, IsleError, PoolConfig, PoolStrategy};
32use serde_json::Value;
33use std::sync::{Arc, OnceLock};
34
35fn default_pool_config() -> PoolConfig {
38 PoolConfig {
39 max_size: 4,
40 strategy: PoolStrategy::Warm,
41 }
42}
43
44fn build_default_pool() -> Arc<AsyncIslePool> {
45 Arc::new(
46 AsyncIslePool::new(|_lua| Ok(()), default_pool_config())
47 .expect("AsyncIslePool::new (no-op factory) must succeed"),
48 )
49}
50
51#[derive(Clone, Default)]
56pub struct LuaMiddleware {
57 before_src: Option<String>,
58 after_src: Option<String>,
59 pool: Option<Arc<AsyncIslePool>>,
60}
61
62impl LuaMiddleware {
63 pub fn new() -> Self {
65 Self::default()
66 }
67 pub fn before(mut self, src: impl Into<String>) -> Self {
70 self.before_src = Some(src.into());
71 self
72 }
73 pub fn after(mut self, src: impl Into<String>) -> Self {
77 self.after_src = Some(src.into());
78 self
79 }
80 pub fn with_pool(mut self, pool: Arc<AsyncIslePool>) -> Self {
84 self.pool = Some(pool);
85 self
86 }
87}
88
89impl SpawnerLayer for LuaMiddleware {
90 fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
91 static DEFAULT_POOL: OnceLock<Arc<AsyncIslePool>> = OnceLock::new();
93 let pool = self
94 .pool
95 .clone()
96 .unwrap_or_else(|| DEFAULT_POOL.get_or_init(build_default_pool).clone());
97 Arc::new(LuaWrapped {
98 inner,
99 before_src: self.before_src.clone(),
100 after_src: self.after_src.clone(),
101 pool,
102 })
103 }
104}
105
106struct LuaWrapped {
107 inner: Arc<dyn SpawnerAdapter>,
108 before_src: Option<String>,
109 after_src: Option<String>,
110 pool: Arc<AsyncIslePool>,
111}
112
113fn ctx_to_serializable(ctx: &Ctx) -> Value {
115 serde_json::json!({
116 "task_id": ctx.task_id.0,
117 "attempt": ctx.attempt,
118 "agent": ctx.agent,
119 "operator": {
120 "kind": format!("{:?}", ctx.operator.kind),
121 "id": ctx.operator.id,
122 },
123 })
124}
125
126fn make_before_exec(
130 src: String,
131 ctx_json: String,
132) -> impl FnOnce(&mlua::Lua) -> Result<String, IsleError> + Send + 'static {
133 move |lua| {
134 let ctx_val: Value =
135 serde_json::from_str(&ctx_json).map_err(|e| IsleError::Lua(e.to_string()))?;
136 let ctx_lua: mlua::Value = lua
137 .to_value(&ctx_val)
138 .map_err(|e| IsleError::Lua(e.to_string()))?;
139 let f: mlua::Function = lua
140 .load(&src)
141 .eval()
142 .map_err(|e| IsleError::Lua(e.to_string()))?;
143 let _: mlua::Value = f.call(ctx_lua).map_err(|e| IsleError::Lua(e.to_string()))?;
144 Ok("ok".to_string())
145 }
146}
147
148fn make_after_exec(
151 src: String,
152 ctx_json: String,
153 result_json: String,
154) -> impl FnOnce(&mlua::Lua) -> Result<String, IsleError> + Send + 'static {
155 move |lua| {
156 let ctx_val: Value =
157 serde_json::from_str(&ctx_json).map_err(|e| IsleError::Lua(e.to_string()))?;
158 let result_val: Value =
159 serde_json::from_str(&result_json).map_err(|e| IsleError::Lua(e.to_string()))?;
160 let ctx_lua: mlua::Value = lua
161 .to_value(&ctx_val)
162 .map_err(|e| IsleError::Lua(e.to_string()))?;
163 let result_lua: mlua::Value = lua
164 .to_value(&result_val)
165 .map_err(|e| IsleError::Lua(e.to_string()))?;
166 let f: mlua::Function = lua
167 .load(&src)
168 .eval()
169 .map_err(|e| IsleError::Lua(e.to_string()))?;
170 let returned: mlua::Value = f
171 .call((ctx_lua, result_lua))
172 .map_err(|e| IsleError::Lua(e.to_string()))?;
173 let new_result: Value = lua
174 .from_value(returned)
175 .map_err(|e| IsleError::Lua(e.to_string()))?;
176 serde_json::to_string(&new_result).map_err(|e| IsleError::Lua(e.to_string()))
177 }
178}
179
180#[async_trait]
181impl SpawnerAdapter for LuaWrapped {
182 async fn spawn(
183 &self,
184 engine: &Engine,
185 ctx: &Ctx,
186 task_id: TaskId,
187 attempt: u32,
188 token: CapToken,
189 ) -> Result<Box<dyn Worker>, SpawnError> {
190 if let Some(src) = &self.before_src {
192 let ctx_json = serde_json::to_string(&ctx_to_serializable(ctx))
193 .map_err(|e| SpawnError::Internal(format!("ctx serialize: {e}")))?;
194 let isle = self
195 .pool
196 .checkout()
197 .await
198 .map_err(|e| SpawnError::Internal(format!("isle pool checkout: {e}")))?;
199 let f = make_before_exec(src.clone(), ctx_json);
200 isle.exec(f)
201 .await
202 .map_err(|e| SpawnError::RejectedByMiddleware(format!("lua before: {e}")))?;
203 }
205
206 let engine_clone = engine.clone();
207 let token_clone = token.clone();
208 let task_id_clone = task_id.clone();
209 let handle = self
210 .inner
211 .spawn(engine, ctx, task_id, attempt, token)
212 .await?;
213
214 let Some(after_src) = self.after_src.clone() else {
216 return Ok(handle);
217 };
218 let ctx_val = ctx_to_serializable(ctx);
219 Ok(wrap_completion_with_lua_pool(
220 handle,
221 after_src,
222 ctx_val,
223 self.pool.clone(),
224 engine_clone,
225 token_clone,
226 task_id_clone,
227 attempt,
228 ))
229 }
230}
231
232#[allow(clippy::too_many_arguments)]
237fn wrap_completion_with_lua_pool(
238 handle: Box<dyn Worker>,
239 after_src: String,
240 ctx_val: Value,
241 pool: Arc<AsyncIslePool>,
242 engine: Engine,
243 token: crate::types::CapToken,
244 task_id: TaskId,
245 attempt: u32,
246) -> Box<dyn Worker> {
247 wrap_join(handle, move |signal| async move {
248 match signal {
249 Ok(()) => {
250 let _ = apply_lua_after_pool(
251 &after_src, &ctx_val, &pool, &engine, &token, &task_id, attempt,
252 )
253 .await;
254 Ok(())
255 }
256 Err(e) => Err(e),
257 }
258 })
259}
260
261async fn apply_lua_after_pool(
262 after_src: &str,
263 ctx_val: &Value,
264 pool: &AsyncIslePool,
265 engine: &Engine,
266 token: &crate::types::CapToken,
267 task_id: &TaskId,
268 attempt: u32,
269) -> Result<(), WorkerError> {
270 let tail = engine.output_tail(task_id, attempt).await;
272 let (value, ok) = match tail.iter().rev().find_map(|ev| match ev {
273 crate::worker::output::OutputEvent::Final {
274 content: crate::worker::output::ContentRef::Inline { value },
275 ok,
276 } => Some((value.clone(), *ok)),
277 _ => None,
278 }) {
279 Some(v) => v,
280 None => return Ok(()), };
282
283 let ctx_json = serde_json::to_string(ctx_val)
284 .map_err(|e| WorkerError::Failed(format!("ctx serialize: {e}")))?;
285 let result_json = serde_json::to_string(&serde_json::json!({"value": value, "ok": ok}))
286 .map_err(|e| WorkerError::Failed(format!("result serialize: {e}")))?;
287
288 let isle = pool
289 .checkout()
290 .await
291 .map_err(|e| WorkerError::Failed(format!("isle pool checkout: {e}")))?;
292 let f = make_after_exec(after_src.to_string(), ctx_json, result_json);
293 let new_json = isle
294 .exec(f)
295 .await
296 .map_err(|e| WorkerError::Failed(format!("lua after: {e}")))?;
297 let new_result: Value = serde_json::from_str(&new_json)
298 .map_err(|e| WorkerError::Failed(format!("lua after decode: {e}")))?;
299 let new_value = new_result.get("value").cloned().unwrap_or(Value::Null);
300 let new_ok = new_result.get("ok").and_then(|v| v.as_bool()).unwrap_or(ok);
301
302 let _ = engine
304 .submit_output(
305 token,
306 task_id,
307 attempt,
308 crate::worker::output::OutputEvent::Final {
309 content: crate::worker::output::ContentRef::Inline { value: new_value },
310 ok: new_ok,
311 },
312 )
313 .await;
314 Ok(())
315}