1pub mod bridge;
2pub mod loader;
3pub mod sandbox;
4
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7
8use imp_core::config::LuaCapabilityPolicy;
9use imp_core::tools::ToolRegistry;
10
11pub use bridge::{json_to_lua_value, load_lua_tools, lua_value_to_json, setup_host_api, LuaTool};
12pub use loader::{discover_extensions, load_extensions, reload, LuaExtension};
13pub use sandbox::{
14 LuaCallContext, LuaCommandHandle, LuaError, LuaHookHandle, LuaRuntime, LuaToolHandle,
15};
16
17pub fn init_lua_extensions(
23 user_config_dir: &Path,
24 project_dir: Option<&Path>,
25 tools: &mut ToolRegistry,
26 policy: &LuaCapabilityPolicy,
27) -> Option<Arc<Mutex<LuaRuntime>>> {
28 let extensions = discover_extensions(user_config_dir, project_dir);
29 if extensions.is_empty() {
30 return None;
31 }
32
33 let rt = match LuaRuntime::new() {
34 Ok(rt) => rt,
35 Err(_e) => {
36 return None;
37 }
38 };
39 if let Err(_e) = setup_host_api(&rt) {
40 return None;
41 }
42 rt.apply_capability_policy(policy);
43
44 let results = load_extensions(&rt, &extensions);
45 for (_name, result) in &results {
46 if let Err(_e) = result {
47 }
50 }
51
52 rt.set_native_tools(tools.tools_map());
54
55 let runtime = Arc::new(Mutex::new(rt));
56 load_lua_tools(Arc::clone(&runtime), tools);
57 Some(runtime)
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use std::path::PathBuf;
64 use std::sync::{Arc, Mutex};
65
66 use imp_core::config::LuaCapabilityPolicy;
67 use imp_core::tools::{ToolContext, ToolRegistry};
68 use imp_core::ui::NullInterface;
69 use tempfile::TempDir;
70
71 fn make_policy() -> LuaCapabilityPolicy {
72 LuaCapabilityPolicy::default()
73 }
74
75 fn make_runtime() -> LuaRuntime {
77 let rt = LuaRuntime::new().expect("create runtime");
78 setup_host_api(&rt).expect("setup host api");
79 rt
80 }
81
82 fn write_lua(dir: &std::path::Path, name: &str, content: &str) -> PathBuf {
84 let path = dir.join(name);
85 std::fs::write(&path, content).unwrap();
86 path
87 }
88
89 fn test_ctx() -> ToolContext {
90 let (tx, _rx) = tokio::sync::mpsc::channel(16);
91 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
92 ToolContext {
93 cwd: PathBuf::from("/tmp/lua-tools"),
94 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
95 update_tx: tx,
96 command_tx: cmd_tx,
97 ui: Arc::new(NullInterface),
98 file_cache: Arc::new(imp_core::tools::FileCache::new()),
99 checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
100 file_tracker: Arc::new(std::sync::Mutex::new(
101 imp_core::tools::FileTracker::default(),
102 )),
103 anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
104 mode: imp_core::config::AgentMode::Full,
105 read_max_lines: 0,
106 turn_mana_review: Arc::new(std::sync::Mutex::new(
107 imp_core::mana_review::TurnManaReviewAccumulator::default(),
108 )),
109 config: Arc::new(imp_core::config::Config::default()),
110 lua_tool_loader: None,
111 }
112 }
113
114 #[test]
117 fn init_lua_extensions_applies_capability_policy() {
118 let user = TempDir::new().unwrap();
119 let lua_dir = user.path().join("lua");
120 std::fs::create_dir_all(&lua_dir).unwrap();
121 write_lua(&lua_dir, "ext.lua", "-- extension present");
122
123 let mut registry = ToolRegistry::new();
124 let mut policy = make_policy();
125 policy.allow_native_tool_calls = false;
126 policy.allow_shell_exec = true;
127 policy.allow_http = true;
128 policy.allow_secrets = true;
129 policy.allowed_env.insert("ALLOWED_ONE".to_string());
130
131 let runtime = init_lua_extensions(user.path(), None, &mut registry, &policy)
132 .expect("runtime should initialize");
133 let guard = runtime.lock().unwrap();
134
135 assert!(!guard
136 .allow_native_tool_calls()
137 .load(std::sync::atomic::Ordering::Relaxed));
138 assert!(guard
139 .allow_shell_exec()
140 .load(std::sync::atomic::Ordering::Relaxed));
141 assert!(guard
142 .allow_http()
143 .load(std::sync::atomic::Ordering::Relaxed));
144 assert!(guard
145 .allow_secrets()
146 .load(std::sync::atomic::Ordering::Relaxed));
147 assert!(guard.allowed_env().lock().unwrap().contains("ALLOWED_ONE"));
148 }
149
150 #[test]
151 fn discover_user_lua_files() {
152 let tmp = TempDir::new().unwrap();
153 let lua_dir = tmp.path().join("lua");
154 std::fs::create_dir_all(&lua_dir).unwrap();
155 write_lua(&lua_dir, "greet.lua", "-- hello");
156 write_lua(&lua_dir, "utils.lua", "-- utils");
157
158 let exts = discover_extensions(tmp.path(), None);
159 assert_eq!(exts.len(), 2);
160
161 let names: Vec<&str> = exts.iter().map(|e| e.name.as_str()).collect();
162 assert!(names.contains(&"greet"));
163 assert!(names.contains(&"utils"));
164 }
165
166 #[test]
167 fn discover_directory_init_lua() {
168 let tmp = TempDir::new().unwrap();
169 let ext_dir = tmp.path().join("lua").join("my-ext");
170 std::fs::create_dir_all(&ext_dir).unwrap();
171 write_lua(&ext_dir, "init.lua", "-- init");
172
173 let exts = discover_extensions(tmp.path(), None);
174 assert_eq!(exts.len(), 1);
175 assert_eq!(exts[0].name, "my-ext");
176 assert!(exts[0].path.ends_with("init.lua"));
177 }
178
179 #[test]
180 fn discover_project_local() {
181 let user = TempDir::new().unwrap();
182 let project = TempDir::new().unwrap();
183 let proj_lua = project.path().join(".imp").join("lua");
184 std::fs::create_dir_all(&proj_lua).unwrap();
185 write_lua(&proj_lua, "local.lua", "-- local");
186
187 let exts = discover_extensions(user.path(), Some(project.path()));
188 assert_eq!(exts.len(), 1);
189 assert_eq!(exts[0].name, "local");
190 }
191
192 #[test]
193 fn discover_empty_dirs_return_nothing() {
194 let tmp = TempDir::new().unwrap();
195 let exts = discover_extensions(tmp.path(), None);
196 assert!(exts.is_empty());
197 }
198
199 #[test]
202 fn on_registers_hook() {
203 let rt = make_runtime();
204 rt.exec(
205 r#"
206 imp.on("on_session_start", function(event, ctx)
207 -- handler
208 end)
209 "#,
210 )
211 .unwrap();
212
213 assert_eq!(rt.hook_count(), 1);
214 let events = rt.hook_events();
215 assert_eq!(events[0], "on_session_start");
216 }
217
218 #[test]
219 fn on_registers_multiple_hooks() {
220 let rt = make_runtime();
221 rt.exec(
222 r#"
223 imp.on("on_session_start", function() end)
224 imp.on("after_file_write", function() end)
225 imp.on("before_tool_call", function() end)
226 "#,
227 )
228 .unwrap();
229
230 assert_eq!(rt.hook_count(), 3);
231 let events = rt.hook_events();
232 assert!(events.contains(&"on_session_start".to_string()));
233 assert!(events.contains(&"after_file_write".to_string()));
234 assert!(events.contains(&"before_tool_call".to_string()));
235 }
236
237 #[test]
238 fn hook_handler_fires_on_correct_event() {
239 let rt = make_runtime();
240 rt.exec(
241 r#"
242 _test_fired = false
243 imp.on("on_session_start", function()
244 _test_fired = true
245 end)
246 "#,
247 )
248 .unwrap();
249
250 let hooks = rt.hooks();
252 let hooks_guard = hooks.lock().unwrap();
253 assert_eq!(hooks_guard.len(), 1);
254
255 let handler: mlua::Function = rt
256 .lua()
257 .registry_value(&hooks_guard[0].handler_key)
258 .unwrap();
259 handler.call::<()>(()).unwrap();
260
261 let fired: bool = rt.lua().globals().get("_test_fired").unwrap();
262 assert!(fired);
263 }
264
265 #[test]
268 fn register_tool_creates_handle() {
269 let rt = make_runtime();
270 rt.exec(
271 r#"
272 imp.register_tool({
273 name = "greet",
274 label = "Greeting Tool",
275 description = "Says hello",
276 readonly = true,
277 params = {
278 type = "object",
279 properties = {
280 name = { type = "string", description = "Who to greet" }
281 }
282 },
283 execute = function(call_id, params, ctx)
284 return { content = "Hello, " .. (params.name or "world") }
285 end
286 })
287 "#,
288 )
289 .unwrap();
290
291 assert_eq!(rt.tool_count(), 1);
292 let names = rt.tool_names();
293 assert_eq!(names[0], "greet");
294 }
295
296 #[test]
297 fn register_tool_execute_callable() {
298 let rt = make_runtime();
299 rt.exec(
300 r#"
301 imp.register_tool({
302 name = "add",
303 execute = function(call_id, params, ctx)
304 return { content = tostring(params.a + params.b), is_error = false }
305 end
306 })
307 "#,
308 )
309 .unwrap();
310
311 let tools = rt.tools();
313 let tools_guard = tools.lock().unwrap();
314 let execute_fn: mlua::Function = rt
315 .lua()
316 .registry_value(&tools_guard[0].execute_key)
317 .unwrap();
318
319 let params = rt.lua().create_table().unwrap();
320 params.set("a", 3).unwrap();
321 params.set("b", 4).unwrap();
322
323 let result: mlua::Table = execute_fn
324 .call(("call_1", params, mlua::Value::Nil))
325 .unwrap();
326 let content: String = result.get("content").unwrap();
327 assert_eq!(content, "7");
328 }
329
330 #[tokio::test]
331 async fn load_lua_tools_registers_and_executes_bridge() {
332 let rt = make_runtime();
333 rt.exec(
334 r#"
335 imp.register_tool({
336 name = "greet",
337 label = "Greeting Tool",
338 description = "Greets from Lua",
339 readonly = true,
340 params = {
341 name = { type = "string", description = "Who to greet", required = true },
342 excited = { type = "boolean" }
343 },
344 execute = function(call_id, params, ctx)
345 local suffix = params.excited and "!" or "."
346 return {
347 content = {
348 { type = "text", text = "hello " .. params.name .. suffix },
349 },
350 details = {
351 call_id = call_id,
352 cwd = ctx.cwd,
353 cancelled = ctx.cancelled,
354 },
355 }
356 end
357 })
358 "#,
359 )
360 .unwrap();
361
362 let runtime = Arc::new(Mutex::new(rt));
363 let mut registry = ToolRegistry::new();
364 load_lua_tools(Arc::clone(&runtime), &mut registry);
365
366 let tool = registry
367 .get("greet")
368 .expect("lua tool should be registered");
369 assert_eq!(tool.label(), "Greeting Tool");
370 assert_eq!(tool.description(), "Greets from Lua");
371 assert!(tool.is_readonly());
372 assert_eq!(tool.parameters()["properties"]["name"]["type"], "string");
373 assert_eq!(tool.parameters()["required"], serde_json::json!(["name"]));
374
375 let output = tool
376 .execute(
377 "call_123",
378 serde_json::json!({ "name": "Ada", "excited": true }),
379 test_ctx(),
380 )
381 .await
382 .unwrap();
383
384 let text = output
385 .content
386 .iter()
387 .find_map(|block| match block {
388 imp_core::imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
389 _ => None,
390 })
391 .expect("lua tool should return text");
392 assert_eq!(text, "hello Ada!");
393 assert_eq!(output.details["call_id"], "call_123");
394 assert_eq!(output.details["cwd"], "/tmp/lua-tools");
395 assert_eq!(output.details["cancelled"], false);
396 }
397
398 #[test]
399 fn imp_secret_helpers_exist() {
400 let rt = make_runtime();
401 rt.exec(
402 r#"
403 _has_secret = type(imp.secret) == "function"
404 _has_secret_fields = type(imp.secret_fields) == "function"
405 "#,
406 )
407 .unwrap();
408
409 let has_secret: bool = rt.lua().globals().get("_has_secret").unwrap();
410 let has_secret_fields: bool = rt.lua().globals().get("_has_secret_fields").unwrap();
411 assert!(has_secret);
412 assert!(has_secret_fields);
413 }
414
415 #[test]
418 fn exec_runs_command_returns_stdout() {
419 let rt = make_runtime();
420 rt.set_allow_shell_exec(true);
421 rt.exec(
422 r#"
423 local result = imp.exec("echo hello")
424 _test_stdout = result.stdout
425 _test_exit = result.exit_code
426 "#,
427 )
428 .unwrap();
429
430 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
431 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
432 assert_eq!(stdout.trim(), "hello");
433 assert_eq!(exit_code, 0);
434 }
435
436 #[test]
437 fn exec_captures_stderr() {
438 let rt = make_runtime();
439 rt.set_allow_shell_exec(true);
440 rt.exec(
441 r#"
442 local result = imp.exec("echo error >&2")
443 _test_stderr = result.stderr
444 "#,
445 )
446 .unwrap();
447
448 let stderr: String = rt.lua().globals().get("_test_stderr").unwrap();
449 assert_eq!(stderr.trim(), "error");
450 }
451
452 #[test]
453 fn exec_returns_nonzero_exit_code() {
454 let rt = make_runtime();
455 rt.set_allow_shell_exec(true);
456 rt.exec(
457 r#"
458 local result = imp.exec("exit 42")
459 _test_exit = result.exit_code
460 "#,
461 )
462 .unwrap();
463
464 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
465 assert_eq!(exit_code, 42);
466 }
467
468 #[test]
469 fn exec_with_cwd() {
470 let rt = make_runtime();
471 rt.set_allow_shell_exec(true);
472 rt.exec(
473 r#"
474 local result = imp.exec("pwd", nil, { cwd = "/tmp" })
475 _test_cwd = result.stdout
476 "#,
477 )
478 .unwrap();
479
480 let cwd: String = rt.lua().globals().get("_test_cwd").unwrap();
481 assert!(cwd.trim().contains("tmp"));
483 }
484
485 #[test]
490 fn null_interface_confirm_returns_nil() {
491 let rt = make_runtime();
492 rt.exec(
494 r#"
495 -- When NullInterface returns None, the bridge maps it to nil
496 _confirm_result = nil -- This is what NullInterface.confirm() produces
497 _is_nil = (_confirm_result == nil)
498 "#,
499 )
500 .unwrap();
501
502 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
503 assert!(is_nil);
504 }
505
506 #[test]
509 fn hot_reload_drops_and_recreates() {
510 let user_dir = TempDir::new().unwrap();
511 let lua_dir = user_dir.path().join("lua");
512 std::fs::create_dir_all(&lua_dir).unwrap();
513
514 write_lua(
515 &lua_dir,
516 "ext.lua",
517 r#"
518 imp.on("on_session_start", function() end)
519 imp.register_tool({ name = "my_tool", execute = function() end })
520 "#,
521 );
522
523 let (rt1, exts1) = reload(user_dir.path(), None).unwrap();
525 assert_eq!(rt1.hook_count(), 1);
526 assert_eq!(rt1.tool_count(), 1);
527 assert_eq!(exts1.len(), 1);
528
529 write_lua(
531 &lua_dir,
532 "ext.lua",
533 r#"
534 imp.on("on_session_start", function() end)
535 imp.on("after_file_write", function() end)
536 imp.register_tool({ name = "tool_a", execute = function() end })
537 imp.register_tool({ name = "tool_b", execute = function() end })
538 "#,
539 );
540
541 let (rt2, exts2) = reload(user_dir.path(), None).unwrap();
543 assert_eq!(rt2.hook_count(), 2);
544 assert_eq!(rt2.tool_count(), 2);
545 assert_eq!(exts2.len(), 1);
546
547 let tool_names = rt2.tool_names();
548 assert!(tool_names.contains(&"tool_a".to_string()));
549 assert!(tool_names.contains(&"tool_b".to_string()));
550 }
551
552 #[test]
555 fn lua_syntax_error_caught() {
556 let rt = make_runtime();
557 let result = rt.exec("this is not valid lua !!!");
558 assert!(result.is_err());
559 let result2 = rt.exec("_test_ok = true");
561 assert!(result2.is_ok());
562 let ok: bool = rt.lua().globals().get("_test_ok").unwrap();
563 assert!(ok);
564 }
565
566 #[test]
567 fn lua_runtime_error_caught() {
568 let rt = make_runtime();
569 let result = rt.exec("error('intentional error')");
570 assert!(result.is_err());
571 let err_msg = format!("{}", result.unwrap_err());
572 assert!(err_msg.contains("intentional error"));
573 }
574
575 #[test]
576 fn extension_error_doesnt_crash_runtime() {
577 let rt = make_runtime();
578
579 let r1 = rt.exec("error('ext1 failed')");
581 assert!(r1.is_err());
582
583 let r2 = rt.exec(
585 r#"
586 imp.on("on_session_start", function() end)
587 _ext2_loaded = true
588 "#,
589 );
590 assert!(r2.is_ok());
591 assert_eq!(rt.hook_count(), 1);
592
593 let loaded: bool = rt.lua().globals().get("_ext2_loaded").unwrap();
594 assert!(loaded);
595 }
596
597 #[test]
600 fn multiple_extensions_coexist() {
601 let rt = make_runtime();
602
603 rt.exec(
605 r#"
606 imp.on("on_session_start", function()
607 _ext1_fired = true
608 end)
609 imp.register_tool({ name = "ext1_tool", execute = function() end })
610 "#,
611 )
612 .unwrap();
613
614 rt.exec(
616 r#"
617 imp.on("after_file_write", function()
618 _ext2_fired = true
619 end)
620 imp.register_tool({ name = "ext2_tool", execute = function() end })
621 "#,
622 )
623 .unwrap();
624
625 assert_eq!(rt.hook_count(), 2);
626 assert_eq!(rt.tool_count(), 2);
627
628 let names = rt.tool_names();
629 assert!(names.contains(&"ext1_tool".to_string()));
630 assert!(names.contains(&"ext2_tool".to_string()));
631
632 let events = rt.hook_events();
633 assert!(events.contains(&"on_session_start".to_string()));
634 assert!(events.contains(&"after_file_write".to_string()));
635 }
636
637 #[test]
638 fn extensions_share_state() {
639 let rt = make_runtime();
640
641 rt.exec("shared_counter = 1").unwrap();
643
644 rt.exec(
646 r#"
647 shared_counter = shared_counter + 1
648 _final = shared_counter
649 "#,
650 )
651 .unwrap();
652
653 let val: i64 = rt.lua().globals().get("_final").unwrap();
654 assert_eq!(val, 2);
655 }
656
657 #[test]
660 fn events_on_and_emit() {
661 let rt = make_runtime();
662 rt.exec(
663 r#"
664 _received = nil
665 imp.events.on("custom_event", function(data)
666 _received = data
667 end)
668 imp.events.emit("custom_event", "hello from event")
669 "#,
670 )
671 .unwrap();
672
673 let received: String = rt.lua().globals().get("_received").unwrap();
674 assert_eq!(received, "hello from event");
675 }
676
677 #[test]
678 fn events_multiple_handlers() {
679 let rt = make_runtime();
680 rt.exec(
681 r#"
682 _count = 0
683 imp.events.on("tick", function() _count = _count + 1 end)
684 imp.events.on("tick", function() _count = _count + 1 end)
685 imp.events.emit("tick", nil)
686 "#,
687 )
688 .unwrap();
689
690 let count: i64 = rt.lua().globals().get("_count").unwrap();
691 assert_eq!(count, 2);
692 }
693
694 #[test]
695 fn events_handler_error_doesnt_crash() {
696 let rt = make_runtime();
697 rt.exec(
698 r#"
699 _after_error = false
700 imp.events.on("test", function() error("boom") end)
701 imp.events.on("test", function() _after_error = true end)
702 imp.events.emit("test", nil)
703 "#,
704 )
705 .unwrap();
706
707 let after: bool = rt.lua().globals().get("_after_error").unwrap();
708 assert!(after, "second handler should still fire after first errors");
709 }
710
711 #[test]
714 fn register_command_creates_handle() {
715 let rt = make_runtime();
716 rt.exec(
717 r#"
718 imp.register_command("greet", {
719 description = "Say hello",
720 handler = function(args, ctx)
721 return "Hello!"
722 end
723 })
724 "#,
725 )
726 .unwrap();
727
728 assert_eq!(rt.command_count(), 1);
729 }
730
731 #[test]
734 fn lua_value_to_json_primitives() {
735 let rt = make_runtime();
736 let lua = rt.lua();
737
738 assert_eq!(lua_value_to_json(mlua::Value::Nil), serde_json::Value::Null);
739 assert_eq!(
740 lua_value_to_json(mlua::Value::Boolean(true)),
741 serde_json::json!(true)
742 );
743 assert_eq!(
744 lua_value_to_json(mlua::Value::Integer(42)),
745 serde_json::json!(42)
746 );
747 assert_eq!(
748 lua_value_to_json(mlua::Value::Number(3.14)),
749 serde_json::json!(3.14)
750 );
751
752 let s = lua.create_string("hello").unwrap();
753 assert_eq!(
754 lua_value_to_json(mlua::Value::String(s)),
755 serde_json::json!("hello")
756 );
757 }
758
759 #[test]
760 fn lua_table_to_json_object() {
761 let rt = make_runtime();
762 rt.exec(
763 r#"
764 _test_table = { name = "Alice", age = 30 }
765 "#,
766 )
767 .unwrap();
768
769 let val: mlua::Value = rt.lua().globals().get("_test_table").unwrap();
770 let json = lua_value_to_json(val);
771 assert_eq!(json["name"], "Alice");
772 assert_eq!(json["age"], 30);
773 }
774
775 #[test]
776 fn lua_array_to_json_array() {
777 let rt = make_runtime();
778 rt.exec(
779 r#"
780 _test_arr = { 1, 2, 3 }
781 "#,
782 )
783 .unwrap();
784
785 let val: mlua::Value = rt.lua().globals().get("_test_arr").unwrap();
786 let json = lua_value_to_json(val);
787 assert_eq!(json, serde_json::json!([1, 2, 3]));
788 }
789
790 #[test]
791 fn json_to_lua_roundtrip() {
792 let rt = make_runtime();
793 let lua = rt.lua();
794
795 let original = serde_json::json!({
796 "name": "test",
797 "count": 42,
798 "active": true,
799 "tags": ["a", "b"],
800 "nested": { "x": 1 }
801 });
802
803 let lua_val = json_to_lua_value(lua, &original).unwrap();
804 let back = lua_value_to_json(lua_val);
805 assert_eq!(back, original);
806 }
807
808 #[test]
811 fn load_extensions_from_files() {
812 let user_dir = TempDir::new().unwrap();
813 let lua_dir = user_dir.path().join("lua");
814 std::fs::create_dir_all(&lua_dir).unwrap();
815
816 write_lua(
817 &lua_dir,
818 "a.lua",
819 r#"imp.on("on_session_start", function() end)"#,
820 );
821 write_lua(
822 &lua_dir,
823 "b.lua",
824 r#"imp.register_tool({ name = "b_tool", execute = function() end })"#,
825 );
826
827 let exts = discover_extensions(user_dir.path(), None);
828 assert_eq!(exts.len(), 2);
829
830 let rt = make_runtime();
831 let results = load_extensions(&rt, &exts);
832
833 for (name, result) in &results {
835 assert!(result.is_ok(), "Extension {} failed: {:?}", name, result);
836 }
837
838 assert_eq!(rt.hook_count(), 1);
839 assert_eq!(rt.tool_count(), 1);
840 }
841
842 #[test]
843 fn load_extension_error_reported_not_fatal() {
844 let user_dir = TempDir::new().unwrap();
845 let lua_dir = user_dir.path().join("lua");
846 std::fs::create_dir_all(&lua_dir).unwrap();
847
848 write_lua(&lua_dir, "bad.lua", "error('bad extension')");
849 write_lua(
850 &lua_dir,
851 "good.lua",
852 r#"imp.on("on_session_start", function() end)"#,
853 );
854
855 let exts = discover_extensions(user_dir.path(), None);
856 let rt = make_runtime();
857 let results = load_extensions(&rt, &exts);
858
859 let failures: Vec<_> = results.iter().filter(|(_, r)| r.is_err()).collect();
861 let successes: Vec<_> = results.iter().filter(|(_, r)| r.is_ok()).collect();
862
863 assert_eq!(failures.len(), 1);
864 assert_eq!(successes.len(), 1);
865
866 assert_eq!(rt.hook_count(), 1);
868 }
869
870 use async_trait::async_trait;
873
874 struct EchoTestTool;
875
876 #[async_trait]
877 impl imp_core::tools::Tool for EchoTestTool {
878 fn name(&self) -> &str {
879 "echo"
880 }
881 fn label(&self) -> &str {
882 "Echo"
883 }
884 fn description(&self) -> &str {
885 "Echoes text"
886 }
887 fn parameters(&self) -> serde_json::Value {
888 serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}})
889 }
890 fn is_readonly(&self) -> bool {
891 true
892 }
893 async fn execute(
894 &self,
895 _call_id: &str,
896 params: serde_json::Value,
897 _ctx: imp_core::tools::ToolContext,
898 ) -> imp_core::Result<imp_core::tools::ToolOutput> {
899 let text = params["text"].as_str().unwrap_or("no text");
900 Ok(imp_core::tools::ToolOutput::text(format!("echo: {text}")))
901 }
902 }
903
904 struct FailTestTool;
905
906 #[async_trait]
907 impl imp_core::tools::Tool for FailTestTool {
908 fn name(&self) -> &str {
909 "fail"
910 }
911 fn label(&self) -> &str {
912 "Fail"
913 }
914 fn description(&self) -> &str {
915 "Always fails"
916 }
917 fn parameters(&self) -> serde_json::Value {
918 serde_json::json!({"type": "object"})
919 }
920 fn is_readonly(&self) -> bool {
921 true
922 }
923 async fn execute(
924 &self,
925 _call_id: &str,
926 _params: serde_json::Value,
927 _ctx: imp_core::tools::ToolContext,
928 ) -> imp_core::Result<imp_core::tools::ToolOutput> {
929 Ok(imp_core::tools::ToolOutput::error("intentional failure"))
930 }
931 }
932
933 fn make_call_context() -> sandbox::LuaCallContext {
934 let (tx, _rx) = tokio::sync::mpsc::channel(16);
935 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
936 sandbox::LuaCallContext {
937 cwd: PathBuf::from("/tmp/lua-test"),
938 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
939 update_tx: tx,
940 command_tx: cmd_tx,
941 ui: Arc::new(NullInterface),
942 file_cache: Arc::new(imp_core::tools::FileCache::new()),
943 checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
944 file_tracker: Arc::new(std::sync::Mutex::new(
945 imp_core::tools::FileTracker::default(),
946 )),
947 anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
948 mode: imp_core::config::AgentMode::Full,
949 read_max_lines: 500,
950 lua_tool_loader: None,
951 config: Arc::new(imp_core::config::Config::default()),
952 }
953 }
954
955 #[tokio::test(flavor = "multi_thread")]
956 async fn imp_tool_calls_native_tool() {
957 let rt = make_runtime();
958
959 let mut native = std::collections::HashMap::new();
960 native.insert(
961 "echo".to_string(),
962 Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
963 );
964 rt.set_native_tools(native);
965 rt.set_call_context(make_call_context());
966
967 let rt = Arc::new(Mutex::new(rt));
968 let rt2 = Arc::clone(&rt);
969
970 let result = tokio::task::spawn_blocking(move || {
971 let guard = rt2.lock().unwrap();
972 guard
973 .exec(
974 r#"
975 _result, _err = imp.tool("echo", { text = "hello from lua" })
976 "#,
977 )
978 .unwrap();
979 let result: String = guard.lua().globals().get("_result").unwrap();
980 let err: mlua::Value = guard.lua().globals().get("_err").unwrap();
981 assert!(matches!(err, mlua::Value::Nil), "expected no error");
982 result
983 })
984 .await
985 .unwrap();
986
987 assert_eq!(result, "echo: hello from lua");
988 }
989
990 #[tokio::test(flavor = "multi_thread")]
991 async fn imp_tool_returns_error_on_failure() {
992 let rt = make_runtime();
993
994 let mut native = std::collections::HashMap::new();
995 native.insert(
996 "fail".to_string(),
997 Arc::new(FailTestTool) as Arc<dyn imp_core::tools::Tool>,
998 );
999 rt.set_native_tools(native);
1000 rt.set_call_context(make_call_context());
1001
1002 let rt = Arc::new(Mutex::new(rt));
1003 let rt2 = Arc::clone(&rt);
1004
1005 tokio::task::spawn_blocking(move || {
1006 let guard = rt2.lock().unwrap();
1007 guard
1008 .exec(
1009 r#"
1010 _result, _err = imp.tool("fail", {})
1011 "#,
1012 )
1013 .unwrap();
1014 let result: mlua::Value = guard.lua().globals().get("_result").unwrap();
1015 assert!(matches!(result, mlua::Value::Nil), "expected nil result");
1016 let err: String = guard.lua().globals().get("_err").unwrap();
1017 assert!(
1018 err.contains("intentional failure"),
1019 "expected failure message, got: {err}"
1020 );
1021 })
1022 .await
1023 .unwrap();
1024 }
1025
1026 #[tokio::test(flavor = "multi_thread")]
1027 async fn imp_tool_errors_on_unknown_tool() {
1028 let rt = make_runtime();
1029 rt.set_native_tools(std::collections::HashMap::new());
1030 rt.set_call_context(make_call_context());
1031
1032 let rt = Arc::new(Mutex::new(rt));
1033 let rt2 = Arc::clone(&rt);
1034
1035 tokio::task::spawn_blocking(move || {
1036 let guard = rt2.lock().unwrap();
1037 let result = guard.exec(
1038 r#"
1039 imp.tool("nonexistent", {})
1040 "#,
1041 );
1042 assert!(result.is_err(), "should error on unknown tool");
1043 let err = format!("{}", result.unwrap_err());
1044 assert!(
1045 err.contains("not found"),
1046 "error should mention 'not found': {err}"
1047 );
1048 })
1049 .await
1050 .unwrap();
1051 }
1052
1053 #[tokio::test(flavor = "multi_thread")]
1054 async fn imp_tool_errors_when_disabled() {
1055 let rt = make_runtime();
1056
1057 let mut native = std::collections::HashMap::new();
1058 native.insert(
1059 "echo".to_string(),
1060 Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
1061 );
1062 rt.set_native_tools(native);
1063 rt.set_call_context(make_call_context());
1064 rt.set_allow_native_tool_calls(false);
1065
1066 let rt = Arc::new(Mutex::new(rt));
1067 let rt2 = Arc::clone(&rt);
1068
1069 tokio::task::spawn_blocking(move || {
1070 let guard = rt2.lock().unwrap();
1071 let result = guard.exec(
1072 r#"
1073 imp.tool("echo", { text = "hello from lua" })
1074 "#,
1075 );
1076 assert!(result.is_err(), "disabled imp.tool() should error");
1077 let err = format!("{}", result.unwrap_err());
1078 assert!(
1079 err.contains("disabled"),
1080 "error should mention disabled state: {err}"
1081 );
1082 })
1083 .await
1084 .unwrap();
1085 }
1086
1087 #[test]
1090 fn imp_exec_errors_when_disabled() {
1091 let rt = make_runtime();
1092 let result = rt.exec(
1093 r#"
1094 local _ = imp.exec("echo hi")
1095 "#,
1096 );
1097 assert!(result.is_err(), "disabled imp.exec() should error");
1098 let err = format!("{}", result.unwrap_err());
1099 assert!(
1100 err.contains("imp.exec() is disabled"),
1101 "unexpected error: {err}"
1102 );
1103 }
1104
1105 #[test]
1106 fn imp_exec_runs_when_enabled() {
1107 let rt = make_runtime();
1108 rt.set_allow_shell_exec(true);
1109
1110 rt.exec(
1111 r#"
1112 local result = imp.exec("printf lua_exec_ok")
1113 _test_stdout = result.stdout
1114 _test_exit = result.exit_code
1115 "#,
1116 )
1117 .unwrap();
1118
1119 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1120 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1121 assert_eq!(stdout, "lua_exec_ok");
1122 assert_eq!(exit_code, 0);
1123 }
1124
1125 #[test]
1126 fn imp_exec_passes_scoped_env_to_child_process() {
1127 let rt = make_runtime();
1128 rt.set_allow_shell_exec(true);
1129
1130 rt.exec(
1131 r#"
1132 local result = imp.exec("printf %s \"$IMP_LUA_CHILD_SECRET\"", nil, {
1133 env = { IMP_LUA_CHILD_SECRET = "child-only-value" },
1134 })
1135 _test_stdout = result.stdout
1136 _test_exit = result.exit_code
1137 "#,
1138 )
1139 .unwrap();
1140
1141 let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1142 let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1143 assert_eq!(stdout, "child-only-value");
1144 assert_eq!(exit_code, 0);
1145 }
1146
1147 #[test]
1148 fn imp_secret_errors_when_disabled() {
1149 let rt = make_runtime();
1150 let result = rt.exec(
1151 r#"
1152 local _ = imp.secret("openai", "api_key")
1153 "#,
1154 );
1155 assert!(result.is_err(), "disabled imp.secret() should error");
1156 let err = format!("{}", result.unwrap_err());
1157 assert!(
1158 err.contains("imp.secret() is disabled"),
1159 "unexpected error: {err}"
1160 );
1161 }
1162
1163 #[test]
1164 fn imp_secret_fields_errors_when_disabled() {
1165 let rt = make_runtime();
1166 let result = rt.exec(
1167 r#"
1168 local _ = imp.secret_fields("openai")
1169 "#,
1170 );
1171 assert!(result.is_err(), "disabled imp.secret_fields() should error");
1172 let err = format!("{}", result.unwrap_err());
1173 assert!(
1174 err.contains("imp.secret_fields() is disabled"),
1175 "unexpected error: {err}"
1176 );
1177 }
1178
1179 #[tokio::test(flavor = "multi_thread")]
1180 async fn imp_http_get_errors_when_disabled() {
1181 let rt = Arc::new(Mutex::new(make_runtime()));
1182 let rt2 = Arc::clone(&rt);
1183 tokio::task::spawn_blocking(move || {
1184 let guard = rt2.lock().unwrap();
1185 let result = guard.exec(
1186 r#"
1187 local _ = imp.http.get("https://example.com")
1188 "#,
1189 );
1190 assert!(result.is_err(), "disabled imp.http.get() should error");
1191 let err = format!("{}", result.unwrap_err());
1192 assert!(
1193 err.contains("imp.http.get() is disabled"),
1194 "unexpected error: {err}"
1195 );
1196 })
1197 .await
1198 .unwrap();
1199 }
1200
1201 #[test]
1202 fn imp_env_reads_var_when_allowed() {
1203 let rt = make_runtime();
1204 std::env::set_var("IMP_LUA_TEST_VAR", "test_value");
1205
1206 let mut allowed = std::collections::HashSet::new();
1207 allowed.insert("IMP_LUA_TEST_VAR".to_string());
1208 rt.set_allowed_env(allowed);
1209
1210 rt.exec(
1211 r#"
1212 _env_val = imp.env("IMP_LUA_TEST_VAR")
1213 "#,
1214 )
1215 .unwrap();
1216
1217 let val: String = rt.lua().globals().get("_env_val").unwrap();
1218 assert_eq!(val, "test_value");
1219 }
1220
1221 #[test]
1222 fn imp_env_returns_nil_for_denied_var() {
1223 let rt = make_runtime();
1224 std::env::set_var("IMP_LUA_TEST_SECRET", "secret_value");
1225
1226 let mut allowed = std::collections::HashSet::new();
1227 allowed.insert("SOME_OTHER_VAR".to_string());
1228 rt.set_allowed_env(allowed);
1229
1230 rt.exec(
1231 r#"
1232 _env_val = imp.env("IMP_LUA_TEST_SECRET")
1233 _is_nil = (_env_val == nil)
1234 "#,
1235 )
1236 .unwrap();
1237
1238 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1239 assert!(is_nil, "denied env var should return nil");
1240 }
1241
1242 #[test]
1243 fn imp_env_allows_all_when_list_empty() {
1244 let rt = make_runtime();
1245 std::env::set_var("IMP_LUA_TEST_OPEN", "open_value");
1246
1247 rt.set_allowed_env(std::collections::HashSet::new());
1249
1250 rt.exec(
1251 r#"
1252 _env_val = imp.env("IMP_LUA_TEST_OPEN")
1253 _is_nil = (_env_val == nil)
1254 "#,
1255 )
1256 .unwrap();
1257
1258 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1259 assert!(is_nil, "empty allow-list should deny env access by default");
1260 }
1261
1262 #[test]
1263 fn imp_env_returns_nil_for_missing_var() {
1264 let rt = make_runtime();
1265 rt.set_allowed_env(std::collections::HashSet::new());
1266
1267 rt.exec(
1268 r#"
1269 _env_val = imp.env("DEFINITELY_NOT_SET_IMP_LUA_TEST")
1270 _is_nil = (_env_val == nil)
1271 "#,
1272 )
1273 .unwrap();
1274
1275 let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1276 assert!(is_nil, "missing env var should return nil");
1277 }
1278}