use std::{collections::HashMap, str::FromStr};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use crate::{
error::{Error, Result},
json_rpc::JsonRPCParameters,
};
#[async_trait(?Send)]
pub trait BakedMcpToolTrait {
type Error;
async fn call(&mut self, params: &McpParams) -> core::result::Result<String, Self::Error>;
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum ToolType {
#[serde(rename = "object")]
Object,
#[serde(rename = "string")]
String,
#[serde(rename = "integer")]
Integer,
#[serde(rename = "boolean")]
Boolean,
#[serde(rename = "array")]
Array,
#[serde(rename = "number")]
Number,
#[serde(rename = "function")]
Function,
}
pub type McpArguments = HashMap<String, Value>;
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct McpParams {
#[serde(rename = "name")]
pub tool_name: String,
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub arguments: McpArguments,
}
impl TryFrom<&McpParams> for JsonRPCParameters {
type Error = Error;
fn try_from(mcp_params: &McpParams) -> Result<JsonRPCParameters> {
let mcp_params_json = serde_json::to_string(mcp_params)?;
let params: JsonRPCParameters = serde_json::from_str(&mcp_params_json)?;
Ok(params)
}
}
impl McpParams {
pub fn new<S>(name: S) -> Self
where
S: AsRef<str>,
{
Self {
tool_name: name.as_ref().to_string(),
..Default::default()
}
}
pub fn add_argument<S>(&mut self, name: S, value: Value)
where
S: AsRef<str>,
{
self.arguments.insert(name.as_ref().to_string(), value);
}
pub fn set_argument(&mut self, args: McpArguments) {
self.arguments = args
}
pub fn get_bool<S>(&self, key: S) -> Result<bool>
where
S: AsRef<str>,
{
let v = self
.arguments
.get(key.as_ref())
.ok_or(Error::ParameterNotFound)?
.as_bool()
.ok_or(Error::ParameterInvalidFormat)?;
Ok(v)
}
pub fn get_object<S>(&self, key: S) -> Result<Map<String, Value>>
where
S: AsRef<str>,
{
let v = self
.arguments
.get(key.as_ref())
.ok_or(Error::ParameterNotFound)?
.as_object()
.ok_or(Error::ParameterInvalidFormat)?;
Ok(v.clone())
}
pub fn get_int<S>(&self, key: S) -> Result<i64>
where
S: AsRef<str>,
{
let v = self
.arguments
.get(key.as_ref())
.ok_or(Error::ParameterNotFound)?
.as_i64()
.ok_or(Error::ParameterInvalidFormat)?;
Ok(v)
}
pub fn get_string<S>(&self, key: S) -> Result<&str>
where
S: AsRef<str>,
{
let v = self
.arguments
.get(key.as_ref())
.ok_or(Error::ParameterNotFound)?
.as_str()
.ok_or(Error::ParameterInvalidFormat)?;
Ok(v)
}
pub fn get<S>(&self, key: S) -> Result<&Value>
where
S: AsRef<str>,
{
let v = self.arguments.get(key.as_ref()).ok_or(Error::ParameterNotFound)?;
Ok(v)
}
}
impl AsRef<McpParams> for McpParams {
fn as_ref(&self) -> &McpParams {
self
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct McpToolProperty {
#[serde(rename = "type")]
#[serde(skip_serializing_if = "Option::is_none")]
pub property_type: Option<ToolType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<McpToolSchema>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enums: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct McpToolSchema {
#[serde(rename = "type")]
pub schema_type: ToolType,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, McpToolProperty>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enums: Option<Vec<String>>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct McpTool {
pub name: String,
pub description: String,
#[serde(rename(serialize = "input_schema", deserialize = "inputSchema"))]
pub input_schema: Option<McpToolSchema>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum McpTypes {
#[serde(rename = "sse")]
Sse,
#[serde(rename = "baked")]
Baked,
}
impl std::fmt::Display for McpTypes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McpTypes::Sse => write!(f, "sse"),
McpTypes::Baked => write!(f, "baked"),
}
}
}
impl FromStr for ToolType {
type Err = Error;
fn from_str(s: &str) -> Result<ToolType> {
match s {
"object" => Ok(ToolType::Object),
"string" => Ok(ToolType::String),
"integer" => Ok(ToolType::Integer),
"boolean" => Ok(ToolType::Boolean),
"array" => Ok(ToolType::Array),
"number" => Ok(ToolType::Number),
"function" => Ok(ToolType::Function),
_ => Err(Error::NotImplemented),
}
}
}