Skip to main content

objectiveai_sdk/mcp/
client.rs

1//! MCP client for creating connections to MCP servers.
2
3use std::time::Duration;
4
5use indexmap::IndexMap;
6
7
8/// Client for creating MCP connections.
9///
10/// Holds shared configuration (HTTP client, headers, backoff parameters)
11/// and creates [`Connection`](super::Connection) instances via
12/// [`connect`](Client::connect).
13#[derive(Debug, Clone)]
14pub struct Client {
15    /// HTTP client for making requests.
16    pub http_client: reqwest::Client,
17    /// User-Agent header value.
18    pub user_agent: String,
19    /// X-Title header value.
20    pub x_title: String,
21    /// Referer header value.
22    pub http_referer: String,
23    /// Timeout for the initial connection (initialize request).
24    pub connect_timeout: Duration,
25
26    /// Current backoff interval for retry logic.
27    pub backoff_current_interval: Duration,
28    /// Initial backoff interval for retry logic.
29    pub backoff_initial_interval: Duration,
30    /// Randomization factor for backoff jitter.
31    pub backoff_randomization_factor: f64,
32    /// Multiplier for exponential backoff growth.
33    pub backoff_multiplier: f64,
34    /// Maximum backoff interval.
35    pub backoff_max_interval: Duration,
36    /// Maximum total time to spend on retries.
37    pub backoff_max_elapsed_time: Duration,
38    /// Timeout for individual RPC calls after connection is established.
39    pub call_timeout: Duration,
40}
41
42impl Client {
43    /// Creates a new MCP client.
44    pub fn new(
45        http_client: reqwest::Client,
46        user_agent: String,
47        x_title: String,
48        http_referer: String,
49        connect_timeout: Duration,
50        backoff_current_interval: Duration,
51        backoff_initial_interval: Duration,
52        backoff_randomization_factor: f64,
53        backoff_multiplier: f64,
54        backoff_max_interval: Duration,
55        backoff_max_elapsed_time: Duration,
56        call_timeout: Duration,
57    ) -> Self {
58        Self {
59            http_client,
60            user_agent,
61            x_title,
62            http_referer,
63            connect_timeout,
64            backoff_current_interval,
65            backoff_initial_interval,
66            backoff_randomization_factor,
67            backoff_multiplier,
68            backoff_max_interval,
69            backoff_max_elapsed_time,
70            call_timeout,
71        }
72    }
73
74    /// Build the canonical header map that the client stamps on every
75    /// request opened under this client / connection. The supplied
76    /// caller map (if any) wins on conflict — defaults are inserted
77    /// only when the caller didn't already provide them. The merged
78    /// map is computed once at the top of `connect_once` and reused
79    /// across all three handshake requests + handed to the resulting
80    /// [`Connection`] for every later RPC.
81    fn headers(
82        &self,
83        supplied: Option<IndexMap<String, String>>,
84    ) -> IndexMap<String, String> {
85        let mut out = supplied.unwrap_or_default();
86        out.entry("User-Agent".to_string())
87            .or_insert_with(|| self.user_agent.clone());
88        out.entry("X-Title".to_string())
89            .or_insert_with(|| self.x_title.clone());
90        out.entry("Referer".to_string())
91            .or_insert_with(|| self.http_referer.clone());
92        out.entry("HTTP-Referer".to_string())
93            .or_insert_with(|| self.http_referer.clone());
94        out
95    }
96
97    /// Connects to an MCP server using the Streamable HTTP transport.
98    ///
99    /// Sends an `initialize` JSON-RPC request to the server and extracts
100    /// the `Mcp-Session-Id` from the response. Returns a [`Connection`]
101    /// that can be used to list/call tools and list/read resources.
102    ///
103    /// `headers` are forwarded on every request this connection makes
104    /// to the upstream — both the initial `initialize` POST and every
105    /// subsequent RPC. The client merges its own defaults
106    /// (`User-Agent`, `X-Title`, `Referer`, `HTTP-Referer`) into this
107    /// map, but caller-supplied values for any of those win on
108    /// conflict. `Authorization` (when needed) is just another entry
109    /// in `headers`. The `Mcp-Session-Id` header is reserved — pass it
110    /// via `session_id` instead so the explicit argument can never be
111    /// clobbered by the headers map.
112    ///
113    /// ## SSE handoff
114    ///
115    /// `Accept` is `text/event-stream, application/json` — stream first
116    /// — so the server is encouraged to keep the underlying connection
117    /// open. If the response comes back as SSE we read the initialize
118    /// event off the stream and hand the *still-open* line reader to the
119    /// returned [`Connection`]'s list-changed listener. The listener
120    /// starts reading from that pre-opened stream immediately, which
121    /// closes the race where a peer (e.g. an in-process rmcp upstream)
122    /// would broadcast `notifications/tools/list_changed` before our
123    /// listener had managed to open its own GET `/` SSE.
124    ///
125    /// If the response is unary JSON and the server advertises either
126    /// `tools.list_changed` or `resources.list_changed`, we proactively
127    /// open a GET `/` SSE stream *before returning* and hand it to the
128    /// listener for the same reason. If neither capability is set, no
129    /// listener is needed and we return without touching SSE.
130    pub async fn connect(
131        &self,
132        url: String,
133        session_id: Option<String>,
134        headers: Option<IndexMap<String, String>>,
135    ) -> Result<super::Connection, super::Error> {
136        if url == "mock" {
137            return Ok(super::Connection::new_mock(url));
138        }
139
140        // Merge the caller's headers with the client's defaults once,
141        // then reuse the same merged map across every retry of the
142        // handshake AND hand it to the resulting Connection for every
143        // later RPC. Caller-supplied headers always win over defaults.
144        let headers = self.headers(headers);
145
146        // One outer backoff retry around all three handshake steps —
147        // initialize POST, notifications/initialized POST, GET / SSE
148        // (when capabilities require it). On a failure of any step we
149        // restart from scratch: a partial handshake leaves server-side
150        // session state we can't reuse, so retrying just the failed
151        // step would reference a session the server already discarded.
152        // Every error is treated as transient — the loop only gives up
153        // when the backoff's `max_elapsed_time` is exceeded.
154        let mut backoff = backoff::ExponentialBackoff {
155            current_interval: self.backoff_current_interval,
156            initial_interval: self.backoff_initial_interval,
157            randomization_factor: self.backoff_randomization_factor,
158            multiplier: self.backoff_multiplier,
159            max_interval: self.backoff_max_interval,
160            start_time: std::time::Instant::now(),
161            max_elapsed_time: Some(self.backoff_max_elapsed_time),
162            clock: backoff::SystemClock::default(),
163        };
164
165        loop {
166            match self
167                .connect_once(&url, session_id.as_deref(), &headers)
168                .await
169            {
170                Ok(conn) => return Ok(conn),
171                Err(e) => {
172                    use backoff::backoff::Backoff;
173                    match backoff.next_backoff() {
174                        Some(d) => tokio::time::sleep(d).await,
175                        None => return Err(e),
176                    }
177                }
178            }
179        }
180    }
181
182    /// One pass through the full Streamable-HTTP handshake. Caller
183    /// applies the outer backoff retry loop in [`Self::connect`].
184    /// `headers` is the already-merged map (defaults + caller overrides
185    /// from [`Self::headers`]), reused on every request without further
186    /// processing. `Mcp-Session-Id` is applied AFTER the headers loop
187    /// so it always wins over any same-named entry in `headers`.
188    async fn connect_once(
189        &self,
190        url: &str,
191        session_id: Option<&str>,
192        headers: &IndexMap<String, String>,
193    ) -> Result<super::Connection, super::Error> {
194        let init_request = serde_json::json!({
195            "jsonrpc": "2.0",
196            "id": 1,
197            "method": "initialize",
198            "params": {
199                "protocolVersion": "2025-06-18",
200                "capabilities": {},
201                "clientInfo": {
202                    "name": "objectiveai",
203                    "version": env!("CARGO_PKG_VERSION"),
204                }
205            }
206        });
207
208        let mut request = self
209            .http_client
210            .post(url)
211            .timeout(self.connect_timeout)
212            .header("Content-Type", "application/json")
213            .header("Accept", "text/event-stream, application/json")
214            .json(&init_request);
215
216        for (name, value) in headers {
217            request = request.header(name, value);
218        }
219        // Mcp-Session-Id is applied last so the explicit `session_id`
220        // argument always wins over any same-named entry in `headers`.
221        if let Some(sid) = session_id {
222            request = request.header("Mcp-Session-Id", sid);
223        }
224
225        let response = request.send().await.map_err(|source| super::Error::Connection {
226            url: url.to_string(),
227            source,
228        })?;
229
230        if !response.status().is_success() {
231            let code = response.status();
232            let body = response.text().await.unwrap_or_default();
233            return Err(super::Error::BadStatus {
234                url: url.to_string(),
235                code,
236                body,
237            });
238        }
239
240        // Extract session ID from response header.
241        //
242        // For *new* sessions the server must mint and return a session
243        // id. For *existing* sessions (caller passed `session_id` in)
244        // many servers — including rmcp's `StreamableHttpService` on
245        // its existing-session branch — don't echo the header back
246        // because nothing changed. When the caller already knew the
247        // session id, fall back to it instead of erroring; that's the
248        // value we'll be stamping on every subsequent request anyway.
249        let resolved_session_id = match response
250            .headers()
251            .get("Mcp-Session-Id")
252            .and_then(|v| v.to_str().ok())
253            .map(String::from)
254        {
255            Some(s) => s,
256            None => match session_id {
257                Some(provided) => provided.to_string(),
258                None => {
259                    let body = response.text().await.unwrap_or_default();
260                    return Err(super::Error::NoSessionId {
261                        url: url.to_string(),
262                        body: body.chars().take(800).collect(),
263                    });
264                }
265            },
266        };
267
268        // Did the server return SSE or unary JSON? rmcp's
269        // `StreamableHttpService` always returns SSE; many other servers
270        // reply with bare JSON.
271        let is_sse = response
272            .headers()
273            .get(reqwest::header::CONTENT_TYPE)
274            .and_then(|v| v.to_str().ok())
275            .map(|v| v.starts_with("text/event-stream"))
276            .unwrap_or(false);
277
278        // Parse the initialize response. SSE path consumes one event
279        // from the stream and keeps the rest of the stream alive for
280        // the listener; unary path consumes the whole body.
281        let (initialize_result, mut initial_sse_lines) = if is_sse {
282            let mut lines = super::lines_from_response(response);
283            let rpc_response: super::JsonRpcResponse<
284                super::initialize_result::InitializeResult,
285            > = super::read_next_sse_event(url, &mut lines).await?;
286            let result = match rpc_response {
287                super::JsonRpcResponse::Success { result, .. } => result,
288                super::JsonRpcResponse::Error { error, .. } => {
289                    return Err(super::Error::JsonRpc {
290                        url: url.to_string(),
291                        code: error.code,
292                        message: error.message,
293                        data: error.data,
294                    });
295                }
296            };
297            (result, Some(lines))
298        } else {
299            let rpc_response: super::JsonRpcResponse<
300                super::initialize_result::InitializeResult,
301            > = super::parse_streamable_http_response(url, response).await?;
302            let result = match rpc_response {
303                super::JsonRpcResponse::Success { result, .. } => result,
304                super::JsonRpcResponse::Error { error, .. } => {
305                    return Err(super::Error::JsonRpc {
306                        url: url.to_string(),
307                        code: error.code,
308                        message: error.message,
309                        data: error.data,
310                    });
311                }
312            };
313            (result, None)
314        };
315
316        // Whether we need a notification SSE channel at all.
317        let needs_sse = initialize_result
318            .capabilities
319            .tools
320            .as_ref()
321            .and_then(|t| t.list_changed)
322            .unwrap_or(false)
323            || initialize_result
324                .capabilities
325                .resources
326                .as_ref()
327                .and_then(|r| r.list_changed)
328                .unwrap_or(false);
329
330        // Send `notifications/initialized` BEFORE any other request.
331        // rmcp's per-session worker is in `expect_notification("initialized")`
332        // at this point — anything else (a `tools/list`, an opportunistic
333        // GET `/`) that lands during that window pushes a non-notification
334        // through the worker, makes `serve_server_with_ct_inner` return
335        // `Err(ExpectedInitializedNotification(...))`, drops the
336        // WorkerTransport, cancels the worker via its drop_guard, and
337        // tears the whole session down. Every later POST then 500s with
338        // "Session service terminated."
339        //
340        // We don't have a `Connection` yet — building one would spawn
341        // `refresh_tools` / `refresh_resources` background tasks that
342        // race with this notification, which is exactly the bug we're
343        // avoiding. We therefore POST inline here.
344        let init_notification_body = serde_json::json!({
345            "jsonrpc": "2.0",
346            "method": "notifications/initialized",
347            "params": {},
348        });
349        let mut notify_request = self
350            .http_client
351            .post(url)
352            .timeout(self.call_timeout)
353            .header("Content-Type", "application/json")
354            .header("Accept", "application/json, text/event-stream");
355        for (name, value) in headers {
356            notify_request = notify_request.header(name, value);
357        }
358        notify_request = notify_request.header("Mcp-Session-Id", &resolved_session_id);
359        let notify_response = notify_request
360            .json(&init_notification_body)
361            .send()
362            .await
363            .map_err(|source| super::Error::Request {
364                url: url.to_string(),
365                source,
366            })?;
367        if !notify_response.status().is_success() {
368            let code = notify_response.status();
369            let body = notify_response.text().await.unwrap_or_default();
370            return Err(super::Error::BadStatus {
371                url: url.to_string(),
372                code,
373                body,
374            });
375        }
376
377        // Now safe to drop the init-side SSE stream; rmcp's session
378        // worker is past the init handshake and we have no further use
379        // for it (notifications come on the GET / stream below).
380        drop(initial_sse_lines.take());
381
382        // Now that the server's session worker is past the
383        // `expect_notification` gate, it's safe to open the proactive
384        // GET `/` SSE stream the listener will read
385        // `notifications/{tools,resources}/list_changed` from. Capability
386        // inspection only gates whether we open this stream; the
387        // `Connection` itself is naive about capabilities.
388        let initial_sse_lines: Option<super::LinesStream> = if needs_sse {
389            let mut get_request = self
390                .http_client
391                .get(url)
392                .timeout(self.connect_timeout)
393                .header("Accept", "text/event-stream");
394            for (name, value) in headers {
395                get_request = get_request.header(name, value);
396            }
397            get_request = get_request.header("Mcp-Session-Id", &resolved_session_id);
398            let get_response = get_request.send().await.map_err(|source| {
399                super::Error::Connection {
400                    url: url.to_string(),
401                    source,
402                }
403            })?;
404            if !get_response.status().is_success() {
405                let code = get_response.status();
406                let body = get_response.text().await.unwrap_or_default();
407                return Err(super::Error::BadStatus {
408                    url: url.to_string(),
409                    code,
410                    body,
411                });
412            }
413            Some(super::lines_from_response(get_response))
414        } else {
415            None
416        };
417
418        // Construct the Connection at the very end. This is the only
419        // place in `connect` where the listener task and the
420        // `refresh_tools` / `refresh_resources` background tasks get
421        // spawned — by now the upstream is fully past its init
422        // handshake, so any of those POSTs land safely.
423        let connection = super::Connection::new(
424            self.http_client.clone(),
425            url.to_string(),
426            resolved_session_id,
427            headers.clone(),
428            self.backoff_current_interval,
429            self.backoff_initial_interval,
430            self.backoff_randomization_factor,
431            self.backoff_multiplier,
432            self.backoff_max_interval,
433            self.backoff_max_elapsed_time,
434            self.call_timeout,
435            initialize_result,
436            initial_sse_lines,
437        )
438        .await;
439
440        Ok(connection)
441    }
442}
443