use std::sync::Arc;
use arc_swap::ArcSwap;
use rmcp::handler::client::ClientHandler;
use rmcp::handler::server::ServerHandler;
use rmcp::model::*;
use rmcp::service::{Peer, RequestContext, RoleClient, RoleServer};
use rmcp::Error as McpError;
use crate::filter::FilterEngine;
use crate::tracking::Tracker;
#[derive(Clone)]
pub struct ProxyServer {
upstream: Peer<RoleClient>,
filter: Arc<ArcSwap<FilterEngine>>,
tracker: Option<Arc<Tracker>>,
peer: Option<Peer<RoleServer>>,
}
impl ProxyServer {
pub fn new(
engine: Arc<ArcSwap<FilterEngine>>,
tracker: Option<Arc<Tracker>>,
upstream: Peer<RoleClient>,
) -> Self {
Self {
upstream,
filter: engine,
tracker,
peer: None,
}
}
}
impl ServerHandler for ProxyServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: Default::default(),
capabilities: ServerCapabilities {
tools: Some(ToolsCapability {
list_changed: Some(true),
}),
..Default::default()
},
server_info: Implementation {
name: "mcp-rtk".into(),
version: env!("CARGO_PKG_VERSION").into(),
},
instructions: Some("Token-optimizing MCP proxy".into()),
}
}
fn get_peer(&self) -> Option<Peer<RoleServer>> {
self.peer.clone()
}
fn set_peer(&mut self, peer: Peer<RoleServer>) {
self.peer = Some(peer);
}
fn list_tools(
&self,
request: PaginatedRequestParam,
_context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
let upstream = self.upstream.clone();
async move {
let params = request.or(Some(PaginatedRequestParamInner { cursor: None }));
upstream.list_tools(params).await.map_err(|e| {
McpError::internal_error(format!("upstream list_tools failed: {e}"), None)
})
}
}
fn call_tool(
&self,
request: CallToolRequestParam,
_context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
let filter = self.filter.load_full();
let tracker = self.tracker.clone();
let preset = filter
.config()
.preset
.clone()
.unwrap_or_else(|| "generic".to_string());
let tool_name = request.name.to_string();
let upstream = self.upstream.clone();
async move {
let result = upstream.call_tool(request).await.map_err(|e| {
McpError::internal_error(format!("upstream call_tool failed: {e}"), None)
})?;
let filtered_content: Vec<Content> = result
.content
.into_iter()
.map(|content| {
filter_content(&filter, &tool_name, content, tracker.as_ref(), &preset)
})
.collect();
Ok(CallToolResult {
content: filtered_content,
is_error: result.is_error,
})
}
}
}
#[derive(Clone, Default)]
pub struct ProxyClient {
peer: Option<Peer<RoleClient>>,
}
impl ProxyClient {
pub fn new() -> Self {
Self::default()
}
}
impl ClientHandler for ProxyClient {
fn get_info(&self) -> ClientInfo {
ClientInfo {
protocol_version: Default::default(),
capabilities: Default::default(),
client_info: Implementation {
name: "mcp-rtk-client".into(),
version: env!("CARGO_PKG_VERSION").into(),
},
}
}
fn get_peer(&self) -> Option<Peer<RoleClient>> {
self.peer.clone()
}
fn set_peer(&mut self, peer: Peer<RoleClient>) {
self.peer = Some(peer);
}
}
fn filter_content(
filter: &FilterEngine,
tool_name: &str,
content: Content,
tracker: Option<&Arc<Tracker>>,
preset: &str,
) -> Content {
match &content.raw {
RawContent::Text(text_content) => {
let raw = &text_content.text;
let filtered = filter.filter(tool_name, raw);
if let Some(tracker) = tracker {
if let Err(e) = tracker.track(tool_name, raw, &filtered, preset) {
tracing::warn!("Failed to track tool call: {e}");
}
}
let raw_len = raw.len().max(1);
tracing::debug!(
tool = tool_name,
raw_len = raw.len(),
filtered_len = filtered.len(),
savings_pct = format!(
"{:.1}%",
(1.0 - filtered.len() as f64 / raw_len as f64) * 100.0
),
"Filtered tool response"
);
Annotated {
raw: RawContent::Text(RawTextContent { text: filtered }),
annotations: content.annotations,
}
}
_ => content,
}
}