1use std::sync::Arc;
25use std::time::Duration;
26
27use dashmap::DashMap;
28use indexmap::IndexMap;
29use objectiveai_sdk::client_objectiveai_mcp::{
30 McpKind,
31 client_request::{self, McpListChangedKind},
32 client_response,
33 server_request::{self, InitializeRequest, Request as ServerRequest},
34 server_response::{self, JsonRpcResult, Response as ServerResponse},
35};
36use objectiveai_sdk::mcp::resource::{
37 ListResourcesRequest, ReadResourceRequestParams, ReadResourceResult, Resource,
38};
39use objectiveai_sdk::mcp::tool::{
40 CallToolRequestParams, CallToolResult, ListToolsRequest, Tool,
41};
42use objectiveai_sdk::mcp::{Connection, Error as McpError};
43use tokio::sync::{RwLock, mpsc, oneshot};
44
45type ListChangedCb = Arc<dyn Fn() + Send + Sync>;
47
48struct Inner {
49 tx: mpsc::UnboundedSender<ServerRequest>,
52 pending: DashMap<String, oneshot::Sender<ServerResponse>>,
54 timeout: Duration,
56 list_changed: DashMap<McpKind, (Option<ListChangedCb>, Option<ListChangedCb>)>,
59}
60
61#[derive(Clone)]
63pub struct ReverseChannel(Arc<Inner>);
64
65impl std::fmt::Debug for ReverseChannel {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("ReverseChannel").finish_non_exhaustive()
68 }
69}
70
71impl ReverseChannel {
72 pub fn new(timeout: Duration) -> (Self, mpsc::UnboundedReceiver<ServerRequest>) {
75 let (tx, rx) = mpsc::unbounded_channel();
76 let inner = Inner {
77 tx,
78 pending: DashMap::new(),
79 timeout,
80 list_changed: DashMap::new(),
81 };
82 (Self(Arc::new(inner)), rx)
83 }
84
85 async fn request(
89 &self,
90 payload: server_request::Payload,
91 headers: IndexMap<String, String>,
92 ) -> Result<ServerResponse, McpError> {
93 let id = uuid::Uuid::new_v4().to_string();
94 let (resp_tx, resp_rx) = oneshot::channel();
95 self.0.pending.insert(id.clone(), resp_tx);
96 let request = ServerRequest {
97 id: id.clone(),
98 headers,
99 payload,
100 };
101 if self.0.tx.send(request).is_err() {
102 self.0.pending.remove(&id);
103 return Err(transport_error("reverse channel closed before send"));
104 }
105 match tokio::time::timeout(self.0.timeout, resp_rx).await {
106 Ok(Ok(response)) => Ok(response),
107 Ok(Err(_)) => {
108 self.0.pending.remove(&id);
109 Err(transport_error("reverse channel dropped before response"))
110 }
111 Err(_) => {
112 self.0.pending.remove(&id);
113 Err(transport_error("reverse channel timed out waiting for response"))
114 }
115 }
116 }
117
118 pub fn deliver_response(&self, response: ServerResponse) {
122 if let Some((_, tx)) = self.0.pending.remove(&response.id) {
123 let _ = tx.send(response);
124 }
125 }
126
127 pub fn deliver_client_request(
131 &self,
132 request: client_request::Request,
133 ) -> client_response::Response {
134 let client_request::Request { id, payload } = request;
135 match payload {
136 client_request::Payload::McpListChanged(change) => {
137 if let Some(cbs) = self.0.list_changed.get(&change.mcp_kind) {
138 let cb = match change.kind {
139 McpListChangedKind::Tools => cbs.0.clone(),
140 McpListChangedKind::Resources => cbs.1.clone(),
141 };
142 drop(cbs);
143 if let Some(cb) = cb {
144 cb();
145 }
146 }
147 client_response::Response::Ok { id }
148 }
149 }
150 }
151
152 fn set_tools_list_changed(&self, mcp_kind: McpKind, cb: ListChangedCb) {
153 let mut entry = self.0.list_changed.entry(mcp_kind).or_default();
154 entry.0 = Some(cb);
155 }
156
157 fn set_resources_list_changed(&self, mcp_kind: McpKind, cb: ListChangedCb) {
158 let mut entry = self.0.list_changed.entry(mcp_kind).or_default();
159 entry.1 = Some(cb);
160 }
161}
162
163pub struct WsUpstream {
168 channel: ReverseChannel,
169 mcp_kind: McpKind,
170 pub url: String,
172 pub session_id: String,
174 server_name: String,
177 server_version: String,
178 has_tools_cap: bool,
187 has_resources_cap: bool,
188 base_headers: IndexMap<String, String>,
196 extra_headers: RwLock<IndexMap<String, String>>,
202}
203
204impl std::fmt::Debug for WsUpstream {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("WsUpstream")
207 .field("url", &self.url)
208 .field("session_id", &self.session_id)
209 .finish_non_exhaustive()
210 }
211}
212
213impl WsUpstream {
214 async fn headers(&self) -> IndexMap<String, String> {
222 let mut h = self.base_headers.clone();
223 for (k, v) in self.extra_headers.read().await.iter() {
224 h.insert(k.clone(), v.clone());
225 }
226 h.insert(
227 crate::upstream::MCP_SESSION_ID_KEY.to_string(),
228 self.session_id.clone(),
229 );
230 h
231 }
232
233 pub async fn list_tools(&self) -> Result<Arc<Vec<Tool>>, Arc<McpError>> {
234 if !self.has_tools_cap {
237 return Ok(Arc::new(Vec::new()));
238 }
239 let headers = self.headers().await;
240 let response = self
241 .channel
242 .request(
243 server_request::Payload::ToolsList {
244 mcp_kind: self.mcp_kind.clone(),
245 params: ListToolsRequest { cursor: None },
246 },
247 headers,
248 )
249 .await
250 .map_err(Arc::new)?;
251 match response.payload {
252 server_response::Payload::ToolsList { result, .. } => {
253 Ok(Arc::new(unwrap_rpc(&self.url, result).map_err(Arc::new)?.tools))
254 }
255 other => Err(Arc::new(variant_mismatch(&self.url, "tools_list", &other))),
256 }
257 }
258
259 pub async fn list_resources(&self) -> Result<Arc<Vec<Resource>>, Arc<McpError>> {
260 if !self.has_resources_cap {
264 return Ok(Arc::new(Vec::new()));
265 }
266 let headers = self.headers().await;
267 let response = self
268 .channel
269 .request(
270 server_request::Payload::ResourcesList {
271 mcp_kind: self.mcp_kind.clone(),
272 params: ListResourcesRequest { cursor: None },
273 },
274 headers,
275 )
276 .await
277 .map_err(Arc::new)?;
278 match response.payload {
279 server_response::Payload::ResourcesList { result, .. } => {
280 Ok(Arc::new(unwrap_rpc(&self.url, result).map_err(Arc::new)?.resources))
281 }
282 other => Err(Arc::new(variant_mismatch(&self.url, "resources_list", &other))),
283 }
284 }
285
286 pub async fn call_tool(
287 &self,
288 params: &CallToolRequestParams,
289 ) -> Result<CallToolResult, McpError> {
290 let headers = self.headers().await;
291 let response = self
292 .channel
293 .request(
294 server_request::Payload::ToolsCall {
295 mcp_kind: self.mcp_kind.clone(),
296 params: params.clone(),
297 },
298 headers,
299 )
300 .await?;
301 match response.payload {
302 server_response::Payload::ToolsCall { result, .. } => unwrap_rpc(&self.url, result),
303 other => Err(variant_mismatch(&self.url, "tools_call", &other)),
304 }
305 }
306
307 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
308 let headers = self.headers().await;
309 let response = self
310 .channel
311 .request(
312 server_request::Payload::ResourcesRead {
313 mcp_kind: self.mcp_kind.clone(),
314 params: ReadResourceRequestParams {
315 uri: uri.to_string(),
316 },
317 },
318 headers,
319 )
320 .await?;
321 match response.payload {
322 server_response::Payload::ResourcesRead { result, .. } => unwrap_rpc(&self.url, result),
323 other => Err(variant_mismatch(&self.url, "resources_read", &other)),
324 }
325 }
326
327 pub async fn delete(&self) -> Result<(), McpError> {
328 let headers = self.headers().await;
329 let response = self
330 .channel
331 .request(
332 server_request::Payload::SessionTerminate {
333 mcp_kind: self.mcp_kind.clone(),
334 },
335 headers,
336 )
337 .await?;
338 match response.payload {
339 server_response::Payload::SessionTerminate { result, .. } => unwrap_rpc(&self.url, result),
340 other => Err(variant_mismatch(&self.url, "session_terminate", &other)),
341 }
342 }
343
344 pub fn set_on_tools_list_changed<F>(&self, callback: F)
345 where
346 F: Fn() + Send + Sync + 'static,
347 {
348 self.channel
349 .set_tools_list_changed(self.mcp_kind.clone(), Arc::new(callback));
350 }
351
352 pub fn set_on_resources_list_changed<F>(&self, callback: F)
353 where
354 F: Fn() + Send + Sync + 'static,
355 {
356 self.channel
357 .set_resources_list_changed(self.mcp_kind.clone(), Arc::new(callback));
358 }
359
360 pub async fn set_extra_headers(&self, extras: IndexMap<String, String>) {
361 *self.extra_headers.write().await = extras;
362 }
363}
364
365#[derive(Debug)]
368pub enum Upstream {
369 Http(Connection),
370 Ws(WsUpstream),
371}
372
373impl Upstream {
374 pub fn url(&self) -> &str {
375 match self {
376 Upstream::Http(c) => &c.url,
377 Upstream::Ws(w) => &w.url,
378 }
379 }
380
381 pub fn session_id(&self) -> &str {
382 match self {
383 Upstream::Http(c) => &c.session_id,
384 Upstream::Ws(w) => &w.session_id,
385 }
386 }
387
388 pub fn server_name(&self) -> &str {
391 match self {
392 Upstream::Http(c) => &c.initialize_result.server_info.name,
393 Upstream::Ws(w) => &w.server_name,
394 }
395 }
396
397 pub fn server_version(&self) -> &str {
399 match self {
400 Upstream::Http(c) => &c.initialize_result.server_info.version,
401 Upstream::Ws(w) => &w.server_version,
402 }
403 }
404
405 pub async fn list_tools(&self) -> Result<Arc<Vec<Tool>>, Arc<McpError>> {
406 match self {
407 Upstream::Http(c) => c.list_tools().await,
408 Upstream::Ws(w) => w.list_tools().await,
409 }
410 }
411
412 pub async fn list_resources(&self) -> Result<Arc<Vec<Resource>>, Arc<McpError>> {
413 match self {
414 Upstream::Http(c) => c.list_resources().await,
415 Upstream::Ws(w) => w.list_resources().await,
416 }
417 }
418
419 pub async fn call_tool(
420 &self,
421 params: &CallToolRequestParams,
422 ) -> Result<CallToolResult, McpError> {
423 match self {
424 Upstream::Http(c) => c.call_tool(params).await,
425 Upstream::Ws(w) => w.call_tool(params).await,
426 }
427 }
428
429 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
430 match self {
431 Upstream::Http(c) => c.read_resource(uri).await,
432 Upstream::Ws(w) => w.read_resource(uri).await,
433 }
434 }
435
436 pub async fn delete(&self) -> Result<(), McpError> {
437 match self {
438 Upstream::Http(c) => c.delete().await,
439 Upstream::Ws(w) => w.delete().await,
440 }
441 }
442
443 pub fn set_on_tools_list_changed<F>(&self, callback: F)
444 where
445 F: Fn() + Send + Sync + 'static,
446 {
447 match self {
448 Upstream::Http(c) => c.set_on_tools_list_changed(callback),
449 Upstream::Ws(w) => w.set_on_tools_list_changed(callback),
450 }
451 }
452
453 pub fn set_on_resources_list_changed<F>(&self, callback: F)
454 where
455 F: Fn() + Send + Sync + 'static,
456 {
457 match self {
458 Upstream::Http(c) => c.set_on_resources_list_changed(callback),
459 Upstream::Ws(w) => w.set_on_resources_list_changed(callback),
460 }
461 }
462
463 pub async fn set_extra_headers(&self, extras: IndexMap<String, String>) {
464 match self {
465 Upstream::Http(c) => c.set_extra_headers(extras).await,
466 Upstream::Ws(w) => w.set_extra_headers(extras).await,
467 }
468 }
469}
470
471pub fn parse_ws_mcp_kind(url: &str) -> Option<McpKind> {
474 let rest = url.strip_prefix("ws://")?;
475 let rest = rest.split('?').next().unwrap_or(rest);
477 if rest == "objectiveai" {
479 return Some(McpKind::ObjectiveAi);
480 }
481 let path = rest.strip_prefix('/')?;
483 let parts: Vec<&str> = path.split('/').collect();
484 if let [owner, name, version, mcp] = parts.as_slice() {
485 if !owner.is_empty() && !name.is_empty() && !version.is_empty() && !mcp.is_empty() {
486 return Some(McpKind::Other {
487 owner: (*owner).to_string(),
488 name: (*name).to_string(),
489 version: (*version).to_string(),
490 mcp: (*mcp).to_string(),
491 });
492 }
493 }
494 None
495}
496
497pub async fn connect_ws(
503 channel: ReverseChannel,
504 url: String,
505 mcp_kind: McpKind,
506 args: IndexMap<String, Option<String>>,
507 mut headers: IndexMap<String, String>,
508) -> Result<WsUpstream, McpError> {
509 let response = channel
510 .request(
511 server_request::Payload::Initialize {
512 mcp_kind: mcp_kind.clone(),
513 params: InitializeRequest { args },
514 },
515 headers.clone(),
516 )
517 .await?;
518 let reply = match response.payload {
519 server_response::Payload::Initialize { result, .. } => unwrap_rpc(&url, result)?,
520 other => return Err(variant_mismatch(&url, "initialize", &other)),
521 };
522 headers.shift_remove(crate::upstream::MCP_SESSION_ID_KEY);
527 let has_tools_cap = reply.result.capabilities.tools.is_some();
528 let has_resources_cap = reply.result.capabilities.resources.is_some();
529 Ok(WsUpstream {
530 channel,
531 mcp_kind,
532 url,
533 session_id: reply.mcp_session_id,
534 server_name: reply.result.server_info.name,
535 server_version: reply.result.server_info.version,
536 has_tools_cap,
537 has_resources_cap,
538 base_headers: headers,
544 extra_headers: RwLock::new(IndexMap::new()),
545 })
546}
547
548fn unwrap_rpc<R>(url: &str, result: JsonRpcResult<R>) -> Result<R, McpError> {
549 match result {
550 JsonRpcResult::Ok { result } => Ok(result),
551 JsonRpcResult::Err {
552 code,
553 message,
554 data,
555 } => Err(McpError::JsonRpc {
556 url: url.to_string(),
557 code,
558 message,
559 data,
560 }),
561 }
562}
563
564fn transport_error(message: &str) -> McpError {
565 McpError::MalformedResponse {
566 url: "ws".to_string(),
567 message: message.to_string(),
568 }
569}
570
571fn variant_mismatch(url: &str, expected: &str, got: &server_response::Payload) -> McpError {
572 McpError::MalformedResponse {
573 url: url.to_string(),
574 message: format!(
575 "reverse channel returned wrong payload variant: expected {expected}, got {}",
576 got_variant_name(got),
577 ),
578 }
579}
580
581fn got_variant_name(p: &server_response::Payload) -> &'static str {
582 use server_response::Payload as P;
583 match p {
584 P::Initialize { .. } => "initialize",
585 P::ToolsList { .. } => "tools_list",
586 P::ToolsCall { .. } => "tools_call",
587 P::ResourcesList { .. } => "resources_list",
588 P::ResourcesRead { .. } => "resources_read",
589 P::SessionTerminate { .. } => "session_terminate",
590 P::ReadMessageQueue(_) => "read_message_queue",
591 P::Retrieve(_) => "retrieve",
592 }
593}