Skip to main content

orcs_lua/
component.rs

1//! LuaComponent implementation.
2//!
3//! Wraps a Lua script to implement the Component trait.
4
5mod ctx_fns;
6mod emitter_fns;
7
8/// Truncate a string to at most `max_bytes`, respecting UTF-8 char boundaries.
9fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
10    if s.len() <= max_bytes {
11        return s;
12    }
13    let mut end = max_bytes;
14    while end > 0 && !s.is_char_boundary(end) {
15        end -= 1;
16    }
17    &s[..end]
18}
19
20use crate::error::LuaError;
21use crate::lua_env::LuaEnv;
22use crate::types::{
23    parse_event_category, parse_signal_response, LuaRequest, LuaResponse, LuaSignal,
24};
25use mlua::{Function, IntoLua, Lua, LuaSerdeExt, RegistryKey, Table, Value as LuaValue};
26use orcs_component::{
27    ChildContext, Component, ComponentError, ComponentLoader, ComponentSnapshot, Emitter,
28    EventCategory, RuntimeHints, SnapshotError, SpawnError, Status, SubscriptionEntry,
29};
30use orcs_event::{Request, Signal, SignalResponse};
31use orcs_runtime::sandbox::SandboxPolicy;
32use orcs_types::ComponentId;
33use parking_lot::Mutex;
34use serde_json::Value as JsonValue;
35use std::path::Path;
36use std::sync::Arc;
37
38/// Extracts a [`ComponentError::Suspended`] from a potentially nested `mlua::Error` chain.
39///
40/// Lua wraps callback errors in `CallbackError { cause }` at each call-stack level.
41/// This helper recursively unwraps `CallbackError` and checks `ExternalError` for
42/// a boxed `ComponentError::Suspended`. Returns `None` if the error is not a Suspended.
43fn extract_suspended(err: &mlua::Error) -> Option<ComponentError> {
44    match err {
45        mlua::Error::ExternalError(ext) => ext
46            .downcast_ref::<ComponentError>()
47            .filter(|ce| matches!(ce, ComponentError::Suspended { .. }))
48            .cloned(),
49        mlua::Error::CallbackError { cause, .. } => extract_suspended(cause),
50        _ => None,
51    }
52}
53
54/// A component implemented in Lua.
55///
56/// Loads a Lua script and delegates Component trait methods to Lua functions.
57///
58/// # Script Format
59///
60/// The Lua script must return a table with the following structure:
61///
62/// ```lua
63/// return {
64///     id = "component-id",           -- Required: unique identifier
65///     subscriptions = {"Echo"},      -- Required: event categories
66///
67///     on_request = function(req)     -- Required: handle requests
68///         return { success = true, data = ... }
69///     end,
70///
71///     on_signal = function(sig)      -- Required: handle signals
72///         return "Handled" | "Ignored" | "Abort"
73///     end,
74///
75///     init = function()              -- Optional: initialization
76///     end,
77///
78///     shutdown = function()          -- Optional: cleanup
79///     end,
80/// }
81/// ```
82pub struct LuaComponent {
83    /// Lua runtime (wrapped in Mutex for Send+Sync).
84    lua: Mutex<Lua>,
85    /// Component identifier.
86    id: ComponentId,
87    /// Subscribed event categories (for Component::subscriptions()).
88    subscriptions: Vec<EventCategory>,
89    /// Subscription entries with optional operation filters (for Component::subscription_entries()).
90    subscription_entries: Vec<SubscriptionEntry>,
91    /// Current status.
92    status: Status,
93    /// Registry key for on_request callback.
94    on_request_key: RegistryKey,
95    /// Registry key for on_signal callback.
96    on_signal_key: RegistryKey,
97    /// Registry key for init callback (optional).
98    init_key: Option<RegistryKey>,
99    /// Registry key for shutdown callback (optional).
100    shutdown_key: Option<RegistryKey>,
101    /// Registry key for snapshot callback (optional).
102    snapshot_key: Option<RegistryKey>,
103    /// Registry key for restore callback (optional).
104    restore_key: Option<RegistryKey>,
105    /// Script path (for hot reload).
106    script_path: Option<String>,
107    /// Event emitter for ChannelRunner mode.
108    ///
109    /// When set, allows Lua scripts to emit events via `orcs.output()`.
110    /// This enables ChannelRunner-based execution (IO-less, event-only).
111    emitter: Option<Arc<Mutex<Box<dyn Emitter>>>>,
112    /// Child context for spawning and managing children.
113    ///
114    /// When set, allows Lua scripts to spawn children via `orcs.spawn_child()`.
115    /// This enables the Manager-Worker pattern where Components manage Children.
116    child_context: Option<Arc<Mutex<Box<dyn ChildContext>>>>,
117    /// Sandbox policy for file operations.
118    ///
119    /// Injected at construction time. Used by `orcs.read/write/grep/glob` in Lua.
120    /// Stored for use during `reload()`.
121    sandbox: Arc<dyn SandboxPolicy>,
122    /// Runtime hints declared by the Lua script.
123    hints: RuntimeHints,
124}
125
126// SAFETY: LuaComponent can be safely sent between threads and accessed concurrently.
127//
128// Justification:
129// 1. mlua is built with "send" feature (see Cargo.toml), which enables thread-safe
130//    Lua state allocation and makes the allocator thread-safe.
131// 2. The Lua runtime is wrapped in parking_lot::Mutex<Lua>, ensuring exclusive
132//    mutable access. All methods that access the Lua state acquire the lock first.
133// 3. All Lua callbacks are stored in the Lua registry via RegistryKey, which is
134//    designed for this use case. RegistryKey itself is Send.
135// 4. No raw Lua values (userdata, functions) escape the Mutex guard scope.
136//    Values are converted to/from Rust types within the lock scope.
137// 5. The remaining fields (id, subscriptions, status, script_path) are all Send+Sync.
138//
139// The "send" feature documentation: https://docs.rs/mlua/latest/mlua/#async-send
140unsafe impl Send for LuaComponent {}
141unsafe impl Sync for LuaComponent {}
142
143impl LuaComponent {
144    /// Creates a new LuaComponent from a script file.
145    ///
146    /// # Arguments
147    ///
148    /// * `path` - Path to the Lua script
149    /// * `sandbox` - Sandbox policy for file operations and exec cwd
150    ///
151    /// # Errors
152    ///
153    /// Returns error if:
154    /// - Script file not found
155    /// - Script syntax error
156    /// - Missing required fields/callbacks
157    pub fn from_file<P: AsRef<Path>>(
158        path: P,
159        sandbox: Arc<dyn SandboxPolicy>,
160    ) -> Result<Self, LuaError> {
161        let path = path.as_ref();
162        let script = std::fs::read_to_string(path)
163            .map_err(|_| LuaError::ScriptNotFound(path.display().to_string()))?;
164
165        let script_dir = path.parent().map(|p| p.to_path_buf());
166        let mut component = Self::from_script_inner(&script, sandbox, script_dir.as_deref())?;
167        component.script_path = Some(path.display().to_string());
168        Ok(component)
169    }
170
171    /// Creates a new LuaComponent from a directory containing `init.lua`.
172    ///
173    /// The directory is added to Lua's `package.path`, enabling standard
174    /// `require()` for co-located modules (e.g. `require("lib.my_module")`).
175    ///
176    /// # Directory Structure
177    ///
178    /// ```text
179    /// components/my_component/
180    ///   init.lua              -- entry point (must return component table)
181    ///   lib/
182    ///     helper.lua          -- require("lib.helper")
183    ///   vendor/
184    ///     lua_solver/init.lua -- require("vendor.lua_solver")
185    /// ```
186    ///
187    /// # Errors
188    ///
189    /// Returns error if `init.lua` not found or script is invalid.
190    pub fn from_dir<P: AsRef<Path>>(
191        dir: P,
192        sandbox: Arc<dyn SandboxPolicy>,
193    ) -> Result<Self, LuaError> {
194        let dir = dir.as_ref();
195        let init_path = dir.join("init.lua");
196        let script = std::fs::read_to_string(&init_path)
197            .map_err(|_| LuaError::ScriptNotFound(init_path.display().to_string()))?;
198
199        let mut component = Self::from_script_inner(&script, sandbox, Some(dir))?;
200        component.script_path = Some(init_path.display().to_string());
201        Ok(component)
202    }
203
204    /// Creates a new LuaComponent from a script string.
205    ///
206    /// # Arguments
207    ///
208    /// * `script` - Lua script content
209    /// * `sandbox` - Sandbox policy for file operations and exec cwd
210    ///
211    /// # Errors
212    ///
213    /// Returns error if script is invalid.
214    pub fn from_script(script: &str, sandbox: Arc<dyn SandboxPolicy>) -> Result<Self, LuaError> {
215        Self::from_script_inner(script, sandbox, None)
216    }
217
218    /// Internal: creates a LuaComponent with optional search path setup.
219    ///
220    /// When `script_dir` is provided, it is added to `LuaEnv`'s search paths
221    /// so that `require()` resolves co-located modules with sandbox validation.
222    fn from_script_inner(
223        script: &str,
224        sandbox: Arc<dyn SandboxPolicy>,
225        script_dir: Option<&Path>,
226    ) -> Result<Self, LuaError> {
227        // Build LuaEnv with sandbox and optional script directory as search path.
228        let mut lua_env = LuaEnv::new(Arc::clone(&sandbox));
229        if let Some(dir) = script_dir {
230            lua_env = lua_env.with_search_path(dir);
231        }
232
233        // Create configured Lua VM (orcs.*, tools, sandboxed require).
234        let lua = lua_env.create_lua()?;
235
236        // Register Component-specific output placeholders.
237        // These are overridden by real emitter functions via set_emitter().
238        {
239            let orcs_table: Table = lua.globals().get("orcs")?;
240            let output_noop = lua.create_function(|_, msg: String| {
241                tracing::warn!(
242                    "[lua] orcs.output called without emitter (noop): {}",
243                    truncate_utf8(&msg, 100)
244                );
245                Ok(())
246            })?;
247            orcs_table.set("output", output_noop)?;
248
249            let output_level_noop = lua.create_function(|_, (msg, _level): (String, String)| {
250                tracing::warn!(
251                    "[lua] orcs.output_with_level called without emitter (noop): {}",
252                    truncate_utf8(&msg, 100)
253                );
254                Ok(())
255            })?;
256            orcs_table.set("output_with_level", output_level_noop)?;
257        }
258
259        // Execute script and get the returned table
260        let component_table: Table = lua
261            .load(script)
262            .eval()
263            .map_err(|e| LuaError::InvalidScript(e.to_string()))?;
264
265        // Extract id and namespace
266        let id_str: String = component_table
267            .get("id")
268            .map_err(|_| LuaError::MissingCallback("id".to_string()))?;
269        let namespace: String = component_table
270            .get("namespace")
271            .unwrap_or_else(|_| "lua".to_string());
272        let id = ComponentId::new(namespace, &id_str);
273
274        // Extract subscriptions (supports both string and table entries)
275        //
276        // String form (all operations):
277        //   subscriptions = { "Echo", "UserInput" }
278        //
279        // Table form (specific operations):
280        //   subscriptions = {
281        //       "UserInput",
282        //       { category = "Extension", operations = {"route_response"} },
283        //   }
284        let subs_table: Table = component_table
285            .get("subscriptions")
286            .map_err(|_| LuaError::MissingCallback("subscriptions".to_string()))?;
287
288        let mut subscriptions = Vec::new();
289        let mut subscription_entries = Vec::new();
290        for pair in subs_table.pairs::<i64, LuaValue>() {
291            let (_, value) = pair.map_err(|e| LuaError::TypeError(e.to_string()))?;
292            match &value {
293                LuaValue::String(s) => {
294                    let cat_str = s.to_str().map_err(|e| LuaError::TypeError(e.to_string()))?;
295                    if let Some(cat) = parse_event_category(&cat_str) {
296                        subscriptions.push(cat.clone());
297                        subscription_entries.push(SubscriptionEntry::all(cat));
298                    }
299                }
300                LuaValue::Table(tbl) => {
301                    // Table form: { category = "Extension", operations = {"op1", "op2"} }
302                    let cat_str: String = tbl.get("category").map_err(|e| {
303                        LuaError::TypeError(format!(
304                            "subscription table must have 'category' field: {e}"
305                        ))
306                    })?;
307                    if let Some(cat) = parse_event_category(&cat_str) {
308                        subscriptions.push(cat.clone());
309                        // Parse optional operations list
310                        let ops_table: Option<Table> = tbl.get("operations").ok();
311                        if let Some(ops) = ops_table {
312                            let mut op_names = Vec::new();
313                            for (_, op) in ops.pairs::<i64, String>().flatten() {
314                                op_names.push(op);
315                            }
316                            subscription_entries
317                                .push(SubscriptionEntry::with_operations(cat, op_names));
318                        } else {
319                            subscription_entries.push(SubscriptionEntry::all(cat));
320                        }
321                    }
322                }
323                _ => {
324                    tracing::warn!("subscription entry must be a string or table, ignoring");
325                }
326            }
327        }
328
329        // Extract required callbacks
330        let on_request_fn: Function = component_table
331            .get("on_request")
332            .map_err(|_| LuaError::MissingCallback("on_request".to_string()))?;
333
334        let on_signal_fn: Function = component_table
335            .get("on_signal")
336            .map_err(|_| LuaError::MissingCallback("on_signal".to_string()))?;
337
338        // Store callbacks in registry
339        let on_request_key = lua.create_registry_value(on_request_fn)?;
340        let on_signal_key = lua.create_registry_value(on_signal_fn)?;
341
342        // Extract optional callbacks
343        let init_key = component_table
344            .get::<Function>("init")
345            .ok()
346            .map(|f| lua.create_registry_value(f))
347            .transpose()?;
348
349        let shutdown_key = component_table
350            .get::<Function>("shutdown")
351            .ok()
352            .map(|f| lua.create_registry_value(f))
353            .transpose()?;
354
355        let snapshot_key = component_table
356            .get::<Function>("snapshot")
357            .ok()
358            .map(|f| lua.create_registry_value(f))
359            .transpose()?;
360
361        let restore_key = component_table
362            .get::<Function>("restore")
363            .ok()
364            .map(|f| lua.create_registry_value(f))
365            .transpose()?;
366
367        // Extract runtime hints (all optional, default false)
368        let hints = RuntimeHints {
369            output_to_io: component_table.get("output_to_io").unwrap_or(false),
370            elevated: component_table.get("elevated").unwrap_or(false),
371            child_spawner: component_table.get("child_spawner").unwrap_or(false),
372        };
373
374        Ok(Self {
375            lua: Mutex::new(lua),
376            id,
377            subscriptions,
378            subscription_entries,
379            status: Status::Idle,
380            on_request_key,
381            on_signal_key,
382            init_key,
383            shutdown_key,
384            snapshot_key,
385            restore_key,
386            script_path: None,
387            emitter: None,
388            child_context: None,
389            sandbox,
390            hints,
391        })
392    }
393
394    /// Provides closure-based access to the internal Lua state.
395    ///
396    /// Intended for test mock injection (e.g. overriding `orcs.llm()`).
397    #[cfg(any(test, feature = "test-utils"))]
398    pub(crate) fn with_lua<F, R>(&self, f: F) -> R
399    where
400        F: FnOnce(&Lua) -> R,
401    {
402        let lua = self.lua.lock();
403        f(&lua)
404    }
405
406    /// Returns the script path if loaded from file.
407    #[must_use]
408    pub fn script_path(&self) -> Option<&str> {
409        self.script_path.as_deref()
410    }
411
412    /// Reloads the script from file.
413    ///
414    /// # Errors
415    ///
416    /// Returns error if reload fails.
417    pub fn reload(&mut self) -> Result<(), LuaError> {
418        let Some(path) = &self.script_path else {
419            return Err(LuaError::InvalidScript("no script path".into()));
420        };
421
422        let new_component = Self::from_file(path, Arc::clone(&self.sandbox))?;
423
424        // Swap internals (preserve emitter)
425        self.lua = new_component.lua;
426        self.subscriptions = new_component.subscriptions;
427        self.on_request_key = new_component.on_request_key;
428        self.on_signal_key = new_component.on_signal_key;
429        self.init_key = new_component.init_key;
430        self.shutdown_key = new_component.shutdown_key;
431        self.snapshot_key = new_component.snapshot_key;
432        self.restore_key = new_component.restore_key;
433        // Note: emitter is preserved across reload
434
435        // Re-register orcs.output if emitter is set
436        if let Some(emitter) = &self.emitter {
437            let lua = self.lua.lock();
438            emitter_fns::register(&lua, Arc::clone(emitter))?;
439        }
440
441        // Re-register child context functions if child_context is set
442        if let Some(ctx) = &self.child_context {
443            let lua = self.lua.lock();
444            ctx_fns::register(&lua, Arc::clone(ctx), Arc::clone(&self.sandbox))?;
445        }
446
447        tracing::info!("Reloaded Lua component: {}", self.id);
448        Ok(())
449    }
450
451    /// Returns whether this component has an emitter set.
452    ///
453    /// When true, the component can emit events via `orcs.output()`.
454    #[must_use]
455    pub fn has_emitter(&self) -> bool {
456        self.emitter.is_some()
457    }
458
459    /// Returns whether this component has a child context set.
460    ///
461    /// When true, the component can spawn children via `orcs.spawn_child()`.
462    #[must_use]
463    pub fn has_child_context(&self) -> bool {
464        self.child_context.is_some()
465    }
466
467    /// Sets the child context for spawning and managing children.
468    ///
469    /// Once set, the Lua script can use:
470    /// - `orcs.spawn_child(config)` - Spawn a child
471    /// - `orcs.child_count()` - Get current child count
472    /// - `orcs.max_children()` - Get max allowed children
473    ///
474    /// # Arguments
475    ///
476    /// * `ctx` - The child context
477    pub fn set_child_context(&mut self, ctx: Box<dyn ChildContext>) {
478        self.install_child_context(ctx);
479    }
480
481    /// Shared implementation for `set_child_context` (inherent + trait).
482    ///
483    /// Extracts hook registry, registers ctx functions, and wires up hooks.
484    fn install_child_context(&mut self, ctx: Box<dyn ChildContext>) {
485        let hook_registry = ctx
486            .extension("hook_registry")
487            .and_then(|any| any.downcast::<orcs_hook::SharedHookRegistry>().ok())
488            .map(|boxed| *boxed);
489
490        let ctx_arc = Arc::new(Mutex::new(ctx));
491        self.child_context = Some(Arc::clone(&ctx_arc));
492
493        let lua = self.lua.lock();
494
495        if let Err(e) = ctx_fns::register(&lua, ctx_arc, Arc::clone(&self.sandbox)) {
496            tracing::warn!("Failed to register child context functions: {}", e);
497        }
498
499        let Some(registry) = hook_registry else {
500            return;
501        };
502
503        if let Err(e) =
504            crate::hook_helpers::register_hook_function(&lua, registry.clone(), self.id.clone())
505        {
506            tracing::warn!("Failed to register orcs.hook(): {}", e);
507        } else {
508            tracing::debug!(component = %self.id.fqn(), "orcs.hook() registered");
509        }
510
511        if let Err(e) = crate::hook_helpers::register_unhook_function(&lua, registry.clone()) {
512            tracing::warn!("Failed to register orcs.unhook(): {}", e);
513        }
514
515        lua.set_app_data(crate::tools::ToolHookContext {
516            registry,
517            component_id: self.id.clone(),
518        });
519        if let Err(e) = crate::tools::wrap_tools_with_hooks(&lua) {
520            tracing::warn!("Failed to wrap tools with hooks: {}", e);
521        }
522    }
523}
524
525impl Component for LuaComponent {
526    fn id(&self) -> &ComponentId {
527        &self.id
528    }
529
530    fn subscriptions(&self) -> &[EventCategory] {
531        &self.subscriptions
532    }
533
534    fn subscription_entries(&self) -> Vec<SubscriptionEntry> {
535        self.subscription_entries.clone()
536    }
537
538    fn runtime_hints(&self) -> RuntimeHints {
539        self.hints.clone()
540    }
541
542    fn status(&self) -> Status {
543        self.status
544    }
545
546    #[tracing::instrument(
547        skip(self, request),
548        fields(component = %self.id.fqn(), operation = %request.operation)
549    )]
550    fn on_request(&mut self, request: &Request) -> Result<JsonValue, ComponentError> {
551        if self.status == Status::Aborted {
552            return Err(ComponentError::ExecutionFailed(
553                "component is aborted".to_string(),
554            ));
555        }
556        self.status = Status::Running;
557
558        let lua = self.lua.lock();
559
560        // Get callback from registry
561        let on_request: Function = lua.registry_value(&self.on_request_key).map_err(|e| {
562            tracing::debug!("Failed to get on_request from registry: {}", e);
563            ComponentError::ExecutionFailed("lua callback not found".to_string())
564        })?;
565
566        // Convert request to Lua
567        let lua_req = LuaRequest::from_request(request);
568
569        // Call Lua function
570        let result: LuaResponse = on_request.call(lua_req).map_err(|e| {
571            // Propagate Suspended errors transparently — ChannelRunner needs
572            // the approval_id and grant_pattern to drive the HIL flow.
573            if let Some(suspended) = extract_suspended(&e) {
574                return suspended;
575            }
576            // Sanitize other error messages to avoid leaking internal details
577            tracing::debug!("Lua on_request error: {}", e);
578            ComponentError::ExecutionFailed("lua script execution failed".to_string())
579        })?;
580
581        drop(lua);
582        self.status = Status::Idle;
583
584        if result.success {
585            Ok(result.data.unwrap_or(JsonValue::Null))
586        } else {
587            Err(ComponentError::ExecutionFailed(
588                result.error.unwrap_or_else(|| "unknown error".into()),
589            ))
590        }
591    }
592
593    #[tracing::instrument(
594        skip(self, signal),
595        fields(component = %self.id.fqn(), signal_kind = ?signal.kind)
596    )]
597    fn on_signal(&mut self, signal: &Signal) -> SignalResponse {
598        let lua = self.lua.lock();
599
600        let Ok(on_signal): Result<Function, _> = lua.registry_value(&self.on_signal_key) else {
601            return SignalResponse::Ignored;
602        };
603
604        let lua_sig = LuaSignal::from_signal(signal);
605
606        let result: Result<String, _> = on_signal.call(lua_sig);
607
608        match result {
609            Ok(response_str) => {
610                let response = parse_signal_response(&response_str);
611                if matches!(response, SignalResponse::Abort) {
612                    drop(lua);
613                    self.status = Status::Aborted;
614                }
615                response
616            }
617            Err(e) => {
618                tracing::warn!("Lua on_signal error: {}", e);
619                SignalResponse::Ignored
620            }
621        }
622    }
623
624    fn abort(&mut self) {
625        self.status = Status::Aborted;
626    }
627
628    /// Calls the Lua `init(cfg)` callback with per-component settings.
629    ///
630    /// `config` contains `[components.settings.<name>]` from config.toml,
631    /// plus `_global` (injected by builder) with global config fields.
632    /// Null or empty objects are passed as `nil` to Lua.
633    #[tracing::instrument(skip(self, config), fields(component = %self.id.fqn()))]
634    fn init(&mut self, config: &serde_json::Value) -> Result<(), ComponentError> {
635        let Some(init_key) = &self.init_key else {
636            return Ok(());
637        };
638
639        let lua = self.lua.lock();
640
641        let init_fn: Function = lua.registry_value(init_key).map_err(|e| {
642            tracing::debug!("Failed to get init from registry: {}", e);
643            ComponentError::ExecutionFailed("lua init callback not found".to_string())
644        })?;
645
646        // Convert JSON config to Lua value; pass nil if null or empty object
647        let lua_config = if config.is_null()
648            || (config.is_object() && config.as_object().map_or(true, serde_json::Map::is_empty))
649        {
650            mlua::Value::Nil
651        } else {
652            lua.to_value(config).map_err(|e| {
653                tracing::debug!("Failed to convert config to Lua: {}", e);
654                ComponentError::ExecutionFailed("config conversion failed".to_string())
655            })?
656        };
657
658        init_fn.call::<()>(lua_config).map_err(|e| {
659            tracing::debug!("Lua init error: {}", e);
660            ComponentError::ExecutionFailed("lua init callback failed".to_string())
661        })?;
662
663        Ok(())
664    }
665
666    #[tracing::instrument(skip(self), fields(component = %self.id.fqn()))]
667    fn shutdown(&mut self) {
668        let Some(shutdown_key) = &self.shutdown_key else {
669            return;
670        };
671
672        let lua = self.lua.lock();
673
674        if let Ok(shutdown_fn) = lua.registry_value::<Function>(shutdown_key) {
675            if let Err(e) = shutdown_fn.call::<()>(()) {
676                tracing::warn!("Lua shutdown error: {}", e);
677            }
678        }
679    }
680
681    fn snapshot(&self) -> Result<ComponentSnapshot, SnapshotError> {
682        let Some(snapshot_key) = &self.snapshot_key else {
683            return Err(SnapshotError::NotSupported(self.id.fqn()));
684        };
685
686        let lua = self.lua.lock();
687
688        let snapshot_fn: Function = lua
689            .registry_value(snapshot_key)
690            .map_err(|e| SnapshotError::InvalidData(format!("snapshot callback not found: {e}")))?;
691
692        let lua_result: LuaValue = snapshot_fn
693            .call(())
694            .map_err(|e| SnapshotError::InvalidData(format!("snapshot callback failed: {e}")))?;
695
696        let json_value = lua_value_to_json(&lua_result);
697        ComponentSnapshot::from_state(self.id.fqn(), &json_value)
698    }
699
700    fn restore(&mut self, snapshot: &ComponentSnapshot) -> Result<(), SnapshotError> {
701        let Some(restore_key) = &self.restore_key else {
702            return Err(SnapshotError::NotSupported(self.id.fqn()));
703        };
704
705        snapshot.validate(&self.id.fqn())?;
706
707        let lua = self.lua.lock();
708
709        let restore_fn: Function = lua
710            .registry_value(restore_key)
711            .map_err(|e| SnapshotError::InvalidData(format!("restore callback not found: {e}")))?;
712
713        let lua_value = json_to_lua_value(&lua, &snapshot.state).map_err(|e| {
714            SnapshotError::InvalidData(format!("failed to convert snapshot to lua: {e}"))
715        })?;
716
717        restore_fn
718            .call::<()>(lua_value)
719            .map_err(|e| SnapshotError::RestoreFailed {
720                component: self.id.fqn(),
721                reason: format!("restore callback failed: {e}"),
722            })?;
723
724        Ok(())
725    }
726
727    fn set_emitter(&mut self, emitter: Box<dyn Emitter>) {
728        let emitter_arc = Arc::new(Mutex::new(emitter));
729        self.emitter = Some(Arc::clone(&emitter_arc));
730
731        // Register emitter-backed Lua functions (orcs.output, orcs.emit_event)
732        let lua = self.lua.lock();
733        if let Err(e) = emitter_fns::register(&lua, emitter_arc) {
734            tracing::warn!("Failed to register emitter functions: {}", e);
735        }
736    }
737
738    fn set_child_context(&mut self, ctx: Box<dyn ChildContext>) {
739        self.install_child_context(ctx);
740    }
741}
742
743/// ComponentLoader implementation for Lua components.
744///
745/// Allows creating LuaComponent instances from inline script content
746/// for use with ChildContext::spawn_runner_from_script().
747#[derive(Clone)]
748pub struct LuaComponentLoader {
749    sandbox: Arc<dyn SandboxPolicy>,
750}
751
752impl LuaComponentLoader {
753    /// Creates a new LuaComponentLoader with the given sandbox policy.
754    #[must_use]
755    pub fn new(sandbox: Arc<dyn SandboxPolicy>) -> Self {
756        Self { sandbox }
757    }
758}
759
760impl ComponentLoader for LuaComponentLoader {
761    fn load_from_script(
762        &self,
763        script: &str,
764        _id: Option<&str>,
765    ) -> Result<Box<dyn Component>, SpawnError> {
766        // Note: id parameter is ignored; LuaComponent extracts ID from script
767        LuaComponent::from_script(script, Arc::clone(&self.sandbox))
768            .map(|c| Box::new(c) as Box<dyn Component>)
769            .map_err(|e| SpawnError::InvalidScript(e.to_string()))
770    }
771}
772
773// === JSON ↔ Lua conversion helpers for snapshot/restore ===
774
775/// Converts a Lua value to a serde_json::Value.
776fn lua_value_to_json(value: &LuaValue) -> JsonValue {
777    match value {
778        LuaValue::Nil => JsonValue::Null,
779        LuaValue::Boolean(b) => JsonValue::Bool(*b),
780        LuaValue::Integer(i) => JsonValue::Number((*i).into()),
781        LuaValue::Number(n) => serde_json::Number::from_f64(*n)
782            .map(JsonValue::Number)
783            .unwrap_or(JsonValue::Null),
784        LuaValue::String(s) => JsonValue::String(s.to_string_lossy().to_string()),
785        LuaValue::Table(table) => {
786            // Detect array vs object: check if sequential integer keys starting at 1
787            let len = table.raw_len();
788            let is_array = len > 0
789                && table
790                    .clone()
791                    .pairs::<i64, LuaValue>()
792                    .enumerate()
793                    .all(|(idx, pair)| pair.map(|(k, _)| k == (idx as i64 + 1)).unwrap_or(false));
794
795            if is_array {
796                let arr: Vec<JsonValue> = table
797                    .clone()
798                    .sequence_values::<LuaValue>()
799                    .filter_map(|v| v.ok())
800                    .map(|v| lua_value_to_json(&v))
801                    .collect();
802                JsonValue::Array(arr)
803            } else {
804                let mut map = serde_json::Map::new();
805                if let Ok(pairs) = table
806                    .clone()
807                    .pairs::<LuaValue, LuaValue>()
808                    .collect::<Result<Vec<_>, _>>()
809                {
810                    for (k, v) in pairs {
811                        let key = match &k {
812                            LuaValue::String(s) => s.to_string_lossy().to_string(),
813                            LuaValue::Integer(i) => i.to_string(),
814                            _ => continue,
815                        };
816                        map.insert(key, lua_value_to_json(&v));
817                    }
818                }
819                JsonValue::Object(map)
820            }
821        }
822        _ => JsonValue::Null,
823    }
824}
825
826/// Converts a serde_json::Value to a Lua value.
827fn json_to_lua_value(lua: &Lua, value: &JsonValue) -> Result<LuaValue, mlua::Error> {
828    match value {
829        JsonValue::Null => Ok(LuaValue::Nil),
830        JsonValue::Bool(b) => Ok(LuaValue::Boolean(*b)),
831        JsonValue::Number(n) => {
832            if let Some(i) = n.as_i64() {
833                Ok(LuaValue::Integer(i))
834            } else if let Some(f) = n.as_f64() {
835                Ok(LuaValue::Number(f))
836            } else {
837                Ok(LuaValue::Nil)
838            }
839        }
840        JsonValue::String(s) => s.as_str().into_lua(lua),
841        JsonValue::Array(arr) => {
842            let table = lua.create_table()?;
843            for (i, v) in arr.iter().enumerate() {
844                let lua_val = json_to_lua_value(lua, v)?;
845                table.raw_set(i + 1, lua_val)?;
846            }
847            Ok(LuaValue::Table(table))
848        }
849        JsonValue::Object(map) => {
850            let table = lua.create_table()?;
851            for (k, v) in map {
852                let lua_val = json_to_lua_value(lua, v)?;
853                table.raw_set(k.as_str(), lua_val)?;
854            }
855            Ok(LuaValue::Table(table))
856        }
857    }
858}
859
860#[cfg(test)]
861mod tests;
862
863#[cfg(test)]
864mod extract_suspended_tests {
865    use super::*;
866
867    #[test]
868    fn extracts_suspended_from_external_error() {
869        let suspended = ComponentError::Suspended {
870            approval_id: "ap-1".into(),
871            grant_pattern: "shell:*".into(),
872            pending_request: serde_json::json!({"cmd": "ls"}),
873        };
874        let err = mlua::Error::ExternalError(Arc::new(suspended));
875        let result = extract_suspended(&err);
876        assert!(
877            result.is_some(),
878            "should extract Suspended from ExternalError"
879        );
880        match result.expect("already checked is_some") {
881            ComponentError::Suspended { approval_id, .. } => {
882                assert_eq!(approval_id, "ap-1");
883            }
884            other => panic!("Expected Suspended, got {:?}", other),
885        }
886    }
887
888    #[test]
889    fn extracts_suspended_from_callback_error() {
890        let suspended = ComponentError::Suspended {
891            approval_id: "ap-2".into(),
892            grant_pattern: "tool:*".into(),
893            pending_request: serde_json::Value::Null,
894        };
895        let inner = mlua::Error::ExternalError(Arc::new(suspended));
896        let err = mlua::Error::CallbackError {
897            traceback: "stack trace".into(),
898            cause: Arc::new(inner),
899        };
900        let result = extract_suspended(&err);
901        assert!(
902            result.is_some(),
903            "should extract Suspended through CallbackError"
904        );
905    }
906
907    #[test]
908    fn extracts_suspended_from_nested_callback_errors() {
909        let suspended = ComponentError::Suspended {
910            approval_id: "ap-3".into(),
911            grant_pattern: "exec:*".into(),
912            pending_request: serde_json::Value::Null,
913        };
914        let inner = mlua::Error::ExternalError(Arc::new(suspended));
915        let mid = mlua::Error::CallbackError {
916            traceback: "level 1".into(),
917            cause: Arc::new(inner),
918        };
919        let outer = mlua::Error::CallbackError {
920            traceback: "level 2".into(),
921            cause: Arc::new(mid),
922        };
923        let result = extract_suspended(&outer);
924        assert!(
925            result.is_some(),
926            "should extract through nested CallbackErrors"
927        );
928    }
929
930    #[test]
931    fn returns_none_for_non_suspended_component_error() {
932        let err =
933            mlua::Error::ExternalError(Arc::new(ComponentError::ExecutionFailed("timeout".into())));
934        assert!(
935            extract_suspended(&err).is_none(),
936            "ExecutionFailed should not match"
937        );
938    }
939
940    #[test]
941    fn returns_none_for_runtime_error() {
942        let err = mlua::Error::RuntimeError("some error".into());
943        assert!(
944            extract_suspended(&err).is_none(),
945            "RuntimeError should not match"
946        );
947    }
948}