Skip to main content

objectiveai_mcp_proxy/
reverse_channel.rs

1//! WS reverse-channel transport for CLI-hosted upstreams.
2//!
3//! When the proxy is embedded in the API (per request), it is handed a
4//! [`ReverseChannel`] — the means to speak the `client_objectiveai_mcp`
5//! protocol over the request's WebSocket. Upstreams whose URL scheme is
6//! `ws` ([`WsUpstream`]) are reached through it instead of over HTTP:
7//!
8//! - `ws://objectiveai` → [`McpKind::ObjectiveAi`]
9//! - `ws:///owner/name/version/mcp` → [`McpKind::Other`]
10//!
11//! Direction split (the API owns the WS itself):
12//! - **send**: the proxy emits a `server_request::Request` into the
13//!   channel's mpsc; the API serializes it onto the shared WS sink.
14//! - **recv**: the API's recv loop demuxes incoming frames by type and
15//!   hands the proxy-bound ones back via [`ReverseChannel::deliver_response`]
16//!   (the 6 MCP `server_response` variants) and
17//!   [`ReverseChannel::deliver_client_request`] (`McpListChanged`). The
18//!   proxy correlates responses to its own outstanding requests by id.
19//!
20//! [`Upstream`] is the proxy's per-upstream handle — either an HTTP
21//! [`Connection`] or a [`WsUpstream`] — exposing the slice of the
22//! `Connection` interface the [`crate::session::Session`] depends on.
23
24use std::sync::Arc;
25use std::time::Duration;
26
27use dashmap::DashMap;
28use indexmap::IndexMap;
29use objectiveai_sdk::client_objectiveai_mcp::{
30    McpKind,
31    client_request::{self, McpListChangedKind},
32    client_response,
33    server_request::{self, InitializeRequest, Request as ServerRequest},
34    server_response::{self, JsonRpcResult, Response as ServerResponse},
35};
36use objectiveai_sdk::mcp::resource::{
37    ListResourcesRequest, ReadResourceRequestParams, ReadResourceResult, Resource,
38};
39use objectiveai_sdk::mcp::tool::{
40    CallToolRequestParams, CallToolResult, ListToolsRequest, Tool,
41};
42use objectiveai_sdk::mcp::{Connection, Error as McpError};
43use tokio::sync::{RwLock, mpsc, oneshot};
44
45/// A list-changed callback (mirrors `Connection::set_on_*_list_changed`).
46type ListChangedCb = Arc<dyn Fn() + Send + Sync>;
47
48struct Inner {
49    /// proxy → API → WS. The API drains the paired receiver and writes
50    /// each request onto the shared WS sink.
51    tx: mpsc::UnboundedSender<ServerRequest>,
52    /// Outstanding requests awaiting their `server_response`, by id.
53    pending: DashMap<String, oneshot::Sender<ServerResponse>>,
54    /// Per-upstream round-trip budget.
55    timeout: Duration,
56    /// list-changed callbacks per upstream `McpKind`: `(tools, resources)`.
57    /// Fired when a matching `client_request::McpListChanged` arrives.
58    list_changed: DashMap<McpKind, (Option<ListChangedCb>, Option<ListChangedCb>)>,
59}
60
61/// Cheaply-cloneable handle the proxy uses to speak over the WS.
62#[derive(Clone)]
63pub struct ReverseChannel(Arc<Inner>);
64
65impl std::fmt::Debug for ReverseChannel {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("ReverseChannel").finish_non_exhaustive()
68    }
69}
70
71impl ReverseChannel {
72    /// Build a channel. Returns the channel plus the receiver the API
73    /// drains (serializing each `server_request` onto the shared WS sink).
74    pub fn new(timeout: Duration) -> (Self, mpsc::UnboundedReceiver<ServerRequest>) {
75        let (tx, rx) = mpsc::unbounded_channel();
76        let inner = Inner {
77            tx,
78            pending: DashMap::new(),
79            timeout,
80            list_changed: DashMap::new(),
81        };
82        (Self(Arc::new(inner)), rx)
83    }
84
85    /// Emit a `server_request` and await its matching `server_response`,
86    /// bounded by the configured timeout. `id` is minted here; the API's
87    /// recv loop routes the reply back via [`Self::deliver_response`].
88    async fn request(
89        &self,
90        payload: server_request::Payload,
91        headers: IndexMap<String, String>,
92    ) -> Result<ServerResponse, McpError> {
93        let id = uuid::Uuid::new_v4().to_string();
94        let (resp_tx, resp_rx) = oneshot::channel();
95        self.0.pending.insert(id.clone(), resp_tx);
96        let request = ServerRequest {
97            id: id.clone(),
98            headers,
99            payload,
100        };
101        if self.0.tx.send(request).is_err() {
102            self.0.pending.remove(&id);
103            return Err(transport_error("reverse channel closed before send"));
104        }
105        match tokio::time::timeout(self.0.timeout, resp_rx).await {
106            Ok(Ok(response)) => Ok(response),
107            Ok(Err(_)) => {
108                self.0.pending.remove(&id);
109                Err(transport_error("reverse channel dropped before response"))
110            }
111            Err(_) => {
112                self.0.pending.remove(&id);
113                Err(transport_error("reverse channel timed out waiting for response"))
114            }
115        }
116    }
117
118    /// Hand a proxy-bound `server_response` (one of the 6 MCP variants)
119    /// back to the waiter that issued the matching request. Called by the
120    /// API's recv loop. Unknown id → dropped.
121    pub fn deliver_response(&self, response: ServerResponse) {
122        if let Some((_, tx)) = self.0.pending.remove(&response.id) {
123            let _ = tx.send(response);
124        }
125    }
126
127    /// Hand a proxy-bound `client_request` (today only `McpListChanged`)
128    /// to the proxy. Fires the registered list-changed callback for the
129    /// upstream, and returns the ack the API writes back over the WS.
130    pub fn deliver_client_request(
131        &self,
132        request: client_request::Request,
133    ) -> client_response::Response {
134        let client_request::Request { id, payload } = request;
135        match payload {
136            client_request::Payload::McpListChanged(change) => {
137                if let Some(cbs) = self.0.list_changed.get(&change.mcp_kind) {
138                    let cb = match change.kind {
139                        McpListChangedKind::Tools => cbs.0.clone(),
140                        McpListChangedKind::Resources => cbs.1.clone(),
141                    };
142                    drop(cbs);
143                    if let Some(cb) = cb {
144                        cb();
145                    }
146                }
147                client_response::Response::Ok { id }
148            }
149        }
150    }
151
152    fn set_tools_list_changed(&self, mcp_kind: McpKind, cb: ListChangedCb) {
153        let mut entry = self.0.list_changed.entry(mcp_kind).or_default();
154        entry.0 = Some(cb);
155    }
156
157    fn set_resources_list_changed(&self, mcp_kind: McpKind, cb: ListChangedCb) {
158        let mut entry = self.0.list_changed.entry(mcp_kind).or_default();
159        entry.1 = Some(cb);
160    }
161}
162
163/// A `ws://`-scheme upstream, reached over the [`ReverseChannel`]. Mirrors
164/// the slice of [`Connection`]'s interface the [`crate::session::Session`]
165/// uses, translating each op into a `server_request` carrying this
166/// upstream's [`McpKind`].
167pub struct WsUpstream {
168    channel: ReverseChannel,
169    mcp_kind: McpKind,
170    /// The `ws://…` URL this upstream was dialed with (used for filtering).
171    pub url: String,
172    /// Upstream `Mcp-Session-Id` returned by the CLI on `initialize`.
173    pub session_id: String,
174    /// Upstream `server_info.name` / `.version` from the `initialize`
175    /// reply — feeds the session's routing-prefix derivation.
176    server_name: String,
177    server_version: String,
178    /// Whether the upstream advertised the `tools` / `resources`
179    /// capability in its `initialize` reply. We must NOT issue
180    /// `tools/list` / `resources/list` against an upstream that didn't
181    /// advertise the capability: many servers (incl. the test
182    /// fixtures) 404 the un-advertised method, and a hard error there
183    /// fails the whole aggregate — and, on the post-init health probe,
184    /// fails the connect and churns endless re-`initialize`s. Mirrors
185    /// `mcp::Connection::has_{tools,resources}_cap`.
186    has_tools_cap: bool,
187    has_resources_cap: bool,
188    /// Persistent per-upstream headers captured at connect: the per-URL
189    /// set (`Authorization`, custom `X-*`, `X-OBJECTIVEAI-ARGUMENTS`)
190    /// plus whatever identity headers were present at dial. Never
191    /// mutated after connect — mirrors the SDK `Connection`'s base
192    /// `headers`. The transient subset is overridden per request by
193    /// `extra_headers`; the per-URL subset has no overlay key, so it
194    /// always survives on every request.
195    base_headers: IndexMap<String, String>,
196    /// Mutable transient-identity overlay, full-replaced every turn by
197    /// `apply_transient_headers` → `set_extra_headers`. Overrides
198    /// `base_headers` per key (mirrors `Connection::extra_headers`).
199    /// Starts empty: until the first refresh, `base_headers` alone
200    /// carries the dial-time identity headers.
201    extra_headers: RwLock<IndexMap<String, String>>,
202}
203
204impl std::fmt::Debug for WsUpstream {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("WsUpstream")
207            .field("url", &self.url)
208            .field("session_id", &self.session_id)
209            .finish_non_exhaustive()
210    }
211}
212
213impl WsUpstream {
214    /// Headers to stamp on every outbound `server_request`, mirroring
215    /// `Connection::build_request_headers` exactly: the persistent
216    /// `base_headers` first, then the `extra_headers` transient overlay
217    /// (overrides per key), then this upstream's `Mcp-Session-Id` last
218    /// (so it can never be shadowed). Per-URL headers live only in
219    /// `base_headers` (no overlay key collides), so they're present on
220    /// EVERY request — identical to the HTTP path.
221    async fn headers(&self) -> IndexMap<String, String> {
222        let mut h = self.base_headers.clone();
223        for (k, v) in self.extra_headers.read().await.iter() {
224            h.insert(k.clone(), v.clone());
225        }
226        h.insert(
227            crate::upstream::MCP_SESSION_ID_KEY.to_string(),
228            self.session_id.clone(),
229        );
230        h
231    }
232
233    pub async fn list_tools(&self) -> Result<Arc<Vec<Tool>>, Arc<McpError>> {
234        // Capability gate — an upstream that didn't advertise `tools`
235        // has no `tools/list`; calling it anyway 404s on most servers.
236        if !self.has_tools_cap {
237            return Ok(Arc::new(Vec::new()));
238        }
239        let headers = self.headers().await;
240        let response = self
241            .channel
242            .request(
243                server_request::Payload::ToolsList {
244                    mcp_kind: self.mcp_kind.clone(),
245                    params: ListToolsRequest { cursor: None },
246                },
247                headers,
248            )
249            .await
250            .map_err(Arc::new)?;
251        match response.payload {
252            server_response::Payload::ToolsList { result, .. } => {
253                Ok(Arc::new(unwrap_rpc(&self.url, result).map_err(Arc::new)?.tools))
254            }
255            other => Err(Arc::new(variant_mismatch(&self.url, "tools_list", &other))),
256        }
257    }
258
259    pub async fn list_resources(&self) -> Result<Arc<Vec<Resource>>, Arc<McpError>> {
260        // Capability gate — an upstream that didn't advertise
261        // `resources` has no `resources/list`; calling it anyway 404s
262        // on most servers (e.g. the tools-only plugin fixtures).
263        if !self.has_resources_cap {
264            return Ok(Arc::new(Vec::new()));
265        }
266        let headers = self.headers().await;
267        let response = self
268            .channel
269            .request(
270                server_request::Payload::ResourcesList {
271                    mcp_kind: self.mcp_kind.clone(),
272                    params: ListResourcesRequest { cursor: None },
273                },
274                headers,
275            )
276            .await
277            .map_err(Arc::new)?;
278        match response.payload {
279            server_response::Payload::ResourcesList { result, .. } => {
280                Ok(Arc::new(unwrap_rpc(&self.url, result).map_err(Arc::new)?.resources))
281            }
282            other => Err(Arc::new(variant_mismatch(&self.url, "resources_list", &other))),
283        }
284    }
285
286    pub async fn call_tool(
287        &self,
288        params: &CallToolRequestParams,
289    ) -> Result<CallToolResult, McpError> {
290        let headers = self.headers().await;
291        let response = self
292            .channel
293            .request(
294                server_request::Payload::ToolsCall {
295                    mcp_kind: self.mcp_kind.clone(),
296                    params: params.clone(),
297                },
298                headers,
299            )
300            .await?;
301        match response.payload {
302            server_response::Payload::ToolsCall { result, .. } => unwrap_rpc(&self.url, result),
303            other => Err(variant_mismatch(&self.url, "tools_call", &other)),
304        }
305    }
306
307    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
308        let headers = self.headers().await;
309        let response = self
310            .channel
311            .request(
312                server_request::Payload::ResourcesRead {
313                    mcp_kind: self.mcp_kind.clone(),
314                    params: ReadResourceRequestParams {
315                        uri: uri.to_string(),
316                    },
317                },
318                headers,
319            )
320            .await?;
321        match response.payload {
322            server_response::Payload::ResourcesRead { result, .. } => unwrap_rpc(&self.url, result),
323            other => Err(variant_mismatch(&self.url, "resources_read", &other)),
324        }
325    }
326
327    pub async fn delete(&self) -> Result<(), McpError> {
328        let headers = self.headers().await;
329        let response = self
330            .channel
331            .request(
332                server_request::Payload::SessionTerminate {
333                    mcp_kind: self.mcp_kind.clone(),
334                },
335                headers,
336            )
337            .await?;
338        match response.payload {
339            server_response::Payload::SessionTerminate { result, .. } => unwrap_rpc(&self.url, result),
340            other => Err(variant_mismatch(&self.url, "session_terminate", &other)),
341        }
342    }
343
344    pub fn set_on_tools_list_changed<F>(&self, callback: F)
345    where
346        F: Fn() + Send + Sync + 'static,
347    {
348        self.channel
349            .set_tools_list_changed(self.mcp_kind.clone(), Arc::new(callback));
350    }
351
352    pub fn set_on_resources_list_changed<F>(&self, callback: F)
353    where
354        F: Fn() + Send + Sync + 'static,
355    {
356        self.channel
357            .set_resources_list_changed(self.mcp_kind.clone(), Arc::new(callback));
358    }
359
360    pub async fn set_extra_headers(&self, extras: IndexMap<String, String>) {
361        *self.extra_headers.write().await = extras;
362    }
363}
364
365/// A per-upstream handle: HTTP [`Connection`] or WS [`WsUpstream`]. Exposes
366/// exactly the surface [`crate::session::Session`] + `handle_delete` use.
367#[derive(Debug)]
368pub enum Upstream {
369    Http(Connection),
370    Ws(WsUpstream),
371}
372
373impl Upstream {
374    pub fn url(&self) -> &str {
375        match self {
376            Upstream::Http(c) => &c.url,
377            Upstream::Ws(w) => &w.url,
378        }
379    }
380
381    pub fn session_id(&self) -> &str {
382        match self {
383            Upstream::Http(c) => &c.session_id,
384            Upstream::Ws(w) => &w.session_id,
385        }
386    }
387
388    /// Upstream `server_info.name` — used to derive the session's routing
389    /// prefix. (`Connection` exposes it via `initialize_result`.)
390    pub fn server_name(&self) -> &str {
391        match self {
392            Upstream::Http(c) => &c.initialize_result.server_info.name,
393            Upstream::Ws(w) => &w.server_name,
394        }
395    }
396
397    /// Upstream `server_info.version` — the prefix collision tie-breaker.
398    pub fn server_version(&self) -> &str {
399        match self {
400            Upstream::Http(c) => &c.initialize_result.server_info.version,
401            Upstream::Ws(w) => &w.server_version,
402        }
403    }
404
405    pub async fn list_tools(&self) -> Result<Arc<Vec<Tool>>, Arc<McpError>> {
406        match self {
407            Upstream::Http(c) => c.list_tools().await,
408            Upstream::Ws(w) => w.list_tools().await,
409        }
410    }
411
412    pub async fn list_resources(&self) -> Result<Arc<Vec<Resource>>, Arc<McpError>> {
413        match self {
414            Upstream::Http(c) => c.list_resources().await,
415            Upstream::Ws(w) => w.list_resources().await,
416        }
417    }
418
419    pub async fn call_tool(
420        &self,
421        params: &CallToolRequestParams,
422    ) -> Result<CallToolResult, McpError> {
423        match self {
424            Upstream::Http(c) => c.call_tool(params).await,
425            Upstream::Ws(w) => w.call_tool(params).await,
426        }
427    }
428
429    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
430        match self {
431            Upstream::Http(c) => c.read_resource(uri).await,
432            Upstream::Ws(w) => w.read_resource(uri).await,
433        }
434    }
435
436    pub async fn delete(&self) -> Result<(), McpError> {
437        match self {
438            Upstream::Http(c) => c.delete().await,
439            Upstream::Ws(w) => w.delete().await,
440        }
441    }
442
443    pub fn set_on_tools_list_changed<F>(&self, callback: F)
444    where
445        F: Fn() + Send + Sync + 'static,
446    {
447        match self {
448            Upstream::Http(c) => c.set_on_tools_list_changed(callback),
449            Upstream::Ws(w) => w.set_on_tools_list_changed(callback),
450        }
451    }
452
453    pub fn set_on_resources_list_changed<F>(&self, callback: F)
454    where
455        F: Fn() + Send + Sync + 'static,
456    {
457        match self {
458            Upstream::Http(c) => c.set_on_resources_list_changed(callback),
459            Upstream::Ws(w) => w.set_on_resources_list_changed(callback),
460        }
461    }
462
463    pub async fn set_extra_headers(&self, extras: IndexMap<String, String>) {
464        match self {
465            Upstream::Http(c) => c.set_extra_headers(extras).await,
466            Upstream::Ws(w) => w.set_extra_headers(extras).await,
467        }
468    }
469}
470
471/// Parse a `ws://objectiveai` / `ws:///owner/name/version/mcp` URL into
472/// its [`McpKind`]. Returns `None` for any other shape.
473pub fn parse_ws_mcp_kind(url: &str) -> Option<McpKind> {
474    let rest = url.strip_prefix("ws://")?;
475    // Drop any `?query` (plugin args ride there, parsed separately).
476    let rest = rest.split('?').next().unwrap_or(rest);
477    // `ws://objectiveai` → host "objectiveai", no path.
478    if rest == "objectiveai" {
479        return Some(McpKind::ObjectiveAi);
480    }
481    // `ws:///owner/name/version/mcp` → empty host, leading '/'.
482    let path = rest.strip_prefix('/')?;
483    let parts: Vec<&str> = path.split('/').collect();
484    if let [owner, name, version, mcp] = parts.as_slice() {
485        if !owner.is_empty() && !name.is_empty() && !version.is_empty() && !mcp.is_empty() {
486            return Some(McpKind::Other {
487                owner: (*owner).to_string(),
488                name: (*name).to_string(),
489                version: (*version).to_string(),
490                mcp: (*mcp).to_string(),
491            });
492        }
493    }
494    None
495}
496
497/// `initialize` a `ws://` upstream over `channel` and build its
498/// [`WsUpstream`]. `headers` is the full set sent on the `initialize`
499/// request — the session-global transient identity headers, plus (on
500/// resume) the upstream `Mcp-Session-Id` and any auth. `args` carries
501/// plugin init arguments (empty for `objectiveai`).
502pub async fn connect_ws(
503    channel: ReverseChannel,
504    url: String,
505    mcp_kind: McpKind,
506    args: IndexMap<String, Option<String>>,
507    mut headers: IndexMap<String, String>,
508) -> Result<WsUpstream, McpError> {
509    let response = channel
510        .request(
511            server_request::Payload::Initialize {
512                mcp_kind: mcp_kind.clone(),
513                params: InitializeRequest { args },
514            },
515            headers.clone(),
516        )
517        .await?;
518    let reply = match response.payload {
519        server_response::Payload::Initialize { result, .. } => unwrap_rpc(&url, result)?,
520        other => return Err(variant_mismatch(&url, "initialize", &other)),
521    };
522    // The per-request stamped set drops the resume `Mcp-Session-Id`
523    // ([`WsUpstream::headers`] re-adds whatever the upstream just minted)
524    // but keeps the transient identity + auth so the post-init health
525    // probe + every later call still pass the conduit's transient check.
526    headers.shift_remove(crate::upstream::MCP_SESSION_ID_KEY);
527    let has_tools_cap = reply.result.capabilities.tools.is_some();
528    let has_resources_cap = reply.result.capabilities.resources.is_some();
529    Ok(WsUpstream {
530        channel,
531        mcp_kind,
532        url,
533        session_id: reply.mcp_session_id,
534        server_name: reply.result.server_info.name,
535        server_version: reply.result.server_info.version,
536        has_tools_cap,
537        has_resources_cap,
538        // The connect-time set (per-URL ∪ dial-time identity) is the
539        // persistent base; the transient overlay starts empty and is
540        // filled by the first `set_extra_headers`. Mirrors the SDK
541        // `Connection`, where connect headers are the base and
542        // `extra_headers` begins empty.
543        base_headers: headers,
544        extra_headers: RwLock::new(IndexMap::new()),
545    })
546}
547
548fn unwrap_rpc<R>(url: &str, result: JsonRpcResult<R>) -> Result<R, McpError> {
549    match result {
550        JsonRpcResult::Ok { result } => Ok(result),
551        JsonRpcResult::Err {
552            code,
553            message,
554            data,
555        } => Err(McpError::JsonRpc {
556            url: url.to_string(),
557            code,
558            message,
559            data,
560        }),
561    }
562}
563
564fn transport_error(message: &str) -> McpError {
565    McpError::MalformedResponse {
566        url: "ws".to_string(),
567        message: message.to_string(),
568    }
569}
570
571fn variant_mismatch(url: &str, expected: &str, got: &server_response::Payload) -> McpError {
572    McpError::MalformedResponse {
573        url: url.to_string(),
574        message: format!(
575            "reverse channel returned wrong payload variant: expected {expected}, got {}",
576            got_variant_name(got),
577        ),
578    }
579}
580
581fn got_variant_name(p: &server_response::Payload) -> &'static str {
582    use server_response::Payload as P;
583    match p {
584        P::Initialize { .. } => "initialize",
585        P::ToolsList { .. } => "tools_list",
586        P::ToolsCall { .. } => "tools_call",
587        P::ResourcesList { .. } => "resources_list",
588        P::ResourcesRead { .. } => "resources_read",
589        P::SessionTerminate { .. } => "session_terminate",
590        P::ReadMessageQueue(_) => "read_message_queue",
591        P::Retrieve(_) => "retrieve",
592    }
593}