1use async_trait::async_trait;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{mpsc, Mutex, RwLock};
17
18use crate::mcp::error::{McpError, McpResult};
19use crate::mcp::transport::{
20 McpMessage, McpRequest, McpResponse, Transport, TransportConfig, TransportEvent, TransportState,
21};
22use crate::mcp::types::{ConnectionOptions, TransportType};
23
24#[derive(Debug, Clone)]
26pub struct HttpConfig {
27 pub url: String,
29 pub headers: HashMap<String, String>,
31}
32
33pub struct HttpTransport {
39 config: HttpConfig,
41 options: ConnectionOptions,
43 state: Arc<RwLock<TransportState>>,
45 client: Option<reqwest::Client>,
47 event_tx: Arc<Mutex<Option<mpsc::Sender<TransportEvent>>>>,
49 request_counter: AtomicU64,
51}
52
53impl HttpTransport {
54 pub fn new(config: HttpConfig, options: ConnectionOptions) -> Self {
56 Self {
57 config,
58 options,
59 state: Arc::new(RwLock::new(TransportState::Disconnected)),
60 client: None,
61 event_tx: Arc::new(Mutex::new(None)),
62 request_counter: AtomicU64::new(1),
63 }
64 }
65
66 pub fn from_config(config: TransportConfig, options: ConnectionOptions) -> McpResult<Self> {
68 match config {
69 TransportConfig::Http { url, headers } | TransportConfig::Sse { url, headers } => {
70 Ok(Self::new(HttpConfig { url, headers }, options))
71 }
72 _ => Err(McpError::config("Expected HTTP transport configuration")),
73 }
74 }
75
76 pub fn next_request_id(&self) -> String {
78 let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
79 format!("http-req-{}", id)
80 }
81
82 async fn set_state(&self, state: TransportState) {
84 let mut current = self.state.write().await;
85 *current = state;
86 }
87
88 async fn emit_event(&self, event: TransportEvent) {
90 if let Some(tx) = self.event_tx.lock().await.as_ref() {
91 let _ = tx.send(event).await;
92 }
93 }
94}
95
96#[async_trait]
97impl Transport for HttpTransport {
98 fn transport_type(&self) -> TransportType {
99 TransportType::Http
100 }
101
102 fn state(&self) -> TransportState {
103 self.state
104 .try_read()
105 .map(|s| *s)
106 .unwrap_or(TransportState::Disconnected)
107 }
108
109 async fn connect(&mut self) -> McpResult<()> {
110 self.set_state(TransportState::Connecting).await;
111 self.emit_event(TransportEvent::Connecting).await;
112
113 let mut headers = reqwest::header::HeaderMap::new();
115 headers.insert(
116 reqwest::header::CONTENT_TYPE,
117 reqwest::header::HeaderValue::from_static("application/json"),
118 );
119
120 for (key, value) in &self.config.headers {
121 if let (Ok(name), Ok(val)) = (
122 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
123 reqwest::header::HeaderValue::from_str(value),
124 ) {
125 headers.insert(name, val);
126 }
127 }
128
129 let client = reqwest::Client::builder()
130 .default_headers(headers)
131 .timeout(self.options.timeout)
132 .build()
133 .map_err(|e| McpError::transport_with_source("Failed to create HTTP client", e))?;
134
135 self.client = Some(client);
136 self.set_state(TransportState::Connected).await;
137 self.emit_event(TransportEvent::Connected).await;
138
139 Ok(())
140 }
141
142 async fn disconnect(&mut self) -> McpResult<()> {
143 self.set_state(TransportState::Closing).await;
144 self.client = None;
145 self.set_state(TransportState::Disconnected).await;
146 self.emit_event(TransportEvent::Disconnected {
147 reason: Some("Disconnected by user".to_string()),
148 })
149 .await;
150 Ok(())
151 }
152
153 async fn send(&mut self, message: McpMessage) -> McpResult<()> {
154 let state = *self.state.read().await;
155 if state != TransportState::Connected {
156 return Err(McpError::transport("Transport is not connected"));
157 }
158
159 let client = self
160 .client
161 .as_ref()
162 .ok_or_else(|| McpError::transport("HTTP client not initialized"))?;
163
164 let json = serde_json::to_string(&message)?;
165
166 client
167 .post(&self.config.url)
168 .body(json)
169 .send()
170 .await
171 .map_err(|e| McpError::transport_with_source("Failed to send HTTP request", e))?;
172
173 Ok(())
174 }
175
176 async fn send_request(&mut self, request: McpRequest) -> McpResult<McpResponse> {
177 self.send_request_with_timeout(request, self.options.timeout)
178 .await
179 }
180
181 async fn send_request_with_timeout(
182 &mut self,
183 request: McpRequest,
184 timeout: Duration,
185 ) -> McpResult<McpResponse> {
186 let state = *self.state.read().await;
187 if state != TransportState::Connected {
188 return Err(McpError::transport("Transport is not connected"));
189 }
190
191 let client = self
192 .client
193 .as_ref()
194 .ok_or_else(|| McpError::transport("HTTP client not initialized"))?;
195
196 let json = serde_json::to_string(&request)?;
197
198 let response =
199 tokio::time::timeout(timeout, client.post(&self.config.url).body(json).send())
200 .await
201 .map_err(|_| McpError::timeout("HTTP request timed out", timeout))?
202 .map_err(|e| McpError::transport_with_source("Failed to send HTTP request", e))?;
203
204 let status = response.status();
206 if !status.is_success() {
207 return Err(McpError::transport(format!(
208 "HTTP request failed with status: {}",
209 status
210 )));
211 }
212
213 let body = response
214 .text()
215 .await
216 .map_err(|e| McpError::transport_with_source("Failed to read response body", e))?;
217
218 let mcp_response: McpResponse = serde_json::from_str(&body)?;
219
220 Ok(mcp_response)
221 }
222
223 fn subscribe(&self) -> mpsc::Receiver<TransportEvent> {
224 let (tx, rx) = mpsc::channel(100);
225 let event_tx = self.event_tx.clone();
226 tokio::spawn(async move {
227 *event_tx.lock().await = Some(tx);
228 });
229 rx
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_http_config() {
239 let config = HttpConfig {
240 url: "http://localhost:8080".to_string(),
241 headers: HashMap::new(),
242 };
243 assert_eq!(config.url, "http://localhost:8080");
244 }
245
246 #[test]
247 fn test_http_transport_new() {
248 let config = HttpConfig {
249 url: "http://localhost:8080".to_string(),
250 headers: HashMap::new(),
251 };
252 let transport = HttpTransport::new(config, ConnectionOptions::default());
253 assert_eq!(transport.transport_type(), TransportType::Http);
254 assert_eq!(transport.state(), TransportState::Disconnected);
255 }
256
257 #[test]
258 fn test_from_config() {
259 let config = TransportConfig::Http {
260 url: "http://localhost:8080".to_string(),
261 headers: HashMap::new(),
262 };
263 let transport = HttpTransport::from_config(config, ConnectionOptions::default());
264 assert!(transport.is_ok());
265 }
266
267 #[test]
268 fn test_from_config_sse() {
269 let config = TransportConfig::Sse {
270 url: "http://localhost:8080/sse".to_string(),
271 headers: HashMap::new(),
272 };
273 let transport = HttpTransport::from_config(config, ConnectionOptions::default());
274 assert!(transport.is_ok());
275 }
276
277 #[test]
278 fn test_from_config_wrong_type() {
279 let config = TransportConfig::Stdio {
280 command: "node".to_string(),
281 args: vec![],
282 env: HashMap::new(),
283 cwd: None,
284 };
285 let transport = HttpTransport::from_config(config, ConnectionOptions::default());
286 assert!(transport.is_err());
287 }
288
289 #[test]
290 fn test_next_request_id() {
291 let config = HttpConfig {
292 url: "http://localhost:8080".to_string(),
293 headers: HashMap::new(),
294 };
295 let transport = HttpTransport::new(config, ConnectionOptions::default());
296
297 let id1 = transport.next_request_id();
298 let id2 = transport.next_request_id();
299
300 assert_ne!(id1, id2);
301 assert!(id1.starts_with("http-req-"));
302 assert!(id2.starts_with("http-req-"));
303 }
304
305 #[tokio::test]
306 async fn test_connect_creates_client() {
307 let config = HttpConfig {
308 url: "http://localhost:8080".to_string(),
309 headers: HashMap::new(),
310 };
311 let mut transport = HttpTransport::new(config, ConnectionOptions::default());
312
313 let result = transport.connect().await;
314 assert!(result.is_ok());
315 assert_eq!(transport.state(), TransportState::Connected);
316 assert!(transport.client.is_some());
317 }
318
319 #[tokio::test]
320 async fn test_disconnect() {
321 let config = HttpConfig {
322 url: "http://localhost:8080".to_string(),
323 headers: HashMap::new(),
324 };
325 let mut transport = HttpTransport::new(config, ConnectionOptions::default());
326
327 transport.connect().await.unwrap();
328 let result = transport.disconnect().await;
329
330 assert!(result.is_ok());
331 assert_eq!(transport.state(), TransportState::Disconnected);
332 assert!(transport.client.is_none());
333 }
334
335 #[tokio::test]
336 async fn test_send_not_connected() {
337 let config = HttpConfig {
338 url: "http://localhost:8080".to_string(),
339 headers: HashMap::new(),
340 };
341 let mut transport = HttpTransport::new(config, ConnectionOptions::default());
342
343 let request = McpRequest::new(serde_json::json!(1), "test/method");
344 let result = transport.send(McpMessage::Request(request)).await;
345 assert!(result.is_err());
346 }
347}