use std::collections::HashMap;
use std::time::Duration;
use eventsource_stream::Eventsource;
use futures::{SinkExt, StreamExt};
use reqwest::{RequestBuilder, Url};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tonic::metadata::{MetadataKey, MetadataValue};
use tonic::transport::Endpoint;
use tonic::Request;
use crate::grpc_proto::bridge::mcp_bridge_client::McpBridgeClient;
use crate::grpc_proto::bridge::Envelope;
use crate::tool_api::error::ToolCallError;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_millis(1_500);
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WsWriter = futures::stream::SplitSink<WsStream, Message>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Transport {
Stdio,
StreamableHttp,
Sse,
Ws,
Grpc,
}
#[derive(Debug, Clone)]
pub(crate) struct TransportOptions {
pub(crate) endpoint: String,
pub(crate) headers: HashMap<String, String>,
pub(crate) connect_timeout: Duration,
pub(crate) request_timeout: Option<Duration>,
pub(crate) stdio_command: Option<String>,
pub(crate) stdio_args: Vec<String>,
pub(crate) stdio_env: HashMap<String, String>,
pub(crate) stdio_cwd: Option<String>,
}
impl Default for TransportOptions {
fn default() -> Self {
Self {
endpoint: String::new(),
headers: HashMap::new(),
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
request_timeout: Some(DEFAULT_REQUEST_TIMEOUT),
stdio_command: None,
stdio_args: Vec::new(),
stdio_env: HashMap::new(),
stdio_cwd: None,
}
}
}
pub(crate) enum TransportClient {
Stdio(StdioTransport),
StreamableHttp(StreamableHttpTransport),
Sse(SseTransport),
Ws(WsTransport),
Grpc(GrpcTransport),
}
impl TransportClient {
pub(crate) fn new(
transport: Transport,
options: TransportOptions,
) -> Result<Self, ToolCallError> {
match transport {
Transport::Stdio => Ok(Self::Stdio(StdioTransport::new(options)?)),
Transport::StreamableHttp => {
Ok(Self::StreamableHttp(StreamableHttpTransport::new(options)?))
}
Transport::Sse => Ok(Self::Sse(SseTransport::new(options)?)),
Transport::Ws => Ok(Self::Ws(WsTransport::new(options))),
Transport::Grpc => Ok(Self::Grpc(GrpcTransport::new(options))),
}
}
pub(crate) async fn send_request(&mut self, request: &Value) -> Result<Value, ToolCallError> {
match self {
Self::Stdio(inner) => inner.send_request(request).await,
Self::StreamableHttp(inner) => inner.send_request(request).await,
Self::Sse(inner) => inner.send_request(request).await,
Self::Ws(inner) => inner.send_request(request).await,
Self::Grpc(inner) => inner.send_request(request).await,
}
}
pub(crate) async fn send_notification(
&mut self,
notification: &Value,
) -> Result<(), ToolCallError> {
match self {
Self::Stdio(inner) => inner.send_notification(notification).await,
Self::StreamableHttp(inner) => inner.send_notification(notification).await,
Self::Sse(inner) => inner.send_notification(notification).await,
Self::Ws(inner) => inner.send_notification(notification).await,
Self::Grpc(inner) => inner.send_notification(notification).await,
}
}
}
pub(crate) struct StdioTransport {
command: String,
args: Vec<String>,
env: HashMap<String, String>,
cwd: Option<String>,
request_timeout: Option<Duration>,
stdin: Option<ChildStdin>,
child: Option<Child>,
pending: std::sync::Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
reader_task: Option<JoinHandle<()>>,
}
impl StdioTransport {
fn new(options: TransportOptions) -> Result<Self, ToolCallError> {
let command = options
.stdio_command
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
.ok_or_else(|| {
ToolCallError::InvalidArguments("stdio command cannot be empty".to_string())
})?;
let cwd = options.stdio_cwd.and_then(|value| {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
});
Ok(Self {
command,
args: options.stdio_args,
env: options.stdio_env,
cwd,
request_timeout: options.request_timeout,
stdin: None,
child: None,
pending: std::sync::Arc::new(Mutex::new(HashMap::new())),
reader_task: None,
})
}
async fn send_request(&mut self, request: &Value) -> Result<Value, ToolCallError> {
self.ensure_connected().await?;
let request_id = id_key_from_envelope(request).ok_or_else(|| {
ToolCallError::Protocol("JSON-RPC request is missing an id".to_string())
})?;
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(request_id.clone(), tx);
if let Err(err) = self.send_line(request).await {
self.pending.lock().await.remove(&request_id);
return Err(err);
}
let timeout = self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(payload)) => Ok(payload),
Ok(Err(_)) => Err(ToolCallError::Transport(
"stdio child closed while waiting for response".to_string(),
)),
Err(_) => {
self.pending.lock().await.remove(&request_id);
Err(ToolCallError::Transport(format!(
"Timed out waiting for stdio response after {}ms",
timeout.as_millis()
)))
}
}
}
async fn send_notification(&mut self, notification: &Value) -> Result<(), ToolCallError> {
self.ensure_connected().await?;
self.send_line(notification).await
}
async fn send_line(&mut self, payload: &Value) -> Result<(), ToolCallError> {
let line = serde_json::to_string(payload).map_err(|err| {
ToolCallError::InvalidArguments(format!("Failed to serialize JSON-RPC payload: {err}"))
})?;
let stdin = self.stdin.as_mut().ok_or_else(|| {
ToolCallError::Transport("stdio child stdin is not connected".to_string())
})?;
stdin.write_all(line.as_bytes()).await.map_err(|err| {
ToolCallError::Transport(format!("Failed writing to stdio child: {err}"))
})?;
stdin.write_all(b"\n").await.map_err(|err| {
ToolCallError::Transport(format!("Failed writing newline to stdio child: {err}"))
})?;
stdin
.flush()
.await
.map_err(|err| ToolCallError::Transport(format!("Failed flushing stdio child: {err}")))
}
async fn ensure_connected(&mut self) -> Result<(), ToolCallError> {
if self
.reader_task
.as_ref()
.map(|task| task.is_finished())
.unwrap_or(false)
{
self.child = None;
self.stdin = None;
self.reader_task = None;
}
if let Some(child) = self.child.as_mut() {
match child.try_wait() {
Ok(None) => {}
Ok(Some(_)) => {
self.child = None;
self.stdin = None;
self.reader_task = None;
}
Err(err) => {
self.child = None;
self.stdin = None;
self.reader_task = None;
return Err(ToolCallError::Transport(format!(
"Failed to poll stdio child process: {err}"
)));
}
}
}
if self.stdin.is_some() && self.child.is_some() && self.reader_task.is_some() {
return Ok(());
}
self.pending.lock().await.clear();
let mut command = Command::new(&self.command);
command.args(&self.args);
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::piped());
if !self.env.is_empty() {
command.envs(&self.env);
}
if let Some(cwd) = self.cwd.as_deref() {
command.current_dir(cwd);
}
let mut child = command.spawn().map_err(|err| {
ToolCallError::Transport(format!(
"Failed to spawn stdio command '{}': {err}",
self.command
))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
ToolCallError::Transport("Spawned stdio child missing stdin".to_string())
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ToolCallError::Transport("Spawned stdio child missing stdout".to_string())
})?;
let stderr = child.stderr.take().ok_or_else(|| {
ToolCallError::Transport("Spawned stdio child missing stderr".to_string())
})?;
let command_label = self.command.clone();
tokio::spawn(async move {
let mut stderr_lines = BufReader::new(stderr).lines();
loop {
match stderr_lines.next_line().await {
Ok(Some(line)) => {
if !line.trim().is_empty() {
tracing::error!("stdio child stderr ({command_label}): {line}");
}
}
Ok(None) => break,
Err(err) => {
tracing::error!(
"Failed reading stdio child stderr ({command_label}): {err}"
);
break;
}
}
}
});
let pending = self.pending.clone();
let command_label = self.command.clone();
let reader_task = tokio::spawn(async move {
let mut stdout_lines = BufReader::new(stdout).lines();
loop {
match stdout_lines.next_line().await {
Ok(Some(line)) => {
let line = line.trim();
if line.is_empty() {
continue;
}
let Ok(payload) = serde_json::from_str::<Value>(line) else {
tracing::error!(
"stdio child produced non-JSON line ({command_label}): {line}"
);
continue;
};
let Some(id_key) = id_key_from_envelope(&payload) else {
continue;
};
if let Some(tx) = pending.lock().await.remove(&id_key) {
let _ = tx.send(payload);
}
}
Ok(None) => break,
Err(err) => {
tracing::error!(
"Failed reading stdio child stdout ({command_label}): {err}"
);
break;
}
}
}
pending.lock().await.clear();
});
self.stdin = Some(stdin);
self.child = Some(child);
self.reader_task = Some(reader_task);
Ok(())
}
}
impl Drop for StdioTransport {
fn drop(&mut self) {
self.stdin = None;
if let Some(task) = self.reader_task.take() {
task.abort();
}
if let Some(child) = self.child.as_mut() {
let _ = child.start_kill();
}
}
}
pub(crate) struct StreamableHttpTransport {
endpoint: String,
headers: HashMap<String, String>,
client: reqwest::Client,
session_id: Option<String>,
}
impl StreamableHttpTransport {
fn new(options: TransportOptions) -> Result<Self, ToolCallError> {
Url::parse(&options.endpoint)
.map_err(|err| ToolCallError::InvalidEndpoint(err.to_string()))?;
let mut builder = reqwest::Client::builder().connect_timeout(options.connect_timeout);
if let Some(timeout) = options.request_timeout {
builder = builder.timeout(timeout);
}
let client = builder.build().map_err(|err| {
ToolCallError::Transport(format!("Failed to build HTTP client: {err}"))
})?;
Ok(Self {
endpoint: options.endpoint,
headers: options.headers,
client,
session_id: None,
})
}
async fn send_request(&mut self, request: &Value) -> Result<Value, ToolCallError> {
let response = self.dispatch(request).await?;
response.ok_or_else(|| {
ToolCallError::Protocol(
"Expected JSON-RPC response but received empty body".to_string(),
)
})
}
async fn send_notification(&mut self, notification: &Value) -> Result<(), ToolCallError> {
let _ = self.dispatch(notification).await?;
Ok(())
}
async fn dispatch(&mut self, payload: &Value) -> Result<Option<Value>, ToolCallError> {
let mut req = apply_headers(
self.client.post(&self.endpoint).json(payload),
&self.headers,
);
if let Some(session_id) = &self.session_id {
req = req.header("Mcp-Session-Id", session_id);
}
let response = req
.send()
.await
.map_err(|err| ToolCallError::Transport(err.to_string()))?;
if let Some(session_id) = extract_session_id(response.headers()) {
self.session_id = Some(session_id);
}
parse_json_response(response, Some(self.endpoint.as_str())).await
}
}
pub(crate) struct SseTransport {
sse_endpoint: Url,
headers: HashMap<String, String>,
client: reqwest::Client,
request_timeout: Option<Duration>,
endpoint_wait_timeout: Duration,
message_endpoint: std::sync::Arc<RwLock<Option<Url>>>,
pending: std::sync::Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
stream_task: Option<JoinHandle<()>>,
}
impl SseTransport {
fn new(options: TransportOptions) -> Result<Self, ToolCallError> {
let sse_endpoint = Url::parse(&options.endpoint)
.map_err(|err| ToolCallError::InvalidEndpoint(err.to_string()))?;
let client = reqwest::Client::builder()
.connect_timeout(options.connect_timeout)
.build()
.map_err(|err| {
ToolCallError::Transport(format!("Failed to build SSE client: {err}"))
})?;
Ok(Self {
sse_endpoint,
headers: options.headers,
client,
request_timeout: options.request_timeout,
endpoint_wait_timeout: options.connect_timeout,
message_endpoint: std::sync::Arc::new(RwLock::new(None)),
pending: std::sync::Arc::new(Mutex::new(HashMap::new())),
stream_task: None,
})
}
async fn send_request(&mut self, request: &Value) -> Result<Value, ToolCallError> {
self.ensure_stream_started().await?;
let message_endpoint = self.wait_for_message_endpoint().await?;
let request_id = id_key_from_envelope(request).ok_or_else(|| {
ToolCallError::Protocol("JSON-RPC request is missing an id".to_string())
})?;
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(request_id.clone(), tx);
match self.post_json(message_endpoint, request).await {
Ok(Some(payload)) => {
self.pending.lock().await.remove(&request_id);
return Ok(payload);
}
Ok(None) => {}
Err(err) => {
self.pending.lock().await.remove(&request_id);
return Err(err);
}
}
let timeout = self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(payload)) => Ok(payload),
Ok(Err(_)) => Err(ToolCallError::Transport(
"SSE stream disconnected while waiting for response".to_string(),
)),
Err(_) => {
self.pending.lock().await.remove(&request_id);
Err(ToolCallError::Transport(format!(
"Timed out waiting for SSE response after {}ms",
timeout.as_millis()
)))
}
}
}
async fn send_notification(&mut self, notification: &Value) -> Result<(), ToolCallError> {
self.ensure_stream_started().await?;
let message_endpoint = self.wait_for_message_endpoint().await?;
let _ = self.post_json(message_endpoint, notification).await?;
Ok(())
}
async fn ensure_stream_started(&mut self) -> Result<(), ToolCallError> {
if self
.stream_task
.as_ref()
.map(|task| !task.is_finished())
.unwrap_or(false)
{
return Ok(());
}
let mut req = self.client.get(self.sse_endpoint.clone());
for (k, v) in &self.headers {
req = req.header(k, v);
}
let response = req
.send()
.await
.map_err(|err| ToolCallError::Transport(format!("Failed to open SSE stream: {err}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(map_http_status_error(status, body.trim()));
}
*self.message_endpoint.write().await = None;
let base_url = self.sse_endpoint.clone();
let pending = self.pending.clone();
let message_endpoint = self.message_endpoint.clone();
self.stream_task = Some(tokio::spawn(async move {
let stream = response.bytes_stream().eventsource();
tokio::pin!(stream);
while let Some(event) = stream.next().await {
let Ok(event) = event else {
break;
};
if event.event == "endpoint" {
if let Ok(endpoint_url) = base_url.join(event.data.trim()) {
*message_endpoint.write().await = Some(endpoint_url);
}
continue;
}
if event.data.trim().is_empty() {
continue;
}
let Ok(payload) = serde_json::from_str::<Value>(&event.data) else {
continue;
};
let Some(id_key) = id_key_from_envelope(&payload) else {
continue;
};
if let Some(tx) = pending.lock().await.remove(&id_key) {
let _ = tx.send(payload);
}
}
}));
Ok(())
}
async fn wait_for_message_endpoint(&self) -> Result<Url, ToolCallError> {
let deadline = tokio::time::Instant::now() + self.endpoint_wait_timeout;
loop {
if let Some(endpoint) = self.message_endpoint.read().await.clone() {
return Ok(endpoint);
}
if tokio::time::Instant::now() >= deadline {
return Err(ToolCallError::Transport(format!(
"Timed out waiting for SSE message endpoint after {}ms",
self.endpoint_wait_timeout.as_millis()
)));
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
async fn post_json(
&self,
endpoint: Url,
payload: &Value,
) -> Result<Option<Value>, ToolCallError> {
let endpoint_hint = endpoint.to_string();
let mut req = self.client.post(endpoint).json(payload);
for (k, v) in &self.headers {
req = req.header(k, v);
}
if let Some(timeout) = self.request_timeout {
req = req.timeout(timeout);
}
let response = req
.send()
.await
.map_err(|err| ToolCallError::Transport(err.to_string()))?;
parse_json_response(response, Some(endpoint_hint.as_str())).await
}
}
pub(crate) struct WsTransport {
endpoint: String,
headers: HashMap<String, String>,
connect_timeout: Duration,
request_timeout: Option<Duration>,
writer: Option<WsWriter>,
pending: std::sync::Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
reader_task: Option<JoinHandle<()>>,
}
impl WsTransport {
fn new(options: TransportOptions) -> Self {
Self {
endpoint: options.endpoint,
headers: options.headers,
connect_timeout: options.connect_timeout,
request_timeout: options.request_timeout,
writer: None,
pending: std::sync::Arc::new(Mutex::new(HashMap::new())),
reader_task: None,
}
}
async fn send_request(&mut self, request: &Value) -> Result<Value, ToolCallError> {
self.ensure_connected().await?;
let request_id = id_key_from_envelope(request).ok_or_else(|| {
ToolCallError::Protocol("JSON-RPC request is missing an id".to_string())
})?;
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(request_id.clone(), tx);
let writer = self.writer.as_mut().ok_or_else(|| {
ToolCallError::Transport("WebSocket writer is not connected".to_string())
})?;
if let Err(err) = writer.send(Message::Text(request.to_string().into())).await {
self.pending.lock().await.remove(&request_id);
return Err(ToolCallError::Transport(format!(
"Failed to send WebSocket message: {err}"
)));
}
let timeout = self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(payload)) => Ok(payload),
Ok(Err(_)) => Err(ToolCallError::Transport(
"WebSocket disconnected while waiting for response".to_string(),
)),
Err(_) => {
self.pending.lock().await.remove(&request_id);
Err(ToolCallError::Transport(format!(
"Timed out waiting for WebSocket response after {}ms",
timeout.as_millis()
)))
}
}
}
async fn send_notification(&mut self, notification: &Value) -> Result<(), ToolCallError> {
self.ensure_connected().await?;
let writer = self.writer.as_mut().ok_or_else(|| {
ToolCallError::Transport("WebSocket writer is not connected".to_string())
})?;
writer
.send(Message::Text(notification.to_string().into()))
.await
.map_err(|err| {
ToolCallError::Transport(format!("Failed to send WebSocket message: {err}"))
})
}
async fn ensure_connected(&mut self) -> Result<(), ToolCallError> {
if self
.reader_task
.as_ref()
.map(|task| task.is_finished())
.unwrap_or(false)
{
self.writer = None;
self.reader_task = None;
}
if self.writer.is_some() {
return Ok(());
}
let mut request = self
.endpoint
.as_str()
.into_client_request()
.map_err(|err| ToolCallError::InvalidEndpoint(err.to_string()))?;
for (key, value) in &self.headers {
let header_name = tokio_tungstenite::tungstenite::http::header::HeaderName::from_bytes(
key.as_bytes(),
)
.map_err(|err| {
ToolCallError::InvalidArguments(format!(
"Invalid WebSocket header name '{key}': {err}"
))
})?;
let header_value =
tokio_tungstenite::tungstenite::http::header::HeaderValue::from_str(value)
.map_err(|err| {
ToolCallError::InvalidArguments(format!(
"Invalid WebSocket header value for '{key}': {err}"
))
})?;
request.headers_mut().insert(header_name, header_value);
}
let (stream, _) = tokio::time::timeout(self.connect_timeout, connect_async(request))
.await
.map_err(|_| {
ToolCallError::Transport(format!(
"Timed out connecting to WebSocket endpoint after {}ms",
self.connect_timeout.as_millis()
))
})?
.map_err(|err| {
ToolCallError::Transport(format!("WebSocket connection failed: {err}"))
})?;
let (writer, mut reader) = stream.split();
let pending = self.pending.clone();
self.reader_task = Some(tokio::spawn(async move {
while let Some(frame) = reader.next().await {
let Ok(frame) = frame else {
break;
};
let payload = match frame {
Message::Text(text) => serde_json::from_str::<Value>(&text).ok(),
Message::Binary(bytes) => serde_json::from_slice::<Value>(&bytes).ok(),
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => None,
_ => None,
};
let Some(payload) = payload else {
continue;
};
let Some(id_key) = id_key_from_envelope(&payload) else {
continue;
};
if let Some(tx) = pending.lock().await.remove(&id_key) {
let _ = tx.send(payload);
}
}
}));
self.writer = Some(writer);
Ok(())
}
}
pub(crate) struct GrpcTransport {
endpoint: String,
headers: HashMap<String, String>,
connect_timeout: Duration,
request_timeout: Option<Duration>,
writer: Option<mpsc::Sender<Envelope>>,
pending: std::sync::Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
reader_task: Option<JoinHandle<()>>,
seq: u64,
}
impl GrpcTransport {
fn new(options: TransportOptions) -> Self {
Self {
endpoint: options.endpoint,
headers: options.headers,
connect_timeout: options.connect_timeout,
request_timeout: options.request_timeout,
writer: None,
pending: std::sync::Arc::new(Mutex::new(HashMap::new())),
reader_task: None,
seq: 0,
}
}
async fn send_request(&mut self, request: &Value) -> Result<Value, ToolCallError> {
self.ensure_connected().await?;
let request_id = id_key_from_envelope(request).ok_or_else(|| {
ToolCallError::Protocol("JSON-RPC request is missing an id".to_string())
})?;
let envelope = self.next_envelope(request.to_string());
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(request_id.clone(), tx);
let writer = self.writer.as_mut().ok_or_else(|| {
ToolCallError::Transport("gRPC stream writer is not connected".to_string())
})?;
if writer.send(envelope).await.is_err() {
self.pending.lock().await.remove(&request_id);
return Err(ToolCallError::Transport(
"Failed to send gRPC message: stream is closed".to_string(),
));
}
let timeout = self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(payload)) => Ok(payload),
Ok(Err(_)) => Err(ToolCallError::Transport(
"gRPC stream disconnected while waiting for response".to_string(),
)),
Err(_) => {
self.pending.lock().await.remove(&request_id);
Err(ToolCallError::Transport(format!(
"Timed out waiting for gRPC response after {}ms",
timeout.as_millis()
)))
}
}
}
async fn send_notification(&mut self, notification: &Value) -> Result<(), ToolCallError> {
self.ensure_connected().await?;
let envelope = self.next_envelope(notification.to_string());
let writer = self.writer.as_mut().ok_or_else(|| {
ToolCallError::Transport("gRPC stream writer is not connected".to_string())
})?;
writer.send(envelope).await.map_err(|_| {
ToolCallError::Transport("Failed to send gRPC message: stream is closed".into())
})
}
async fn ensure_connected(&mut self) -> Result<(), ToolCallError> {
if self
.reader_task
.as_ref()
.map(|task| task.is_finished())
.unwrap_or(false)
{
self.writer = None;
self.reader_task = None;
}
if self.writer.is_some() {
return Ok(());
}
let normalized_endpoint = normalize_grpc_endpoint(&self.endpoint)?;
let channel = Endpoint::from_shared(normalized_endpoint.clone())
.map_err(|err| ToolCallError::InvalidEndpoint(err.to_string()))?
.connect_timeout(self.connect_timeout)
.connect()
.await
.map_err(|err| ToolCallError::Transport(format!("gRPC connection failed: {err}")))?;
let mut client = McpBridgeClient::new(channel);
let (tx, rx) = mpsc::channel::<Envelope>(256);
let mut request = Request::new(ReceiverStream::new(rx));
apply_grpc_headers(request.metadata_mut(), &self.headers)?;
let mut inbound = client
.stream(request)
.await
.map_err(|err| ToolCallError::Transport(format!("gRPC stream failed: {err}")))?
.into_inner();
let pending = self.pending.clone();
self.reader_task = Some(tokio::spawn(async move {
loop {
let payload = match inbound.message().await {
Ok(Some(envelope)) => envelope,
Ok(None) => break,
Err(_) => break,
};
if payload.json_rpc.trim().is_empty() {
continue;
}
let Ok(parsed) = serde_json::from_str::<Value>(&payload.json_rpc) else {
continue;
};
let Some(id_key) = id_key_from_envelope(&parsed) else {
continue;
};
if let Some(tx) = pending.lock().await.remove(&id_key) {
let _ = tx.send(parsed);
}
}
}));
self.writer = Some(tx);
Ok(())
}
fn next_envelope(&mut self, json_rpc: String) -> Envelope {
self.seq = self.seq.saturating_add(1);
Envelope {
json_rpc,
metadata: self.headers.clone(),
session_id: String::new(),
seq: self.seq,
}
}
}
pub(crate) fn id_key_from_envelope(message: &Value) -> Option<String> {
let id = message.get("id")?;
Some(match id {
Value::String(s) => format!("s:{s}"),
Value::Number(n) => format!("n:{n}"),
Value::Bool(v) => format!("b:{v}"),
Value::Null => "null".to_string(),
other => format!("j:{}", other),
})
}
fn normalize_grpc_endpoint(endpoint: &str) -> Result<String, ToolCallError> {
if let Some(rest) = endpoint.strip_prefix("grpc://") {
return Ok(format!("http://{rest}"));
}
if let Some(rest) = endpoint.strip_prefix("grpcs://") {
return Ok(format!("https://{rest}"));
}
let url =
Url::parse(endpoint).map_err(|err| ToolCallError::InvalidEndpoint(err.to_string()))?;
match url.scheme() {
"http" | "https" => Ok(url.to_string()),
other => Err(ToolCallError::InvalidEndpoint(format!(
"Unsupported gRPC endpoint scheme '{other}'. Use grpc://, grpcs://, http://, or https://"
))),
}
}
fn apply_grpc_headers(
metadata: &mut tonic::metadata::MetadataMap,
headers: &HashMap<String, String>,
) -> Result<(), ToolCallError> {
for (key, value) in headers {
let lower = key.to_ascii_lowercase();
let name = MetadataKey::from_bytes(lower.as_bytes()).map_err(|err| {
ToolCallError::InvalidArguments(format!("Invalid gRPC metadata key '{key}': {err}"))
})?;
let value = MetadataValue::try_from(value.as_str()).map_err(|err| {
ToolCallError::InvalidArguments(format!(
"Invalid gRPC metadata value for '{key}': {err}"
))
})?;
metadata.insert(name, value);
}
Ok(())
}
fn apply_headers(mut req: RequestBuilder, headers: &HashMap<String, String>) -> RequestBuilder {
for (k, v) in headers {
req = req.header(k, v);
}
req
}
fn extract_session_id(headers: &reqwest::header::HeaderMap) -> Option<String> {
headers
.get("Mcp-Session-Id")
.and_then(|value| value.to_str().ok())
.map(str::to_string)
}
async fn parse_json_response(
response: reqwest::Response,
_endpoint_hint: Option<&str>,
) -> Result<Option<Value>, ToolCallError> {
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(map_http_status_error(status, body.trim()));
}
let bytes = response
.bytes()
.await
.map_err(|err| ToolCallError::Transport(err.to_string()))?;
if bytes.is_empty() {
return Ok(None);
}
let payload = serde_json::from_slice::<Value>(&bytes)
.map_err(|err| ToolCallError::Protocol(format!("Response was not JSON: {err}")))?;
Ok(Some(payload))
}
fn map_http_status_error(status: reqwest::StatusCode, body: &str) -> ToolCallError {
ToolCallError::Transport(format!("HTTP {} {}", status, body))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stdio_transport_rejects_missing_command() {
let result = TransportClient::new(
Transport::Stdio,
TransportOptions {
stdio_command: None,
..Default::default()
},
);
assert!(matches!(result, Err(ToolCallError::InvalidArguments(_))));
}
#[test]
fn stdio_transport_accepts_command() {
let result = TransportClient::new(
Transport::Stdio,
TransportOptions {
stdio_command: Some("node".to_string()),
stdio_args: vec!["--version".to_string()],
..Default::default()
},
);
assert!(result.is_ok());
}
#[test]
fn unauthorized_status_maps_to_transport_error() {
let err = map_http_status_error(reqwest::StatusCode::UNAUTHORIZED, "unauthorized");
match err {
ToolCallError::Transport(message) => assert!(message.contains("401")),
other => panic!("unexpected error variant: {other}"),
}
}
}