#[cfg(any(feature = "server", feature = "client"))]
use crate::error::{Error, ErrorCode};
use crate::shared;
use crate::types::{Cursor, Icon, PropertyType, request::RequestParamsMeta};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
#[cfg(feature = "server")]
use {
super::helpers::TypeCategory,
crate::json::JsonSchema,
crate::types::{FromRequest, IntoResponse, Page, Request, RequestId, Response},
crate::{
Context,
app::handler::{FromHandlerParams, GenericHandler, Handler, HandlerParams, RequestHandler},
},
futures_util::future::BoxFuture,
std::{future::Future, sync::Arc},
};
#[cfg(all(feature = "server", feature = "tasks"))]
use crate::types::RelatedTaskMetadata;
#[cfg(feature = "tasks")]
use crate::types::TaskMetadata;
#[cfg(feature = "client")]
use jsonschema::validator_for;
pub use call_tool_response::CallToolResponse;
mod call_tool_response;
#[cfg(feature = "server")]
mod from_request;
pub mod commands {
pub const LIST: &str = "tools/list";
pub const LIST_CHANGED: &str = "notifications/tools/list_changed";
pub const CALL: &str = "tools/call";
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(rename = "description", skip_serializing_if = "Option::is_none")]
pub descr: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: ToolSchema,
#[serde(rename = "outputSchema", skip_serializing_if = "Option::is_none")]
pub output_schema: Option<ToolSchema>,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<ToolAnnotations>,
#[serde(skip_serializing_if = "Option::is_none")]
pub icons: Option<Vec<Icon>>,
#[cfg(feature = "tasks")]
#[serde(rename = "execution", skip_serializing_if = "Option::is_none")]
pub exec: Option<ToolExecution>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<Value>,
#[serde(skip)]
#[cfg(feature = "http-server")]
pub(crate) roles: Option<Vec<String>>,
#[serde(skip)]
#[cfg(feature = "http-server")]
pub(crate) permissions: Option<Vec<String>>,
#[serde(skip)]
#[cfg(feature = "server")]
handler: Option<RequestHandler<CallToolResponse>>,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
#[cfg(feature = "tasks")]
pub struct ToolExecution {
#[serde(rename = "taskSupport", skip_serializing_if = "Option::is_none")]
pub task_support: Option<TaskSupport>,
}
#[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[cfg(feature = "tasks")]
#[serde(rename_all = "lowercase")]
pub enum TaskSupport {
#[default]
Forbidden,
Optional,
Required,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct ListToolsRequestParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub cursor: Option<Cursor>,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct ListToolsResult {
pub tools: Vec<Tool>,
#[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<Cursor>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallToolRequestParams {
pub name: String,
#[serde(rename = "arguments")]
pub args: Option<HashMap<String, Value>>,
#[cfg(feature = "tasks")]
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<TaskMetadata>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<RequestParamsMeta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolSchema {
#[serde(rename = "type", default)]
pub r#type: PropertyType,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, SchemaProperty>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SchemaProperty {
#[serde(rename = "type", default)]
pub r#type: PropertyType,
#[serde(rename = "description", skip_serializing_if = "Option::is_none")]
pub descr: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolAnnotations {
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
pub destructive: Option<bool>,
#[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
pub idempotent: Option<bool>,
#[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
pub open_world: Option<bool>,
#[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
pub readonly: Option<bool>,
}
#[cfg(feature = "server")]
impl IntoResponse for ListToolsResult {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
match serde_json::to_value(self) {
Ok(v) => Response::success(req_id, v),
Err(err) => Response::error(req_id, err.into()),
}
}
}
#[cfg(feature = "server")]
impl From<Vec<Tool>> for ListToolsResult {
#[inline]
fn from(tools: Vec<Tool>) -> Self {
Self {
next_cursor: None,
tools,
}
}
}
#[cfg(feature = "server")]
impl From<Page<'_, Tool>> for ListToolsResult {
#[inline]
fn from(page: Page<'_, Tool>) -> Self {
Self {
next_cursor: page.next_cursor,
tools: page.items.to_vec(),
}
}
}
#[cfg(feature = "server")]
impl ListToolsResult {
#[inline]
pub fn new() -> Self {
Default::default()
}
}
#[cfg(feature = "client")]
impl ListToolsResult {
#[inline]
pub fn get(&self, name: impl AsRef<str>) -> Option<&Tool> {
self.get_by(|t| t.name == name.as_ref())
}
#[inline]
pub fn get_by<F>(&self, mut f: F) -> Option<&Tool>
where
F: FnMut(&Tool) -> bool,
{
self.tools.iter().find(|&t| f(t))
}
}
impl Default for ToolSchema {
#[inline]
fn default() -> Self {
Self {
r#type: PropertyType::Object,
properties: Some(HashMap::new()),
required: None,
}
}
}
impl Default for ToolAnnotations {
#[inline]
fn default() -> Self {
Self {
title: None,
destructive: Some(true),
idempotent: Some(false),
open_world: Some(true),
readonly: Some(false),
}
}
}
#[cfg(feature = "tasks")]
impl From<&str> for TaskSupport {
#[inline]
fn from(value: &str) -> Self {
match value {
"forbidden" => Self::Forbidden,
"required" => Self::Required,
"optional" => Self::Optional,
_ => unreachable!(),
}
}
}
#[cfg(feature = "tasks")]
impl From<String> for TaskSupport {
#[inline]
fn from(value: String) -> Self {
Self::from(value.as_str())
}
}
#[cfg(feature = "server")]
impl ToolSchema {
#[inline]
pub(crate) fn new(props: Option<HashMap<String, SchemaProperty>>) -> Self {
Self {
r#type: PropertyType::Object,
properties: props,
required: None,
}
}
#[inline]
pub fn from_json_str(json: &str) -> Self {
serde_json::from_str(json).expect("InputSchema: Incorrect JSON string provided")
}
pub fn with_prop<T: Into<PropertyType>>(
self,
name: &str,
descr: &str,
property_type: T,
) -> Self {
self.add_property_impl(name, descr, property_type.into())
}
pub fn with_required<T: Into<PropertyType>>(
self,
name: &str,
descr: &str,
property_type: T,
) -> Self {
self.add_required_property_impl(name, descr, property_type.into())
}
pub fn with_schema<T: JsonSchema>(self) -> Self {
let json_schema = schemars::schema_for!(T);
self.with_schema_impl(json_schema)
}
pub fn from_schema(json_schema: schemars::Schema) -> Self {
Self::default().with_schema_impl(json_schema)
}
#[inline]
fn with_schema_impl(mut self, json_schema: schemars::Schema) -> Self {
let required = json_schema.get("required").and_then(|v| v.as_array());
if let Some(props) = json_schema.get("properties").and_then(|v| v.as_object()) {
for (field, def) in props {
let req = required
.map(|arr| !arr.iter().any(|v| v == field))
.unwrap_or(true);
let type_str = def.get("type").and_then(|v| v.as_str()).unwrap_or("string");
self = if req {
self.add_required_property_impl(field, field, type_str.into())
} else {
self.add_property_impl(field, field, type_str.into())
};
}
}
self
}
#[inline]
fn add_property_impl(mut self, name: &str, descr: &str, property_type: PropertyType) -> Self {
self.properties.get_or_insert_with(HashMap::new).insert(
name.into(),
SchemaProperty {
r#type: property_type,
descr: Some(descr.into()),
},
);
self
}
#[inline]
fn add_required_property_impl(
mut self,
name: &str,
descr: &str,
property_type: PropertyType,
) -> Self {
self = self.add_property_impl(name, descr, property_type);
self.required.get_or_insert_with(Vec::new).push(name.into());
self
}
}
#[cfg(feature = "server")]
impl SchemaProperty {
#[inline]
pub(crate) fn new<T: TypeCategory>() -> Self {
Self {
r#type: T::category(),
descr: None,
}
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for CallToolRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for ListToolsRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
#[cfg(feature = "server")]
pub trait ToolHandler<Args>: GenericHandler<Args> {
#[inline]
fn args() -> Option<HashMap<String, SchemaProperty>> {
None
}
}
#[cfg(feature = "server")]
pub(crate) struct ToolFunc<F, R, Args>
where
F: ToolHandler<Args, Output = R>,
R: Into<CallToolResponse>,
Args: TryFrom<CallToolRequestParams, Error = Error>,
{
func: F,
_marker: std::marker::PhantomData<Args>,
}
#[cfg(feature = "server")]
impl<F, R, Args> ToolFunc<F, R, Args>
where
F: ToolHandler<Args, Output = R>,
R: Into<CallToolResponse>,
Args: TryFrom<CallToolRequestParams, Error = Error>,
{
pub(crate) fn new(func: F) -> Arc<Self> {
let func = Self {
func,
_marker: std::marker::PhantomData,
};
Arc::new(func)
}
}
#[cfg(feature = "server")]
impl<F, R, Args> Handler<CallToolResponse> for ToolFunc<F, R, Args>
where
F: ToolHandler<Args, Output = R>,
R: Into<CallToolResponse>,
Args: TryFrom<CallToolRequestParams, Error = Error> + Send + Sync,
{
#[inline]
fn call(&self, params: HandlerParams) -> BoxFuture<'_, Result<CallToolResponse, Error>> {
let HandlerParams::Tool(params) = params else {
unreachable!()
};
Box::pin(async move {
let args = Args::try_from(params)?;
Ok(self.func.call(args).await.into())
})
}
}
impl CallToolRequestParams {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
args: None,
meta: None,
#[cfg(feature = "tasks")]
task: None,
}
}
pub fn with_args<Args: shared::IntoArgs>(mut self, args: Args) -> Self {
self.args = args.into_args();
self
}
pub fn with_meta(mut self, meta: RequestParamsMeta) -> Self {
self.meta = Some(meta);
self
}
#[cfg(feature = "tasks")]
pub fn with_ttl(mut self, ttl: Option<usize>) -> Self {
self.task = Some(TaskMetadata { ttl });
self
}
}
#[cfg(feature = "server")]
impl CallToolRequestParams {
pub(crate) fn with_context(mut self, ctx: Context) -> Self {
self.meta.get_or_insert_default().context = Some(ctx);
self
}
#[cfg(feature = "tasks")]
pub(crate) fn with_task(mut self, task_id: impl Into<String>) -> Self {
self.meta.get_or_insert_default().task = Some(RelatedTaskMetadata { id: task_id.into() });
self
}
}
impl Debug for Tool {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tool")
.field("name", &self.name)
.field("title", &self.title)
.field("descr", &self.descr)
.field("input_schema", &self.input_schema)
.field("output_schema", &self.output_schema)
.field("annotations", &self.annotations)
.field("meta", &self.meta)
.finish()
}
}
#[cfg(feature = "server")]
impl Tool {
pub fn new<F, Args, R>(name: impl Into<String>, handler: F) -> Self
where
F: ToolHandler<Args, Output = R>,
R: Into<CallToolResponse> + Send + 'static,
Args: TryFrom<CallToolRequestParams, Error = Error> + Send + Sync + 'static,
{
let handler = ToolFunc::new(handler);
let input_schema = ToolSchema::new(F::args());
Self {
name: name.into(),
title: None,
descr: None,
input_schema,
output_schema: None,
meta: None,
annotations: None,
handler: Some(handler),
icons: None,
#[cfg(feature = "http-server")]
roles: None,
#[cfg(feature = "http-server")]
permissions: None,
#[cfg(feature = "tasks")]
exec: None,
}
}
pub fn with_title(&mut self, title: impl Into<String>) -> &mut Self {
self.title = Some(title.into());
self
}
pub fn with_description(&mut self, description: &str) -> &mut Self {
self.descr = Some(description.into());
self
}
pub fn with_input_schema<F>(&mut self, config: F) -> &mut Self
where
F: FnOnce(ToolSchema) -> ToolSchema,
{
self.input_schema = config(Default::default());
self
}
pub fn with_output_schema<F>(&mut self, config: F) -> &mut Self
where
F: FnOnce(ToolSchema) -> ToolSchema,
{
self.output_schema = Some(config(Default::default()));
self
}
#[cfg(feature = "http-server")]
pub fn with_roles<T, I>(&mut self, roles: T) -> &mut Self
where
T: IntoIterator<Item = I>,
I: Into<String>,
{
self.roles = Some(roles.into_iter().map(Into::into).collect());
self
}
#[cfg(feature = "http-server")]
pub fn with_permissions<T, I>(&mut self, permissions: T) -> &mut Self
where
T: IntoIterator<Item = I>,
I: Into<String>,
{
self.permissions = Some(permissions.into_iter().map(Into::into).collect());
self
}
pub fn with_annotations<F>(&mut self, config: F) -> &mut Self
where
F: FnOnce(ToolAnnotations) -> ToolAnnotations,
{
self.annotations = Some(config(Default::default()));
self
}
pub fn with_icons(&mut self, icons: impl IntoIterator<Item = Icon>) -> &mut Self {
self.icons = Some(icons.into_iter().collect());
self
}
#[cfg(feature = "tasks")]
pub fn with_task_support(&mut self, support: impl Into<TaskSupport>) -> &mut Self {
self.exec = Some(ToolExecution::new(support.into()));
self
}
#[inline]
pub(crate) async fn call(&self, params: HandlerParams) -> Result<CallToolResponse, Error> {
match self.handler {
Some(ref handler) => handler.call(params).await,
None => Err(Error::new(
ErrorCode::InternalError,
"Tool handler not specified",
)),
}
}
}
#[cfg(feature = "client")]
impl Tool {
pub fn validate<'a>(&self, resp: &'a CallToolResponse) -> Result<&'a CallToolResponse, Error> {
let schema = self.output_schema.as_ref().map_or_else(
|| {
Err(Error::new(
ErrorCode::ParseError,
"Tool: Output schema not specified",
))
},
|s| serde_json::to_value(s.clone()).map_err(Into::into),
)?;
let validator =
validator_for(&schema).map_err(|err| Error::new(ErrorCode::ParseError, err))?;
let content = resp.struct_content()?;
validator
.validate(content)
.map(|_| resp)
.map_err(|err| Error::new(ErrorCode::ParseError, err.to_string()))
}
}
#[cfg(feature = "tasks")]
impl Tool {
#[inline]
pub fn task_support(&self) -> Option<TaskSupport> {
self.exec.as_ref().and_then(|e| e.task_support)
}
}
#[cfg(feature = "server")]
impl ToolAnnotations {
#[inline]
pub fn new() -> Self {
Default::default()
}
#[inline]
pub fn from_json_str(json: &str) -> Self {
serde_json::from_str(json).expect("ToolAnnotations: Incorrect JSON string provided")
}
#[inline]
pub fn with_title(mut self, title: &str) -> Self {
self.title = Some(title.into());
self
}
#[inline]
pub fn with_destructive(mut self, destructive: bool) -> Self {
self.destructive = Some(destructive);
self.readonly = Some(false);
self
}
pub fn with_idempotent(mut self, idempotent: bool) -> Self {
self.idempotent = Some(idempotent);
self.readonly = Some(false);
self
}
#[inline]
pub fn with_open_world(mut self, open_world: bool) -> Self {
self.open_world = Some(open_world);
self
}
}
#[cfg(all(feature = "server", feature = "tasks"))]
impl ToolExecution {
#[inline]
pub fn new(support: TaskSupport) -> Self {
Self {
task_support: Some(support),
}
}
}
macro_rules! impl_generic_tool_handler ({ $($param:ident)* } => {
#[cfg(feature = "server")]
impl<Func, Fut: Send, $($param: TypeCategory,)*> ToolHandler<($($param,)*)> for Func
where
Func: Fn($($param),*) -> Fut + Send + Sync + Clone + 'static,
Fut: Future + 'static,
{
#[inline]
#[allow(unused_mut)]
fn args() -> Option<HashMap<String, SchemaProperty>> {
let mut args = HashMap::new();
$(
{
let prop = SchemaProperty::new::<$param>();
if prop.r#type != PropertyType::None {
args.insert(
prop.r#type.to_string(),
prop
);
}
};
)*
if args.len() == 0 {
None
} else {
Some(args)
}
}
}
});
impl_generic_tool_handler! {}
impl_generic_tool_handler! { T1 }
impl_generic_tool_handler! { T1 T2 }
impl_generic_tool_handler! { T1 T2 T3 }
impl_generic_tool_handler! { T1 T2 T3 T4 }
impl_generic_tool_handler! { T1 T2 T3 T4 T5 }
#[cfg(test)]
#[cfg(feature = "server")]
mod tests {
use super::*;
#[tokio::test]
async fn it_creates_and_calls_tool() {
let tool = Tool::new("sum", |a: i32, b: i32| async move { a + b });
let params = CallToolRequestParams {
name: "sum".into(),
meta: None,
#[cfg(feature = "tasks")]
task: None,
args: Some(HashMap::from([
("a".into(), serde_json::to_value(5).unwrap()),
("b".into(), serde_json::to_value(2).unwrap()),
])),
};
let resp = tool.call(params.into()).await.unwrap();
let json = serde_json::to_string(&resp).unwrap();
assert_eq!(
json,
r#"{"content":[{"type":"text","text":"7"}],"isError":false}"#
);
}
#[test]
fn it_deserializes_input_schema() {
let json = r#"{
"properties": {
"name": {
"type": "string",
"description": "A name to whom say hello"
}
}
}"#;
let schema: ToolSchema = serde_json::from_str(json).unwrap();
assert_eq!(schema.r#type, PropertyType::Object);
assert!(schema.properties.is_some());
}
}