use std::collections::HashMap;
use rmcp::{
model::{CallToolRequestParam, CallToolResult, Tool},
service::ServerSink,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator::Validate;
use crate::model::{Function, Tools};
#[inline]
pub fn mcp_tool_to_function(t: &Tool) -> Tools {
let desc = t.description.as_deref().unwrap_or("Remote MCP tool");
let schema = t.schema_as_json_value();
Tools::Function {
function: Function::new(t.name.to_string(), desc.to_string(), schema),
}
}
#[inline]
pub fn mcp_tools_to_functions(tools: &[Tool]) -> Vec<Tools> {
tools.iter().map(mcp_tool_to_function).collect()
}
#[inline]
pub fn call_tool_result_to_json(res: &CallToolResult) -> Value {
if let Some(structured) = &res.structured_content {
return structured.clone();
}
serde_json::to_value(res).unwrap_or_else(|_| {
serde_json::json!({
"error": {"type": "serialization_error", "message": "failed to serialize tool result"}
})
})
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct McpCallSpec {
#[validate(length(min = 1))]
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<Value>,
}
impl McpCallSpec {
pub fn new(name: impl Into<String>, arguments: Option<Value>) -> Self {
Self {
name: name.into(),
arguments,
}
}
}
pub async fn call_mcp_tool(
server: &ServerSink,
name: impl Into<String>,
args: Option<Value>,
) -> crate::ZaiResult<(String, Value)> {
let name: String = name.into();
if name.trim().is_empty() {
return Err(crate::client::error::ZaiError::Unknown {
code: 0,
message: "tool name cannot be empty".to_string(),
});
}
let arguments = match args {
Some(Value::Object(map)) => Some(map),
Some(other) => {
let val = serde_json::json!({
"error": {"type": "invalid_arguments", "message": "arguments must be a JSON object", "got": other}
});
return Ok((name.clone(), val));
},
None => None,
};
let res = server
.call_tool(CallToolRequestParam {
name: name.clone().into(),
arguments,
})
.await
.map_err(|e| crate::client::error::ZaiError::Unknown {
code: 0,
message: format!("RMCP service error: {}", e),
})?;
Ok((name, call_tool_result_to_json(&res)))
}
pub async fn call_mcp_tools_collect<I>(
server: &ServerSink,
calls: I,
) -> crate::ZaiResult<HashMap<String, Value>>
where
I: IntoIterator<Item = (String, Option<Value>)>,
{
use futures::stream::{FuturesUnordered, StreamExt};
let mut futs = FuturesUnordered::new();
for (name, args) in calls {
futs.push(call_mcp_tool(server, name, args));
}
let mut map = HashMap::new();
while let Some(item) = futs.next().await {
let (name, value) = item?;
map.insert(name, value);
}
Ok(map)
}
#[derive(Clone)]
pub struct McpToolCaller {
server: ServerSink,
}
impl McpToolCaller {
pub fn new(server: ServerSink) -> Self {
Self { server }
}
pub async fn call(
&self,
name: impl Into<String>,
args: Option<Value>,
) -> crate::ZaiResult<(String, Value)> {
call_mcp_tool(&self.server, name, args).await
}
pub async fn call_collect<I>(&self, calls: I) -> crate::ZaiResult<HashMap<String, Value>>
where
I: IntoIterator<Item = (String, Option<Value>)>,
{
call_mcp_tools_collect(&self.server, calls).await
}
}
#[cfg(feature = "rmcp-kits")]
pub async fn execute_tool_calls_as_messages(
caller: &McpToolCaller,
resp: &crate::model::chat_base_response::ChatCompletionResponse,
) -> crate::ZaiResult<Vec<crate::model::chat_message_types::TextMessage>> {
use crate::model::{chat_base_response::ToolCallMessage, chat_message_types::TextMessage};
let mut out: Vec<TextMessage> = Vec::new();
let calls: Option<&[ToolCallMessage]> = resp
.choices()
.and_then(|v| v.first())
.and_then(|c| c.message().tool_calls());
let Some(calls) = calls else { return Ok(out) };
tracing::info!("AI requested tool calls: {}", calls.len());
for tc in calls {
let id = match tc.id() {
Some(id) => id.to_string(),
None => {
tracing::warn!("Tool call without id, skipping");
continue;
},
};
let func = match tc.function() {
Some(f) => f,
None => {
tracing::warn!("Tool call missing function payload, skipping");
continue;
},
};
let name = match func.name() {
Some(n) => n.to_string(),
None => {
tracing::warn!("Tool call missing function name, skipping");
continue;
},
};
let args_value: Option<serde_json::Value> = match func.arguments() {
Some(arg_str) => match serde_json::from_str::<serde_json::Value>(arg_str) {
Ok(serde_json::Value::Object(map)) => Some(serde_json::Value::Object(map)),
Ok(_) => {
tracing::warn!("Function arguments are not an object; passing None");
None
},
Err(e) => {
tracing::warn!("Failed to parse function arguments JSON: {}", e);
None
},
},
None => None,
};
let (_tool, payload): (String, Value) =
caller.call(name, args_value).await.map_err(|e| {
crate::client::error::ZaiError::Unknown {
code: 0,
message: format!("RMCP call_tool failed: {}", e),
}
})?;
out.push(TextMessage::tool_with_id(payload.to_string(), id));
}
Ok(out)
}
#[cfg(feature = "rmcp-kits")]
pub async fn run_mcp_tool_roundtrip<N>(
caller: &McpToolCaller,
mut chat: crate::model::chat::data::ChatCompletion<
N,
crate::model::chat_message_types::TextMessage,
crate::model::traits::StreamOff,
>,
system_hint_after_tools: Option<&str>,
) -> crate::ZaiResult<crate::model::chat_base_response::ChatCompletionResponse>
where
N: crate::model::traits::ModelName + crate::model::traits::Chat + serde::Serialize,
(N, crate::model::chat_message_types::TextMessage): crate::model::traits::Bounded,
{
use crate::model::chat_message_types::TextMessage;
let first_resp = chat.send().await?;
tracing::info!("AI response: {:#?}", first_resp);
let tool_msgs: Vec<crate::model::chat_message_types::TextMessage> =
execute_tool_calls_as_messages(caller, &first_resp).await?;
if tool_msgs.is_empty() {
return Ok(first_resp);
}
for m in tool_msgs {
chat = chat.add_messages(m);
}
chat.body_mut().tools = None;
if let Some(hint) = system_hint_after_tools {
chat = chat.add_messages(TextMessage::system(hint));
}
let final_resp = chat.send().await?;
Ok(final_resp)
}
#[cfg(feature = "rmcp-kits")]
pub fn extract_final_text(
resp: &crate::model::chat_base_response::ChatCompletionResponse,
) -> Option<String> {
let msg = resp.choices()?.first()?.message();
match msg.content() {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(serde_json::Value::Array(arr)) => arr.iter().find_map(|item| {
if let serde_json::Value::Object(obj) = item
&& obj.get("type").and_then(|v| v.as_str()) == Some("text")
{
return obj
.get("text")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
}
None
}),
_ => None,
}
}