coralstack_cmd_ipc/registry.rs
1//! The [`CommandRegistry`] — core routing and execution hub.
2//!
3//! Mirrors the TypeScript `CommandRegistry` in
4//! `packages/cmd-ipc/src/registry/command-registry.ts`: it owns a local
5//! command table, a remote command table (command → owning channel), a
6//! set of connected channels, and a handful of [`TtlMap`]s correlating
7//! in-flight requests, forwarded routes, and recently-seen events.
8//!
9//! # Topology
10//!
11//! * **Root** registries have no `router_channel`. Unknown commands
12//! produce `NotFound` errors.
13//! * **Child** registries set `router_channel = Some(peer_id)`. Unknown
14//! commands, and new local registrations, are escalated upstream.
15//! * Events fan out across every connected channel; dedup by message
16//! id prevents echo loops in meshes.
17//!
18//! Private commands/events (identifiers starting with `_`) stay local
19//! — never escalated, never advertised, never broadcast.
20
21use std::collections::{BTreeMap, HashMap};
22use std::future::Future;
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::Duration;
26
27use futures::channel::oneshot;
28use futures::future::BoxFuture;
29use futures::stream::FuturesUnordered;
30use futures::{FutureExt, StreamExt};
31use parking_lot::Mutex;
32use serde::Serialize;
33use serde_json::Value;
34use uuid::Uuid;
35
36use crate::channel::CommandChannel;
37use crate::command::Command;
38use crate::error::{ChannelError, CommandError, ExecuteErrorCode, RegisterErrorCode};
39use crate::event::Event;
40use crate::message::{
41 CommandDef, ExecuteError, ExecuteResult, False, Message, MessageId, RegisterResult, True,
42};
43use crate::ttl_map::TtlMap;
44
45/// Configuration for a [`CommandRegistry`].
46pub struct Config {
47 /// Registry identifier used in log messages. Defaults to a random
48 /// UUID.
49 pub id: Option<String>,
50 /// Channel id to escalate unknown commands and new registrations
51 /// to. Leave `None` for a root registry.
52 pub router_channel: Option<String>,
53 /// How long a pending `execute` / `register` reply can wait before
54 /// being rejected with [`CommandError::Timeout`]. Zero disables
55 /// the TTL check (request hangs until the channel closes).
56 pub request_ttl: Duration,
57 /// How long a seen event id is remembered for dedup purposes.
58 pub event_ttl: Duration,
59 /// Maximum number of handler futures that may be in flight
60 /// concurrently on a single channel's pump. Once this cap is
61 /// reached, the pump stops pulling new messages off the channel
62 /// (applying upstream backpressure) until an in-flight handler
63 /// finishes. `0` disables the cap.
64 ///
65 /// **Optional in practice** — `Config::default()` sets this to
66 /// `256`, so callers using `..Default::default()` never need to
67 /// supply it. It is only *syntactically* required when you build
68 /// `Config { … }` field-by-field (Rust struct literals must list
69 /// every field). The cap exists so a misbehaving peer can't cause
70 /// unbounded handler-future buildup; the default is fine for
71 /// almost every workload.
72 pub max_in_flight_per_channel: usize,
73}
74
75impl Default for Config {
76 fn default() -> Self {
77 Self {
78 id: None,
79 router_channel: None,
80 request_ttl: Duration::from_secs(30),
81 event_ttl: Duration::from_secs(5),
82 max_in_flight_per_channel: 256,
83 }
84 }
85}
86
87type HandlerFn = dyn Fn(Value) -> BoxFuture<'static, Result<Value, ExecuteError>> + Send + Sync;
88type EventListener = Arc<dyn Fn(Value) + Send + Sync>;
89
90struct LocalEntry {
91 handler: Arc<HandlerFn>,
92 def: CommandDef,
93 is_private: bool,
94}
95
96struct PendingExecute {
97 tx: oneshot::Sender<ExecuteResult>,
98 target_channel: String,
99}
100
101/// Outcome delivered to the caller awaiting a `register.command.request`
102/// reply. The peer's `RegisterResult` is wrapped so the registry can
103/// synthesize timeout / disconnect errors distinguishable from the
104/// wire-level duplicate-command case.
105enum RegisterOutcome {
106 Wire(RegisterResult),
107 Timeout,
108 Disconnected,
109}
110
111struct PendingRegister {
112 tx: oneshot::Sender<RegisterOutcome>,
113 target_channel: String,
114}
115
116struct RouteEntry {
117 origin_channel: String,
118 target_channel: String,
119}
120
121/// Shared state behind the [`CommandRegistry`] Arc.
122struct Inner {
123 id: String,
124 router_channel: Option<String>,
125 local: Mutex<HashMap<String, LocalEntry>>,
126 /// command id -> owning channel id
127 remote: Mutex<HashMap<String, String>>,
128 /// command id -> advertised definition (description + schema)
129 /// kept parallel to `remote` so `list_commands` can render
130 /// the same richness for remote entries as for local ones.
131 remote_defs: Mutex<HashMap<String, CommandDef>>,
132 channels: Mutex<HashMap<String, Arc<dyn CommandChannel>>>,
133 execute_replies: TtlMap<MessageId, PendingExecute>,
134 register_replies: TtlMap<MessageId, PendingRegister>,
135 routes: TtlMap<MessageId, RouteEntry>,
136 seen_events: TtlMap<MessageId, ()>,
137 /// Event listeners keyed by event id then by a monotonic token, so
138 /// `add_event_listener` can return an unsubscribe closure that
139 /// removes just that one listener. `BTreeMap` preserves insertion
140 /// order (tokens are monotonically increasing) for dispatch.
141 event_listeners: Mutex<HashMap<String, BTreeMap<u64, EventListener>>>,
142 /// Monotonically-increasing token used to key event listeners.
143 next_listener_token: AtomicU64,
144 /// Per-channel in-flight cap, copied from [`Config`].
145 max_in_flight_per_channel: usize,
146}
147
148/// The main entry point of the crate.
149///
150/// A registry is cheap to clone: internally it's an `Arc<Inner>`.
151#[derive(Clone)]
152pub struct CommandRegistry {
153 inner: Arc<Inner>,
154}
155
156impl CommandRegistry {
157 pub fn new(cfg: Config) -> Self {
158 // On TTL expiry, deliver an explicit Timeout outcome on the
159 // pending oneshot so the caller's `rx.await` resolves to
160 // `CommandError::Timeout` rather than blocking forever (or
161 // collapsing to a generic `ChannelDisconnected` when the tx is
162 // dropped). Eviction is lazy — driven by `sweep_expired()`
163 // calls in the driver loop and registry hot paths.
164 let execute_replies =
165 TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingExecute| {
166 let _ = pending.tx.send(ExecuteResult::Err {
167 ok: False,
168 error: ExecuteError {
169 code: ExecuteErrorCode::Timeout,
170 message: "request timed out".into(),
171 },
172 });
173 });
174 let register_replies =
175 TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingRegister| {
176 let _ = pending.tx.send(RegisterOutcome::Timeout);
177 });
178
179 let inner = Arc::new(Inner {
180 id: cfg.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
181 router_channel: cfg.router_channel,
182 local: Mutex::new(HashMap::new()),
183 remote: Mutex::new(HashMap::new()),
184 remote_defs: Mutex::new(HashMap::new()),
185 channels: Mutex::new(HashMap::new()),
186 execute_replies,
187 register_replies,
188 routes: TtlMap::new(cfg.request_ttl),
189 seen_events: TtlMap::new(cfg.event_ttl),
190 event_listeners: Mutex::new(HashMap::new()),
191 next_listener_token: AtomicU64::new(0),
192 max_in_flight_per_channel: cfg.max_in_flight_per_channel,
193 });
194 Self { inner }
195 }
196
197 /// Returns this registry's identifier.
198 pub fn id(&self) -> &str {
199 &self.inner.id
200 }
201
202 /// Returns the ids of every currently-registered channel, sorted.
203 ///
204 /// Mirrors the TypeScript library's `listChannels()` method.
205 pub fn list_channels(&self) -> Vec<String> {
206 let mut ids: Vec<String> = self.inner.channels.lock().keys().cloned().collect();
207 ids.sort();
208 ids
209 }
210
211 /// Returns the full [`CommandDef`] (id + description + schema) for
212 /// every reachable command — local (non-private) and remote. Remote
213 /// defs are those advertised via `register.command.request` or
214 /// `list.commands.response` on the channel.
215 ///
216 /// Mirrors the TypeScript library's `listCommands()` method.
217 /// Results are sorted by id. A command id is only included once even
218 /// if both a local and remote entry exist (local wins).
219 pub fn list_commands(&self) -> Vec<CommandDef> {
220 let mut out: HashMap<String, CommandDef> = HashMap::new();
221 for (id, entry) in self.inner.local.lock().iter() {
222 if !entry.is_private {
223 out.insert(id.clone(), entry.def.clone());
224 }
225 }
226 for (id, def) in self.inner.remote_defs.lock().iter() {
227 out.entry(id.clone()).or_insert_with(|| def.clone());
228 }
229 let mut v: Vec<CommandDef> = out.into_values().collect();
230 v.sort_by(|a, b| a.id.cmp(&b.id));
231 v
232 }
233
234 /// Register a command on this registry.
235 ///
236 /// The single registration entry point, covering both compile-time
237 /// and runtime commands:
238 ///
239 /// - **Compile-time**: pass an instance of a type that implements
240 /// [`Command`]. The `#[command]` / `#[command_service]` macros
241 /// generate such types from a plain `async fn`.
242 /// - **Runtime**: pass a [`DynCommand`] carrying owned id /
243 /// description / schema and a closure handler.
244 ///
245 /// Mirrors the TypeScript library's `registerCommand`.
246 ///
247 /// - Commands whose id starts with `_` stay local: they are never
248 /// escalated to a `router_channel` and never advertised to peers
249 /// via `list.commands.response`.
250 /// - Non-private commands are escalated upstream if this registry
251 /// has a `router_channel`; the local entry is only committed
252 /// after the router acks.
253 /// - The advertised schema is normalized via
254 /// [`crate::schema::normalize_schema`] on the way in, so every
255 /// schema leaving the registry is language-agnostic JSON Schema
256 /// regardless of how the caller built it.
257 pub async fn register_command<C: Command>(&self, cmd: C) -> Result<(), CommandError> {
258 let id = cmd.id().to_string();
259 let description = cmd.description().map(str::to_string);
260 let schema = cmd.schema().map(crate::schema::normalize_command_schema);
261 let is_private = id.starts_with('_');
262 let def = CommandDef {
263 id: id.clone(),
264 description,
265 schema,
266 };
267 let handler: Arc<HandlerFn> = Arc::new({
268 let cmd = Arc::new(cmd);
269 move |value: Value| {
270 let cmd = cmd.clone();
271 async move {
272 let req: C::Request =
273 serde_json::from_value(value).map_err(|e| ExecuteError {
274 code: ExecuteErrorCode::InvalidRequest,
275 message: e.to_string(),
276 })?;
277 let res = cmd
278 .handle(req)
279 .await
280 .map_err(|e| command_error_to_execute(&e, cmd.id()))?;
281 serde_json::to_value(res).map_err(|e| ExecuteError {
282 code: ExecuteErrorCode::InternalError,
283 message: e.to_string(),
284 })
285 }
286 .boxed()
287 }
288 });
289 self.register_inner(id, handler, def, is_private).await
290 }
291
292 async fn register_inner(
293 &self,
294 id: String,
295 handler: Arc<HandlerFn>,
296 def: CommandDef,
297 is_private: bool,
298 ) -> Result<(), CommandError> {
299 self.inner.execute_replies.sweep_expired();
300 self.inner.register_replies.sweep_expired();
301 // Duplicate check against the local table.
302 if self.inner.local.lock().contains_key(&id) {
303 return Err(CommandError::DuplicateCommand(id));
304 }
305
306 // Non-private commands escalate to the router before being added.
307 if !is_private {
308 if let Some(router_id) = self.inner.router_channel.clone() {
309 let router_ch = self.inner.channels.lock().get(&router_id).cloned();
310 if let Some(router_ch) = router_ch {
311 let req_id = MessageId::new_v4();
312 let (tx, rx) = oneshot::channel();
313 self.inner.register_replies.insert(
314 req_id,
315 PendingRegister {
316 tx,
317 target_channel: router_id.clone(),
318 },
319 );
320 router_ch
321 .send(Message::RegisterCommandRequest {
322 id: req_id,
323 command: def.clone(),
324 })
325 .map_err(|_| CommandError::ChannelDisconnected)?;
326 match rx.await {
327 Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
328 Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
329 return Err(match error {
330 RegisterErrorCode::DuplicateCommand => {
331 CommandError::DuplicateCommand(id)
332 }
333 });
334 }
335 Ok(RegisterOutcome::Timeout) => return Err(CommandError::Timeout),
336 Ok(RegisterOutcome::Disconnected) | Err(_) => {
337 return Err(CommandError::ChannelDisconnected);
338 }
339 }
340 }
341 }
342 }
343
344 self.inner.local.lock().insert(
345 id,
346 LocalEntry {
347 handler,
348 def,
349 is_private,
350 },
351 );
352 Ok(())
353 }
354
355 /// Connects a [`CommandChannel`] to this registry.
356 ///
357 /// Returns a driver future which must be polled by the caller's
358 /// executor (via `tokio::spawn`, `smol::spawn`,
359 /// `futures::executor::block_on`, …) for the registry to exchange
360 /// messages with the peer. The future completes when the channel
361 /// closes.
362 ///
363 /// Handler dispatch is **concurrent within the driver task**: the
364 /// pump pushes each incoming message's handler future into a
365 /// [`FuturesUnordered`] and cooperatively interleaves them with
366 /// the next `recv`, so a slow handler that awaits external work
367 /// no longer blocks subsequent messages on the same channel. The
368 /// number of simultaneously in-flight handlers is capped by
369 /// [`Config::max_in_flight_per_channel`] (default 256); at the
370 /// cap the pump applies backpressure to the channel rather than
371 /// dropping messages. For true multi-thread parallelism, wrap
372 /// the handler body in your runtime's `spawn` (e.g.
373 /// `tokio::spawn`) — the crate itself stays runtime-agnostic.
374 pub async fn register_channel(
375 &self,
376 channel: Arc<dyn CommandChannel>,
377 ) -> Result<impl Future<Output = ()> + Send + 'static, ChannelError> {
378 let id = channel.id().to_string();
379 {
380 let mut chans = self.inner.channels.lock();
381 if chans.contains_key(&id) {
382 return Err(ChannelError::Other(format!(
383 "channel with id `{id}` already registered"
384 )));
385 }
386 chans.insert(id.clone(), channel.clone());
387 }
388
389 channel.start().await?;
390
391 // Ask the peer for its command list. The response is handled
392 // by the driver loop, which will register each entry as remote.
393 if let Err(e) = channel.send(Message::ListCommandsRequest {
394 id: MessageId::new_v4(),
395 }) {
396 self.inner.channels.lock().remove(&id);
397 return Err(e);
398 }
399
400 let inner = self.inner.clone();
401 let ch = channel;
402 Ok(async move {
403 // In-flight handler futures cooperatively interleaved with
404 // message recv. This is what makes the pump non-blocking:
405 // when one handler awaits (sleep / network / forwarded
406 // call), the executor polls other handlers and the recv
407 // future on the same task, so a slow handler can no
408 // longer wedge every subsequent message behind it.
409 //
410 // Stays runtime-agnostic (no `tokio::spawn`); the user's
411 // executor still owns the single task driving this pump.
412 // True multi-thread parallelism is possible if the user
413 // wraps their handler body in `tokio::spawn` themselves.
414 let mut in_flight: FuturesUnordered<BoxFuture<'static, ()>> = FuturesUnordered::new();
415 let cap = inner.max_in_flight_per_channel;
416
417 loop {
418 // Backpressure: while at capacity, drain in-flight
419 // handlers instead of accepting new messages. Ordering
420 // of message *reception* is preserved — only dispatch
421 // fans out. Messages are never dropped.
422 while cap > 0 && in_flight.len() >= cap {
423 if in_flight.next().await.is_none() {
424 break;
425 }
426 }
427
428 // Opportunistic TTL sweep: any registry activity
429 // triggers a pass over the pending-reply / route maps
430 // so timed-out entries fire their on_expire callbacks
431 // (delivering Timeout to waiting callers). Cheap —
432 // O(n) over maps that are normally tiny.
433 inner.execute_replies.sweep_expired();
434 inner.register_replies.sweep_expired();
435 inner.routes.sweep_expired();
436
437 // Wait for either the next message OR for an in-flight
438 // handler to make progress. When `in_flight` is empty
439 // we'd otherwise busy-loop on its `.next()` returning
440 // `None`, so park on `pending()` in that case.
441 let next_msg = if in_flight.is_empty() {
442 ch.recv().await
443 } else {
444 futures::select_biased! {
445 // Drain finished handlers so the in_flight set
446 // shrinks promptly. `select_next_some` skips
447 // the empty case, but we've already guarded
448 // against that above.
449 _ = in_flight.select_next_some() => continue,
450 msg = ch.recv().fuse() => msg,
451 }
452 };
453
454 let Some(msg) = next_msg else { break };
455 in_flight.push(Inner::handle_message(inner.clone(), ch.clone(), msg).boxed());
456 }
457
458 // Channel closed — drain any in-flight handlers so their
459 // outgoing responses are sent before we tear the channel
460 // state down. Handlers that try to `send` on the now-closed
461 // channel will get `Err(Closed)` from the transport, which
462 // is the existing semantics.
463 while in_flight.next().await.is_some() {}
464 Inner::handle_channel_close(&inner, ch.id());
465 })
466 }
467
468 /// Executes a command identified by a compile-time [`Command`] type
469 /// — the **strict** form, giving the same compile-time type safety
470 /// that TypeScript's strict-mode `executeCommand<K>` gives via the
471 /// `CommandSchemaMap` type parameter.
472 ///
473 /// The command id comes from `C::ID`, the request type is pinned to
474 /// `C::Request`, and the response type is pinned to `C::Response`,
475 /// so the compiler rejects mismatches at the call site:
476 ///
477 /// ```ignore
478 /// let sum: i64 = registry.execute::<MathAdd>(AddReq { a: 2, b: 3 }).await?;
479 /// ```
480 ///
481 /// For commands whose id or payload shape is only known at runtime
482 /// (scripting hosts, FFI, plugins that advertise their own schema),
483 /// use [`execute_dyn`](Self::execute_dyn).
484 pub async fn execute<C: Command>(
485 &self,
486 request: C::Request,
487 ) -> Result<C::Response, CommandError>
488 where
489 C::Request: Serialize,
490 C::Response: serde::de::DeserializeOwned,
491 {
492 let req_value = value_from_request(&request)?;
493 let result = self.execute_raw_impl(C::ID.to_string(), req_value).await?;
494 let deserialized = serde_json::from_value(result.unwrap_or(Value::Null))?;
495 Ok(deserialized)
496 }
497
498 /// Executes a command whose id is only known at runtime — the
499 /// **loose** form, mirroring the TypeScript library's
500 /// `executeCommand(id, args)` in loose mode.
501 ///
502 /// Request and response are raw [`serde_json::Value`]s, so this is
503 /// the canonical entry point for plugin hosts, scripting runtimes,
504 /// FFI bridges, and any code where the schema is discovered via
505 /// [`list_commands`](Self::list_commands) rather than declared at
506 /// compile time.
507 ///
508 /// For statically-known commands, prefer [`execute`](Self::execute)
509 /// — it pins both types via the [`Command`] trait.
510 pub async fn execute_dyn(
511 &self,
512 command_id: &str,
513 request: Value,
514 ) -> Result<Value, CommandError> {
515 let result = self
516 .execute_raw_impl(command_id.to_string(), request)
517 .await?;
518 Ok(result.unwrap_or(Value::Null))
519 }
520
521 async fn execute_raw_impl(
522 &self,
523 command_id: String,
524 request: Value,
525 ) -> Result<Option<Value>, CommandError> {
526 // Flush any expired pending replies so stale entries fire their
527 // Timeout on_expire before we enqueue a new one.
528 self.inner.execute_replies.sweep_expired();
529 self.inner.register_replies.sweep_expired();
530 // 1) Local handler wins.
531 let local_handler = self
532 .inner
533 .local
534 .lock()
535 .get(&command_id)
536 .map(|entry| entry.handler.clone());
537 if let Some(handler) = local_handler {
538 return handler(request)
539 .await
540 .map(Some)
541 .map_err(|e| e.into_command_error(&command_id));
542 }
543
544 // 2) Known remote command.
545 let remote_target = self.inner.remote.lock().get(&command_id).cloned();
546 let target = match remote_target {
547 Some(t) => Some(t),
548 None => self.inner.router_channel.clone(),
549 };
550
551 let Some(target_id) = target else {
552 return Err(CommandError::NotFound(command_id));
553 };
554
555 let channel = self.inner.channels.lock().get(&target_id).cloned();
556 let Some(channel) = channel else {
557 return Err(CommandError::ChannelDisconnected);
558 };
559
560 self.forward_execute(command_id, request, &channel, target_id)
561 .await
562 }
563
564 async fn forward_execute(
565 &self,
566 command_id: String,
567 request: Value,
568 channel: &Arc<dyn CommandChannel>,
569 target_id: String,
570 ) -> Result<Option<Value>, CommandError> {
571 let req_id = MessageId::new_v4();
572 let (tx, rx) = oneshot::channel();
573 self.inner.execute_replies.insert(
574 req_id,
575 PendingExecute {
576 tx,
577 target_channel: target_id,
578 },
579 );
580 channel
581 .send(Message::ExecuteCommandRequest {
582 id: req_id,
583 command_id: command_id.clone(),
584 // Void requests are elided from the wire (Null → None)
585 // so peers expecting an absent `request` field (per the
586 // JSON Schema spec) don't see `"request": null`.
587 request: value_to_wire(request),
588 })
589 .map_err(|_| CommandError::ChannelDisconnected)?;
590
591 match rx.await {
592 Ok(ExecuteResult::Ok { result, .. }) => Ok(result),
593 Ok(ExecuteResult::Err { error, .. }) => Err(error_to_command_error(error, &command_id)),
594 Err(_) => {
595 self.inner.execute_replies.remove(&req_id);
596 Err(CommandError::ChannelDisconnected)
597 }
598 }
599 }
600
601 /// Emit an event. Dispatches to local listeners and — unless the
602 /// event id is private (starts with `_`) — broadcasts to every
603 /// connected channel.
604 ///
605 /// Works for both compile-time events (`#[event]`-annotated
606 /// structs) and runtime events ([`DynEvent`](crate::command::Command)
607 /// — actually [`DynEvent`](crate::event::DynEvent)). Id, description,
608 /// and schema are all read off the event instance.
609 pub fn emit<E: Event>(&self, event: E) -> Result<(), CommandError> {
610 let event_id = event.id().to_string();
611 let payload_value = serde_json::to_value(&event)?;
612 let msg_id = MessageId::new_v4();
613 self.inner.seen_events.insert(msg_id, ());
614
615 self.dispatch_event_locally(&event_id, &payload_value);
616
617 if !event_id.starts_with('_') {
618 let channels: Vec<Arc<dyn CommandChannel>> =
619 self.inner.channels.lock().values().cloned().collect();
620 // Void payloads (serde `()` → `Value::Null`) are elided
621 // from the wire per the event schema (`payload` is optional).
622 let wire_payload = value_to_wire(payload_value);
623 for ch in channels {
624 let _ = ch.send(Message::Event {
625 id: msg_id,
626 event_id: event_id.clone(),
627 payload: wire_payload.clone(),
628 });
629 }
630 }
631 Ok(())
632 }
633
634 /// Subscribe a typed listener. The callback receives a
635 /// deserialized `E` every time an event with id `E::ID` fires,
636 /// whether emitted locally or received from a connected channel.
637 ///
638 /// Returns an unsubscribe closure — call it (and drop it) to
639 /// remove just this listener. Ignoring the return value is fine;
640 /// the listener then lives for the life of the registry.
641 ///
642 /// Listeners for the same event fire in insertion order. Payloads
643 /// that fail to deserialize into `E` are silently dropped for
644 /// this listener — they still flow to any typed-for-Value
645 /// listeners registered via [`on_dyn`](Self::on_dyn).
646 pub fn on<E: Event + serde::de::DeserializeOwned>(
647 &self,
648 listener: impl Fn(E) + Send + Sync + 'static,
649 ) -> impl FnOnce() + Send + Sync + 'static {
650 self.install_listener(E::ID, move |value| {
651 if let Ok(typed) = serde_json::from_value::<E>(value) {
652 listener(typed);
653 }
654 })
655 }
656
657 /// Subscribe a dynamic listener by runtime id. The callback
658 /// receives the raw JSON payload. Use this when the event id is
659 /// only known at runtime (plugin runtimes, FFI, scripting hosts);
660 /// prefer [`on`](Self::on) whenever you have a compile-time
661 /// [`Event`] type.
662 ///
663 /// Same unsubscribe semantics as [`on`](Self::on).
664 pub fn on_dyn<F>(
665 &self,
666 event_id: impl Into<String>,
667 listener: F,
668 ) -> impl FnOnce() + Send + Sync + 'static
669 where
670 F: Fn(Value) + Send + Sync + 'static,
671 {
672 self.install_listener(&event_id.into(), listener)
673 }
674
675 fn install_listener<F>(
676 &self,
677 event_id: &str,
678 listener: F,
679 ) -> impl FnOnce() + Send + Sync + 'static
680 where
681 F: Fn(Value) + Send + Sync + 'static,
682 {
683 let token = self
684 .inner
685 .next_listener_token
686 .fetch_add(1, Ordering::Relaxed);
687 self.inner
688 .event_listeners
689 .lock()
690 .entry(event_id.to_string())
691 .or_default()
692 .insert(token, Arc::new(listener));
693
694 let inner = Arc::clone(&self.inner);
695 let event_id = event_id.to_string();
696 move || {
697 let mut map = inner.event_listeners.lock();
698 if let Some(slot) = map.get_mut(&event_id) {
699 slot.remove(&token);
700 if slot.is_empty() {
701 map.remove(&event_id);
702 }
703 }
704 }
705 }
706
707 fn dispatch_event_locally(&self, event_id: &str, payload: &Value) {
708 let listeners: Vec<EventListener> = self
709 .inner
710 .event_listeners
711 .lock()
712 .get(event_id)
713 .map(|m| m.values().cloned().collect())
714 .unwrap_or_default();
715 for l in listeners {
716 l(payload.clone());
717 }
718 }
719
720 /// Tears down the registry: awaits `close()` on every connected
721 /// channel, drops every local and remote command, and clears all
722 /// event listeners. In-flight executes and register requests fail
723 /// with [`CommandError::ChannelDisconnected`] via the existing
724 /// channel-close path.
725 ///
726 /// Mirrors the TypeScript library's `dispose()`, but async so that
727 /// transports doing real teardown work (HTTP flush, MCP goodbye,
728 /// plugin sandbox shutdown) complete before this returns.
729 ///
730 /// Callers normally don't need this — dropping the last
731 /// `CommandRegistry` clone releases the inner state automatically
732 /// via `Drop`. Use `dispose` when a *shared* registry (held through
733 /// multiple clones) needs to be forcibly torn down, or in tests.
734 pub async fn dispose(&self) {
735 // Snapshot channel arcs so we can call `close` without holding
736 // the channels lock for the duration.
737 let channels: Vec<Arc<dyn CommandChannel>> = {
738 let mut locked = self.inner.channels.lock();
739 let out: Vec<_> = locked.values().cloned().collect();
740 locked.clear();
741 out
742 };
743
744 // Await each channel's close sequentially. Channels define
745 // their own close semantics (InMemoryChannel is effectively
746 // synchronous; MCPServerChannel flushes the transport; Flow's
747 // SourceChannel tears down its QuickJS VM), so we let every
748 // implementation finish its teardown before returning.
749 for ch in channels {
750 ch.close().await;
751 }
752
753 self.inner.local.lock().clear();
754 self.inner.remote.lock().clear();
755 self.inner.remote_defs.lock().clear();
756 self.inner.event_listeners.lock().clear();
757 }
758}
759
760impl Inner {
761 fn local_command_defs(&self) -> Vec<CommandDef> {
762 self.local
763 .lock()
764 .values()
765 .filter(|e| !e.is_private)
766 .map(|e| e.def.clone())
767 .collect()
768 }
769
770 /// Central dispatcher invoked by each channel's driver loop.
771 async fn handle_message(inner: Arc<Self>, channel: Arc<dyn CommandChannel>, msg: Message) {
772 match msg {
773 Message::RegisterCommandRequest { id, command } => {
774 Self::handle_register_request(inner, channel, id, command).await;
775 }
776 Message::RegisterCommandResponse { thid, response, .. } => {
777 if let Some(pending) = inner.register_replies.remove(&thid) {
778 let _ = pending.tx.send(RegisterOutcome::Wire(response));
779 }
780 }
781 Message::ListCommandsRequest { id } => {
782 let commands = inner.local_command_defs();
783 let _ = channel.send(Message::ListCommandsResponse {
784 id: MessageId::new_v4(),
785 thid: id,
786 commands,
787 });
788 }
789 Message::ListCommandsResponse { commands, .. } => {
790 let channel_id = channel.id().to_string();
791 let mut remote = inner.remote.lock();
792 let mut remote_defs = inner.remote_defs.lock();
793 for cmd in commands {
794 // Normalize ingested schemas so the local cache
795 // matches what register() produces.
796 let cmd = CommandDef {
797 id: cmd.id,
798 description: cmd.description,
799 schema: cmd.schema.map(crate::schema::normalize_command_schema),
800 };
801 let entry_is_new = !remote.contains_key(&cmd.id);
802 if entry_is_new {
803 remote.insert(cmd.id.clone(), channel_id.clone());
804 }
805 // Always refresh the def (the latest advertisement wins).
806 remote_defs.insert(cmd.id.clone(), cmd);
807 }
808 }
809 Message::ExecuteCommandRequest {
810 id,
811 command_id,
812 request,
813 } => {
814 Self::handle_execute_request(
815 inner,
816 channel,
817 id,
818 command_id,
819 request.unwrap_or(Value::Null),
820 )
821 .await;
822 }
823 Message::ExecuteCommandResponse { thid, response, .. } => {
824 Self::handle_execute_response(&inner, thid, response);
825 }
826 Message::Event {
827 id,
828 event_id,
829 payload,
830 } => {
831 Self::handle_event(&inner, channel, id, event_id, payload);
832 }
833 }
834 }
835
836 async fn handle_register_request(
837 inner: Arc<Self>,
838 channel: Arc<dyn CommandChannel>,
839 req_id: MessageId,
840 command: CommandDef,
841 ) {
842 // Normalize ingested schemas so our cached copy is guaranteed
843 // to be language-agnostic JSON Schema, even if the peer didn't
844 // normalize on its side.
845 let command = CommandDef {
846 id: command.id,
847 description: command.description,
848 schema: command.schema.map(crate::schema::normalize_command_schema),
849 };
850 let channel_id = channel.id().to_string();
851 let command_id = command.id.clone();
852
853 // Duplicate against local?
854 let dup = inner.local.lock().contains_key(&command_id);
855 if dup {
856 let _ = channel.send(Message::RegisterCommandResponse {
857 id: MessageId::new_v4(),
858 thid: req_id,
859 response: RegisterResult::Err {
860 ok: False,
861 error: RegisterErrorCode::DuplicateCommand,
862 },
863 });
864 return;
865 }
866
867 // Already known in the remote table?
868 // - Same channel re-advertising: short-circuit with a success
869 // ack, do NOT re-escalate upstream. A re-escalation would
870 // bubble up as a duplicate rejection from the router and
871 // spuriously fail a legitimate re-registration (common when
872 // a plugin host re-advertises its command list).
873 // - Different channel claiming the same id: reject as duplicate.
874 let existing_owner = inner.remote.lock().get(&command_id).cloned();
875 match existing_owner {
876 Some(owner) if owner == channel_id => {
877 // Refresh the cached def and re-ack. No upstream traffic.
878 inner.remote_defs.lock().insert(command_id, command);
879 let _ = channel.send(Message::RegisterCommandResponse {
880 id: MessageId::new_v4(),
881 thid: req_id,
882 response: RegisterResult::Ok { ok: True },
883 });
884 return;
885 }
886 Some(_) => {
887 let _ = channel.send(Message::RegisterCommandResponse {
888 id: MessageId::new_v4(),
889 thid: req_id,
890 response: RegisterResult::Err {
891 ok: False,
892 error: RegisterErrorCode::DuplicateCommand,
893 },
894 });
895 return;
896 }
897 None => {}
898 }
899
900 // Escalate upstream if we have a router.
901 if let Some(router_id) = inner.router_channel.clone() {
902 if router_id != channel_id {
903 let router_ch = inner.channels.lock().get(&router_id).cloned();
904 if let Some(router_ch) = router_ch {
905 let up_id = MessageId::new_v4();
906 let (tx, rx) = oneshot::channel();
907 inner.register_replies.insert(
908 up_id,
909 PendingRegister {
910 tx,
911 target_channel: router_id,
912 },
913 );
914 if router_ch
915 .send(Message::RegisterCommandRequest {
916 id: up_id,
917 command: command.clone(),
918 })
919 .is_ok()
920 {
921 let up = rx.await;
922 match up {
923 Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
924 Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
925 let _ = channel.send(Message::RegisterCommandResponse {
926 id: MessageId::new_v4(),
927 thid: req_id,
928 response: RegisterResult::Err { ok: False, error },
929 });
930 return;
931 }
932 // Timeout / disconnected upstream: we have no
933 // wire-level error code for these on the
934 // register response, so surface them as
935 // duplicate_command to match the prior
936 // behaviour. The escalating caller's own
937 // register_command call will see the correct
938 // Timeout / ChannelDisconnected via its own
939 // pending entry.
940 Ok(RegisterOutcome::Timeout)
941 | Ok(RegisterOutcome::Disconnected)
942 | Err(_) => {
943 let _ = channel.send(Message::RegisterCommandResponse {
944 id: MessageId::new_v4(),
945 thid: req_id,
946 response: RegisterResult::Err {
947 ok: False,
948 error: RegisterErrorCode::DuplicateCommand,
949 },
950 });
951 return;
952 }
953 }
954 }
955 }
956 }
957 }
958
959 inner.remote.lock().insert(command_id.clone(), channel_id);
960 inner.remote_defs.lock().insert(command_id, command);
961 let _ = channel.send(Message::RegisterCommandResponse {
962 id: MessageId::new_v4(),
963 thid: req_id,
964 response: RegisterResult::Ok { ok: True },
965 });
966 }
967
968 async fn handle_execute_request(
969 inner: Arc<Self>,
970 origin: Arc<dyn CommandChannel>,
971 req_id: MessageId,
972 command_id: String,
973 request: Value,
974 ) {
975 // Local handler?
976 let handler = inner
977 .local
978 .lock()
979 .get(&command_id)
980 .map(|e| e.handler.clone());
981 if let Some(handler) = handler {
982 let result = handler(request).await;
983 let response = match result {
984 Ok(v) => ExecuteResult::Ok {
985 ok: True,
986 // Void responses (`() → Value::Null`) are elided from
987 // the wire per the response schema (`result` optional).
988 result: value_to_wire(v),
989 },
990 Err(error) => ExecuteResult::Err { ok: False, error },
991 };
992 let _ = origin.send(Message::ExecuteCommandResponse {
993 id: MessageId::new_v4(),
994 thid: req_id,
995 response,
996 });
997 return;
998 }
999
1000 // Forward?
1001 let target_id = inner
1002 .remote
1003 .lock()
1004 .get(&command_id)
1005 .cloned()
1006 .or_else(|| inner.router_channel.clone());
1007
1008 let origin_id = origin.id().to_string();
1009 let Some(target_id) = target_id else {
1010 let _ = origin.send(Message::ExecuteCommandResponse {
1011 id: MessageId::new_v4(),
1012 thid: req_id,
1013 response: ExecuteResult::Err {
1014 ok: False,
1015 error: ExecuteError {
1016 code: ExecuteErrorCode::NotFound,
1017 message: format!("command not found: {command_id}"),
1018 },
1019 },
1020 });
1021 return;
1022 };
1023
1024 if target_id == origin_id {
1025 // Would loop; treat as not found.
1026 let _ = origin.send(Message::ExecuteCommandResponse {
1027 id: MessageId::new_v4(),
1028 thid: req_id,
1029 response: ExecuteResult::Err {
1030 ok: False,
1031 error: ExecuteError {
1032 code: ExecuteErrorCode::NotFound,
1033 message: format!("command not found: {command_id}"),
1034 },
1035 },
1036 });
1037 return;
1038 }
1039
1040 let target = inner.channels.lock().get(&target_id).cloned();
1041 let Some(target) = target else {
1042 let _ = origin.send(Message::ExecuteCommandResponse {
1043 id: MessageId::new_v4(),
1044 thid: req_id,
1045 response: ExecuteResult::Err {
1046 ok: False,
1047 error: ExecuteError {
1048 code: ExecuteErrorCode::ChannelDisconnected,
1049 message: "target channel disconnected".into(),
1050 },
1051 },
1052 });
1053 return;
1054 };
1055
1056 inner.routes.insert(
1057 req_id,
1058 RouteEntry {
1059 origin_channel: origin_id,
1060 target_channel: target_id,
1061 },
1062 );
1063 let _ = target.send(Message::ExecuteCommandRequest {
1064 id: req_id,
1065 command_id,
1066 request: value_to_wire(request),
1067 });
1068 }
1069
1070 fn handle_execute_response(inner: &Arc<Self>, thid: MessageId, response: ExecuteResult) {
1071 // Either this is a reply to a local call…
1072 if let Some(pending) = inner.execute_replies.remove(&thid) {
1073 let _ = pending.tx.send(response);
1074 return;
1075 }
1076
1077 // …or we forwarded this request and need to route the reply.
1078 if let Some(route) = inner.routes.remove(&thid) {
1079 let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1080 if let Some(origin) = origin {
1081 let _ = origin.send(Message::ExecuteCommandResponse {
1082 id: MessageId::new_v4(),
1083 thid,
1084 response,
1085 });
1086 }
1087 }
1088 }
1089
1090 fn handle_event(
1091 inner: &Arc<Self>,
1092 origin: Arc<dyn CommandChannel>,
1093 msg_id: MessageId,
1094 event_id: String,
1095 payload: Option<Value>,
1096 ) {
1097 if inner.seen_events.contains_key(&msg_id) {
1098 return;
1099 }
1100 inner.seen_events.insert(msg_id, ());
1101
1102 let payload_value = payload.clone().unwrap_or(Value::Null);
1103 let listeners: Vec<EventListener> = inner
1104 .event_listeners
1105 .lock()
1106 .get(&event_id)
1107 .map(|m| m.values().cloned().collect())
1108 .unwrap_or_default();
1109 for l in listeners {
1110 l(payload_value.clone());
1111 }
1112
1113 if event_id.starts_with('_') {
1114 return;
1115 }
1116
1117 let channels: Vec<Arc<dyn CommandChannel>> = inner
1118 .channels
1119 .lock()
1120 .iter()
1121 .filter(|(k, _)| k.as_str() != origin.id())
1122 .map(|(_, v)| v.clone())
1123 .collect();
1124 for ch in channels {
1125 let _ = ch.send(Message::Event {
1126 id: msg_id,
1127 event_id: event_id.clone(),
1128 payload: payload.clone(),
1129 });
1130 }
1131 }
1132
1133 /// Invoked by the driver once the channel returns `None` from recv.
1134 fn handle_channel_close(inner: &Arc<Self>, channel_id: &str) {
1135 // Drop the channel from the lookup table.
1136 inner.channels.lock().remove(channel_id);
1137
1138 // Drop every remote command owned by this channel, along with
1139 // its cached definition.
1140 let dropped_ids: Vec<String> = {
1141 let mut remote = inner.remote.lock();
1142 let to_drop: Vec<String> = remote
1143 .iter()
1144 .filter(|(_, owner)| *owner == channel_id)
1145 .map(|(id, _)| id.clone())
1146 .collect();
1147 for id in &to_drop {
1148 remote.remove(id);
1149 }
1150 to_drop
1151 };
1152 let mut remote_defs = inner.remote_defs.lock();
1153 for id in dropped_ids {
1154 remote_defs.remove(&id);
1155 }
1156 drop(remote_defs);
1157
1158 // Reject any pending executes whose response was expected from
1159 // this channel.
1160 let exec_ids: Vec<MessageId> = inner
1161 .execute_replies
1162 .snapshot_keys_where(|v| v.target_channel == channel_id);
1163 for id in exec_ids {
1164 if let Some(pending) = inner.execute_replies.remove(&id) {
1165 let _ = pending.tx.send(ExecuteResult::Err {
1166 ok: False,
1167 error: ExecuteError {
1168 code: ExecuteErrorCode::ChannelDisconnected,
1169 message: "channel disconnected".into(),
1170 },
1171 });
1172 }
1173 }
1174
1175 let reg_ids: Vec<MessageId> = inner
1176 .register_replies
1177 .snapshot_keys_where(|v| v.target_channel == channel_id);
1178 for id in reg_ids {
1179 if let Some(pending) = inner.register_replies.remove(&id) {
1180 let _ = pending.tx.send(RegisterOutcome::Disconnected);
1181 }
1182 }
1183
1184 // For every route where either endpoint is the dead channel,
1185 // notify the origin (if it is still alive).
1186 let route_ids: Vec<MessageId> = inner.routes.snapshot_keys_where(|r| {
1187 r.origin_channel == channel_id || r.target_channel == channel_id
1188 });
1189 for id in route_ids {
1190 if let Some(route) = inner.routes.remove(&id) {
1191 if route.origin_channel == channel_id {
1192 continue;
1193 }
1194 let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1195 if let Some(origin) = origin {
1196 let _ = origin.send(Message::ExecuteCommandResponse {
1197 id: MessageId::new_v4(),
1198 thid: id,
1199 response: ExecuteResult::Err {
1200 ok: False,
1201 error: ExecuteError {
1202 code: ExecuteErrorCode::ChannelDisconnected,
1203 message: "target channel disconnected".into(),
1204 },
1205 },
1206 });
1207 }
1208 }
1209 }
1210 }
1211}
1212
1213// ------- helpers -----------------------------------------------------
1214
1215fn command_error_to_execute(e: &CommandError, command_id: &str) -> ExecuteError {
1216 match e {
1217 CommandError::InvalidRequest { message, .. } => ExecuteError {
1218 code: ExecuteErrorCode::InvalidRequest,
1219 message: message.clone(),
1220 },
1221 CommandError::Internal { message, .. } => ExecuteError {
1222 code: ExecuteErrorCode::InternalError,
1223 message: message.clone(),
1224 },
1225 CommandError::Timeout => ExecuteError {
1226 code: ExecuteErrorCode::Timeout,
1227 message: "request timed out".into(),
1228 },
1229 CommandError::ChannelDisconnected => ExecuteError {
1230 code: ExecuteErrorCode::ChannelDisconnected,
1231 message: "channel disconnected".into(),
1232 },
1233 CommandError::NotFound(id) => ExecuteError {
1234 code: ExecuteErrorCode::NotFound,
1235 message: format!("command not found: {id}"),
1236 },
1237 _ => ExecuteError {
1238 code: ExecuteErrorCode::InternalError,
1239 message: format!("{e} [command {command_id}]"),
1240 },
1241 }
1242}
1243
1244fn error_to_command_error(err: ExecuteError, command_id: &str) -> CommandError {
1245 match err.code {
1246 ExecuteErrorCode::NotFound => CommandError::NotFound(err.message),
1247 ExecuteErrorCode::InvalidRequest => CommandError::InvalidRequest {
1248 command_id: command_id.into(),
1249 message: err.message,
1250 },
1251 ExecuteErrorCode::InternalError => CommandError::Internal {
1252 command_id: command_id.into(),
1253 message: err.message,
1254 },
1255 ExecuteErrorCode::Timeout => CommandError::Timeout,
1256 ExecuteErrorCode::ChannelDisconnected => CommandError::ChannelDisconnected,
1257 }
1258}
1259
1260// Small convenience on ExecuteError.
1261impl ExecuteError {
1262 fn into_command_error(self, command_id: &str) -> CommandError {
1263 error_to_command_error(self, command_id)
1264 }
1265}
1266
1267/// Collapse a serialized request/result/payload value to `None` when it
1268/// is JSON `null`. This is what makes void commands and events
1269/// spec-compliant on the wire: `request` / `result` / `payload` are all
1270/// optional fields in the JSON schemas, so an absent value must be
1271/// encoded by omitting the key, not by emitting `null`.
1272///
1273/// Used on every outgoing `execute.command.request`,
1274/// `execute.command.response` success, and `event` message.
1275fn value_to_wire(v: Value) -> Option<Value> {
1276 if v.is_null() {
1277 None
1278 } else {
1279 Some(v)
1280 }
1281}
1282
1283/// Serialize a strict-mode request value to JSON. Wraps `serde_json`
1284/// with the right error type for the strict `execute::<C>` path.
1285fn value_from_request<T: Serialize>(v: &T) -> Result<Value, CommandError> {
1286 serde_json::to_value(v).map_err(CommandError::Serde)
1287}