Skip to main content

imp_lua/
lib.rs

1pub mod bridge;
2pub mod loader;
3pub mod sandbox;
4
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7
8use imp_core::config::LuaCapabilityPolicy;
9use imp_core::tools::ToolRegistry;
10
11pub use bridge::{json_to_lua_value, load_lua_tools, lua_value_to_json, setup_host_api, LuaTool};
12pub use loader::{discover_extensions, load_extensions, reload, LuaExtension};
13pub use sandbox::{
14    LuaCallContext, LuaCommandHandle, LuaError, LuaHookHandle, LuaRuntime, LuaToolHandle,
15};
16
17/// Discover and load Lua extensions from user and project directories,
18/// registering any tools they define onto the given registry.
19///
20/// Returns the shared runtime handle (for command dispatch and hot-reload).
21/// Returns `None` if no extensions were found or the runtime failed to start.
22pub fn init_lua_extensions(
23    user_config_dir: &Path,
24    project_dir: Option<&Path>,
25    tools: &mut ToolRegistry,
26    policy: &LuaCapabilityPolicy,
27) -> Option<Arc<Mutex<LuaRuntime>>> {
28    let extensions = discover_extensions(user_config_dir, project_dir);
29    if extensions.is_empty() {
30        return None;
31    }
32
33    let rt = match LuaRuntime::new() {
34        Ok(rt) => rt,
35        Err(_e) => {
36            return None;
37        }
38    };
39    if let Err(_e) = setup_host_api(&rt) {
40        return None;
41    }
42    rt.apply_capability_policy(policy);
43
44    let results = load_extensions(&rt, &extensions);
45    for (_name, result) in &results {
46        if let Err(_e) = result {
47            // Keep extension bootstrap silent in embedded runtimes; failed extensions
48            // simply do not register their tools.
49        }
50    }
51
52    // Give the Lua runtime access to native tools for imp.tool() calls
53    rt.set_native_tools(tools.tools_map());
54
55    let runtime = Arc::new(Mutex::new(rt));
56    load_lua_tools(Arc::clone(&runtime), tools);
57    Some(runtime)
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use std::path::PathBuf;
64    use std::sync::{Arc, Mutex};
65
66    use imp_core::config::LuaCapabilityPolicy;
67    use imp_core::tools::{ToolContext, ToolRegistry};
68    use imp_core::ui::NullInterface;
69    use tempfile::TempDir;
70
71    fn make_policy() -> LuaCapabilityPolicy {
72        LuaCapabilityPolicy::default()
73    }
74
75    /// Helper: create a runtime with host API set up.
76    fn make_runtime() -> LuaRuntime {
77        let rt = LuaRuntime::new().expect("create runtime");
78        setup_host_api(&rt).expect("setup host api");
79        rt
80    }
81
82    /// Helper: write a Lua file into a directory.
83    fn write_lua(dir: &std::path::Path, name: &str, content: &str) -> PathBuf {
84        let path = dir.join(name);
85        std::fs::write(&path, content).unwrap();
86        path
87    }
88
89    fn test_ctx() -> ToolContext {
90        let (tx, _rx) = tokio::sync::mpsc::channel(16);
91        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
92        ToolContext {
93            cwd: PathBuf::from("/tmp/lua-tools"),
94            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
95            update_tx: tx,
96            command_tx: cmd_tx,
97            ui: Arc::new(NullInterface),
98            file_cache: Arc::new(imp_core::tools::FileCache::new()),
99            checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
100            file_tracker: Arc::new(std::sync::Mutex::new(
101                imp_core::tools::FileTracker::default(),
102            )),
103            anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
104            mode: imp_core::config::AgentMode::Full,
105            read_max_lines: 0,
106            turn_mana_review: Arc::new(std::sync::Mutex::new(
107                imp_core::mana_review::TurnManaReviewAccumulator::default(),
108            )),
109            config: Arc::new(imp_core::config::Config::default()),
110            lua_tool_loader: None,
111        }
112    }
113
114    // ── Discovery ────────────────────────────────────────────────
115
116    #[test]
117    fn init_lua_extensions_applies_capability_policy() {
118        let user = TempDir::new().unwrap();
119        let lua_dir = user.path().join("lua");
120        std::fs::create_dir_all(&lua_dir).unwrap();
121        write_lua(&lua_dir, "ext.lua", "-- extension present");
122
123        let mut registry = ToolRegistry::new();
124        let mut policy = make_policy();
125        policy.allow_native_tool_calls = false;
126        policy.allow_shell_exec = true;
127        policy.allow_http = true;
128        policy.allow_secrets = true;
129        policy.allowed_env.insert("ALLOWED_ONE".to_string());
130
131        let runtime = init_lua_extensions(user.path(), None, &mut registry, &policy)
132            .expect("runtime should initialize");
133        let guard = runtime.lock().unwrap();
134
135        assert!(!guard
136            .allow_native_tool_calls()
137            .load(std::sync::atomic::Ordering::Relaxed));
138        assert!(guard
139            .allow_shell_exec()
140            .load(std::sync::atomic::Ordering::Relaxed));
141        assert!(guard
142            .allow_http()
143            .load(std::sync::atomic::Ordering::Relaxed));
144        assert!(guard
145            .allow_secrets()
146            .load(std::sync::atomic::Ordering::Relaxed));
147        assert!(guard.allowed_env().lock().unwrap().contains("ALLOWED_ONE"));
148    }
149
150    #[test]
151    fn discover_user_lua_files() {
152        let tmp = TempDir::new().unwrap();
153        let lua_dir = tmp.path().join("lua");
154        std::fs::create_dir_all(&lua_dir).unwrap();
155        write_lua(&lua_dir, "greet.lua", "-- hello");
156        write_lua(&lua_dir, "utils.lua", "-- utils");
157
158        let exts = discover_extensions(tmp.path(), None);
159        assert_eq!(exts.len(), 2);
160
161        let names: Vec<&str> = exts.iter().map(|e| e.name.as_str()).collect();
162        assert!(names.contains(&"greet"));
163        assert!(names.contains(&"utils"));
164    }
165
166    #[test]
167    fn discover_directory_init_lua() {
168        let tmp = TempDir::new().unwrap();
169        let ext_dir = tmp.path().join("lua").join("my-ext");
170        std::fs::create_dir_all(&ext_dir).unwrap();
171        write_lua(&ext_dir, "init.lua", "-- init");
172
173        let exts = discover_extensions(tmp.path(), None);
174        assert_eq!(exts.len(), 1);
175        assert_eq!(exts[0].name, "my-ext");
176        assert!(exts[0].path.ends_with("init.lua"));
177    }
178
179    #[test]
180    fn discover_project_local() {
181        let user = TempDir::new().unwrap();
182        let project = TempDir::new().unwrap();
183        let proj_lua = project.path().join(".imp").join("lua");
184        std::fs::create_dir_all(&proj_lua).unwrap();
185        write_lua(&proj_lua, "local.lua", "-- local");
186
187        let exts = discover_extensions(user.path(), Some(project.path()));
188        assert_eq!(exts.len(), 1);
189        assert_eq!(exts[0].name, "local");
190    }
191
192    #[test]
193    fn discover_empty_dirs_return_nothing() {
194        let tmp = TempDir::new().unwrap();
195        let exts = discover_extensions(tmp.path(), None);
196        assert!(exts.is_empty());
197    }
198
199    // ── imp.on() — Hook registration ────────────────────────────
200
201    #[test]
202    fn on_registers_hook() {
203        let rt = make_runtime();
204        rt.exec(
205            r#"
206            imp.on("on_session_start", function(event, ctx)
207                -- handler
208            end)
209        "#,
210        )
211        .unwrap();
212
213        assert_eq!(rt.hook_count(), 1);
214        let events = rt.hook_events();
215        assert_eq!(events[0], "on_session_start");
216    }
217
218    #[test]
219    fn on_registers_multiple_hooks() {
220        let rt = make_runtime();
221        rt.exec(
222            r#"
223            imp.on("on_session_start", function() end)
224            imp.on("after_file_write", function() end)
225            imp.on("before_tool_call", function() end)
226        "#,
227        )
228        .unwrap();
229
230        assert_eq!(rt.hook_count(), 3);
231        let events = rt.hook_events();
232        assert!(events.contains(&"on_session_start".to_string()));
233        assert!(events.contains(&"after_file_write".to_string()));
234        assert!(events.contains(&"before_tool_call".to_string()));
235    }
236
237    #[test]
238    fn hook_handler_fires_on_correct_event() {
239        let rt = make_runtime();
240        rt.exec(
241            r#"
242            _test_fired = false
243            imp.on("on_session_start", function()
244                _test_fired = true
245            end)
246        "#,
247        )
248        .unwrap();
249
250        // Simulate firing the hook by calling the stored handler
251        let hooks = rt.hooks();
252        let hooks_guard = hooks.lock().unwrap();
253        assert_eq!(hooks_guard.len(), 1);
254
255        let handler: mlua::Function = rt
256            .lua()
257            .registry_value(&hooks_guard[0].handler_key)
258            .unwrap();
259        handler.call::<()>(()).unwrap();
260
261        let fired: bool = rt.lua().globals().get("_test_fired").unwrap();
262        assert!(fired);
263    }
264
265    // ── imp.register_tool() ─────────────────────────────────────
266
267    #[test]
268    fn register_tool_creates_handle() {
269        let rt = make_runtime();
270        rt.exec(
271            r#"
272            imp.register_tool({
273                name = "greet",
274                label = "Greeting Tool",
275                description = "Says hello",
276                readonly = true,
277                params = {
278                    type = "object",
279                    properties = {
280                        name = { type = "string", description = "Who to greet" }
281                    }
282                },
283                execute = function(call_id, params, ctx)
284                    return { content = "Hello, " .. (params.name or "world") }
285                end
286            })
287        "#,
288        )
289        .unwrap();
290
291        assert_eq!(rt.tool_count(), 1);
292        let names = rt.tool_names();
293        assert_eq!(names[0], "greet");
294    }
295
296    #[test]
297    fn register_tool_execute_callable() {
298        let rt = make_runtime();
299        rt.exec(
300            r#"
301            imp.register_tool({
302                name = "add",
303                execute = function(call_id, params, ctx)
304                    return { content = tostring(params.a + params.b), is_error = false }
305                end
306            })
307        "#,
308        )
309        .unwrap();
310
311        // Call the execute function directly
312        let tools = rt.tools();
313        let tools_guard = tools.lock().unwrap();
314        let execute_fn: mlua::Function = rt
315            .lua()
316            .registry_value(&tools_guard[0].execute_key)
317            .unwrap();
318
319        let params = rt.lua().create_table().unwrap();
320        params.set("a", 3).unwrap();
321        params.set("b", 4).unwrap();
322
323        let result: mlua::Table = execute_fn
324            .call(("call_1", params, mlua::Value::Nil))
325            .unwrap();
326        let content: String = result.get("content").unwrap();
327        assert_eq!(content, "7");
328    }
329
330    #[tokio::test]
331    async fn load_lua_tools_registers_and_executes_bridge() {
332        let rt = make_runtime();
333        rt.exec(
334            r#"
335            imp.register_tool({
336                name = "greet",
337                label = "Greeting Tool",
338                description = "Greets from Lua",
339                readonly = true,
340                params = {
341                    name = { type = "string", description = "Who to greet", required = true },
342                    excited = { type = "boolean" }
343                },
344                execute = function(call_id, params, ctx)
345                    local suffix = params.excited and "!" or "."
346                    return {
347                        content = {
348                            { type = "text", text = "hello " .. params.name .. suffix },
349                        },
350                        details = {
351                            call_id = call_id,
352                            cwd = ctx.cwd,
353                            cancelled = ctx.cancelled,
354                        },
355                    }
356                end
357            })
358        "#,
359        )
360        .unwrap();
361
362        let runtime = Arc::new(Mutex::new(rt));
363        let mut registry = ToolRegistry::new();
364        load_lua_tools(Arc::clone(&runtime), &mut registry);
365
366        let tool = registry
367            .get("greet")
368            .expect("lua tool should be registered");
369        assert_eq!(tool.label(), "Greeting Tool");
370        assert_eq!(tool.description(), "Greets from Lua");
371        assert!(tool.is_readonly());
372        assert_eq!(tool.parameters()["properties"]["name"]["type"], "string");
373        assert_eq!(tool.parameters()["required"], serde_json::json!(["name"]));
374
375        let output = tool
376            .execute(
377                "call_123",
378                serde_json::json!({ "name": "Ada", "excited": true }),
379                test_ctx(),
380            )
381            .await
382            .unwrap();
383
384        let text = output
385            .content
386            .iter()
387            .find_map(|block| match block {
388                imp_core::imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
389                _ => None,
390            })
391            .expect("lua tool should return text");
392        assert_eq!(text, "hello Ada!");
393        assert_eq!(output.details["call_id"], "call_123");
394        assert_eq!(output.details["cwd"], "/tmp/lua-tools");
395        assert_eq!(output.details["cancelled"], false);
396    }
397
398    #[test]
399    fn imp_secret_helpers_exist() {
400        let rt = make_runtime();
401        rt.exec(
402            r#"
403            _has_secret = type(imp.secret) == "function"
404            _has_secret_fields = type(imp.secret_fields) == "function"
405        "#,
406        )
407        .unwrap();
408
409        let has_secret: bool = rt.lua().globals().get("_has_secret").unwrap();
410        let has_secret_fields: bool = rt.lua().globals().get("_has_secret_fields").unwrap();
411        assert!(has_secret);
412        assert!(has_secret_fields);
413    }
414
415    // ── imp.exec() — Shell execution ────────────────────────────
416
417    #[test]
418    fn exec_runs_command_returns_stdout() {
419        let rt = make_runtime();
420        rt.set_allow_shell_exec(true);
421        rt.exec(
422            r#"
423            local result = imp.exec("echo hello")
424            _test_stdout = result.stdout
425            _test_exit = result.exit_code
426        "#,
427        )
428        .unwrap();
429
430        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
431        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
432        assert_eq!(stdout.trim(), "hello");
433        assert_eq!(exit_code, 0);
434    }
435
436    #[test]
437    fn exec_captures_stderr() {
438        let rt = make_runtime();
439        rt.set_allow_shell_exec(true);
440        rt.exec(
441            r#"
442            local result = imp.exec("echo error >&2")
443            _test_stderr = result.stderr
444        "#,
445        )
446        .unwrap();
447
448        let stderr: String = rt.lua().globals().get("_test_stderr").unwrap();
449        assert_eq!(stderr.trim(), "error");
450    }
451
452    #[test]
453    fn exec_returns_nonzero_exit_code() {
454        let rt = make_runtime();
455        rt.set_allow_shell_exec(true);
456        rt.exec(
457            r#"
458            local result = imp.exec("exit 42")
459            _test_exit = result.exit_code
460        "#,
461        )
462        .unwrap();
463
464        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
465        assert_eq!(exit_code, 42);
466    }
467
468    #[test]
469    fn exec_with_cwd() {
470        let rt = make_runtime();
471        rt.set_allow_shell_exec(true);
472        rt.exec(
473            r#"
474            local result = imp.exec("pwd", nil, { cwd = "/tmp" })
475            _test_cwd = result.stdout
476        "#,
477        )
478        .unwrap();
479
480        let cwd: String = rt.lua().globals().get("_test_cwd").unwrap();
481        // /tmp may resolve to /private/tmp on macOS
482        assert!(cwd.trim().contains("tmp"));
483    }
484
485    // ── ctx.ui.confirm() with NullInterface ─────────────────────
486    // The NullInterface returns None for confirm, which maps to nil in Lua.
487    // We simulate this by testing that the bridge correctly handles nil returns.
488
489    #[test]
490    fn null_interface_confirm_returns_nil() {
491        let rt = make_runtime();
492        // Simulate what ctx.ui.confirm would do with NullInterface — just return nil
493        rt.exec(
494            r#"
495            -- When NullInterface returns None, the bridge maps it to nil
496            _confirm_result = nil  -- This is what NullInterface.confirm() produces
497            _is_nil = (_confirm_result == nil)
498        "#,
499        )
500        .unwrap();
501
502        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
503        assert!(is_nil);
504    }
505
506    // ── Hot reload ──────────────────────────────────────────────
507
508    #[test]
509    fn hot_reload_drops_and_recreates() {
510        let user_dir = TempDir::new().unwrap();
511        let lua_dir = user_dir.path().join("lua");
512        std::fs::create_dir_all(&lua_dir).unwrap();
513
514        write_lua(
515            &lua_dir,
516            "ext.lua",
517            r#"
518            imp.on("on_session_start", function() end)
519            imp.register_tool({ name = "my_tool", execute = function() end })
520        "#,
521        );
522
523        // First load
524        let (rt1, exts1) = reload(user_dir.path(), None).unwrap();
525        assert_eq!(rt1.hook_count(), 1);
526        assert_eq!(rt1.tool_count(), 1);
527        assert_eq!(exts1.len(), 1);
528
529        // Modify the extension
530        write_lua(
531            &lua_dir,
532            "ext.lua",
533            r#"
534            imp.on("on_session_start", function() end)
535            imp.on("after_file_write", function() end)
536            imp.register_tool({ name = "tool_a", execute = function() end })
537            imp.register_tool({ name = "tool_b", execute = function() end })
538        "#,
539        );
540
541        // Reload — old state is dropped, new state picks up changes
542        let (rt2, exts2) = reload(user_dir.path(), None).unwrap();
543        assert_eq!(rt2.hook_count(), 2);
544        assert_eq!(rt2.tool_count(), 2);
545        assert_eq!(exts2.len(), 1);
546
547        let tool_names = rt2.tool_names();
548        assert!(tool_names.contains(&"tool_a".to_string()));
549        assert!(tool_names.contains(&"tool_b".to_string()));
550    }
551
552    // ── Error handling ──────────────────────────────────────────
553
554    #[test]
555    fn lua_syntax_error_caught() {
556        let rt = make_runtime();
557        let result = rt.exec("this is not valid lua !!!");
558        assert!(result.is_err());
559        // Runtime is still usable after error
560        let result2 = rt.exec("_test_ok = true");
561        assert!(result2.is_ok());
562        let ok: bool = rt.lua().globals().get("_test_ok").unwrap();
563        assert!(ok);
564    }
565
566    #[test]
567    fn lua_runtime_error_caught() {
568        let rt = make_runtime();
569        let result = rt.exec("error('intentional error')");
570        assert!(result.is_err());
571        let err_msg = format!("{}", result.unwrap_err());
572        assert!(err_msg.contains("intentional error"));
573    }
574
575    #[test]
576    fn extension_error_doesnt_crash_runtime() {
577        let rt = make_runtime();
578
579        // First extension errors
580        let r1 = rt.exec("error('ext1 failed')");
581        assert!(r1.is_err());
582
583        // Second extension still loads fine
584        let r2 = rt.exec(
585            r#"
586            imp.on("on_session_start", function() end)
587            _ext2_loaded = true
588        "#,
589        );
590        assert!(r2.is_ok());
591        assert_eq!(rt.hook_count(), 1);
592
593        let loaded: bool = rt.lua().globals().get("_ext2_loaded").unwrap();
594        assert!(loaded);
595    }
596
597    // ── Multiple extensions coexist ─────────────────────────────
598
599    #[test]
600    fn multiple_extensions_coexist() {
601        let rt = make_runtime();
602
603        // Extension 1
604        rt.exec(
605            r#"
606            imp.on("on_session_start", function()
607                _ext1_fired = true
608            end)
609            imp.register_tool({ name = "ext1_tool", execute = function() end })
610        "#,
611        )
612        .unwrap();
613
614        // Extension 2
615        rt.exec(
616            r#"
617            imp.on("after_file_write", function()
618                _ext2_fired = true
619            end)
620            imp.register_tool({ name = "ext2_tool", execute = function() end })
621        "#,
622        )
623        .unwrap();
624
625        assert_eq!(rt.hook_count(), 2);
626        assert_eq!(rt.tool_count(), 2);
627
628        let names = rt.tool_names();
629        assert!(names.contains(&"ext1_tool".to_string()));
630        assert!(names.contains(&"ext2_tool".to_string()));
631
632        let events = rt.hook_events();
633        assert!(events.contains(&"on_session_start".to_string()));
634        assert!(events.contains(&"after_file_write".to_string()));
635    }
636
637    #[test]
638    fn extensions_share_state() {
639        let rt = make_runtime();
640
641        // Extension 1 sets a global
642        rt.exec("shared_counter = 1").unwrap();
643
644        // Extension 2 reads and increments it
645        rt.exec(
646            r#"
647            shared_counter = shared_counter + 1
648            _final = shared_counter
649        "#,
650        )
651        .unwrap();
652
653        let val: i64 = rt.lua().globals().get("_final").unwrap();
654        assert_eq!(val, 2);
655    }
656
657    // ── Inter-extension events ──────────────────────────────────
658
659    #[test]
660    fn events_on_and_emit() {
661        let rt = make_runtime();
662        rt.exec(
663            r#"
664            _received = nil
665            imp.events.on("custom_event", function(data)
666                _received = data
667            end)
668            imp.events.emit("custom_event", "hello from event")
669        "#,
670        )
671        .unwrap();
672
673        let received: String = rt.lua().globals().get("_received").unwrap();
674        assert_eq!(received, "hello from event");
675    }
676
677    #[test]
678    fn events_multiple_handlers() {
679        let rt = make_runtime();
680        rt.exec(
681            r#"
682            _count = 0
683            imp.events.on("tick", function() _count = _count + 1 end)
684            imp.events.on("tick", function() _count = _count + 1 end)
685            imp.events.emit("tick", nil)
686        "#,
687        )
688        .unwrap();
689
690        let count: i64 = rt.lua().globals().get("_count").unwrap();
691        assert_eq!(count, 2);
692    }
693
694    #[test]
695    fn events_handler_error_doesnt_crash() {
696        let rt = make_runtime();
697        rt.exec(
698            r#"
699            _after_error = false
700            imp.events.on("test", function() error("boom") end)
701            imp.events.on("test", function() _after_error = true end)
702            imp.events.emit("test", nil)
703        "#,
704        )
705        .unwrap();
706
707        let after: bool = rt.lua().globals().get("_after_error").unwrap();
708        assert!(after, "second handler should still fire after first errors");
709    }
710
711    // ── imp.register_command() ──────────────────────────────────
712
713    #[test]
714    fn register_command_creates_handle() {
715        let rt = make_runtime();
716        rt.exec(
717            r#"
718            imp.register_command("greet", {
719                description = "Say hello",
720                handler = function(args, ctx)
721                    return "Hello!"
722                end
723            })
724        "#,
725        )
726        .unwrap();
727
728        assert_eq!(rt.command_count(), 1);
729    }
730
731    // ── JSON conversion ─────────────────────────────────────────
732
733    #[test]
734    fn lua_value_to_json_primitives() {
735        let rt = make_runtime();
736        let lua = rt.lua();
737
738        assert_eq!(lua_value_to_json(mlua::Value::Nil), serde_json::Value::Null);
739        assert_eq!(
740            lua_value_to_json(mlua::Value::Boolean(true)),
741            serde_json::json!(true)
742        );
743        assert_eq!(
744            lua_value_to_json(mlua::Value::Integer(42)),
745            serde_json::json!(42)
746        );
747        assert_eq!(
748            lua_value_to_json(mlua::Value::Number(3.14)),
749            serde_json::json!(3.14)
750        );
751
752        let s = lua.create_string("hello").unwrap();
753        assert_eq!(
754            lua_value_to_json(mlua::Value::String(s)),
755            serde_json::json!("hello")
756        );
757    }
758
759    #[test]
760    fn lua_table_to_json_object() {
761        let rt = make_runtime();
762        rt.exec(
763            r#"
764            _test_table = { name = "Alice", age = 30 }
765        "#,
766        )
767        .unwrap();
768
769        let val: mlua::Value = rt.lua().globals().get("_test_table").unwrap();
770        let json = lua_value_to_json(val);
771        assert_eq!(json["name"], "Alice");
772        assert_eq!(json["age"], 30);
773    }
774
775    #[test]
776    fn lua_array_to_json_array() {
777        let rt = make_runtime();
778        rt.exec(
779            r#"
780            _test_arr = { 1, 2, 3 }
781        "#,
782        )
783        .unwrap();
784
785        let val: mlua::Value = rt.lua().globals().get("_test_arr").unwrap();
786        let json = lua_value_to_json(val);
787        assert_eq!(json, serde_json::json!([1, 2, 3]));
788    }
789
790    #[test]
791    fn json_to_lua_roundtrip() {
792        let rt = make_runtime();
793        let lua = rt.lua();
794
795        let original = serde_json::json!({
796            "name": "test",
797            "count": 42,
798            "active": true,
799            "tags": ["a", "b"],
800            "nested": { "x": 1 }
801        });
802
803        let lua_val = json_to_lua_value(lua, &original).unwrap();
804        let back = lua_value_to_json(lua_val);
805        assert_eq!(back, original);
806    }
807
808    // ── File loading ────────────────────────────────────────────
809
810    #[test]
811    fn load_extensions_from_files() {
812        let user_dir = TempDir::new().unwrap();
813        let lua_dir = user_dir.path().join("lua");
814        std::fs::create_dir_all(&lua_dir).unwrap();
815
816        write_lua(
817            &lua_dir,
818            "a.lua",
819            r#"imp.on("on_session_start", function() end)"#,
820        );
821        write_lua(
822            &lua_dir,
823            "b.lua",
824            r#"imp.register_tool({ name = "b_tool", execute = function() end })"#,
825        );
826
827        let exts = discover_extensions(user_dir.path(), None);
828        assert_eq!(exts.len(), 2);
829
830        let rt = make_runtime();
831        let results = load_extensions(&rt, &exts);
832
833        // Both should succeed
834        for (name, result) in &results {
835            assert!(result.is_ok(), "Extension {} failed: {:?}", name, result);
836        }
837
838        assert_eq!(rt.hook_count(), 1);
839        assert_eq!(rt.tool_count(), 1);
840    }
841
842    #[test]
843    fn load_extension_error_reported_not_fatal() {
844        let user_dir = TempDir::new().unwrap();
845        let lua_dir = user_dir.path().join("lua");
846        std::fs::create_dir_all(&lua_dir).unwrap();
847
848        write_lua(&lua_dir, "bad.lua", "error('bad extension')");
849        write_lua(
850            &lua_dir,
851            "good.lua",
852            r#"imp.on("on_session_start", function() end)"#,
853        );
854
855        let exts = discover_extensions(user_dir.path(), None);
856        let rt = make_runtime();
857        let results = load_extensions(&rt, &exts);
858
859        // One fails, one succeeds
860        let failures: Vec<_> = results.iter().filter(|(_, r)| r.is_err()).collect();
861        let successes: Vec<_> = results.iter().filter(|(_, r)| r.is_ok()).collect();
862
863        assert_eq!(failures.len(), 1);
864        assert_eq!(successes.len(), 1);
865
866        // Good extension's hook was registered despite the bad extension
867        assert_eq!(rt.hook_count(), 1);
868    }
869
870    // ── imp.tool() — call native tools from Lua ─────────────────
871
872    use async_trait::async_trait;
873
874    struct EchoTestTool;
875
876    #[async_trait]
877    impl imp_core::tools::Tool for EchoTestTool {
878        fn name(&self) -> &str {
879            "echo"
880        }
881        fn label(&self) -> &str {
882            "Echo"
883        }
884        fn description(&self) -> &str {
885            "Echoes text"
886        }
887        fn parameters(&self) -> serde_json::Value {
888            serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}})
889        }
890        fn is_readonly(&self) -> bool {
891            true
892        }
893        async fn execute(
894            &self,
895            _call_id: &str,
896            params: serde_json::Value,
897            _ctx: imp_core::tools::ToolContext,
898        ) -> imp_core::Result<imp_core::tools::ToolOutput> {
899            let text = params["text"].as_str().unwrap_or("no text");
900            Ok(imp_core::tools::ToolOutput::text(format!("echo: {text}")))
901        }
902    }
903
904    struct FailTestTool;
905
906    #[async_trait]
907    impl imp_core::tools::Tool for FailTestTool {
908        fn name(&self) -> &str {
909            "fail"
910        }
911        fn label(&self) -> &str {
912            "Fail"
913        }
914        fn description(&self) -> &str {
915            "Always fails"
916        }
917        fn parameters(&self) -> serde_json::Value {
918            serde_json::json!({"type": "object"})
919        }
920        fn is_readonly(&self) -> bool {
921            true
922        }
923        async fn execute(
924            &self,
925            _call_id: &str,
926            _params: serde_json::Value,
927            _ctx: imp_core::tools::ToolContext,
928        ) -> imp_core::Result<imp_core::tools::ToolOutput> {
929            Ok(imp_core::tools::ToolOutput::error("intentional failure"))
930        }
931    }
932
933    fn make_call_context() -> sandbox::LuaCallContext {
934        let (tx, _rx) = tokio::sync::mpsc::channel(16);
935        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
936        sandbox::LuaCallContext {
937            cwd: PathBuf::from("/tmp/lua-test"),
938            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
939            update_tx: tx,
940            command_tx: cmd_tx,
941            ui: Arc::new(NullInterface),
942            file_cache: Arc::new(imp_core::tools::FileCache::new()),
943            checkpoint_state: Arc::new(imp_core::tools::CheckpointState::new()),
944            file_tracker: Arc::new(std::sync::Mutex::new(
945                imp_core::tools::FileTracker::default(),
946            )),
947            anchor_store: Arc::new(imp_core::tools::AnchorStore::new()),
948            mode: imp_core::config::AgentMode::Full,
949            read_max_lines: 500,
950            lua_tool_loader: None,
951            config: Arc::new(imp_core::config::Config::default()),
952        }
953    }
954
955    #[tokio::test(flavor = "multi_thread")]
956    async fn imp_tool_calls_native_tool() {
957        let rt = make_runtime();
958
959        let mut native = std::collections::HashMap::new();
960        native.insert(
961            "echo".to_string(),
962            Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
963        );
964        rt.set_native_tools(native);
965        rt.set_call_context(make_call_context());
966
967        let rt = Arc::new(Mutex::new(rt));
968        let rt2 = Arc::clone(&rt);
969
970        let result = tokio::task::spawn_blocking(move || {
971            let guard = rt2.lock().unwrap();
972            guard
973                .exec(
974                    r#"
975                _result, _err = imp.tool("echo", { text = "hello from lua" })
976            "#,
977                )
978                .unwrap();
979            let result: String = guard.lua().globals().get("_result").unwrap();
980            let err: mlua::Value = guard.lua().globals().get("_err").unwrap();
981            assert!(matches!(err, mlua::Value::Nil), "expected no error");
982            result
983        })
984        .await
985        .unwrap();
986
987        assert_eq!(result, "echo: hello from lua");
988    }
989
990    #[tokio::test(flavor = "multi_thread")]
991    async fn imp_tool_returns_error_on_failure() {
992        let rt = make_runtime();
993
994        let mut native = std::collections::HashMap::new();
995        native.insert(
996            "fail".to_string(),
997            Arc::new(FailTestTool) as Arc<dyn imp_core::tools::Tool>,
998        );
999        rt.set_native_tools(native);
1000        rt.set_call_context(make_call_context());
1001
1002        let rt = Arc::new(Mutex::new(rt));
1003        let rt2 = Arc::clone(&rt);
1004
1005        tokio::task::spawn_blocking(move || {
1006            let guard = rt2.lock().unwrap();
1007            guard
1008                .exec(
1009                    r#"
1010                _result, _err = imp.tool("fail", {})
1011            "#,
1012                )
1013                .unwrap();
1014            let result: mlua::Value = guard.lua().globals().get("_result").unwrap();
1015            assert!(matches!(result, mlua::Value::Nil), "expected nil result");
1016            let err: String = guard.lua().globals().get("_err").unwrap();
1017            assert!(
1018                err.contains("intentional failure"),
1019                "expected failure message, got: {err}"
1020            );
1021        })
1022        .await
1023        .unwrap();
1024    }
1025
1026    #[tokio::test(flavor = "multi_thread")]
1027    async fn imp_tool_errors_on_unknown_tool() {
1028        let rt = make_runtime();
1029        rt.set_native_tools(std::collections::HashMap::new());
1030        rt.set_call_context(make_call_context());
1031
1032        let rt = Arc::new(Mutex::new(rt));
1033        let rt2 = Arc::clone(&rt);
1034
1035        tokio::task::spawn_blocking(move || {
1036            let guard = rt2.lock().unwrap();
1037            let result = guard.exec(
1038                r#"
1039                imp.tool("nonexistent", {})
1040            "#,
1041            );
1042            assert!(result.is_err(), "should error on unknown tool");
1043            let err = format!("{}", result.unwrap_err());
1044            assert!(
1045                err.contains("not found"),
1046                "error should mention 'not found': {err}"
1047            );
1048        })
1049        .await
1050        .unwrap();
1051    }
1052
1053    #[tokio::test(flavor = "multi_thread")]
1054    async fn imp_tool_errors_when_disabled() {
1055        let rt = make_runtime();
1056
1057        let mut native = std::collections::HashMap::new();
1058        native.insert(
1059            "echo".to_string(),
1060            Arc::new(EchoTestTool) as Arc<dyn imp_core::tools::Tool>,
1061        );
1062        rt.set_native_tools(native);
1063        rt.set_call_context(make_call_context());
1064        rt.set_allow_native_tool_calls(false);
1065
1066        let rt = Arc::new(Mutex::new(rt));
1067        let rt2 = Arc::clone(&rt);
1068
1069        tokio::task::spawn_blocking(move || {
1070            let guard = rt2.lock().unwrap();
1071            let result = guard.exec(
1072                r#"
1073                imp.tool("echo", { text = "hello from lua" })
1074            "#,
1075            );
1076            assert!(result.is_err(), "disabled imp.tool() should error");
1077            let err = format!("{}", result.unwrap_err());
1078            assert!(
1079                err.contains("disabled"),
1080                "error should mention disabled state: {err}"
1081            );
1082        })
1083        .await
1084        .unwrap();
1085    }
1086
1087    // ── imp.env() — scoped env var access ───────────────────────
1088
1089    #[test]
1090    fn imp_exec_errors_when_disabled() {
1091        let rt = make_runtime();
1092        let result = rt.exec(
1093            r#"
1094            local _ = imp.exec("echo hi")
1095        "#,
1096        );
1097        assert!(result.is_err(), "disabled imp.exec() should error");
1098        let err = format!("{}", result.unwrap_err());
1099        assert!(
1100            err.contains("imp.exec() is disabled"),
1101            "unexpected error: {err}"
1102        );
1103    }
1104
1105    #[test]
1106    fn imp_exec_runs_when_enabled() {
1107        let rt = make_runtime();
1108        rt.set_allow_shell_exec(true);
1109
1110        rt.exec(
1111            r#"
1112            local result = imp.exec("printf lua_exec_ok")
1113            _test_stdout = result.stdout
1114            _test_exit = result.exit_code
1115        "#,
1116        )
1117        .unwrap();
1118
1119        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1120        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1121        assert_eq!(stdout, "lua_exec_ok");
1122        assert_eq!(exit_code, 0);
1123    }
1124
1125    #[test]
1126    fn imp_exec_passes_scoped_env_to_child_process() {
1127        let rt = make_runtime();
1128        rt.set_allow_shell_exec(true);
1129
1130        rt.exec(
1131            r#"
1132            local result = imp.exec("printf %s \"$IMP_LUA_CHILD_SECRET\"", nil, {
1133                env = { IMP_LUA_CHILD_SECRET = "child-only-value" },
1134            })
1135            _test_stdout = result.stdout
1136            _test_exit = result.exit_code
1137        "#,
1138        )
1139        .unwrap();
1140
1141        let stdout: String = rt.lua().globals().get("_test_stdout").unwrap();
1142        let exit_code: i32 = rt.lua().globals().get("_test_exit").unwrap();
1143        assert_eq!(stdout, "child-only-value");
1144        assert_eq!(exit_code, 0);
1145    }
1146
1147    #[test]
1148    fn imp_secret_errors_when_disabled() {
1149        let rt = make_runtime();
1150        let result = rt.exec(
1151            r#"
1152            local _ = imp.secret("openai", "api_key")
1153        "#,
1154        );
1155        assert!(result.is_err(), "disabled imp.secret() should error");
1156        let err = format!("{}", result.unwrap_err());
1157        assert!(
1158            err.contains("imp.secret() is disabled"),
1159            "unexpected error: {err}"
1160        );
1161    }
1162
1163    #[test]
1164    fn imp_secret_fields_errors_when_disabled() {
1165        let rt = make_runtime();
1166        let result = rt.exec(
1167            r#"
1168            local _ = imp.secret_fields("openai")
1169        "#,
1170        );
1171        assert!(result.is_err(), "disabled imp.secret_fields() should error");
1172        let err = format!("{}", result.unwrap_err());
1173        assert!(
1174            err.contains("imp.secret_fields() is disabled"),
1175            "unexpected error: {err}"
1176        );
1177    }
1178
1179    #[tokio::test(flavor = "multi_thread")]
1180    async fn imp_http_get_errors_when_disabled() {
1181        let rt = Arc::new(Mutex::new(make_runtime()));
1182        let rt2 = Arc::clone(&rt);
1183        tokio::task::spawn_blocking(move || {
1184            let guard = rt2.lock().unwrap();
1185            let result = guard.exec(
1186                r#"
1187                local _ = imp.http.get("https://example.com")
1188            "#,
1189            );
1190            assert!(result.is_err(), "disabled imp.http.get() should error");
1191            let err = format!("{}", result.unwrap_err());
1192            assert!(
1193                err.contains("imp.http.get() is disabled"),
1194                "unexpected error: {err}"
1195            );
1196        })
1197        .await
1198        .unwrap();
1199    }
1200
1201    #[test]
1202    fn imp_env_reads_var_when_allowed() {
1203        let rt = make_runtime();
1204        std::env::set_var("IMP_LUA_TEST_VAR", "test_value");
1205
1206        let mut allowed = std::collections::HashSet::new();
1207        allowed.insert("IMP_LUA_TEST_VAR".to_string());
1208        rt.set_allowed_env(allowed);
1209
1210        rt.exec(
1211            r#"
1212            _env_val = imp.env("IMP_LUA_TEST_VAR")
1213        "#,
1214        )
1215        .unwrap();
1216
1217        let val: String = rt.lua().globals().get("_env_val").unwrap();
1218        assert_eq!(val, "test_value");
1219    }
1220
1221    #[test]
1222    fn imp_env_returns_nil_for_denied_var() {
1223        let rt = make_runtime();
1224        std::env::set_var("IMP_LUA_TEST_SECRET", "secret_value");
1225
1226        let mut allowed = std::collections::HashSet::new();
1227        allowed.insert("SOME_OTHER_VAR".to_string());
1228        rt.set_allowed_env(allowed);
1229
1230        rt.exec(
1231            r#"
1232            _env_val = imp.env("IMP_LUA_TEST_SECRET")
1233            _is_nil = (_env_val == nil)
1234        "#,
1235        )
1236        .unwrap();
1237
1238        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1239        assert!(is_nil, "denied env var should return nil");
1240    }
1241
1242    #[test]
1243    fn imp_env_allows_all_when_list_empty() {
1244        let rt = make_runtime();
1245        std::env::set_var("IMP_LUA_TEST_OPEN", "open_value");
1246
1247        // Empty allowed set should deny by default.
1248        rt.set_allowed_env(std::collections::HashSet::new());
1249
1250        rt.exec(
1251            r#"
1252            _env_val = imp.env("IMP_LUA_TEST_OPEN")
1253            _is_nil = (_env_val == nil)
1254        "#,
1255        )
1256        .unwrap();
1257
1258        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1259        assert!(is_nil, "empty allow-list should deny env access by default");
1260    }
1261
1262    #[test]
1263    fn imp_env_returns_nil_for_missing_var() {
1264        let rt = make_runtime();
1265        rt.set_allowed_env(std::collections::HashSet::new());
1266
1267        rt.exec(
1268            r#"
1269            _env_val = imp.env("DEFINITELY_NOT_SET_IMP_LUA_TEST")
1270            _is_nil = (_env_val == nil)
1271        "#,
1272        )
1273        .unwrap();
1274
1275        let is_nil: bool = rt.lua().globals().get("_is_nil").unwrap();
1276        assert!(is_nil, "missing env var should return nil");
1277    }
1278}