Skip to main content

objectiveai_sdk/mcp/
client.rs

1//! MCP client for creating connections to MCP servers.
2
3use std::time::Duration;
4
5use indexmap::IndexMap;
6
7/// Client for creating MCP connections.
8///
9/// Holds shared configuration (HTTP client, headers, backoff parameters)
10/// and creates [`Connection`](super::Connection) instances via
11/// [`connect`](Client::connect).
12#[derive(Debug, Clone)]
13pub struct Client {
14    /// HTTP client for making requests.
15    pub http_client: reqwest::Client,
16    /// User-Agent header value.
17    pub user_agent: String,
18    /// X-Title header value.
19    pub x_title: String,
20    /// Referer header value.
21    pub http_referer: String,
22    /// Timeout for the initial connection (initialize request).
23    pub connect_timeout: Duration,
24
25    /// Current backoff interval for retry logic.
26    pub backoff_current_interval: Duration,
27    /// Initial backoff interval for retry logic.
28    pub backoff_initial_interval: Duration,
29    /// Randomization factor for backoff jitter.
30    pub backoff_randomization_factor: f64,
31    /// Multiplier for exponential backoff growth.
32    pub backoff_multiplier: f64,
33    /// Maximum backoff interval.
34    pub backoff_max_interval: Duration,
35    /// Maximum total time to spend on retries.
36    pub backoff_max_elapsed_time: Duration,
37    /// Timeout for individual RPC calls after connection is established.
38    pub call_timeout: Duration,
39}
40
41impl Client {
42    /// Creates a new MCP client.
43    pub fn new(
44        http_client: reqwest::Client,
45        user_agent: String,
46        x_title: String,
47        http_referer: String,
48        connect_timeout: Duration,
49        backoff_current_interval: Duration,
50        backoff_initial_interval: Duration,
51        backoff_randomization_factor: f64,
52        backoff_multiplier: f64,
53        backoff_max_interval: Duration,
54        backoff_max_elapsed_time: Duration,
55        call_timeout: Duration,
56    ) -> Self {
57        Self {
58            http_client,
59            user_agent,
60            x_title,
61            http_referer,
62            connect_timeout,
63            backoff_current_interval,
64            backoff_initial_interval,
65            backoff_randomization_factor,
66            backoff_multiplier,
67            backoff_max_interval,
68            backoff_max_elapsed_time,
69            call_timeout,
70        }
71    }
72
73    /// Build the canonical header map that the client stamps on every
74    /// request opened under this client / connection. The supplied
75    /// caller map (if any) wins on conflict — defaults are inserted
76    /// only when the caller didn't already provide them. The merged
77    /// map is computed once at the top of `connect_once` and reused
78    /// across all three handshake requests + handed to the resulting
79    /// [`Connection`] for every later RPC.
80    fn headers(
81        &self,
82        supplied: Option<IndexMap<String, String>>,
83    ) -> IndexMap<String, String> {
84        let mut out = supplied.unwrap_or_default();
85        out.entry("User-Agent".to_string())
86            .or_insert_with(|| self.user_agent.clone());
87        out.entry("X-Title".to_string())
88            .or_insert_with(|| self.x_title.clone());
89        out.entry("Referer".to_string())
90            .or_insert_with(|| self.http_referer.clone());
91        out.entry("HTTP-Referer".to_string())
92            .or_insert_with(|| self.http_referer.clone());
93        out
94    }
95
96    /// Connects to an MCP server using the Streamable HTTP transport.
97    ///
98    /// Sends an `initialize` JSON-RPC request to the server and extracts
99    /// the `Mcp-Session-Id` from the response. Returns a [`Connection`]
100    /// that can be used to list/call tools and list/read resources.
101    ///
102    /// `headers` are forwarded on every request this connection makes
103    /// to the upstream — both the initial `initialize` POST and every
104    /// subsequent RPC. The client merges its own defaults
105    /// (`User-Agent`, `X-Title`, `Referer`, `HTTP-Referer`) into this
106    /// map, but caller-supplied values for any of those win on
107    /// conflict. `Authorization` (when needed) is just another entry
108    /// in `headers`. The `Mcp-Session-Id` header is reserved — pass it
109    /// via `session_id` instead so the explicit argument can never be
110    /// clobbered by the headers map.
111    ///
112    /// ## SSE handoff
113    ///
114    /// `Accept` is `text/event-stream, application/json` — stream first
115    /// — so the server is encouraged to keep the underlying connection
116    /// open. If the response comes back as SSE we read the initialize
117    /// event off the stream and hand the *still-open* line reader to the
118    /// returned [`Connection`]'s list-changed listener. The listener
119    /// starts reading from that pre-opened stream immediately, which
120    /// closes the race where a peer (e.g. an in-process rmcp upstream)
121    /// would broadcast `notifications/tools/list_changed` before our
122    /// listener had managed to open its own GET `/` SSE.
123    ///
124    /// If the response is unary JSON and the server advertises either
125    /// `tools.list_changed` or `resources.list_changed`, we proactively
126    /// open a GET `/` SSE stream *before returning* and hand it to the
127    /// listener for the same reason. If neither capability is set, no
128    /// listener is needed and we return without touching SSE.
129    pub async fn connect(
130        &self,
131        url: String,
132        session_id: Option<String>,
133        headers: Option<IndexMap<String, String>>,
134    ) -> Result<super::Connection, super::Error> {
135        if url == "mock" {
136            return Ok(super::Connection::new_mock(url));
137        }
138
139        // Merge the caller's headers with the client's defaults once,
140        // then reuse the same merged map across every retry of the
141        // handshake AND hand it to the resulting Connection for every
142        // later RPC. Caller-supplied headers always win over defaults.
143        let headers = self.headers(headers);
144
145        // One outer backoff retry around all three handshake steps —
146        // initialize POST, notifications/initialized POST, GET / SSE
147        // (when capabilities require it). On a failure of any step we
148        // restart from scratch: a partial handshake leaves server-side
149        // session state we can't reuse, so retrying just the failed
150        // step would reference a session the server already discarded.
151        // Every error is treated as transient — the loop only gives up
152        // when the backoff's `max_elapsed_time` is exceeded.
153        let mut backoff = backoff::ExponentialBackoff {
154            current_interval: self.backoff_current_interval,
155            initial_interval: self.backoff_initial_interval,
156            randomization_factor: self.backoff_randomization_factor,
157            multiplier: self.backoff_multiplier,
158            max_interval: self.backoff_max_interval,
159            start_time: std::time::Instant::now(),
160            max_elapsed_time: Some(self.backoff_max_elapsed_time),
161            clock: backoff::SystemClock::default(),
162        };
163
164        loop {
165            match self
166                .connect_once(&url, session_id.as_deref(), &headers)
167                .await
168            {
169                Ok(conn) => return Ok(conn),
170                Err(e) => {
171                    use backoff::backoff::Backoff;
172                    match backoff.next_backoff() {
173                        Some(d) => tokio::time::sleep(d).await,
174                        None => return Err(e),
175                    }
176                }
177            }
178        }
179    }
180
181    /// Issue a stateless HTTP `DELETE /` to one upstream MCP server,
182    /// telling it to terminate the session identified by `session_id`.
183    ///
184    /// This is the low-level primitive — no backoff, no listener
185    /// teardown, no connection state. Caller is responsible for any
186    /// retry semantics. For the stateful tear-down path that also
187    /// cancels the connection's own active streams and absorbs
188    /// upstream `404 / 401 / 403` as success, use
189    /// [`Connection::delete`](super::Connection::delete) instead.
190    ///
191    /// `headers` is merged with the same defaults `connect` applies
192    /// (`User-Agent`, `X-Title`, `Referer`, `HTTP-Referer`). The
193    /// explicit `session_id` argument always wins over any
194    /// `Mcp-Session-Id` entry that happens to appear in `headers` — the
195    /// shape mirrors `connect`'s argument split for the same reason.
196    ///
197    /// Returns `Ok(())` on any 2xx status. Any non-2xx (including
198    /// `404 Not Found`) surfaces as [`Error::BadStatus`]. Network /
199    /// transport failures surface as [`Error::Request`].
200    pub async fn delete(
201        &self,
202        url: String,
203        session_id: String,
204        headers: Option<IndexMap<String, String>>,
205    ) -> Result<(), super::Error> {
206        if url == "mock" {
207            return Ok(());
208        }
209        let headers = self.headers(headers);
210        let mut request = self
211            .http_client
212            .delete(&url)
213            .timeout(self.call_timeout)
214            .header("Mcp-Session-Id", &session_id);
215        for (name, value) in &headers {
216            // Explicit `session_id` arg always wins.
217            if name.eq_ignore_ascii_case("Mcp-Session-Id") {
218                continue;
219            }
220            request = request.header(name, value);
221        }
222        let response = request.send().await.map_err(|source| {
223            super::Error::Request {
224                url: url.clone(),
225                source,
226            }
227        })?;
228        if !response.status().is_success() {
229            let code = response.status();
230            let body = response.text().await.unwrap_or_default();
231            return Err(super::Error::BadStatus {
232                url,
233                code,
234                body: body.chars().take(800).collect(),
235            });
236        }
237        Ok(())
238    }
239
240    /// One pass through the full Streamable-HTTP handshake. Caller
241    /// applies the outer backoff retry loop in [`Self::connect`].
242    /// `headers` is the already-merged map (defaults + caller overrides
243    /// from [`Self::headers`]), reused on every request without further
244    /// processing. `Mcp-Session-Id` is applied AFTER the headers loop
245    /// so it always wins over any same-named entry in `headers`.
246    async fn connect_once(
247        &self,
248        url: &str,
249        session_id: Option<&str>,
250        headers: &IndexMap<String, String>,
251    ) -> Result<super::Connection, super::Error> {
252        let init_request = serde_json::json!({
253            "jsonrpc": "2.0",
254            "id": 1,
255            "method": "initialize",
256            "params": {
257                "protocolVersion": "2025-06-18",
258                "capabilities": {},
259                "clientInfo": {
260                    "name": "objectiveai",
261                    "version": env!("CARGO_PKG_VERSION"),
262                }
263            }
264        });
265
266        let mut request = self
267            .http_client
268            .post(url)
269            .timeout(self.connect_timeout)
270            .header("Content-Type", "application/json")
271            .header("Accept", "text/event-stream, application/json")
272            .json(&init_request);
273
274        for (name, value) in headers {
275            request = request.header(name, value);
276        }
277        // Mcp-Session-Id is applied last so the explicit `session_id`
278        // argument always wins over any same-named entry in `headers`.
279        if let Some(sid) = session_id {
280            request = request.header("Mcp-Session-Id", sid);
281        }
282
283        let response = request.send().await.map_err(|source| {
284            super::Error::Connection {
285                url: url.to_string(),
286                source,
287            }
288        })?;
289
290        if !response.status().is_success() {
291            let code = response.status();
292            let body = response.text().await.unwrap_or_default();
293            return Err(super::Error::BadStatus {
294                url: url.to_string(),
295                code,
296                body,
297            });
298        }
299
300        // Extract session ID from response header.
301        //
302        // For *new* sessions the server must mint and return a session
303        // id. For *existing* sessions (caller passed `session_id` in)
304        // many servers — including rmcp's `StreamableHttpService` on
305        // its existing-session branch — don't echo the header back
306        // because nothing changed. When the caller already knew the
307        // session id, fall back to it instead of erroring; that's the
308        // value we'll be stamping on every subsequent request anyway.
309        let resolved_session_id = match response
310            .headers()
311            .get("Mcp-Session-Id")
312            .and_then(|v| v.to_str().ok())
313            .map(String::from)
314        {
315            Some(s) => s,
316            None => match session_id {
317                Some(provided) => provided.to_string(),
318                None => {
319                    let body = response.text().await.unwrap_or_default();
320                    return Err(super::Error::NoSessionId {
321                        url: url.to_string(),
322                        body: body.chars().take(800).collect(),
323                    });
324                }
325            },
326        };
327
328        // Did the server return SSE or unary JSON? rmcp's
329        // `StreamableHttpService` always returns SSE; many other servers
330        // reply with bare JSON.
331        let is_sse = response
332            .headers()
333            .get(reqwest::header::CONTENT_TYPE)
334            .and_then(|v| v.to_str().ok())
335            .map(|v| v.starts_with("text/event-stream"))
336            .unwrap_or(false);
337
338        // Parse the initialize response. SSE path consumes one event
339        // from the stream and keeps the rest of the stream alive for
340        // the listener; unary path consumes the whole body.
341        let (initialize_result, mut initial_sse_lines) = if is_sse {
342            let mut lines = super::lines_from_response(response);
343            let rpc_response: super::JsonRpcResponse<
344                super::initialize_result::InitializeResult,
345            > = super::read_next_sse_event(url, &mut lines).await?;
346            let result = match rpc_response {
347                super::JsonRpcResponse::Success { result, .. } => result,
348                super::JsonRpcResponse::Error { error, .. } => {
349                    return Err(super::Error::JsonRpc {
350                        url: url.to_string(),
351                        code: error.code,
352                        message: error.message,
353                        data: error.data,
354                    });
355                }
356            };
357            (result, Some(lines))
358        } else {
359            let rpc_response: super::JsonRpcResponse<
360                super::initialize_result::InitializeResult,
361            > = super::parse_streamable_http_response(url, response).await?;
362            let result = match rpc_response {
363                super::JsonRpcResponse::Success { result, .. } => result,
364                super::JsonRpcResponse::Error { error, .. } => {
365                    return Err(super::Error::JsonRpc {
366                        url: url.to_string(),
367                        code: error.code,
368                        message: error.message,
369                        data: error.data,
370                    });
371                }
372            };
373            (result, None)
374        };
375
376        // Whether we need a notification SSE channel at all.
377        let needs_sse = initialize_result
378            .capabilities
379            .tools
380            .as_ref()
381            .and_then(|t| t.list_changed)
382            .unwrap_or(false)
383            || initialize_result
384                .capabilities
385                .resources
386                .as_ref()
387                .and_then(|r| r.list_changed)
388                .unwrap_or(false);
389
390        // Send `notifications/initialized` BEFORE any other request.
391        // rmcp's per-session worker is in `expect_notification("initialized")`
392        // at this point — anything else (a `tools/list`, an opportunistic
393        // GET `/`) that lands during that window pushes a non-notification
394        // through the worker, makes `serve_server_with_ct_inner` return
395        // `Err(ExpectedInitializedNotification(...))`, drops the
396        // WorkerTransport, cancels the worker via its drop_guard, and
397        // tears the whole session down. Every later POST then 500s with
398        // "Session service terminated."
399        //
400        // We don't have a `Connection` yet — building one would spawn
401        // `refresh_tools` / `refresh_resources` background tasks that
402        // race with this notification, which is exactly the bug we're
403        // avoiding. We therefore POST inline here.
404        let init_notification_body = serde_json::json!({
405            "jsonrpc": "2.0",
406            "method": "notifications/initialized",
407            "params": {},
408        });
409        let mut notify_request = self
410            .http_client
411            .post(url)
412            .timeout(self.call_timeout)
413            .header("Content-Type", "application/json")
414            .header("Accept", "application/json, text/event-stream");
415        for (name, value) in headers {
416            notify_request = notify_request.header(name, value);
417        }
418        notify_request =
419            notify_request.header("Mcp-Session-Id", &resolved_session_id);
420        let notify_response = notify_request
421            .json(&init_notification_body)
422            .send()
423            .await
424            .map_err(|source| super::Error::Request {
425                url: url.to_string(),
426                source,
427            })?;
428        if !notify_response.status().is_success() {
429            let code = notify_response.status();
430            let body = notify_response.text().await.unwrap_or_default();
431            return Err(super::Error::BadStatus {
432                url: url.to_string(),
433                code,
434                body,
435            });
436        }
437
438        // Now safe to drop the init-side SSE stream; rmcp's session
439        // worker is past the init handshake and we have no further use
440        // for it (notifications come on the GET / stream below).
441        drop(initial_sse_lines.take());
442
443        // Now that the server's session worker is past the
444        // `expect_notification` gate, it's safe to open the proactive
445        // GET `/` SSE stream the listener will read
446        // `notifications/{tools,resources}/list_changed` from. Capability
447        // inspection only gates whether we open this stream; the
448        // `Connection` itself is naive about capabilities.
449        let initial_sse_lines: Option<super::LinesStream> = if needs_sse {
450            let mut get_request = self
451                .http_client
452                .get(url)
453                .timeout(self.connect_timeout)
454                .header("Accept", "text/event-stream");
455            for (name, value) in headers {
456                get_request = get_request.header(name, value);
457            }
458            get_request =
459                get_request.header("Mcp-Session-Id", &resolved_session_id);
460            let get_response = get_request.send().await.map_err(|source| {
461                super::Error::Connection {
462                    url: url.to_string(),
463                    source,
464                }
465            })?;
466            if !get_response.status().is_success() {
467                let code = get_response.status();
468                let body = get_response.text().await.unwrap_or_default();
469                return Err(super::Error::BadStatus {
470                    url: url.to_string(),
471                    code,
472                    body,
473                });
474            }
475            Some(super::lines_from_response(get_response))
476        } else {
477            None
478        };
479
480        // Construct the Connection at the very end. This is the only
481        // place in `connect` where the listener task and the
482        // `refresh_tools` / `refresh_resources` background tasks get
483        // spawned — by now the upstream is fully past its init
484        // handshake, so any of those POSTs land safely.
485        //
486        // `was_resumed` is true when the caller passed an existing
487        // `session_id` into `connect` — i.e. this is a re-attach to a
488        // pre-existing upstream session, not a fresh mint. The bit
489        // rides into `ConnectionInner` as `is_reconnect` and gates
490        // the drop-time orphan-DELETE that releases resumed sessions
491        // nobody ended up using.
492        let was_resumed = session_id.is_some();
493        let connection = super::Connection::new(
494            self.http_client.clone(),
495            url.to_string(),
496            resolved_session_id,
497            headers.clone(),
498            self.backoff_current_interval,
499            self.backoff_initial_interval,
500            self.backoff_randomization_factor,
501            self.backoff_multiplier,
502            self.backoff_max_interval,
503            self.backoff_max_elapsed_time,
504            self.call_timeout,
505            initialize_result,
506            initial_sse_lines,
507            was_resumed,
508        )
509        .await;
510
511        Ok(connection)
512    }
513}