Skip to main content

imp_lua/
bridge.rs

1use std::process::{Command, Stdio};
2
3use async_trait::async_trait;
4use imp_core::storage;
5use imp_core::tools::lua::{parameter_schema_from_lua, tool_output_from_lua_result};
6use imp_core::tools::{Tool, ToolContext, ToolOutput, ToolRegistry};
7use imp_core::Error as CoreError;
8use imp_llm::auth::AuthStore;
9use mlua::{Function, Lua, MultiValue, Table, Value};
10use serde_json::json;
11use std::sync::{Arc, Mutex};
12
13use crate::sandbox::{
14    LuaCallContext, LuaCommandHandle, LuaError, LuaHookHandle, LuaRuntime, LuaToolHandle,
15};
16
17/// A `Tool` implementation backed by a Lua function registered with
18/// `imp.register_tool()`.
19pub struct LuaTool {
20    name: String,
21    label: String,
22    description: String,
23    readonly: bool,
24    params: serde_json::Value,
25    runtime: Arc<Mutex<LuaRuntime>>,
26    handle_index: usize,
27}
28
29#[async_trait]
30impl Tool for LuaTool {
31    fn name(&self) -> &str {
32        &self.name
33    }
34
35    fn label(&self) -> &str {
36        &self.label
37    }
38
39    fn description(&self) -> &str {
40        &self.description
41    }
42
43    fn parameters(&self) -> serde_json::Value {
44        parameter_schema_from_lua(&self.params)
45    }
46
47    fn is_readonly(&self) -> bool {
48        self.readonly
49    }
50
51    async fn execute(
52        &self,
53        call_id: &str,
54        params: serde_json::Value,
55        ctx: ToolContext,
56    ) -> imp_core::Result<ToolOutput> {
57        let runtime = Arc::clone(&self.runtime);
58        let handle_index = self.handle_index;
59        let call_id = call_id.to_string();
60        let ctx_json = json!({
61            "cwd": ctx.cwd.display().to_string(),
62            "cancelled": ctx.is_cancelled(),
63        });
64        let call_ctx = LuaCallContext {
65            cwd: ctx.cwd,
66            cancelled: ctx.cancelled,
67            update_tx: ctx.update_tx,
68            command_tx: ctx.command_tx,
69            ui: ctx.ui,
70            file_cache: ctx.file_cache,
71            checkpoint_state: ctx.checkpoint_state,
72            file_tracker: ctx.file_tracker,
73            anchor_store: ctx.anchor_store,
74            lua_tool_loader: ctx.lua_tool_loader,
75            mode: ctx.mode,
76            read_max_lines: ctx.read_max_lines,
77            config: ctx.config,
78        };
79
80        tokio::task::spawn_blocking(move || {
81            let runtime_guard = runtime
82                .lock()
83                .map_err(|_| CoreError::Tool("Lua runtime lock poisoned".into()))?;
84
85            // Make the ToolContext available to imp.tool() during this execution.
86            runtime_guard.set_call_context(call_ctx);
87
88            let result = (|| {
89                let tools = runtime_guard.tools();
90                let handles = tools
91                    .lock()
92                    .map_err(|_| CoreError::Tool("Lua tool registry lock poisoned".into()))?;
93                let handle = handles.get(handle_index).ok_or_else(|| {
94                    CoreError::Tool(format!("Lua tool handle {handle_index} not found"))
95                })?;
96
97                let execute_fn: Function = runtime_guard
98                    .lua()
99                    .registry_value(&handle.execute_key)
100                    .map_err(lua_tool_error)?;
101                let lua_params =
102                    json_to_lua_value(runtime_guard.lua(), &params).map_err(lua_tool_error)?;
103                let lua_ctx =
104                    json_to_lua_value(runtime_guard.lua(), &ctx_json).map_err(lua_tool_error)?;
105                let result: Value = execute_fn
106                    .call((call_id.as_str(), lua_params, lua_ctx))
107                    .map_err(lua_tool_error)?;
108
109                tool_output_from_lua_result(lua_value_to_json(result))
110            })();
111
112            runtime_guard.clear_call_context();
113            result
114        })
115        .await
116        .map_err(|error| CoreError::Tool(format!("Lua tool task failed: {error}")))?
117    }
118}
119
120/// Register all currently loaded Lua tools with imp-core's tool registry.
121pub fn load_lua_tools(runtime: Arc<Mutex<LuaRuntime>>, registry: &mut ToolRegistry) {
122    let handles = {
123        let runtime_guard = runtime
124            .lock()
125            .expect("Lua runtime lock poisoned while loading tools");
126        let tools = runtime_guard.tools();
127        let handles = tools
128            .lock()
129            .expect("Lua tool registry lock poisoned while loading tools");
130
131        handles
132            .iter()
133            .enumerate()
134            .map(|(index, handle)| LuaTool {
135                name: handle.name.clone(),
136                label: handle.label.clone(),
137                description: handle.description.clone(),
138                readonly: handle.readonly,
139                params: handle.params.clone(),
140                runtime: Arc::clone(&runtime),
141                handle_index: index,
142            })
143            .collect::<Vec<_>>()
144    };
145
146    for tool in handles {
147        registry.register(Arc::new(tool));
148    }
149}
150
151fn lua_tool_error(error: mlua::Error) -> CoreError {
152    CoreError::Tool(format!("Lua tool error: {error}"))
153}
154
155/// Extract header key-value pairs from an optional Lua table.
156fn extract_header_pairs(headers: Option<Table>) -> mlua::Result<Vec<(String, String)>> {
157    let mut pairs = Vec::new();
158    if let Some(tbl) = headers {
159        for pair in tbl.pairs::<String, String>() {
160            let (k, v) = pair?;
161            pairs.push((k, v));
162        }
163    }
164    Ok(pairs)
165}
166
167/// Set up the `imp` global table with host API functions.
168///
169/// Exposes to Lua:
170/// - imp.on(event, handler)           — subscribe to hook events
171/// - imp.register_tool(def)           — register a custom tool
172/// - imp.exec(command, args, opts)    — run a shell command
173/// - imp.register_command(name, def)  — register a slash command
174/// - imp.events.on() / imp.events.emit() — inter-extension event bus
175/// - imp.tool(name, params)           — call a native imp tool
176/// - imp.secret(provider, field?)     — read a saved imp secret field
177/// - imp.secret_fields(provider)      — read all saved fields for a provider
178/// - imp.env(name)                    — read an env var (scoped by allowed list)
179/// - imp.http.get(url, headers?)      — HTTP GET
180/// - imp.http.post(url, body, headers?) — HTTP POST
181pub fn setup_host_api(runtime: &LuaRuntime) -> Result<(), LuaError> {
182    let lua = runtime.lua();
183
184    let imp = lua.create_table()?;
185
186    // ── imp.on(event_name, handler) ──────────────────────────────
187    let hooks = runtime.hooks();
188    let on_fn = lua.create_function(move |lua_inner, (event, handler): (String, Function)| {
189        let key = lua_inner.create_registry_value(handler)?;
190        let handle = LuaHookHandle {
191            event,
192            handler_key: key,
193        };
194        hooks.lock().unwrap().push(handle);
195        Ok(())
196    })?;
197    imp.set("on", on_fn)?;
198
199    // ── imp.register_tool(definition) ────────────────────────────
200    let tools = runtime.tools();
201    let register_tool_fn = lua.create_function(move |lua_inner, def: Table| {
202        let name: String = def.get("name")?;
203        let label: String = def
204            .get::<Option<String>>("label")?
205            .unwrap_or_else(|| name.clone());
206        let description: String = def
207            .get::<Option<String>>("description")?
208            .unwrap_or_default();
209        let readonly: bool = def.get::<Option<bool>>("readonly")?.unwrap_or(false);
210
211        let params_val: Value = def.get("params")?;
212        let params = lua_value_to_json(params_val);
213
214        let execute_fn: Function = def.get("execute")?;
215        let key = lua_inner.create_registry_value(execute_fn)?;
216
217        let handle = LuaToolHandle {
218            name,
219            label,
220            description,
221            readonly,
222            params,
223            execute_key: key,
224        };
225        tools.lock().unwrap().push(handle);
226        Ok(())
227    })?;
228    imp.set("register_tool", register_tool_fn)?;
229
230    // ── imp.exec(command, args, opts) ────────────────────────────
231    let allow_shell_exec = runtime.allow_shell_exec();
232    let exec_fn = lua.create_function(
233        move |lua_inner, (cmd, args, opts): (String, Option<Table>, Option<Table>)| {
234            if !allow_shell_exec.load(std::sync::atomic::Ordering::Relaxed) {
235                return Err(mlua::Error::external(
236                    "imp.exec() is disabled for this runtime",
237                ));
238            }
239            let mut command = Command::new("sh");
240            command.arg("-c");
241
242            // Build the full command string
243            let full_cmd = if let Some(args_table) = args {
244                let mut parts = vec![cmd];
245                for pair in args_table.sequence_values::<String>() {
246                    parts.push(pair?);
247                }
248                parts.join(" ")
249            } else {
250                cmd
251            };
252            command.stdin(Stdio::null()).arg(&full_cmd);
253
254            // Apply opts
255            if let Some(opts_table) = &opts {
256                if let Ok(Some(cwd)) = opts_table.get::<Option<String>>("cwd") {
257                    command.current_dir(cwd);
258                }
259                if let Ok(Some(env_table)) = opts_table.get::<Option<Table>>("env") {
260                    for pair in env_table.pairs::<String, String>() {
261                        let (name, value) = pair?;
262                        command.env(name, value);
263                    }
264                }
265            }
266
267            let output = command.output().map_err(mlua::Error::external)?;
268
269            let result = lua_inner.create_table()?;
270            result.set(
271                "stdout",
272                String::from_utf8_lossy(&output.stdout).to_string(),
273            )?;
274            result.set(
275                "stderr",
276                String::from_utf8_lossy(&output.stderr).to_string(),
277            )?;
278            result.set("exit_code", output.status.code().unwrap_or(-1))?;
279
280            Ok(result)
281        },
282    )?;
283    imp.set("exec", exec_fn)?;
284
285    // ── imp.register_command(name, definition) ───────────────────
286    let commands = runtime.commands();
287    let register_command_fn =
288        lua.create_function(move |lua_inner, (name, def): (String, Table)| {
289            let description: String = def
290                .get::<Option<String>>("description")?
291                .unwrap_or_default();
292            let handler: Function = def.get("handler")?;
293            let key = lua_inner.create_registry_value(handler)?;
294
295            let handle = LuaCommandHandle {
296                name,
297                description,
298                handler_key: key,
299            };
300            commands.lock().unwrap().push(handle);
301            Ok(())
302        })?;
303    imp.set("register_command", register_command_fn)?;
304
305    // ── imp.events (inter-extension event bus) ───────────────────
306    let events = lua.create_table()?;
307
308    // Store handlers in a Lua table: { event_name = { handler1, handler2, ... } }
309    let handlers_table = lua.create_table()?;
310    lua.set_named_registry_value("__imp_event_handlers", handlers_table)?;
311
312    let events_on = lua.create_function(|lua_inner, (name, handler): (String, Function)| {
313        let handlers: Table = lua_inner.named_registry_value("__imp_event_handlers")?;
314        let list: Table = match handlers.get::<Option<Table>>(name.as_str())? {
315            Some(t) => t,
316            None => {
317                let t = lua_inner.create_table()?;
318                handlers.set(name.as_str(), t.clone())?;
319                t
320            }
321        };
322        let len = list.raw_len();
323        list.set(len + 1, handler)?;
324        Ok(())
325    })?;
326    events.set("on", events_on)?;
327
328    let events_emit = lua.create_function(|lua_inner, (name, data): (String, Value)| {
329        let handlers: Table = lua_inner.named_registry_value("__imp_event_handlers")?;
330        if let Some(list) = handlers.get::<Option<Table>>(name.as_str())? {
331            for pair in list.sequence_values::<Function>() {
332                let handler = pair?;
333                // Errors in event handlers are caught and ignored so one bad
334                // extension callback cannot destabilize the host runtime.
335                let _ = handler.call::<()>(data.clone());
336            }
337        }
338        Ok(())
339    })?;
340    events.set("emit", events_emit)?;
341
342    imp.set("events", events)?;
343
344    // ── imp.tool(name, params) — call a native imp tool ──────────
345    let native_tools = runtime.native_tools();
346    let tool_call_ctx = runtime.call_context();
347    let allow_native_tool_calls = runtime.allow_native_tool_calls();
348    let imp_tool_fn = lua.create_function(
349        move |lua_inner, (name, params): (String, Value)| -> mlua::Result<MultiValue> {
350            if !allow_native_tool_calls.load(std::sync::atomic::Ordering::Relaxed) {
351                return Err(mlua::Error::external(
352                    "imp.tool() is disabled for this runtime",
353                ));
354            }
355
356            // Look up the tool.
357            let tool = {
358                let tools_guard = native_tools
359                    .lock()
360                    .map_err(|_| mlua::Error::external("native tools lock poisoned"))?;
361                tools_guard
362                    .get(&name)
363                    .cloned()
364                    .ok_or_else(|| mlua::Error::external(format!("tool '{name}' not found")))?
365            };
366
367            // Build a ToolContext from the stored call context.
368            let ctx = {
369                let ctx_guard = tool_call_ctx
370                    .lock()
371                    .map_err(|_| mlua::Error::external("call context lock poisoned"))?;
372                ctx_guard
373                    .as_ref()
374                    .ok_or_else(|| {
375                        mlua::Error::external("imp.tool() called outside of tool execution context")
376                    })?
377                    .to_tool_context()
378            };
379
380            let params_json = lua_value_to_json(params);
381
382            // Execute the tool — async via block_on (safe from spawn_blocking).
383            let handle = tokio::runtime::Handle::try_current()
384                .map_err(|_| mlua::Error::external("imp.tool() requires a tokio runtime"))?;
385
386            let output = handle
387                .block_on(tool.execute("lua-call", params_json, ctx))
388                .map_err(|e| mlua::Error::external(format!("tool error: {e}")))?;
389
390            // Convert ToolOutput → Lua multi-return: (result, err).
391            let mut mv = MultiValue::new();
392            if output.is_error {
393                let err_text = output
394                    .text_content()
395                    .unwrap_or("tool execution failed")
396                    .to_string();
397                mv.push_back(Value::Nil);
398                mv.push_back(Value::String(lua_inner.create_string(&err_text)?));
399            } else if let Some(text) = output.text_content() {
400                mv.push_back(Value::String(lua_inner.create_string(text)?));
401            } else {
402                mv.push_back(Value::Nil);
403            }
404            Ok(mv)
405        },
406    )?;
407    imp.set("tool", imp_tool_fn)?;
408
409    // ── imp.update(text) — stream progress to the TUI ─────────────
410    let update_call_ctx = runtime.call_context();
411    let imp_update_fn = lua.create_function(move |_lua, text: String| {
412        let ctx_guard = update_call_ctx
413            .lock()
414            .map_err(|_| mlua::Error::external("call context lock poisoned"))?;
415        if let Some(ref ctx) = *ctx_guard {
416            let _ = ctx.update_tx.try_send(imp_core::tools::ToolUpdate {
417                content: vec![imp_core::imp_llm::ContentBlock::Text { text }],
418                details: serde_json::Value::Null,
419            });
420        }
421        Ok(())
422    })?;
423    imp.set("update", imp_update_fn)?;
424
425    // ── imp.secret(provider, field?) — read a saved secret field ──────────
426    let allow_secrets = runtime.allow_secrets();
427    let secret_fn = lua.create_function(
428        move |lua_inner, (provider, field): (String, Option<String>)| -> mlua::Result<Value> {
429            if !allow_secrets.load(std::sync::atomic::Ordering::Relaxed) {
430                return Err(mlua::Error::external(
431                    "imp.secret() is disabled for this runtime",
432                ));
433            }
434            let auth_path =
435                storage::existing_global_auth_path().unwrap_or_else(storage::global_auth_path);
436            let auth_store =
437                AuthStore::load(&auth_path).unwrap_or_else(|_| AuthStore::new(auth_path.clone()));
438            let field = field.unwrap_or_else(|| "api_key".to_string());
439            match auth_store.resolve_secret_field(&provider, &field) {
440                Ok(value) => Ok(Value::String(lua_inner.create_string(&value)?)),
441                Err(error) => Err(mlua::Error::external(error.to_string())),
442            }
443        },
444    )?;
445    imp.set("secret", secret_fn)?;
446
447    // ── imp.secret_fields(provider) — read all saved secret fields ─────────
448    let allow_secrets = runtime.allow_secrets();
449    let secret_fields_fn =
450        lua.create_function(move |lua_inner, provider: String| -> mlua::Result<Value> {
451            if !allow_secrets.load(std::sync::atomic::Ordering::Relaxed) {
452                return Err(mlua::Error::external(
453                    "imp.secret_fields() is disabled for this runtime",
454                ));
455            }
456            let auth_path =
457                storage::existing_global_auth_path().unwrap_or_else(storage::global_auth_path);
458            let auth_store =
459                AuthStore::load(&auth_path).unwrap_or_else(|_| AuthStore::new(auth_path.clone()));
460            match auth_store.resolve_secret_fields(&provider) {
461                Ok(fields) => {
462                    let table = lua_inner.create_table()?;
463                    for (field, value) in fields {
464                        table.set(field, value)?;
465                    }
466                    Ok(Value::Table(table))
467                }
468                Err(error) => Err(mlua::Error::external(error.to_string())),
469            }
470        })?;
471    imp.set("secret_fields", secret_fields_fn)?;
472
473    // ── imp.env(name) — read a scoped env var ────────────────────
474    let allowed_env = runtime.allowed_env();
475    let env_fn = lua.create_function(move |lua_inner, name: String| {
476        let allowed = allowed_env
477            .lock()
478            .map_err(|_| mlua::Error::external("allowed_env lock poisoned"))?;
479        // If the allow-list is empty or the var is not listed, deny access.
480        if !allowed.contains(&name) {
481            return Ok(Value::Nil);
482        }
483        match std::env::var(&name) {
484            Ok(val) => Ok(Value::String(lua_inner.create_string(&val)?)),
485            Err(_) => Ok(Value::Nil),
486        }
487    })?;
488    imp.set("env", env_fn)?;
489
490    // ── imp.http — HTTP GET / POST via reqwest ───────────────────
491    let http = lua.create_table()?;
492    let allow_http = runtime.allow_http();
493
494    let http_get_fn =
495        lua.create_function(move |lua_inner, (url, headers): (String, Option<Table>)| {
496            if !allow_http.load(std::sync::atomic::Ordering::Relaxed) {
497                return Err(mlua::Error::external(
498                    "imp.http.get() is disabled for this runtime",
499                ));
500            }
501            let header_pairs = extract_header_pairs(headers)?;
502
503            let handle = tokio::runtime::Handle::try_current()
504                .map_err(|_| mlua::Error::external("imp.http requires a tokio runtime"))?;
505
506            let (status, body) = handle
507                .block_on(async {
508                    let client = reqwest::Client::new();
509                    let mut builder = client.get(&url);
510                    for (k, v) in &header_pairs {
511                        builder = builder.header(k.as_str(), v.as_str());
512                    }
513                    let resp = builder.send().await.map_err(|e| e.to_string())?;
514                    let status = resp.status().as_u16();
515                    let body = resp.text().await.map_err(|e| e.to_string())?;
516                    Ok::<_, String>((status, body))
517                })
518                .map_err(mlua::Error::external)?;
519
520            let result = lua_inner.create_table()?;
521            result.set("status", status)?;
522            result.set("body", body)?;
523            Ok(result)
524        })?;
525    http.set("get", http_get_fn)?;
526
527    let allow_http = runtime.allow_http();
528    let http_post_fn = lua.create_function(
529        move |lua_inner, (url, body, headers): (String, String, Option<Table>)| {
530            if !allow_http.load(std::sync::atomic::Ordering::Relaxed) {
531                return Err(mlua::Error::external(
532                    "imp.http.post() is disabled for this runtime",
533                ));
534            }
535            let header_pairs = extract_header_pairs(headers)?;
536
537            let handle = tokio::runtime::Handle::try_current()
538                .map_err(|_| mlua::Error::external("imp.http requires a tokio runtime"))?;
539
540            let (status, resp_body) = handle
541                .block_on(async {
542                    let client = reqwest::Client::new();
543                    let mut builder = client.post(&url).body(body);
544                    for (k, v) in &header_pairs {
545                        builder = builder.header(k.as_str(), v.as_str());
546                    }
547                    let resp = builder.send().await.map_err(|e| e.to_string())?;
548                    let status = resp.status().as_u16();
549                    let resp_body = resp.text().await.map_err(|e| e.to_string())?;
550                    Ok::<_, String>((status, resp_body))
551                })
552                .map_err(mlua::Error::external)?;
553
554            let result = lua_inner.create_table()?;
555            result.set("status", status)?;
556            result.set("body", resp_body)?;
557            Ok(result)
558        },
559    )?;
560    http.set("post", http_post_fn)?;
561
562    imp.set("http", http)?;
563
564    // ── Set the global ───────────────────────────────────────────
565    lua.globals().set("imp", imp)?;
566
567    Ok(())
568}
569
570/// Convert a Lua value to serde_json::Value.
571pub fn lua_value_to_json(value: Value) -> serde_json::Value {
572    match value {
573        Value::Nil => serde_json::Value::Null,
574        Value::Boolean(b) => serde_json::Value::Bool(b),
575        Value::Integer(i) => serde_json::Value::Number(serde_json::Number::from(i)),
576        Value::Number(n) => serde_json::Number::from_f64(n)
577            .map(serde_json::Value::Number)
578            .unwrap_or(serde_json::Value::Null),
579        Value::String(s) => {
580            serde_json::Value::String(s.to_str().map(|s| s.to_string()).unwrap_or_default())
581        }
582        Value::Table(t) => {
583            // Check if it's an array (sequential integer keys starting at 1)
584            let len = t.raw_len();
585            if len > 0 {
586                // Check if all keys 1..=len exist (it's an array)
587                let is_array = (1..=len).all(|i| {
588                    t.get::<Value>(i)
589                        .ok()
590                        .map(|v| !matches!(v, Value::Nil))
591                        .unwrap_or(false)
592                });
593                if is_array {
594                    let arr: Vec<serde_json::Value> = (1..=len)
595                        .filter_map(|i| t.get::<Value>(i).ok().map(lua_value_to_json))
596                        .collect();
597                    return serde_json::Value::Array(arr);
598                }
599            }
600
601            // Otherwise it's an object
602            let mut map = serde_json::Map::new();
603            if let Ok(pairs) = t.pairs::<String, Value>().collect::<Result<Vec<_>, _>>() {
604                for (k, v) in pairs {
605                    map.insert(k, lua_value_to_json(v));
606                }
607            }
608            serde_json::Value::Object(map)
609        }
610        _ => serde_json::Value::Null,
611    }
612}
613
614/// Convert a serde_json::Value to a Lua value.
615pub fn json_to_lua_value(lua: &Lua, value: &serde_json::Value) -> mlua::Result<Value> {
616    match value {
617        serde_json::Value::Null => Ok(Value::Nil),
618        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
619        serde_json::Value::Number(n) => {
620            if let Some(i) = n.as_i64() {
621                Ok(Value::Integer(i))
622            } else if let Some(f) = n.as_f64() {
623                Ok(Value::Number(f))
624            } else {
625                Ok(Value::Nil)
626            }
627        }
628        serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
629        serde_json::Value::Array(arr) => {
630            let table = lua.create_table()?;
631            for (i, v) in arr.iter().enumerate() {
632                table.set(i + 1, json_to_lua_value(lua, v)?)?;
633            }
634            Ok(Value::Table(table))
635        }
636        serde_json::Value::Object(map) => {
637            let table = lua.create_table()?;
638            for (k, v) in map {
639                table.set(k.as_str(), json_to_lua_value(lua, v)?)?;
640            }
641            Ok(Value::Table(table))
642        }
643    }
644}