use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
use turbomcp_protocol::{Error, Result};
use turbomcp_transport::{Transport, TransportConfig, TransportMessage};
use super::dispatcher::MessageDispatcher;
#[derive(Debug)]
pub(super) struct ProtocolClient<T: Transport> {
transport: Arc<T>,
dispatcher: Arc<MessageDispatcher>,
next_id: AtomicU64,
config: TransportConfig,
}
impl<T: Transport + 'static> ProtocolClient<T> {
pub(super) fn with_config(transport: T, config: TransportConfig) -> Self {
let transport = Arc::new(transport);
let dispatcher = MessageDispatcher::new(transport.clone());
Self {
transport,
dispatcher,
next_id: AtomicU64::new(1),
config,
}
}
pub(super) fn dispatcher(&self) -> &Arc<MessageDispatcher> {
&self.dispatcher
}
pub(super) async fn request<R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<R> {
let operation = self.request_inner(method, params);
if let Some(total_timeout) = self.config.timeouts.total {
match tokio::time::timeout(total_timeout, operation).await {
Ok(result) => result,
Err(_) => {
let err = turbomcp_transport::TransportError::TotalTimeout {
operation: format!("{}()", method),
timeout: total_timeout,
};
Err(Error::transport(err.to_string()))
}
}
} else {
operation.await
}
}
async fn request_inner<R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<R> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request_id = turbomcp_protocol::MessageId::from(id.to_string());
let request = JsonRpcRequest {
jsonrpc: JsonRpcVersion,
id: request_id.clone(),
method: method.to_string(),
params,
};
let response_receiver = self.dispatcher.wait_for_response(request_id.clone());
let payload = serde_json::to_vec(&request)
.map_err(|e| Error::internal(format!("Failed to serialize request: {e}")))?;
let message = TransportMessage::new(
turbomcp_protocol::MessageId::from(format!("req-{id}")),
payload.into(),
);
self.transport.send(message).await.map_err(|e| {
self.dispatcher.remove_response_waiter(&request_id);
Error::transport(format!("Transport send failed: {e}"))
})?;
let response = if let Some(request_timeout) = self.config.timeouts.request {
match tokio::time::timeout(request_timeout, response_receiver).await {
Ok(Ok(response)) => response,
Ok(Err(_)) => return Err(Error::transport("Response channel closed".to_string())),
Err(_) => {
self.dispatcher.remove_response_waiter(&request_id);
let err = turbomcp_transport::TransportError::RequestTimeout {
operation: format!("{}()", method),
timeout: request_timeout,
};
return Err(Error::transport(err.to_string()));
}
}
} else {
response_receiver
.await
.map_err(|_| Error::transport("Response channel closed".to_string()))?
};
if let Some(error) = response.error() {
return Err(Error::from_rpc_code(error.code, &error.message));
}
serde_json::from_value(response.result().unwrap_or_default().clone())
.map_err(|e| Error::internal(format!("Failed to deserialize response: {e}")))
}
pub(super) async fn notify(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<()> {
let request = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params
});
let payload = serde_json::to_vec(&request)
.map_err(|e| Error::internal(format!("Failed to serialize notification: {e}")))?;
let message = TransportMessage::new(
turbomcp_protocol::MessageId::from("notification"),
payload.into(),
);
self.transport
.send(message)
.await
.map_err(|e| Error::transport(format!("Transport send failed: {e}")))
}
pub(super) fn transport(&self) -> &Arc<T> {
&self.transport
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use turbomcp_transport::{
TransportCapabilities, TransportConfig, TransportError, TransportMetrics, TransportResult,
TransportState, TransportType,
};
#[derive(Debug)]
struct MockTransport {
capabilities: TransportCapabilities,
fail_send: AtomicBool,
}
impl MockTransport {
fn ok() -> Self {
Self {
capabilities: TransportCapabilities::default(),
fail_send: AtomicBool::new(false),
}
}
fn fail_send() -> Self {
Self {
capabilities: TransportCapabilities::default(),
fail_send: AtomicBool::new(true),
}
}
}
impl Transport for MockTransport {
fn transport_type(&self) -> TransportType {
TransportType::Stdio
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
Box::pin(async { TransportState::Connected })
}
fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn send(
&self,
_message: TransportMessage,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
let fail = self.fail_send.load(Ordering::Relaxed);
Box::pin(async move {
if fail {
Err(TransportError::SendFailed("send failed".to_string()))
} else {
Ok(())
}
})
}
fn receive(
&self,
) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>>
{
Box::pin(async { Ok(None) })
}
fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
Box::pin(async { TransportMetrics::default() })
}
fn configure(
&self,
_config: TransportConfig,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
}
#[tokio::test]
async fn test_request_timeout_cleans_up_waiter() {
let config = TransportConfig {
timeouts: turbomcp_transport::config::TimeoutConfig {
request: Some(Duration::from_millis(10)),
total: Some(Duration::from_millis(25)),
..Default::default()
},
..Default::default()
};
let client = ProtocolClient::with_config(MockTransport::ok(), config);
let result: Result<serde_json::Value> = client.request("tools/list", None).await;
assert!(result.is_err());
assert_eq!(client.dispatcher.response_waiter_count(), 0);
client.dispatcher.shutdown();
}
#[tokio::test]
async fn test_send_failure_cleans_up_waiter() {
let client =
ProtocolClient::with_config(MockTransport::fail_send(), TransportConfig::default());
let result: Result<serde_json::Value> = client.request("tools/list", None).await;
assert!(result.is_err());
assert_eq!(client.dispatcher.response_waiter_count(), 0);
client.dispatcher.shutdown();
}
}