1pub mod bridge;
2pub mod loader;
3pub mod sandbox;
4
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7
8use imp_core::config::LuaCapabilityPolicy;
9use imp_core::tools::ToolRegistry;
10
11pub use bridge::{json_to_lua_value, load_lua_tools, lua_value_to_json, setup_host_api, LuaTool};
12pub use loader::{discover_extensions, load_extensions, reload, LuaExtension};
13pub use sandbox::{
14 LuaCallContext, LuaCommandHandle, LuaError, LuaHookHandle, LuaRuntime, LuaToolHandle,
15};
16
17pub fn init_lua_extensions(
23 user_config_dir: &Path,
24 project_dir: Option<&Path>,
25 tools: &mut ToolRegistry,
26 policy: &LuaCapabilityPolicy,
27) -> Option<Arc<Mutex<LuaRuntime>>> {
28 let extensions = discover_extensions(user_config_dir, project_dir);
29 if extensions.is_empty() {
30 return None;
31 }
32
33 let rt = match LuaRuntime::new() {
34 Ok(rt) => rt,
35 Err(_e) => {
36 return None;
37 }
38 };
39 if let Err(_e) = setup_host_api(&rt) {
40 return None;
41 }
42 rt.apply_capability_policy(policy);
43
44 let results = load_extensions(&rt, &extensions);
45 for (_name, result) in &results {
46 if let Err(_e) = result {
47 }
50 }
51
52 rt.set_native_tools(tools.tools_map());
54
55 let runtime = Arc::new(Mutex::new(rt));
56 load_lua_tools(Arc::clone(&runtime), tools);
57 Some(runtime)
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use std::path::PathBuf;
64 use std::sync::{Arc, Mutex};
65
66 use imp_core::config::LuaCapabilityPolicy;
67 use imp_core::tools::{ToolContext, ToolRegistry};
68 use imp_core::ui::NullInterface;
69 use tempfile::TempDir;
70
71 fn make_policy() -> LuaCapabilityPolicy {
72 LuaCapabilityPolicy::default()
73 }
74
75 fn make_runtime() -> LuaRuntime {
77 let rt = LuaRuntime::new().expect("create runtime");
78 setup_host_api(&rt).expect("setup host api");
79 rt
80 }
81
82 fn write_lua(dir: &std::path::Path, name: &str, content: &str) -> PathBuf {
84 let path = dir.join(name);
85 std::fs::write(&path, content).unwrap();
86 path
87 }
88
89 fn test_ctx() -> ToolContext {
90 let (tx, _rx) = tokio::sync::mpsc::channel(16);
91 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
92 ToolContext {
93 cwd: PathBuf::from("/tmp/lua-tools"),
94 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
95 update_tx: tx,
96 command_tx: cmd_tx,
97 ui: Arc::new(NullInterface),
98 file_cache: Arc::new(imp_core::tools::FileCache::new()),
99 checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
100 file_tracker: Arc::new(std::sync::Mutex::new(
101 imp_core::tools::FileTracker::default(),
102 )),
103 anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
104 mode: imp_core::config::AgentMode::Full,
105 read_max_lines: 0,
106 turn_mana_review: Arc::new(std::sync::Mutex::new(
107 imp_core::mana_review::TurnManaReviewAccumulator::default(),
108 )),
109 run_policy: Default::default(),
110 config: Arc::new(imp_core::config::Config::default()),
111 lua_tool_loader: None,
112 supporting_provenance: Vec::new(),
113 }
114 }
115
116 #[test]
119 fn init_lua_extensions_applies_capability_policy() {
120 let user = TempDir::new().unwrap();
121 let lua_dir = user.path().join("lua");
122 std::fs::create_dir_all(&lua_dir).unwrap();
123 write_lua(&lua_dir, "ext.lua", "-- extension present");
124
125 let mut registry = ToolRegistry::new();
126 let mut policy = make_policy();
127 policy.allow_native_tool_calls = false;
128 policy.allow_shell_exec = true;
129 policy.allow_http = true;
130 policy.allow_secrets = true;
131 policy.allowed_env.insert("ALLOWED_ONE".to_string());
132
133 let runtime = init_lua_extensions(user.path(), None, &mut registry, &policy)
134 .expect("runtime should initialize");
135 let guard = runtime.lock().unwrap();
136
137 assert!(!guard
138 .allow_native_tool_calls()
139 .load(std::sync::atomic::Ordering::Relaxed));
140 assert!(guard
141 .allow_shell_exec()
142 .load(std::sync::atomic::Ordering::Relaxed));
143 assert!(guard
144 .allow_http()
145 .load(std::sync::atomic::Ordering::Relaxed));
146 assert!(guard
147 .allow_secrets()
148 .load(std::sync::atomic::Ordering::Relaxed));
149 assert!(guard.allowed_env().lock().unwrap().contains("ALLOWED_ONE"));
150 }
151
152 #[test]
153 fn discover_user_lua_files() {
154 let tmp = TempDir::new().unwrap();
155 let lua_dir = tmp.path().join("lua");
156 std::fs::create_dir_all(&lua_dir).unwrap();
157 write_lua(&lua_dir, "greet.lua", "-- hello");
158 write_lua(&lua_dir, "utils.lua", "-- utils");
159
160 let exts = discover_extensions(tmp.path(), None);
161 assert_eq!(exts.len(), 2);
162
163 let names: Vec<&str> = exts.iter().map(|e| e.name.as_str()).collect();
164 assert!(names.contains(&"greet"));
165 assert!(names.contains(&"utils"));
166 }
167
168 #[test]
169 fn discover_directory_init_lua() {
170 let tmp = TempDir::new().unwrap();
171 let ext_dir = tmp.path().join("lua").join("my-ext");
172 std::fs::create_dir_all(&ext_dir).unwrap();
173 write_lua(&ext_dir, "init.lua", "-- init");
174
175 let exts = discover_extensions(tmp.path(), None);
176 assert_eq!(exts.len(), 1);
177 assert_eq!(exts[0].name, "my-ext");
178 assert!(exts[0].path.ends_with("init.lua"));
179 }
180
181 #[test]
182 fn discover_project_local() {
183 let user = TempDir::new().unwrap();
184 let project = TempDir::new().unwrap();
185 let proj_lua = project.path().join(".imp").join("lua");
186 std::fs::create_dir_all(&proj_lua).unwrap();
187 write_lua(&proj_lua, "local.lua", "-- local");
188
189 let exts = discover_extensions(user.path(), Some(project.path()));
190 assert_eq!(exts.len(), 1);
191 assert_eq!(exts[0].name, "local");
192 }
193
194 #[test]
195 fn discover_empty_dirs_return_nothing() {
196 let tmp = TempDir::new().unwrap();
197 let exts = discover_extensions(tmp.path(), None);
198 assert!(exts.is_empty());
199 }
200
201 #[test]
204 fn on_registers_hook() {
205 let rt = make_runtime();
206 rt.exec(
207 r#"
208 imp.on("on_session_start", function(event, ctx)
209 -- handler
210 end)
211 "#,
212 )
213 .unwrap();
214
215 assert_eq!(rt.hook_count(), 1);
216 let events = rt.hook_events();
217 assert_eq!(events[0], "on_session_start");
218 }
219
220 #[test]
221 fn on_registers_multiple_hooks() {
222 let rt = make_runtime();
223 rt.exec(
224 r#"
225 imp.on("on_session_start", function() end)
226 imp.on("after_file_write", function() end)
227 imp.on("before_tool_call", function() end)
228 "#,
229 )
230 .unwrap();
231
232 assert_eq!(rt.hook_count(), 3);
233 let events = rt.hook_events();
234 assert!(events.contains(&"on_session_start".to_string()));
235 assert!(events.contains(&"after_file_write".to_string()));
236 assert!(events.contains(&"before_tool_call".to_string()));
237 }
238
239 #[test]
240 fn hook_handler_fires_on_correct_event() {
241 let rt = make_runtime();
242 rt.exec(
243 r#"
244 _test_fired = false
245 imp.on("on_session_start", function()
246 _test_fired = true
247 end)
248 "#,
249 )
250 .unwrap();
251
252 let hooks = rt.hooks();
254 let hooks_guard = hooks.lock().unwrap();
255 assert_eq!(hooks_guard.len(), 1);
256
257 let handler: mlua::Function = rt
258 .lua()
259 .registry_value(&hooks_guard[0].handler_key)
260 .unwrap();
261 handler.call::<()>(()).unwrap();
262
263 let fired: bool = rt.lua().globals().get("_test_fired").unwrap();
264 assert!(fired);
265 }
266
267 #[test]
270 fn register_tool_creates_handle() {
271 let rt = make_runtime();
272 rt.exec(
273 r#"
274 imp.register_tool({
275 name = "greet",
276 label = "Greeting Tool",
277 description = "Says hello",
278 readonly = true,
279 params = {
280 type = "object",
281 properties = {
282 name = { type = "string", description = "Who to greet" }
283 }
284 },
285 execute = function(call_id, params, ctx)
286 return { content = "Hello, " .. (params.name or "world") }
287 end
288 })
289 "#,
290 )
291 .unwrap();
292
293 assert_eq!(rt.tool_count(), 1);
294 let names = rt.tool_names();
295 assert_eq!(names[0], "greet");
296 }
297
298 #[test]
299 fn register_tool_execute_callable() {
300 let rt = make_runtime();
301 rt.exec(
302 r#"
303 imp.register_tool({
304 name = "add",
305 execute = function(call_id, params, ctx)
306 return { content = tostring(params.a + params.b), is_error = false }
307 end
308 })
309 "#,
310 )
311 .unwrap();
312
313 let tools = rt.tools();
315 let tools_guard = tools.lock().unwrap();
316 let execute_fn: mlua::Function = rt
317 .lua()
318 .registry_value(&tools_guard[0].execute_key)
319 .unwrap();
320
321 let params = rt.lua().create_table().unwrap();
322 params.set("a", 3).unwrap();
323 params.set("b", 4).unwrap();
324
325 let result: mlua::Table = execute_fn
326 .call(("call_1", params, mlua::Value::Nil))
327 .unwrap();
328 let content: String = result.get("content").unwrap();
329 assert_eq!(content, "7");
330 }
331
332 #[tokio::test]
333 async fn load_lua_tools_registers_and_executes_bridge() {
334 let rt = make_runtime();
335 rt.exec(
336 r#"
337 imp.register_tool({
338 name = "greet",
339 label = "Greeting Tool",
340 description = "Greets from Lua",
341 readonly = true,
342 params = {
343 name = { type = "string", description = "Who to greet", required = true },
344 excited = { type = "boolean" }
345 },
346 execute = function(call_id, params, ctx)
347 local suffix = params.excited and "!" or "."
348 return {
349 content = {
350 { type = "text", text = "hello " .. params.name .. suffix },
351 },
352 details = {
353 call_id = call_id,
354 cwd = ctx.cwd,
355 cancelled = ctx.cancelled,
356 },
357 }
358 end
359 })
360 "#,
361 )
362 .unwrap();
363
364 let runtime = Arc::new(Mutex::new(rt));
365 let mut registry = ToolRegistry::new();
366 load_lua_tools(Arc::clone(&runtime), &mut registry);
367
368 let tool = registry
369 .get("greet")
370 .expect("lua tool should be registered");
371 assert_eq!(tool.label(), "Greeting Tool");
372 assert_eq!(tool.description(), "Greets from Lua");
373 assert!(tool.is_readonly());
374 assert_eq!(tool.parameters()["properties"]["name"]["type"], "string");
375 assert_eq!(tool.parameters()["required"], serde_json::json!(["name"]));
376
377 let output = tool
378 .execute(
379 "call_123",
380 serde_json::json!({ "name": "Ada", "excited": true }),
381 test_ctx(),
382 )
383 .await
384 .unwrap();
385
386 let text = output
387 .content
388 .iter()
389 .find_map(|block| match block {
390 imp_core::imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
391 _ => None,
392 })
393 .expect("lua tool should return text");
394 assert_eq!(text, "hello Ada!");
395 assert_eq!(output.details["call_id"], "call_123");
396 assert_eq!(output.details["cwd"], "/tmp/lua-tools");
397 assert_eq!(output.details["cancelled"], false);
398 }
399
400 #[test]
401 fn imp_secret_helpers_exist() {
402 let rt = make_runtime();
403 rt.exec(
404 r#"
405 _has_secret = type(imp.secret) == "function"
406 _has_secret_fields = type(imp.secret_fields) == "function"
407 "#,
408 )
409 .unwrap();
410
411 let has_secret: bool = rt.lua().globals().get("_has_secret").unwrap();
412 let has_secret_fields: bool = rt.lua().globals().get("_has_secret_fields").unwrap();
413 assert!(has_secret);
414 assert!(has_secret_fields);
415 }
416
417 #[test]
420 fn exec_runs_command_returns_stdout() {
421 let rt = make_runtime();
422 rt.set_allow_shell_exec(true);
423 rt.exec(
424 r#"
425 local result = imp.exec("echo hello")
426 _test_stdout = result.stdout
427 _test_exit = result.exit_code
428 "#,
429 )
430 .unwrap();
431
432 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
433 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
434 assert_eq!(stdout.trim(), "hello");
435 assert_eq!(exit_code, 0);
436 }
437
438 #[test]
439 fn exec_captures_stderr() {
440 let rt = make_runtime();
441 rt.set_allow_shell_exec(true);
442 rt.exec(
443 r#"
444 local result = imp.exec("echo error >&2")
445 _test_stderr = result.stderr
446 "#,
447 )
448 .unwrap();
449
450 let stderr: String = rt.lua().globals().get("_test_stderr").unwrap();
451 assert_eq!(stderr.trim(), "error");
452 }
453
454 #[test]
455 fn exec_returns_nonzero_exit_code() {
456 let rt = make_runtime();
457 rt.set_allow_shell_exec(true);
458 rt.exec(
459 r#"
460 local result = imp.exec("exit 42")
461 _test_exit = result.exit_code
462 "#,
463 )
464 .unwrap();
465
466 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
467 assert_eq!(exit_code, 42);
468 }
469
470 #[test]
471 fn exec_with_args_runs_command_without_shell_joining() {
472 let rt = make_runtime();
473 rt.set_allow_shell_exec(true);
474 rt.exec(
475 r#"
476 local result = imp.exec("sh", { "-lc", "printf '%s' \"$1\"", "--", "hello world" })
477 _test_stdout = result.stdout
478 _test_exit = result.exit_code
479 "#,
480 )
481 .unwrap();
482
483 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
484 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
485 assert_eq!(stdout, "hello world");
486 assert_eq!(exit_code, 0);
487 }
488
489 #[test]
490 fn exec_with_cwd() {
491 let rt = make_runtime();
492 rt.set_allow_shell_exec(true);
493 rt.exec(
494 r#"
495 local result = imp.exec("pwd", nil, { cwd = "/tmp" })
496 _test_cwd = result.stdout
497 "#,
498 )
499 .unwrap();
500
501 let cwd: String = rt.lua().globals().get("_test_cwd").unwrap();
502 assert!(cwd.trim().contains("tmp"));
504 }
505
506 #[test]
511 fn null_interface_confirm_returns_nil() {
512 let rt = make_runtime();
513 rt.exec(
515 r#"
516 -- When NullInterface returns None, the bridge maps it to nil
517 _confirm_result = nil -- This is what NullInterface.confirm() produces
518 _is_nil = (_confirm_result == nil)
519 "#,
520 )
521 .unwrap();
522
523 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
524 assert!(is_nil);
525 }
526
527 #[test]
530 fn imp_ui_confirm_returns_nil_without_call_context() {
531 let rt = make_runtime();
532 rt.exec(
533 r#"
534 _confirm_result = imp.ui.confirm("Proceed?", "No UI should be unavailable")
535 _is_nil = (_confirm_result == nil)
536 "#,
537 )
538 .unwrap();
539
540 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
541 assert!(is_nil);
542 }
543
544 #[test]
545 fn imp_ui_request_reports_unavailable_without_call_context() {
546 let rt = make_runtime();
547 rt.exec(
548 r#"
549 _ui_result = imp.ui.request({ kind = "confirm", title = "Proceed?" })
550 "#,
551 )
552 .unwrap();
553
554 let result: mlua::Table = rt.lua().globals().get("_ui_result").unwrap();
555 let ok: bool = result.get("ok").unwrap();
556 let reason: String = result.get("reason").unwrap();
557 assert!(!ok);
558 assert_eq!(reason, "unavailable");
559 }
560
561 #[test]
562 fn lua_command_can_call_imp_ui_confirm_safely_headless() {
563 let rt = make_runtime();
564 rt.exec(
565 r#"
566 imp.register_command("confirm-test", {
567 handler = function(args)
568 local yes = imp.ui.confirm("Proceed?", args)
569 if yes == nil then
570 return "unavailable"
571 end
572 return tostring(yes)
573 end
574 })
575 "#,
576 )
577 .unwrap();
578
579 let output = rt
580 .execute_command_with_context("confirm-test", "headless", Some(test_ctx().into()))
581 .unwrap();
582 assert_eq!(output.as_deref(), Some("unavailable"));
583 }
584
585 #[test]
588 fn hot_reload_drops_and_recreates() {
589 let user_dir = TempDir::new().unwrap();
590 let lua_dir = user_dir.path().join("lua");
591 std::fs::create_dir_all(&lua_dir).unwrap();
592
593 write_lua(
594 &lua_dir,
595 "ext.lua",
596 r#"
597 imp.on("on_session_start", function() end)
598 imp.register_tool({ name = "my_tool", execute = function() end })
599 "#,
600 );
601
602 let policy = make_policy();
603 let (rt1, exts1) = reload(user_dir.path(), None, &policy).unwrap();
605 assert_eq!(rt1.hook_count(), 1);
606 assert_eq!(rt1.tool_count(), 1);
607 assert_eq!(exts1.len(), 1);
608
609 write_lua(
611 &lua_dir,
612 "ext.lua",
613 r#"
614 imp.on("on_session_start", function() end)
615 imp.on("after_file_write", function() end)
616 imp.register_tool({ name = "tool_a", execute = function() end })
617 imp.register_tool({ name = "tool_b", execute = function() end })
618 "#,
619 );
620
621 let (rt2, exts2) = reload(user_dir.path(), None, &policy).unwrap();
623 assert_eq!(rt2.hook_count(), 2);
624 assert_eq!(rt2.tool_count(), 2);
625 assert_eq!(exts2.len(), 1);
626
627 let tool_names = rt2.tool_names();
628 assert!(tool_names.contains(&"tool_a".to_string()));
629 assert!(tool_names.contains(&"tool_b".to_string()));
630 }
631
632 #[test]
633 fn hot_reload_applies_capability_policy() {
634 let user_dir = TempDir::new().unwrap();
635 let lua_dir = user_dir.path().join("lua");
636 std::fs::create_dir_all(&lua_dir).unwrap();
637 write_lua(&lua_dir, "ext.lua", "-- extension present");
638
639 let mut policy = make_policy();
640 policy.allow_shell_exec = true;
641
642 let (rt, _exts) = reload(user_dir.path(), None, &policy).unwrap();
643 assert!(rt
644 .allow_shell_exec()
645 .load(std::sync::atomic::Ordering::Relaxed));
646 }
647
648 #[test]
651 fn lua_syntax_error_caught() {
652 let rt = make_runtime();
653 let result = rt.exec("this is not valid lua !!!");
654 assert!(result.is_err());
655 let result2 = rt.exec("_test_ok = true");
657 assert!(result2.is_ok());
658 let ok: bool = rt.lua().globals().get("_test_ok").unwrap();
659 assert!(ok);
660 }
661
662 #[test]
663 fn lua_runtime_error_caught() {
664 let rt = make_runtime();
665 let result = rt.exec("error('intentional error')");
666 assert!(result.is_err());
667 let err_msg = format!("{}", result.unwrap_err());
668 assert!(err_msg.contains("intentional error"));
669 }
670
671 #[test]
672 fn extension_error_doesnt_crash_runtime() {
673 let rt = make_runtime();
674
675 let r1 = rt.exec("error('ext1 failed')");
677 assert!(r1.is_err());
678
679 let r2 = rt.exec(
681 r#"
682 imp.on("on_session_start", function() end)
683 _ext2_loaded = true
684 "#,
685 );
686 assert!(r2.is_ok());
687 assert_eq!(rt.hook_count(), 1);
688
689 let loaded: bool = rt.lua().globals().get("_ext2_loaded").unwrap();
690 assert!(loaded);
691 }
692
693 #[test]
696 fn multiple_extensions_coexist() {
697 let rt = make_runtime();
698
699 rt.exec(
701 r#"
702 imp.on("on_session_start", function()
703 _ext1_fired = true
704 end)
705 imp.register_tool({ name = "ext1_tool", execute = function() end })
706 "#,
707 )
708 .unwrap();
709
710 rt.exec(
712 r#"
713 imp.on("after_file_write", function()
714 _ext2_fired = true
715 end)
716 imp.register_tool({ name = "ext2_tool", execute = function() end })
717 "#,
718 )
719 .unwrap();
720
721 assert_eq!(rt.hook_count(), 2);
722 assert_eq!(rt.tool_count(), 2);
723
724 let names = rt.tool_names();
725 assert!(names.contains(&"ext1_tool".to_string()));
726 assert!(names.contains(&"ext2_tool".to_string()));
727
728 let events = rt.hook_events();
729 assert!(events.contains(&"on_session_start".to_string()));
730 assert!(events.contains(&"after_file_write".to_string()));
731 }
732
733 #[test]
734 fn extensions_share_state() {
735 let rt = make_runtime();
736
737 rt.exec("shared_counter = 1").unwrap();
739
740 rt.exec(
742 r#"
743 shared_counter = shared_counter + 1
744 _final = shared_counter
745 "#,
746 )
747 .unwrap();
748
749 let val: i64 = rt.lua().globals().get("_final").unwrap();
750 assert_eq!(val, 2);
751 }
752
753 #[test]
756 fn events_on_and_emit() {
757 let rt = make_runtime();
758 rt.exec(
759 r#"
760 _received = nil
761 imp.events.on("custom_event", function(data)
762 _received = data
763 end)
764 imp.events.emit("custom_event", "hello from event")
765 "#,
766 )
767 .unwrap();
768
769 let received: String = rt.lua().globals().get("_received").unwrap();
770 assert_eq!(received, "hello from event");
771 }
772
773 #[test]
774 fn events_multiple_handlers() {
775 let rt = make_runtime();
776 rt.exec(
777 r#"
778 _count = 0
779 imp.events.on("tick", function() _count = _count + 1 end)
780 imp.events.on("tick", function() _count = _count + 1 end)
781 imp.events.emit("tick", nil)
782 "#,
783 )
784 .unwrap();
785
786 let count: i64 = rt.lua().globals().get("_count").unwrap();
787 assert_eq!(count, 2);
788 }
789
790 #[test]
791 fn events_handler_error_doesnt_crash() {
792 let rt = make_runtime();
793 rt.exec(
794 r#"
795 _after_error = false
796 imp.events.on("test", function() error("boom") end)
797 imp.events.on("test", function() _after_error = true end)
798 imp.events.emit("test", nil)
799 "#,
800 )
801 .unwrap();
802
803 let after: bool = rt.lua().globals().get("_after_error").unwrap();
804 assert!(after, "second handler should still fire after first errors");
805 }
806
807 #[test]
810 fn register_command_creates_handle() {
811 let rt = make_runtime();
812 rt.exec(
813 r#"
814 imp.register_command("greet", {
815 description = "Say hello",
816 handler = function(args, ctx)
817 return "Hello!"
818 end
819 })
820 "#,
821 )
822 .unwrap();
823
824 assert_eq!(rt.command_count(), 1);
825 }
826
827 #[test]
830 fn lua_value_to_json_primitives() {
831 let rt = make_runtime();
832 let lua = rt.lua();
833
834 assert_eq!(lua_value_to_json(mlua::Value::Nil), serde_json::Value::Null);
835 assert_eq!(
836 lua_value_to_json(mlua::Value::Boolean(true)),
837 serde_json::json!(true)
838 );
839 assert_eq!(
840 lua_value_to_json(mlua::Value::Integer(42)),
841 serde_json::json!(42)
842 );
843 assert_eq!(
844 lua_value_to_json(mlua::Value::Number(1.25)),
845 serde_json::json!(1.25)
846 );
847
848 let s = lua.create_string("hello").unwrap();
849 assert_eq!(
850 lua_value_to_json(mlua::Value::String(s)),
851 serde_json::json!("hello")
852 );
853 }
854
855 #[test]
856 fn lua_table_to_json_object() {
857 let rt = make_runtime();
858 rt.exec(
859 r#"
860 _test_table = { name = "Alice", age = 30 }
861 "#,
862 )
863 .unwrap();
864
865 let val: mlua::Value = rt.lua().globals().get("_test_table").unwrap();
866 let json = lua_value_to_json(val);
867 assert_eq!(json["name"], "Alice");
868 assert_eq!(json["age"], 30);
869 }
870
871 #[test]
872 fn lua_array_to_json_array() {
873 let rt = make_runtime();
874 rt.exec(
875 r#"
876 _test_arr = { 1, 2, 3 }
877 "#,
878 )
879 .unwrap();
880
881 let val: mlua::Value = rt.lua().globals().get("_test_arr").unwrap();
882 let json = lua_value_to_json(val);
883 assert_eq!(json, serde_json::json!([1, 2, 3]));
884 }
885
886 #[test]
887 fn json_to_lua_roundtrip() {
888 let rt = make_runtime();
889 let lua = rt.lua();
890
891 let original = serde_json::json!({
892 "name": "test",
893 "count": 42,
894 "active": true,
895 "tags": ["a", "b"],
896 "nested": { "x": 1 }
897 });
898
899 let lua_val = json_to_lua_value(lua, &original).unwrap();
900 let back = lua_value_to_json(lua_val);
901 assert_eq!(back, original);
902 }
903
904 #[test]
907 fn load_extensions_from_files() {
908 let user_dir = TempDir::new().unwrap();
909 let lua_dir = user_dir.path().join("lua");
910 std::fs::create_dir_all(&lua_dir).unwrap();
911
912 write_lua(
913 &lua_dir,
914 "a.lua",
915 r#"imp.on("on_session_start", function() end)"#,
916 );
917 write_lua(
918 &lua_dir,
919 "b.lua",
920 r#"imp.register_tool({ name = "b_tool", execute = function() end })"#,
921 );
922
923 let exts = discover_extensions(user_dir.path(), None);
924 assert_eq!(exts.len(), 2);
925
926 let rt = make_runtime();
927 let results = load_extensions(&rt, &exts);
928
929 for (name, result) in &results {
931 assert!(result.is_ok(), "Extension {} failed: {:?}", name, result);
932 }
933
934 assert_eq!(rt.hook_count(), 1);
935 assert_eq!(rt.tool_count(), 1);
936 }
937
938 #[test]
939 fn load_extension_error_reported_not_fatal() {
940 let user_dir = TempDir::new().unwrap();
941 let lua_dir = user_dir.path().join("lua");
942 std::fs::create_dir_all(&lua_dir).unwrap();
943
944 write_lua(&lua_dir, "bad.lua", "error('bad extension')");
945 write_lua(
946 &lua_dir,
947 "good.lua",
948 r#"imp.on("on_session_start", function() end)"#,
949 );
950
951 let exts = discover_extensions(user_dir.path(), None);
952 let rt = make_runtime();
953 let results = load_extensions(&rt, &exts);
954
955 let failures: Vec<_> = results.iter().filter(|(_, r)| r.is_err()).collect();
957 let successes: Vec<_> = results.iter().filter(|(_, r)| r.is_ok()).collect();
958
959 assert_eq!(failures.len(), 1);
960 assert_eq!(successes.len(), 1);
961
962 assert_eq!(rt.hook_count(), 1);
964 }
965
966 use async_trait::async_trait;
969
970 struct EchoTestTool;
971
972 #[async_trait]
973 impl imp_core::tools::Tool for EchoTestTool {
974 fn name(&self) -> &str {
975 "echo"
976 }
977 fn label(&self) -> &str {
978 "Echo"
979 }
980 fn description(&self) -> &str {
981 "Echoes text"
982 }
983 fn parameters(&self) -> serde_json::Value {
984 serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}})
985 }
986 fn is_readonly(&self) -> bool {
987 true
988 }
989 async fn execute(
990 &self,
991 _call_id: &str,
992 params: serde_json::Value,
993 _ctx: imp_core::tools::ToolContext,
994 ) -> imp_core::Result<imp_core::tools::ToolOutput> {
995 let text = params["text"].as_str().unwrap_or("no text");
996 Ok(imp_core::tools::ToolOutput::text(format!("echo: {text}")))
997 }
998 }
999
1000 struct FailTestTool;
1001
1002 #[async_trait]
1003 impl imp_core::tools::Tool for FailTestTool {
1004 fn name(&self) -> &str {
1005 "fail"
1006 }
1007 fn label(&self) -> &str {
1008 "Fail"
1009 }
1010 fn description(&self) -> &str {
1011 "Always fails"
1012 }
1013 fn parameters(&self) -> serde_json::Value {
1014 serde_json::json!({"type": "object"})
1015 }
1016 fn is_readonly(&self) -> bool {
1017 true
1018 }
1019 async fn execute(
1020 &self,
1021 _call_id: &str,
1022 _params: serde_json::Value,
1023 _ctx: imp_core::tools::ToolContext,
1024 ) -> imp_core::Result<imp_core::tools::ToolOutput> {
1025 Ok(imp_core::tools::ToolOutput::error("intentional failure"))
1026 }
1027 }
1028
1029 fn make_call_context() -> sandbox::LuaCallContext {
1030 let (tx, _rx) = tokio::sync::mpsc::channel(16);
1031 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
1032 sandbox::LuaCallContext {
1033 cwd: PathBuf::from("/tmp/lua-test"),
1034 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1035 update_tx: tx,
1036 command_tx: cmd_tx,
1037 ui: Arc::new(NullInterface),
1038 file_cache: Arc::new(imp_core::tools::FileCache::new()),
1039 checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
1040 file_tracker: Arc::new(std::sync::Mutex::new(
1041 imp_core::tools::FileTracker::default(),
1042 )),
1043 anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
1044 mode: imp_core::config::AgentMode::Full,
1045 read_max_lines: 500,
1046 run_policy: Default::default(),
1047 lua_tool_loader: None,
1048 config: Arc::new(imp_core::config::Config::default()),
1049 }
1050 }
1051
1052 #[tokio::test(flavor = "multi_thread")]
1053 async fn imp_tool_calls_native_tool() {
1054 let rt = make_runtime();
1055
1056 let mut native = std::collections::HashMap::new();
1057 native.insert(
1058 "echo".to_string(),
1059 Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
1060 );
1061 rt.set_native_tools(native);
1062 rt.set_call_context(make_call_context());
1063
1064 let rt = Arc::new(Mutex::new(rt));
1065 let rt2 = Arc::clone(&rt);
1066
1067 let result = tokio::task::spawn_blocking(move || {
1068 let guard = rt2.lock().unwrap();
1069 guard
1070 .exec(
1071 r#"
1072 _result, _err = imp.tool("echo", { text = "hello from lua" })
1073 "#,
1074 )
1075 .unwrap();
1076 let result: String = guard.lua().globals().get("_result").unwrap();
1077 let err: mlua::Value = guard.lua().globals().get("_err").unwrap();
1078 assert!(matches!(err, mlua::Value::Nil), "expected no error");
1079 result
1080 })
1081 .await
1082 .unwrap();
1083
1084 assert_eq!(result, "echo: hello from lua");
1085 }
1086
1087 #[tokio::test(flavor = "multi_thread")]
1088 async fn imp_tool_returns_error_on_failure() {
1089 let rt = make_runtime();
1090
1091 let mut native = std::collections::HashMap::new();
1092 native.insert(
1093 "fail".to_string(),
1094 Arc::new(FailTestTool) as Arc<dyn imp_core::tools::Tool>,
1095 );
1096 rt.set_native_tools(native);
1097 rt.set_call_context(make_call_context());
1098
1099 let rt = Arc::new(Mutex::new(rt));
1100 let rt2 = Arc::clone(&rt);
1101
1102 tokio::task::spawn_blocking(move || {
1103 let guard = rt2.lock().unwrap();
1104 guard
1105 .exec(
1106 r#"
1107 _result, _err = imp.tool("fail", {})
1108 "#,
1109 )
1110 .unwrap();
1111 let result: mlua::Value = guard.lua().globals().get("_result").unwrap();
1112 assert!(matches!(result, mlua::Value::Nil), "expected nil result");
1113 let err: String = guard.lua().globals().get("_err").unwrap();
1114 assert!(
1115 err.contains("intentional failure"),
1116 "expected failure message, got: {err}"
1117 );
1118 })
1119 .await
1120 .unwrap();
1121 }
1122
1123 #[tokio::test(flavor = "multi_thread")]
1124 async fn imp_tool_errors_on_unknown_tool() {
1125 let rt = make_runtime();
1126 rt.set_native_tools(std::collections::HashMap::new());
1127 rt.set_call_context(make_call_context());
1128
1129 let rt = Arc::new(Mutex::new(rt));
1130 let rt2 = Arc::clone(&rt);
1131
1132 tokio::task::spawn_blocking(move || {
1133 let guard = rt2.lock().unwrap();
1134 let result = guard.exec(
1135 r#"
1136 imp.tool("nonexistent", {})
1137 "#,
1138 );
1139 assert!(result.is_err(), "should error on unknown tool");
1140 let err = format!("{}", result.unwrap_err());
1141 assert!(
1142 err.contains("not found"),
1143 "error should mention 'not found': {err}"
1144 );
1145 })
1146 .await
1147 .unwrap();
1148 }
1149
1150 #[tokio::test(flavor = "multi_thread")]
1151 async fn imp_tool_errors_when_disabled() {
1152 let rt = make_runtime();
1153
1154 let mut native = std::collections::HashMap::new();
1155 native.insert(
1156 "echo".to_string(),
1157 Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
1158 );
1159 rt.set_native_tools(native);
1160 rt.set_call_context(make_call_context());
1161 rt.set_allow_native_tool_calls(false);
1162
1163 let rt = Arc::new(Mutex::new(rt));
1164 let rt2 = Arc::clone(&rt);
1165
1166 tokio::task::spawn_blocking(move || {
1167 let guard = rt2.lock().unwrap();
1168 let result = guard.exec(
1169 r#"
1170 imp.tool("echo", { text = "hello from lua" })
1171 "#,
1172 );
1173 assert!(result.is_err(), "disabled imp.tool() should error");
1174 let err = format!("{}", result.unwrap_err());
1175 assert!(
1176 err.contains("disabled"),
1177 "error should mention disabled state: {err}"
1178 );
1179 })
1180 .await
1181 .unwrap();
1182 }
1183
1184 #[test]
1187 fn imp_exec_errors_when_disabled() {
1188 let rt = make_runtime();
1189 let result = rt.exec(
1190 r#"
1191 local _ = imp.exec("echo hi")
1192 "#,
1193 );
1194 assert!(result.is_err(), "disabled imp.exec() should error");
1195 let err = format!("{}", result.unwrap_err());
1196 assert!(
1197 err.contains("imp.exec() is disabled"),
1198 "unexpected error: {err}"
1199 );
1200 }
1201
1202 #[test]
1203 fn imp_exec_runs_when_enabled() {
1204 let rt = make_runtime();
1205 rt.set_allow_shell_exec(true);
1206
1207 rt.exec(
1208 r#"
1209 local result = imp.exec("printf lua_exec_ok")
1210 _test_stdout = result.stdout
1211 _test_exit = result.exit_code
1212 "#,
1213 )
1214 .unwrap();
1215
1216 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1217 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1218 assert_eq!(stdout, "lua_exec_ok");
1219 assert_eq!(exit_code, 0);
1220 }
1221
1222 #[test]
1223 fn imp_exec_passes_scoped_env_to_child_process() {
1224 let rt = make_runtime();
1225 rt.set_allow_shell_exec(true);
1226
1227 rt.exec(
1228 r#"
1229 local result = imp.exec("printf %s \"$IMP_LUA_CHILD_SECRET\"", nil, {
1230 env = { IMP_LUA_CHILD_SECRET = "child-only-value" },
1231 })
1232 _test_stdout = result.stdout
1233 _test_exit = result.exit_code
1234 "#,
1235 )
1236 .unwrap();
1237
1238 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1239 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1240 assert_eq!(stdout, "child-only-value");
1241 assert_eq!(exit_code, 0);
1242 }
1243
1244 #[test]
1245 fn imp_secret_errors_when_disabled() {
1246 let rt = make_runtime();
1247 let result = rt.exec(
1248 r#"
1249 local _ = imp.secret("openai", "api_key")
1250 "#,
1251 );
1252 assert!(result.is_err(), "disabled imp.secret() should error");
1253 let err = format!("{}", result.unwrap_err());
1254 assert!(
1255 err.contains("imp.secret() is disabled"),
1256 "unexpected error: {err}"
1257 );
1258 }
1259
1260 #[test]
1261 fn imp_secret_fields_errors_when_disabled() {
1262 let rt = make_runtime();
1263 let result = rt.exec(
1264 r#"
1265 local _ = imp.secret_fields("openai")
1266 "#,
1267 );
1268 assert!(result.is_err(), "disabled imp.secret_fields() should error");
1269 let err = format!("{}", result.unwrap_err());
1270 assert!(
1271 err.contains("imp.secret_fields() is disabled"),
1272 "unexpected error: {err}"
1273 );
1274 }
1275
1276 #[tokio::test(flavor = "multi_thread")]
1277 async fn imp_http_get_errors_when_disabled() {
1278 let rt = Arc::new(Mutex::new(make_runtime()));
1279 let rt2 = Arc::clone(&rt);
1280 tokio::task::spawn_blocking(move || {
1281 let guard = rt2.lock().unwrap();
1282 let result = guard.exec(
1283 r#"
1284 local _ = imp.http.get("https://example.com")
1285 "#,
1286 );
1287 assert!(result.is_err(), "disabled imp.http.get() should error");
1288 let err = format!("{}", result.unwrap_err());
1289 assert!(
1290 err.contains("imp.http.get() is disabled"),
1291 "unexpected error: {err}"
1292 );
1293 })
1294 .await
1295 .unwrap();
1296 }
1297
1298 #[test]
1299 fn imp_env_reads_var_when_allowed() {
1300 let rt = make_runtime();
1301 std::env::set_var("IMP_LUA_TEST_VAR", "test_value");
1302
1303 let mut allowed = std::collections::HashSet::new();
1304 allowed.insert("IMP_LUA_TEST_VAR".to_string());
1305 rt.set_allowed_env(allowed);
1306
1307 rt.exec(
1308 r#"
1309 _env_val = imp.env("IMP_LUA_TEST_VAR")
1310 "#,
1311 )
1312 .unwrap();
1313
1314 let val: String = rt.lua().globals().get("_env_val").unwrap();
1315 assert_eq!(val, "test_value");
1316 }
1317
1318 #[test]
1319 fn imp_env_returns_nil_for_denied_var() {
1320 let rt = make_runtime();
1321 std::env::set_var("IMP_LUA_TEST_SECRET", "secret_value");
1322
1323 let mut allowed = std::collections::HashSet::new();
1324 allowed.insert("SOME_OTHER_VAR".to_string());
1325 rt.set_allowed_env(allowed);
1326
1327 rt.exec(
1328 r#"
1329 _env_val = imp.env("IMP_LUA_TEST_SECRET")
1330 _is_nil = (_env_val == nil)
1331 "#,
1332 )
1333 .unwrap();
1334
1335 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1336 assert!(is_nil, "denied env var should return nil");
1337 }
1338
1339 #[test]
1340 fn imp_env_allows_all_when_list_empty() {
1341 let rt = make_runtime();
1342 std::env::set_var("IMP_LUA_TEST_OPEN", "open_value");
1343
1344 rt.set_allowed_env(std::collections::HashSet::new());
1346
1347 rt.exec(
1348 r#"
1349 _env_val = imp.env("IMP_LUA_TEST_OPEN")
1350 _is_nil = (_env_val == nil)
1351 "#,
1352 )
1353 .unwrap();
1354
1355 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1356 assert!(is_nil, "empty allow-list should deny env access by default");
1357 }
1358
1359 #[test]
1360 fn imp_env_returns_nil_for_missing_var() {
1361 let rt = make_runtime();
1362 rt.set_allowed_env(std::collections::HashSet::new());
1363
1364 rt.exec(
1365 r#"
1366 _env_val = imp.env("DEFINITELY_NOT_SET_IMP_LUA_TEST")
1367 _is_nil = (_env_val == nil)
1368 "#,
1369 )
1370 .unwrap();
1371
1372 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1373 assert!(is_nil, "missing env var should return nil");
1374 }
1375}