coralstack_cmd_ipc/registry.rs
1//! The [`CommandRegistry`] — core routing and execution hub.
2//!
3//! Mirrors the TypeScript `CommandRegistry` in
4//! `packages/cmd-ipc/src/registry/command-registry.ts`: it owns a local
5//! command table, a remote command table (command → owning channel), a
6//! set of connected channels, and a handful of [`TtlMap`]s correlating
7//! in-flight requests, forwarded routes, and recently-seen events.
8//!
9//! # Topology
10//!
11//! * **Root** registries have no `router_channel`. Unknown commands
12//! produce `NotFound` errors.
13//! * **Child** registries set `router_channel = Some(peer_id)`. Unknown
14//! commands, and new local registrations, are escalated upstream.
15//! * Events fan out across every connected channel; dedup by message
16//! id prevents echo loops in meshes.
17//!
18//! Private commands/events (identifiers starting with `_`) stay local
19//! — never escalated, never advertised, never broadcast.
20
21use std::collections::{BTreeMap, HashMap};
22use std::future::Future;
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::Duration;
26
27use futures::channel::oneshot;
28use futures::future::BoxFuture;
29use futures::stream::FuturesUnordered;
30use futures::{FutureExt, StreamExt};
31use parking_lot::Mutex;
32use serde::Serialize;
33use serde_json::Value;
34use uuid::Uuid;
35
36use crate::channel::CommandChannel;
37use crate::command::Command;
38use crate::error::{ChannelError, CommandError, ExecuteErrorCode, RegisterErrorCode};
39use crate::event::Event;
40use crate::message::{
41 CommandDef, ExecuteError, ExecuteResult, False, Message, MessageId, RegisterResult, True,
42};
43use crate::ttl_map::TtlMap;
44
45/// Configuration for a [`CommandRegistry`].
46pub struct Config {
47 /// Registry identifier used in log messages. Defaults to a random
48 /// UUID.
49 pub id: Option<String>,
50 /// Channel id to escalate unknown commands and new registrations
51 /// to. Leave `None` for a root registry.
52 pub router_channel: Option<String>,
53 /// How long a pending `execute` / `register` reply can wait before
54 /// being rejected with [`CommandError::Timeout`]. Zero disables
55 /// the TTL check (request hangs until the channel closes).
56 pub request_ttl: Duration,
57 /// How long a seen event id is remembered for dedup purposes.
58 pub event_ttl: Duration,
59 /// Maximum number of handler futures that may be in flight
60 /// concurrently on a single channel's pump. Once this cap is
61 /// reached, the pump stops pulling new messages off the channel
62 /// (applying upstream backpressure) until an in-flight handler
63 /// finishes. `0` disables the cap.
64 ///
65 /// **Optional in practice** — `Config::default()` sets this to
66 /// `256`, so callers using `..Default::default()` never need to
67 /// supply it. It is only *syntactically* required when you build
68 /// `Config { … }` field-by-field (Rust struct literals must list
69 /// every field). The cap exists so a misbehaving peer can't cause
70 /// unbounded handler-future buildup; the default is fine for
71 /// almost every workload.
72 pub max_in_flight_per_channel: usize,
73}
74
75impl Default for Config {
76 fn default() -> Self {
77 Self {
78 id: None,
79 router_channel: None,
80 request_ttl: Duration::from_secs(30),
81 event_ttl: Duration::from_secs(5),
82 max_in_flight_per_channel: 256,
83 }
84 }
85}
86
87type HandlerFn = dyn Fn(Value) -> BoxFuture<'static, Result<Value, ExecuteError>> + Send + Sync;
88type EventListener = Arc<dyn Fn(Value) + Send + Sync>;
89
90struct LocalEntry {
91 handler: Arc<HandlerFn>,
92 def: CommandDef,
93 is_private: bool,
94}
95
96struct PendingExecute {
97 tx: oneshot::Sender<ExecuteResult>,
98 target_channel: String,
99}
100
101/// Outcome delivered to the caller awaiting a `register.command.request`
102/// reply. The peer's `RegisterResult` is wrapped so the registry can
103/// synthesize timeout / disconnect errors distinguishable from the
104/// wire-level duplicate-command case.
105enum RegisterOutcome {
106 Wire(RegisterResult),
107 Timeout,
108 Disconnected,
109}
110
111struct PendingRegister {
112 tx: oneshot::Sender<RegisterOutcome>,
113 target_channel: String,
114}
115
116struct RouteEntry {
117 origin_channel: String,
118 target_channel: String,
119}
120
121/// Shared state behind the [`CommandRegistry`] Arc.
122struct Inner {
123 id: String,
124 router_channel: Option<String>,
125 local: Mutex<HashMap<String, LocalEntry>>,
126 /// command id -> owning channel id
127 remote: Mutex<HashMap<String, String>>,
128 /// command id -> advertised definition (description + schema)
129 /// kept parallel to `remote` so `list_commands` can render
130 /// the same richness for remote entries as for local ones.
131 remote_defs: Mutex<HashMap<String, CommandDef>>,
132 channels: Mutex<HashMap<String, Arc<dyn CommandChannel>>>,
133 execute_replies: TtlMap<MessageId, PendingExecute>,
134 register_replies: TtlMap<MessageId, PendingRegister>,
135 routes: TtlMap<MessageId, RouteEntry>,
136 seen_events: TtlMap<MessageId, ()>,
137 /// Event listeners keyed by event id then by a monotonic token, so
138 /// `add_event_listener` can return an unsubscribe closure that
139 /// removes just that one listener. `BTreeMap` preserves insertion
140 /// order (tokens are monotonically increasing) for dispatch.
141 event_listeners: Mutex<HashMap<String, BTreeMap<u64, EventListener>>>,
142 /// Monotonically-increasing token used to key event listeners.
143 next_listener_token: AtomicU64,
144 /// Per-channel in-flight cap, copied from [`Config`].
145 max_in_flight_per_channel: usize,
146}
147
148/// The main entry point of the crate.
149///
150/// A registry is cheap to clone: internally it's an `Arc<Inner>`.
151#[derive(Clone)]
152pub struct CommandRegistry {
153 inner: Arc<Inner>,
154}
155
156impl CommandRegistry {
157 pub fn new(cfg: Config) -> Self {
158 // On TTL expiry, deliver an explicit Timeout outcome on the
159 // pending oneshot so the caller's `rx.await` resolves to
160 // `CommandError::Timeout` rather than blocking forever (or
161 // collapsing to a generic `ChannelDisconnected` when the tx is
162 // dropped). Eviction is lazy — driven by `sweep_expired()`
163 // calls in the driver loop and registry hot paths.
164 let execute_replies =
165 TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingExecute| {
166 let _ = pending.tx.send(ExecuteResult::Err {
167 ok: False,
168 error: ExecuteError {
169 code: ExecuteErrorCode::Timeout,
170 message: "request timed out".into(),
171 },
172 });
173 });
174 let register_replies =
175 TtlMap::new(cfg.request_ttl).with_on_expire(|_, pending: PendingRegister| {
176 let _ = pending.tx.send(RegisterOutcome::Timeout);
177 });
178
179 let inner = Arc::new(Inner {
180 id: cfg.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
181 router_channel: cfg.router_channel,
182 local: Mutex::new(HashMap::new()),
183 remote: Mutex::new(HashMap::new()),
184 remote_defs: Mutex::new(HashMap::new()),
185 channels: Mutex::new(HashMap::new()),
186 execute_replies,
187 register_replies,
188 routes: TtlMap::new(cfg.request_ttl),
189 seen_events: TtlMap::new(cfg.event_ttl),
190 event_listeners: Mutex::new(HashMap::new()),
191 next_listener_token: AtomicU64::new(0),
192 max_in_flight_per_channel: cfg.max_in_flight_per_channel,
193 });
194 Self { inner }
195 }
196
197 /// Returns this registry's identifier.
198 pub fn id(&self) -> &str {
199 &self.inner.id
200 }
201
202 /// Returns the ids of every currently-registered channel, sorted.
203 ///
204 /// Mirrors the TypeScript library's `listChannels()` method.
205 pub fn list_channels(&self) -> Vec<String> {
206 let mut ids: Vec<String> = self.inner.channels.lock().keys().cloned().collect();
207 ids.sort();
208 ids
209 }
210
211 /// Returns the full [`CommandDef`] (id + description + schema) for
212 /// every reachable command — local (non-private) and remote. Remote
213 /// defs are those advertised via `register.command.request` or
214 /// `list.commands.response` on the channel.
215 ///
216 /// Mirrors the TypeScript library's `listCommands()` method.
217 /// Results are sorted by id. A command id is only included once even
218 /// if both a local and remote entry exist (local wins).
219 pub fn list_commands(&self) -> Vec<CommandDef> {
220 let mut out: HashMap<String, CommandDef> = HashMap::new();
221 for (id, entry) in self.inner.local.lock().iter() {
222 if !entry.is_private {
223 out.insert(id.clone(), entry.def.clone());
224 }
225 }
226 for (id, def) in self.inner.remote_defs.lock().iter() {
227 out.entry(id.clone()).or_insert_with(|| def.clone());
228 }
229 let mut v: Vec<CommandDef> = out.into_values().collect();
230 v.sort_by(|a, b| a.id.cmp(&b.id));
231 v
232 }
233
234 /// Register a command on this registry.
235 ///
236 /// The single registration entry point, covering both compile-time
237 /// and runtime commands:
238 ///
239 /// - **Compile-time**: pass an instance of a type that implements
240 /// [`Command`]. The `#[command]` / `#[command_service]` macros
241 /// generate such types from a plain `async fn`.
242 /// - **Runtime**: pass a [`DynCommand`] carrying owned id /
243 /// description / schema and a closure handler.
244 ///
245 /// Mirrors the TypeScript library's `registerCommand`.
246 ///
247 /// - Commands whose id starts with `_` stay local: they are never
248 /// escalated to a `router_channel` and never advertised to peers
249 /// via `list.commands.response`.
250 /// - Non-private commands are escalated upstream if this registry
251 /// has a `router_channel`; the local entry is only committed
252 /// after the router acks.
253 /// - The advertised schema is normalized via
254 /// [`crate::schema::normalize_schema`] on the way in, so every
255 /// schema leaving the registry is language-agnostic JSON Schema
256 /// regardless of how the caller built it.
257 pub async fn register_command<C: Command>(&self, cmd: C) -> Result<(), CommandError> {
258 let id = cmd.id().to_string();
259 let description = cmd.description().map(str::to_string);
260 let schema = cmd.schema().map(crate::schema::normalize_command_schema);
261 let is_private = id.starts_with('_');
262 let def = CommandDef {
263 id: id.clone(),
264 description,
265 schema,
266 };
267 let handler: Arc<HandlerFn> = Arc::new({
268 let cmd = Arc::new(cmd);
269 move |value: Value| {
270 let cmd = cmd.clone();
271 async move {
272 let req: C::Request =
273 serde_json::from_value(value).map_err(|e| ExecuteError {
274 code: ExecuteErrorCode::InvalidRequest,
275 message: e.to_string(),
276 })?;
277 let res = cmd
278 .handle(req)
279 .await
280 .map_err(|e| command_error_to_execute(&e, cmd.id()))?;
281 serde_json::to_value(res).map_err(|e| ExecuteError {
282 code: ExecuteErrorCode::InternalError,
283 message: e.to_string(),
284 })
285 }
286 .boxed()
287 }
288 });
289 self.register_inner(id, handler, def, is_private).await
290 }
291
292 async fn register_inner(
293 &self,
294 id: String,
295 handler: Arc<HandlerFn>,
296 def: CommandDef,
297 is_private: bool,
298 ) -> Result<(), CommandError> {
299 self.inner.execute_replies.sweep_expired();
300 self.inner.register_replies.sweep_expired();
301 // Duplicate check against the local table.
302 if self.inner.local.lock().contains_key(&id) {
303 return Err(CommandError::DuplicateCommand(id));
304 }
305
306 // Non-private commands escalate to the router before being added.
307 if !is_private {
308 if let Some(router_id) = self.inner.router_channel.clone() {
309 let router_ch = self.inner.channels.lock().get(&router_id).cloned();
310 if let Some(router_ch) = router_ch {
311 let req_id = MessageId::new_v4();
312 let (tx, rx) = oneshot::channel();
313 self.inner.register_replies.insert(
314 req_id,
315 PendingRegister {
316 tx,
317 target_channel: router_id.clone(),
318 },
319 );
320 router_ch
321 .send(Message::RegisterCommandRequest {
322 id: req_id,
323 meta: None,
324 command: def.clone(),
325 })
326 .map_err(|_| CommandError::ChannelDisconnected)?;
327 match rx.await {
328 Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
329 Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
330 return Err(match error {
331 RegisterErrorCode::DuplicateCommand => {
332 CommandError::DuplicateCommand(id)
333 }
334 });
335 }
336 Ok(RegisterOutcome::Timeout) => return Err(CommandError::Timeout),
337 Ok(RegisterOutcome::Disconnected) | Err(_) => {
338 return Err(CommandError::ChannelDisconnected);
339 }
340 }
341 }
342 }
343 }
344
345 self.inner.local.lock().insert(
346 id,
347 LocalEntry {
348 handler,
349 def,
350 is_private,
351 },
352 );
353 Ok(())
354 }
355
356 /// Connects a [`CommandChannel`] to this registry.
357 ///
358 /// Returns a driver future which must be polled by the caller's
359 /// executor (via `tokio::spawn`, `smol::spawn`,
360 /// `futures::executor::block_on`, …) for the registry to exchange
361 /// messages with the peer. The future completes when the channel
362 /// closes.
363 ///
364 /// Handler dispatch is **concurrent within the driver task**: the
365 /// pump pushes each incoming message's handler future into a
366 /// [`FuturesUnordered`] and cooperatively interleaves them with
367 /// the next `recv`, so a slow handler that awaits external work
368 /// no longer blocks subsequent messages on the same channel. The
369 /// number of simultaneously in-flight handlers is capped by
370 /// [`Config::max_in_flight_per_channel`] (default 256); at the
371 /// cap the pump applies backpressure to the channel rather than
372 /// dropping messages. For true multi-thread parallelism, wrap
373 /// the handler body in your runtime's `spawn` (e.g.
374 /// `tokio::spawn`) — the crate itself stays runtime-agnostic.
375 pub async fn register_channel(
376 &self,
377 channel: Arc<dyn CommandChannel>,
378 ) -> Result<impl Future<Output = ()> + Send + 'static, ChannelError> {
379 let id = channel.id().to_string();
380 {
381 let mut chans = self.inner.channels.lock();
382 if chans.contains_key(&id) {
383 return Err(ChannelError::Other(format!(
384 "channel with id `{id}` already registered"
385 )));
386 }
387 chans.insert(id.clone(), channel.clone());
388 }
389
390 channel.start().await?;
391
392 // Ask the peer for its command list. The response is handled
393 // by the driver loop, which will register each entry as remote.
394 if let Err(e) = channel.send(Message::ListCommandsRequest {
395 id: MessageId::new_v4(),
396 meta: None,
397 }) {
398 self.inner.channels.lock().remove(&id);
399 return Err(e);
400 }
401
402 let inner = self.inner.clone();
403 let ch = channel;
404 Ok(async move {
405 // In-flight handler futures cooperatively interleaved with
406 // message recv. This is what makes the pump non-blocking:
407 // when one handler awaits (sleep / network / forwarded
408 // call), the executor polls other handlers and the recv
409 // future on the same task, so a slow handler can no
410 // longer wedge every subsequent message behind it.
411 //
412 // Stays runtime-agnostic (no `tokio::spawn`); the user's
413 // executor still owns the single task driving this pump.
414 // True multi-thread parallelism is possible if the user
415 // wraps their handler body in `tokio::spawn` themselves.
416 let mut in_flight: FuturesUnordered<BoxFuture<'static, ()>> = FuturesUnordered::new();
417 let cap = inner.max_in_flight_per_channel;
418
419 loop {
420 // Backpressure: while at capacity, drain in-flight
421 // handlers instead of accepting new messages. Ordering
422 // of message *reception* is preserved — only dispatch
423 // fans out. Messages are never dropped.
424 while cap > 0 && in_flight.len() >= cap {
425 if in_flight.next().await.is_none() {
426 break;
427 }
428 }
429
430 // Opportunistic TTL sweep: any registry activity
431 // triggers a pass over the pending-reply / route maps
432 // so timed-out entries fire their on_expire callbacks
433 // (delivering Timeout to waiting callers). Cheap —
434 // O(n) over maps that are normally tiny.
435 inner.execute_replies.sweep_expired();
436 inner.register_replies.sweep_expired();
437 inner.routes.sweep_expired();
438
439 // Wait for either the next message OR for an in-flight
440 // handler to make progress. When `in_flight` is empty
441 // we'd otherwise busy-loop on its `.next()` returning
442 // `None`, so park on `pending()` in that case.
443 let next_msg = if in_flight.is_empty() {
444 ch.recv().await
445 } else {
446 futures::select_biased! {
447 // Drain finished handlers so the in_flight set
448 // shrinks promptly. `select_next_some` skips
449 // the empty case, but we've already guarded
450 // against that above.
451 _ = in_flight.select_next_some() => continue,
452 msg = ch.recv().fuse() => msg,
453 }
454 };
455
456 let Some(msg) = next_msg else { break };
457 in_flight.push(Inner::handle_message(inner.clone(), ch.clone(), msg).boxed());
458 }
459
460 // Channel closed — drain any in-flight handlers so their
461 // outgoing responses are sent before we tear the channel
462 // state down. Handlers that try to `send` on the now-closed
463 // channel will get `Err(Closed)` from the transport, which
464 // is the existing semantics.
465 while in_flight.next().await.is_some() {}
466 Inner::handle_channel_close(&inner, ch.id());
467 })
468 }
469
470 /// Executes a command identified by a compile-time [`Command`] type
471 /// — the **strict** form, giving the same compile-time type safety
472 /// that TypeScript's strict-mode `executeCommand<K>` gives via the
473 /// `CommandSchemaMap` type parameter.
474 ///
475 /// The command id comes from `C::ID`, the request type is pinned to
476 /// `C::Request`, and the response type is pinned to `C::Response`,
477 /// so the compiler rejects mismatches at the call site:
478 ///
479 /// ```ignore
480 /// let sum: i64 = registry.execute::<MathAdd>(AddReq { a: 2, b: 3 }).await?;
481 /// ```
482 ///
483 /// For commands whose id or payload shape is only known at runtime
484 /// (scripting hosts, FFI, plugins that advertise their own schema),
485 /// use [`execute_dyn`](Self::execute_dyn).
486 pub async fn execute<C: Command>(
487 &self,
488 request: C::Request,
489 ) -> Result<C::Response, CommandError>
490 where
491 C::Request: Serialize,
492 C::Response: serde::de::DeserializeOwned,
493 {
494 let req_value = value_from_request(&request)?;
495 let result = self.execute_raw_impl(C::ID.to_string(), req_value).await?;
496 let deserialized = serde_json::from_value(result.unwrap_or(Value::Null))?;
497 Ok(deserialized)
498 }
499
500 /// Executes a command whose id is only known at runtime — the
501 /// **loose** form, mirroring the TypeScript library's
502 /// `executeCommand(id, args)` in loose mode.
503 ///
504 /// Request and response are raw [`serde_json::Value`]s, so this is
505 /// the canonical entry point for plugin hosts, scripting runtimes,
506 /// FFI bridges, and any code where the schema is discovered via
507 /// [`list_commands`](Self::list_commands) rather than declared at
508 /// compile time.
509 ///
510 /// For statically-known commands, prefer [`execute`](Self::execute)
511 /// — it pins both types via the [`Command`] trait.
512 pub async fn execute_dyn(
513 &self,
514 command_id: &str,
515 request: Value,
516 ) -> Result<Value, CommandError> {
517 let result = self
518 .execute_raw_impl(command_id.to_string(), request)
519 .await?;
520 Ok(result.unwrap_or(Value::Null))
521 }
522
523 async fn execute_raw_impl(
524 &self,
525 command_id: String,
526 request: Value,
527 ) -> Result<Option<Value>, CommandError> {
528 // Flush any expired pending replies so stale entries fire their
529 // Timeout on_expire before we enqueue a new one.
530 self.inner.execute_replies.sweep_expired();
531 self.inner.register_replies.sweep_expired();
532 // 1) Local handler wins.
533 let local_handler = self
534 .inner
535 .local
536 .lock()
537 .get(&command_id)
538 .map(|entry| entry.handler.clone());
539 if let Some(handler) = local_handler {
540 return handler(request)
541 .await
542 .map(Some)
543 .map_err(|e| e.into_command_error(&command_id));
544 }
545
546 // 2) Known remote command.
547 let remote_target = self.inner.remote.lock().get(&command_id).cloned();
548 let target = match remote_target {
549 Some(t) => Some(t),
550 None => self.inner.router_channel.clone(),
551 };
552
553 let Some(target_id) = target else {
554 return Err(CommandError::NotFound(command_id));
555 };
556
557 let channel = self.inner.channels.lock().get(&target_id).cloned();
558 let Some(channel) = channel else {
559 return Err(CommandError::ChannelDisconnected);
560 };
561
562 self.forward_execute(command_id, request, &channel, target_id)
563 .await
564 }
565
566 async fn forward_execute(
567 &self,
568 command_id: String,
569 request: Value,
570 channel: &Arc<dyn CommandChannel>,
571 target_id: String,
572 ) -> Result<Option<Value>, CommandError> {
573 let req_id = MessageId::new_v4();
574 let (tx, rx) = oneshot::channel();
575 self.inner.execute_replies.insert(
576 req_id,
577 PendingExecute {
578 tx,
579 target_channel: target_id,
580 },
581 );
582 channel
583 .send(Message::ExecuteCommandRequest {
584 id: req_id,
585 meta: None,
586 command_id: command_id.clone(),
587 // Void requests are elided from the wire (Null → None)
588 // so peers expecting an absent `request` field (per the
589 // JSON Schema spec) don't see `"request": null`.
590 request: value_to_wire(request),
591 })
592 .map_err(|_| CommandError::ChannelDisconnected)?;
593
594 match rx.await {
595 Ok(ExecuteResult::Ok { result, .. }) => Ok(result),
596 Ok(ExecuteResult::Err { error, .. }) => Err(error_to_command_error(error, &command_id)),
597 Err(_) => {
598 self.inner.execute_replies.remove(&req_id);
599 Err(CommandError::ChannelDisconnected)
600 }
601 }
602 }
603
604 /// Emit an event. Dispatches to local listeners and — unless the
605 /// event id is private (starts with `_`) — broadcasts to every
606 /// connected channel.
607 ///
608 /// Works for both compile-time events (`#[event]`-annotated
609 /// structs) and runtime events ([`DynEvent`](crate::command::Command)
610 /// — actually [`DynEvent`](crate::event::DynEvent)). Id, description,
611 /// and schema are all read off the event instance.
612 pub fn emit<E: Event>(&self, event: E) -> Result<(), CommandError> {
613 let event_id = event.id().to_string();
614 let payload_value = serde_json::to_value(&event)?;
615 let msg_id = MessageId::new_v4();
616 self.inner.seen_events.insert(msg_id, ());
617
618 self.dispatch_event_locally(&event_id, &payload_value);
619
620 if !event_id.starts_with('_') {
621 let channels: Vec<Arc<dyn CommandChannel>> =
622 self.inner.channels.lock().values().cloned().collect();
623 // Void payloads (serde `()` → `Value::Null`) are elided
624 // from the wire per the event schema (`payload` is optional).
625 let wire_payload = value_to_wire(payload_value);
626 for ch in channels {
627 let _ = ch.send(Message::Event {
628 id: msg_id,
629 meta: None,
630 event_id: event_id.clone(),
631 payload: wire_payload.clone(),
632 });
633 }
634 }
635 Ok(())
636 }
637
638 /// Subscribe a typed listener. The callback receives a
639 /// deserialized `E` every time an event with id `E::ID` fires,
640 /// whether emitted locally or received from a connected channel.
641 ///
642 /// Returns an unsubscribe closure — call it (and drop it) to
643 /// remove just this listener. Ignoring the return value is fine;
644 /// the listener then lives for the life of the registry.
645 ///
646 /// Listeners for the same event fire in insertion order. Payloads
647 /// that fail to deserialize into `E` are silently dropped for
648 /// this listener — they still flow to any typed-for-Value
649 /// listeners registered via [`on_dyn`](Self::on_dyn).
650 pub fn on<E: Event + serde::de::DeserializeOwned>(
651 &self,
652 listener: impl Fn(E) + Send + Sync + 'static,
653 ) -> impl FnOnce() + Send + Sync + 'static {
654 self.install_listener(E::ID, move |value| {
655 if let Ok(typed) = serde_json::from_value::<E>(value) {
656 listener(typed);
657 }
658 })
659 }
660
661 /// Subscribe a dynamic listener by runtime id. The callback
662 /// receives the raw JSON payload. Use this when the event id is
663 /// only known at runtime (plugin runtimes, FFI, scripting hosts);
664 /// prefer [`on`](Self::on) whenever you have a compile-time
665 /// [`Event`] type.
666 ///
667 /// Same unsubscribe semantics as [`on`](Self::on).
668 pub fn on_dyn<F>(
669 &self,
670 event_id: impl Into<String>,
671 listener: F,
672 ) -> impl FnOnce() + Send + Sync + 'static
673 where
674 F: Fn(Value) + Send + Sync + 'static,
675 {
676 self.install_listener(&event_id.into(), listener)
677 }
678
679 fn install_listener<F>(
680 &self,
681 event_id: &str,
682 listener: F,
683 ) -> impl FnOnce() + Send + Sync + 'static
684 where
685 F: Fn(Value) + Send + Sync + 'static,
686 {
687 let token = self
688 .inner
689 .next_listener_token
690 .fetch_add(1, Ordering::Relaxed);
691 self.inner
692 .event_listeners
693 .lock()
694 .entry(event_id.to_string())
695 .or_default()
696 .insert(token, Arc::new(listener));
697
698 let inner = Arc::clone(&self.inner);
699 let event_id = event_id.to_string();
700 move || {
701 let mut map = inner.event_listeners.lock();
702 if let Some(slot) = map.get_mut(&event_id) {
703 slot.remove(&token);
704 if slot.is_empty() {
705 map.remove(&event_id);
706 }
707 }
708 }
709 }
710
711 fn dispatch_event_locally(&self, event_id: &str, payload: &Value) {
712 let listeners: Vec<EventListener> = self
713 .inner
714 .event_listeners
715 .lock()
716 .get(event_id)
717 .map(|m| m.values().cloned().collect())
718 .unwrap_or_default();
719 for l in listeners {
720 l(payload.clone());
721 }
722 }
723
724 /// Tears down the registry: awaits `close()` on every connected
725 /// channel, drops every local and remote command, and clears all
726 /// event listeners. In-flight executes and register requests fail
727 /// with [`CommandError::ChannelDisconnected`] via the existing
728 /// channel-close path.
729 ///
730 /// Mirrors the TypeScript library's `dispose()`, but async so that
731 /// transports doing real teardown work (HTTP flush, MCP goodbye,
732 /// plugin sandbox shutdown) complete before this returns.
733 ///
734 /// Callers normally don't need this — dropping the last
735 /// `CommandRegistry` clone releases the inner state automatically
736 /// via `Drop`. Use `dispose` when a *shared* registry (held through
737 /// multiple clones) needs to be forcibly torn down, or in tests.
738 pub async fn dispose(&self) {
739 // Snapshot channel arcs so we can call `close` without holding
740 // the channels lock for the duration.
741 let channels: Vec<Arc<dyn CommandChannel>> = {
742 let mut locked = self.inner.channels.lock();
743 let out: Vec<_> = locked.values().cloned().collect();
744 locked.clear();
745 out
746 };
747
748 // Await each channel's close sequentially. Channels define
749 // their own close semantics (InMemoryChannel is effectively
750 // synchronous; MCPServerChannel flushes the transport; Flow's
751 // SourceChannel tears down its QuickJS VM), so we let every
752 // implementation finish its teardown before returning.
753 for ch in channels {
754 ch.close().await;
755 }
756
757 self.inner.local.lock().clear();
758 self.inner.remote.lock().clear();
759 self.inner.remote_defs.lock().clear();
760 self.inner.event_listeners.lock().clear();
761 }
762}
763
764impl Inner {
765 fn local_command_defs(&self) -> Vec<CommandDef> {
766 self.local
767 .lock()
768 .values()
769 .filter(|e| !e.is_private)
770 .map(|e| e.def.clone())
771 .collect()
772 }
773
774 /// Central dispatcher invoked by each channel's driver loop.
775 async fn handle_message(inner: Arc<Self>, channel: Arc<dyn CommandChannel>, msg: Message) {
776 match msg {
777 Message::RegisterCommandRequest { id, command, .. } => {
778 Self::handle_register_request(inner, channel, id, command).await;
779 }
780 Message::RegisterCommandResponse { thid, response, .. } => {
781 if let Some(pending) = inner.register_replies.remove(&thid) {
782 let _ = pending.tx.send(RegisterOutcome::Wire(response));
783 }
784 }
785 Message::ListCommandsRequest { id, .. } => {
786 let commands = inner.local_command_defs();
787 let _ = channel.send(Message::ListCommandsResponse {
788 id: MessageId::new_v4(),
789 meta: None,
790 thid: id,
791 commands,
792 });
793 }
794 Message::ListCommandsResponse { commands, .. } => {
795 let channel_id = channel.id().to_string();
796 let mut remote = inner.remote.lock();
797 let mut remote_defs = inner.remote_defs.lock();
798 for cmd in commands {
799 // Normalize ingested schemas so the local cache
800 // matches what register() produces.
801 let cmd = CommandDef {
802 id: cmd.id,
803 description: cmd.description,
804 schema: cmd.schema.map(crate::schema::normalize_command_schema),
805 };
806 let entry_is_new = !remote.contains_key(&cmd.id);
807 if entry_is_new {
808 remote.insert(cmd.id.clone(), channel_id.clone());
809 }
810 // Always refresh the def (the latest advertisement wins).
811 remote_defs.insert(cmd.id.clone(), cmd);
812 }
813 }
814 Message::ExecuteCommandRequest {
815 id,
816 command_id,
817 request,
818 ..
819 } => {
820 Self::handle_execute_request(
821 inner,
822 channel,
823 id,
824 command_id,
825 request.unwrap_or(Value::Null),
826 )
827 .await;
828 }
829 Message::ExecuteCommandResponse { thid, response, .. } => {
830 Self::handle_execute_response(&inner, thid, response);
831 }
832 Message::Event {
833 id,
834 event_id,
835 payload,
836 ..
837 } => {
838 Self::handle_event(&inner, channel, id, event_id, payload);
839 }
840 }
841 }
842
843 async fn handle_register_request(
844 inner: Arc<Self>,
845 channel: Arc<dyn CommandChannel>,
846 req_id: MessageId,
847 command: CommandDef,
848 ) {
849 // Normalize ingested schemas so our cached copy is guaranteed
850 // to be language-agnostic JSON Schema, even if the peer didn't
851 // normalize on its side.
852 let command = CommandDef {
853 id: command.id,
854 description: command.description,
855 schema: command.schema.map(crate::schema::normalize_command_schema),
856 };
857 let channel_id = channel.id().to_string();
858 let command_id = command.id.clone();
859
860 // Duplicate against local?
861 let dup = inner.local.lock().contains_key(&command_id);
862 if dup {
863 let _ = channel.send(Message::RegisterCommandResponse {
864 id: MessageId::new_v4(),
865 meta: None,
866 thid: req_id,
867 response: RegisterResult::Err {
868 ok: False,
869 error: RegisterErrorCode::DuplicateCommand,
870 },
871 });
872 return;
873 }
874
875 // Already known in the remote table?
876 // - Same channel re-advertising: short-circuit with a success
877 // ack, do NOT re-escalate upstream. A re-escalation would
878 // bubble up as a duplicate rejection from the router and
879 // spuriously fail a legitimate re-registration (common when
880 // a plugin host re-advertises its command list).
881 // - Different channel claiming the same id: reject as duplicate.
882 let existing_owner = inner.remote.lock().get(&command_id).cloned();
883 match existing_owner {
884 Some(owner) if owner == channel_id => {
885 // Refresh the cached def and re-ack. No upstream traffic.
886 inner.remote_defs.lock().insert(command_id, command);
887 let _ = channel.send(Message::RegisterCommandResponse {
888 id: MessageId::new_v4(),
889 meta: None,
890 thid: req_id,
891 response: RegisterResult::Ok { ok: True },
892 });
893 return;
894 }
895 Some(_) => {
896 let _ = channel.send(Message::RegisterCommandResponse {
897 id: MessageId::new_v4(),
898 meta: None,
899 thid: req_id,
900 response: RegisterResult::Err {
901 ok: False,
902 error: RegisterErrorCode::DuplicateCommand,
903 },
904 });
905 return;
906 }
907 None => {}
908 }
909
910 // Escalate upstream if we have a router.
911 if let Some(router_id) = inner.router_channel.clone() {
912 if router_id != channel_id {
913 let router_ch = inner.channels.lock().get(&router_id).cloned();
914 if let Some(router_ch) = router_ch {
915 let up_id = MessageId::new_v4();
916 let (tx, rx) = oneshot::channel();
917 inner.register_replies.insert(
918 up_id,
919 PendingRegister {
920 tx,
921 target_channel: router_id,
922 },
923 );
924 if router_ch
925 .send(Message::RegisterCommandRequest {
926 id: up_id,
927 meta: None,
928 command: command.clone(),
929 })
930 .is_ok()
931 {
932 let up = rx.await;
933 match up {
934 Ok(RegisterOutcome::Wire(RegisterResult::Ok { .. })) => {}
935 Ok(RegisterOutcome::Wire(RegisterResult::Err { error, .. })) => {
936 let _ = channel.send(Message::RegisterCommandResponse {
937 id: MessageId::new_v4(),
938 meta: None,
939 thid: req_id,
940 response: RegisterResult::Err { ok: False, error },
941 });
942 return;
943 }
944 // Timeout / disconnected upstream: we have no
945 // wire-level error code for these on the
946 // register response, so surface them as
947 // duplicate_command to match the prior
948 // behaviour. The escalating caller's own
949 // register_command call will see the correct
950 // Timeout / ChannelDisconnected via its own
951 // pending entry.
952 Ok(RegisterOutcome::Timeout)
953 | Ok(RegisterOutcome::Disconnected)
954 | Err(_) => {
955 let _ = channel.send(Message::RegisterCommandResponse {
956 id: MessageId::new_v4(),
957 meta: None,
958 thid: req_id,
959 response: RegisterResult::Err {
960 ok: False,
961 error: RegisterErrorCode::DuplicateCommand,
962 },
963 });
964 return;
965 }
966 }
967 }
968 }
969 }
970 }
971
972 inner.remote.lock().insert(command_id.clone(), channel_id);
973 inner.remote_defs.lock().insert(command_id, command);
974 let _ = channel.send(Message::RegisterCommandResponse {
975 id: MessageId::new_v4(),
976 meta: None,
977 thid: req_id,
978 response: RegisterResult::Ok { ok: True },
979 });
980 }
981
982 async fn handle_execute_request(
983 inner: Arc<Self>,
984 origin: Arc<dyn CommandChannel>,
985 req_id: MessageId,
986 command_id: String,
987 request: Value,
988 ) {
989 // Local handler?
990 let handler = inner
991 .local
992 .lock()
993 .get(&command_id)
994 .map(|e| e.handler.clone());
995 if let Some(handler) = handler {
996 let result = handler(request).await;
997 let response = match result {
998 Ok(v) => ExecuteResult::Ok {
999 ok: True,
1000 // Void responses (`() → Value::Null`) are elided from
1001 // the wire per the response schema (`result` optional).
1002 result: value_to_wire(v),
1003 },
1004 Err(error) => ExecuteResult::Err { ok: False, error },
1005 };
1006 let _ = origin.send(Message::ExecuteCommandResponse {
1007 id: MessageId::new_v4(),
1008 meta: None,
1009 thid: req_id,
1010 response,
1011 });
1012 return;
1013 }
1014
1015 // Forward?
1016 let target_id = inner
1017 .remote
1018 .lock()
1019 .get(&command_id)
1020 .cloned()
1021 .or_else(|| inner.router_channel.clone());
1022
1023 let origin_id = origin.id().to_string();
1024 let Some(target_id) = target_id else {
1025 let _ = origin.send(Message::ExecuteCommandResponse {
1026 id: MessageId::new_v4(),
1027 meta: None,
1028 thid: req_id,
1029 response: ExecuteResult::Err {
1030 ok: False,
1031 error: ExecuteError {
1032 code: ExecuteErrorCode::NotFound,
1033 message: format!("command not found: {command_id}"),
1034 },
1035 },
1036 });
1037 return;
1038 };
1039
1040 if target_id == origin_id {
1041 // Would loop; treat as not found.
1042 let _ = origin.send(Message::ExecuteCommandResponse {
1043 id: MessageId::new_v4(),
1044 meta: None,
1045 thid: req_id,
1046 response: ExecuteResult::Err {
1047 ok: False,
1048 error: ExecuteError {
1049 code: ExecuteErrorCode::NotFound,
1050 message: format!("command not found: {command_id}"),
1051 },
1052 },
1053 });
1054 return;
1055 }
1056
1057 let target = inner.channels.lock().get(&target_id).cloned();
1058 let Some(target) = target else {
1059 let _ = origin.send(Message::ExecuteCommandResponse {
1060 id: MessageId::new_v4(),
1061 meta: None,
1062 thid: req_id,
1063 response: ExecuteResult::Err {
1064 ok: False,
1065 error: ExecuteError {
1066 code: ExecuteErrorCode::ChannelDisconnected,
1067 message: "target channel disconnected".into(),
1068 },
1069 },
1070 });
1071 return;
1072 };
1073
1074 inner.routes.insert(
1075 req_id,
1076 RouteEntry {
1077 origin_channel: origin_id,
1078 target_channel: target_id,
1079 },
1080 );
1081 let _ = target.send(Message::ExecuteCommandRequest {
1082 id: req_id,
1083 meta: None,
1084 command_id,
1085 request: value_to_wire(request),
1086 });
1087 }
1088
1089 fn handle_execute_response(inner: &Arc<Self>, thid: MessageId, response: ExecuteResult) {
1090 // Either this is a reply to a local call…
1091 if let Some(pending) = inner.execute_replies.remove(&thid) {
1092 let _ = pending.tx.send(response);
1093 return;
1094 }
1095
1096 // …or we forwarded this request and need to route the reply.
1097 if let Some(route) = inner.routes.remove(&thid) {
1098 let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1099 if let Some(origin) = origin {
1100 let _ = origin.send(Message::ExecuteCommandResponse {
1101 id: MessageId::new_v4(),
1102 meta: None,
1103 thid,
1104 response,
1105 });
1106 }
1107 }
1108 }
1109
1110 fn handle_event(
1111 inner: &Arc<Self>,
1112 origin: Arc<dyn CommandChannel>,
1113 msg_id: MessageId,
1114 event_id: String,
1115 payload: Option<Value>,
1116 ) {
1117 if inner.seen_events.contains_key(&msg_id) {
1118 return;
1119 }
1120 inner.seen_events.insert(msg_id, ());
1121
1122 let payload_value = payload.clone().unwrap_or(Value::Null);
1123 let listeners: Vec<EventListener> = inner
1124 .event_listeners
1125 .lock()
1126 .get(&event_id)
1127 .map(|m| m.values().cloned().collect())
1128 .unwrap_or_default();
1129 for l in listeners {
1130 l(payload_value.clone());
1131 }
1132
1133 if event_id.starts_with('_') {
1134 return;
1135 }
1136
1137 let channels: Vec<Arc<dyn CommandChannel>> = inner
1138 .channels
1139 .lock()
1140 .iter()
1141 .filter(|(k, _)| k.as_str() != origin.id())
1142 .map(|(_, v)| v.clone())
1143 .collect();
1144 for ch in channels {
1145 let _ = ch.send(Message::Event {
1146 id: msg_id,
1147 meta: None,
1148 event_id: event_id.clone(),
1149 payload: payload.clone(),
1150 });
1151 }
1152 }
1153
1154 /// Invoked by the driver once the channel returns `None` from recv.
1155 fn handle_channel_close(inner: &Arc<Self>, channel_id: &str) {
1156 // Drop the channel from the lookup table.
1157 inner.channels.lock().remove(channel_id);
1158
1159 // Drop every remote command owned by this channel, along with
1160 // its cached definition.
1161 let dropped_ids: Vec<String> = {
1162 let mut remote = inner.remote.lock();
1163 let to_drop: Vec<String> = remote
1164 .iter()
1165 .filter(|(_, owner)| *owner == channel_id)
1166 .map(|(id, _)| id.clone())
1167 .collect();
1168 for id in &to_drop {
1169 remote.remove(id);
1170 }
1171 to_drop
1172 };
1173 let mut remote_defs = inner.remote_defs.lock();
1174 for id in dropped_ids {
1175 remote_defs.remove(&id);
1176 }
1177 drop(remote_defs);
1178
1179 // Reject any pending executes whose response was expected from
1180 // this channel.
1181 let exec_ids: Vec<MessageId> = inner
1182 .execute_replies
1183 .snapshot_keys_where(|v| v.target_channel == channel_id);
1184 for id in exec_ids {
1185 if let Some(pending) = inner.execute_replies.remove(&id) {
1186 let _ = pending.tx.send(ExecuteResult::Err {
1187 ok: False,
1188 error: ExecuteError {
1189 code: ExecuteErrorCode::ChannelDisconnected,
1190 message: "channel disconnected".into(),
1191 },
1192 });
1193 }
1194 }
1195
1196 let reg_ids: Vec<MessageId> = inner
1197 .register_replies
1198 .snapshot_keys_where(|v| v.target_channel == channel_id);
1199 for id in reg_ids {
1200 if let Some(pending) = inner.register_replies.remove(&id) {
1201 let _ = pending.tx.send(RegisterOutcome::Disconnected);
1202 }
1203 }
1204
1205 // For every route where either endpoint is the dead channel,
1206 // notify the origin (if it is still alive).
1207 let route_ids: Vec<MessageId> = inner.routes.snapshot_keys_where(|r| {
1208 r.origin_channel == channel_id || r.target_channel == channel_id
1209 });
1210 for id in route_ids {
1211 if let Some(route) = inner.routes.remove(&id) {
1212 if route.origin_channel == channel_id {
1213 continue;
1214 }
1215 let origin = inner.channels.lock().get(&route.origin_channel).cloned();
1216 if let Some(origin) = origin {
1217 let _ = origin.send(Message::ExecuteCommandResponse {
1218 id: MessageId::new_v4(),
1219 meta: None,
1220 thid: id,
1221 response: ExecuteResult::Err {
1222 ok: False,
1223 error: ExecuteError {
1224 code: ExecuteErrorCode::ChannelDisconnected,
1225 message: "target channel disconnected".into(),
1226 },
1227 },
1228 });
1229 }
1230 }
1231 }
1232 }
1233}
1234
1235// ------- helpers -----------------------------------------------------
1236
1237fn command_error_to_execute(e: &CommandError, command_id: &str) -> ExecuteError {
1238 match e {
1239 CommandError::InvalidRequest { message, .. } => ExecuteError {
1240 code: ExecuteErrorCode::InvalidRequest,
1241 message: message.clone(),
1242 },
1243 CommandError::Internal { message, .. } => ExecuteError {
1244 code: ExecuteErrorCode::InternalError,
1245 message: message.clone(),
1246 },
1247 CommandError::Timeout => ExecuteError {
1248 code: ExecuteErrorCode::Timeout,
1249 message: "request timed out".into(),
1250 },
1251 CommandError::ChannelDisconnected => ExecuteError {
1252 code: ExecuteErrorCode::ChannelDisconnected,
1253 message: "channel disconnected".into(),
1254 },
1255 CommandError::NotFound(id) => ExecuteError {
1256 code: ExecuteErrorCode::NotFound,
1257 message: format!("command not found: {id}"),
1258 },
1259 _ => ExecuteError {
1260 code: ExecuteErrorCode::InternalError,
1261 message: format!("{e} [command {command_id}]"),
1262 },
1263 }
1264}
1265
1266fn error_to_command_error(err: ExecuteError, command_id: &str) -> CommandError {
1267 match err.code {
1268 ExecuteErrorCode::NotFound => CommandError::NotFound(err.message),
1269 ExecuteErrorCode::InvalidRequest => CommandError::InvalidRequest {
1270 command_id: command_id.into(),
1271 message: err.message,
1272 },
1273 ExecuteErrorCode::InternalError => CommandError::Internal {
1274 command_id: command_id.into(),
1275 message: err.message,
1276 },
1277 ExecuteErrorCode::Timeout => CommandError::Timeout,
1278 ExecuteErrorCode::ChannelDisconnected => CommandError::ChannelDisconnected,
1279 }
1280}
1281
1282// Small convenience on ExecuteError.
1283impl ExecuteError {
1284 fn into_command_error(self, command_id: &str) -> CommandError {
1285 error_to_command_error(self, command_id)
1286 }
1287}
1288
1289/// Collapse a serialized request/result/payload value to `None` when it
1290/// is JSON `null`. This is what makes void commands and events
1291/// spec-compliant on the wire: `request` / `result` / `payload` are all
1292/// optional fields in the JSON schemas, so an absent value must be
1293/// encoded by omitting the key, not by emitting `null`.
1294///
1295/// Used on every outgoing `execute.command.request`,
1296/// `execute.command.response` success, and `event` message.
1297fn value_to_wire(v: Value) -> Option<Value> {
1298 if v.is_null() {
1299 None
1300 } else {
1301 Some(v)
1302 }
1303}
1304
1305/// Serialize a strict-mode request value to JSON. Wraps `serde_json`
1306/// with the right error type for the strict `execute::<C>` path.
1307fn value_from_request<T: Serialize>(v: &T) -> Result<Value, CommandError> {
1308 serde_json::to_value(v).map_err(CommandError::Serde)
1309}