use crate::core::ctx::Ctx;
use crate::core::engine::Engine;
use crate::middleware::SpawnerLayer;
use crate::types::{CapToken, TaskId};
use crate::worker::adapter::{SpawnError, SpawnerAdapter, WorkerError};
use crate::worker::{wrap_join, Worker};
use async_trait::async_trait;
use mlua::LuaSerdeExt;
use mlua_isle::{AsyncIslePool, IsleError, PoolConfig, PoolStrategy};
use serde_json::Value;
use std::sync::{Arc, OnceLock};
fn default_pool_config() -> PoolConfig {
PoolConfig {
max_size: 4,
strategy: PoolStrategy::Warm,
}
}
fn build_default_pool() -> Arc<AsyncIslePool> {
Arc::new(
AsyncIslePool::new(|_lua| Ok(()), default_pool_config())
.expect("AsyncIslePool::new (no-op factory) must succeed"),
)
}
#[derive(Clone, Default)]
pub struct LuaMiddleware {
before_src: Option<String>,
after_src: Option<String>,
pool: Option<Arc<AsyncIslePool>>,
}
impl LuaMiddleware {
pub fn new() -> Self {
Self::default()
}
pub fn before(mut self, src: impl Into<String>) -> Self {
self.before_src = Some(src.into());
self
}
pub fn after(mut self, src: impl Into<String>) -> Self {
self.after_src = Some(src.into());
self
}
pub fn with_pool(mut self, pool: Arc<AsyncIslePool>) -> Self {
self.pool = Some(pool);
self
}
}
impl SpawnerLayer for LuaMiddleware {
fn wrap(&self, inner: Arc<dyn SpawnerAdapter>) -> Arc<dyn SpawnerAdapter> {
static DEFAULT_POOL: OnceLock<Arc<AsyncIslePool>> = OnceLock::new();
let pool = self
.pool
.clone()
.unwrap_or_else(|| DEFAULT_POOL.get_or_init(build_default_pool).clone());
Arc::new(LuaWrapped {
inner,
before_src: self.before_src.clone(),
after_src: self.after_src.clone(),
pool,
})
}
}
struct LuaWrapped {
inner: Arc<dyn SpawnerAdapter>,
before_src: Option<String>,
after_src: Option<String>,
pool: Arc<AsyncIslePool>,
}
fn ctx_to_serializable(ctx: &Ctx) -> Value {
serde_json::json!({
"task_id": ctx.task_id.0,
"attempt": ctx.attempt,
"agent": ctx.agent,
"operator": {
"kind": format!("{:?}", ctx.operator.kind),
"id": ctx.operator.id,
},
})
}
fn make_before_exec(
src: String,
ctx_json: String,
) -> impl FnOnce(&mlua::Lua) -> Result<String, IsleError> + Send + 'static {
move |lua| {
let ctx_val: Value =
serde_json::from_str(&ctx_json).map_err(|e| IsleError::Lua(e.to_string()))?;
let ctx_lua: mlua::Value = lua
.to_value(&ctx_val)
.map_err(|e| IsleError::Lua(e.to_string()))?;
let f: mlua::Function = lua
.load(&src)
.eval()
.map_err(|e| IsleError::Lua(e.to_string()))?;
let _: mlua::Value = f.call(ctx_lua).map_err(|e| IsleError::Lua(e.to_string()))?;
Ok("ok".to_string())
}
}
fn make_after_exec(
src: String,
ctx_json: String,
result_json: String,
) -> impl FnOnce(&mlua::Lua) -> Result<String, IsleError> + Send + 'static {
move |lua| {
let ctx_val: Value =
serde_json::from_str(&ctx_json).map_err(|e| IsleError::Lua(e.to_string()))?;
let result_val: Value =
serde_json::from_str(&result_json).map_err(|e| IsleError::Lua(e.to_string()))?;
let ctx_lua: mlua::Value = lua
.to_value(&ctx_val)
.map_err(|e| IsleError::Lua(e.to_string()))?;
let result_lua: mlua::Value = lua
.to_value(&result_val)
.map_err(|e| IsleError::Lua(e.to_string()))?;
let f: mlua::Function = lua
.load(&src)
.eval()
.map_err(|e| IsleError::Lua(e.to_string()))?;
let returned: mlua::Value = f
.call((ctx_lua, result_lua))
.map_err(|e| IsleError::Lua(e.to_string()))?;
let new_result: Value = lua
.from_value(returned)
.map_err(|e| IsleError::Lua(e.to_string()))?;
serde_json::to_string(&new_result).map_err(|e| IsleError::Lua(e.to_string()))
}
}
#[async_trait]
impl SpawnerAdapter for LuaWrapped {
async fn spawn(
&self,
engine: &Engine,
ctx: &Ctx,
task_id: TaskId,
attempt: u32,
token: CapToken,
) -> Result<Box<dyn Worker>, SpawnError> {
if let Some(src) = &self.before_src {
let ctx_json = serde_json::to_string(&ctx_to_serializable(ctx))
.map_err(|e| SpawnError::Internal(format!("ctx serialize: {e}")))?;
let isle = self
.pool
.checkout()
.await
.map_err(|e| SpawnError::Internal(format!("isle pool checkout: {e}")))?;
let f = make_before_exec(src.clone(), ctx_json);
isle.exec(f)
.await
.map_err(|e| SpawnError::RejectedByMiddleware(format!("lua before: {e}")))?;
}
let engine_clone = engine.clone();
let token_clone = token.clone();
let task_id_clone = task_id.clone();
let handle = self
.inner
.spawn(engine, ctx, task_id, attempt, token)
.await?;
let Some(after_src) = self.after_src.clone() else {
return Ok(handle);
};
let ctx_val = ctx_to_serializable(ctx);
Ok(wrap_completion_with_lua_pool(
handle,
after_src,
ctx_val,
self.pool.clone(),
engine_clone,
token_clone,
task_id_clone,
attempt,
))
}
}
#[allow(clippy::too_many_arguments)]
fn wrap_completion_with_lua_pool(
handle: Box<dyn Worker>,
after_src: String,
ctx_val: Value,
pool: Arc<AsyncIslePool>,
engine: Engine,
token: crate::types::CapToken,
task_id: TaskId,
attempt: u32,
) -> Box<dyn Worker> {
wrap_join(handle, move |signal| async move {
match signal {
Ok(()) => {
let _ = apply_lua_after_pool(
&after_src, &ctx_val, &pool, &engine, &token, &task_id, attempt,
)
.await;
Ok(())
}
Err(e) => Err(e),
}
})
}
async fn apply_lua_after_pool(
after_src: &str,
ctx_val: &Value,
pool: &AsyncIslePool,
engine: &Engine,
token: &crate::types::CapToken,
task_id: &TaskId,
attempt: u32,
) -> Result<(), WorkerError> {
let tail = engine.output_tail(task_id, attempt).await;
let (value, ok) = match tail.iter().rev().find_map(|ev| match ev {
crate::worker::output::OutputEvent::Final {
content: crate::worker::output::ContentRef::Inline { value },
ok,
} => Some((value.clone(), *ok)),
_ => None,
}) {
Some(v) => v,
None => return Ok(()), };
let ctx_json = serde_json::to_string(ctx_val)
.map_err(|e| WorkerError::Failed(format!("ctx serialize: {e}")))?;
let result_json = serde_json::to_string(&serde_json::json!({"value": value, "ok": ok}))
.map_err(|e| WorkerError::Failed(format!("result serialize: {e}")))?;
let isle = pool
.checkout()
.await
.map_err(|e| WorkerError::Failed(format!("isle pool checkout: {e}")))?;
let f = make_after_exec(after_src.to_string(), ctx_json, result_json);
let new_json = isle
.exec(f)
.await
.map_err(|e| WorkerError::Failed(format!("lua after: {e}")))?;
let new_result: Value = serde_json::from_str(&new_json)
.map_err(|e| WorkerError::Failed(format!("lua after decode: {e}")))?;
let new_value = new_result.get("value").cloned().unwrap_or(Value::Null);
let new_ok = new_result.get("ok").and_then(|v| v.as_bool()).unwrap_or(ok);
let _ = engine
.submit_output(
token,
task_id,
attempt,
crate::worker::output::OutputEvent::Final {
content: crate::worker::output::ContentRef::Inline { value: new_value },
ok: new_ok,
},
)
.await;
Ok(())
}