Skip to main content

aster/mcp/transport/
http.rs

1//! HTTP Transport Implementation
2//!
3//! This module implements the HTTP transport for MCP communication.
4//! It uses HTTP POST requests for request/response communication.
5//!
6//! # Message Format
7//!
8//! Messages are sent as JSON-RPC 2.0 format in HTTP POST request bodies.
9//! Responses are received as JSON-RPC 2.0 format in HTTP response bodies.
10
11use async_trait::async_trait;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{mpsc, Mutex, RwLock};
17
18use crate::mcp::error::{McpError, McpResult};
19use crate::mcp::transport::{
20    McpMessage, McpRequest, McpResponse, Transport, TransportConfig, TransportEvent, TransportState,
21};
22use crate::mcp::types::{ConnectionOptions, TransportType};
23
24/// HTTP-specific configuration
25#[derive(Debug, Clone)]
26pub struct HttpConfig {
27    /// Server URL
28    pub url: String,
29    /// HTTP headers
30    pub headers: HashMap<String, String>,
31}
32
33/// HTTP transport for MCP communication
34///
35/// This transport uses HTTP POST requests for request/response communication.
36/// Each request is sent as a separate HTTP POST request and the response
37/// is received in the HTTP response body.
38pub struct HttpTransport {
39    /// Transport configuration
40    config: HttpConfig,
41    /// Connection options
42    options: ConnectionOptions,
43    /// Current transport state
44    state: Arc<RwLock<TransportState>>,
45    /// HTTP client
46    client: Option<reqwest::Client>,
47    /// Event channel sender
48    event_tx: Arc<Mutex<Option<mpsc::Sender<TransportEvent>>>>,
49    /// Request ID counter
50    request_counter: AtomicU64,
51}
52
53impl HttpTransport {
54    /// Create a new HTTP transport
55    pub fn new(config: HttpConfig, options: ConnectionOptions) -> Self {
56        Self {
57            config,
58            options,
59            state: Arc::new(RwLock::new(TransportState::Disconnected)),
60            client: None,
61            event_tx: Arc::new(Mutex::new(None)),
62            request_counter: AtomicU64::new(1),
63        }
64    }
65
66    /// Create from transport config
67    pub fn from_config(config: TransportConfig, options: ConnectionOptions) -> McpResult<Self> {
68        match config {
69            TransportConfig::Http { url, headers } | TransportConfig::Sse { url, headers } => {
70                Ok(Self::new(HttpConfig { url, headers }, options))
71            }
72            _ => Err(McpError::config("Expected HTTP transport configuration")),
73        }
74    }
75
76    /// Generate a unique request ID
77    pub fn next_request_id(&self) -> String {
78        let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
79        format!("http-req-{}", id)
80    }
81
82    /// Set the transport state
83    async fn set_state(&self, state: TransportState) {
84        let mut current = self.state.write().await;
85        *current = state;
86    }
87
88    /// Emit a transport event
89    async fn emit_event(&self, event: TransportEvent) {
90        if let Some(tx) = self.event_tx.lock().await.as_ref() {
91            let _ = tx.send(event).await;
92        }
93    }
94}
95
96#[async_trait]
97impl Transport for HttpTransport {
98    fn transport_type(&self) -> TransportType {
99        TransportType::Http
100    }
101
102    fn state(&self) -> TransportState {
103        self.state
104            .try_read()
105            .map(|s| *s)
106            .unwrap_or(TransportState::Disconnected)
107    }
108
109    async fn connect(&mut self) -> McpResult<()> {
110        self.set_state(TransportState::Connecting).await;
111        self.emit_event(TransportEvent::Connecting).await;
112
113        // Build HTTP client with headers
114        let mut headers = reqwest::header::HeaderMap::new();
115        headers.insert(
116            reqwest::header::CONTENT_TYPE,
117            reqwest::header::HeaderValue::from_static("application/json"),
118        );
119
120        for (key, value) in &self.config.headers {
121            if let (Ok(name), Ok(val)) = (
122                reqwest::header::HeaderName::from_bytes(key.as_bytes()),
123                reqwest::header::HeaderValue::from_str(value),
124            ) {
125                headers.insert(name, val);
126            }
127        }
128
129        let client = reqwest::Client::builder()
130            .default_headers(headers)
131            .timeout(self.options.timeout)
132            .build()
133            .map_err(|e| McpError::transport_with_source("Failed to create HTTP client", e))?;
134
135        self.client = Some(client);
136        self.set_state(TransportState::Connected).await;
137        self.emit_event(TransportEvent::Connected).await;
138
139        Ok(())
140    }
141
142    async fn disconnect(&mut self) -> McpResult<()> {
143        self.set_state(TransportState::Closing).await;
144        self.client = None;
145        self.set_state(TransportState::Disconnected).await;
146        self.emit_event(TransportEvent::Disconnected {
147            reason: Some("Disconnected by user".to_string()),
148        })
149        .await;
150        Ok(())
151    }
152
153    async fn send(&mut self, message: McpMessage) -> McpResult<()> {
154        let state = *self.state.read().await;
155        if state != TransportState::Connected {
156            return Err(McpError::transport("Transport is not connected"));
157        }
158
159        let client = self
160            .client
161            .as_ref()
162            .ok_or_else(|| McpError::transport("HTTP client not initialized"))?;
163
164        let json = serde_json::to_string(&message)?;
165
166        client
167            .post(&self.config.url)
168            .body(json)
169            .send()
170            .await
171            .map_err(|e| McpError::transport_with_source("Failed to send HTTP request", e))?;
172
173        Ok(())
174    }
175
176    async fn send_request(&mut self, request: McpRequest) -> McpResult<McpResponse> {
177        self.send_request_with_timeout(request, self.options.timeout)
178            .await
179    }
180
181    async fn send_request_with_timeout(
182        &mut self,
183        request: McpRequest,
184        timeout: Duration,
185    ) -> McpResult<McpResponse> {
186        let state = *self.state.read().await;
187        if state != TransportState::Connected {
188            return Err(McpError::transport("Transport is not connected"));
189        }
190
191        let client = self
192            .client
193            .as_ref()
194            .ok_or_else(|| McpError::transport("HTTP client not initialized"))?;
195
196        let json = serde_json::to_string(&request)?;
197
198        let response =
199            tokio::time::timeout(timeout, client.post(&self.config.url).body(json).send())
200                .await
201                .map_err(|_| McpError::timeout("HTTP request timed out", timeout))?
202                .map_err(|e| McpError::transport_with_source("Failed to send HTTP request", e))?;
203
204        // Check HTTP status
205        let status = response.status();
206        if !status.is_success() {
207            return Err(McpError::transport(format!(
208                "HTTP request failed with status: {}",
209                status
210            )));
211        }
212
213        let body = response
214            .text()
215            .await
216            .map_err(|e| McpError::transport_with_source("Failed to read response body", e))?;
217
218        let mcp_response: McpResponse = serde_json::from_str(&body)?;
219
220        Ok(mcp_response)
221    }
222
223    fn subscribe(&self) -> mpsc::Receiver<TransportEvent> {
224        let (tx, rx) = mpsc::channel(100);
225        let event_tx = self.event_tx.clone();
226        tokio::spawn(async move {
227            *event_tx.lock().await = Some(tx);
228        });
229        rx
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_http_config() {
239        let config = HttpConfig {
240            url: "http://localhost:8080".to_string(),
241            headers: HashMap::new(),
242        };
243        assert_eq!(config.url, "http://localhost:8080");
244    }
245
246    #[test]
247    fn test_http_transport_new() {
248        let config = HttpConfig {
249            url: "http://localhost:8080".to_string(),
250            headers: HashMap::new(),
251        };
252        let transport = HttpTransport::new(config, ConnectionOptions::default());
253        assert_eq!(transport.transport_type(), TransportType::Http);
254        assert_eq!(transport.state(), TransportState::Disconnected);
255    }
256
257    #[test]
258    fn test_from_config() {
259        let config = TransportConfig::Http {
260            url: "http://localhost:8080".to_string(),
261            headers: HashMap::new(),
262        };
263        let transport = HttpTransport::from_config(config, ConnectionOptions::default());
264        assert!(transport.is_ok());
265    }
266
267    #[test]
268    fn test_from_config_sse() {
269        let config = TransportConfig::Sse {
270            url: "http://localhost:8080/sse".to_string(),
271            headers: HashMap::new(),
272        };
273        let transport = HttpTransport::from_config(config, ConnectionOptions::default());
274        assert!(transport.is_ok());
275    }
276
277    #[test]
278    fn test_from_config_wrong_type() {
279        let config = TransportConfig::Stdio {
280            command: "node".to_string(),
281            args: vec![],
282            env: HashMap::new(),
283            cwd: None,
284        };
285        let transport = HttpTransport::from_config(config, ConnectionOptions::default());
286        assert!(transport.is_err());
287    }
288
289    #[test]
290    fn test_next_request_id() {
291        let config = HttpConfig {
292            url: "http://localhost:8080".to_string(),
293            headers: HashMap::new(),
294        };
295        let transport = HttpTransport::new(config, ConnectionOptions::default());
296
297        let id1 = transport.next_request_id();
298        let id2 = transport.next_request_id();
299
300        assert_ne!(id1, id2);
301        assert!(id1.starts_with("http-req-"));
302        assert!(id2.starts_with("http-req-"));
303    }
304
305    #[tokio::test]
306    async fn test_connect_creates_client() {
307        let config = HttpConfig {
308            url: "http://localhost:8080".to_string(),
309            headers: HashMap::new(),
310        };
311        let mut transport = HttpTransport::new(config, ConnectionOptions::default());
312
313        let result = transport.connect().await;
314        assert!(result.is_ok());
315        assert_eq!(transport.state(), TransportState::Connected);
316        assert!(transport.client.is_some());
317    }
318
319    #[tokio::test]
320    async fn test_disconnect() {
321        let config = HttpConfig {
322            url: "http://localhost:8080".to_string(),
323            headers: HashMap::new(),
324        };
325        let mut transport = HttpTransport::new(config, ConnectionOptions::default());
326
327        transport.connect().await.unwrap();
328        let result = transport.disconnect().await;
329
330        assert!(result.is_ok());
331        assert_eq!(transport.state(), TransportState::Disconnected);
332        assert!(transport.client.is_none());
333    }
334
335    #[tokio::test]
336    async fn test_send_not_connected() {
337        let config = HttpConfig {
338            url: "http://localhost:8080".to_string(),
339            headers: HashMap::new(),
340        };
341        let mut transport = HttpTransport::new(config, ConnectionOptions::default());
342
343        let request = McpRequest::new(serde_json::json!(1), "test/method");
344        let result = transport.send(McpMessage::Request(request)).await;
345        assert!(result.is_err());
346    }
347}