1use crate::client::McpClient;
8use crate::error::{Error, Result};
9use crate::server::ServerId;
10use crate::sse_proxy::events::EventManager;
11use crate::sse_proxy::proxy::SSEProxy;
12use crate::sse_proxy::types::{ResourceInfo, ToolInfo};
13use crate::transport::json_rpc::{JSON_RPC_VERSION, JsonRpcRequest, JsonRpcResponse};
14
15use actix_web::{
16 HttpRequest, HttpResponse, Responder,
17 web::{Bytes, Data, Json, Path},
18};
19use serde::{Deserialize, Serialize};
20use serde_json::{Value, json};
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::Mutex;
24use tokio::time::interval;
25use tracing;
26
27#[derive(Deserialize)]
29#[serde(rename_all = "camelCase")]
30struct ResourceResponse {
31 content_type: String,
32 data: String,
34}
35
36async fn get_validated_client(
39 proxy: &SSEProxy,
40 server_name: &str,
41) -> Result<(McpClient, ServerId, String)> {
42 if let Some(allowed_servers) = (proxy.get_runner_access().get_allowed_servers)() {
44 if !allowed_servers.contains(&server_name.to_string()) {
45 return Err(Error::Unauthorized(format!(
46 "Server '{}' not in allowed list",
47 server_name
48 )));
49 }
50 }
51
52 let server_id = match (proxy.get_runner_access().get_server_id)(server_name) {
54 Ok(id) => id,
55 Err(e) => {
56 tracing::warn!(server = %server_name, error = %e, "Server not found");
57 return Err(e);
58 }
59 };
60 let server_id_str = format!("{:?}", server_id);
61
62 let client = match (proxy.get_runner_access().get_client)(server_id) {
64 Ok(client) => client,
65 Err(e) => {
66 tracing::error!(server_id = ?server_id, error = %e, "Failed to get client");
67 return Err(e);
68 }
69 };
70
71 client.initialize().await?;
73
74 Ok((client, server_id, server_id_str))
75}
76
77fn create_jsonrpc_error(request_id: Value, code: i32, message: String) -> HttpResponse {
79 let error_response = JsonRpcResponse::error(request_id, code, message, None);
80
81 HttpResponse::InternalServerError().json(error_response)
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ToolCallRequest {
87 pub tool: String,
89 pub server: String,
91 pub args: Value,
93 #[serde(rename = "requestId")]
95 pub request_id: String,
96}
97
98pub async fn sse_events(
106 event_manager: Data<Arc<EventManager>>,
107 _req: HttpRequest,
108) -> impl Responder {
109 tracing::debug!("Client connected to SSE stream");
110
111 let mut receiver = event_manager.subscribe();
113
114 let stream = async_stream::stream! {
116 let initial_event = json!({
118 "type": "notification",
119 "title": "Connected",
120 "message": "SSE connection established",
121 "level": "info"
122 });
123
124 if let Ok(data) = serde_json::to_string(&initial_event) {
125 let msg = crate::sse_proxy::types::SSEMessage::new("notification", &data, None);
126 yield Ok::<_, actix_web::Error>(EventManager::format_sse_message(&msg));
127 }
128
129 let mut heartbeat_interval = interval(Duration::from_secs(30));
131
132 loop {
133 tokio::select! {
134 event = receiver.recv() => {
136 match event {
137 Ok(msg) => {
138 yield Ok::<_, actix_web::Error>(EventManager::format_sse_message(&msg));
139 },
140 Err(e) => {
141 tracing::error!(error = %e, "Error receiving SSE event");
142 break;
143 }
144 }
145 },
146 _ = heartbeat_interval.tick() => {
148 yield Ok::<_, actix_web::Error>(Bytes::from(":\n\n")); }
150 }
151 }
152 };
153
154 HttpResponse::Ok()
156 .append_header(("Content-Type", "text/event-stream"))
157 .append_header(("Cache-Control", "no-cache"))
158 .append_header(("Connection", "keep-alive"))
159 .append_header(("Access-Control-Allow-Origin", "*"))
160 .streaming(stream)
161}
162
163pub async fn initialize(_req: HttpRequest) -> impl Responder {
171 let response = json!({
172 "status": "ok",
173 "version": env!("CARGO_PKG_VERSION"),
174 "serverType": "mcp-runner-actix",
175 });
176
177 HttpResponse::Ok().json(response)
178}
179
180pub async fn tool_call(
189 proxy: Data<Arc<Mutex<SSEProxy>>>,
190 body: Json<ToolCallRequest>,
191) -> impl Responder {
192 let request = body.into_inner();
193 let tool_name = request.tool.clone();
194 let server_name = request.server.clone();
195 let args = request.args.clone();
196 let request_id = request.request_id.clone();
197
198 tracing::debug!(
199 req_id = %request_id,
200 server = %server_name,
201 tool = %tool_name,
202 "Tool call request received"
203 );
204
205 let proxy_lock = Arc::clone(&proxy);
207 let request_id_clone = request_id.clone();
209
210 tokio::spawn(async move {
212 let proxy = proxy_lock.lock().await;
214
215 if let Err(e) = proxy
217 .process_tool_call(&server_name, &tool_name, args, &request_id_clone)
218 .await
219 {
220 tracing::error!(
221 req_id = %request_id_clone, server = %server_name,
223 tool = %tool_name,
224 error = %e,
225 "Failed to process tool call"
226 );
227 }
228 });
229
230 HttpResponse::Accepted().json(json!({
232 "status": "accepted",
233 "requestId": request_id, "message": "Tool call accepted and processing asynchronously"
235 }))
236}
237
238pub async fn list_servers(proxy: Data<Arc<Mutex<SSEProxy>>>) -> impl Responder {
246 let proxy = proxy.lock().await;
247 let servers = proxy.get_server_info().lock().await;
248
249 let server_list: Vec<_> = servers.values().cloned().collect();
251
252 HttpResponse::Ok().json(server_list)
253}
254
255pub async fn list_server_tools(
263 proxy: Data<Arc<Mutex<SSEProxy>>>,
264 path: Path<(String,)>,
265) -> Result<impl Responder> {
266 let server_name = &path.0;
267 let proxy = proxy.lock().await;
268
269 let (client, _, _) = get_validated_client(&proxy, server_name).await?;
270
271 let tools = client.list_tools().await?;
273
274 let tool_infos: Vec<ToolInfo> = tools
276 .into_iter()
277 .map(|tool| ToolInfo {
278 name: tool.name,
279 description: tool.description,
280 parameters: tool.input_schema, return_type: tool.output_schema.map(|v| v.to_string()), })
283 .collect();
284
285 Ok(HttpResponse::Ok().json(tool_infos))
286}
287
288pub async fn list_server_resources(
296 proxy: Data<Arc<Mutex<SSEProxy>>>,
297 path: Path<(String,)>,
298) -> Result<impl Responder> {
299 let server_name = &path.0;
300 let proxy = proxy.lock().await;
301
302 let (client, _, _) = get_validated_client(&proxy, server_name).await?;
303
304 let resources_result = client.list_resources().await;
306
307 match resources_result {
308 Ok(resources) => {
309 let resource_infos: Vec<ResourceInfo> = resources
311 .into_iter()
312 .map(|r| ResourceInfo {
313 name: r.name,
314 description: r.description,
315 metadata: None, })
317 .collect();
318
319 Ok(HttpResponse::Ok().json(resource_infos))
320 }
321 Err(e) => {
322 if matches!(&e, Error::JsonRpc(s) if s.contains("method not found") || s.contains("unsupported"))
324 {
325 tracing::debug!(server = %server_name, "Server does not support resources (method not found)");
327 Ok(HttpResponse::Ok().json(Vec::<ResourceInfo>::new()))
328 } else {
329 Err(e)
330 }
331 }
332 }
333}
334
335pub async fn get_server_resource(
343 proxy: Data<Arc<Mutex<SSEProxy>>>,
344 path: Path<(String, String)>,
345) -> Result<impl Responder> {
346 let (server_name, resource_name) = (&path.0, &path.1);
347 let proxy = proxy.lock().await;
348
349 let (client, _, _) = get_validated_client(&proxy, server_name).await?;
350
351 let resource_response: ResourceResponse = client.get_resource(resource_name).await?;
353
354 let data_bytes = Bytes::from(resource_response.data);
357
358 Ok(HttpResponse::Ok()
360 .content_type(resource_response.content_type)
361 .body(data_bytes))
362}
363
364pub async fn tool_call_jsonrpc(
372 proxy: Data<Arc<Mutex<SSEProxy>>>,
373 body: Bytes,
374) -> Result<impl Responder> {
375 let json_rpc_req: JsonRpcRequest = match serde_json::from_slice(&body) {
377 Ok(req) => req,
378 Err(e) => {
379 tracing::error!(error = %e, "Failed to parse JsonRPC request");
380 return Err(Error::JsonRpc(format!("Invalid JsonRPC request: {}", e)));
382 }
383 };
384
385 if json_rpc_req.jsonrpc != JSON_RPC_VERSION {
387 return Err(Error::JsonRpc(format!(
389 "Unsupported JsonRPC version: {}",
390 json_rpc_req.jsonrpc
391 )));
392 }
393
394 let request_id = json_rpc_req.id.clone();
396 let method = json_rpc_req.method.clone();
397
398 let parts: Vec<&str> = method.splitn(2, '.').collect();
400 if parts.len() != 2 {
401 return Err(Error::JsonRpc(format!(
403 "Invalid method format. Expected 'server.tool', got '{}'",
404 method
405 )));
406 }
407
408 let server_name = parts[0];
409 let tool_name = parts[1];
410 let args = json_rpc_req.params.unwrap_or(json!({}));
411
412 tracing::debug!(
413 req_id = ?request_id,
414 server = %server_name,
415 tool = %tool_name,
416 "JsonRPC tool call received"
417 );
418
419 let proxy_instance = proxy.lock().await;
421
422 let (client, _, _) = match get_validated_client(&proxy_instance, server_name).await {
423 Ok(result) => result,
424 Err(e) => {
425 let error_response =
426 create_jsonrpc_error(request_id, -32000, format!("Validation failed: {}", e));
427 return Ok(error_response);
428 }
429 };
430
431 match client.call_tool(tool_name, &args).await {
433 Ok(result) => {
434 let json_rpc_resp = JsonRpcResponse::success(request_id, result);
436 Ok(HttpResponse::Ok().json(json_rpc_resp))
437 }
438 Err(e) => {
439 tracing::error!(
440 req_id = ?request_id,
441 error = %e,
442 "Tool call failed"
443 );
444 let error_response =
445 create_jsonrpc_error(request_id, -32000, format!("Tool call failed: {}", e));
446 Ok(error_response)
447 }
448 }
449}