1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3use std::sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc, Mutex,
6};
7
8use imp_core::config::{AgentMode, Config, LuaCapabilityPolicy};
9use imp_core::tools::{FileCache, FileTracker, Tool, ToolContext, ToolUpdate};
10use imp_core::ui::UserInterface;
11use mlua::Lua;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum LuaError {
16 #[error("Lua error: {0}")]
17 Mlua(#[from] mlua::Error),
18
19 #[error("Extension error: {0}")]
20 Extension(String),
21}
22
23pub struct LuaToolHandle {
25 pub name: String,
26 pub label: String,
27 pub description: String,
28 pub readonly: bool,
29 pub params: serde_json::Value,
30 pub execute_key: mlua::RegistryKey,
32}
33
34pub struct LuaHookHandle {
36 pub event: String,
37 pub handler_key: mlua::RegistryKey,
39}
40
41pub struct LuaCommandHandle {
43 pub name: String,
44 pub description: String,
45 pub handler_key: mlua::RegistryKey,
46}
47
48pub struct LuaCallContext {
54 pub cwd: PathBuf,
55 pub cancelled: Arc<std::sync::atomic::AtomicBool>,
56 pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
57 pub command_tx: tokio::sync::mpsc::Sender<imp_core::agent::AgentCommand>,
58 pub ui: Arc<dyn UserInterface>,
59 pub file_cache: Arc<FileCache>,
60 pub checkpoint_state: Arc<imp_core::tools::CheckpointState>,
61 pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
62 pub anchor_store: Arc<imp_core::tools::AnchorStore>,
63 pub lua_tool_loader: Option<imp_core::tools::LuaToolLoader>,
64 pub mode: AgentMode,
65 pub read_max_lines: usize,
66 pub run_policy: imp_core::policy::RunPolicy,
67 pub config: Arc<Config>,
68}
69
70impl LuaCallContext {
71 pub fn to_tool_context(&self) -> ToolContext {
73 ToolContext {
74 cwd: self.cwd.clone(),
75 cancelled: Arc::clone(&self.cancelled),
76 update_tx: self.update_tx.clone(),
77 command_tx: self.command_tx.clone(),
78 ui: Arc::clone(&self.ui),
79 file_cache: Arc::clone(&self.file_cache),
80 checkpoint_state: Arc::clone(&self.checkpoint_state),
81 file_tracker: Arc::clone(&self.file_tracker),
82 anchor_store: Arc::clone(&self.anchor_store),
83 lua_tool_loader: self.lua_tool_loader.clone(),
84 mode: self.mode,
85 read_max_lines: self.read_max_lines,
86 turn_mana_review: Arc::new(std::sync::Mutex::new(
87 imp_core::mana_review::TurnManaReviewAccumulator::default(),
88 )),
89 run_policy: self.run_policy.clone(),
90 config: Arc::clone(&self.config),
91 supporting_provenance: Vec::new(),
92 }
93 }
94}
95
96impl From<ToolContext> for LuaCallContext {
97 fn from(ctx: ToolContext) -> Self {
98 Self {
99 cwd: ctx.cwd,
100 cancelled: ctx.cancelled,
101 update_tx: ctx.update_tx,
102 command_tx: ctx.command_tx,
103 ui: ctx.ui,
104 file_cache: ctx.file_cache,
105 checkpoint_state: ctx.checkpoint_state,
106 file_tracker: ctx.file_tracker,
107 anchor_store: ctx.anchor_store,
108 lua_tool_loader: ctx.lua_tool_loader,
109 mode: ctx.mode,
110 read_max_lines: ctx.read_max_lines,
111 run_policy: ctx.run_policy,
112 config: ctx.config,
113 }
114 }
115}
116
117pub struct LuaRuntime {
119 lua: Lua,
120 tools: Arc<Mutex<Vec<LuaToolHandle>>>,
121 hooks: Arc<Mutex<Vec<LuaHookHandle>>>,
122 commands: Arc<Mutex<Vec<LuaCommandHandle>>>,
123 native_tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
125 call_context: Arc<Mutex<Option<LuaCallContext>>>,
127 allowed_env: Arc<Mutex<HashSet<String>>>,
129 allow_native_tool_calls: Arc<AtomicBool>,
131 allow_shell_exec: Arc<AtomicBool>,
133 allow_http: Arc<AtomicBool>,
135 allow_secrets: Arc<AtomicBool>,
137}
138
139impl LuaRuntime {
140 pub fn new() -> Result<Self, LuaError> {
142 let lua = Lua::new();
143 Ok(Self {
144 lua,
145 tools: Arc::new(Mutex::new(Vec::new())),
146 hooks: Arc::new(Mutex::new(Vec::new())),
147 commands: Arc::new(Mutex::new(Vec::new())),
148 native_tools: Arc::new(Mutex::new(HashMap::new())),
149 call_context: Arc::new(Mutex::new(None)),
150 allowed_env: Arc::new(Mutex::new(HashSet::new())),
151 allow_native_tool_calls: Arc::new(AtomicBool::new(true)),
152 allow_shell_exec: Arc::new(AtomicBool::new(false)),
153 allow_http: Arc::new(AtomicBool::new(false)),
154 allow_secrets: Arc::new(AtomicBool::new(false)),
155 })
156 }
157
158 pub fn lua(&self) -> &Lua {
160 &self.lua
161 }
162
163 pub fn tools(&self) -> Arc<Mutex<Vec<LuaToolHandle>>> {
165 Arc::clone(&self.tools)
166 }
167
168 pub fn hooks(&self) -> Arc<Mutex<Vec<LuaHookHandle>>> {
170 Arc::clone(&self.hooks)
171 }
172
173 pub fn commands(&self) -> Arc<Mutex<Vec<LuaCommandHandle>>> {
175 Arc::clone(&self.commands)
176 }
177
178 pub fn native_tools(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Tool>>>> {
180 Arc::clone(&self.native_tools)
181 }
182
183 pub fn call_context(&self) -> Arc<Mutex<Option<LuaCallContext>>> {
185 Arc::clone(&self.call_context)
186 }
187
188 pub fn allowed_env(&self) -> Arc<Mutex<HashSet<String>>> {
190 Arc::clone(&self.allowed_env)
191 }
192
193 pub fn allow_shell_exec(&self) -> Arc<AtomicBool> {
195 Arc::clone(&self.allow_shell_exec)
196 }
197
198 pub fn allow_http(&self) -> Arc<AtomicBool> {
200 Arc::clone(&self.allow_http)
201 }
202
203 pub fn allow_secrets(&self) -> Arc<AtomicBool> {
205 Arc::clone(&self.allow_secrets)
206 }
207
208 pub fn allow_native_tool_calls(&self) -> Arc<AtomicBool> {
210 Arc::clone(&self.allow_native_tool_calls)
211 }
212
213 pub fn set_native_tools(&self, tools: HashMap<String, Arc<dyn Tool>>) {
215 *self.native_tools.lock().unwrap() = tools;
216 }
217
218 pub fn set_call_context(&self, ctx: LuaCallContext) {
220 *self.call_context.lock().unwrap() = Some(ctx);
221 }
222
223 pub fn clear_call_context(&self) {
225 *self.call_context.lock().unwrap() = None;
226 }
227
228 pub fn set_allowed_env(&self, vars: HashSet<String>) {
230 *self.allowed_env.lock().unwrap() = vars;
231 }
232
233 pub fn set_allow_shell_exec(&self, allowed: bool) {
235 self.allow_shell_exec.store(allowed, Ordering::Relaxed);
236 }
237
238 pub fn set_allow_http(&self, allowed: bool) {
240 self.allow_http.store(allowed, Ordering::Relaxed);
241 }
242
243 pub fn set_allow_secrets(&self, allowed: bool) {
245 self.allow_secrets.store(allowed, Ordering::Relaxed);
246 }
247
248 pub fn set_allow_native_tool_calls(&self, allowed: bool) {
250 self.allow_native_tool_calls
251 .store(allowed, Ordering::Relaxed);
252 }
253
254 pub fn apply_capability_policy(&self, policy: &LuaCapabilityPolicy) {
256 self.set_allow_native_tool_calls(policy.allow_native_tool_calls);
257 self.set_allow_shell_exec(policy.allow_shell_exec);
258 self.set_allow_http(policy.allow_http);
259 self.set_allow_secrets(policy.allow_secrets);
260 self.set_allowed_env(policy.allowed_env.clone());
261 }
262
263 pub fn register_tool(&self, handle: LuaToolHandle) {
265 self.tools.lock().unwrap().push(handle);
266 }
267
268 pub fn register_hook(&self, handle: LuaHookHandle) {
270 self.hooks.lock().unwrap().push(handle);
271 }
272
273 pub fn register_command(&self, handle: LuaCommandHandle) {
275 self.commands.lock().unwrap().push(handle);
276 }
277
278 pub fn exec(&self, source: &str) -> Result<(), LuaError> {
280 self.lua.load(source).exec()?;
281 Ok(())
282 }
283
284 pub fn exec_file(&self, path: &std::path::Path) -> Result<(), LuaError> {
286 let source = std::fs::read_to_string(path)
287 .map_err(|e| LuaError::Extension(format!("{}: {}", path.display(), e)))?;
288 self.lua
289 .load(&source)
290 .set_name(path.to_string_lossy())
291 .exec()?;
292 Ok(())
293 }
294
295 pub fn clear_registrations(&self) {
297 self.tools.lock().unwrap().clear();
298 self.hooks.lock().unwrap().clear();
299 self.commands.lock().unwrap().clear();
300 }
301
302 pub fn tool_count(&self) -> usize {
304 self.tools.lock().unwrap().len()
305 }
306
307 pub fn hook_count(&self) -> usize {
309 self.hooks.lock().unwrap().len()
310 }
311
312 pub fn command_count(&self) -> usize {
314 self.commands.lock().unwrap().len()
315 }
316
317 pub fn tool_names(&self) -> Vec<String> {
319 self.tools
320 .lock()
321 .unwrap()
322 .iter()
323 .map(|t| t.name.clone())
324 .collect()
325 }
326
327 pub fn hook_events(&self) -> Vec<String> {
329 self.hooks
330 .lock()
331 .unwrap()
332 .iter()
333 .map(|h| h.event.clone())
334 .collect()
335 }
336
337 pub fn execute_command(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
343 self.execute_command_with_context(name, args, None)
344 }
345
346 pub fn execute_command_with_context(
348 &self,
349 name: &str,
350 args: &str,
351 call_ctx: Option<LuaCallContext>,
352 ) -> Result<Option<String>, LuaError> {
353 if let Some(ctx) = call_ctx {
354 self.set_call_context(ctx);
355 }
356 let result = self.execute_command_inner(name, args);
357 self.clear_call_context();
358 result
359 }
360
361 fn execute_command_inner(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
362 let commands = self.commands.lock().unwrap();
363 let handle = commands
364 .iter()
365 .find(|c| c.name == name)
366 .ok_or_else(|| LuaError::Extension(format!("command '{name}' not found")))?;
367
368 let handler: mlua::Function = self
369 .lua
370 .registry_value(&handle.handler_key)
371 .map_err(LuaError::Mlua)?;
372
373 let result: mlua::Value = handler.call(args.to_string()).map_err(LuaError::Mlua)?;
374
375 match result {
376 mlua::Value::Nil => Ok(None),
377 mlua::Value::String(s) => Ok(Some(
378 s.to_str()
379 .map(|v| v.to_string())
380 .unwrap_or_else(|_| "(non-utf8)".into()),
381 )),
382 other => {
383 let json = crate::bridge::lua_value_to_json(other);
384 Ok(Some(format!("{json}")))
385 }
386 }
387 }
388
389 pub fn command_names(&self) -> Vec<String> {
391 self.commands
392 .lock()
393 .unwrap()
394 .iter()
395 .map(|c| c.name.clone())
396 .collect()
397 }
398
399 pub fn command_summaries(&self) -> Vec<(String, String)> {
401 self.commands
402 .lock()
403 .unwrap()
404 .iter()
405 .map(|c| (c.name.clone(), c.description.clone()))
406 .collect()
407 }
408
409 pub fn has_command(&self, name: &str) -> bool {
411 self.commands.lock().unwrap().iter().any(|c| c.name == name)
412 }
413}