Skip to main content

imp_lua/
lib.rs

1pub mod bridge;
2pub mod loader;
3pub mod sandbox;
4
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7
8use imp_core::config::LuaCapabilityPolicy;
9use imp_core::tools::ToolRegistry;
10
11pub use bridge::{json_to_lua_value, load_lua_tools, lua_value_to_json, setup_host_api, LuaTool};
12pub use loader::{discover_extensions, load_extensions, reload, LuaExtension};
13pub use sandbox::{
14    LuaCallContext, LuaCommandHandle, LuaError, LuaHookHandle, LuaRuntime, LuaToolHandle,
15};
16
17/// Discover and load Lua extensions from user and project directories,
18/// registering any tools they define onto the given registry.
19///
20/// Returns the shared runtime handle (for command dispatch and hot-reload).
21/// Returns `None` if no extensions were found or the runtime failed to start.
22pub fn init_lua_extensions(
23    user_config_dir: &Path,
24    project_dir: Option<&Path>,
25    tools: &mut ToolRegistry,
26    policy: &LuaCapabilityPolicy,
27) -> Option<Arc<Mutex<LuaRuntime>>> {
28    let extensions = discover_extensions(user_config_dir, project_dir);
29    if extensions.is_empty() {
30        return None;
31    }
32
33    let rt = match LuaRuntime::new() {
34        Ok(rt) => rt,
35        Err(_e) => {
36            return None;
37        }
38    };
39    if let Err(_e) = setup_host_api(&rt) {
40        return None;
41    }
42    rt.apply_capability_policy(policy);
43
44    let results = load_extensions(&rt, &extensions);
45    for (_name, result) in &results {
46        if let Err(_e) = result {
47            // Keep extension bootstrap silent in embedded runtimes; failed extensions
48            // simply do not register their tools.
49        }
50    }
51
52    // Give the Lua runtime access to native tools for imp.tool() calls
53    rt.set_native_tools(tools.tools_map());
54
55    let runtime = Arc::new(Mutex::new(rt));
56    load_lua_tools(Arc::clone(&runtime), tools);
57    Some(runtime)
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use std::path::PathBuf;
64    use std::sync::{Arc, Mutex};
65
66    use imp_core::config::LuaCapabilityPolicy;
67    use imp_core::tools::{ToolContext, ToolRegistry};
68    use imp_core::ui::NullInterface;
69    use tempfile::TempDir;
70
71    fn make_policy() -> LuaCapabilityPolicy {
72        LuaCapabilityPolicy::default()
73    }
74
75    /// Helper: create a runtime with host API set up.
76    fn make_runtime() -> LuaRuntime {
77        let rt = LuaRuntime::new().expect("create runtime");
78        setup_host_api(&rt).expect("setup host api");
79        rt
80    }
81
82    /// Helper: write a Lua file into a directory.
83    fn write_lua(dir: &std::path::Path, name: &str, content: &str) -> PathBuf {
84        let path = dir.join(name);
85        std::fs::write(&path, content).unwrap();
86        path
87    }
88
89    fn test_ctx() -> ToolContext {
90        let (tx, _rx) = tokio::sync::mpsc::channel(16);
91        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
92        ToolContext {
93            cwd: PathBuf::from("/tmp/lua-tools"),
94            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
95            update_tx: tx,
96            command_tx: cmd_tx,
97            ui: Arc::new(NullInterface),
98            file_cache: Arc::new(imp_core::tools::FileCache::new()),
99            checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
100            file_tracker: Arc::new(std::sync::Mutex::new(
101                imp_core::tools::FileTracker::default(),
102            )),
103            anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
104            mode: imp_core::config::AgentMode::Full,
105            read_max_lines: 0,
106            turn_mana_review: Arc::new(std::sync::Mutex::new(
107                imp_core::mana_review::TurnManaReviewAccumulator::default(),
108            )),
109            run_policy: Default::default(),
110            config: Arc::new(imp_core::config::Config::default()),
111            lua_tool_loader: None,
112            supporting_provenance: Vec::new(),
113        }
114    }
115
116    // ── Discovery ────────────────────────────────────────────────
117
118    #[test]
119    fn init_lua_extensions_applies_capability_policy() {
120        let user = TempDir::new().unwrap();
121        let lua_dir = user.path().join("lua");
122        std::fs::create_dir_all(&lua_dir).unwrap();
123        write_lua(&lua_dir, "ext.lua", "-- extension present");
124
125        let mut registry = ToolRegistry::new();
126        let mut policy = make_policy();
127        policy.allow_native_tool_calls = false;
128        policy.allow_shell_exec = true;
129        policy.allow_http = true;
130        policy.allow_secrets = true;
131        policy.allowed_env.insert("ALLOWED_ONE".to_string());
132
133        let runtime = init_lua_extensions(user.path(), None, &mut registry, &policy)
134            .expect("runtime should initialize");
135        let guard = runtime.lock().unwrap();
136
137        assert!(!guard
138            .allow_native_tool_calls()
139            .load(std::sync::atomic::Ordering::Relaxed));
140        assert!(guard
141            .allow_shell_exec()
142            .load(std::sync::atomic::Ordering::Relaxed));
143        assert!(guard
144            .allow_http()
145            .load(std::sync::atomic::Ordering::Relaxed));
146        assert!(guard
147            .allow_secrets()
148            .load(std::sync::atomic::Ordering::Relaxed));
149        assert!(guard.allowed_env().lock().unwrap().contains("ALLOWED_ONE"));
150    }
151
152    #[test]
153    fn discover_user_lua_files() {
154        let tmp = TempDir::new().unwrap();
155        let lua_dir = tmp.path().join("lua");
156        std::fs::create_dir_all(&lua_dir).unwrap();
157        write_lua(&lua_dir, "greet.lua", "-- hello");
158        write_lua(&lua_dir, "utils.lua", "-- utils");
159
160        let exts = discover_extensions(tmp.path(), None);
161        assert_eq!(exts.len(), 2);
162
163        let names: Vec<&str> = exts.iter().map(|e| e.name.as_str()).collect();
164        assert!(names.contains(&"greet"));
165        assert!(names.contains(&"utils"));
166    }
167
168    #[test]
169    fn discover_directory_init_lua() {
170        let tmp = TempDir::new().unwrap();
171        let ext_dir = tmp.path().join("lua").join("my-ext");
172        std::fs::create_dir_all(&ext_dir).unwrap();
173        write_lua(&ext_dir, "init.lua", "-- init");
174
175        let exts = discover_extensions(tmp.path(), None);
176        assert_eq!(exts.len(), 1);
177        assert_eq!(exts[0].name, "my-ext");
178        assert!(exts[0].path.ends_with("init.lua"));
179    }
180
181    #[test]
182    fn discover_project_local() {
183        let user = TempDir::new().unwrap();
184        let project = TempDir::new().unwrap();
185        let proj_lua = project.path().join(".imp").join("lua");
186        std::fs::create_dir_all(&proj_lua).unwrap();
187        write_lua(&proj_lua, "local.lua", "-- local");
188
189        let exts = discover_extensions(user.path(), Some(project.path()));
190        assert_eq!(exts.len(), 1);
191        assert_eq!(exts[0].name, "local");
192    }
193
194    #[test]
195    fn discover_empty_dirs_return_nothing() {
196        let tmp = TempDir::new().unwrap();
197        let exts = discover_extensions(tmp.path(), None);
198        assert!(exts.is_empty());
199    }
200
201    // ── imp.on() — Hook registration ────────────────────────────
202
203    #[test]
204    fn on_registers_hook() {
205        let rt = make_runtime();
206        rt.exec(
207            r#"
208            imp.on("on_session_start", function(event, ctx)
209                -- handler
210            end)
211        "#,
212        )
213        .unwrap();
214
215        assert_eq!(rt.hook_count(), 1);
216        let events = rt.hook_events();
217        assert_eq!(events[0], "on_session_start");
218    }
219
220    #[test]
221    fn on_registers_multiple_hooks() {
222        let rt = make_runtime();
223        rt.exec(
224            r#"
225            imp.on("on_session_start", function() end)
226            imp.on("after_file_write", function() end)
227            imp.on("before_tool_call", function() end)
228        "#,
229        )
230        .unwrap();
231
232        assert_eq!(rt.hook_count(), 3);
233        let events = rt.hook_events();
234        assert!(events.contains(&"on_session_start".to_string()));
235        assert!(events.contains(&"after_file_write".to_string()));
236        assert!(events.contains(&"before_tool_call".to_string()));
237    }
238
239    #[test]
240    fn hook_handler_fires_on_correct_event() {
241        let rt = make_runtime();
242        rt.exec(
243            r#"
244            _test_fired = false
245            imp.on("on_session_start", function()
246                _test_fired = true
247            end)
248        "#,
249        )
250        .unwrap();
251
252        // Simulate firing the hook by calling the stored handler
253        let hooks = rt.hooks();
254        let hooks_guard = hooks.lock().unwrap();
255        assert_eq!(hooks_guard.len(), 1);
256
257        let handler: mlua::Function = rt
258            .lua()
259            .registry_value(&hooks_guard[0].handler_key)
260            .unwrap();
261        handler.call::<()>(()).unwrap();
262
263        let fired: bool = rt.lua().globals().get("_test_fired").unwrap();
264        assert!(fired);
265    }
266
267    // ── imp.register_tool() ─────────────────────────────────────
268
269    #[test]
270    fn register_tool_creates_handle() {
271        let rt = make_runtime();
272        rt.exec(
273            r#"
274            imp.register_tool({
275                name = "greet",
276                label = "Greeting Tool",
277                description = "Says hello",
278                readonly = true,
279                params = {
280                    type = "object",
281                    properties = {
282                        name = { type = "string", description = "Who to greet" }
283                    }
284                },
285                execute = function(call_id, params, ctx)
286                    return { content = "Hello, " .. (params.name or "world") }
287                end
288            })
289        "#,
290        )
291        .unwrap();
292
293        assert_eq!(rt.tool_count(), 1);
294        let names = rt.tool_names();
295        assert_eq!(names[0], "greet");
296    }
297
298    #[test]
299    fn register_tool_execute_callable() {
300        let rt = make_runtime();
301        rt.exec(
302            r#"
303            imp.register_tool({
304                name = "add",
305                execute = function(call_id, params, ctx)
306                    return { content = tostring(params.a + params.b), is_error = false }
307                end
308            })
309        "#,
310        )
311        .unwrap();
312
313        // Call the execute function directly
314        let tools = rt.tools();
315        let tools_guard = tools.lock().unwrap();
316        let execute_fn: mlua::Function = rt
317            .lua()
318            .registry_value(&tools_guard[0].execute_key)
319            .unwrap();
320
321        let params = rt.lua().create_table().unwrap();
322        params.set("a", 3).unwrap();
323        params.set("b", 4).unwrap();
324
325        let result: mlua::Table = execute_fn
326            .call(("call_1", params, mlua::Value::Nil))
327            .unwrap();
328        let content: String = result.get("content").unwrap();
329        assert_eq!(content, "7");
330    }
331
332    #[tokio::test]
333    async fn load_lua_tools_registers_and_executes_bridge() {
334        let rt = make_runtime();
335        rt.exec(
336            r#"
337            imp.register_tool({
338                name = "greet",
339                label = "Greeting Tool",
340                description = "Greets from Lua",
341                readonly = true,
342                params = {
343                    name = { type = "string", description = "Who to greet", required = true },
344                    excited = { type = "boolean" }
345                },
346                execute = function(call_id, params, ctx)
347                    local suffix = params.excited and "!" or "."
348                    return {
349                        content = {
350                            { type = "text", text = "hello " .. params.name .. suffix },
351                        },
352                        details = {
353                            call_id = call_id,
354                            cwd = ctx.cwd,
355                            cancelled = ctx.cancelled,
356                        },
357                    }
358                end
359            })
360        "#,
361        )
362        .unwrap();
363
364        let runtime = Arc::new(Mutex::new(rt));
365        let mut registry = ToolRegistry::new();
366        load_lua_tools(Arc::clone(&runtime), &mut registry);
367
368        let tool = registry
369            .get("greet")
370            .expect("lua tool should be registered");
371        assert_eq!(tool.label(), "Greeting Tool");
372        assert_eq!(tool.description(), "Greets from Lua");
373        assert!(tool.is_readonly());
374        assert_eq!(tool.parameters()["properties"]["name"]["type"], "string");
375        assert_eq!(tool.parameters()["required"], serde_json::json!(["name"]));
376
377        let output = tool
378            .execute(
379                "call_123",
380                serde_json::json!({ "name": "Ada", "excited": true }),
381                test_ctx(),
382            )
383            .await
384            .unwrap();
385
386        let text = output
387            .content
388            .iter()
389            .find_map(|block| match block {
390                imp_core::imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
391                _ => None,
392            })
393            .expect("lua tool should return text");
394        assert_eq!(text, "hello Ada!");
395        assert_eq!(output.details["call_id"], "call_123");
396        assert_eq!(output.details["cwd"], "/tmp/lua-tools");
397        assert_eq!(output.details["cancelled"], false);
398    }
399
400    #[test]
401    fn imp_secret_helpers_exist() {
402        let rt = make_runtime();
403        rt.exec(
404            r#"
405            _has_secret = type(imp.secret) == "function"
406            _has_secret_fields = type(imp.secret_fields) == "function"
407        "#,
408        )
409        .unwrap();
410
411        let has_secret: bool = rt.lua().globals().get("_has_secret").unwrap();
412        let has_secret_fields: bool = rt.lua().globals().get("_has_secret_fields").unwrap();
413        assert!(has_secret);
414        assert!(has_secret_fields);
415    }
416
417    // ── imp.exec() — Shell execution ────────────────────────────
418
419    #[test]
420    fn exec_runs_command_returns_stdout() {
421        let rt = make_runtime();
422        rt.set_allow_shell_exec(true);
423        rt.exec(
424            r#"
425            local result = imp.exec("echo hello")
426            _test_stdout = result.stdout
427            _test_exit = result.exit_code
428        "#,
429        )
430        .unwrap();
431
432        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
433        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
434        assert_eq!(stdout.trim(), "hello");
435        assert_eq!(exit_code, 0);
436    }
437
438    #[test]
439    fn exec_captures_stderr() {
440        let rt = make_runtime();
441        rt.set_allow_shell_exec(true);
442        rt.exec(
443            r#"
444            local result = imp.exec("echo error >&2")
445            _test_stderr = result.stderr
446        "#,
447        )
448        .unwrap();
449
450        let stderr: String = rt.lua().globals().get("_test_stderr").unwrap();
451        assert_eq!(stderr.trim(), "error");
452    }
453
454    #[test]
455    fn exec_returns_nonzero_exit_code() {
456        let rt = make_runtime();
457        rt.set_allow_shell_exec(true);
458        rt.exec(
459            r#"
460            local result = imp.exec("exit 42")
461            _test_exit = result.exit_code
462        "#,
463        )
464        .unwrap();
465
466        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
467        assert_eq!(exit_code, 42);
468    }
469
470    #[test]
471    fn exec_with_args_runs_command_without_shell_joining() {
472        let rt = make_runtime();
473        rt.set_allow_shell_exec(true);
474        rt.exec(
475            r#"
476            local result = imp.exec("sh", { "-lc", "printf '%s' \"$1\"", "--", "hello world" })
477            _test_stdout = result.stdout
478            _test_exit = result.exit_code
479        "#,
480        )
481        .unwrap();
482
483        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
484        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
485        assert_eq!(stdout, "hello world");
486        assert_eq!(exit_code, 0);
487    }
488
489    #[test]
490    fn exec_with_cwd() {
491        let rt = make_runtime();
492        rt.set_allow_shell_exec(true);
493        rt.exec(
494            r#"
495            local result = imp.exec("pwd", nil, { cwd = "/tmp" })
496            _test_cwd = result.stdout
497        "#,
498        )
499        .unwrap();
500
501        let cwd: String = rt.lua().globals().get("_test_cwd").unwrap();
502        // /tmp may resolve to /private/tmp on macOS
503        assert!(cwd.trim().contains("tmp"));
504    }
505
506    // ── ctx.ui.confirm() with NullInterface ─────────────────────
507    // The NullInterface returns None for confirm, which maps to nil in Lua.
508    // We simulate this by testing that the bridge correctly handles nil returns.
509
510    #[test]
511    fn null_interface_confirm_returns_nil() {
512        let rt = make_runtime();
513        // Simulate what ctx.ui.confirm would do with NullInterface — just return nil
514        rt.exec(
515            r#"
516            -- When NullInterface returns None, the bridge maps it to nil
517            _confirm_result = nil  -- This is what NullInterface.confirm() produces
518            _is_nil = (_confirm_result == nil)
519        "#,
520        )
521        .unwrap();
522
523        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
524        assert!(is_nil);
525    }
526
527    // ── imp.ui ─────────────────────────────────────────────────
528
529    #[test]
530    fn imp_ui_confirm_returns_nil_without_call_context() {
531        let rt = make_runtime();
532        rt.exec(
533            r#"
534            _confirm_result = imp.ui.confirm("Proceed?", "No UI should be unavailable")
535            _is_nil = (_confirm_result == nil)
536        "#,
537        )
538        .unwrap();
539
540        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
541        assert!(is_nil);
542    }
543
544    #[test]
545    fn imp_ui_request_reports_unavailable_without_call_context() {
546        let rt = make_runtime();
547        rt.exec(
548            r#"
549            _ui_result = imp.ui.request({ kind = "confirm", title = "Proceed?" })
550        "#,
551        )
552        .unwrap();
553
554        let result: mlua::Table = rt.lua().globals().get("_ui_result").unwrap();
555        let ok: bool = result.get("ok").unwrap();
556        let reason: String = result.get("reason").unwrap();
557        assert!(!ok);
558        assert_eq!(reason, "unavailable");
559    }
560
561    #[test]
562    fn lua_command_can_call_imp_ui_confirm_safely_headless() {
563        let rt = make_runtime();
564        rt.exec(
565            r#"
566            imp.register_command("confirm-test", {
567                handler = function(args)
568                    local yes = imp.ui.confirm("Proceed?", args)
569                    if yes == nil then
570                        return "unavailable"
571                    end
572                    return tostring(yes)
573                end
574            })
575        "#,
576        )
577        .unwrap();
578
579        let output = rt
580            .execute_command_with_context("confirm-test", "headless", Some(test_ctx().into()))
581            .unwrap();
582        assert_eq!(output.as_deref(), Some("unavailable"));
583    }
584
585    // ── Hot reload ──────────────────────────────────────────────
586
587    #[test]
588    fn hot_reload_drops_and_recreates() {
589        let user_dir = TempDir::new().unwrap();
590        let lua_dir = user_dir.path().join("lua");
591        std::fs::create_dir_all(&lua_dir).unwrap();
592
593        write_lua(
594            &lua_dir,
595            "ext.lua",
596            r#"
597            imp.on("on_session_start", function() end)
598            imp.register_tool({ name = "my_tool", execute = function() end })
599        "#,
600        );
601
602        let policy = make_policy();
603        // First load
604        let (rt1, exts1) = reload(user_dir.path(), None, &policy).unwrap();
605        assert_eq!(rt1.hook_count(), 1);
606        assert_eq!(rt1.tool_count(), 1);
607        assert_eq!(exts1.len(), 1);
608
609        // Modify the extension
610        write_lua(
611            &lua_dir,
612            "ext.lua",
613            r#"
614            imp.on("on_session_start", function() end)
615            imp.on("after_file_write", function() end)
616            imp.register_tool({ name = "tool_a", execute = function() end })
617            imp.register_tool({ name = "tool_b", execute = function() end })
618        "#,
619        );
620
621        // Reload — old state is dropped, new state picks up changes
622        let (rt2, exts2) = reload(user_dir.path(), None, &policy).unwrap();
623        assert_eq!(rt2.hook_count(), 2);
624        assert_eq!(rt2.tool_count(), 2);
625        assert_eq!(exts2.len(), 1);
626
627        let tool_names = rt2.tool_names();
628        assert!(tool_names.contains(&"tool_a".to_string()));
629        assert!(tool_names.contains(&"tool_b".to_string()));
630    }
631
632    #[test]
633    fn hot_reload_applies_capability_policy() {
634        let user_dir = TempDir::new().unwrap();
635        let lua_dir = user_dir.path().join("lua");
636        std::fs::create_dir_all(&lua_dir).unwrap();
637        write_lua(&lua_dir, "ext.lua", "-- extension present");
638
639        let mut policy = make_policy();
640        policy.allow_shell_exec = true;
641
642        let (rt, _exts) = reload(user_dir.path(), None, &policy).unwrap();
643        assert!(rt
644            .allow_shell_exec()
645            .load(std::sync::atomic::Ordering::Relaxed));
646    }
647
648    // ── Error handling ──────────────────────────────────────────
649
650    #[test]
651    fn lua_syntax_error_caught() {
652        let rt = make_runtime();
653        let result = rt.exec("this is not valid lua !!!");
654        assert!(result.is_err());
655        // Runtime is still usable after error
656        let result2 = rt.exec("_test_ok = true");
657        assert!(result2.is_ok());
658        let ok: bool = rt.lua().globals().get("_test_ok").unwrap();
659        assert!(ok);
660    }
661
662    #[test]
663    fn lua_runtime_error_caught() {
664        let rt = make_runtime();
665        let result = rt.exec("error('intentional error')");
666        assert!(result.is_err());
667        let err_msg = format!("{}", result.unwrap_err());
668        assert!(err_msg.contains("intentional error"));
669    }
670
671    #[test]
672    fn extension_error_doesnt_crash_runtime() {
673        let rt = make_runtime();
674
675        // First extension errors
676        let r1 = rt.exec("error('ext1 failed')");
677        assert!(r1.is_err());
678
679        // Second extension still loads fine
680        let r2 = rt.exec(
681            r#"
682            imp.on("on_session_start", function() end)
683            _ext2_loaded = true
684        "#,
685        );
686        assert!(r2.is_ok());
687        assert_eq!(rt.hook_count(), 1);
688
689        let loaded: bool = rt.lua().globals().get("_ext2_loaded").unwrap();
690        assert!(loaded);
691    }
692
693    // ── Multiple extensions coexist ─────────────────────────────
694
695    #[test]
696    fn multiple_extensions_coexist() {
697        let rt = make_runtime();
698
699        // Extension 1
700        rt.exec(
701            r#"
702            imp.on("on_session_start", function()
703                _ext1_fired = true
704            end)
705            imp.register_tool({ name = "ext1_tool", execute = function() end })
706        "#,
707        )
708        .unwrap();
709
710        // Extension 2
711        rt.exec(
712            r#"
713            imp.on("after_file_write", function()
714                _ext2_fired = true
715            end)
716            imp.register_tool({ name = "ext2_tool", execute = function() end })
717        "#,
718        )
719        .unwrap();
720
721        assert_eq!(rt.hook_count(), 2);
722        assert_eq!(rt.tool_count(), 2);
723
724        let names = rt.tool_names();
725        assert!(names.contains(&"ext1_tool".to_string()));
726        assert!(names.contains(&"ext2_tool".to_string()));
727
728        let events = rt.hook_events();
729        assert!(events.contains(&"on_session_start".to_string()));
730        assert!(events.contains(&"after_file_write".to_string()));
731    }
732
733    #[test]
734    fn extensions_share_state() {
735        let rt = make_runtime();
736
737        // Extension 1 sets a global
738        rt.exec("shared_counter = 1").unwrap();
739
740        // Extension 2 reads and increments it
741        rt.exec(
742            r#"
743            shared_counter = shared_counter + 1
744            _final = shared_counter
745        "#,
746        )
747        .unwrap();
748
749        let val: i64 = rt.lua().globals().get("_final").unwrap();
750        assert_eq!(val, 2);
751    }
752
753    // ── Inter-extension events ──────────────────────────────────
754
755    #[test]
756    fn events_on_and_emit() {
757        let rt = make_runtime();
758        rt.exec(
759            r#"
760            _received = nil
761            imp.events.on("custom_event", function(data)
762                _received = data
763            end)
764            imp.events.emit("custom_event", "hello from event")
765        "#,
766        )
767        .unwrap();
768
769        let received: String = rt.lua().globals().get("_received").unwrap();
770        assert_eq!(received, "hello from event");
771    }
772
773    #[test]
774    fn events_multiple_handlers() {
775        let rt = make_runtime();
776        rt.exec(
777            r#"
778            _count = 0
779            imp.events.on("tick", function() _count = _count + 1 end)
780            imp.events.on("tick", function() _count = _count + 1 end)
781            imp.events.emit("tick", nil)
782        "#,
783        )
784        .unwrap();
785
786        let count: i64 = rt.lua().globals().get("_count").unwrap();
787        assert_eq!(count, 2);
788    }
789
790    #[test]
791    fn events_handler_error_doesnt_crash() {
792        let rt = make_runtime();
793        rt.exec(
794            r#"
795            _after_error = false
796            imp.events.on("test", function() error("boom") end)
797            imp.events.on("test", function() _after_error = true end)
798            imp.events.emit("test", nil)
799        "#,
800        )
801        .unwrap();
802
803        let after: bool = rt.lua().globals().get("_after_error").unwrap();
804        assert!(after, "second handler should still fire after first errors");
805    }
806
807    // ── imp.register_command() ──────────────────────────────────
808
809    #[test]
810    fn register_command_creates_handle() {
811        let rt = make_runtime();
812        rt.exec(
813            r#"
814            imp.register_command("greet", {
815                description = "Say hello",
816                handler = function(args, ctx)
817                    return "Hello!"
818                end
819            })
820        "#,
821        )
822        .unwrap();
823
824        assert_eq!(rt.command_count(), 1);
825    }
826
827    // ── JSON conversion ─────────────────────────────────────────
828
829    #[test]
830    fn lua_value_to_json_primitives() {
831        let rt = make_runtime();
832        let lua = rt.lua();
833
834        assert_eq!(lua_value_to_json(mlua::Value::Nil), serde_json::Value::Null);
835        assert_eq!(
836            lua_value_to_json(mlua::Value::Boolean(true)),
837            serde_json::json!(true)
838        );
839        assert_eq!(
840            lua_value_to_json(mlua::Value::Integer(42)),
841            serde_json::json!(42)
842        );
843        assert_eq!(
844            lua_value_to_json(mlua::Value::Number(1.25)),
845            serde_json::json!(1.25)
846        );
847
848        let s = lua.create_string("hello").unwrap();
849        assert_eq!(
850            lua_value_to_json(mlua::Value::String(s)),
851            serde_json::json!("hello")
852        );
853    }
854
855    #[test]
856    fn lua_table_to_json_object() {
857        let rt = make_runtime();
858        rt.exec(
859            r#"
860            _test_table = { name = "Alice", age = 30 }
861        "#,
862        )
863        .unwrap();
864
865        let val: mlua::Value = rt.lua().globals().get("_test_table").unwrap();
866        let json = lua_value_to_json(val);
867        assert_eq!(json["name"], "Alice");
868        assert_eq!(json["age"], 30);
869    }
870
871    #[test]
872    fn lua_array_to_json_array() {
873        let rt = make_runtime();
874        rt.exec(
875            r#"
876            _test_arr = { 1, 2, 3 }
877        "#,
878        )
879        .unwrap();
880
881        let val: mlua::Value = rt.lua().globals().get("_test_arr").unwrap();
882        let json = lua_value_to_json(val);
883        assert_eq!(json, serde_json::json!([1, 2, 3]));
884    }
885
886    #[test]
887    fn json_to_lua_roundtrip() {
888        let rt = make_runtime();
889        let lua = rt.lua();
890
891        let original = serde_json::json!({
892            "name": "test",
893            "count": 42,
894            "active": true,
895            "tags": ["a", "b"],
896            "nested": { "x": 1 }
897        });
898
899        let lua_val = json_to_lua_value(lua, &original).unwrap();
900        let back = lua_value_to_json(lua_val);
901        assert_eq!(back, original);
902    }
903
904    // ── File loading ────────────────────────────────────────────
905
906    #[test]
907    fn load_extensions_from_files() {
908        let user_dir = TempDir::new().unwrap();
909        let lua_dir = user_dir.path().join("lua");
910        std::fs::create_dir_all(&lua_dir).unwrap();
911
912        write_lua(
913            &lua_dir,
914            "a.lua",
915            r#"imp.on("on_session_start", function() end)"#,
916        );
917        write_lua(
918            &lua_dir,
919            "b.lua",
920            r#"imp.register_tool({ name = "b_tool", execute = function() end })"#,
921        );
922
923        let exts = discover_extensions(user_dir.path(), None);
924        assert_eq!(exts.len(), 2);
925
926        let rt = make_runtime();
927        let results = load_extensions(&rt, &exts);
928
929        // Both should succeed
930        for (name, result) in &results {
931            assert!(result.is_ok(), "Extension {} failed: {:?}", name, result);
932        }
933
934        assert_eq!(rt.hook_count(), 1);
935        assert_eq!(rt.tool_count(), 1);
936    }
937
938    #[test]
939    fn load_extension_error_reported_not_fatal() {
940        let user_dir = TempDir::new().unwrap();
941        let lua_dir = user_dir.path().join("lua");
942        std::fs::create_dir_all(&lua_dir).unwrap();
943
944        write_lua(&lua_dir, "bad.lua", "error('bad extension')");
945        write_lua(
946            &lua_dir,
947            "good.lua",
948            r#"imp.on("on_session_start", function() end)"#,
949        );
950
951        let exts = discover_extensions(user_dir.path(), None);
952        let rt = make_runtime();
953        let results = load_extensions(&rt, &exts);
954
955        // One fails, one succeeds
956        let failures: Vec<_> = results.iter().filter(|(_, r)| r.is_err()).collect();
957        let successes: Vec<_> = results.iter().filter(|(_, r)| r.is_ok()).collect();
958
959        assert_eq!(failures.len(), 1);
960        assert_eq!(successes.len(), 1);
961
962        // Good extension's hook was registered despite the bad extension
963        assert_eq!(rt.hook_count(), 1);
964    }
965
966    // ── imp.tool() — call native tools from Lua ─────────────────
967
968    use async_trait::async_trait;
969
970    struct EchoTestTool;
971
972    #[async_trait]
973    impl imp_core::tools::Tool for EchoTestTool {
974        fn name(&self) -> &str {
975            "echo"
976        }
977        fn label(&self) -> &str {
978            "Echo"
979        }
980        fn description(&self) -> &str {
981            "Echoes text"
982        }
983        fn parameters(&self) -> serde_json::Value {
984            serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}})
985        }
986        fn is_readonly(&self) -> bool {
987            true
988        }
989        async fn execute(
990            &self,
991            _call_id: &str,
992            params: serde_json::Value,
993            _ctx: imp_core::tools::ToolContext,
994        ) -> imp_core::Result<imp_core::tools::ToolOutput> {
995            let text = params["text"].as_str().unwrap_or("no text");
996            Ok(imp_core::tools::ToolOutput::text(format!("echo: {text}")))
997        }
998    }
999
1000    struct FailTestTool;
1001
1002    #[async_trait]
1003    impl imp_core::tools::Tool for FailTestTool {
1004        fn name(&self) -> &str {
1005            "fail"
1006        }
1007        fn label(&self) -> &str {
1008            "Fail"
1009        }
1010        fn description(&self) -> &str {
1011            "Always fails"
1012        }
1013        fn parameters(&self) -> serde_json::Value {
1014            serde_json::json!({"type": "object"})
1015        }
1016        fn is_readonly(&self) -> bool {
1017            true
1018        }
1019        async fn execute(
1020            &self,
1021            _call_id: &str,
1022            _params: serde_json::Value,
1023            _ctx: imp_core::tools::ToolContext,
1024        ) -> imp_core::Result<imp_core::tools::ToolOutput> {
1025            Ok(imp_core::tools::ToolOutput::error("intentional failure"))
1026        }
1027    }
1028
1029    fn make_call_context() -> sandbox::LuaCallContext {
1030        let (tx, _rx) = tokio::sync::mpsc::channel(16);
1031        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
1032        sandbox::LuaCallContext {
1033            cwd: PathBuf::from("/tmp/lua-test"),
1034            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1035            update_tx: tx,
1036            command_tx: cmd_tx,
1037            ui: Arc::new(NullInterface),
1038            file_cache: Arc::new(imp_core::tools::FileCache::new()),
1039            checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
1040            file_tracker: Arc::new(std::sync::Mutex::new(
1041                imp_core::tools::FileTracker::default(),
1042            )),
1043            anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
1044            mode: imp_core::config::AgentMode::Full,
1045            read_max_lines: 500,
1046            run_policy: Default::default(),
1047            lua_tool_loader: None,
1048            config: Arc::new(imp_core::config::Config::default()),
1049        }
1050    }
1051
1052    #[tokio::test(flavor = "multi_thread")]
1053    async fn imp_tool_calls_native_tool() {
1054        let rt = make_runtime();
1055
1056        let mut native = std::collections::HashMap::new();
1057        native.insert(
1058            "echo".to_string(),
1059            Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
1060        );
1061        rt.set_native_tools(native);
1062        rt.set_call_context(make_call_context());
1063
1064        let rt = Arc::new(Mutex::new(rt));
1065        let rt2 = Arc::clone(&rt);
1066
1067        let result = tokio::task::spawn_blocking(move || {
1068            let guard = rt2.lock().unwrap();
1069            guard
1070                .exec(
1071                    r#"
1072                _result, _err = imp.tool("echo", { text = "hello from lua" })
1073            "#,
1074                )
1075                .unwrap();
1076            let result: String = guard.lua().globals().get("_result").unwrap();
1077            let err: mlua::Value = guard.lua().globals().get("_err").unwrap();
1078            assert!(matches!(err, mlua::Value::Nil), "expected no error");
1079            result
1080        })
1081        .await
1082        .unwrap();
1083
1084        assert_eq!(result, "echo: hello from lua");
1085    }
1086
1087    #[tokio::test(flavor = "multi_thread")]
1088    async fn imp_tool_returns_error_on_failure() {
1089        let rt = make_runtime();
1090
1091        let mut native = std::collections::HashMap::new();
1092        native.insert(
1093            "fail".to_string(),
1094            Arc::new(FailTestTool) as Arc<dyn imp_core::tools::Tool>,
1095        );
1096        rt.set_native_tools(native);
1097        rt.set_call_context(make_call_context());
1098
1099        let rt = Arc::new(Mutex::new(rt));
1100        let rt2 = Arc::clone(&rt);
1101
1102        tokio::task::spawn_blocking(move || {
1103            let guard = rt2.lock().unwrap();
1104            guard
1105                .exec(
1106                    r#"
1107                _result, _err = imp.tool("fail", {})
1108            "#,
1109                )
1110                .unwrap();
1111            let result: mlua::Value = guard.lua().globals().get("_result").unwrap();
1112            assert!(matches!(result, mlua::Value::Nil), "expected nil result");
1113            let err: String = guard.lua().globals().get("_err").unwrap();
1114            assert!(
1115                err.contains("intentional failure"),
1116                "expected failure message, got: {err}"
1117            );
1118        })
1119        .await
1120        .unwrap();
1121    }
1122
1123    #[tokio::test(flavor = "multi_thread")]
1124    async fn imp_tool_errors_on_unknown_tool() {
1125        let rt = make_runtime();
1126        rt.set_native_tools(std::collections::HashMap::new());
1127        rt.set_call_context(make_call_context());
1128
1129        let rt = Arc::new(Mutex::new(rt));
1130        let rt2 = Arc::clone(&rt);
1131
1132        tokio::task::spawn_blocking(move || {
1133            let guard = rt2.lock().unwrap();
1134            let result = guard.exec(
1135                r#"
1136                imp.tool("nonexistent", {})
1137            "#,
1138            );
1139            assert!(result.is_err(), "should error on unknown tool");
1140            let err = format!("{}", result.unwrap_err());
1141            assert!(
1142                err.contains("not found"),
1143                "error should mention 'not found': {err}"
1144            );
1145        })
1146        .await
1147        .unwrap();
1148    }
1149
1150    #[tokio::test(flavor = "multi_thread")]
1151    async fn imp_tool_errors_when_disabled() {
1152        let rt = make_runtime();
1153
1154        let mut native = std::collections::HashMap::new();
1155        native.insert(
1156            "echo".to_string(),
1157            Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
1158        );
1159        rt.set_native_tools(native);
1160        rt.set_call_context(make_call_context());
1161        rt.set_allow_native_tool_calls(false);
1162
1163        let rt = Arc::new(Mutex::new(rt));
1164        let rt2 = Arc::clone(&rt);
1165
1166        tokio::task::spawn_blocking(move || {
1167            let guard = rt2.lock().unwrap();
1168            let result = guard.exec(
1169                r#"
1170                imp.tool("echo", { text = "hello from lua" })
1171            "#,
1172            );
1173            assert!(result.is_err(), "disabled imp.tool() should error");
1174            let err = format!("{}", result.unwrap_err());
1175            assert!(
1176                err.contains("disabled"),
1177                "error should mention disabled state: {err}"
1178            );
1179        })
1180        .await
1181        .unwrap();
1182    }
1183
1184    // ── imp.env() — scoped env var access ───────────────────────
1185
1186    #[test]
1187    fn imp_exec_errors_when_disabled() {
1188        let rt = make_runtime();
1189        let result = rt.exec(
1190            r#"
1191            local _ = imp.exec("echo hi")
1192        "#,
1193        );
1194        assert!(result.is_err(), "disabled imp.exec() should error");
1195        let err = format!("{}", result.unwrap_err());
1196        assert!(
1197            err.contains("imp.exec() is disabled"),
1198            "unexpected error: {err}"
1199        );
1200    }
1201
1202    #[test]
1203    fn imp_exec_runs_when_enabled() {
1204        let rt = make_runtime();
1205        rt.set_allow_shell_exec(true);
1206
1207        rt.exec(
1208            r#"
1209            local result = imp.exec("printf lua_exec_ok")
1210            _test_stdout = result.stdout
1211            _test_exit = result.exit_code
1212        "#,
1213        )
1214        .unwrap();
1215
1216        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1217        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1218        assert_eq!(stdout, "lua_exec_ok");
1219        assert_eq!(exit_code, 0);
1220    }
1221
1222    #[test]
1223    fn imp_exec_passes_scoped_env_to_child_process() {
1224        let rt = make_runtime();
1225        rt.set_allow_shell_exec(true);
1226
1227        rt.exec(
1228            r#"
1229            local result = imp.exec("printf %s \"$IMP_LUA_CHILD_SECRET\"", nil, {
1230                env = { IMP_LUA_CHILD_SECRET = "child-only-value" },
1231            })
1232            _test_stdout = result.stdout
1233            _test_exit = result.exit_code
1234        "#,
1235        )
1236        .unwrap();
1237
1238        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1239        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1240        assert_eq!(stdout, "child-only-value");
1241        assert_eq!(exit_code, 0);
1242    }
1243
1244    #[test]
1245    fn imp_secret_errors_when_disabled() {
1246        let rt = make_runtime();
1247        let result = rt.exec(
1248            r#"
1249            local _ = imp.secret("openai", "api_key")
1250        "#,
1251        );
1252        assert!(result.is_err(), "disabled imp.secret() should error");
1253        let err = format!("{}", result.unwrap_err());
1254        assert!(
1255            err.contains("imp.secret() is disabled"),
1256            "unexpected error: {err}"
1257        );
1258    }
1259
1260    #[test]
1261    fn imp_secret_fields_errors_when_disabled() {
1262        let rt = make_runtime();
1263        let result = rt.exec(
1264            r#"
1265            local _ = imp.secret_fields("openai")
1266        "#,
1267        );
1268        assert!(result.is_err(), "disabled imp.secret_fields() should error");
1269        let err = format!("{}", result.unwrap_err());
1270        assert!(
1271            err.contains("imp.secret_fields() is disabled"),
1272            "unexpected error: {err}"
1273        );
1274    }
1275
1276    #[tokio::test(flavor = "multi_thread")]
1277    async fn imp_http_get_errors_when_disabled() {
1278        let rt = Arc::new(Mutex::new(make_runtime()));
1279        let rt2 = Arc::clone(&rt);
1280        tokio::task::spawn_blocking(move || {
1281            let guard = rt2.lock().unwrap();
1282            let result = guard.exec(
1283                r#"
1284                local _ = imp.http.get("https://example.com")
1285            "#,
1286            );
1287            assert!(result.is_err(), "disabled imp.http.get() should error");
1288            let err = format!("{}", result.unwrap_err());
1289            assert!(
1290                err.contains("imp.http.get() is disabled"),
1291                "unexpected error: {err}"
1292            );
1293        })
1294        .await
1295        .unwrap();
1296    }
1297
1298    #[test]
1299    fn imp_env_reads_var_when_allowed() {
1300        let rt = make_runtime();
1301        std::env::set_var("IMP_LUA_TEST_VAR", "test_value");
1302
1303        let mut allowed = std::collections::HashSet::new();
1304        allowed.insert("IMP_LUA_TEST_VAR".to_string());
1305        rt.set_allowed_env(allowed);
1306
1307        rt.exec(
1308            r#"
1309            _env_val = imp.env("IMP_LUA_TEST_VAR")
1310        "#,
1311        )
1312        .unwrap();
1313
1314        let val: String = rt.lua().globals().get("_env_val").unwrap();
1315        assert_eq!(val, "test_value");
1316    }
1317
1318    #[test]
1319    fn imp_env_returns_nil_for_denied_var() {
1320        let rt = make_runtime();
1321        std::env::set_var("IMP_LUA_TEST_SECRET", "secret_value");
1322
1323        let mut allowed = std::collections::HashSet::new();
1324        allowed.insert("SOME_OTHER_VAR".to_string());
1325        rt.set_allowed_env(allowed);
1326
1327        rt.exec(
1328            r#"
1329            _env_val = imp.env("IMP_LUA_TEST_SECRET")
1330            _is_nil = (_env_val == nil)
1331        "#,
1332        )
1333        .unwrap();
1334
1335        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1336        assert!(is_nil, "denied env var should return nil");
1337    }
1338
1339    #[test]
1340    fn imp_env_allows_all_when_list_empty() {
1341        let rt = make_runtime();
1342        std::env::set_var("IMP_LUA_TEST_OPEN", "open_value");
1343
1344        // Empty allowed set should deny by default.
1345        rt.set_allowed_env(std::collections::HashSet::new());
1346
1347        rt.exec(
1348            r#"
1349            _env_val = imp.env("IMP_LUA_TEST_OPEN")
1350            _is_nil = (_env_val == nil)
1351        "#,
1352        )
1353        .unwrap();
1354
1355        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1356        assert!(is_nil, "empty allow-list should deny env access by default");
1357    }
1358
1359    #[test]
1360    fn imp_env_returns_nil_for_missing_var() {
1361        let rt = make_runtime();
1362        rt.set_allowed_env(std::collections::HashSet::new());
1363
1364        rt.exec(
1365            r#"
1366            _env_val = imp.env("DEFINITELY_NOT_SET_IMP_LUA_TEST")
1367            _is_nil = (_env_val == nil)
1368        "#,
1369        )
1370        .unwrap();
1371
1372        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1373        assert!(is_nil, "missing env var should return nil");
1374    }
1375}