use std::sync::Arc;
use tokio::sync::broadcast;
use super::upstream::UpstreamConnection;
use crate::util::RefreshGuard;
use bitrouter_core::api::mcp::gateway::{
McpCompletionServer, McpLoggingServer, McpPromptServer, McpResourceServer,
McpSubscriptionServer, McpToolServer,
};
use bitrouter_core::api::mcp::types::McpGatewayError;
use bitrouter_core::api::mcp::types::{
CompleteParams, CompleteResult, Completion, LoggingLevel, McpGetPromptResult, McpPrompt,
McpResource, McpResourceContent, McpResourceTemplate, McpTool, McpToolCallResult,
};
pub struct SingleServerBridge {
upstream: Arc<UpstreamConnection>,
tool_change_tx: broadcast::Sender<()>,
resource_change_tx: broadcast::Sender<()>,
prompt_change_tx: broadcast::Sender<()>,
}
impl SingleServerBridge {
pub fn new(
upstream: Arc<UpstreamConnection>,
upstream_tool_rx: broadcast::Receiver<()>,
upstream_resource_rx: broadcast::Receiver<()>,
upstream_prompt_rx: broadcast::Receiver<()>,
) -> (Arc<Self>, RefreshGuard) {
let (tool_tx, _) = broadcast::channel(16);
let (resource_tx, _) = broadcast::channel(16);
let (prompt_tx, _) = broadcast::channel(16);
let bridge = Arc::new(Self {
upstream,
tool_change_tx: tool_tx,
resource_change_tx: resource_tx,
prompt_change_tx: prompt_tx,
});
let guard = bridge.spawn_forward_listeners(
upstream_tool_rx,
upstream_resource_rx,
upstream_prompt_rx,
);
(bridge, guard)
}
fn spawn_forward_listeners(
self: &Arc<Self>,
mut tool_rx: broadcast::Receiver<()>,
mut resource_rx: broadcast::Receiver<()>,
mut prompt_rx: broadcast::Receiver<()>,
) -> RefreshGuard {
let mut handles = Vec::new();
let tx = self.tool_change_tx.clone();
handles.push(tokio::spawn(async move {
loop {
match tool_rx.recv().await {
Ok(()) => {
let _ = tx.send(());
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}));
let tx = self.resource_change_tx.clone();
handles.push(tokio::spawn(async move {
loop {
match resource_rx.recv().await {
Ok(()) => {
let _ = tx.send(());
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}));
let tx = self.prompt_change_tx.clone();
handles.push(tokio::spawn(async move {
loop {
match prompt_rx.recv().await {
Ok(()) => {
let _ = tx.send(());
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}));
RefreshGuard::from_handles(handles)
}
}
impl McpToolServer for SingleServerBridge {
async fn list_tools(&self) -> Vec<McpTool> {
self.upstream.raw_tools().await
}
async fn call_tool(
&self,
name: &str,
arguments: Option<serde_json::Map<String, serde_json::Value>>,
) -> Result<McpToolCallResult, McpGatewayError> {
self.upstream.call_tool(name, arguments).await
}
fn subscribe_tool_changes(&self) -> broadcast::Receiver<()> {
self.tool_change_tx.subscribe()
}
}
impl McpResourceServer for SingleServerBridge {
async fn list_resources(&self) -> Vec<McpResource> {
self.upstream.raw_resources().await
}
async fn read_resource(&self, uri: &str) -> Result<Vec<McpResourceContent>, McpGatewayError> {
self.upstream.read_resource(uri).await
}
async fn list_resource_templates(&self) -> Vec<McpResourceTemplate> {
self.upstream.raw_resource_templates().await
}
fn subscribe_resource_changes(&self) -> broadcast::Receiver<()> {
self.resource_change_tx.subscribe()
}
}
impl McpPromptServer for SingleServerBridge {
async fn list_prompts(&self) -> Vec<McpPrompt> {
self.upstream.raw_prompts().await
}
async fn get_prompt(
&self,
name: &str,
arguments: Option<std::collections::HashMap<String, String>>,
) -> Result<McpGetPromptResult, McpGatewayError> {
self.upstream.get_prompt(name, arguments).await
}
fn subscribe_prompt_changes(&self) -> broadcast::Receiver<()> {
self.prompt_change_tx.subscribe()
}
}
impl McpSubscriptionServer for SingleServerBridge {
async fn subscribe_resource(&self, _uri: &str) -> Result<(), McpGatewayError> {
Ok(())
}
async fn unsubscribe_resource(&self, _uri: &str) -> Result<(), McpGatewayError> {
Ok(())
}
}
impl McpLoggingServer for SingleServerBridge {
async fn set_logging_level(&self, _level: LoggingLevel) -> Result<(), McpGatewayError> {
Ok(())
}
}
impl McpCompletionServer for SingleServerBridge {
async fn complete(&self, _params: CompleteParams) -> Result<CompleteResult, McpGatewayError> {
Ok(CompleteResult {
completion: Completion {
values: Vec::new(),
has_more: None,
total: None,
},
})
}
}