1use std::collections::BTreeMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{bail, Context, Result};
9use async_trait::async_trait;
10use tokio::sync::Mutex;
11use tokio::time::timeout;
12
13use super::client::McpClient;
14use super::config::McpHttpAuthConfig;
15use super::oauth::{refresh_mcp_oauth_token, token_is_expired, McpTokenStore};
16use super::types::{CallToolResult, InitializeResult, ListToolsResult, ServerStatus};
17
18const DEFAULT_TIMEOUT_MS: u64 = 30_000;
20
21const MCP_HTTP_ACCEPT: &str = "application/json, text/event-stream";
24
25pub struct HttpClient {
27 server_name: String,
28 url: String,
29 headers: BTreeMap<String, String>,
30 auth: Option<McpHttpAuthConfig>,
31 timeout_ms: u64,
32 status: Arc<Mutex<ServerStatus>>,
33 next_id: AtomicU64,
34 client: reqwest::Client,
35}
36
37impl HttpClient {
38 pub fn new(
40 server_name: String,
41 url: String,
42 headers: BTreeMap<String, String>,
43 auth: Option<McpHttpAuthConfig>,
44 timeout_ms: Option<u64>,
45 ) -> Self {
46 let timeout = Duration::from_millis(timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
47
48 let client = reqwest::Client::builder()
49 .timeout(timeout)
50 .build()
51 .unwrap_or_else(|_| reqwest::Client::new());
52
53 Self {
54 server_name,
55 url,
56 headers,
57 auth,
58 timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS),
59 status: Arc::new(Mutex::new(ServerStatus::Disconnected)),
60 next_id: AtomicU64::new(1),
61 client,
62 }
63 }
64
65 async fn send_request(
67 &self,
68 method: &str,
69 params: Option<serde_json::Value>,
70 ) -> Result<serde_json::Value> {
71 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
72
73 let mut request = serde_json::Map::new();
78 request.insert(
79 "jsonrpc".to_string(),
80 serde_json::Value::String("2.0".to_string()),
81 );
82 request.insert("id".to_string(), serde_json::Value::Number(id.into()));
83 request.insert(
84 "method".to_string(),
85 serde_json::Value::String(method.to_string()),
86 );
87 if let Some(p) = params {
88 request.insert("params".to_string(), p);
89 }
90 let request = serde_json::Value::Object(request);
91
92 let mut req = self.client.post(&self.url).json(&request);
93
94 let user_has_accept = self
95 .headers
96 .keys()
97 .any(|k| k.eq_ignore_ascii_case("accept"));
98 if !user_has_accept {
99 req = req.header("Accept", MCP_HTTP_ACCEPT);
100 }
101
102 let user_has_authorization = self
103 .headers
104 .keys()
105 .any(|k| k.eq_ignore_ascii_case("authorization"));
106
107 for (key, value) in &self.headers {
108 req = req.header(key, value);
109 }
110
111 if !user_has_authorization {
112 if let Some(token) = self.load_oauth_token()? {
113 req = req.bearer_auth(token);
114 }
115 }
116
117 let timeout_duration = Duration::from_millis(self.timeout_ms);
118 let response = timeout(timeout_duration, req.send())
119 .await
120 .with_context(|| {
121 format!(
122 "HTTP request to MCP server {} timed out after {}ms",
123 self.server_name, self.timeout_ms
124 )
125 })?
126 .with_context(|| format!("HTTP request to MCP server {} failed", self.server_name))?;
127
128 if !response.status().is_success() {
129 let status = response.status();
130 let body = response.text().await.unwrap_or_default();
131 if status == reqwest::StatusCode::UNAUTHORIZED && self.auth.is_some() {
132 bail!(
133 "MCP server {} requires OAuth; run `atomcode mcp login {}` or `/mcp login {}`",
134 self.server_name,
135 self.server_name,
136 self.server_name
137 );
138 }
139 bail!(
140 "MCP server {} returned HTTP {}: {}",
141 self.server_name,
142 status,
143 body
144 );
145 }
146
147 let content_type = response
153 .headers()
154 .get(reqwest::header::CONTENT_TYPE)
155 .and_then(|v| v.to_str().ok())
156 .map(|s| s.to_ascii_lowercase())
157 .unwrap_or_default();
158
159 let body = response
160 .text()
161 .await
162 .with_context(|| format!("Failed to read MCP HTTP body from {}", self.server_name))?;
163
164 let result: super::types::JsonRpcResponse = if content_type.contains("text/event-stream") {
165 parse_sse_jsonrpc(&body, id).with_context(|| {
166 format!(
167 "Failed to parse MCP SSE response from {} (first 200 bytes: {:?})",
168 self.server_name,
169 body.chars().take(200).collect::<String>()
170 )
171 })?
172 } else {
173 serde_json::from_str(&body).with_context(|| {
174 format!(
175 "Failed to parse MCP HTTP response from {} (content-type={:?}, \
176 first 200 bytes: {:?})",
177 self.server_name,
178 content_type,
179 body.chars().take(200).collect::<String>()
180 )
181 })?
182 };
183
184 if let Some(error) = result.error {
185 bail!("MCP error {} (code {}): {}", error.message, error.code, "");
186 }
187
188 result
189 .result
190 .ok_or_else(|| anyhow::anyhow!("MCP response missing result"))
191 }
192
193 fn load_oauth_token(&self) -> Result<Option<String>> {
194 let Some(McpHttpAuthConfig::OAuth(_)) = &self.auth else {
195 return Ok(None);
196 };
197 let Some(token) = McpTokenStore::default().load_token(&self.server_name)? else {
198 return Ok(None);
199 };
200 if token_is_expired(&token) {
201 let refreshed =
202 refresh_mcp_oauth_token(&self.server_name, &token).with_context(|| {
203 format!(
204 "MCP server {} OAuth token is expired; run `atomcode mcp login {}`",
205 self.server_name, self.server_name
206 )
207 })?;
208 return Ok(Some(refreshed.access_token));
209 }
210 Ok(Some(token.access_token))
211 }
212
213 async fn send_notification(&self, method: &str) -> Result<()> {
216 let request = serde_json::json!({
217 "jsonrpc": "2.0",
218 "method": method
219 });
220
221 let mut req = self.client.post(&self.url).json(&request);
222
223 let user_has_accept = self.headers.keys().any(|k| k.eq_ignore_ascii_case("accept"));
224 if !user_has_accept {
225 req = req.header("Accept", MCP_HTTP_ACCEPT);
226 }
227
228 let user_has_authorization = self.headers.keys().any(|k| k.eq_ignore_ascii_case("authorization"));
229 for (key, value) in &self.headers {
230 req = req.header(key, value);
231 }
232
233 if !user_has_authorization {
234 if let Some(token) = self.load_oauth_token()? {
235 req = req.bearer_auth(token);
236 }
237 }
238
239 let _ = req.send().await;
241 Ok(())
242 }
243}
244
245#[async_trait]
246impl McpClient for HttpClient {
247 async fn initialize(&mut self) -> Result<InitializeResult> {
248 let mut status = self.status.lock().await;
249 *status = ServerStatus::Connecting;
250 drop(status);
251
252 let params = serde_json::json!({
253 "protocolVersion": "2024-11-05",
254 "capabilities": {
255 "tools": {}
256 },
257 "clientInfo": {
258 "name": "atomcode",
259 "version": env!("CARGO_PKG_VERSION")
260 }
261 });
262
263 let result = self.send_request("initialize", Some(params)).await?;
264
265 let init_result: InitializeResult =
266 serde_json::from_value(result).context("Failed to parse initialize result")?;
267
268 let _ = self.send_notification("notifications/initialized").await;
270
271 let mut status = self.status.lock().await;
272 *status = ServerStatus::Connected;
273
274 Ok(init_result)
275 }
276
277 async fn list_tools(&self) -> Result<ListToolsResult> {
278 let result = self.send_request("tools/list", None).await?;
279 serde_json::from_value(result).context("Failed to parse tools/list result")
280 }
281
282 async fn call_tool(
283 &self,
284 tool_name: &str,
285 arguments: serde_json::Value,
286 ) -> Result<CallToolResult> {
287 let params = serde_json::json!({
288 "name": tool_name,
289 "arguments": arguments
290 });
291
292 let result = self.send_request("tools/call", Some(params)).await?;
293 serde_json::from_value(result).context("Failed to parse tools/call result")
294 }
295
296 fn server_name(&self) -> &str {
297 &self.server_name
298 }
299
300 fn status(&self) -> ServerStatus {
301 self.status
302 .try_lock()
303 .map(|s| s.clone())
304 .unwrap_or(ServerStatus::Disconnected)
305 }
306}
307
308fn parse_sse_jsonrpc(body: &str, request_id: u64) -> Result<super::types::JsonRpcResponse> {
331 let mut current = String::new();
332 let try_match = |buf: &str| -> Option<super::types::JsonRpcResponse> {
333 if buf.is_empty() {
334 return None;
335 }
336 let val: serde_json::Value = serde_json::from_str(buf).ok()?;
337 let id_match = val
338 .get("id")
339 .and_then(|v| v.as_u64())
340 .map_or(false, |id| id == request_id);
341 if !id_match {
342 return None;
343 }
344 serde_json::from_value(val).ok()
345 };
346
347 for line in body.lines() {
348 if line.is_empty() {
349 if let Some(resp) = try_match(¤t) {
350 return Ok(resp);
351 }
352 current.clear();
353 continue;
354 }
355 if let Some(rest) = line.strip_prefix("data:") {
356 let rest = rest.strip_prefix(' ').unwrap_or(rest);
358 if !current.is_empty() {
359 current.push('\n');
360 }
361 current.push_str(rest);
362 }
363 }
365 if let Some(resp) = try_match(¤t) {
367 return Ok(resp);
368 }
369 bail!(
370 "event-stream contained no JSON-RPC response matching id {}",
371 request_id
372 )
373}
374
375#[cfg(test)]
376mod sse_tests {
377 use super::*;
378
379 #[test]
380 fn single_data_frame_with_event_header() {
381 let body =
382 "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n";
383 let resp = parse_sse_jsonrpc(body, 1).expect("parse");
384 assert_eq!(resp.id, 1);
385 assert!(resp.error.is_none());
386 assert_eq!(
387 resp.result
388 .as_ref()
389 .and_then(|v| v.get("ok"))
390 .and_then(|v| v.as_bool()),
391 Some(true)
392 );
393 }
394
395 #[test]
396 fn data_only_frame_without_event_header() {
397 let body = "data: {\"jsonrpc\":\"2.0\",\"id\":7,\"result\":{}}\n\n";
398 let resp = parse_sse_jsonrpc(body, 7).expect("parse");
399 assert_eq!(resp.id, 7);
400 }
401
402 #[test]
403 fn skips_notifications_picks_matching_id() {
404 let body = "data: {\"jsonrpc\":\"2.0\",\"method\":\"progress\",\"params\":{}}\n\n\
407 data: {\"jsonrpc\":\"2.0\",\"id\":99,\"result\":{}}\n\n\
408 data: {\"jsonrpc\":\"2.0\",\"id\":42,\"result\":{\"hit\":true}}\n\n";
409 let resp = parse_sse_jsonrpc(body, 42).expect("parse");
410 assert_eq!(resp.id, 42);
411 assert_eq!(
412 resp.result
413 .as_ref()
414 .and_then(|v| v.get("hit"))
415 .and_then(|v| v.as_bool()),
416 Some(true)
417 );
418 }
419
420 #[test]
421 fn multi_line_data_concatenates() {
422 let body = "data: {\"jsonrpc\":\"2.0\",\n\
426 data: \"id\":3,\n\
427 data: \"result\":{}}\n\n";
428 let resp = parse_sse_jsonrpc(body, 3).expect("parse");
429 assert_eq!(resp.id, 3);
430 }
431
432 #[test]
433 fn trailing_frame_without_blank_terminator() {
434 let body = "data: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{}}";
437 let resp = parse_sse_jsonrpc(body, 2).expect("parse");
438 assert_eq!(resp.id, 2);
439 }
440
441 #[test]
442 fn ignores_sse_comments_and_other_fields() {
443 let body = ": this is a heartbeat comment\n\
444 event: message\n\
445 id: 17\n\
446 retry: 5000\n\
447 data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\n\n";
448 let resp = parse_sse_jsonrpc(body, 1).expect("parse");
449 assert_eq!(resp.id, 1);
450 }
451
452 #[test]
453 fn no_matching_id_returns_error() {
454 let body = "data: {\"jsonrpc\":\"2.0\",\"id\":99,\"result\":{}}\n\n";
455 let err = parse_sse_jsonrpc(body, 1).expect_err("must fail");
456 assert!(format!("{}", err).contains("no JSON-RPC response matching id 1"));
457 }
458
459 #[test]
460 fn skips_non_json_data_lines() {
461 let body = "data: [DONE]\n\n\
462 data: {\"jsonrpc\":\"2.0\",\"id\":5,\"result\":{}}\n\n";
463 let resp = parse_sse_jsonrpc(body, 5).expect("parse");
464 assert_eq!(resp.id, 5);
465 }
466}