mcp_runner/sse_proxy/
handlers.rs

1//! HTTP request handlers for the Actix Web-based SSE proxy.
2//!
3//! This module contains the Actix Web handlers for the various endpoints
4//! supported by the SSE proxy, including SSE events streaming, tool calls,
5//! and server information retrieval.
6
7use crate::client::McpClient;
8use crate::error::{Error, Result};
9use crate::server::ServerId;
10use crate::sse_proxy::events::EventManager;
11use crate::sse_proxy::proxy::SSEProxy;
12use crate::sse_proxy::types::{ResourceInfo, ToolInfo};
13use crate::transport::json_rpc::{JSON_RPC_VERSION, JsonRpcRequest, JsonRpcResponse};
14
15use actix_web::{
16    HttpRequest, HttpResponse, Responder,
17    web::{Bytes, Data, Json, Path},
18};
19use serde::{Deserialize, Serialize};
20use serde_json::{Value, json};
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::Mutex;
24use tokio::time::interval;
25use tracing;
26
27// Define a struct to deserialize the resource response from the client
28#[derive(Deserialize)]
29#[serde(rename_all = "camelCase")]
30struct ResourceResponse {
31    content_type: String,
32    // Assuming data is returned as a string (potentially base64 for binary)
33    data: String,
34}
35
36/// Helper function to validate server, retrieve client and initialize it
37/// Returns a tuple of (client, server_id, server_id_str) or an Error
38async fn get_validated_client(
39    proxy: &SSEProxy,
40    server_name: &str,
41) -> Result<(McpClient, ServerId, String)> {
42    // Check if server is allowed
43    if let Some(allowed_servers) = (proxy.get_runner_access().get_allowed_servers)() {
44        if !allowed_servers.contains(&server_name.to_string()) {
45            return Err(Error::Unauthorized(format!(
46                "Server '{}' not in allowed list",
47                server_name
48            )));
49        }
50    }
51
52    // Get server ID
53    let server_id = match (proxy.get_runner_access().get_server_id)(server_name) {
54        Ok(id) => id,
55        Err(e) => {
56            tracing::warn!(server = %server_name, error = %e, "Server not found");
57            return Err(e);
58        }
59    };
60    let server_id_str = format!("{:?}", server_id);
61
62    // Get client
63    let client = match (proxy.get_runner_access().get_client)(server_id) {
64        Ok(client) => client,
65        Err(e) => {
66            tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
67            return Err(e);
68        }
69    };
70
71    // Initialize client
72    client.initialize().await?;
73
74    Ok((client, server_id, server_id_str))
75}
76
77/// Helper function to create a JsonRPC error response
78fn create_jsonrpc_error(request_id: Value, code: i32, message: String) -> HttpResponse {
79    let error_response = JsonRpcResponse::error(request_id, code, message, None);
80
81    HttpResponse::InternalServerError().json(error_response)
82}
83
84/// Request body for tool calls
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ToolCallRequest {
87    /// Tool name to call
88    pub tool: String,
89    /// Server name to call the tool on
90    pub server: String,
91    /// Arguments to pass to the tool
92    pub args: Value,
93    /// Request ID for correlation
94    #[serde(rename = "requestId")]
95    pub request_id: String,
96}
97
98/// SSE stream handler
99///
100/// This handler creates and returns an SSE stream for clients to receive events.
101///
102/// # Returns
103///
104/// An HTTP response with an SSE stream
105pub async fn sse_events(
106    event_manager: Data<Arc<EventManager>>,
107    _req: HttpRequest,
108) -> impl Responder {
109    tracing::debug!("Client connected to SSE stream");
110
111    // Create a receiver from the event manager
112    let mut receiver = event_manager.subscribe();
113
114    // Prepare the event stream
115    let stream = async_stream::stream! {
116        // Send an initial event to confirm connection
117        let initial_event = json!({
118            "type": "notification",
119            "title": "Connected",
120            "message": "SSE connection established",
121            "level": "info"
122        });
123
124        if let Ok(data) = serde_json::to_string(&initial_event) {
125            let msg = crate::sse_proxy::types::SSEMessage::new("notification", &data, None);
126            yield Ok::<_, actix_web::Error>(EventManager::format_sse_message(&msg));
127        }
128
129        // Create a heartbeat interval (every 30 seconds)
130        let mut heartbeat_interval = interval(Duration::from_secs(30));
131
132        loop {
133            tokio::select! {
134                // Check for new events
135                event = receiver.recv() => {
136                    match event {
137                        Ok(msg) => {
138                            yield Ok::<_, actix_web::Error>(EventManager::format_sse_message(&msg));
139                        },
140                        Err(e) => {
141                            tracing::error!(error = %e, "Error receiving SSE event");
142                            break;
143                        }
144                    }
145                },
146                // Send heartbeat
147                _ = heartbeat_interval.tick() => {
148                    yield Ok::<_, actix_web::Error>(Bytes::from(":\n\n")); // Colon comment for heartbeat
149                }
150            }
151        }
152    };
153
154    // Return the HTTP response with the SSE stream
155    HttpResponse::Ok()
156        .append_header(("Content-Type", "text/event-stream"))
157        .append_header(("Cache-Control", "no-cache"))
158        .append_header(("Connection", "keep-alive"))
159        .append_header(("Access-Control-Allow-Origin", "*"))
160        .streaming(stream)
161}
162
163/// Initialize the connection
164///
165/// This handler initializes the connection with the client.
166///
167/// # Returns
168///
169/// A JSON response with initialization information
170pub async fn initialize(_req: HttpRequest) -> impl Responder {
171    let response = json!({
172        "status": "ok",
173        "version": env!("CARGO_PKG_VERSION"),
174        "serverType": "mcp-runner-actix",
175    });
176
177    HttpResponse::Ok().json(response)
178}
179
180/// Process tool calls
181///
182/// This handler processes tool calls from clients and streams the response
183/// through the SSE channel.
184///
185/// # Returns
186///
187/// An HTTP response indicating the tool call was accepted
188pub async fn tool_call(
189    proxy: Data<Arc<Mutex<SSEProxy>>>,
190    body: Json<ToolCallRequest>,
191) -> impl Responder {
192    let request = body.into_inner();
193    let tool_name = request.tool.clone();
194    let server_name = request.server.clone();
195    let args = request.args.clone();
196    let request_id = request.request_id.clone();
197
198    tracing::debug!(
199        req_id = %request_id,
200        server = %server_name,
201        tool = %tool_name,
202        "Tool call request received"
203    );
204
205    // Clone the proxy for async processing
206    let proxy_lock = Arc::clone(&proxy);
207    // Clone request_id again for use in the async block
208    let request_id_clone = request_id.clone();
209
210    // Process the tool call asynchronously
211    tokio::spawn(async move {
212        // Acquire the lock on the proxy
213        let proxy = proxy_lock.lock().await;
214
215        // Process the tool call using the cloned request_id
216        if let Err(e) = proxy
217            .process_tool_call(&server_name, &tool_name, args, &request_id_clone)
218            .await
219        {
220            tracing::error!(
221                req_id = %request_id_clone, // Use cloned request_id here
222                server = %server_name,
223                tool = %tool_name,
224                error = %e,
225                "Failed to process tool call"
226            );
227        }
228    });
229
230    // Return a successful response immediately using the original request_id
231    HttpResponse::Accepted().json(json!({
232        "status": "accepted",
233        "requestId": request_id, // Use original request_id here
234        "message": "Tool call accepted and processing asynchronously"
235    }))
236}
237
238/// List available servers
239///
240/// This handler returns a list of all available servers with their status.
241///
242/// # Returns
243///
244/// A JSON response with server information
245pub async fn list_servers(proxy: Data<Arc<Mutex<SSEProxy>>>) -> impl Responder {
246    let proxy = proxy.lock().await;
247    let servers = proxy.get_server_info().lock().await;
248
249    // Convert to a vector for the response
250    let server_list: Vec<_> = servers.values().cloned().collect();
251
252    HttpResponse::Ok().json(server_list)
253}
254
255/// List tools for a specific server
256///
257/// This handler returns a list of all tools available on a specific server.
258///
259/// # Returns
260///
261/// A JSON response with tool information
262pub async fn list_server_tools(
263    proxy: Data<Arc<Mutex<SSEProxy>>>,
264    path: Path<(String,)>,
265) -> Result<impl Responder> {
266    let server_name = &path.0;
267    let proxy = proxy.lock().await;
268
269    let (client, _, _) = get_validated_client(&proxy, server_name).await?;
270
271    // List tools
272    let tools = client.list_tools().await?;
273
274    // Convert to ToolInfo for response
275    let tool_infos: Vec<ToolInfo> = tools
276        .into_iter()
277        .map(|tool| ToolInfo {
278            name: tool.name,
279            description: tool.description,
280            parameters: tool.input_schema, // Map input_schema to parameters
281            return_type: tool.output_schema.map(|v| v.to_string()), // Map output_schema to return_type string
282        })
283        .collect();
284
285    Ok(HttpResponse::Ok().json(tool_infos))
286}
287
288/// List resources for a specific server
289///
290/// This handler returns a list of all resources available on a specific server.
291///
292/// # Returns
293///
294/// A JSON response with resource information
295pub async fn list_server_resources(
296    proxy: Data<Arc<Mutex<SSEProxy>>>,
297    path: Path<(String,)>,
298) -> Result<impl Responder> {
299    let server_name = &path.0;
300    let proxy = proxy.lock().await;
301
302    let (client, _, _) = get_validated_client(&proxy, server_name).await?;
303
304    // List resources
305    let resources_result = client.list_resources().await;
306
307    match resources_result {
308        Ok(resources) => {
309            // Convert to ResourceInfo for response
310            let resource_infos: Vec<ResourceInfo> = resources
311                .into_iter()
312                .map(|r| ResourceInfo {
313                    name: r.name,
314                    description: r.description,
315                    metadata: None, // No metadata field in client::Resource
316                })
317                .collect();
318
319            Ok(HttpResponse::Ok().json(resource_infos))
320        }
321        Err(e) => {
322            // Check if the error indicates an unsupported operation (e.g., method not found)
323            if matches!(&e, Error::JsonRpc(s) if s.contains("method not found") || s.contains("unsupported"))
324            {
325                // Return empty list if resources not supported
326                tracing::debug!(server = %server_name, "Server does not support resources (method not found)");
327                Ok(HttpResponse::Ok().json(Vec::<ResourceInfo>::new()))
328            } else {
329                Err(e)
330            }
331        }
332    }
333}
334
335/// Get a specific resource from a server
336///
337/// This handler retrieves a specific resource from a server.
338///
339/// # Returns
340///
341/// The resource content with appropriate content type
342pub async fn get_server_resource(
343    proxy: Data<Arc<Mutex<SSEProxy>>>,
344    path: Path<(String, String)>,
345) -> Result<impl Responder> {
346    let (server_name, resource_name) = (&path.0, &path.1);
347    let proxy = proxy.lock().await;
348
349    let (client, _, _) = get_validated_client(&proxy, server_name).await?;
350
351    // Get the resource by deserializing into ResourceResponse
352    let resource_response: ResourceResponse = client.get_resource(resource_name).await?;
353
354    // Convert the data string to Bytes.
355    // TODO: Handle potential base64 decoding if necessary based on content_type
356    let data_bytes = Bytes::from(resource_response.data);
357
358    // Determine how to return the resource based on content type
359    Ok(HttpResponse::Ok()
360        .content_type(resource_response.content_type)
361        .body(data_bytes))
362}
363
364/// Process JsonRPC tool calls
365///
366/// This handler processes JsonRPC tool calls and returns the response directly.
367///
368/// # Returns
369///
370/// A JSON response with the tool call result
371pub async fn tool_call_jsonrpc(
372    proxy: Data<Arc<Mutex<SSEProxy>>>,
373    body: Bytes,
374) -> Result<impl Responder> {
375    // Parse the JsonRPC request
376    let json_rpc_req: JsonRpcRequest = match serde_json::from_slice(&body) {
377        Ok(req) => req,
378        Err(e) => {
379            tracing::error!(error = %e, "Failed to parse JsonRPC request");
380            // Use Error::JsonRpc instead of Error::InvalidRequest
381            return Err(Error::JsonRpc(format!("Invalid JsonRPC request: {}", e)));
382        }
383    };
384
385    // Check JsonRPC version
386    if json_rpc_req.jsonrpc != JSON_RPC_VERSION {
387        // Use Error::JsonRpc instead of Error::InvalidRequest
388        return Err(Error::JsonRpc(format!(
389            "Unsupported JsonRPC version: {}",
390            json_rpc_req.jsonrpc
391        )));
392    }
393
394    // Extract method and params
395    let request_id = json_rpc_req.id.clone();
396    let method = json_rpc_req.method.clone();
397
398    // Parse method to get server name and tool name
399    let parts: Vec<&str> = method.splitn(2, '.').collect();
400    if parts.len() != 2 {
401        // Use Error::JsonRpc instead of Error::InvalidRequest
402        return Err(Error::JsonRpc(format!(
403            "Invalid method format. Expected 'server.tool', got '{}'",
404            method
405        )));
406    }
407
408    let server_name = parts[0];
409    let tool_name = parts[1];
410    let args = json_rpc_req.params.unwrap_or(json!({}));
411
412    tracing::debug!(
413        req_id = ?request_id,
414        server = %server_name,
415        tool = %tool_name,
416        "JsonRPC tool call received"
417    );
418
419    // Process the tool call
420    let proxy_instance = proxy.lock().await;
421
422    let (client, _, _) = match get_validated_client(&proxy_instance, server_name).await {
423        Ok(result) => result,
424        Err(e) => {
425            let error_response =
426                create_jsonrpc_error(request_id, -32000, format!("Validation failed: {}", e));
427            return Ok(error_response);
428        }
429    };
430
431    // Call the tool
432    match client.call_tool(tool_name, &args).await {
433        Ok(result) => {
434            // Use JsonRpcResponse::success instead of ::result
435            let json_rpc_resp = JsonRpcResponse::success(request_id, result);
436            Ok(HttpResponse::Ok().json(json_rpc_resp))
437        }
438        Err(e) => {
439            tracing::error!(
440                req_id = ?request_id,
441                error = %e,
442                "Tool call failed"
443            );
444            let error_response =
445                create_jsonrpc_error(request_id, -32000, format!("Tool call failed: {}", e));
446            Ok(error_response)
447        }
448    }
449}