use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
use imp_core::config::{AgentMode, Config, LuaCapabilityPolicy};
use imp_core::tools::{FileCache, FileTracker, Tool, ToolContext, ToolUpdate};
use imp_core::ui::UserInterface;
use mlua::Lua;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LuaError {
#[error("Lua error: {0}")]
Mlua(#[from] mlua::Error),
#[error("Extension error: {0}")]
Extension(String),
}
pub struct LuaToolHandle {
pub name: String,
pub label: String,
pub description: String,
pub readonly: bool,
pub params: serde_json::Value,
pub execute_key: mlua::RegistryKey,
}
pub struct LuaHookHandle {
pub event: String,
pub handler_key: mlua::RegistryKey,
}
pub struct LuaCommandHandle {
pub name: String,
pub description: String,
pub handler_key: mlua::RegistryKey,
}
pub struct LuaCallContext {
pub cwd: PathBuf,
pub cancelled: Arc<std::sync::atomic::AtomicBool>,
pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
pub command_tx: tokio::sync::mpsc::Sender<imp_core::agent::AgentCommand>,
pub ui: Arc<dyn UserInterface>,
pub file_cache: Arc<FileCache>,
pub checkpoint_state: Arc<imp_core::tools::CheckpointState>,
pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
pub anchor_store: Arc<imp_core::tools::AnchorStore>,
pub lua_tool_loader: Option<imp_core::tools::LuaToolLoader>,
pub mode: AgentMode,
pub read_max_lines: usize,
pub run_policy: imp_core::policy::RunPolicy,
pub config: Arc<Config>,
}
impl LuaCallContext {
pub fn to_tool_context(&self) -> ToolContext {
ToolContext {
cwd: self.cwd.clone(),
cancelled: Arc::clone(&self.cancelled),
update_tx: self.update_tx.clone(),
command_tx: self.command_tx.clone(),
ui: Arc::clone(&self.ui),
file_cache: Arc::clone(&self.file_cache),
checkpoint_state: Arc::clone(&self.checkpoint_state),
file_tracker: Arc::clone(&self.file_tracker),
anchor_store: Arc::clone(&self.anchor_store),
lua_tool_loader: self.lua_tool_loader.clone(),
mode: self.mode,
read_max_lines: self.read_max_lines,
turn_mana_review: Arc::new(std::sync::Mutex::new(
imp_core::mana_review::TurnManaReviewAccumulator::default(),
)),
run_policy: self.run_policy.clone(),
config: Arc::clone(&self.config),
supporting_provenance: Vec::new(),
}
}
}
impl From<ToolContext> for LuaCallContext {
fn from(ctx: ToolContext) -> Self {
Self {
cwd: ctx.cwd,
cancelled: ctx.cancelled,
update_tx: ctx.update_tx,
command_tx: ctx.command_tx,
ui: ctx.ui,
file_cache: ctx.file_cache,
checkpoint_state: ctx.checkpoint_state,
file_tracker: ctx.file_tracker,
anchor_store: ctx.anchor_store,
lua_tool_loader: ctx.lua_tool_loader,
mode: ctx.mode,
read_max_lines: ctx.read_max_lines,
run_policy: ctx.run_policy,
config: ctx.config,
}
}
}
pub struct LuaRuntime {
lua: Lua,
tools: Arc<Mutex<Vec<LuaToolHandle>>>,
hooks: Arc<Mutex<Vec<LuaHookHandle>>>,
commands: Arc<Mutex<Vec<LuaCommandHandle>>>,
native_tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
call_context: Arc<Mutex<Option<LuaCallContext>>>,
allowed_env: Arc<Mutex<HashSet<String>>>,
allow_native_tool_calls: Arc<AtomicBool>,
allow_shell_exec: Arc<AtomicBool>,
allow_http: Arc<AtomicBool>,
allow_secrets: Arc<AtomicBool>,
}
impl LuaRuntime {
pub fn new() -> Result<Self, LuaError> {
let lua = Lua::new();
Ok(Self {
lua,
tools: Arc::new(Mutex::new(Vec::new())),
hooks: Arc::new(Mutex::new(Vec::new())),
commands: Arc::new(Mutex::new(Vec::new())),
native_tools: Arc::new(Mutex::new(HashMap::new())),
call_context: Arc::new(Mutex::new(None)),
allowed_env: Arc::new(Mutex::new(HashSet::new())),
allow_native_tool_calls: Arc::new(AtomicBool::new(true)),
allow_shell_exec: Arc::new(AtomicBool::new(false)),
allow_http: Arc::new(AtomicBool::new(false)),
allow_secrets: Arc::new(AtomicBool::new(false)),
})
}
pub fn lua(&self) -> &Lua {
&self.lua
}
pub fn tools(&self) -> Arc<Mutex<Vec<LuaToolHandle>>> {
Arc::clone(&self.tools)
}
pub fn hooks(&self) -> Arc<Mutex<Vec<LuaHookHandle>>> {
Arc::clone(&self.hooks)
}
pub fn commands(&self) -> Arc<Mutex<Vec<LuaCommandHandle>>> {
Arc::clone(&self.commands)
}
pub fn native_tools(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Tool>>>> {
Arc::clone(&self.native_tools)
}
pub fn call_context(&self) -> Arc<Mutex<Option<LuaCallContext>>> {
Arc::clone(&self.call_context)
}
pub fn allowed_env(&self) -> Arc<Mutex<HashSet<String>>> {
Arc::clone(&self.allowed_env)
}
pub fn allow_shell_exec(&self) -> Arc<AtomicBool> {
Arc::clone(&self.allow_shell_exec)
}
pub fn allow_http(&self) -> Arc<AtomicBool> {
Arc::clone(&self.allow_http)
}
pub fn allow_secrets(&self) -> Arc<AtomicBool> {
Arc::clone(&self.allow_secrets)
}
pub fn allow_native_tool_calls(&self) -> Arc<AtomicBool> {
Arc::clone(&self.allow_native_tool_calls)
}
pub fn set_native_tools(&self, tools: HashMap<String, Arc<dyn Tool>>) {
*self.native_tools.lock().unwrap() = tools;
}
pub fn set_call_context(&self, ctx: LuaCallContext) {
*self.call_context.lock().unwrap() = Some(ctx);
}
pub fn clear_call_context(&self) {
*self.call_context.lock().unwrap() = None;
}
pub fn set_allowed_env(&self, vars: HashSet<String>) {
*self.allowed_env.lock().unwrap() = vars;
}
pub fn set_allow_shell_exec(&self, allowed: bool) {
self.allow_shell_exec.store(allowed, Ordering::Relaxed);
}
pub fn set_allow_http(&self, allowed: bool) {
self.allow_http.store(allowed, Ordering::Relaxed);
}
pub fn set_allow_secrets(&self, allowed: bool) {
self.allow_secrets.store(allowed, Ordering::Relaxed);
}
pub fn set_allow_native_tool_calls(&self, allowed: bool) {
self.allow_native_tool_calls
.store(allowed, Ordering::Relaxed);
}
pub fn apply_capability_policy(&self, policy: &LuaCapabilityPolicy) {
self.set_allow_native_tool_calls(policy.allow_native_tool_calls);
self.set_allow_shell_exec(policy.allow_shell_exec);
self.set_allow_http(policy.allow_http);
self.set_allow_secrets(policy.allow_secrets);
self.set_allowed_env(policy.allowed_env.clone());
}
pub fn register_tool(&self, handle: LuaToolHandle) {
self.tools.lock().unwrap().push(handle);
}
pub fn register_hook(&self, handle: LuaHookHandle) {
self.hooks.lock().unwrap().push(handle);
}
pub fn register_command(&self, handle: LuaCommandHandle) {
self.commands.lock().unwrap().push(handle);
}
pub fn exec(&self, source: &str) -> Result<(), LuaError> {
self.lua.load(source).exec()?;
Ok(())
}
pub fn exec_file(&self, path: &std::path::Path) -> Result<(), LuaError> {
let source = std::fs::read_to_string(path)
.map_err(|e| LuaError::Extension(format!("{}: {}", path.display(), e)))?;
self.lua
.load(&source)
.set_name(path.to_string_lossy())
.exec()?;
Ok(())
}
pub fn clear_registrations(&self) {
self.tools.lock().unwrap().clear();
self.hooks.lock().unwrap().clear();
self.commands.lock().unwrap().clear();
}
pub fn tool_count(&self) -> usize {
self.tools.lock().unwrap().len()
}
pub fn hook_count(&self) -> usize {
self.hooks.lock().unwrap().len()
}
pub fn command_count(&self) -> usize {
self.commands.lock().unwrap().len()
}
pub fn tool_names(&self) -> Vec<String> {
self.tools
.lock()
.unwrap()
.iter()
.map(|t| t.name.clone())
.collect()
}
pub fn hook_events(&self) -> Vec<String> {
self.hooks
.lock()
.unwrap()
.iter()
.map(|h| h.event.clone())
.collect()
}
pub fn execute_command(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
self.execute_command_with_context(name, args, None)
}
pub fn execute_command_with_context(
&self,
name: &str,
args: &str,
call_ctx: Option<LuaCallContext>,
) -> Result<Option<String>, LuaError> {
if let Some(ctx) = call_ctx {
self.set_call_context(ctx);
}
let result = self.execute_command_inner(name, args);
self.clear_call_context();
result
}
fn execute_command_inner(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
let commands = self.commands.lock().unwrap();
let handle = commands
.iter()
.find(|c| c.name == name)
.ok_or_else(|| LuaError::Extension(format!("command '{name}' not found")))?;
let handler: mlua::Function = self
.lua
.registry_value(&handle.handler_key)
.map_err(LuaError::Mlua)?;
let result: mlua::Value = handler.call(args.to_string()).map_err(LuaError::Mlua)?;
match result {
mlua::Value::Nil => Ok(None),
mlua::Value::String(s) => Ok(Some(
s.to_str()
.map(|v| v.to_string())
.unwrap_or_else(|_| "(non-utf8)".into()),
)),
other => {
let json = crate::bridge::lua_value_to_json(other);
Ok(Some(format!("{json}")))
}
}
}
pub fn command_names(&self) -> Vec<String> {
self.commands
.lock()
.unwrap()
.iter()
.map(|c| c.name.clone())
.collect()
}
pub fn command_summaries(&self) -> Vec<(String, String)> {
self.commands
.lock()
.unwrap()
.iter()
.map(|c| (c.name.clone(), c.description.clone()))
.collect()
}
pub fn has_command(&self, name: &str) -> bool {
self.commands.lock().unwrap().iter().any(|c| c.name == name)
}
}