1use std::process::{Command, Stdio};
2
3use async_trait::async_trait;
4use imp_core::storage;
5use imp_core::tools::lua::{parameter_schema_from_lua, tool_output_from_lua_result};
6use imp_core::tools::{Tool, ToolContext, ToolOutput, ToolRegistry};
7use imp_core::Error as CoreError;
8use imp_llm::auth::AuthStore;
9use mlua::{Function, Lua, MultiValue, Table, Value};
10use serde_json::json;
11use std::sync::{Arc, Mutex};
12
13use crate::sandbox::{
14 LuaCallContext, LuaCommandHandle, LuaError, LuaHookHandle, LuaRuntime, LuaToolHandle,
15};
16
17pub struct LuaTool {
20 name: String,
21 label: String,
22 description: String,
23 readonly: bool,
24 params: serde_json::Value,
25 runtime: Arc<Mutex<LuaRuntime>>,
26 handle_index: usize,
27}
28
29#[async_trait]
30impl Tool for LuaTool {
31 fn name(&self) -> &str {
32 &self.name
33 }
34
35 fn label(&self) -> &str {
36 &self.label
37 }
38
39 fn description(&self) -> &str {
40 &self.description
41 }
42
43 fn parameters(&self) -> serde_json::Value {
44 parameter_schema_from_lua(&self.params)
45 }
46
47 fn is_readonly(&self) -> bool {
48 self.readonly
49 }
50
51 async fn execute(
52 &self,
53 call_id: &str,
54 params: serde_json::Value,
55 ctx: ToolContext,
56 ) -> imp_core::Result<ToolOutput> {
57 let runtime = Arc::clone(&self.runtime);
58 let handle_index = self.handle_index;
59 let call_id = call_id.to_string();
60 let ctx_json = json!({
61 "cwd": ctx.cwd.display().to_string(),
62 "cancelled": ctx.is_cancelled(),
63 });
64 let call_ctx = LuaCallContext {
65 cwd: ctx.cwd,
66 cancelled: ctx.cancelled,
67 update_tx: ctx.update_tx,
68 command_tx: ctx.command_tx,
69 ui: ctx.ui,
70 file_cache: ctx.file_cache,
71 checkpoint_state: ctx.checkpoint_state,
72 file_tracker: ctx.file_tracker,
73 anchor_store: ctx.anchor_store,
74 lua_tool_loader: ctx.lua_tool_loader,
75 mode: ctx.mode,
76 read_max_lines: ctx.read_max_lines,
77 config: ctx.config,
78 };
79
80 tokio::task::spawn_blocking(move || {
81 let runtime_guard = runtime
82 .lock()
83 .map_err(|_| CoreError::Tool("Lua runtime lock poisoned".into()))?;
84
85 runtime_guard.set_call_context(call_ctx);
87
88 let result = (|| {
89 let tools = runtime_guard.tools();
90 let handles = tools
91 .lock()
92 .map_err(|_| CoreError::Tool("Lua tool registry lock poisoned".into()))?;
93 let handle = handles.get(handle_index).ok_or_else(|| {
94 CoreError::Tool(format!("Lua tool handle {handle_index} not found"))
95 })?;
96
97 let execute_fn: Function = runtime_guard
98 .lua()
99 .registry_value(&handle.execute_key)
100 .map_err(lua_tool_error)?;
101 let lua_params =
102 json_to_lua_value(runtime_guard.lua(), ¶ms).map_err(lua_tool_error)?;
103 let lua_ctx =
104 json_to_lua_value(runtime_guard.lua(), &ctx_json).map_err(lua_tool_error)?;
105 let result: Value = execute_fn
106 .call((call_id.as_str(), lua_params, lua_ctx))
107 .map_err(lua_tool_error)?;
108
109 tool_output_from_lua_result(lua_value_to_json(result))
110 })();
111
112 runtime_guard.clear_call_context();
113 result
114 })
115 .await
116 .map_err(|error| CoreError::Tool(format!("Lua tool task failed: {error}")))?
117 }
118}
119
120pub fn load_lua_tools(runtime: Arc<Mutex<LuaRuntime>>, registry: &mut ToolRegistry) {
122 let handles = {
123 let runtime_guard = runtime
124 .lock()
125 .expect("Lua runtime lock poisoned while loading tools");
126 let tools = runtime_guard.tools();
127 let handles = tools
128 .lock()
129 .expect("Lua tool registry lock poisoned while loading tools");
130
131 handles
132 .iter()
133 .enumerate()
134 .map(|(index, handle)| LuaTool {
135 name: handle.name.clone(),
136 label: handle.label.clone(),
137 description: handle.description.clone(),
138 readonly: handle.readonly,
139 params: handle.params.clone(),
140 runtime: Arc::clone(&runtime),
141 handle_index: index,
142 })
143 .collect::<Vec<_>>()
144 };
145
146 for tool in handles {
147 registry.register(Arc::new(tool));
148 }
149}
150
151fn lua_tool_error(error: mlua::Error) -> CoreError {
152 CoreError::Tool(format!("Lua tool error: {error}"))
153}
154
155fn extract_header_pairs(headers: Option<Table>) -> mlua::Result<Vec<(String, String)>> {
157 let mut pairs = Vec::new();
158 if let Some(tbl) = headers {
159 for pair in tbl.pairs::<String, String>() {
160 let (k, v) = pair?;
161 pairs.push((k, v));
162 }
163 }
164 Ok(pairs)
165}
166
167pub fn setup_host_api(runtime: &LuaRuntime) -> Result<(), LuaError> {
182 let lua = runtime.lua();
183
184 let imp = lua.create_table()?;
185
186 let hooks = runtime.hooks();
188 let on_fn = lua.create_function(move |lua_inner, (event, handler): (String, Function)| {
189 let key = lua_inner.create_registry_value(handler)?;
190 let handle = LuaHookHandle {
191 event,
192 handler_key: key,
193 };
194 hooks.lock().unwrap().push(handle);
195 Ok(())
196 })?;
197 imp.set("on", on_fn)?;
198
199 let tools = runtime.tools();
201 let register_tool_fn = lua.create_function(move |lua_inner, def: Table| {
202 let name: String = def.get("name")?;
203 let label: String = def
204 .get::<Option<String>>("label")?
205 .unwrap_or_else(|| name.clone());
206 let description: String = def
207 .get::<Option<String>>("description")?
208 .unwrap_or_default();
209 let readonly: bool = def.get::<Option<bool>>("readonly")?.unwrap_or(false);
210
211 let params_val: Value = def.get("params")?;
212 let params = lua_value_to_json(params_val);
213
214 let execute_fn: Function = def.get("execute")?;
215 let key = lua_inner.create_registry_value(execute_fn)?;
216
217 let handle = LuaToolHandle {
218 name,
219 label,
220 description,
221 readonly,
222 params,
223 execute_key: key,
224 };
225 tools.lock().unwrap().push(handle);
226 Ok(())
227 })?;
228 imp.set("register_tool", register_tool_fn)?;
229
230 let allow_shell_exec = runtime.allow_shell_exec();
232 let exec_fn = lua.create_function(
233 move |lua_inner, (cmd, args, opts): (String, Option<Table>, Option<Table>)| {
234 if !allow_shell_exec.load(std::sync::atomic::Ordering::Relaxed) {
235 return Err(mlua::Error::external(
236 "imp.exec() is disabled for this runtime",
237 ));
238 }
239 let mut command = Command::new("sh");
240 command.arg("-c");
241
242 let full_cmd = if let Some(args_table) = args {
244 let mut parts = vec![cmd];
245 for pair in args_table.sequence_values::<String>() {
246 parts.push(pair?);
247 }
248 parts.join(" ")
249 } else {
250 cmd
251 };
252 command.stdin(Stdio::null()).arg(&full_cmd);
253
254 if let Some(opts_table) = &opts {
256 if let Ok(Some(cwd)) = opts_table.get::<Option<String>>("cwd") {
257 command.current_dir(cwd);
258 }
259 if let Ok(Some(env_table)) = opts_table.get::<Option<Table>>("env") {
260 for pair in env_table.pairs::<String, String>() {
261 let (name, value) = pair?;
262 command.env(name, value);
263 }
264 }
265 }
266
267 let output = command.output().map_err(mlua::Error::external)?;
268
269 let result = lua_inner.create_table()?;
270 result.set(
271 "stdout",
272 String::from_utf8_lossy(&output.stdout).to_string(),
273 )?;
274 result.set(
275 "stderr",
276 String::from_utf8_lossy(&output.stderr).to_string(),
277 )?;
278 result.set("exit_code", output.status.code().unwrap_or(-1))?;
279
280 Ok(result)
281 },
282 )?;
283 imp.set("exec", exec_fn)?;
284
285 let commands = runtime.commands();
287 let register_command_fn =
288 lua.create_function(move |lua_inner, (name, def): (String, Table)| {
289 let description: String = def
290 .get::<Option<String>>("description")?
291 .unwrap_or_default();
292 let handler: Function = def.get("handler")?;
293 let key = lua_inner.create_registry_value(handler)?;
294
295 let handle = LuaCommandHandle {
296 name,
297 description,
298 handler_key: key,
299 };
300 commands.lock().unwrap().push(handle);
301 Ok(())
302 })?;
303 imp.set("register_command", register_command_fn)?;
304
305 let events = lua.create_table()?;
307
308 let handlers_table = lua.create_table()?;
310 lua.set_named_registry_value("__imp_event_handlers", handlers_table)?;
311
312 let events_on = lua.create_function(|lua_inner, (name, handler): (String, Function)| {
313 let handlers: Table = lua_inner.named_registry_value("__imp_event_handlers")?;
314 let list: Table = match handlers.get::<Option<Table>>(name.as_str())? {
315 Some(t) => t,
316 None => {
317 let t = lua_inner.create_table()?;
318 handlers.set(name.as_str(), t.clone())?;
319 t
320 }
321 };
322 let len = list.raw_len();
323 list.set(len + 1, handler)?;
324 Ok(())
325 })?;
326 events.set("on", events_on)?;
327
328 let events_emit = lua.create_function(|lua_inner, (name, data): (String, Value)| {
329 let handlers: Table = lua_inner.named_registry_value("__imp_event_handlers")?;
330 if let Some(list) = handlers.get::<Option<Table>>(name.as_str())? {
331 for pair in list.sequence_values::<Function>() {
332 let handler = pair?;
333 let _ = handler.call::<()>(data.clone());
336 }
337 }
338 Ok(())
339 })?;
340 events.set("emit", events_emit)?;
341
342 imp.set("events", events)?;
343
344 let native_tools = runtime.native_tools();
346 let tool_call_ctx = runtime.call_context();
347 let allow_native_tool_calls = runtime.allow_native_tool_calls();
348 let imp_tool_fn = lua.create_function(
349 move |lua_inner, (name, params): (String, Value)| -> mlua::Result<MultiValue> {
350 if !allow_native_tool_calls.load(std::sync::atomic::Ordering::Relaxed) {
351 return Err(mlua::Error::external(
352 "imp.tool() is disabled for this runtime",
353 ));
354 }
355
356 let tool = {
358 let tools_guard = native_tools
359 .lock()
360 .map_err(|_| mlua::Error::external("native tools lock poisoned"))?;
361 tools_guard
362 .get(&name)
363 .cloned()
364 .ok_or_else(|| mlua::Error::external(format!("tool '{name}' not found")))?
365 };
366
367 let ctx = {
369 let ctx_guard = tool_call_ctx
370 .lock()
371 .map_err(|_| mlua::Error::external("call context lock poisoned"))?;
372 ctx_guard
373 .as_ref()
374 .ok_or_else(|| {
375 mlua::Error::external("imp.tool() called outside of tool execution context")
376 })?
377 .to_tool_context()
378 };
379
380 let params_json = lua_value_to_json(params);
381
382 let handle = tokio::runtime::Handle::try_current()
384 .map_err(|_| mlua::Error::external("imp.tool() requires a tokio runtime"))?;
385
386 let output = handle
387 .block_on(tool.execute("lua-call", params_json, ctx))
388 .map_err(|e| mlua::Error::external(format!("tool error: {e}")))?;
389
390 let mut mv = MultiValue::new();
392 if output.is_error {
393 let err_text = output
394 .text_content()
395 .unwrap_or("tool execution failed")
396 .to_string();
397 mv.push_back(Value::Nil);
398 mv.push_back(Value::String(lua_inner.create_string(&err_text)?));
399 } else if let Some(text) = output.text_content() {
400 mv.push_back(Value::String(lua_inner.create_string(text)?));
401 } else {
402 mv.push_back(Value::Nil);
403 }
404 Ok(mv)
405 },
406 )?;
407 imp.set("tool", imp_tool_fn)?;
408
409 let update_call_ctx = runtime.call_context();
411 let imp_update_fn = lua.create_function(move |_lua, text: String| {
412 let ctx_guard = update_call_ctx
413 .lock()
414 .map_err(|_| mlua::Error::external("call context lock poisoned"))?;
415 if let Some(ref ctx) = *ctx_guard {
416 let _ = ctx.update_tx.try_send(imp_core::tools::ToolUpdate {
417 content: vec![imp_core::imp_llm::ContentBlock::Text { text }],
418 details: serde_json::Value::Null,
419 });
420 }
421 Ok(())
422 })?;
423 imp.set("update", imp_update_fn)?;
424
425 let allow_secrets = runtime.allow_secrets();
427 let secret_fn = lua.create_function(
428 move |lua_inner, (provider, field): (String, Option<String>)| -> mlua::Result<Value> {
429 if !allow_secrets.load(std::sync::atomic::Ordering::Relaxed) {
430 return Err(mlua::Error::external(
431 "imp.secret() is disabled for this runtime",
432 ));
433 }
434 let auth_path =
435 storage::existing_global_auth_path().unwrap_or_else(storage::global_auth_path);
436 let auth_store =
437 AuthStore::load(&auth_path).unwrap_or_else(|_| AuthStore::new(auth_path.clone()));
438 let field = field.unwrap_or_else(|| "api_key".to_string());
439 match auth_store.resolve_secret_field(&provider, &field) {
440 Ok(value) => Ok(Value::String(lua_inner.create_string(&value)?)),
441 Err(error) => Err(mlua::Error::external(error.to_string())),
442 }
443 },
444 )?;
445 imp.set("secret", secret_fn)?;
446
447 let allow_secrets = runtime.allow_secrets();
449 let secret_fields_fn =
450 lua.create_function(move |lua_inner, provider: String| -> mlua::Result<Value> {
451 if !allow_secrets.load(std::sync::atomic::Ordering::Relaxed) {
452 return Err(mlua::Error::external(
453 "imp.secret_fields() is disabled for this runtime",
454 ));
455 }
456 let auth_path =
457 storage::existing_global_auth_path().unwrap_or_else(storage::global_auth_path);
458 let auth_store =
459 AuthStore::load(&auth_path).unwrap_or_else(|_| AuthStore::new(auth_path.clone()));
460 match auth_store.resolve_secret_fields(&provider) {
461 Ok(fields) => {
462 let table = lua_inner.create_table()?;
463 for (field, value) in fields {
464 table.set(field, value)?;
465 }
466 Ok(Value::Table(table))
467 }
468 Err(error) => Err(mlua::Error::external(error.to_string())),
469 }
470 })?;
471 imp.set("secret_fields", secret_fields_fn)?;
472
473 let allowed_env = runtime.allowed_env();
475 let env_fn = lua.create_function(move |lua_inner, name: String| {
476 let allowed = allowed_env
477 .lock()
478 .map_err(|_| mlua::Error::external("allowed_env lock poisoned"))?;
479 if !allowed.contains(&name) {
481 return Ok(Value::Nil);
482 }
483 match std::env::var(&name) {
484 Ok(val) => Ok(Value::String(lua_inner.create_string(&val)?)),
485 Err(_) => Ok(Value::Nil),
486 }
487 })?;
488 imp.set("env", env_fn)?;
489
490 let http = lua.create_table()?;
492 let allow_http = runtime.allow_http();
493
494 let http_get_fn =
495 lua.create_function(move |lua_inner, (url, headers): (String, Option<Table>)| {
496 if !allow_http.load(std::sync::atomic::Ordering::Relaxed) {
497 return Err(mlua::Error::external(
498 "imp.http.get() is disabled for this runtime",
499 ));
500 }
501 let header_pairs = extract_header_pairs(headers)?;
502
503 let handle = tokio::runtime::Handle::try_current()
504 .map_err(|_| mlua::Error::external("imp.http requires a tokio runtime"))?;
505
506 let (status, body) = handle
507 .block_on(async {
508 let client = reqwest::Client::new();
509 let mut builder = client.get(&url);
510 for (k, v) in &header_pairs {
511 builder = builder.header(k.as_str(), v.as_str());
512 }
513 let resp = builder.send().await.map_err(|e| e.to_string())?;
514 let status = resp.status().as_u16();
515 let body = resp.text().await.map_err(|e| e.to_string())?;
516 Ok::<_, String>((status, body))
517 })
518 .map_err(mlua::Error::external)?;
519
520 let result = lua_inner.create_table()?;
521 result.set("status", status)?;
522 result.set("body", body)?;
523 Ok(result)
524 })?;
525 http.set("get", http_get_fn)?;
526
527 let allow_http = runtime.allow_http();
528 let http_post_fn = lua.create_function(
529 move |lua_inner, (url, body, headers): (String, String, Option<Table>)| {
530 if !allow_http.load(std::sync::atomic::Ordering::Relaxed) {
531 return Err(mlua::Error::external(
532 "imp.http.post() is disabled for this runtime",
533 ));
534 }
535 let header_pairs = extract_header_pairs(headers)?;
536
537 let handle = tokio::runtime::Handle::try_current()
538 .map_err(|_| mlua::Error::external("imp.http requires a tokio runtime"))?;
539
540 let (status, resp_body) = handle
541 .block_on(async {
542 let client = reqwest::Client::new();
543 let mut builder = client.post(&url).body(body);
544 for (k, v) in &header_pairs {
545 builder = builder.header(k.as_str(), v.as_str());
546 }
547 let resp = builder.send().await.map_err(|e| e.to_string())?;
548 let status = resp.status().as_u16();
549 let resp_body = resp.text().await.map_err(|e| e.to_string())?;
550 Ok::<_, String>((status, resp_body))
551 })
552 .map_err(mlua::Error::external)?;
553
554 let result = lua_inner.create_table()?;
555 result.set("status", status)?;
556 result.set("body", resp_body)?;
557 Ok(result)
558 },
559 )?;
560 http.set("post", http_post_fn)?;
561
562 imp.set("http", http)?;
563
564 lua.globals().set("imp", imp)?;
566
567 Ok(())
568}
569
570pub fn lua_value_to_json(value: Value) -> serde_json::Value {
572 match value {
573 Value::Nil => serde_json::Value::Null,
574 Value::Boolean(b) => serde_json::Value::Bool(b),
575 Value::Integer(i) => serde_json::Value::Number(serde_json::Number::from(i)),
576 Value::Number(n) => serde_json::Number::from_f64(n)
577 .map(serde_json::Value::Number)
578 .unwrap_or(serde_json::Value::Null),
579 Value::String(s) => {
580 serde_json::Value::String(s.to_str().map(|s| s.to_string()).unwrap_or_default())
581 }
582 Value::Table(t) => {
583 let len = t.raw_len();
585 if len > 0 {
586 let is_array = (1..=len).all(|i| {
588 t.get::<Value>(i)
589 .ok()
590 .map(|v| !matches!(v, Value::Nil))
591 .unwrap_or(false)
592 });
593 if is_array {
594 let arr: Vec<serde_json::Value> = (1..=len)
595 .filter_map(|i| t.get::<Value>(i).ok().map(lua_value_to_json))
596 .collect();
597 return serde_json::Value::Array(arr);
598 }
599 }
600
601 let mut map = serde_json::Map::new();
603 if let Ok(pairs) = t.pairs::<String, Value>().collect::<Result<Vec<_>, _>>() {
604 for (k, v) in pairs {
605 map.insert(k, lua_value_to_json(v));
606 }
607 }
608 serde_json::Value::Object(map)
609 }
610 _ => serde_json::Value::Null,
611 }
612}
613
614pub fn json_to_lua_value(lua: &Lua, value: &serde_json::Value) -> mlua::Result<Value> {
616 match value {
617 serde_json::Value::Null => Ok(Value::Nil),
618 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
619 serde_json::Value::Number(n) => {
620 if let Some(i) = n.as_i64() {
621 Ok(Value::Integer(i))
622 } else if let Some(f) = n.as_f64() {
623 Ok(Value::Number(f))
624 } else {
625 Ok(Value::Nil)
626 }
627 }
628 serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
629 serde_json::Value::Array(arr) => {
630 let table = lua.create_table()?;
631 for (i, v) in arr.iter().enumerate() {
632 table.set(i + 1, json_to_lua_value(lua, v)?)?;
633 }
634 Ok(Value::Table(table))
635 }
636 serde_json::Value::Object(map) => {
637 let table = lua.create_table()?;
638 for (k, v) in map {
639 table.set(k.as_str(), json_to_lua_value(lua, v)?)?;
640 }
641 Ok(Value::Table(table))
642 }
643 }
644}