1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3use std::sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc, Mutex,
6};
7
8use imp_core::config::{AgentMode, Config, LuaCapabilityPolicy};
9use imp_core::tools::{FileCache, FileTracker, Tool, ToolContext, ToolUpdate};
10use imp_core::ui::UserInterface;
11use mlua::Lua;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum LuaError {
16 #[error("Lua error: {0}")]
17 Mlua(#[from] mlua::Error),
18
19 #[error("Extension error: {0}")]
20 Extension(String),
21}
22
23pub struct LuaToolHandle {
25 pub name: String,
26 pub label: String,
27 pub description: String,
28 pub readonly: bool,
29 pub params: serde_json::Value,
30 pub execute_key: mlua::RegistryKey,
32}
33
34pub struct LuaHookHandle {
36 pub event: String,
37 pub handler_key: mlua::RegistryKey,
39}
40
41pub struct LuaCommandHandle {
43 pub name: String,
44 pub description: String,
45 pub handler_key: mlua::RegistryKey,
46}
47
48pub struct LuaCallContext {
54 pub cwd: PathBuf,
55 pub cancelled: Arc<std::sync::atomic::AtomicBool>,
56 pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
57 pub command_tx: tokio::sync::mpsc::Sender<imp_core::agent::AgentCommand>,
58 pub ui: Arc<dyn UserInterface>,
59 pub file_cache: Arc<FileCache>,
60 pub checkpoint_state: Arc<imp_core::tools::CheckpointState>,
61 pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
62 pub anchor_store: Arc<imp_core::tools::AnchorStore>,
63 pub lua_tool_loader: Option<imp_core::tools::LuaToolLoader>,
64 pub mode: AgentMode,
65 pub read_max_lines: usize,
66 pub config: Arc<Config>,
67}
68
69impl LuaCallContext {
70 pub fn to_tool_context(&self) -> ToolContext {
72 ToolContext {
73 cwd: self.cwd.clone(),
74 cancelled: Arc::clone(&self.cancelled),
75 update_tx: self.update_tx.clone(),
76 command_tx: self.command_tx.clone(),
77 ui: Arc::clone(&self.ui),
78 file_cache: Arc::clone(&self.file_cache),
79 checkpoint_state: Arc::clone(&self.checkpoint_state),
80 file_tracker: Arc::clone(&self.file_tracker),
81 anchor_store: Arc::clone(&self.anchor_store),
82 lua_tool_loader: self.lua_tool_loader.clone(),
83 mode: self.mode,
84 read_max_lines: self.read_max_lines,
85 turn_mana_review: Arc::new(std::sync::Mutex::new(
86 imp_core::mana_review::TurnManaReviewAccumulator::default(),
87 )),
88 config: Arc::clone(&self.config),
89 }
90 }
91}
92
93pub struct LuaRuntime {
95 lua: Lua,
96 tools: Arc<Mutex<Vec<LuaToolHandle>>>,
97 hooks: Arc<Mutex<Vec<LuaHookHandle>>>,
98 commands: Arc<Mutex<Vec<LuaCommandHandle>>>,
99 native_tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
101 call_context: Arc<Mutex<Option<LuaCallContext>>>,
103 allowed_env: Arc<Mutex<HashSet<String>>>,
105 allow_native_tool_calls: Arc<AtomicBool>,
107 allow_shell_exec: Arc<AtomicBool>,
109 allow_http: Arc<AtomicBool>,
111 allow_secrets: Arc<AtomicBool>,
113}
114
115impl LuaRuntime {
116 pub fn new() -> Result<Self, LuaError> {
118 let lua = Lua::new();
119 Ok(Self {
120 lua,
121 tools: Arc::new(Mutex::new(Vec::new())),
122 hooks: Arc::new(Mutex::new(Vec::new())),
123 commands: Arc::new(Mutex::new(Vec::new())),
124 native_tools: Arc::new(Mutex::new(HashMap::new())),
125 call_context: Arc::new(Mutex::new(None)),
126 allowed_env: Arc::new(Mutex::new(HashSet::new())),
127 allow_native_tool_calls: Arc::new(AtomicBool::new(true)),
128 allow_shell_exec: Arc::new(AtomicBool::new(false)),
129 allow_http: Arc::new(AtomicBool::new(false)),
130 allow_secrets: Arc::new(AtomicBool::new(false)),
131 })
132 }
133
134 pub fn lua(&self) -> &Lua {
136 &self.lua
137 }
138
139 pub fn tools(&self) -> Arc<Mutex<Vec<LuaToolHandle>>> {
141 Arc::clone(&self.tools)
142 }
143
144 pub fn hooks(&self) -> Arc<Mutex<Vec<LuaHookHandle>>> {
146 Arc::clone(&self.hooks)
147 }
148
149 pub fn commands(&self) -> Arc<Mutex<Vec<LuaCommandHandle>>> {
151 Arc::clone(&self.commands)
152 }
153
154 pub fn native_tools(&self) -> Arc<Mutex<HashMap<String, Arc<dyn Tool>>>> {
156 Arc::clone(&self.native_tools)
157 }
158
159 pub fn call_context(&self) -> Arc<Mutex<Option<LuaCallContext>>> {
161 Arc::clone(&self.call_context)
162 }
163
164 pub fn allowed_env(&self) -> Arc<Mutex<HashSet<String>>> {
166 Arc::clone(&self.allowed_env)
167 }
168
169 pub fn allow_shell_exec(&self) -> Arc<AtomicBool> {
171 Arc::clone(&self.allow_shell_exec)
172 }
173
174 pub fn allow_http(&self) -> Arc<AtomicBool> {
176 Arc::clone(&self.allow_http)
177 }
178
179 pub fn allow_secrets(&self) -> Arc<AtomicBool> {
181 Arc::clone(&self.allow_secrets)
182 }
183
184 pub fn allow_native_tool_calls(&self) -> Arc<AtomicBool> {
186 Arc::clone(&self.allow_native_tool_calls)
187 }
188
189 pub fn set_native_tools(&self, tools: HashMap<String, Arc<dyn Tool>>) {
191 *self.native_tools.lock().unwrap() = tools;
192 }
193
194 pub fn set_call_context(&self, ctx: LuaCallContext) {
196 *self.call_context.lock().unwrap() = Some(ctx);
197 }
198
199 pub fn clear_call_context(&self) {
201 *self.call_context.lock().unwrap() = None;
202 }
203
204 pub fn set_allowed_env(&self, vars: HashSet<String>) {
206 *self.allowed_env.lock().unwrap() = vars;
207 }
208
209 pub fn set_allow_shell_exec(&self, allowed: bool) {
211 self.allow_shell_exec.store(allowed, Ordering::Relaxed);
212 }
213
214 pub fn set_allow_http(&self, allowed: bool) {
216 self.allow_http.store(allowed, Ordering::Relaxed);
217 }
218
219 pub fn set_allow_secrets(&self, allowed: bool) {
221 self.allow_secrets.store(allowed, Ordering::Relaxed);
222 }
223
224 pub fn set_allow_native_tool_calls(&self, allowed: bool) {
226 self.allow_native_tool_calls
227 .store(allowed, Ordering::Relaxed);
228 }
229
230 pub fn apply_capability_policy(&self, policy: &LuaCapabilityPolicy) {
232 self.set_allow_native_tool_calls(policy.allow_native_tool_calls);
233 self.set_allow_shell_exec(policy.allow_shell_exec);
234 self.set_allow_http(policy.allow_http);
235 self.set_allow_secrets(policy.allow_secrets);
236 self.set_allowed_env(policy.allowed_env.clone());
237 }
238
239 pub fn register_tool(&self, handle: LuaToolHandle) {
241 self.tools.lock().unwrap().push(handle);
242 }
243
244 pub fn register_hook(&self, handle: LuaHookHandle) {
246 self.hooks.lock().unwrap().push(handle);
247 }
248
249 pub fn register_command(&self, handle: LuaCommandHandle) {
251 self.commands.lock().unwrap().push(handle);
252 }
253
254 pub fn exec(&self, source: &str) -> Result<(), LuaError> {
256 self.lua.load(source).exec()?;
257 Ok(())
258 }
259
260 pub fn exec_file(&self, path: &std::path::Path) -> Result<(), LuaError> {
262 let source = std::fs::read_to_string(path)
263 .map_err(|e| LuaError::Extension(format!("{}: {}", path.display(), e)))?;
264 self.lua
265 .load(&source)
266 .set_name(path.to_string_lossy())
267 .exec()?;
268 Ok(())
269 }
270
271 pub fn clear_registrations(&self) {
273 self.tools.lock().unwrap().clear();
274 self.hooks.lock().unwrap().clear();
275 self.commands.lock().unwrap().clear();
276 }
277
278 pub fn tool_count(&self) -> usize {
280 self.tools.lock().unwrap().len()
281 }
282
283 pub fn hook_count(&self) -> usize {
285 self.hooks.lock().unwrap().len()
286 }
287
288 pub fn command_count(&self) -> usize {
290 self.commands.lock().unwrap().len()
291 }
292
293 pub fn tool_names(&self) -> Vec<String> {
295 self.tools
296 .lock()
297 .unwrap()
298 .iter()
299 .map(|t| t.name.clone())
300 .collect()
301 }
302
303 pub fn hook_events(&self) -> Vec<String> {
305 self.hooks
306 .lock()
307 .unwrap()
308 .iter()
309 .map(|h| h.event.clone())
310 .collect()
311 }
312
313 pub fn execute_command(&self, name: &str, args: &str) -> Result<Option<String>, LuaError> {
319 let commands = self.commands.lock().unwrap();
320 let handle = commands
321 .iter()
322 .find(|c| c.name == name)
323 .ok_or_else(|| LuaError::Extension(format!("command '{name}' not found")))?;
324
325 let handler: mlua::Function = self
326 .lua
327 .registry_value(&handle.handler_key)
328 .map_err(LuaError::Mlua)?;
329
330 let result: mlua::Value = handler.call(args.to_string()).map_err(LuaError::Mlua)?;
331
332 match result {
333 mlua::Value::Nil => Ok(None),
334 mlua::Value::String(s) => Ok(Some(
335 s.to_str()
336 .map(|v| v.to_string())
337 .unwrap_or_else(|_| "(non-utf8)".into()),
338 )),
339 other => {
340 let json = crate::bridge::lua_value_to_json(other);
341 Ok(Some(format!("{json}")))
342 }
343 }
344 }
345
346 pub fn command_names(&self) -> Vec<String> {
348 self.commands
349 .lock()
350 .unwrap()
351 .iter()
352 .map(|c| c.name.clone())
353 .collect()
354 }
355
356 pub fn has_command(&self, name: &str) -> bool {
358 self.commands.lock().unwrap().iter().any(|c| c.name == name)
359 }
360}