Skip to main content

imp_lua/
sandbox.rs

1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3use std::sync::{
4    atomic::{AtomicBool, Ordering},
5    Arc, Mutex,
6};
7
8use imp_core::config::{AgentMode, Config, LuaCapabilityPolicy};
9use imp_core::tools::{FileCache, FileTracker, Tool, ToolContext, ToolUpdate};
10use imp_core::ui::UserInterface;
11use mlua::Lua;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum LuaError {
16    #[error("Lua error: {0}")]
17    Mlua(#[from] mlua::Error),
18
19    #[error("Extension error: {0}")]
20    Extension(String),
21}
22
23/// Handle to a Lua-registered tool.
24pub struct LuaToolHandle {
25    pub name: String,
26    pub label: String,
27    pub description: String,
28    pub readonly: bool,
29    pub params: serde_json::Value,
30    /// Registry key for the execute function stored in Lua.
31    pub execute_key: mlua::RegistryKey,
32}
33
34/// Handle to a Lua-registered hook.
35pub struct LuaHookHandle {
36    pub event: String,
37    /// Registry key for the handler function stored in Lua.
38    pub handler_key: mlua::RegistryKey,
39}
40
41/// Handle to a Lua-registered command.
42pub struct LuaCommandHandle {
43    pub name: String,
44    pub description: String,
45    pub handler_key: mlua::RegistryKey,
46}
47
48/// Context passed to Lua host API functions during tool execution.
49///
50/// Mirrors `ToolContext` but is stored separately so the Lua
51/// `imp.tool()` callback can construct a fresh `ToolContext` for
52/// each native tool call.
53pub struct LuaCallContext {
54    pub cwd: PathBuf,
55    pub cancelled: Arc<std::sync::atomic::AtomicBool>,
56    pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
57    pub command_tx: tokio::sync::mpsc::Sender<imp_core::agent::AgentCommand>,
58    pub ui: Arc<dyn UserInterface>,
59    pub file_cache: Arc<FileCache>,
60    pub checkpoint_state: Arc<imp_core::tools::CheckpointState>,
61    pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
62    pub anchor_store: Arc<imp_core::tools::AnchorStore>,
63    pub lua_tool_loader: Option<imp_core::tools::LuaToolLoader>,
64    pub mode: AgentMode,
65    pub read_max_lines: usize,
66    pub run_policy: imp_core::policy::RunPolicy,
67    pub config: Arc<Config>,
68}
69
70impl LuaCallContext {
71    /// Build a `ToolContext` from the stored fields.
72    pub fn to_tool_context(&self) -> ToolContext {
73        ToolContext {
74            cwd: self.cwd.clone(),
75            cancelled: Arc::clone(&self.cancelled),
76            update_tx: self.update_tx.clone(),
77            command_tx: self.command_tx.clone(),
78            ui: Arc::clone(&self.ui),
79            file_cache: Arc::clone(&self.file_cache),
80            checkpoint_state: Arc::clone(&self.checkpoint_state),
81            file_tracker: Arc::clone(&self.file_tracker),
82            anchor_store: Arc::clone(&self.anchor_store),
83            lua_tool_loader: self.lua_tool_loader.clone(),
84            mode: self.mode,
85            read_max_lines: self.read_max_lines,
86            turn_mana_review: Arc::new(std::sync::Mutex::new(
87                imp_core::mana_review::TurnManaReviewAccumulator::default(),
88            )),
89            run_policy: self.run_policy.clone(),
90            config: Arc::clone(&self.config),
91            supporting_provenance: Vec::new(),
92        }
93    }
94}
95
96impl From<ToolContext> for LuaCallContext {
97    fn from(ctx: ToolContext) -> Self {
98        Self {
99            cwd: ctx.cwd,
100            cancelled: ctx.cancelled,
101            update_tx: ctx.update_tx,
102            command_tx: ctx.command_tx,
103            ui: ctx.ui,
104            file_cache: ctx.file_cache,
105            checkpoint_state: ctx.checkpoint_state,
106            file_tracker: ctx.file_tracker,
107            anchor_store: ctx.anchor_store,
108            lua_tool_loader: ctx.lua_tool_loader,
109            mode: ctx.mode,
110            read_max_lines: ctx.read_max_lines,
111            run_policy: ctx.run_policy,
112            config: ctx.config,
113        }
114    }
115}
116
117/// Manages the Lua state for extensions.
118pub struct LuaRuntime {
119    lua: Lua,
120    tools: Arc<Mutex<Vec<LuaToolHandle>>>,
121    hooks: Arc<Mutex<Vec<LuaHookHandle>>>,
122    commands: Arc<Mutex<Vec<LuaCommandHandle>>>,
123    /// Native imp tools available via `imp.tool()` from Lua.
124    native_tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
125    /// Active execution context for `imp.tool()` calls.
126    call_context: Arc<Mutex<Option<LuaCallContext>>>,
127    /// Env vars this extension is allowed to read via `imp.env()`.
128    allowed_env: Arc<Mutex<HashSet<String>>>,
129    /// Whether `imp.tool()` calls are currently permitted.
130    allow_native_tool_calls: Arc<AtomicBool>,
131    /// Whether `imp.exec()` shell execution is permitted.
132    allow_shell_exec: Arc<AtomicBool>,
133    /// Whether `imp.http.*` calls are permitted.
134    allow_http: Arc<AtomicBool>,
135    /// Whether secret access is permitted.
136    allow_secrets: Arc<AtomicBool>,
137}
138
139impl LuaRuntime {
140    /// Create a new Lua runtime with standard libraries.
141    pub fn new() -> Result<Self, LuaError> {
142        let lua = Lua::new();
143        Ok(Self {
144            lua,
145            tools: Arc::new(Mutex::new(Vec::new())),
146            hooks: Arc::new(Mutex::new(Vec::new())),
147            commands: Arc::new(Mutex::new(Vec::new())),
148            native_tools: Arc::new(Mutex::new(HashMap::new())),
149            call_context: Arc::new(Mutex::new(None)),
150            allowed_env: Arc::new(Mutex::new(HashSet::new())),
151            allow_native_tool_calls: Arc::new(AtomicBool::new(true)),
152            allow_shell_exec: Arc::new(AtomicBool::new(false)),
153            allow_http: Arc::new(AtomicBool::new(false)),
154            allow_secrets: Arc::new(AtomicBool::new(false)),
155        })
156    }
157
158    /// Get a reference to the underlying Lua state.
159    pub fn lua(&self) -> &Lua {
160        &self.lua
161    }
162
163    /// Get a clone of the tools handle for external access.
164    pub fn tools(&self) -> Arc<Mutex<Vec<LuaToolHandle>>> {
165        Arc::clone(&self.tools)
166    }
167
168    /// Get a clone of the hooks handle for external access.
169    pub fn hooks(&self) -> Arc<Mutex<Vec<LuaHookHandle>>> {
170        Arc::clone(&self.hooks)
171    }
172
173    /// Get a clone of the commands handle for external access.
174    pub fn commands(&self) -> Arc<Mutex<Vec<LuaCommandHandle>>> {
175        Arc::clone(&self.commands)
176    }
177
178    /// Get a clone of the native tools map.
179    pub fn native_tools(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Tool>>>> {
180        Arc::clone(&self.native_tools)
181    }
182
183    /// Get a clone of the call context handle.
184    pub fn call_context(&self) -> Arc<Mutex<Option<LuaCallContext>>> {
185        Arc::clone(&self.call_context)
186    }
187
188    /// Get a clone of the allowed-env handle.
189    pub fn allowed_env(&self) -> Arc<Mutex<HashSet<String>>> {
190        Arc::clone(&self.allowed_env)
191    }
192
193    /// Get whether `imp.exec()` calls are currently permitted.
194    pub fn allow_shell_exec(&self) -> Arc<AtomicBool> {
195        Arc::clone(&self.allow_shell_exec)
196    }
197
198    /// Get whether `imp.http.*` calls are currently permitted.
199    pub fn allow_http(&self) -> Arc<AtomicBool> {
200        Arc::clone(&self.allow_http)
201    }
202
203    /// Get whether secret access is currently permitted.
204    pub fn allow_secrets(&self) -> Arc<AtomicBool> {
205        Arc::clone(&self.allow_secrets)
206    }
207
208    /// Get whether `imp.tool()` calls are currently permitted.
209    pub fn allow_native_tool_calls(&self) -> Arc<AtomicBool> {
210        Arc::clone(&self.allow_native_tool_calls)
211    }
212
213    /// Populate the native tool registry (called once after tools are registered).
214    pub fn set_native_tools(&self, tools: HashMap<String, Arc<dyn Tool>>) {
215        *self.native_tools.lock().unwrap() = tools;
216    }
217
218    /// Set the call context before executing a Lua tool function.
219    pub fn set_call_context(&self, ctx: LuaCallContext) {
220        *self.call_context.lock().unwrap() = Some(ctx);
221    }
222
223    /// Clear the call context after execution.
224    pub fn clear_call_context(&self) {
225        *self.call_context.lock().unwrap() = None;
226    }
227
228    /// Set the allowed env vars for this extension.
229    pub fn set_allowed_env(&self, vars: HashSet<String>) {
230        *self.allowed_env.lock().unwrap() = vars;
231    }
232
233    /// Set whether `imp.exec()` calls are permitted for the current runtime.
234    pub fn set_allow_shell_exec(&self, allowed: bool) {
235        self.allow_shell_exec.store(allowed, Ordering::Relaxed);
236    }
237
238    /// Set whether `imp.http.*` calls are permitted for the current runtime.
239    pub fn set_allow_http(&self, allowed: bool) {
240        self.allow_http.store(allowed, Ordering::Relaxed);
241    }
242
243    /// Set whether secret access is permitted for the current runtime.
244    pub fn set_allow_secrets(&self, allowed: bool) {
245        self.allow_secrets.store(allowed, Ordering::Relaxed);
246    }
247
248    /// Set whether `imp.tool()` calls are permitted for the current runtime.
249    pub fn set_allow_native_tool_calls(&self, allowed: bool) {
250        self.allow_native_tool_calls
251            .store(allowed, Ordering::Relaxed);
252    }
253
254    /// Apply a shipped-runtime capability policy.
255    pub fn apply_capability_policy(&self, policy: &LuaCapabilityPolicy) {
256        self.set_allow_native_tool_calls(policy.allow_native_tool_calls);
257        self.set_allow_shell_exec(policy.allow_shell_exec);
258        self.set_allow_http(policy.allow_http);
259        self.set_allow_secrets(policy.allow_secrets);
260        self.set_allowed_env(policy.allowed_env.clone());
261    }
262
263    /// Register a tool handle (called from bridge).
264    pub fn register_tool(&self, handle: LuaToolHandle) {
265        self.tools.lock().unwrap().push(handle);
266    }
267
268    /// Register a hook handle (called from bridge).
269    pub fn register_hook(&self, handle: LuaHookHandle) {
270        self.hooks.lock().unwrap().push(handle);
271    }
272
273    /// Register a command handle (called from bridge).
274    pub fn register_command(&self, handle: LuaCommandHandle) {
275        self.commands.lock().unwrap().push(handle);
276    }
277
278    /// Execute a Lua script string.
279    pub fn exec(&self, source: &str) -> Result<(), LuaError> {
280        self.lua.load(source).exec()?;
281        Ok(())
282    }
283
284    /// Execute a Lua file.
285    pub fn exec_file(&self, path: &std::path::Path) -> Result<(), LuaError> {
286        let source = std::fs::read_to_string(path)
287            .map_err(|e| LuaError::Extension(format!("{}: {}", path.display(), e)))?;
288        self.lua
289            .load(&source)
290            .set_name(path.to_string_lossy())
291            .exec()?;
292        Ok(())
293    }
294
295    /// Clear all registered tools, hooks, and commands.
296    pub fn clear_registrations(&self) {
297        self.tools.lock().unwrap().clear();
298        self.hooks.lock().unwrap().clear();
299        self.commands.lock().unwrap().clear();
300    }
301
302    /// Number of registered tools.
303    pub fn tool_count(&self) -> usize {
304        self.tools.lock().unwrap().len()
305    }
306
307    /// Number of registered hooks.
308    pub fn hook_count(&self) -> usize {
309        self.hooks.lock().unwrap().len()
310    }
311
312    /// Number of registered commands.
313    pub fn command_count(&self) -> usize {
314        self.commands.lock().unwrap().len()
315    }
316
317    /// Get tool names.
318    pub fn tool_names(&self) -> Vec<String> {
319        self.tools
320            .lock()
321            .unwrap()
322            .iter()
323            .map(|t| t.name.clone())
324            .collect()
325    }
326
327    /// Get hook event names.
328    pub fn hook_events(&self) -> Vec<String> {
329        self.hooks
330            .lock()
331            .unwrap()
332            .iter()
333            .map(|h| h.event.clone())
334            .collect()
335    }
336
337    /// Execute a registered command by name, returning its string output.
338    ///
339    /// Returns `Ok(None)` if the command returned nil (silent success).
340    /// Returns `Ok(Some(text))` if the command returned a string or value.
341    /// Returns `Err` if the command handler or name wasn't found.
342    pub fn execute_command(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
343        self.execute_command_with_context(name, args, None)
344    }
345
346    /// Execute a registered command with an optional host call context.
347    pub fn execute_command_with_context(
348        &self,
349        name: &str,
350        args: &str,
351        call_ctx: Option<LuaCallContext>,
352    ) -> Result<Option<String>, LuaError> {
353        if let Some(ctx) = call_ctx {
354            self.set_call_context(ctx);
355        }
356        let result = self.execute_command_inner(name, args);
357        self.clear_call_context();
358        result
359    }
360
361    fn execute_command_inner(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
362        let commands = self.commands.lock().unwrap();
363        let handle = commands
364            .iter()
365            .find(|c| c.name == name)
366            .ok_or_else(|| LuaError::Extension(format!("command '{name}' not found")))?;
367
368        let handler: mlua::Function = self
369            .lua
370            .registry_value(&handle.handler_key)
371            .map_err(LuaError::Mlua)?;
372
373        let result: mlua::Value = handler.call(args.to_string()).map_err(LuaError::Mlua)?;
374
375        match result {
376            mlua::Value::Nil => Ok(None),
377            mlua::Value::String(s) => Ok(Some(
378                s.to_str()
379                    .map(|v| v.to_string())
380                    .unwrap_or_else(|_| "(non-utf8)".into()),
381            )),
382            other => {
383                let json = crate::bridge::lua_value_to_json(other);
384                Ok(Some(format!("{json}")))
385            }
386        }
387    }
388
389    /// Get command names.
390    pub fn command_names(&self) -> Vec<String> {
391        self.commands
392            .lock()
393            .unwrap()
394            .iter()
395            .map(|c| c.name.clone())
396            .collect()
397    }
398
399    /// Get command names with descriptions for menus and discovery.
400    pub fn command_summaries(&self) -> Vec<(String, String)> {
401        self.commands
402            .lock()
403            .unwrap()
404            .iter()
405            .map(|c| (c.name.clone(), c.description.clone()))
406            .collect()
407    }
408
409    /// Check if a command with the given name exists.
410    pub fn has_command(&self, name: &str) -> bool {
411        self.commands.lock().unwrap().iter().any(|c| c.name == name)
412    }
413}