Skip to main content

imp_lua/
sandbox.rs

1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3use std::sync::{
4    atomic::{AtomicBool, Ordering},
5    Arc, Mutex,
6};
7
8use imp_core::config::{AgentMode, Config, LuaCapabilityPolicy};
9use imp_core::tools::{FileCache, FileTracker, Tool, ToolContext, ToolUpdate};
10use imp_core::ui::UserInterface;
11use mlua::Lua;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum LuaError {
16    #[error("Lua error: {0}")]
17    Mlua(#[from] mlua::Error),
18
19    #[error("Extension error: {0}")]
20    Extension(String),
21}
22
23/// Handle to a Lua-registered tool.
24pub struct LuaToolHandle {
25    pub name: String,
26    pub label: String,
27    pub description: String,
28    pub readonly: bool,
29    pub params: serde_json::Value,
30    /// Registry key for the execute function stored in Lua.
31    pub execute_key: mlua::RegistryKey,
32}
33
34/// Handle to a Lua-registered hook.
35pub struct LuaHookHandle {
36    pub event: String,
37    /// Registry key for the handler function stored in Lua.
38    pub handler_key: mlua::RegistryKey,
39}
40
41/// Handle to a Lua-registered command.
42pub struct LuaCommandHandle {
43    pub name: String,
44    pub description: String,
45    pub handler_key: mlua::RegistryKey,
46}
47
48/// Context passed to Lua host API functions during tool execution.
49///
50/// Mirrors `ToolContext` but is stored separately so the Lua
51/// `imp.tool()` callback can construct a fresh `ToolContext` for
52/// each native tool call.
53pub struct LuaCallContext {
54    pub cwd: PathBuf,
55    pub cancelled: Arc<std::sync::atomic::AtomicBool>,
56    pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
57    pub command_tx: tokio::sync::mpsc::Sender<imp_core::agent::AgentCommand>,
58    pub ui: Arc<dyn UserInterface>,
59    pub file_cache: Arc<FileCache>,
60    pub checkpoint_state: Arc<imp_core::tools::CheckpointState>,
61    pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
62    pub anchor_store: Arc<imp_core::tools::AnchorStore>,
63    pub lua_tool_loader: Option<imp_core::tools::LuaToolLoader>,
64    pub mode: AgentMode,
65    pub read_max_lines: usize,
66    pub config: Arc<Config>,
67}
68
69impl LuaCallContext {
70    /// Build a `ToolContext` from the stored fields.
71    pub fn to_tool_context(&self) -> ToolContext {
72        ToolContext {
73            cwd: self.cwd.clone(),
74            cancelled: Arc::clone(&self.cancelled),
75            update_tx: self.update_tx.clone(),
76            command_tx: self.command_tx.clone(),
77            ui: Arc::clone(&self.ui),
78            file_cache: Arc::clone(&self.file_cache),
79            checkpoint_state: Arc::clone(&self.checkpoint_state),
80            file_tracker: Arc::clone(&self.file_tracker),
81            anchor_store: Arc::clone(&self.anchor_store),
82            lua_tool_loader: self.lua_tool_loader.clone(),
83            mode: self.mode,
84            read_max_lines: self.read_max_lines,
85            turn_mana_review: Arc::new(std::sync::Mutex::new(
86                imp_core::mana_review::TurnManaReviewAccumulator::default(),
87            )),
88            config: Arc::clone(&self.config),
89        }
90    }
91}
92
93/// Manages the Lua state for extensions.
94pub struct LuaRuntime {
95    lua: Lua,
96    tools: Arc<Mutex<Vec<LuaToolHandle>>>,
97    hooks: Arc<Mutex<Vec<LuaHookHandle>>>,
98    commands: Arc<Mutex<Vec<LuaCommandHandle>>>,
99    /// Native imp tools available via `imp.tool()` from Lua.
100    native_tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
101    /// Active execution context for `imp.tool()` calls.
102    call_context: Arc<Mutex<Option<LuaCallContext>>>,
103    /// Env vars this extension is allowed to read via `imp.env()`.
104    allowed_env: Arc<Mutex<HashSet<String>>>,
105    /// Whether `imp.tool()` calls are currently permitted.
106    allow_native_tool_calls: Arc<AtomicBool>,
107    /// Whether `imp.exec()` shell execution is permitted.
108    allow_shell_exec: Arc<AtomicBool>,
109    /// Whether `imp.http.*` calls are permitted.
110    allow_http: Arc<AtomicBool>,
111    /// Whether secret access is permitted.
112    allow_secrets: Arc<AtomicBool>,
113}
114
115impl LuaRuntime {
116    /// Create a new Lua runtime with standard libraries.
117    pub fn new() -> Result<Self, LuaError> {
118        let lua = Lua::new();
119        Ok(Self {
120            lua,
121            tools: Arc::new(Mutex::new(Vec::new())),
122            hooks: Arc::new(Mutex::new(Vec::new())),
123            commands: Arc::new(Mutex::new(Vec::new())),
124            native_tools: Arc::new(Mutex::new(HashMap::new())),
125            call_context: Arc::new(Mutex::new(None)),
126            allowed_env: Arc::new(Mutex::new(HashSet::new())),
127            allow_native_tool_calls: Arc::new(AtomicBool::new(true)),
128            allow_shell_exec: Arc::new(AtomicBool::new(false)),
129            allow_http: Arc::new(AtomicBool::new(false)),
130            allow_secrets: Arc::new(AtomicBool::new(false)),
131        })
132    }
133
134    /// Get a reference to the underlying Lua state.
135    pub fn lua(&self) -> &Lua {
136        &self.lua
137    }
138
139    /// Get a clone of the tools handle for external access.
140    pub fn tools(&self) -> Arc<Mutex<Vec<LuaToolHandle>>> {
141        Arc::clone(&self.tools)
142    }
143
144    /// Get a clone of the hooks handle for external access.
145    pub fn hooks(&self) -> Arc<Mutex<Vec<LuaHookHandle>>> {
146        Arc::clone(&self.hooks)
147    }
148
149    /// Get a clone of the commands handle for external access.
150    pub fn commands(&self) -> Arc<Mutex<Vec<LuaCommandHandle>>> {
151        Arc::clone(&self.commands)
152    }
153
154    /// Get a clone of the native tools map.
155    pub fn native_tools(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Tool>>>> {
156        Arc::clone(&self.native_tools)
157    }
158
159    /// Get a clone of the call context handle.
160    pub fn call_context(&self) -> Arc<Mutex<Option<LuaCallContext>>> {
161        Arc::clone(&self.call_context)
162    }
163
164    /// Get a clone of the allowed-env handle.
165    pub fn allowed_env(&self) -> Arc<Mutex<HashSet<String>>> {
166        Arc::clone(&self.allowed_env)
167    }
168
169    /// Get whether `imp.exec()` calls are currently permitted.
170    pub fn allow_shell_exec(&self) -> Arc<AtomicBool> {
171        Arc::clone(&self.allow_shell_exec)
172    }
173
174    /// Get whether `imp.http.*` calls are currently permitted.
175    pub fn allow_http(&self) -> Arc<AtomicBool> {
176        Arc::clone(&self.allow_http)
177    }
178
179    /// Get whether secret access is currently permitted.
180    pub fn allow_secrets(&self) -> Arc<AtomicBool> {
181        Arc::clone(&self.allow_secrets)
182    }
183
184    /// Get whether `imp.tool()` calls are currently permitted.
185    pub fn allow_native_tool_calls(&self) -> Arc<AtomicBool> {
186        Arc::clone(&self.allow_native_tool_calls)
187    }
188
189    /// Populate the native tool registry (called once after tools are registered).
190    pub fn set_native_tools(&self, tools: HashMap<String, Arc<dyn Tool>>) {
191        *self.native_tools.lock().unwrap() = tools;
192    }
193
194    /// Set the call context before executing a Lua tool function.
195    pub fn set_call_context(&self, ctx: LuaCallContext) {
196        *self.call_context.lock().unwrap() = Some(ctx);
197    }
198
199    /// Clear the call context after execution.
200    pub fn clear_call_context(&self) {
201        *self.call_context.lock().unwrap() = None;
202    }
203
204    /// Set the allowed env vars for this extension.
205    pub fn set_allowed_env(&self, vars: HashSet<String>) {
206        *self.allowed_env.lock().unwrap() = vars;
207    }
208
209    /// Set whether `imp.exec()` calls are permitted for the current runtime.
210    pub fn set_allow_shell_exec(&self, allowed: bool) {
211        self.allow_shell_exec.store(allowed, Ordering::Relaxed);
212    }
213
214    /// Set whether `imp.http.*` calls are permitted for the current runtime.
215    pub fn set_allow_http(&self, allowed: bool) {
216        self.allow_http.store(allowed, Ordering::Relaxed);
217    }
218
219    /// Set whether secret access is permitted for the current runtime.
220    pub fn set_allow_secrets(&self, allowed: bool) {
221        self.allow_secrets.store(allowed, Ordering::Relaxed);
222    }
223
224    /// Set whether `imp.tool()` calls are permitted for the current runtime.
225    pub fn set_allow_native_tool_calls(&self, allowed: bool) {
226        self.allow_native_tool_calls
227            .store(allowed, Ordering::Relaxed);
228    }
229
230    /// Apply a shipped-runtime capability policy.
231    pub fn apply_capability_policy(&self, policy: &LuaCapabilityPolicy) {
232        self.set_allow_native_tool_calls(policy.allow_native_tool_calls);
233        self.set_allow_shell_exec(policy.allow_shell_exec);
234        self.set_allow_http(policy.allow_http);
235        self.set_allow_secrets(policy.allow_secrets);
236        self.set_allowed_env(policy.allowed_env.clone());
237    }
238
239    /// Register a tool handle (called from bridge).
240    pub fn register_tool(&self, handle: LuaToolHandle) {
241        self.tools.lock().unwrap().push(handle);
242    }
243
244    /// Register a hook handle (called from bridge).
245    pub fn register_hook(&self, handle: LuaHookHandle) {
246        self.hooks.lock().unwrap().push(handle);
247    }
248
249    /// Register a command handle (called from bridge).
250    pub fn register_command(&self, handle: LuaCommandHandle) {
251        self.commands.lock().unwrap().push(handle);
252    }
253
254    /// Execute a Lua script string.
255    pub fn exec(&self, source: &str) -> Result<(), LuaError> {
256        self.lua.load(source).exec()?;
257        Ok(())
258    }
259
260    /// Execute a Lua file.
261    pub fn exec_file(&self, path: &std::path::Path) -> Result<(), LuaError> {
262        let source = std::fs::read_to_string(path)
263            .map_err(|e| LuaError::Extension(format!("{}: {}", path.display(), e)))?;
264        self.lua
265            .load(&source)
266            .set_name(path.to_string_lossy())
267            .exec()?;
268        Ok(())
269    }
270
271    /// Clear all registered tools, hooks, and commands.
272    pub fn clear_registrations(&self) {
273        self.tools.lock().unwrap().clear();
274        self.hooks.lock().unwrap().clear();
275        self.commands.lock().unwrap().clear();
276    }
277
278    /// Number of registered tools.
279    pub fn tool_count(&self) -> usize {
280        self.tools.lock().unwrap().len()
281    }
282
283    /// Number of registered hooks.
284    pub fn hook_count(&self) -> usize {
285        self.hooks.lock().unwrap().len()
286    }
287
288    /// Number of registered commands.
289    pub fn command_count(&self) -> usize {
290        self.commands.lock().unwrap().len()
291    }
292
293    /// Get tool names.
294    pub fn tool_names(&self) -> Vec<String> {
295        self.tools
296            .lock()
297            .unwrap()
298            .iter()
299            .map(|t| t.name.clone())
300            .collect()
301    }
302
303    /// Get hook event names.
304    pub fn hook_events(&self) -> Vec<String> {
305        self.hooks
306            .lock()
307            .unwrap()
308            .iter()
309            .map(|h| h.event.clone())
310            .collect()
311    }
312
313    /// Execute a registered command by name, returning its string output.
314    ///
315    /// Returns `Ok(None)` if the command returned nil (silent success).
316    /// Returns `Ok(Some(text))` if the command returned a string or value.
317    /// Returns `Err` if the command handler or name wasn't found.
318    pub fn execute_command(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
319        let commands = self.commands.lock().unwrap();
320        let handle = commands
321            .iter()
322            .find(|c| c.name == name)
323            .ok_or_else(|| LuaError::Extension(format!("command '{name}' not found")))?;
324
325        let handler: mlua::Function = self
326            .lua
327            .registry_value(&handle.handler_key)
328            .map_err(LuaError::Mlua)?;
329
330        let result: mlua::Value = handler.call(args.to_string()).map_err(LuaError::Mlua)?;
331
332        match result {
333            mlua::Value::Nil => Ok(None),
334            mlua::Value::String(s) => Ok(Some(
335                s.to_str()
336                    .map(|v| v.to_string())
337                    .unwrap_or_else(|_| "(non-utf8)".into()),
338            )),
339            other => {
340                let json = crate::bridge::lua_value_to_json(other);
341                Ok(Some(format!("{json}")))
342            }
343        }
344    }
345
346    /// Get command names.
347    pub fn command_names(&self) -> Vec<String> {
348        self.commands
349            .lock()
350            .unwrap()
351            .iter()
352            .map(|c| c.name.clone())
353            .collect()
354    }
355
356    /// Check if a command with the given name exists.
357    pub fn has_command(&self, name: &str) -> bool {
358        self.commands.lock().unwrap().iter().any(|c| c.name == name)
359    }
360}