use anyhow::{anyhow, bail, Context, Result};
use bytes::Bytes;
use futures_util::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use serde_json::{json, Value};
use std::time::Duration;
use crate::cli::OutputFormat;
pub struct McpConfig {
pub url: String,
pub token: Option<String>,
pub timeout_secs: u64,
pub output_format: OutputFormat,
}
pub struct McpClient {
client: reqwest::Client,
config: McpConfig,
session_id: Option<String>,
}
impl McpClient {
pub fn new(config: McpConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.context("Failed to build HTTP client")?;
Ok(Self {
client,
config,
session_id: None,
})
}
fn build_headers(&self, session_id: Option<&str>) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(
ACCEPT,
HeaderValue::from_static("application/json, text/event-stream"),
);
if let Some(token) = &self.config.token {
let auth = format!("Bearer {token}");
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&auth).context("Invalid token value")?,
);
}
if let Some(sid) = session_id {
headers.insert(
"mcp-session-id",
HeaderValue::from_str(sid).context("Invalid session ID")?,
);
}
Ok(headers)
}
async fn send(&self, body: Value, session_id: Option<&str>) -> Result<Value> {
let headers = self.build_headers(session_id)?;
let response = self
.client
.post(&self.config.url)
.headers(headers)
.json(&body)
.send()
.await
.context("HTTP request failed")?;
let status = response.status();
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
bail!("Server returned HTTP {status}: {text}");
}
if content_type.contains("text/event-stream") {
self.parse_sse_response(response).await
} else {
let text = response.text().await.context("Failed to read response body")?;
serde_json::from_str(&text).context("Failed to parse JSON response")
}
}
async fn parse_sse_response(&self, response: reqwest::Response) -> Result<Value> {
let mut stream = response.bytes_stream();
let mut buffer = String::new();
let mut last_result: Option<Value> = None;
while let Some(chunk) = stream.next().await {
let chunk: Bytes = chunk.context("SSE stream error")?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(event_end) = buffer.find("\n\n") {
let event_str = buffer[..event_end].to_string();
buffer = buffer[event_end + 2..].to_string();
let event = parse_sse_event(&event_str);
if let Some(data) = event.data {
if data == "[DONE]" {
break;
}
if let Ok(parsed) = serde_json::from_str::<Value>(&data) {
last_result = Some(parsed);
}
}
}
}
if !buffer.is_empty() {
let event = parse_sse_event(&buffer);
if let Some(data) = event.data {
if data != "[DONE]" {
if let Ok(parsed) = serde_json::from_str::<Value>(&data) {
last_result = Some(parsed);
}
}
}
}
last_result.ok_or_else(|| anyhow!("No data received from SSE stream"))
}
pub async fn initialize(&mut self) -> Result<()> {
let body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "dwiki",
"version": env!("CARGO_PKG_VERSION")
}
}
});
let headers = self.build_headers(None)?;
let response = self
.client
.post(&self.config.url)
.headers(headers)
.json(&body)
.send()
.await
.context("Initialize request failed")?;
if let Some(sid) = response.headers().get("mcp-session-id") {
if let Ok(sid_str) = sid.to_str() {
self.session_id = Some(sid_str.to_string());
}
}
let status = response.status();
let _body = response.text().await.ok();
if !status.is_success() {
}
Ok(())
}
pub async fn call_tool(&self, tool_name: &str, arguments: Value) -> Result<String> {
let body = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
});
let response = self.send(body, self.session_id.as_deref()).await?;
extract_tool_result(&response)
}
}
struct SseEvent {
data: Option<String>,
}
fn parse_sse_event(raw: &str) -> SseEvent {
let mut data: Option<String> = None;
for line in raw.lines() {
if let Some(value) = line.strip_prefix("data:") {
let value = value.trim();
data = Some(
data.map(|d| format!("{d}\n{value}"))
.unwrap_or_else(|| value.to_string()),
);
}
}
SseEvent { data }
}
fn extract_tool_result(response: &Value) -> Result<String> {
if let Some(err) = response.get("error") {
let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(-1);
let msg = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
bail!("MCP error {code}: {msg}");
}
let result = response
.get("result")
.ok_or_else(|| anyhow!("Response has no 'result' field"))?;
if let Some(content_arr) = result.get("content").and_then(|v| v.as_array()) {
let mut parts: Vec<String> = Vec::new();
for item in content_arr {
if item.get("type").and_then(|v| v.as_str()) == Some("text") {
if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
parts.push(text.to_string());
}
}
}
if !parts.is_empty() {
return Ok(parts.join("\n"));
}
}
Ok(serde_json::to_string_pretty(result)
.unwrap_or_else(|_| result.to_string()))
}
#[allow(dead_code)]
pub async fn connect(config: McpConfig) -> Result<McpClient> {
let mut client = McpClient::new(config)?;
client.initialize().await?;
Ok(client)
}
pub fn print_result(output_format: &OutputFormat, text: &str) {
match output_format {
OutputFormat::Text => println!("{text}"),
OutputFormat::Json => {
if let Ok(v) = serde_json::from_str::<Value>(text) {
println!("{}", serde_json::to_string_pretty(&v).unwrap_or_else(|_| text.to_string()));
} else {
let wrapper = json!({ "result": text });
println!("{}", serde_json::to_string_pretty(&wrapper).unwrap());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_sse_event_single_data() {
let raw = "data: {\"foo\":\"bar\"}";
let event = parse_sse_event(raw);
assert_eq!(event.data, Some("{\"foo\":\"bar\"}".to_string()));
}
#[test]
fn test_parse_sse_event_multiline_data() {
let raw = "data: line1\ndata: line2";
let event = parse_sse_event(raw);
assert_eq!(event.data, Some("line1\nline2".to_string()));
}
#[test]
fn test_parse_sse_event_done() {
let raw = "data: [DONE]";
let event = parse_sse_event(raw);
assert_eq!(event.data, Some("[DONE]".to_string()));
}
#[test]
fn test_extract_tool_result_text_content() {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"result": {
"content": [
{ "type": "text", "text": "Hello, world!" }
]
}
});
let result = extract_tool_result(&response).unwrap();
assert_eq!(result, "Hello, world!");
}
#[test]
fn test_extract_tool_result_error() {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"error": { "code": -32601, "message": "Method not found" }
});
let result = extract_tool_result(&response);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Method not found"));
}
#[test]
fn test_print_result_json_wraps_plain_text() {
print_result(&OutputFormat::Json, "plain text");
}
}