mcp_compressor_core/proxy/
server.rs1use std::net::{Ipv4Addr, SocketAddr};
7use std::sync::Arc;
8
9use axum::extract::State;
10use axum::http::{header, HeaderMap, StatusCode};
11use axum::response::{IntoResponse, Response};
12use axum::routing::{get, post};
13use axum::{Json, Router};
14use serde::Deserialize;
15use serde_json::Value;
16use tokio::net::TcpListener;
17
18use crate::proxy::auth::SessionToken;
19use crate::server::compressed::CompressedServer;
20use crate::Error;
21
22#[derive(Debug)]
23pub struct ToolProxyServer;
24
25#[derive(Debug)]
26pub struct RunningToolProxy {
27 bridge_url: String,
28 token: SessionToken,
29 task: tokio::task::JoinHandle<()>,
30}
31
32#[derive(Clone)]
33struct ProxyState {
34 server: Arc<CompressedServer>,
35 token: SessionToken,
36}
37
38#[derive(Debug, Deserialize)]
39struct ExecRequest {
40 tool: String,
41 #[serde(default)]
42 input: Value,
43}
44
45#[derive(Debug, Deserialize)]
46struct WrapperInvokeInput {
47 tool_name: String,
48 #[serde(default)]
49 tool_input: Value,
50}
51
52impl ToolProxyServer {
53 pub async fn start(server: CompressedServer) -> Result<RunningToolProxy, Error> {
54 let token = SessionToken::generate();
55 let state = ProxyState {
56 server: Arc::new(server),
57 token: token.clone(),
58 };
59
60 let app = Router::new()
61 .route("/health", get(health))
62 .route("/exec", post(exec))
63 .with_state(state);
64
65 let listener = TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?;
66 let addr = listener.local_addr()?;
67 let task = tokio::spawn(async move {
68 if let Err(error) = axum::serve(listener, app).await {
69 eprintln!("mcp-compressor proxy server error: {error}");
70 }
71 });
72
73 Ok(RunningToolProxy {
74 bridge_url: format!("http://{addr}"),
75 token,
76 task,
77 })
78 }
79}
80
81async fn health() -> Response {
82 close_response(StatusCode::OK, "ok")
83}
84
85async fn exec(
86 State(state): State<ProxyState>,
87 headers: HeaderMap,
88 Json(request): Json<ExecRequest>,
89) -> Response {
90 if !authorized(&state.token, &headers) {
91 return close_response(StatusCode::UNAUTHORIZED, "unauthorized");
92 }
93
94 match dispatch_exec(&state.server, request).await {
95 Ok(result) => close_response(StatusCode::OK, result),
96 Err(error) => close_response(StatusCode::BAD_REQUEST, error.to_string()),
97 }
98}
99
100fn close_response(status: StatusCode, body: impl Into<String>) -> Response {
101 let mut response = (status, body.into()).into_response();
102 response
103 .headers_mut()
104 .insert(header::CONNECTION, header::HeaderValue::from_static("close"));
105 response
106}
107
108async fn dispatch_exec(server: &CompressedServer, request: ExecRequest) -> Result<String, Error> {
109 if request.tool.ends_with("_invoke_tool") || request.tool == "invoke_tool" {
110 let wrapper_input: WrapperInvokeInput = serde_json::from_value(request.input)?;
111 server
112 .invoke_tool(&request.tool, &wrapper_input.tool_name, wrapper_input.tool_input)
113 .await
114 } else {
115 server
116 .invoke_single_backend_tool(&request.tool, request.input)
117 .await
118 }
119}
120
121fn authorized(token: &SessionToken, headers: &HeaderMap) -> bool {
122 headers
123 .get(axum::http::header::AUTHORIZATION)
124 .and_then(|value| value.to_str().ok())
125 .is_some_and(|header| token.verify(header))
126}
127
128impl Drop for RunningToolProxy {
129 fn drop(&mut self) {
130 self.task.abort();
131 }
132}
133
134impl RunningToolProxy {
135 pub fn bridge_url(&self) -> &str {
136 &self.bridge_url
137 }
138
139 pub fn token(&self) -> &SessionToken {
140 &self.token
141 }
142
143 pub fn token_value(&self) -> &str {
144 self.token.value()
145 }
146
147 pub fn health_url(&self) -> String {
148 format!("{}/health", self.bridge_url)
149 }
150
151 pub fn exec_url(&self) -> String {
152 format!("{}/exec", self.bridge_url)
153 }
154}