Skip to main content

coralstack_cmd_ipc/
registry.rs

1//! The [`CommandRegistry`] — core routing and execution hub.
2//!
3//! Mirrors the TypeScript `CommandRegistry` in
4//! `packages/cmd-ipc/src/registry/command-registry.ts`: it owns a local
5//! command table, a remote command table (command → owning channel), a
6//! set of connected channels, and a handful of [`TtlMap`]s correlating
7//! in-flight requests, forwarded routes, and recently-seen events.
8//!
9//! # Topology
10//!
11//! * **Root** registries have no `router_channel`. Unknown commands
12//!   produce `NotFound` errors.
13//! * **Child** registries set `router_channel = Some(peer_id)`. Unknown
14//!   commands, and new local registrations, are escalated upstream.
15//! * Events fan out across every connected channel; dedup by message
16//!   id prevents echo loops in meshes.
17//!
18//! Private commands/events (identifiers starting with `_`) stay local
19//! — never escalated, never advertised, never broadcast.
20
21use std::collections::{BTreeMap, HashMap};
22use std::future::Future;
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::Duration;
26
27use futures::channel::oneshot;
28use futures::future::BoxFuture;
29use futures::stream::FuturesUnordered;
30use futures::{FutureExt, StreamExt};
31use parking_lot::Mutex;
32use serde::Serialize;
33use serde_json::Value;
34use uuid::Uuid;
35
36use crate::channel::CommandChannel;
37use crate::command::Command;
38use crate::error::{ChannelError, CommandError, ExecuteErrorCode, RegisterErrorCode};
39use crate::event::Event;
40use crate::message::{
41    CommandDef, ExecuteError, ExecuteResult, False, Message, MessageId, RegisterResult, True,
42};
43use crate::ttl_map::TtlMap;
44
45/// Configuration for a [`CommandRegistry`].
46pub struct Config {
47    /// Registry identifier used in log messages. Defaults to a random
48    /// UUID.
49    pub id: Option<String>,
50    /// Channel id to escalate unknown commands and new registrations
51    /// to. Leave `None` for a root registry.
52    pub router_channel: Option<String>,
53    /// How long a pending `execute` / `register` reply can wait before
54    /// being rejected with [`CommandError::Timeout`]. Zero disables
55    /// the TTL check (request hangs until the channel closes).
56    pub request_ttl: Duration,
57    /// How long a seen event id is remembered for dedup purposes.
58    pub event_ttl: Duration,
59    /// Maximum number of handler futures that may be in flight
60    /// concurrently on a single channel's pump. Once this cap is
61    /// reached, the pump stops pulling new messages off the channel
62    /// (applying upstream backpressure) until an in-flight handler
63    /// finishes. `0` disables the cap.
64    ///
65    /// **Optional in practice** — `Config::default()` sets this to
66    /// `256`, so callers using `..Default::default()` never need to
67    /// supply it. It is only *syntactically* required when you build
68    /// `Config { … }` field-by-field (Rust struct literals must list
69    /// every field). The cap exists so a misbehaving peer can't cause
70    /// unbounded handler-future buildup; the default is fine for
71    /// almost every workload.
72    pub max_in_flight_per_channel: usize,
73}
74
75impl Default for Config {
76    fn default() -> Self {
77        Self {
78            id: None,
79            router_channel: None,
80            request_ttl: Duration::from_secs(30),
81            event_ttl: Duration::from_secs(5),
82            max_in_flight_per_channel: 256,
83        }
84    }
85}
86
87type HandlerFn = dyn Fn(Value) -> BoxFuture<'static, Result<Value, ExecuteError>> + Send + Sync;
88type EventListener = Arc<dyn Fn(Value) + Send + Sync>;
89
90struct LocalEntry {
91    handler: Arc<HandlerFn>,
92    def: CommandDef,
93    is_private: bool,
94}
95
96struct PendingExecute {
97    tx: oneshot::Sender<ExecuteResult>,
98    target_channel: String,
99}
100
101/// Outcome delivered to the caller awaiting a `register.command.request`
102/// reply. The peer's `RegisterResult` is wrapped so the registry can
103/// synthesize timeout / disconnect errors distinguishable from the
104/// wire-level duplicate-command case.
105enum RegisterOutcome {
106    Wire(RegisterResult),
107    Timeout,
108    Disconnected,
109}
110
111struct PendingRegister {
112    tx: oneshot::Sender<RegisterOutcome>,
113    target_channel: String,
114}
115
116struct RouteEntry {
117    origin_channel: String,
118    target_channel: String,
119}
120
121/// Shared state behind the [`CommandRegistry`] Arc.
122struct Inner {
123    id: String,
124    router_channel: Option<String>,
125    local: Mutex<HashMap<String, LocalEntry>>,
126    /// command id -> owning channel id
127    remote: Mutex<HashMap<String, String>>,
128    /// command id -> advertised definition (description + schema)
129    /// kept parallel to `remote` so `list_commands` can render
130    /// the same richness for remote entries as for local ones.
131    remote_defs: Mutex<HashMap<String, CommandDef>>,
132    channels: Mutex<HashMap<String, Arc<dyn CommandChannel>>>,
133    execute_replies: TtlMap<MessageId, PendingExecute>,
134    register_replies: TtlMap<MessageId, PendingRegister>,
135    routes: TtlMap<MessageId, RouteEntry>,
136    seen_events: TtlMap<MessageId, ()>,
137    /// Event listeners keyed by event id then by a monotonic token, so
138    /// `add_event_listener` can return an unsubscribe closure that
139    /// removes just that one listener. `BTreeMap` preserves insertion
140    /// order (tokens are monotonically increasing) for dispatch.
141    event_listeners: Mutex<HashMap<String, BTreeMap<u64, EventListener>>>,
142    /// Monotonically-increasing token used to key event listeners.
143    next_listener_token: AtomicU64,
144    /// Per-channel in-flight cap, copied from [`Config`].
145    max_in_flight_per_channel: usize,
146}
147
148/// The main entry point of the crate.
149///
150/// A registry is cheap to clone: internally it's an `Arc<Inner>`.
151#[derive(Clone)]
152pub struct CommandRegistry {
153    inner: Arc<Inner>,
154}
155
156impl CommandRegistry {
157    pub fn new(cfg: Config) -> Self {
158        // On TTL expiry, deliver an explicit Timeout outcome on the
159        // pending oneshot so the caller's `rx.await` resolves to
160        // `CommandError::Timeout` rather than blocking forever (or
161        // collapsing to a generic `ChannelDisconnected` when the tx is
162        // dropped). Eviction is lazy — driven by `sweep_expired()`
163        // calls in the driver loop and registry hot paths.
164        let execute_replies =
165            TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingExecute| {
166                let _ = pending.tx.send(ExecuteResult::Err {
167                    ok: False,
168                    error: ExecuteError {
169                        code: ExecuteErrorCode::Timeout,
170                        message: "request timed out".into(),
171                    },
172                });
173            });
174        let register_replies =
175            TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingRegister| {
176                let _ = pending.tx.send(RegisterOutcome::Timeout);
177            });
178
179        let inner = Arc::new(Inner {
180            id: cfg.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
181            router_channel: cfg.router_channel,
182            local: Mutex::new(HashMap::new()),
183            remote: Mutex::new(HashMap::new()),
184            remote_defs: Mutex::new(HashMap::new()),
185            channels: Mutex::new(HashMap::new()),
186            execute_replies,
187            register_replies,
188            routes: TtlMap::new(cfg.request_ttl),
189            seen_events: TtlMap::new(cfg.event_ttl),
190            event_listeners: Mutex::new(HashMap::new()),
191            next_listener_token: AtomicU64::new(0),
192            max_in_flight_per_channel: cfg.max_in_flight_per_channel,
193        });
194        Self { inner }
195    }
196
197    /// Returns this registry's identifier.
198    pub fn id(&self) -> &str {
199        &self.inner.id
200    }
201
202    /// Returns the ids of every currently-registered channel, sorted.
203    ///
204    /// Mirrors the TypeScript library's `listChannels()` method.
205    pub fn list_channels(&self) -> Vec<String> {
206        let mut ids: Vec<String> = self.inner.channels.lock().keys().cloned().collect();
207        ids.sort();
208        ids
209    }
210
211    /// Returns the full [`CommandDef`] (id + description + schema) for
212    /// every reachable command — local (non-private) and remote. Remote
213    /// defs are those advertised via `register.command.request` or
214    /// `list.commands.response` on the channel.
215    ///
216    /// Mirrors the TypeScript library's `listCommands()` method.
217    /// Results are sorted by id. A command id is only included once even
218    /// if both a local and remote entry exist (local wins).
219    pub fn list_commands(&self) -> Vec<CommandDef> {
220        let mut out: HashMap<String, CommandDef> = HashMap::new();
221        for (id, entry) in self.inner.local.lock().iter() {
222            if !entry.is_private {
223                out.insert(id.clone(), entry.def.clone());
224            }
225        }
226        for (id, def) in self.inner.remote_defs.lock().iter() {
227            out.entry(id.clone()).or_insert_with(|| def.clone());
228        }
229        let mut v: Vec<CommandDef> = out.into_values().collect();
230        v.sort_by(|a, b| a.id.cmp(&b.id));
231        v
232    }
233
234    /// Register a command on this registry.
235    ///
236    /// The single registration entry point, covering both compile-time
237    /// and runtime commands:
238    ///
239    /// - **Compile-time**: pass an instance of a type that implements
240    ///   [`Command`]. The `#[command]` / `#[command_service]` macros
241    ///   generate such types from a plain `async fn`.
242    /// - **Runtime**: pass a [`DynCommand`] carrying owned id /
243    ///   description / schema and a closure handler.
244    ///
245    /// Mirrors the TypeScript library's `registerCommand`.
246    ///
247    /// - Commands whose id starts with `_` stay local: they are never
248    ///   escalated to a `router_channel` and never advertised to peers
249    ///   via `list.commands.response`.
250    /// - Non-private commands are escalated upstream if this registry
251    ///   has a `router_channel`; the local entry is only committed
252    ///   after the router acks.
253    /// - The advertised schema is normalized via
254    ///   [`crate::schema::normalize_schema`] on the way in, so every
255    ///   schema leaving the registry is language-agnostic JSON Schema
256    ///   regardless of how the caller built it.
257    pub async fn register_command<C: Command>(&self, cmd: C) -> Result<(), CommandError> {
258        let id = cmd.id().to_string();
259        let description = cmd.description().map(str::to_string);
260        let schema = cmd.schema().map(crate::schema::normalize_command_schema);
261        let is_private = id.starts_with('_');
262        let def = CommandDef {
263            id: id.clone(),
264            description,
265            schema,
266        };
267        let handler: Arc<HandlerFn> = Arc::new({
268            let cmd = Arc::new(cmd);
269            move |value: Value| {
270                let cmd = cmd.clone();
271                async move {
272                    let req: C::Request =
273                        serde_json::from_value(value).map_err(|e| ExecuteError {
274                            code: ExecuteErrorCode::InvalidRequest,
275                            message: e.to_string(),
276                        })?;
277                    let res = cmd
278                        .handle(req)
279                        .await
280                        .map_err(|e| command_error_to_execute(&e, cmd.id()))?;
281                    serde_json::to_value(res).map_err(|e| ExecuteError {
282                        code: ExecuteErrorCode::InternalError,
283                        message: e.to_string(),
284                    })
285                }
286                .boxed()
287            }
288        });
289        self.register_inner(id, handler, def, is_private).await
290    }
291
292    async fn register_inner(
293        &self,
294        id: String,
295        handler: Arc<HandlerFn>,
296        def: CommandDef,
297        is_private: bool,
298    ) -> Result<(), CommandError> {
299        self.inner.execute_replies.sweep_expired();
300        self.inner.register_replies.sweep_expired();
301        // Duplicate check against the local table.
302        if self.inner.local.lock().contains_key(&id) {
303            return Err(CommandError::DuplicateCommand(id));
304        }
305
306        // Non-private commands escalate to the router before being added.
307        if !is_private {
308            if let Some(router_id) = self.inner.router_channel.clone() {
309                let router_ch = self.inner.channels.lock().get(&router_id).cloned();
310                if let Some(router_ch) = router_ch {
311                    let req_id = MessageId::new_v4();
312                    let (tx, rx) = oneshot::channel();
313                    self.inner.register_replies.insert(
314                        req_id,
315                        PendingRegister {
316                            tx,
317                            target_channel: router_id.clone(),
318                        },
319                    );
320                    router_ch
321                        .send(Message::RegisterCommandRequest {
322                            id: req_id,
323                            command: def.clone(),
324                        })
325                        .map_err(|_| CommandError::ChannelDisconnected)?;
326                    match rx.await {
327                        Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
328                        Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
329                            return Err(match error {
330                                RegisterErrorCode::DuplicateCommand => {
331                                    CommandError::DuplicateCommand(id)
332                                }
333                            });
334                        }
335                        Ok(RegisterOutcome::Timeout) => return Err(CommandError::Timeout),
336                        Ok(RegisterOutcome::Disconnected) | Err(_) => {
337                            return Err(CommandError::ChannelDisconnected);
338                        }
339                    }
340                }
341            }
342        }
343
344        self.inner.local.lock().insert(
345            id,
346            LocalEntry {
347                handler,
348                def,
349                is_private,
350            },
351        );
352        Ok(())
353    }
354
355    /// Connects a [`CommandChannel`] to this registry.
356    ///
357    /// Returns a driver future which must be polled by the caller's
358    /// executor (via `tokio::spawn`, `smol::spawn`,
359    /// `futures::executor::block_on`, …) for the registry to exchange
360    /// messages with the peer. The future completes when the channel
361    /// closes.
362    ///
363    /// Handler dispatch is **concurrent within the driver task**: the
364    /// pump pushes each incoming message's handler future into a
365    /// [`FuturesUnordered`] and cooperatively interleaves them with
366    /// the next `recv`, so a slow handler that awaits external work
367    /// no longer blocks subsequent messages on the same channel. The
368    /// number of simultaneously in-flight handlers is capped by
369    /// [`Config::max_in_flight_per_channel`] (default 256); at the
370    /// cap the pump applies backpressure to the channel rather than
371    /// dropping messages. For true multi-thread parallelism, wrap
372    /// the handler body in your runtime's `spawn` (e.g.
373    /// `tokio::spawn`) — the crate itself stays runtime-agnostic.
374    pub async fn register_channel(
375        &self,
376        channel: Arc<dyn CommandChannel>,
377    ) -> Result<impl Future<Output = ()> + Send + 'static, ChannelError> {
378        let id = channel.id().to_string();
379        {
380            let mut chans = self.inner.channels.lock();
381            if chans.contains_key(&id) {
382                return Err(ChannelError::Other(format!(
383                    "channel with id `{id}` already registered"
384                )));
385            }
386            chans.insert(id.clone(), channel.clone());
387        }
388
389        channel.start().await?;
390
391        // Ask the peer for its command list. The response is handled
392        // by the driver loop, which will register each entry as remote.
393        if let Err(e) = channel.send(Message::ListCommandsRequest {
394            id: MessageId::new_v4(),
395        }) {
396            self.inner.channels.lock().remove(&id);
397            return Err(e);
398        }
399
400        let inner = self.inner.clone();
401        let ch = channel;
402        Ok(async move {
403            // In-flight handler futures cooperatively interleaved with
404            // message recv. This is what makes the pump non-blocking:
405            // when one handler awaits (sleep / network / forwarded
406            // call), the executor polls other handlers and the recv
407            // future on the same task, so a slow handler can no
408            // longer wedge every subsequent message behind it.
409            //
410            // Stays runtime-agnostic (no `tokio::spawn`); the user's
411            // executor still owns the single task driving this pump.
412            // True multi-thread parallelism is possible if the user
413            // wraps their handler body in `tokio::spawn` themselves.
414            let mut in_flight: FuturesUnordered<BoxFuture<'static, ()>> = FuturesUnordered::new();
415            let cap = inner.max_in_flight_per_channel;
416
417            loop {
418                // Backpressure: while at capacity, drain in-flight
419                // handlers instead of accepting new messages. Ordering
420                // of message *reception* is preserved — only dispatch
421                // fans out. Messages are never dropped.
422                while cap > 0 && in_flight.len() >= cap {
423                    if in_flight.next().await.is_none() {
424                        break;
425                    }
426                }
427
428                // Opportunistic TTL sweep: any registry activity
429                // triggers a pass over the pending-reply / route maps
430                // so timed-out entries fire their on_expire callbacks
431                // (delivering Timeout to waiting callers). Cheap —
432                // O(n) over maps that are normally tiny.
433                inner.execute_replies.sweep_expired();
434                inner.register_replies.sweep_expired();
435                inner.routes.sweep_expired();
436
437                // Wait for either the next message OR for an in-flight
438                // handler to make progress. When `in_flight` is empty
439                // we'd otherwise busy-loop on its `.next()` returning
440                // `None`, so park on `pending()` in that case.
441                let next_msg = if in_flight.is_empty() {
442                    ch.recv().await
443                } else {
444                    futures::select_biased! {
445                        // Drain finished handlers so the in_flight set
446                        // shrinks promptly. `select_next_some` skips
447                        // the empty case, but we've already guarded
448                        // against that above.
449                        _ = in_flight.select_next_some() => continue,
450                        msg = ch.recv().fuse() => msg,
451                    }
452                };
453
454                let Some(msg) = next_msg else { break };
455                in_flight.push(Inner::handle_message(inner.clone(), ch.clone(), msg).boxed());
456            }
457
458            // Channel closed — drain any in-flight handlers so their
459            // outgoing responses are sent before we tear the channel
460            // state down. Handlers that try to `send` on the now-closed
461            // channel will get `Err(Closed)` from the transport, which
462            // is the existing semantics.
463            while in_flight.next().await.is_some() {}
464            Inner::handle_channel_close(&inner, ch.id());
465        })
466    }
467
468    /// Executes a command identified by a compile-time [`Command`] type
469    /// — the **strict** form, giving the same compile-time type safety
470    /// that TypeScript's strict-mode `executeCommand<K>` gives via the
471    /// `CommandSchemaMap` type parameter.
472    ///
473    /// The command id comes from `C::ID`, the request type is pinned to
474    /// `C::Request`, and the response type is pinned to `C::Response`,
475    /// so the compiler rejects mismatches at the call site:
476    ///
477    /// ```ignore
478    /// let sum: i64 = registry.execute::<MathAdd>(AddReq { a: 2, b: 3 }).await?;
479    /// ```
480    ///
481    /// For commands whose id or payload shape is only known at runtime
482    /// (scripting hosts, FFI, plugins that advertise their own schema),
483    /// use [`execute_dyn`](Self::execute_dyn).
484    pub async fn execute<C: Command>(
485        &self,
486        request: C::Request,
487    ) -> Result<C::Response, CommandError>
488    where
489        C::Request: Serialize,
490        C::Response: serde::de::DeserializeOwned,
491    {
492        let req_value = value_from_request(&request)?;
493        let result = self.execute_raw_impl(C::ID.to_string(), req_value).await?;
494        let deserialized = serde_json::from_value(result.unwrap_or(Value::Null))?;
495        Ok(deserialized)
496    }
497
498    /// Executes a command whose id is only known at runtime — the
499    /// **loose** form, mirroring the TypeScript library's
500    /// `executeCommand(id, args)` in loose mode.
501    ///
502    /// Request and response are raw [`serde_json::Value`]s, so this is
503    /// the canonical entry point for plugin hosts, scripting runtimes,
504    /// FFI bridges, and any code where the schema is discovered via
505    /// [`list_commands`](Self::list_commands) rather than declared at
506    /// compile time.
507    ///
508    /// For statically-known commands, prefer [`execute`](Self::execute)
509    /// — it pins both types via the [`Command`] trait.
510    pub async fn execute_dyn(
511        &self,
512        command_id: &str,
513        request: Value,
514    ) -> Result<Value, CommandError> {
515        let result = self
516            .execute_raw_impl(command_id.to_string(), request)
517            .await?;
518        Ok(result.unwrap_or(Value::Null))
519    }
520
521    async fn execute_raw_impl(
522        &self,
523        command_id: String,
524        request: Value,
525    ) -> Result<Option<Value>, CommandError> {
526        // Flush any expired pending replies so stale entries fire their
527        // Timeout on_expire before we enqueue a new one.
528        self.inner.execute_replies.sweep_expired();
529        self.inner.register_replies.sweep_expired();
530        // 1) Local handler wins.
531        let local_handler = self
532            .inner
533            .local
534            .lock()
535            .get(&command_id)
536            .map(|entry| entry.handler.clone());
537        if let Some(handler) = local_handler {
538            return handler(request)
539                .await
540                .map(Some)
541                .map_err(|e| e.into_command_error(&command_id));
542        }
543
544        // 2) Known remote command.
545        let remote_target = self.inner.remote.lock().get(&command_id).cloned();
546        let target = match remote_target {
547            Some(t) => Some(t),
548            None => self.inner.router_channel.clone(),
549        };
550
551        let Some(target_id) = target else {
552            return Err(CommandError::NotFound(command_id));
553        };
554
555        let channel = self.inner.channels.lock().get(&target_id).cloned();
556        let Some(channel) = channel else {
557            return Err(CommandError::ChannelDisconnected);
558        };
559
560        self.forward_execute(command_id, request, &channel, target_id)
561            .await
562    }
563
564    async fn forward_execute(
565        &self,
566        command_id: String,
567        request: Value,
568        channel: &Arc<dyn CommandChannel>,
569        target_id: String,
570    ) -> Result<Option<Value>, CommandError> {
571        let req_id = MessageId::new_v4();
572        let (tx, rx) = oneshot::channel();
573        self.inner.execute_replies.insert(
574            req_id,
575            PendingExecute {
576                tx,
577                target_channel: target_id,
578            },
579        );
580        channel
581            .send(Message::ExecuteCommandRequest {
582                id: req_id,
583                command_id: command_id.clone(),
584                // Void requests are elided from the wire (Null → None)
585                // so peers expecting an absent `request` field (per the
586                // JSON Schema spec) don't see `"request": null`.
587                request: value_to_wire(request),
588            })
589            .map_err(|_| CommandError::ChannelDisconnected)?;
590
591        match rx.await {
592            Ok(ExecuteResult::Ok { result, .. }) => Ok(result),
593            Ok(ExecuteResult::Err { error, .. }) => Err(error_to_command_error(error, &command_id)),
594            Err(_) => {
595                self.inner.execute_replies.remove(&req_id);
596                Err(CommandError::ChannelDisconnected)
597            }
598        }
599    }
600
601    /// Emit an event. Dispatches to local listeners and — unless the
602    /// event id is private (starts with `_`) — broadcasts to every
603    /// connected channel.
604    ///
605    /// Works for both compile-time events (`#[event]`-annotated
606    /// structs) and runtime events ([`DynEvent`](crate::command::Command)
607    /// — actually [`DynEvent`](crate::event::DynEvent)). Id, description,
608    /// and schema are all read off the event instance.
609    pub fn emit<E: Event>(&self, event: E) -> Result<(), CommandError> {
610        let event_id = event.id().to_string();
611        let payload_value = serde_json::to_value(&event)?;
612        let msg_id = MessageId::new_v4();
613        self.inner.seen_events.insert(msg_id, ());
614
615        self.dispatch_event_locally(&event_id, &payload_value);
616
617        if !event_id.starts_with('_') {
618            let channels: Vec<Arc<dyn CommandChannel>> =
619                self.inner.channels.lock().values().cloned().collect();
620            // Void payloads (serde `()` → `Value::Null`) are elided
621            // from the wire per the event schema (`payload` is optional).
622            let wire_payload = value_to_wire(payload_value);
623            for ch in channels {
624                let _ = ch.send(Message::Event {
625                    id: msg_id,
626                    event_id: event_id.clone(),
627                    payload: wire_payload.clone(),
628                });
629            }
630        }
631        Ok(())
632    }
633
634    /// Subscribe a typed listener. The callback receives a
635    /// deserialized `E` every time an event with id `E::ID` fires,
636    /// whether emitted locally or received from a connected channel.
637    ///
638    /// Returns an unsubscribe closure — call it (and drop it) to
639    /// remove just this listener. Ignoring the return value is fine;
640    /// the listener then lives for the life of the registry.
641    ///
642    /// Listeners for the same event fire in insertion order. Payloads
643    /// that fail to deserialize into `E` are silently dropped for
644    /// this listener — they still flow to any typed-for-Value
645    /// listeners registered via [`on_dyn`](Self::on_dyn).
646    pub fn on<E: Event + serde::de::DeserializeOwned>(
647        &self,
648        listener: impl Fn(E) + Send + Sync + 'static,
649    ) -> impl FnOnce() + Send + Sync + 'static {
650        self.install_listener(E::ID, move |value| {
651            if let Ok(typed) = serde_json::from_value::<E>(value) {
652                listener(typed);
653            }
654        })
655    }
656
657    /// Subscribe a dynamic listener by runtime id. The callback
658    /// receives the raw JSON payload. Use this when the event id is
659    /// only known at runtime (plugin runtimes, FFI, scripting hosts);
660    /// prefer [`on`](Self::on) whenever you have a compile-time
661    /// [`Event`] type.
662    ///
663    /// Same unsubscribe semantics as [`on`](Self::on).
664    pub fn on_dyn<F>(
665        &self,
666        event_id: impl Into<String>,
667        listener: F,
668    ) -> impl FnOnce() + Send + Sync + 'static
669    where
670        F: Fn(Value) + Send + Sync + 'static,
671    {
672        self.install_listener(&event_id.into(), listener)
673    }
674
675    fn install_listener<F>(
676        &self,
677        event_id: &str,
678        listener: F,
679    ) -> impl FnOnce() + Send + Sync + 'static
680    where
681        F: Fn(Value) + Send + Sync + 'static,
682    {
683        let token = self
684            .inner
685            .next_listener_token
686            .fetch_add(1, Ordering::Relaxed);
687        self.inner
688            .event_listeners
689            .lock()
690            .entry(event_id.to_string())
691            .or_default()
692            .insert(token, Arc::new(listener));
693
694        let inner = Arc::clone(&self.inner);
695        let event_id = event_id.to_string();
696        move || {
697            let mut map = inner.event_listeners.lock();
698            if let Some(slot) = map.get_mut(&event_id) {
699                slot.remove(&token);
700                if slot.is_empty() {
701                    map.remove(&event_id);
702                }
703            }
704        }
705    }
706
707    fn dispatch_event_locally(&self, event_id: &str, payload: &Value) {
708        let listeners: Vec<EventListener> = self
709            .inner
710            .event_listeners
711            .lock()
712            .get(event_id)
713            .map(|m| m.values().cloned().collect())
714            .unwrap_or_default();
715        for l in listeners {
716            l(payload.clone());
717        }
718    }
719
720    /// Tears down the registry: awaits `close()` on every connected
721    /// channel, drops every local and remote command, and clears all
722    /// event listeners. In-flight executes and register requests fail
723    /// with [`CommandError::ChannelDisconnected`] via the existing
724    /// channel-close path.
725    ///
726    /// Mirrors the TypeScript library's `dispose()`, but async so that
727    /// transports doing real teardown work (HTTP flush, MCP goodbye,
728    /// plugin sandbox shutdown) complete before this returns.
729    ///
730    /// Callers normally don't need this — dropping the last
731    /// `CommandRegistry` clone releases the inner state automatically
732    /// via `Drop`. Use `dispose` when a *shared* registry (held through
733    /// multiple clones) needs to be forcibly torn down, or in tests.
734    pub async fn dispose(&self) {
735        // Snapshot channel arcs so we can call `close` without holding
736        // the channels lock for the duration.
737        let channels: Vec<Arc<dyn CommandChannel>> = {
738            let mut locked = self.inner.channels.lock();
739            let out: Vec<_> = locked.values().cloned().collect();
740            locked.clear();
741            out
742        };
743
744        // Await each channel's close sequentially. Channels define
745        // their own close semantics (InMemoryChannel is effectively
746        // synchronous; MCPServerChannel flushes the transport; Flow's
747        // SourceChannel tears down its QuickJS VM), so we let every
748        // implementation finish its teardown before returning.
749        for ch in channels {
750            ch.close().await;
751        }
752
753        self.inner.local.lock().clear();
754        self.inner.remote.lock().clear();
755        self.inner.remote_defs.lock().clear();
756        self.inner.event_listeners.lock().clear();
757    }
758}
759
760impl Inner {
761    fn local_command_defs(&self) -> Vec<CommandDef> {
762        self.local
763            .lock()
764            .values()
765            .filter(|e| !e.is_private)
766            .map(|e| e.def.clone())
767            .collect()
768    }
769
770    /// Central dispatcher invoked by each channel's driver loop.
771    async fn handle_message(inner: Arc<Self>, channel: Arc<dyn CommandChannel>, msg: Message) {
772        match msg {
773            Message::RegisterCommandRequest { id, command } => {
774                Self::handle_register_request(inner, channel, id, command).await;
775            }
776            Message::RegisterCommandResponse { thid, response, .. } => {
777                if let Some(pending) = inner.register_replies.remove(&thid) {
778                    let _ = pending.tx.send(RegisterOutcome::Wire(response));
779                }
780            }
781            Message::ListCommandsRequest { id } => {
782                let commands = inner.local_command_defs();
783                let _ = channel.send(Message::ListCommandsResponse {
784                    id: MessageId::new_v4(),
785                    thid: id,
786                    commands,
787                });
788            }
789            Message::ListCommandsResponse { commands, .. } => {
790                let channel_id = channel.id().to_string();
791                let mut remote = inner.remote.lock();
792                let mut remote_defs = inner.remote_defs.lock();
793                for cmd in commands {
794                    // Normalize ingested schemas so the local cache
795                    // matches what register() produces.
796                    let cmd = CommandDef {
797                        id: cmd.id,
798                        description: cmd.description,
799                        schema: cmd.schema.map(crate::schema::normalize_command_schema),
800                    };
801                    let entry_is_new = !remote.contains_key(&cmd.id);
802                    if entry_is_new {
803                        remote.insert(cmd.id.clone(), channel_id.clone());
804                    }
805                    // Always refresh the def (the latest advertisement wins).
806                    remote_defs.insert(cmd.id.clone(), cmd);
807                }
808            }
809            Message::ExecuteCommandRequest {
810                id,
811                command_id,
812                request,
813            } => {
814                Self::handle_execute_request(
815                    inner,
816                    channel,
817                    id,
818                    command_id,
819                    request.unwrap_or(Value::Null),
820                )
821                .await;
822            }
823            Message::ExecuteCommandResponse { thid, response, .. } => {
824                Self::handle_execute_response(&inner, thid, response);
825            }
826            Message::Event {
827                id,
828                event_id,
829                payload,
830            } => {
831                Self::handle_event(&inner, channel, id, event_id, payload);
832            }
833        }
834    }
835
836    async fn handle_register_request(
837        inner: Arc<Self>,
838        channel: Arc<dyn CommandChannel>,
839        req_id: MessageId,
840        command: CommandDef,
841    ) {
842        // Normalize ingested schemas so our cached copy is guaranteed
843        // to be language-agnostic JSON Schema, even if the peer didn't
844        // normalize on its side.
845        let command = CommandDef {
846            id: command.id,
847            description: command.description,
848            schema: command.schema.map(crate::schema::normalize_command_schema),
849        };
850        let channel_id = channel.id().to_string();
851        let command_id = command.id.clone();
852
853        // Duplicate against local?
854        let dup = inner.local.lock().contains_key(&command_id);
855        if dup {
856            let _ = channel.send(Message::RegisterCommandResponse {
857                id: MessageId::new_v4(),
858                thid: req_id,
859                response: RegisterResult::Err {
860                    ok: False,
861                    error: RegisterErrorCode::DuplicateCommand,
862                },
863            });
864            return;
865        }
866
867        // Already known in the remote table?
868        // - Same channel re-advertising: short-circuit with a success
869        //   ack, do NOT re-escalate upstream. A re-escalation would
870        //   bubble up as a duplicate rejection from the router and
871        //   spuriously fail a legitimate re-registration (common when
872        //   a plugin host re-advertises its command list).
873        // - Different channel claiming the same id: reject as duplicate.
874        let existing_owner = inner.remote.lock().get(&command_id).cloned();
875        match existing_owner {
876            Some(owner) if owner == channel_id => {
877                // Refresh the cached def and re-ack. No upstream traffic.
878                inner.remote_defs.lock().insert(command_id, command);
879                let _ = channel.send(Message::RegisterCommandResponse {
880                    id: MessageId::new_v4(),
881                    thid: req_id,
882                    response: RegisterResult::Ok { ok: True },
883                });
884                return;
885            }
886            Some(_) => {
887                let _ = channel.send(Message::RegisterCommandResponse {
888                    id: MessageId::new_v4(),
889                    thid: req_id,
890                    response: RegisterResult::Err {
891                        ok: False,
892                        error: RegisterErrorCode::DuplicateCommand,
893                    },
894                });
895                return;
896            }
897            None => {}
898        }
899
900        // Escalate upstream if we have a router.
901        if let Some(router_id) = inner.router_channel.clone() {
902            if router_id != channel_id {
903                let router_ch = inner.channels.lock().get(&router_id).cloned();
904                if let Some(router_ch) = router_ch {
905                    let up_id = MessageId::new_v4();
906                    let (tx, rx) = oneshot::channel();
907                    inner.register_replies.insert(
908                        up_id,
909                        PendingRegister {
910                            tx,
911                            target_channel: router_id,
912                        },
913                    );
914                    if router_ch
915                        .send(Message::RegisterCommandRequest {
916                            id: up_id,
917                            command: command.clone(),
918                        })
919                        .is_ok()
920                    {
921                        let up = rx.await;
922                        match up {
923                            Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
924                            Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
925                                let _ = channel.send(Message::RegisterCommandResponse {
926                                    id: MessageId::new_v4(),
927                                    thid: req_id,
928                                    response: RegisterResult::Err { ok: False, error },
929                                });
930                                return;
931                            }
932                            // Timeout / disconnected upstream: we have no
933                            // wire-level error code for these on the
934                            // register response, so surface them as
935                            // duplicate_command to match the prior
936                            // behaviour. The escalating caller's own
937                            // register_command call will see the correct
938                            // Timeout / ChannelDisconnected via its own
939                            // pending entry.
940                            Ok(RegisterOutcome::Timeout)
941                            | Ok(RegisterOutcome::Disconnected)
942                            | Err(_) => {
943                                let _ = channel.send(Message::RegisterCommandResponse {
944                                    id: MessageId::new_v4(),
945                                    thid: req_id,
946                                    response: RegisterResult::Err {
947                                        ok: False,
948                                        error: RegisterErrorCode::DuplicateCommand,
949                                    },
950                                });
951                                return;
952                            }
953                        }
954                    }
955                }
956            }
957        }
958
959        inner.remote.lock().insert(command_id.clone(), channel_id);
960        inner.remote_defs.lock().insert(command_id, command);
961        let _ = channel.send(Message::RegisterCommandResponse {
962            id: MessageId::new_v4(),
963            thid: req_id,
964            response: RegisterResult::Ok { ok: True },
965        });
966    }
967
968    async fn handle_execute_request(
969        inner: Arc<Self>,
970        origin: Arc<dyn CommandChannel>,
971        req_id: MessageId,
972        command_id: String,
973        request: Value,
974    ) {
975        // Local handler?
976        let handler = inner
977            .local
978            .lock()
979            .get(&command_id)
980            .map(|e| e.handler.clone());
981        if let Some(handler) = handler {
982            let result = handler(request).await;
983            let response = match result {
984                Ok(v) => ExecuteResult::Ok {
985                    ok: True,
986                    // Void responses (`() → Value::Null`) are elided from
987                    // the wire per the response schema (`result` optional).
988                    result: value_to_wire(v),
989                },
990                Err(error) => ExecuteResult::Err { ok: False, error },
991            };
992            let _ = origin.send(Message::ExecuteCommandResponse {
993                id: MessageId::new_v4(),
994                thid: req_id,
995                response,
996            });
997            return;
998        }
999
1000        // Forward?
1001        let target_id = inner
1002            .remote
1003            .lock()
1004            .get(&command_id)
1005            .cloned()
1006            .or_else(|| inner.router_channel.clone());
1007
1008        let origin_id = origin.id().to_string();
1009        let Some(target_id) = target_id else {
1010            let _ = origin.send(Message::ExecuteCommandResponse {
1011                id: MessageId::new_v4(),
1012                thid: req_id,
1013                response: ExecuteResult::Err {
1014                    ok: False,
1015                    error: ExecuteError {
1016                        code: ExecuteErrorCode::NotFound,
1017                        message: format!("command not found: {command_id}"),
1018                    },
1019                },
1020            });
1021            return;
1022        };
1023
1024        if target_id == origin_id {
1025            // Would loop; treat as not found.
1026            let _ = origin.send(Message::ExecuteCommandResponse {
1027                id: MessageId::new_v4(),
1028                thid: req_id,
1029                response: ExecuteResult::Err {
1030                    ok: False,
1031                    error: ExecuteError {
1032                        code: ExecuteErrorCode::NotFound,
1033                        message: format!("command not found: {command_id}"),
1034                    },
1035                },
1036            });
1037            return;
1038        }
1039
1040        let target = inner.channels.lock().get(&target_id).cloned();
1041        let Some(target) = target else {
1042            let _ = origin.send(Message::ExecuteCommandResponse {
1043                id: MessageId::new_v4(),
1044                thid: req_id,
1045                response: ExecuteResult::Err {
1046                    ok: False,
1047                    error: ExecuteError {
1048                        code: ExecuteErrorCode::ChannelDisconnected,
1049                        message: "target channel disconnected".into(),
1050                    },
1051                },
1052            });
1053            return;
1054        };
1055
1056        inner.routes.insert(
1057            req_id,
1058            RouteEntry {
1059                origin_channel: origin_id,
1060                target_channel: target_id,
1061            },
1062        );
1063        let _ = target.send(Message::ExecuteCommandRequest {
1064            id: req_id,
1065            command_id,
1066            request: value_to_wire(request),
1067        });
1068    }
1069
1070    fn handle_execute_response(inner: &Arc<Self>, thid: MessageId, response: ExecuteResult) {
1071        // Either this is a reply to a local call…
1072        if let Some(pending) = inner.execute_replies.remove(&thid) {
1073            let _ = pending.tx.send(response);
1074            return;
1075        }
1076
1077        // …or we forwarded this request and need to route the reply.
1078        if let Some(route) = inner.routes.remove(&thid) {
1079            let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1080            if let Some(origin) = origin {
1081                let _ = origin.send(Message::ExecuteCommandResponse {
1082                    id: MessageId::new_v4(),
1083                    thid,
1084                    response,
1085                });
1086            }
1087        }
1088    }
1089
1090    fn handle_event(
1091        inner: &Arc<Self>,
1092        origin: Arc<dyn CommandChannel>,
1093        msg_id: MessageId,
1094        event_id: String,
1095        payload: Option<Value>,
1096    ) {
1097        if inner.seen_events.contains_key(&msg_id) {
1098            return;
1099        }
1100        inner.seen_events.insert(msg_id, ());
1101
1102        let payload_value = payload.clone().unwrap_or(Value::Null);
1103        let listeners: Vec<EventListener> = inner
1104            .event_listeners
1105            .lock()
1106            .get(&event_id)
1107            .map(|m| m.values().cloned().collect())
1108            .unwrap_or_default();
1109        for l in listeners {
1110            l(payload_value.clone());
1111        }
1112
1113        if event_id.starts_with('_') {
1114            return;
1115        }
1116
1117        let channels: Vec<Arc<dyn CommandChannel>> = inner
1118            .channels
1119            .lock()
1120            .iter()
1121            .filter(|(k, _)| k.as_str() != origin.id())
1122            .map(|(_, v)| v.clone())
1123            .collect();
1124        for ch in channels {
1125            let _ = ch.send(Message::Event {
1126                id: msg_id,
1127                event_id: event_id.clone(),
1128                payload: payload.clone(),
1129            });
1130        }
1131    }
1132
1133    /// Invoked by the driver once the channel returns `None` from recv.
1134    fn handle_channel_close(inner: &Arc<Self>, channel_id: &str) {
1135        // Drop the channel from the lookup table.
1136        inner.channels.lock().remove(channel_id);
1137
1138        // Drop every remote command owned by this channel, along with
1139        // its cached definition.
1140        let dropped_ids: Vec<String> = {
1141            let mut remote = inner.remote.lock();
1142            let to_drop: Vec<String> = remote
1143                .iter()
1144                .filter(|(_, owner)| *owner == channel_id)
1145                .map(|(id, _)| id.clone())
1146                .collect();
1147            for id in &to_drop {
1148                remote.remove(id);
1149            }
1150            to_drop
1151        };
1152        let mut remote_defs = inner.remote_defs.lock();
1153        for id in dropped_ids {
1154            remote_defs.remove(&id);
1155        }
1156        drop(remote_defs);
1157
1158        // Reject any pending executes whose response was expected from
1159        // this channel.
1160        let exec_ids: Vec<MessageId> = inner
1161            .execute_replies
1162            .snapshot_keys_where(|v| v.target_channel == channel_id);
1163        for id in exec_ids {
1164            if let Some(pending) = inner.execute_replies.remove(&id) {
1165                let _ = pending.tx.send(ExecuteResult::Err {
1166                    ok: False,
1167                    error: ExecuteError {
1168                        code: ExecuteErrorCode::ChannelDisconnected,
1169                        message: "channel disconnected".into(),
1170                    },
1171                });
1172            }
1173        }
1174
1175        let reg_ids: Vec<MessageId> = inner
1176            .register_replies
1177            .snapshot_keys_where(|v| v.target_channel == channel_id);
1178        for id in reg_ids {
1179            if let Some(pending) = inner.register_replies.remove(&id) {
1180                let _ = pending.tx.send(RegisterOutcome::Disconnected);
1181            }
1182        }
1183
1184        // For every route where either endpoint is the dead channel,
1185        // notify the origin (if it is still alive).
1186        let route_ids: Vec<MessageId> = inner.routes.snapshot_keys_where(|r| {
1187            r.origin_channel == channel_id || r.target_channel == channel_id
1188        });
1189        for id in route_ids {
1190            if let Some(route) = inner.routes.remove(&id) {
1191                if route.origin_channel == channel_id {
1192                    continue;
1193                }
1194                let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1195                if let Some(origin) = origin {
1196                    let _ = origin.send(Message::ExecuteCommandResponse {
1197                        id: MessageId::new_v4(),
1198                        thid: id,
1199                        response: ExecuteResult::Err {
1200                            ok: False,
1201                            error: ExecuteError {
1202                                code: ExecuteErrorCode::ChannelDisconnected,
1203                                message: "target channel disconnected".into(),
1204                            },
1205                        },
1206                    });
1207                }
1208            }
1209        }
1210    }
1211}
1212
1213// ------- helpers -----------------------------------------------------
1214
1215fn command_error_to_execute(e: &CommandError, command_id: &str) -> ExecuteError {
1216    match e {
1217        CommandError::InvalidRequest { message, .. } => ExecuteError {
1218            code: ExecuteErrorCode::InvalidRequest,
1219            message: message.clone(),
1220        },
1221        CommandError::Internal { message, .. } => ExecuteError {
1222            code: ExecuteErrorCode::InternalError,
1223            message: message.clone(),
1224        },
1225        CommandError::Timeout => ExecuteError {
1226            code: ExecuteErrorCode::Timeout,
1227            message: "request timed out".into(),
1228        },
1229        CommandError::ChannelDisconnected => ExecuteError {
1230            code: ExecuteErrorCode::ChannelDisconnected,
1231            message: "channel disconnected".into(),
1232        },
1233        CommandError::NotFound(id) => ExecuteError {
1234            code: ExecuteErrorCode::NotFound,
1235            message: format!("command not found: {id}"),
1236        },
1237        _ => ExecuteError {
1238            code: ExecuteErrorCode::InternalError,
1239            message: format!("{e} [command {command_id}]"),
1240        },
1241    }
1242}
1243
1244fn error_to_command_error(err: ExecuteError, command_id: &str) -> CommandError {
1245    match err.code {
1246        ExecuteErrorCode::NotFound => CommandError::NotFound(err.message),
1247        ExecuteErrorCode::InvalidRequest => CommandError::InvalidRequest {
1248            command_id: command_id.into(),
1249            message: err.message,
1250        },
1251        ExecuteErrorCode::InternalError => CommandError::Internal {
1252            command_id: command_id.into(),
1253            message: err.message,
1254        },
1255        ExecuteErrorCode::Timeout => CommandError::Timeout,
1256        ExecuteErrorCode::ChannelDisconnected => CommandError::ChannelDisconnected,
1257    }
1258}
1259
1260// Small convenience on ExecuteError.
1261impl ExecuteError {
1262    fn into_command_error(self, command_id: &str) -> CommandError {
1263        error_to_command_error(self, command_id)
1264    }
1265}
1266
1267/// Collapse a serialized request/result/payload value to `None` when it
1268/// is JSON `null`. This is what makes void commands and events
1269/// spec-compliant on the wire: `request` / `result` / `payload` are all
1270/// optional fields in the JSON schemas, so an absent value must be
1271/// encoded by omitting the key, not by emitting `null`.
1272///
1273/// Used on every outgoing `execute.command.request`,
1274/// `execute.command.response` success, and `event` message.
1275fn value_to_wire(v: Value) -> Option<Value> {
1276    if v.is_null() {
1277        None
1278    } else {
1279        Some(v)
1280    }
1281}
1282
1283/// Serialize a strict-mode request value to JSON. Wraps `serde_json`
1284/// with the right error type for the strict `execute::<C>` path.
1285fn value_from_request<T: Serialize>(v: &T) -> Result<Value, CommandError> {
1286    serde_json::to_value(v).map_err(CommandError::Serde)
1287}