Skip to main content

coralstack_cmd_ipc/
registry.rs

1//! The [`CommandRegistry`] — core routing and execution hub.
2//!
3//! Mirrors the TypeScript `CommandRegistry` in
4//! `packages/cmd-ipc/src/registry/command-registry.ts`: it owns a local
5//! command table, a remote command table (command → owning channel), a
6//! set of connected channels, and a handful of [`TtlMap`]s correlating
7//! in-flight requests, forwarded routes, and recently-seen events.
8//!
9//! # Topology
10//!
11//! * **Root** registries have no `router_channel`. Unknown commands
12//!   produce `NotFound` errors.
13//! * **Child** registries set `router_channel = Some(peer_id)`. Unknown
14//!   commands, and new local registrations, are escalated upstream.
15//! * Events fan out across every connected channel; dedup by message
16//!   id prevents echo loops in meshes.
17//!
18//! Private commands/events (identifiers starting with `_`) stay local
19//! — never escalated, never advertised, never broadcast.
20
21use std::collections::{BTreeMap, HashMap};
22use std::future::Future;
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::Duration;
26
27use futures::channel::oneshot;
28use futures::future::BoxFuture;
29use futures::stream::FuturesUnordered;
30use futures::{FutureExt, StreamExt};
31use parking_lot::Mutex;
32use serde::Serialize;
33use serde_json::Value;
34use uuid::Uuid;
35
36use crate::channel::CommandChannel;
37use crate::command::Command;
38use crate::error::{ChannelError, CommandError, ExecuteErrorCode, RegisterErrorCode};
39use crate::event::Event;
40use crate::message::{
41    CommandDef, ExecuteError, ExecuteResult, False, Message, MessageId, RegisterResult, True,
42};
43use crate::ttl_map::TtlMap;
44
45/// Configuration for a [`CommandRegistry`].
46pub struct Config {
47    /// Registry identifier used in log messages. Defaults to a random
48    /// UUID.
49    pub id: Option<String>,
50    /// Channel id to escalate unknown commands and new registrations
51    /// to. Leave `None` for a root registry.
52    pub router_channel: Option<String>,
53    /// How long a pending `execute` / `register` reply can wait before
54    /// being rejected with [`CommandError::Timeout`]. Zero disables
55    /// the TTL check (request hangs until the channel closes).
56    pub request_ttl: Duration,
57    /// How long a seen event id is remembered for dedup purposes.
58    pub event_ttl: Duration,
59    /// Maximum number of handler futures that may be in flight
60    /// concurrently on a single channel's pump. Once this cap is
61    /// reached, the pump stops pulling new messages off the channel
62    /// (applying upstream backpressure) until an in-flight handler
63    /// finishes. `0` disables the cap.
64    ///
65    /// **Optional in practice** — `Config::default()` sets this to
66    /// `256`, so callers using `..Default::default()` never need to
67    /// supply it. It is only *syntactically* required when you build
68    /// `Config { … }` field-by-field (Rust struct literals must list
69    /// every field). The cap exists so a misbehaving peer can't cause
70    /// unbounded handler-future buildup; the default is fine for
71    /// almost every workload.
72    pub max_in_flight_per_channel: usize,
73}
74
75impl Default for Config {
76    fn default() -> Self {
77        Self {
78            id: None,
79            router_channel: None,
80            request_ttl: Duration::from_secs(30),
81            event_ttl: Duration::from_secs(5),
82            max_in_flight_per_channel: 256,
83        }
84    }
85}
86
87type HandlerFn = dyn Fn(Value) -> BoxFuture<'static, Result<Value, ExecuteError>> + Send + Sync;
88type EventListener = Arc<dyn Fn(Value) + Send + Sync>;
89
90struct LocalEntry {
91    handler: Arc<HandlerFn>,
92    def: CommandDef,
93    is_private: bool,
94}
95
96struct PendingExecute {
97    tx: oneshot::Sender<ExecuteResult>,
98    target_channel: String,
99}
100
101/// Outcome delivered to the caller awaiting a `register.command.request`
102/// reply. The peer's `RegisterResult` is wrapped so the registry can
103/// synthesize timeout / disconnect errors distinguishable from the
104/// wire-level duplicate-command case.
105enum RegisterOutcome {
106    Wire(RegisterResult),
107    Timeout,
108    Disconnected,
109}
110
111struct PendingRegister {
112    tx: oneshot::Sender<RegisterOutcome>,
113    target_channel: String,
114}
115
116struct RouteEntry {
117    origin_channel: String,
118    target_channel: String,
119}
120
121/// Shared state behind the [`CommandRegistry`] Arc.
122struct Inner {
123    id: String,
124    router_channel: Option<String>,
125    local: Mutex<HashMap<String, LocalEntry>>,
126    /// command id -> owning channel id
127    remote: Mutex<HashMap<String, String>>,
128    /// command id -> advertised definition (description + schema)
129    /// kept parallel to `remote` so `list_commands` can render
130    /// the same richness for remote entries as for local ones.
131    remote_defs: Mutex<HashMap<String, CommandDef>>,
132    channels: Mutex<HashMap<String, Arc<dyn CommandChannel>>>,
133    execute_replies: TtlMap<MessageId, PendingExecute>,
134    register_replies: TtlMap<MessageId, PendingRegister>,
135    routes: TtlMap<MessageId, RouteEntry>,
136    seen_events: TtlMap<MessageId, ()>,
137    /// Event listeners keyed by event id then by a monotonic token, so
138    /// `add_event_listener` can return an unsubscribe closure that
139    /// removes just that one listener. `BTreeMap` preserves insertion
140    /// order (tokens are monotonically increasing) for dispatch.
141    event_listeners: Mutex<HashMap<String, BTreeMap<u64, EventListener>>>,
142    /// Monotonically-increasing token used to key event listeners.
143    next_listener_token: AtomicU64,
144    /// Per-channel in-flight cap, copied from [`Config`].
145    max_in_flight_per_channel: usize,
146}
147
148/// The main entry point of the crate.
149///
150/// A registry is cheap to clone: internally it's an `Arc<Inner>`.
151#[derive(Clone)]
152pub struct CommandRegistry {
153    inner: Arc<Inner>,
154}
155
156impl CommandRegistry {
157    pub fn new(cfg: Config) -> Self {
158        // On TTL expiry, deliver an explicit Timeout outcome on the
159        // pending oneshot so the caller's `rx.await` resolves to
160        // `CommandError::Timeout` rather than blocking forever (or
161        // collapsing to a generic `ChannelDisconnected` when the tx is
162        // dropped). Eviction is lazy — driven by `sweep_expired()`
163        // calls in the driver loop and registry hot paths.
164        let execute_replies =
165            TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingExecute| {
166                let _ = pending.tx.send(ExecuteResult::Err {
167                    ok: False,
168                    error: ExecuteError {
169                        code: ExecuteErrorCode::Timeout,
170                        message: "request timed out".into(),
171                    },
172                });
173            });
174        let register_replies =
175            TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingRegister| {
176                let _ = pending.tx.send(RegisterOutcome::Timeout);
177            });
178
179        let inner = Arc::new(Inner {
180            id: cfg.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
181            router_channel: cfg.router_channel,
182            local: Mutex::new(HashMap::new()),
183            remote: Mutex::new(HashMap::new()),
184            remote_defs: Mutex::new(HashMap::new()),
185            channels: Mutex::new(HashMap::new()),
186            execute_replies,
187            register_replies,
188            routes: TtlMap::new(cfg.request_ttl),
189            seen_events: TtlMap::new(cfg.event_ttl),
190            event_listeners: Mutex::new(HashMap::new()),
191            next_listener_token: AtomicU64::new(0),
192            max_in_flight_per_channel: cfg.max_in_flight_per_channel,
193        });
194        Self { inner }
195    }
196
197    /// Returns this registry's identifier.
198    pub fn id(&self) -> &str {
199        &self.inner.id
200    }
201
202    /// Returns the ids of every currently-registered channel, sorted.
203    ///
204    /// Mirrors the TypeScript library's `listChannels()` method.
205    pub fn list_channels(&self) -> Vec<String> {
206        let mut ids: Vec<String> = self.inner.channels.lock().keys().cloned().collect();
207        ids.sort();
208        ids
209    }
210
211    /// Returns the full [`CommandDef`] (id + description + schema) for
212    /// every reachable command — local (non-private) and remote. Remote
213    /// defs are those advertised via `register.command.request` or
214    /// `list.commands.response` on the channel.
215    ///
216    /// Mirrors the TypeScript library's `listCommands()` method.
217    /// Results are sorted by id. A command id is only included once even
218    /// if both a local and remote entry exist (local wins).
219    pub fn list_commands(&self) -> Vec<CommandDef> {
220        let mut out: HashMap<String, CommandDef> = HashMap::new();
221        for (id, entry) in self.inner.local.lock().iter() {
222            if !entry.is_private {
223                out.insert(id.clone(), entry.def.clone());
224            }
225        }
226        for (id, def) in self.inner.remote_defs.lock().iter() {
227            out.entry(id.clone()).or_insert_with(|| def.clone());
228        }
229        let mut v: Vec<CommandDef> = out.into_values().collect();
230        v.sort_by(|a, b| a.id.cmp(&b.id));
231        v
232    }
233
234    /// Register a command on this registry.
235    ///
236    /// The single registration entry point, covering both compile-time
237    /// and runtime commands:
238    ///
239    /// - **Compile-time**: pass an instance of a type that implements
240    ///   [`Command`]. The `#[command]` / `#[command_service]` macros
241    ///   generate such types from a plain `async fn`.
242    /// - **Runtime**: pass a [`DynCommand`] carrying owned id /
243    ///   description / schema and a closure handler.
244    ///
245    /// Mirrors the TypeScript library's `registerCommand`.
246    ///
247    /// - Commands whose id starts with `_` stay local: they are never
248    ///   escalated to a `router_channel` and never advertised to peers
249    ///   via `list.commands.response`.
250    /// - Non-private commands are escalated upstream if this registry
251    ///   has a `router_channel`; the local entry is only committed
252    ///   after the router acks.
253    /// - The advertised schema is normalized via
254    ///   [`crate::schema::normalize_schema`] on the way in, so every
255    ///   schema leaving the registry is language-agnostic JSON Schema
256    ///   regardless of how the caller built it.
257    pub async fn register_command<C: Command>(&self, cmd: C) -> Result<(), CommandError> {
258        let id = cmd.id().to_string();
259        let description = cmd.description().map(str::to_string);
260        let schema = cmd.schema().map(crate::schema::normalize_command_schema);
261        let is_private = id.starts_with('_');
262        let def = CommandDef {
263            id: id.clone(),
264            description,
265            schema,
266        };
267        let handler: Arc<HandlerFn> = Arc::new({
268            let cmd = Arc::new(cmd);
269            move |value: Value| {
270                let cmd = cmd.clone();
271                async move {
272                    let req: C::Request =
273                        serde_json::from_value(value).map_err(|e| ExecuteError {
274                            code: ExecuteErrorCode::InvalidRequest,
275                            message: e.to_string(),
276                        })?;
277                    let res = cmd
278                        .handle(req)
279                        .await
280                        .map_err(|e| command_error_to_execute(&e, cmd.id()))?;
281                    serde_json::to_value(res).map_err(|e| ExecuteError {
282                        code: ExecuteErrorCode::InternalError,
283                        message: e.to_string(),
284                    })
285                }
286                .boxed()
287            }
288        });
289        self.register_inner(id, handler, def, is_private).await
290    }
291
292    async fn register_inner(
293        &self,
294        id: String,
295        handler: Arc<HandlerFn>,
296        def: CommandDef,
297        is_private: bool,
298    ) -> Result<(), CommandError> {
299        self.inner.execute_replies.sweep_expired();
300        self.inner.register_replies.sweep_expired();
301        // Duplicate check against the local table.
302        if self.inner.local.lock().contains_key(&id) {
303            return Err(CommandError::DuplicateCommand(id));
304        }
305
306        // Non-private commands escalate to the router before being added.
307        if !is_private {
308            if let Some(router_id) = self.inner.router_channel.clone() {
309                let router_ch = self.inner.channels.lock().get(&router_id).cloned();
310                if let Some(router_ch) = router_ch {
311                    let req_id = MessageId::new_v4();
312                    let (tx, rx) = oneshot::channel();
313                    self.inner.register_replies.insert(
314                        req_id,
315                        PendingRegister {
316                            tx,
317                            target_channel: router_id.clone(),
318                        },
319                    );
320                    router_ch
321                        .send(Message::RegisterCommandRequest {
322                            id: req_id,
323                            meta: None,
324                            command: def.clone(),
325                        })
326                        .map_err(|_| CommandError::ChannelDisconnected)?;
327                    match rx.await {
328                        Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
329                        Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
330                            return Err(match error {
331                                RegisterErrorCode::DuplicateCommand => {
332                                    CommandError::DuplicateCommand(id)
333                                }
334                            });
335                        }
336                        Ok(RegisterOutcome::Timeout) => return Err(CommandError::Timeout),
337                        Ok(RegisterOutcome::Disconnected) | Err(_) => {
338                            return Err(CommandError::ChannelDisconnected);
339                        }
340                    }
341                }
342            }
343        }
344
345        self.inner.local.lock().insert(
346            id,
347            LocalEntry {
348                handler,
349                def,
350                is_private,
351            },
352        );
353        Ok(())
354    }
355
356    /// Connects a [`CommandChannel`] to this registry.
357    ///
358    /// Returns a driver future which must be polled by the caller's
359    /// executor (via `tokio::spawn`, `smol::spawn`,
360    /// `futures::executor::block_on`, …) for the registry to exchange
361    /// messages with the peer. The future completes when the channel
362    /// closes.
363    ///
364    /// Handler dispatch is **concurrent within the driver task**: the
365    /// pump pushes each incoming message's handler future into a
366    /// [`FuturesUnordered`] and cooperatively interleaves them with
367    /// the next `recv`, so a slow handler that awaits external work
368    /// no longer blocks subsequent messages on the same channel. The
369    /// number of simultaneously in-flight handlers is capped by
370    /// [`Config::max_in_flight_per_channel`] (default 256); at the
371    /// cap the pump applies backpressure to the channel rather than
372    /// dropping messages. For true multi-thread parallelism, wrap
373    /// the handler body in your runtime's `spawn` (e.g.
374    /// `tokio::spawn`) — the crate itself stays runtime-agnostic.
375    pub async fn register_channel(
376        &self,
377        channel: Arc<dyn CommandChannel>,
378    ) -> Result<impl Future<Output = ()> + Send + 'static, ChannelError> {
379        let id = channel.id().to_string();
380        {
381            let mut chans = self.inner.channels.lock();
382            if chans.contains_key(&id) {
383                return Err(ChannelError::Other(format!(
384                    "channel with id `{id}` already registered"
385                )));
386            }
387            chans.insert(id.clone(), channel.clone());
388        }
389
390        channel.start().await?;
391
392        // Ask the peer for its command list. The response is handled
393        // by the driver loop, which will register each entry as remote.
394        if let Err(e) = channel.send(Message::ListCommandsRequest {
395            id: MessageId::new_v4(),
396            meta: None,
397        }) {
398            self.inner.channels.lock().remove(&id);
399            return Err(e);
400        }
401
402        let inner = self.inner.clone();
403        let ch = channel;
404        Ok(async move {
405            // In-flight handler futures cooperatively interleaved with
406            // message recv. This is what makes the pump non-blocking:
407            // when one handler awaits (sleep / network / forwarded
408            // call), the executor polls other handlers and the recv
409            // future on the same task, so a slow handler can no
410            // longer wedge every subsequent message behind it.
411            //
412            // Stays runtime-agnostic (no `tokio::spawn`); the user's
413            // executor still owns the single task driving this pump.
414            // True multi-thread parallelism is possible if the user
415            // wraps their handler body in `tokio::spawn` themselves.
416            let mut in_flight: FuturesUnordered<BoxFuture<'static, ()>> = FuturesUnordered::new();
417            let cap = inner.max_in_flight_per_channel;
418
419            loop {
420                // Backpressure: while at capacity, drain in-flight
421                // handlers instead of accepting new messages. Ordering
422                // of message *reception* is preserved — only dispatch
423                // fans out. Messages are never dropped.
424                while cap > 0 && in_flight.len() >= cap {
425                    if in_flight.next().await.is_none() {
426                        break;
427                    }
428                }
429
430                // Opportunistic TTL sweep: any registry activity
431                // triggers a pass over the pending-reply / route maps
432                // so timed-out entries fire their on_expire callbacks
433                // (delivering Timeout to waiting callers). Cheap —
434                // O(n) over maps that are normally tiny.
435                inner.execute_replies.sweep_expired();
436                inner.register_replies.sweep_expired();
437                inner.routes.sweep_expired();
438
439                // Wait for either the next message OR for an in-flight
440                // handler to make progress. When `in_flight` is empty
441                // we'd otherwise busy-loop on its `.next()` returning
442                // `None`, so park on `pending()` in that case.
443                let next_msg = if in_flight.is_empty() {
444                    ch.recv().await
445                } else {
446                    futures::select_biased! {
447                        // Drain finished handlers so the in_flight set
448                        // shrinks promptly. `select_next_some` skips
449                        // the empty case, but we've already guarded
450                        // against that above.
451                        _ = in_flight.select_next_some() => continue,
452                        msg = ch.recv().fuse() => msg,
453                    }
454                };
455
456                let Some(msg) = next_msg else { break };
457                in_flight.push(Inner::handle_message(inner.clone(), ch.clone(), msg).boxed());
458            }
459
460            // Channel closed — drain any in-flight handlers so their
461            // outgoing responses are sent before we tear the channel
462            // state down. Handlers that try to `send` on the now-closed
463            // channel will get `Err(Closed)` from the transport, which
464            // is the existing semantics.
465            while in_flight.next().await.is_some() {}
466            Inner::handle_channel_close(&inner, ch.id());
467        })
468    }
469
470    /// Executes a command identified by a compile-time [`Command`] type
471    /// — the **strict** form, giving the same compile-time type safety
472    /// that TypeScript's strict-mode `executeCommand<K>` gives via the
473    /// `CommandSchemaMap` type parameter.
474    ///
475    /// The command id comes from `C::ID`, the request type is pinned to
476    /// `C::Request`, and the response type is pinned to `C::Response`,
477    /// so the compiler rejects mismatches at the call site:
478    ///
479    /// ```ignore
480    /// let sum: i64 = registry.execute::<MathAdd>(AddReq { a: 2, b: 3 }).await?;
481    /// ```
482    ///
483    /// For commands whose id or payload shape is only known at runtime
484    /// (scripting hosts, FFI, plugins that advertise their own schema),
485    /// use [`execute_dyn`](Self::execute_dyn).
486    pub async fn execute<C: Command>(
487        &self,
488        request: C::Request,
489    ) -> Result<C::Response, CommandError>
490    where
491        C::Request: Serialize,
492        C::Response: serde::de::DeserializeOwned,
493    {
494        let req_value = value_from_request(&request)?;
495        let result = self.execute_raw_impl(C::ID.to_string(), req_value).await?;
496        let deserialized = serde_json::from_value(result.unwrap_or(Value::Null))?;
497        Ok(deserialized)
498    }
499
500    /// Executes a command whose id is only known at runtime — the
501    /// **loose** form, mirroring the TypeScript library's
502    /// `executeCommand(id, args)` in loose mode.
503    ///
504    /// Request and response are raw [`serde_json::Value`]s, so this is
505    /// the canonical entry point for plugin hosts, scripting runtimes,
506    /// FFI bridges, and any code where the schema is discovered via
507    /// [`list_commands`](Self::list_commands) rather than declared at
508    /// compile time.
509    ///
510    /// For statically-known commands, prefer [`execute`](Self::execute)
511    /// — it pins both types via the [`Command`] trait.
512    pub async fn execute_dyn(
513        &self,
514        command_id: &str,
515        request: Value,
516    ) -> Result<Value, CommandError> {
517        let result = self
518            .execute_raw_impl(command_id.to_string(), request)
519            .await?;
520        Ok(result.unwrap_or(Value::Null))
521    }
522
523    async fn execute_raw_impl(
524        &self,
525        command_id: String,
526        request: Value,
527    ) -> Result<Option<Value>, CommandError> {
528        // Flush any expired pending replies so stale entries fire their
529        // Timeout on_expire before we enqueue a new one.
530        self.inner.execute_replies.sweep_expired();
531        self.inner.register_replies.sweep_expired();
532        // 1) Local handler wins.
533        let local_handler = self
534            .inner
535            .local
536            .lock()
537            .get(&command_id)
538            .map(|entry| entry.handler.clone());
539        if let Some(handler) = local_handler {
540            return handler(request)
541                .await
542                .map(Some)
543                .map_err(|e| e.into_command_error(&command_id));
544        }
545
546        // 2) Known remote command.
547        let remote_target = self.inner.remote.lock().get(&command_id).cloned();
548        let target = match remote_target {
549            Some(t) => Some(t),
550            None => self.inner.router_channel.clone(),
551        };
552
553        let Some(target_id) = target else {
554            return Err(CommandError::NotFound(command_id));
555        };
556
557        let channel = self.inner.channels.lock().get(&target_id).cloned();
558        let Some(channel) = channel else {
559            return Err(CommandError::ChannelDisconnected);
560        };
561
562        self.forward_execute(command_id, request, &channel, target_id)
563            .await
564    }
565
566    async fn forward_execute(
567        &self,
568        command_id: String,
569        request: Value,
570        channel: &Arc<dyn CommandChannel>,
571        target_id: String,
572    ) -> Result<Option<Value>, CommandError> {
573        let req_id = MessageId::new_v4();
574        let (tx, rx) = oneshot::channel();
575        self.inner.execute_replies.insert(
576            req_id,
577            PendingExecute {
578                tx,
579                target_channel: target_id,
580            },
581        );
582        channel
583            .send(Message::ExecuteCommandRequest {
584                id: req_id,
585                meta: None,
586                command_id: command_id.clone(),
587                // Void requests are elided from the wire (Null → None)
588                // so peers expecting an absent `request` field (per the
589                // JSON Schema spec) don't see `"request": null`.
590                request: value_to_wire(request),
591            })
592            .map_err(|_| CommandError::ChannelDisconnected)?;
593
594        match rx.await {
595            Ok(ExecuteResult::Ok { result, .. }) => Ok(result),
596            Ok(ExecuteResult::Err { error, .. }) => Err(error_to_command_error(error, &command_id)),
597            Err(_) => {
598                self.inner.execute_replies.remove(&req_id);
599                Err(CommandError::ChannelDisconnected)
600            }
601        }
602    }
603
604    /// Emit an event. Dispatches to local listeners and — unless the
605    /// event id is private (starts with `_`) — broadcasts to every
606    /// connected channel.
607    ///
608    /// Works for both compile-time events (`#[event]`-annotated
609    /// structs) and runtime events ([`DynEvent`](crate::command::Command)
610    /// — actually [`DynEvent`](crate::event::DynEvent)). Id, description,
611    /// and schema are all read off the event instance.
612    pub fn emit<E: Event>(&self, event: E) -> Result<(), CommandError> {
613        let event_id = event.id().to_string();
614        let payload_value = serde_json::to_value(&event)?;
615        let msg_id = MessageId::new_v4();
616        self.inner.seen_events.insert(msg_id, ());
617
618        self.dispatch_event_locally(&event_id, &payload_value);
619
620        if !event_id.starts_with('_') {
621            let channels: Vec<Arc<dyn CommandChannel>> =
622                self.inner.channels.lock().values().cloned().collect();
623            // Void payloads (serde `()` → `Value::Null`) are elided
624            // from the wire per the event schema (`payload` is optional).
625            let wire_payload = value_to_wire(payload_value);
626            for ch in channels {
627                let _ = ch.send(Message::Event {
628                    id: msg_id,
629                    meta: None,
630                    event_id: event_id.clone(),
631                    payload: wire_payload.clone(),
632                });
633            }
634        }
635        Ok(())
636    }
637
638    /// Subscribe a typed listener. The callback receives a
639    /// deserialized `E` every time an event with id `E::ID` fires,
640    /// whether emitted locally or received from a connected channel.
641    ///
642    /// Returns an unsubscribe closure — call it (and drop it) to
643    /// remove just this listener. Ignoring the return value is fine;
644    /// the listener then lives for the life of the registry.
645    ///
646    /// Listeners for the same event fire in insertion order. Payloads
647    /// that fail to deserialize into `E` are silently dropped for
648    /// this listener — they still flow to any typed-for-Value
649    /// listeners registered via [`on_dyn`](Self::on_dyn).
650    pub fn on<E: Event + serde::de::DeserializeOwned>(
651        &self,
652        listener: impl Fn(E) + Send + Sync + 'static,
653    ) -> impl FnOnce() + Send + Sync + 'static {
654        self.install_listener(E::ID, move |value| {
655            if let Ok(typed) = serde_json::from_value::<E>(value) {
656                listener(typed);
657            }
658        })
659    }
660
661    /// Subscribe a dynamic listener by runtime id. The callback
662    /// receives the raw JSON payload. Use this when the event id is
663    /// only known at runtime (plugin runtimes, FFI, scripting hosts);
664    /// prefer [`on`](Self::on) whenever you have a compile-time
665    /// [`Event`] type.
666    ///
667    /// Same unsubscribe semantics as [`on`](Self::on).
668    pub fn on_dyn<F>(
669        &self,
670        event_id: impl Into<String>,
671        listener: F,
672    ) -> impl FnOnce() + Send + Sync + 'static
673    where
674        F: Fn(Value) + Send + Sync + 'static,
675    {
676        self.install_listener(&event_id.into(), listener)
677    }
678
679    fn install_listener<F>(
680        &self,
681        event_id: &str,
682        listener: F,
683    ) -> impl FnOnce() + Send + Sync + 'static
684    where
685        F: Fn(Value) + Send + Sync + 'static,
686    {
687        let token = self
688            .inner
689            .next_listener_token
690            .fetch_add(1, Ordering::Relaxed);
691        self.inner
692            .event_listeners
693            .lock()
694            .entry(event_id.to_string())
695            .or_default()
696            .insert(token, Arc::new(listener));
697
698        let inner = Arc::clone(&self.inner);
699        let event_id = event_id.to_string();
700        move || {
701            let mut map = inner.event_listeners.lock();
702            if let Some(slot) = map.get_mut(&event_id) {
703                slot.remove(&token);
704                if slot.is_empty() {
705                    map.remove(&event_id);
706                }
707            }
708        }
709    }
710
711    fn dispatch_event_locally(&self, event_id: &str, payload: &Value) {
712        let listeners: Vec<EventListener> = self
713            .inner
714            .event_listeners
715            .lock()
716            .get(event_id)
717            .map(|m| m.values().cloned().collect())
718            .unwrap_or_default();
719        for l in listeners {
720            l(payload.clone());
721        }
722    }
723
724    /// Tears down the registry: awaits `close()` on every connected
725    /// channel, drops every local and remote command, and clears all
726    /// event listeners. In-flight executes and register requests fail
727    /// with [`CommandError::ChannelDisconnected`] via the existing
728    /// channel-close path.
729    ///
730    /// Mirrors the TypeScript library's `dispose()`, but async so that
731    /// transports doing real teardown work (HTTP flush, MCP goodbye,
732    /// plugin sandbox shutdown) complete before this returns.
733    ///
734    /// Callers normally don't need this — dropping the last
735    /// `CommandRegistry` clone releases the inner state automatically
736    /// via `Drop`. Use `dispose` when a *shared* registry (held through
737    /// multiple clones) needs to be forcibly torn down, or in tests.
738    pub async fn dispose(&self) {
739        // Snapshot channel arcs so we can call `close` without holding
740        // the channels lock for the duration.
741        let channels: Vec<Arc<dyn CommandChannel>> = {
742            let mut locked = self.inner.channels.lock();
743            let out: Vec<_> = locked.values().cloned().collect();
744            locked.clear();
745            out
746        };
747
748        // Await each channel's close sequentially. Channels define
749        // their own close semantics (InMemoryChannel is effectively
750        // synchronous; MCPServerChannel flushes the transport; Flow's
751        // SourceChannel tears down its QuickJS VM), so we let every
752        // implementation finish its teardown before returning.
753        for ch in channels {
754            ch.close().await;
755        }
756
757        self.inner.local.lock().clear();
758        self.inner.remote.lock().clear();
759        self.inner.remote_defs.lock().clear();
760        self.inner.event_listeners.lock().clear();
761    }
762}
763
764impl Inner {
765    fn local_command_defs(&self) -> Vec<CommandDef> {
766        self.local
767            .lock()
768            .values()
769            .filter(|e| !e.is_private)
770            .map(|e| e.def.clone())
771            .collect()
772    }
773
774    /// Central dispatcher invoked by each channel's driver loop.
775    async fn handle_message(inner: Arc<Self>, channel: Arc<dyn CommandChannel>, msg: Message) {
776        match msg {
777            Message::RegisterCommandRequest { id, command, .. } => {
778                Self::handle_register_request(inner, channel, id, command).await;
779            }
780            Message::RegisterCommandResponse { thid, response, .. } => {
781                if let Some(pending) = inner.register_replies.remove(&thid) {
782                    let _ = pending.tx.send(RegisterOutcome::Wire(response));
783                }
784            }
785            Message::ListCommandsRequest { id, .. } => {
786                let commands = inner.local_command_defs();
787                let _ = channel.send(Message::ListCommandsResponse {
788                    id: MessageId::new_v4(),
789                    meta: None,
790                    thid: id,
791                    commands,
792                });
793            }
794            Message::ListCommandsResponse { commands, .. } => {
795                let channel_id = channel.id().to_string();
796                let mut remote = inner.remote.lock();
797                let mut remote_defs = inner.remote_defs.lock();
798                for cmd in commands {
799                    // Normalize ingested schemas so the local cache
800                    // matches what register() produces.
801                    let cmd = CommandDef {
802                        id: cmd.id,
803                        description: cmd.description,
804                        schema: cmd.schema.map(crate::schema::normalize_command_schema),
805                    };
806                    let entry_is_new = !remote.contains_key(&cmd.id);
807                    if entry_is_new {
808                        remote.insert(cmd.id.clone(), channel_id.clone());
809                    }
810                    // Always refresh the def (the latest advertisement wins).
811                    remote_defs.insert(cmd.id.clone(), cmd);
812                }
813            }
814            Message::ExecuteCommandRequest {
815                id,
816                command_id,
817                request,
818                ..
819            } => {
820                Self::handle_execute_request(
821                    inner,
822                    channel,
823                    id,
824                    command_id,
825                    request.unwrap_or(Value::Null),
826                )
827                .await;
828            }
829            Message::ExecuteCommandResponse { thid, response, .. } => {
830                Self::handle_execute_response(&inner, thid, response);
831            }
832            Message::Event {
833                id,
834                event_id,
835                payload,
836                ..
837            } => {
838                Self::handle_event(&inner, channel, id, event_id, payload);
839            }
840        }
841    }
842
843    async fn handle_register_request(
844        inner: Arc<Self>,
845        channel: Arc<dyn CommandChannel>,
846        req_id: MessageId,
847        command: CommandDef,
848    ) {
849        // Normalize ingested schemas so our cached copy is guaranteed
850        // to be language-agnostic JSON Schema, even if the peer didn't
851        // normalize on its side.
852        let command = CommandDef {
853            id: command.id,
854            description: command.description,
855            schema: command.schema.map(crate::schema::normalize_command_schema),
856        };
857        let channel_id = channel.id().to_string();
858        let command_id = command.id.clone();
859
860        // Duplicate against local?
861        let dup = inner.local.lock().contains_key(&command_id);
862        if dup {
863            let _ = channel.send(Message::RegisterCommandResponse {
864                id: MessageId::new_v4(),
865                meta: None,
866                thid: req_id,
867                response: RegisterResult::Err {
868                    ok: False,
869                    error: RegisterErrorCode::DuplicateCommand,
870                },
871            });
872            return;
873        }
874
875        // Already known in the remote table?
876        // - Same channel re-advertising: short-circuit with a success
877        //   ack, do NOT re-escalate upstream. A re-escalation would
878        //   bubble up as a duplicate rejection from the router and
879        //   spuriously fail a legitimate re-registration (common when
880        //   a plugin host re-advertises its command list).
881        // - Different channel claiming the same id: reject as duplicate.
882        let existing_owner = inner.remote.lock().get(&command_id).cloned();
883        match existing_owner {
884            Some(owner) if owner == channel_id => {
885                // Refresh the cached def and re-ack. No upstream traffic.
886                inner.remote_defs.lock().insert(command_id, command);
887                let _ = channel.send(Message::RegisterCommandResponse {
888                    id: MessageId::new_v4(),
889                    meta: None,
890                    thid: req_id,
891                    response: RegisterResult::Ok { ok: True },
892                });
893                return;
894            }
895            Some(_) => {
896                let _ = channel.send(Message::RegisterCommandResponse {
897                    id: MessageId::new_v4(),
898                    meta: None,
899                    thid: req_id,
900                    response: RegisterResult::Err {
901                        ok: False,
902                        error: RegisterErrorCode::DuplicateCommand,
903                    },
904                });
905                return;
906            }
907            None => {}
908        }
909
910        // Escalate upstream if we have a router.
911        if let Some(router_id) = inner.router_channel.clone() {
912            if router_id != channel_id {
913                let router_ch = inner.channels.lock().get(&router_id).cloned();
914                if let Some(router_ch) = router_ch {
915                    let up_id = MessageId::new_v4();
916                    let (tx, rx) = oneshot::channel();
917                    inner.register_replies.insert(
918                        up_id,
919                        PendingRegister {
920                            tx,
921                            target_channel: router_id,
922                        },
923                    );
924                    if router_ch
925                        .send(Message::RegisterCommandRequest {
926                            id: up_id,
927                            meta: None,
928                            command: command.clone(),
929                        })
930                        .is_ok()
931                    {
932                        let up = rx.await;
933                        match up {
934                            Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
935                            Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
936                                let _ = channel.send(Message::RegisterCommandResponse {
937                                    id: MessageId::new_v4(),
938                                    meta: None,
939                                    thid: req_id,
940                                    response: RegisterResult::Err { ok: False, error },
941                                });
942                                return;
943                            }
944                            // Timeout / disconnected upstream: we have no
945                            // wire-level error code for these on the
946                            // register response, so surface them as
947                            // duplicate_command to match the prior
948                            // behaviour. The escalating caller's own
949                            // register_command call will see the correct
950                            // Timeout / ChannelDisconnected via its own
951                            // pending entry.
952                            Ok(RegisterOutcome::Timeout)
953                            | Ok(RegisterOutcome::Disconnected)
954                            | Err(_) => {
955                                let _ = channel.send(Message::RegisterCommandResponse {
956                                    id: MessageId::new_v4(),
957                                    meta: None,
958                                    thid: req_id,
959                                    response: RegisterResult::Err {
960                                        ok: False,
961                                        error: RegisterErrorCode::DuplicateCommand,
962                                    },
963                                });
964                                return;
965                            }
966                        }
967                    }
968                }
969            }
970        }
971
972        inner.remote.lock().insert(command_id.clone(), channel_id);
973        inner.remote_defs.lock().insert(command_id, command);
974        let _ = channel.send(Message::RegisterCommandResponse {
975            id: MessageId::new_v4(),
976            meta: None,
977            thid: req_id,
978            response: RegisterResult::Ok { ok: True },
979        });
980    }
981
982    async fn handle_execute_request(
983        inner: Arc<Self>,
984        origin: Arc<dyn CommandChannel>,
985        req_id: MessageId,
986        command_id: String,
987        request: Value,
988    ) {
989        // Local handler?
990        let handler = inner
991            .local
992            .lock()
993            .get(&command_id)
994            .map(|e| e.handler.clone());
995        if let Some(handler) = handler {
996            let result = handler(request).await;
997            let response = match result {
998                Ok(v) => ExecuteResult::Ok {
999                    ok: True,
1000                    // Void responses (`() → Value::Null`) are elided from
1001                    // the wire per the response schema (`result` optional).
1002                    result: value_to_wire(v),
1003                },
1004                Err(error) => ExecuteResult::Err { ok: False, error },
1005            };
1006            let _ = origin.send(Message::ExecuteCommandResponse {
1007                id: MessageId::new_v4(),
1008                meta: None,
1009                thid: req_id,
1010                response,
1011            });
1012            return;
1013        }
1014
1015        // Forward?
1016        let target_id = inner
1017            .remote
1018            .lock()
1019            .get(&command_id)
1020            .cloned()
1021            .or_else(|| inner.router_channel.clone());
1022
1023        let origin_id = origin.id().to_string();
1024        let Some(target_id) = target_id else {
1025            let _ = origin.send(Message::ExecuteCommandResponse {
1026                id: MessageId::new_v4(),
1027                meta: None,
1028                thid: req_id,
1029                response: ExecuteResult::Err {
1030                    ok: False,
1031                    error: ExecuteError {
1032                        code: ExecuteErrorCode::NotFound,
1033                        message: format!("command not found: {command_id}"),
1034                    },
1035                },
1036            });
1037            return;
1038        };
1039
1040        if target_id == origin_id {
1041            // Would loop; treat as not found.
1042            let _ = origin.send(Message::ExecuteCommandResponse {
1043                id: MessageId::new_v4(),
1044                meta: None,
1045                thid: req_id,
1046                response: ExecuteResult::Err {
1047                    ok: False,
1048                    error: ExecuteError {
1049                        code: ExecuteErrorCode::NotFound,
1050                        message: format!("command not found: {command_id}"),
1051                    },
1052                },
1053            });
1054            return;
1055        }
1056
1057        let target = inner.channels.lock().get(&target_id).cloned();
1058        let Some(target) = target else {
1059            let _ = origin.send(Message::ExecuteCommandResponse {
1060                id: MessageId::new_v4(),
1061                meta: None,
1062                thid: req_id,
1063                response: ExecuteResult::Err {
1064                    ok: False,
1065                    error: ExecuteError {
1066                        code: ExecuteErrorCode::ChannelDisconnected,
1067                        message: "target channel disconnected".into(),
1068                    },
1069                },
1070            });
1071            return;
1072        };
1073
1074        inner.routes.insert(
1075            req_id,
1076            RouteEntry {
1077                origin_channel: origin_id,
1078                target_channel: target_id,
1079            },
1080        );
1081        let _ = target.send(Message::ExecuteCommandRequest {
1082            id: req_id,
1083            meta: None,
1084            command_id,
1085            request: value_to_wire(request),
1086        });
1087    }
1088
1089    fn handle_execute_response(inner: &Arc<Self>, thid: MessageId, response: ExecuteResult) {
1090        // Either this is a reply to a local call…
1091        if let Some(pending) = inner.execute_replies.remove(&thid) {
1092            let _ = pending.tx.send(response);
1093            return;
1094        }
1095
1096        // …or we forwarded this request and need to route the reply.
1097        if let Some(route) = inner.routes.remove(&thid) {
1098            let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1099            if let Some(origin) = origin {
1100                let _ = origin.send(Message::ExecuteCommandResponse {
1101                    id: MessageId::new_v4(),
1102                    meta: None,
1103                    thid,
1104                    response,
1105                });
1106            }
1107        }
1108    }
1109
1110    fn handle_event(
1111        inner: &Arc<Self>,
1112        origin: Arc<dyn CommandChannel>,
1113        msg_id: MessageId,
1114        event_id: String,
1115        payload: Option<Value>,
1116    ) {
1117        if inner.seen_events.contains_key(&msg_id) {
1118            return;
1119        }
1120        inner.seen_events.insert(msg_id, ());
1121
1122        let payload_value = payload.clone().unwrap_or(Value::Null);
1123        let listeners: Vec<EventListener> = inner
1124            .event_listeners
1125            .lock()
1126            .get(&event_id)
1127            .map(|m| m.values().cloned().collect())
1128            .unwrap_or_default();
1129        for l in listeners {
1130            l(payload_value.clone());
1131        }
1132
1133        if event_id.starts_with('_') {
1134            return;
1135        }
1136
1137        let channels: Vec<Arc<dyn CommandChannel>> = inner
1138            .channels
1139            .lock()
1140            .iter()
1141            .filter(|(k, _)| k.as_str() != origin.id())
1142            .map(|(_, v)| v.clone())
1143            .collect();
1144        for ch in channels {
1145            let _ = ch.send(Message::Event {
1146                id: msg_id,
1147                meta: None,
1148                event_id: event_id.clone(),
1149                payload: payload.clone(),
1150            });
1151        }
1152    }
1153
1154    /// Invoked by the driver once the channel returns `None` from recv.
1155    fn handle_channel_close(inner: &Arc<Self>, channel_id: &str) {
1156        // Drop the channel from the lookup table.
1157        inner.channels.lock().remove(channel_id);
1158
1159        // Drop every remote command owned by this channel, along with
1160        // its cached definition.
1161        let dropped_ids: Vec<String> = {
1162            let mut remote = inner.remote.lock();
1163            let to_drop: Vec<String> = remote
1164                .iter()
1165                .filter(|(_, owner)| *owner == channel_id)
1166                .map(|(id, _)| id.clone())
1167                .collect();
1168            for id in &to_drop {
1169                remote.remove(id);
1170            }
1171            to_drop
1172        };
1173        let mut remote_defs = inner.remote_defs.lock();
1174        for id in dropped_ids {
1175            remote_defs.remove(&id);
1176        }
1177        drop(remote_defs);
1178
1179        // Reject any pending executes whose response was expected from
1180        // this channel.
1181        let exec_ids: Vec<MessageId> = inner
1182            .execute_replies
1183            .snapshot_keys_where(|v| v.target_channel == channel_id);
1184        for id in exec_ids {
1185            if let Some(pending) = inner.execute_replies.remove(&id) {
1186                let _ = pending.tx.send(ExecuteResult::Err {
1187                    ok: False,
1188                    error: ExecuteError {
1189                        code: ExecuteErrorCode::ChannelDisconnected,
1190                        message: "channel disconnected".into(),
1191                    },
1192                });
1193            }
1194        }
1195
1196        let reg_ids: Vec<MessageId> = inner
1197            .register_replies
1198            .snapshot_keys_where(|v| v.target_channel == channel_id);
1199        for id in reg_ids {
1200            if let Some(pending) = inner.register_replies.remove(&id) {
1201                let _ = pending.tx.send(RegisterOutcome::Disconnected);
1202            }
1203        }
1204
1205        // For every route where either endpoint is the dead channel,
1206        // notify the origin (if it is still alive).
1207        let route_ids: Vec<MessageId> = inner.routes.snapshot_keys_where(|r| {
1208            r.origin_channel == channel_id || r.target_channel == channel_id
1209        });
1210        for id in route_ids {
1211            if let Some(route) = inner.routes.remove(&id) {
1212                if route.origin_channel == channel_id {
1213                    continue;
1214                }
1215                let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1216                if let Some(origin) = origin {
1217                    let _ = origin.send(Message::ExecuteCommandResponse {
1218                        id: MessageId::new_v4(),
1219                        meta: None,
1220                        thid: id,
1221                        response: ExecuteResult::Err {
1222                            ok: False,
1223                            error: ExecuteError {
1224                                code: ExecuteErrorCode::ChannelDisconnected,
1225                                message: "target channel disconnected".into(),
1226                            },
1227                        },
1228                    });
1229                }
1230            }
1231        }
1232    }
1233}
1234
1235// ------- helpers -----------------------------------------------------
1236
1237fn command_error_to_execute(e: &CommandError, command_id: &str) -> ExecuteError {
1238    match e {
1239        CommandError::InvalidRequest { message, .. } => ExecuteError {
1240            code: ExecuteErrorCode::InvalidRequest,
1241            message: message.clone(),
1242        },
1243        CommandError::Internal { message, .. } => ExecuteError {
1244            code: ExecuteErrorCode::InternalError,
1245            message: message.clone(),
1246        },
1247        CommandError::Timeout => ExecuteError {
1248            code: ExecuteErrorCode::Timeout,
1249            message: "request timed out".into(),
1250        },
1251        CommandError::ChannelDisconnected => ExecuteError {
1252            code: ExecuteErrorCode::ChannelDisconnected,
1253            message: "channel disconnected".into(),
1254        },
1255        CommandError::NotFound(id) => ExecuteError {
1256            code: ExecuteErrorCode::NotFound,
1257            message: format!("command not found: {id}"),
1258        },
1259        _ => ExecuteError {
1260            code: ExecuteErrorCode::InternalError,
1261            message: format!("{e} [command {command_id}]"),
1262        },
1263    }
1264}
1265
1266fn error_to_command_error(err: ExecuteError, command_id: &str) -> CommandError {
1267    match err.code {
1268        ExecuteErrorCode::NotFound => CommandError::NotFound(err.message),
1269        ExecuteErrorCode::InvalidRequest => CommandError::InvalidRequest {
1270            command_id: command_id.into(),
1271            message: err.message,
1272        },
1273        ExecuteErrorCode::InternalError => CommandError::Internal {
1274            command_id: command_id.into(),
1275            message: err.message,
1276        },
1277        ExecuteErrorCode::Timeout => CommandError::Timeout,
1278        ExecuteErrorCode::ChannelDisconnected => CommandError::ChannelDisconnected,
1279    }
1280}
1281
1282// Small convenience on ExecuteError.
1283impl ExecuteError {
1284    fn into_command_error(self, command_id: &str) -> CommandError {
1285        error_to_command_error(self, command_id)
1286    }
1287}
1288
1289/// Collapse a serialized request/result/payload value to `None` when it
1290/// is JSON `null`. This is what makes void commands and events
1291/// spec-compliant on the wire: `request` / `result` / `payload` are all
1292/// optional fields in the JSON schemas, so an absent value must be
1293/// encoded by omitting the key, not by emitting `null`.
1294///
1295/// Used on every outgoing `execute.command.request`,
1296/// `execute.command.response` success, and `event` message.
1297fn value_to_wire(v: Value) -> Option<Value> {
1298    if v.is_null() {
1299        None
1300    } else {
1301        Some(v)
1302    }
1303}
1304
1305/// Serialize a strict-mode request value to JSON. Wraps `serde_json`
1306/// with the right error type for the strict `execute::<C>` path.
1307fn value_from_request<T: Serialize>(v: &T) -> Result<Value, CommandError> {
1308    serde_json::to_value(v).map_err(CommandError::Serde)
1309}