use super::http::HttpTransport;
use super::stdio::StdioTransport;
use super::transport::{Transport, TransportError};
use serde::{Deserialize, Serialize};
use turbomcp_core::types::{
capabilities::{ClientCapabilities, ServerCapabilities},
content::ResourceContent,
core::Implementation,
initialization::InitializeResult,
prompts::{GetPromptResult, Prompt},
resources::{Resource, ResourceTemplate},
tools::{CallToolResult, Tool},
};
enum TransportKind {
Stdio(StdioTransport),
Http(HttpTransport),
}
impl TransportKind {
fn request<P, R>(&self, method: &str, params: Option<P>) -> Result<R, TransportError>
where
P: Serialize,
R: serde::de::DeserializeOwned,
{
match self {
Self::Stdio(t) => t.request(method, params),
Self::Http(t) => t.request(method, params),
}
}
fn notify<P>(&self, method: &str, params: Option<P>) -> Result<(), TransportError>
where
P: Serialize,
{
match self {
Self::Stdio(t) => t.notify(method, params),
Self::Http(t) => t.notify(method, params),
}
}
fn is_ready(&self) -> bool {
match self {
Self::Stdio(t) => t.is_ready(),
Self::Http(t) => t.is_ready(),
}
}
fn close(&self) -> Result<(), TransportError> {
match self {
Self::Stdio(t) => t.close(),
Self::Http(t) => t.close(),
}
}
}
pub struct McpClient {
transport: TransportKind,
initialized: bool,
server_info: Option<Implementation>,
server_capabilities: Option<ServerCapabilities>,
protocol_version: String,
}
impl McpClient {
#[must_use]
pub fn with_stdio(transport: StdioTransport) -> Self {
Self {
transport: TransportKind::Stdio(transport),
initialized: false,
server_info: None,
server_capabilities: None,
protocol_version: "2025-11-25".to_string(),
}
}
#[must_use]
pub fn with_http(transport: HttpTransport) -> Self {
Self {
transport: TransportKind::Http(transport),
initialized: false,
server_info: None,
server_capabilities: None,
protocol_version: "2025-11-25".to_string(),
}
}
pub fn initialize(&mut self) -> Result<InitializeResult, TransportError> {
let params = InitializeParams {
protocol_version: self.protocol_version.clone(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "turbomcp-wasm".to_string(),
title: Some("TurboMCP WASI Client".to_string()),
description: Some("MCP client running in WASI Preview 2 environment".to_string()),
version: env!("CARGO_PKG_VERSION").to_string(),
icon: None,
},
};
let result: InitializeResult = self.transport.request("initialize", Some(params))?;
self.initialized = true;
self.server_info = Some(result.server_info.clone());
self.server_capabilities = Some(result.capabilities.clone());
self.protocol_version = result.protocol_version.clone();
self.transport
.notify("notifications/initialized", None::<()>)?;
Ok(result)
}
#[must_use]
pub fn is_initialized(&self) -> bool {
self.initialized
}
#[must_use]
pub fn server_info(&self) -> Option<&Implementation> {
self.server_info.as_ref()
}
#[must_use]
pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
self.server_capabilities.as_ref()
}
#[must_use]
pub fn protocol_version(&self) -> &str {
&self.protocol_version
}
pub fn list_tools(&self) -> Result<Vec<Tool>, TransportError> {
self.ensure_initialized()?;
let result: ListToolsResult = self.transport.request("tools/list", None::<()>)?;
Ok(result.tools)
}
pub fn call_tool(
&self,
name: &str,
arguments: Option<serde_json::Value>,
) -> Result<CallToolResult, TransportError> {
self.ensure_initialized()?;
let params = CallToolParams {
name: name.to_string(),
arguments,
};
self.transport.request("tools/call", Some(params))
}
pub fn list_resources(&self) -> Result<Vec<Resource>, TransportError> {
self.ensure_initialized()?;
let result: ListResourcesResult = self.transport.request("resources/list", None::<()>)?;
Ok(result.resources)
}
pub fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContent>, TransportError> {
self.ensure_initialized()?;
let params = ReadResourceParams {
uri: uri.to_string(),
};
let result: ReadResourceResult = self.transport.request("resources/read", Some(params))?;
Ok(result.contents)
}
pub fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, TransportError> {
self.ensure_initialized()?;
let result: ListResourceTemplatesResult = self
.transport
.request("resources/templates/list", None::<()>)?;
Ok(result.resource_templates)
}
pub fn list_prompts(&self) -> Result<Vec<Prompt>, TransportError> {
self.ensure_initialized()?;
let result: ListPromptsResult = self.transport.request("prompts/list", None::<()>)?;
Ok(result.prompts)
}
pub fn get_prompt(
&self,
name: &str,
arguments: Option<serde_json::Value>,
) -> Result<GetPromptResult, TransportError> {
self.ensure_initialized()?;
let params = GetPromptParams {
name: name.to_string(),
arguments,
};
self.transport.request("prompts/get", Some(params))
}
pub fn ping(&self) -> Result<(), TransportError> {
let _: serde_json::Value = self.transport.request("ping", None::<()>)?;
Ok(())
}
pub fn close(&self) -> Result<(), TransportError> {
self.transport.close()
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.transport.is_ready()
}
fn ensure_initialized(&self) -> Result<(), TransportError> {
if !self.initialized {
Err(TransportError::Protocol(
"Client not initialized. Call initialize() first.".to_string(),
))
} else {
Ok(())
}
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct InitializeParams {
protocol_version: String,
capabilities: ClientCapabilities,
client_info: Implementation,
}
#[derive(Serialize)]
struct CallToolParams {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
arguments: Option<serde_json::Value>,
}
#[derive(Serialize)]
struct ReadResourceParams {
uri: String,
}
#[derive(Serialize)]
struct GetPromptParams {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
arguments: Option<serde_json::Value>,
}
#[derive(Deserialize)]
struct ListToolsResult {
tools: Vec<Tool>,
}
#[derive(Deserialize)]
struct ListResourcesResult {
resources: Vec<Resource>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ListResourceTemplatesResult {
resource_templates: Vec<ResourceTemplate>,
}
#[derive(Deserialize)]
struct ReadResourceResult {
contents: Vec<ResourceContent>,
}
#[derive(Deserialize)]
struct ListPromptsResult {
prompts: Vec<Prompt>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_with_stdio() {
let transport = StdioTransport::new();
let client = McpClient::with_stdio(transport);
assert!(!client.is_initialized());
assert!(client.is_ready());
}
#[test]
fn test_client_with_http() {
let transport = HttpTransport::new("https://api.example.com/mcp");
let client = McpClient::with_http(transport);
assert!(!client.is_initialized());
assert!(client.is_ready());
}
#[test]
fn test_client_protocol_version() {
let transport = StdioTransport::new();
let client = McpClient::with_stdio(transport);
assert_eq!(client.protocol_version(), "2025-11-25");
}
#[test]
fn test_ensure_initialized_fails() {
let transport = StdioTransport::new();
let client = McpClient::with_stdio(transport);
let result = client.ensure_initialized();
assert!(result.is_err());
}
}