use std::sync::Mutex;
use std::time::Duration;
use anyhow::Result;
use rmcp::model::{
CallToolRequestParams, CallToolResult, CancelledNotificationParam, ClientInfo, ClientRequest,
CompleteRequestParams, CompleteResult, CreateElicitationRequestParams, CreateElicitationResult,
CreateMessageRequestParams, CreateMessageResult, ElicitationResponseNotificationParam,
ErrorCode, ErrorData, GetPromptRequestParams, GetPromptResult, Implementation,
InitializeRequestParams, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
ListResourcesResult, ListRootsResult, ListToolsResult, LoggingMessageNotificationParam,
PaginatedRequestParams, PingRequest, ProgressNotificationParam, ReadResourceRequestParams,
ReadResourceResult, ResourceUpdatedNotificationParam, ServerCapabilities,
SetLevelRequestParams, SubscribeRequestParams, UnsubscribeRequestParams,
};
use rmcp::service::{
NotificationContext, Peer, RequestContext, RoleClient, RoleServer, RunningService, ServiceExt,
};
use rmcp::{ClientHandler, ServerHandler};
use tokio::sync::OnceCell;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
use crate::filter::ToolFilter;
use crate::transport::RemoteTransport;
#[derive(Debug, Clone, Copy)]
pub struct KeepaliveConfig {
pub interval: Option<Duration>,
pub timeout: Duration,
}
impl KeepaliveConfig {
pub fn from_secs(interval_secs: u64, timeout_secs: u64) -> Self {
let interval = (interval_secs > 0).then(|| Duration::from_secs(interval_secs));
Self {
interval,
timeout: Duration::from_secs(timeout_secs.max(1)),
}
}
#[cfg(test)]
pub const fn disabled() -> Self {
Self {
interval: None,
timeout: Duration::from_secs(10),
}
}
}
pub struct RemoteClientHandler {
upstream: Peer<RoleServer>,
proxied_info: ClientInfo,
}
impl ClientHandler for RemoteClientHandler {
fn get_info(&self) -> ClientInfo {
self.proxied_info.clone()
}
#[tracing::instrument(skip_all, fields(dir = "remote←local"))]
async fn create_message(
&self,
params: CreateMessageRequestParams,
_ctx: RequestContext<RoleClient>,
) -> Result<CreateMessageResult, ErrorData> {
self.upstream
.create_message(params)
.await
.map_err(upstream_error)
}
#[tracing::instrument(skip_all, fields(dir = "remote←local"))]
async fn list_roots(
&self,
_ctx: RequestContext<RoleClient>,
) -> Result<ListRootsResult, ErrorData> {
self.upstream.list_roots().await.map_err(upstream_error)
}
#[tracing::instrument(skip_all, fields(dir = "remote←local"))]
async fn create_elicitation(
&self,
request: CreateElicitationRequestParams,
_ctx: RequestContext<RoleClient>,
) -> Result<CreateElicitationResult, ErrorData> {
self.upstream
.create_elicitation(request)
.await
.map_err(upstream_error)
}
#[tracing::instrument(skip_all, fields(dir = "remote→local"))]
async fn on_cancelled(
&self,
params: CancelledNotificationParam,
_ctx: NotificationContext<RoleClient>,
) {
if let Err(e) = self.upstream.notify_cancelled(params).await {
tracing::warn!(error = %e, "failed to forward cancelled to stdio client");
}
}
#[tracing::instrument(level = "debug", skip_all, fields(dir = "remote→local"))]
async fn on_progress(
&self,
params: ProgressNotificationParam,
_ctx: NotificationContext<RoleClient>,
) {
if let Err(e) = self.upstream.notify_progress(params).await {
tracing::warn!(error = %e, "failed to forward progress to stdio client");
}
}
#[tracing::instrument(level = "debug", skip_all, fields(dir = "remote→local"))]
async fn on_logging_message(
&self,
params: LoggingMessageNotificationParam,
_ctx: NotificationContext<RoleClient>,
) {
if let Err(e) = self.upstream.notify_logging_message(params).await {
tracing::warn!(error = %e, "failed to forward logging message to stdio client");
}
}
#[tracing::instrument(skip_all, fields(dir = "remote→local", uri = %params.uri))]
async fn on_resource_updated(
&self,
params: ResourceUpdatedNotificationParam,
_ctx: NotificationContext<RoleClient>,
) {
if let Err(e) = self.upstream.notify_resource_updated(params).await {
tracing::warn!(error = %e, "failed to forward resource_updated to stdio client");
}
}
#[tracing::instrument(skip_all, fields(dir = "remote→local"))]
async fn on_resource_list_changed(&self, _ctx: NotificationContext<RoleClient>) {
if let Err(e) = self.upstream.notify_resource_list_changed().await {
tracing::warn!(error = %e, "failed to forward resource_list_changed to stdio client");
}
}
#[tracing::instrument(skip_all, fields(dir = "remote→local"))]
async fn on_tool_list_changed(&self, _ctx: NotificationContext<RoleClient>) {
if let Err(e) = self.upstream.notify_tool_list_changed().await {
tracing::warn!(error = %e, "failed to forward tool_list_changed to stdio client");
}
}
#[tracing::instrument(skip_all, fields(dir = "remote→local"))]
async fn on_prompt_list_changed(&self, _ctx: NotificationContext<RoleClient>) {
if let Err(e) = self.upstream.notify_prompt_list_changed().await {
tracing::warn!(error = %e, "failed to forward prompt_list_changed to stdio client");
}
}
#[tracing::instrument(skip_all, fields(dir = "remote→local"))]
async fn on_url_elicitation_notification_complete(
&self,
params: ElicitationResponseNotificationParam,
_ctx: NotificationContext<RoleClient>,
) {
if let Err(e) = self.upstream.notify_url_elicitation_completed(params).await {
tracing::warn!(error = %e, "failed to forward elicitation completion to stdio client");
}
}
}
struct Remote {
peer: Peer<RoleClient>,
init_result: InitializeResult,
_service_handle: tokio::task::JoinHandle<()>,
_keepalive_handle: Option<tokio::task::JoinHandle<()>>,
_keepalive_cancel: CancellationToken,
}
pub struct ProxyHandler {
transport: Mutex<Option<RemoteTransport>>,
remote: OnceCell<Remote>,
keepalive: KeepaliveConfig,
tool_filter: ToolFilter,
}
impl ProxyHandler {
pub fn new(
transport: RemoteTransport,
keepalive: KeepaliveConfig,
tool_filter: ToolFilter,
) -> Self {
Self {
transport: Mutex::new(Some(transport)),
remote: OnceCell::new(),
keepalive,
tool_filter,
}
}
fn peer(&self) -> Result<&Peer<RoleClient>, ErrorData> {
self.remote
.get()
.map(|r| &r.peer)
.ok_or_else(|| internal_error("proxy session not yet initialized"))
}
#[tracing::instrument(skip_all)]
async fn connect(
&self,
upstream: Peer<RoleServer>,
local_init: InitializeRequestParams,
) -> Result<&Remote, ErrorData> {
self.remote
.get_or_try_init(|| async {
let transport = {
let mut guard = self
.transport
.lock()
.map_err(|_| internal_error("transport mutex poisoned"))?;
guard
.take()
.ok_or_else(|| internal_error("remote transport already consumed"))?
};
let handler = RemoteClientHandler {
upstream,
proxied_info: build_proxied_client_info(local_init),
};
let running: RunningService<RoleClient, RemoteClientHandler> = match transport {
RemoteTransport::Anonymous(t) => {
handler.serve(t).await.map_err(remote_error)?
}
RemoteTransport::Authorized(t) => {
handler.serve(t).await.map_err(remote_error)?
}
};
let peer = running.peer().clone();
let mut init_result = running
.peer_info()
.cloned()
.unwrap_or_else(|| InitializeResult::new(ServerCapabilities::default()));
init_result.server_info =
Implementation::new("hyper-mcp-remote", env!("CARGO_PKG_VERSION"));
let keepalive_cancel = CancellationToken::new();
let keepalive_handle =
spawn_keepalive(peer.clone(), self.keepalive, keepalive_cancel.clone());
let supervisor_cancel = keepalive_cancel.clone();
let handle = tokio::spawn(async move {
let result = running.waiting().await;
supervisor_cancel.cancel();
match result {
Ok(reason) => {
tracing::info!(?reason, "remote MCP service ended");
}
Err(e) => {
tracing::warn!(error = %e, "remote MCP service ended with error");
}
}
});
Ok::<_, ErrorData>(Remote {
peer,
init_result,
_service_handle: handle,
_keepalive_handle: keepalive_handle,
_keepalive_cancel: keepalive_cancel,
})
})
.await
}
}
impl ServerHandler for ProxyHandler {
#[tracing::instrument(
skip_all,
fields(
dir = "local→remote",
client = %request.client_info.name,
client_version = %request.client_info.version,
protocol = ?request.protocol_version,
)
)]
async fn initialize(
&self,
request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, ErrorData> {
let remote = self.connect(context.peer.clone(), request).await?;
Ok(remote.init_result.clone())
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn list_tools(
&self,
request: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> Result<ListToolsResult, ErrorData> {
let mut result = self
.peer()?
.list_tools(request)
.await
.map_err(remote_error)?;
if !self.tool_filter.is_noop() {
let before = result.tools.len();
result
.tools
.retain(|t| self.tool_filter.permits(t.name.as_ref()));
let dropped = before - result.tools.len();
if dropped > 0 {
tracing::debug!(
dropped,
kept = result.tools.len(),
"tool filter applied to list_tools page"
);
}
}
Ok(result)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote", tool = %request.name))]
async fn call_tool(
&self,
request: CallToolRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<CallToolResult, ErrorData> {
enforce_tool_filter(&self.tool_filter, &request.name)?;
self.peer()?.call_tool(request).await.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn list_resources(
&self,
request: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, ErrorData> {
self.peer()?
.list_resources(request)
.await
.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn list_resource_templates(
&self,
request: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, ErrorData> {
self.peer()?
.list_resource_templates(request)
.await
.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote", uri = %request.uri))]
async fn read_resource(
&self,
request: ReadResourceRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, ErrorData> {
self.peer()?
.read_resource(request)
.await
.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote", uri = %request.uri))]
async fn subscribe(
&self,
request: SubscribeRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<(), ErrorData> {
self.peer()?.subscribe(request).await.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote", uri = %request.uri))]
async fn unsubscribe(
&self,
request: UnsubscribeRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<(), ErrorData> {
self.peer()?
.unsubscribe(request)
.await
.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn list_prompts(
&self,
request: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, ErrorData> {
self.peer()?
.list_prompts(request)
.await
.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote", prompt = %request.name))]
async fn get_prompt(
&self,
request: GetPromptRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<GetPromptResult, ErrorData> {
self.peer()?.get_prompt(request).await.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn complete(
&self,
request: CompleteRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<CompleteResult, ErrorData> {
self.peer()?.complete(request).await.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote", level = ?request.level))]
async fn set_level(
&self,
request: SetLevelRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<(), ErrorData> {
self.peer()?.set_level(request).await.map_err(remote_error)
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn on_cancelled(
&self,
notification: CancelledNotificationParam,
_ctx: NotificationContext<RoleServer>,
) {
if let Ok(peer) = self.peer()
&& let Err(e) = peer.notify_cancelled(notification).await
{
tracing::warn!(error = %e, "failed to forward cancellation to remote");
}
}
#[tracing::instrument(level = "debug", skip_all, fields(dir = "local→remote"))]
async fn on_progress(
&self,
notification: ProgressNotificationParam,
_ctx: NotificationContext<RoleServer>,
) {
if let Ok(peer) = self.peer()
&& let Err(e) = peer.notify_progress(notification).await
{
tracing::warn!(error = %e, "failed to forward progress to remote");
}
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn on_initialized(&self, _ctx: NotificationContext<RoleServer>) {
tracing::debug!("local client sent initialized");
}
#[tracing::instrument(skip_all, fields(dir = "local→remote"))]
async fn on_roots_list_changed(&self, _ctx: NotificationContext<RoleServer>) {
if let Ok(peer) = self.peer()
&& let Err(e) = peer.notify_roots_list_changed().await
{
tracing::warn!(error = %e, "failed to forward roots_list_changed to remote");
}
}
}
fn build_proxied_client_info(mut local: InitializeRequestParams) -> ClientInfo {
const PROXY_TAG: &str = concat!(" (via hyper-mcp-remote ", env!("CARGO_PKG_VERSION"), ")");
local.client_info.name.push_str(PROXY_TAG);
local
}
fn spawn_keepalive(
peer: Peer<RoleClient>,
config: KeepaliveConfig,
cancel: CancellationToken,
) -> Option<tokio::task::JoinHandle<()>> {
let interval = config.interval?;
let timeout = config.timeout;
Some(tokio::spawn(async move {
tracing::debug!(?interval, ?timeout, "starting remote MCP keepalive pinger");
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
ticker.tick().await;
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("keepalive cancelled; exiting");
return;
}
_ = ticker.tick() => {
ping_remote(&peer, timeout).await;
}
}
}
}))
}
#[tracing::instrument(level = "debug", skip_all, fields(dir = "local→remote"))]
async fn ping_remote(peer: &Peer<RoleClient>, timeout: Duration) {
let request = ClientRequest::PingRequest(PingRequest::default());
match tokio::time::timeout(timeout, peer.send_request(request)).await {
Ok(Ok(_)) => {
tracing::debug!("remote ping ok");
}
Ok(Err(e)) => {
tracing::warn!(error = %e, "remote ping failed");
}
Err(_) => {
tracing::warn!(?timeout, "remote ping timed out");
}
}
}
fn enforce_tool_filter(filter: &ToolFilter, name: &str) -> Result<(), ErrorData> {
if filter.permits(name) {
return Ok(());
}
tracing::info!(tool = %name, "refusing filtered tool call");
Err(ErrorData::new(
ErrorCode::METHOD_NOT_FOUND,
format!("tool {name:?} is not available on this proxy"),
None,
))
}
fn internal_error(msg: impl Into<String>) -> ErrorData {
ErrorData::new(ErrorCode::INTERNAL_ERROR, msg.into(), None)
}
fn remote_error(e: impl std::fmt::Display) -> ErrorData {
ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!("remote MCP server error: {e}"),
None,
)
}
fn upstream_error(e: impl std::fmt::Display) -> ErrorData {
ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!("local stdio client error: {e}"),
None,
)
}
#[cfg(test)]
mod tests {
use super::*;
use rmcp::model::ProtocolVersion;
#[test]
fn proxied_client_info_preserves_capabilities_and_tags_name() {
let mut input = InitializeRequestParams::default();
input.client_info.name = "Test Client".to_string();
input.client_info.version = "9.9.9".to_string();
input.protocol_version = ProtocolVersion::default();
let out = build_proxied_client_info(input.clone());
assert_eq!(out.client_info.version, "9.9.9");
assert!(
out.client_info.name.starts_with("Test Client"),
"name should keep the original prefix; got {:?}",
out.client_info.name
);
assert!(
out.client_info.name.contains("hyper-mcp-remote"),
"name should be suffixed with proxy identity; got {:?}",
out.client_info.name
);
assert_eq!(
out.capabilities, input.capabilities,
"client capabilities must be forwarded verbatim"
);
assert_eq!(out.protocol_version, input.protocol_version);
}
#[test]
fn helper_error_codes_are_internal_error() {
let e1 = internal_error("x");
let e2 = remote_error("y");
let e3 = upstream_error("z");
assert_eq!(e1.code, ErrorCode::INTERNAL_ERROR);
assert_eq!(e2.code, ErrorCode::INTERNAL_ERROR);
assert_eq!(e3.code, ErrorCode::INTERNAL_ERROR);
}
#[test]
fn remote_and_upstream_have_distinct_prefixes() {
assert!(remote_error("boom").message.contains("remote MCP server"));
assert!(
upstream_error("boom")
.message
.contains("local stdio client")
);
}
#[test]
fn keepalive_config_from_secs_disables_when_interval_is_zero() {
let cfg = KeepaliveConfig::from_secs(0, 10);
assert!(
cfg.interval.is_none(),
"interval 0 must map to disabled keepalive"
);
}
#[test]
fn keepalive_config_from_secs_uses_provided_values() {
let cfg = KeepaliveConfig::from_secs(30, 7);
assert_eq!(cfg.interval, Some(Duration::from_secs(30)));
assert_eq!(cfg.timeout, Duration::from_secs(7));
}
#[test]
fn keepalive_config_from_secs_clamps_zero_timeout_to_one() {
let cfg = KeepaliveConfig::from_secs(30, 0);
assert_eq!(cfg.timeout, Duration::from_secs(1));
}
#[test]
fn enforce_tool_filter_allows_permitted_name() {
let filter = ToolFilter::allow_all();
enforce_tool_filter(&filter, "anything").expect("allow-all must permit");
}
#[test]
fn enforce_tool_filter_rejects_denied_name_with_method_not_found() {
let filter = ToolFilter::from_cli(&[], &["read_secrets".to_string()]).expect("build");
let err = enforce_tool_filter(&filter, "read_secrets")
.expect_err("denied tool must produce an error");
assert_eq!(
err.code,
ErrorCode::METHOD_NOT_FOUND,
"filtered tools must surface as METHOD_NOT_FOUND so clients treat them as nonexistent"
);
assert!(
err.message.contains("read_secrets"),
"error must name the offending tool; got {:?}",
err.message
);
enforce_tool_filter(&filter, "read_file")
.expect("non-matching name must still be admitted");
}
#[test]
fn keepalive_disabled_returns_no_task() {
let cfg = KeepaliveConfig::disabled();
assert!(cfg.interval.is_none());
}
}