use std::future::Future;
use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::context::RequestContext;
use crate::error::{Error, Result};
use crate::protocol::CallToolResult;
#[derive(Debug, Clone)]
pub struct Rejection {
message: String,
}
impl Rejection {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
pub fn message(&self) -> &str {
&self.message
}
}
impl std::fmt::Display for Rejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for Rejection {}
impl From<Rejection> for Error {
fn from(rejection: Rejection) -> Self {
Error::tool(rejection.message)
}
}
#[derive(Debug, Clone)]
pub struct JsonRejection {
message: String,
path: Option<String>,
}
impl JsonRejection {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
path: None,
}
}
pub fn with_path(message: impl Into<String>, path: impl Into<String>) -> Self {
Self {
message: message.into(),
path: Some(path.into()),
}
}
pub fn message(&self) -> &str {
&self.message
}
pub fn path(&self) -> Option<&str> {
self.path.as_deref()
}
}
impl std::fmt::Display for JsonRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(path) = &self.path {
write!(f, "Invalid input at `{}`: {}", path, self.message)
} else {
write!(f, "Invalid input: {}", self.message)
}
}
}
impl std::error::Error for JsonRejection {}
impl From<JsonRejection> for Error {
fn from(rejection: JsonRejection) -> Self {
Error::tool(rejection.to_string())
}
}
impl From<serde_json::Error> for JsonRejection {
fn from(err: serde_json::Error) -> Self {
let path = if err.is_data() {
None
} else {
None
};
Self {
message: err.to_string(),
path,
}
}
}
#[derive(Debug, Clone)]
pub struct ExtensionRejection {
type_name: &'static str,
}
impl ExtensionRejection {
pub fn not_found<T>() -> Self {
Self {
type_name: std::any::type_name::<T>(),
}
}
pub fn type_name(&self) -> &'static str {
self.type_name
}
}
impl std::fmt::Display for ExtensionRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Extension of type `{}` not found. Did you call `router.with_state()` or `router.with_extension()`?",
self.type_name
)
}
}
impl std::error::Error for ExtensionRejection {}
impl From<ExtensionRejection> for Error {
fn from(rejection: ExtensionRejection) -> Self {
Error::tool(rejection.to_string())
}
}
pub trait FromToolRequest<S = ()>: Sized {
type Rejection: Into<Error>;
fn from_tool_request(
ctx: &RequestContext,
state: &S,
args: &Value,
) -> std::result::Result<Self, Self::Rejection>;
}
#[derive(Debug, Clone, Copy)]
pub struct Json<T>(pub T);
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S, T> FromToolRequest<S> for Json<T>
where
T: DeserializeOwned,
{
type Rejection = JsonRejection;
fn from_tool_request(
_ctx: &RequestContext,
_state: &S,
args: &Value,
) -> std::result::Result<Self, Self::Rejection> {
serde_json::from_value(args.clone())
.map(Json)
.map_err(JsonRejection::from)
}
}
#[derive(Debug, Clone, Copy)]
pub struct State<T>(pub T);
impl<T> Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S: Clone> FromToolRequest<S> for State<S> {
type Rejection = Rejection;
fn from_tool_request(
_ctx: &RequestContext,
state: &S,
_args: &Value,
) -> std::result::Result<Self, Self::Rejection> {
Ok(State(state.clone()))
}
}
#[derive(Debug, Clone)]
pub struct Context(RequestContext);
impl Context {
pub fn into_inner(self) -> RequestContext {
self.0
}
}
impl Deref for Context {
type Target = RequestContext;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> FromToolRequest<S> for Context {
type Rejection = Rejection;
fn from_tool_request(
ctx: &RequestContext,
_state: &S,
_args: &Value,
) -> std::result::Result<Self, Self::Rejection> {
Ok(Context(ctx.clone()))
}
}
#[derive(Debug, Clone)]
pub struct RawArgs(pub Value);
impl Deref for RawArgs {
type Target = Value;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> FromToolRequest<S> for RawArgs {
type Rejection = Rejection;
fn from_tool_request(
_ctx: &RequestContext,
_state: &S,
args: &Value,
) -> std::result::Result<Self, Self::Rejection> {
Ok(RawArgs(args.clone()))
}
}
#[derive(Debug, Clone)]
pub struct Extension<T>(pub T);
impl<T> Deref for Extension<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S, T> FromToolRequest<S> for Extension<T>
where
T: Clone + Send + Sync + 'static,
{
type Rejection = ExtensionRejection;
fn from_tool_request(
ctx: &RequestContext,
_state: &S,
_args: &Value,
) -> std::result::Result<Self, Self::Rejection> {
ctx.extension::<T>()
.cloned()
.map(Extension)
.ok_or_else(ExtensionRejection::not_found::<T>)
}
}
pub trait ExtractorHandler<S, T>: Clone + Send + Sync + 'static {
type Future: Future<Output = Result<CallToolResult>> + Send;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
fn input_schema() -> Value;
}
impl<S, F, Fut, T1> ExtractorHandler<S, (T1,)> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + HasSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
self(t1).await
})
}
fn input_schema() -> Value {
if let Some(schema) = T1::schema() {
return schema;
}
serde_json::json!({
"type": "object",
"additionalProperties": true
})
}
}
impl<S, F, Fut, T1, T2> ExtractorHandler<S, (T1, T2)> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + HasSchema + Send,
T2: FromToolRequest<S> + HasSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
self(t1, t2).await
})
}
fn input_schema() -> Value {
if let Some(schema) = T2::schema() {
return schema;
}
if let Some(schema) = T1::schema() {
return schema;
}
serde_json::json!({
"type": "object",
"additionalProperties": true
})
}
}
impl<S, F, Fut, T1, T2, T3> ExtractorHandler<S, (T1, T2, T3)> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + HasSchema + Send,
T2: FromToolRequest<S> + HasSchema + Send,
T3: FromToolRequest<S> + HasSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
self(t1, t2, t3).await
})
}
fn input_schema() -> Value {
if let Some(schema) = T3::schema() {
return schema;
}
if let Some(schema) = T2::schema() {
return schema;
}
if let Some(schema) = T1::schema() {
return schema;
}
serde_json::json!({
"type": "object",
"additionalProperties": true
})
}
}
impl<S, F, Fut, T1, T2, T3, T4> ExtractorHandler<S, (T1, T2, T3, T4)> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + HasSchema + Send,
T2: FromToolRequest<S> + HasSchema + Send,
T3: FromToolRequest<S> + HasSchema + Send,
T4: FromToolRequest<S> + HasSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
self(t1, t2, t3, t4).await
})
}
fn input_schema() -> Value {
if let Some(schema) = T4::schema() {
return schema;
}
if let Some(schema) = T3::schema() {
return schema;
}
if let Some(schema) = T2::schema() {
return schema;
}
if let Some(schema) = T1::schema() {
return schema;
}
serde_json::json!({
"type": "object",
"additionalProperties": true
})
}
}
impl<S, F, Fut, T1, T2, T3, T4, T5> ExtractorHandler<S, (T1, T2, T3, T4, T5)> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + HasSchema + Send,
T2: FromToolRequest<S> + HasSchema + Send,
T3: FromToolRequest<S> + HasSchema + Send,
T4: FromToolRequest<S> + HasSchema + Send,
T5: FromToolRequest<S> + HasSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t4 = T4::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
let t5 = T5::from_tool_request(&ctx, &state, &args).map_err(Into::into)?;
self(t1, t2, t3, t4, t5).await
})
}
fn input_schema() -> Value {
if let Some(schema) = T5::schema() {
return schema;
}
if let Some(schema) = T4::schema() {
return schema;
}
if let Some(schema) = T3::schema() {
return schema;
}
if let Some(schema) = T2::schema() {
return schema;
}
if let Some(schema) = T1::schema() {
return schema;
}
serde_json::json!({
"type": "object",
"additionalProperties": true
})
}
}
pub trait HasSchema {
fn schema() -> Option<Value>;
}
impl<T: JsonSchema> HasSchema for Json<T> {
fn schema() -> Option<Value> {
let schema = schemars::schema_for!(T);
serde_json::to_value(schema)
.ok()
.map(crate::tool::ensure_object_schema)
}
}
impl HasSchema for Context {
fn schema() -> Option<Value> {
None
}
}
impl HasSchema for RawArgs {
fn schema() -> Option<Value> {
None
}
}
impl<T> HasSchema for State<T> {
fn schema() -> Option<Value> {
None
}
}
impl<T> HasSchema for Extension<T> {
fn schema() -> Option<Value> {
None
}
}
#[deprecated(
since = "0.8.0",
note = "Use `ExtractorHandler` instead -- `extractor_handler` auto-detects JSON schema from `Json<T>` extractors"
)]
pub trait TypedExtractorHandler<S, T, I>: Clone + Send + Sync + 'static
where
I: JsonSchema,
{
type Future: Future<Output = Result<CallToolResult>> + Send;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future;
}
#[allow(deprecated)]
impl<S, F, Fut, T> TypedExtractorHandler<S, (Json<T>,), T> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(Json<T>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T: DeserializeOwned + JsonSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 =
Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
self(t1).await
})
}
}
#[allow(deprecated)]
impl<S, F, Fut, T1, T> TypedExtractorHandler<S, (T1, Json<T>), T> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, Json<T>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + Send,
T: DeserializeOwned + JsonSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
let t2 =
Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
self(t1, t2).await
})
}
}
#[allow(deprecated)]
impl<S, F, Fut, T1, T2, T> TypedExtractorHandler<S, (T1, T2, Json<T>), T> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, T2, Json<T>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + Send,
T2: FromToolRequest<S> + Send,
T: DeserializeOwned + JsonSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
let t3 =
Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
self(t1, t2, t3).await
})
}
}
#[allow(deprecated)]
impl<S, F, Fut, T1, T2, T3, T> TypedExtractorHandler<S, (T1, T2, T3, Json<T>), T> for F
where
S: Clone + Send + Sync + 'static,
F: Fn(T1, T2, T3, Json<T>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<CallToolResult>> + Send,
T1: FromToolRequest<S> + Send,
T2: FromToolRequest<S> + Send,
T3: FromToolRequest<S> + Send,
T: DeserializeOwned + JsonSchema + Send,
{
type Future = Pin<Box<dyn Future<Output = Result<CallToolResult>> + Send>>;
fn call(self, ctx: RequestContext, state: S, args: Value) -> Self::Future {
Box::pin(async move {
let t1 = T1::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
let t2 = T2::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
let t3 = T3::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
let t4 =
Json::<T>::from_tool_request(&ctx, &state, &args).map_err(Into::<Error>::into)?;
self(t1, t2, t3, t4).await
})
}
}
use crate::tool::{
BoxFuture, GuardLayer, Tool, ToolCatchError, ToolHandler, ToolHandlerService, ToolRequest,
};
use tower::util::BoxCloneService;
use tower_service::Service;
pub(crate) struct ExtractorToolHandler<S, F, T> {
state: S,
handler: F,
input_schema: Value,
_phantom: PhantomData<T>,
}
impl<S, F, T> ToolHandler for ExtractorToolHandler<S, F, T>
where
S: Clone + Send + Sync + 'static,
F: ExtractorHandler<S, T> + Clone,
T: Send + Sync + 'static,
{
fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
self.call_with_context(ctx, args)
}
fn call_with_context(
&self,
ctx: RequestContext,
args: Value,
) -> BoxFuture<'_, Result<CallToolResult>> {
let state = self.state.clone();
let handler = self.handler.clone();
Box::pin(async move { handler.call(ctx, state, args).await })
}
fn uses_context(&self) -> bool {
true
}
fn input_schema(&self) -> Value {
self.input_schema.clone()
}
}
#[doc(hidden)]
pub struct ToolBuilderWithExtractor<S, F, T> {
pub(crate) name: String,
pub(crate) title: Option<String>,
pub(crate) description: Option<String>,
pub(crate) output_schema: Option<Value>,
pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
pub(crate) task_support: crate::protocol::TaskSupportMode,
pub(crate) state: S,
pub(crate) handler: F,
pub(crate) input_schema: Value,
pub(crate) _phantom: PhantomData<T>,
}
impl<S, F, T> ToolBuilderWithExtractor<S, F, T>
where
S: Clone + Send + Sync + 'static,
F: ExtractorHandler<S, T> + Clone,
T: Send + Sync + 'static,
{
pub fn build(self) -> Tool {
let handler = ExtractorToolHandler {
state: self.state,
handler: self.handler,
input_schema: self.input_schema.clone(),
_phantom: PhantomData,
};
let handler_service = ToolHandlerService::new(handler);
let catch_error = ToolCatchError::new(handler_service);
let service = BoxCloneService::new(catch_error);
Tool {
name: self.name,
title: self.title,
description: self.description,
output_schema: self.output_schema,
icons: self.icons,
annotations: self.annotations,
task_support: self.task_support,
service,
input_schema: self.input_schema,
}
}
pub fn layer<L>(self, layer: L) -> ToolBuilderWithExtractorLayer<S, F, T, L> {
ToolBuilderWithExtractorLayer {
name: self.name,
title: self.title,
description: self.description,
output_schema: self.output_schema,
icons: self.icons,
annotations: self.annotations,
task_support: self.task_support,
state: self.state,
handler: self.handler,
input_schema: self.input_schema,
layer,
_phantom: PhantomData,
}
}
pub fn guard<G>(self, guard: G) -> ToolBuilderWithExtractorLayer<S, F, T, GuardLayer<G>>
where
G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
{
self.layer(GuardLayer::new(guard))
}
}
#[doc(hidden)]
pub struct ToolBuilderWithExtractorLayer<S, F, T, L> {
name: String,
title: Option<String>,
description: Option<String>,
output_schema: Option<Value>,
icons: Option<Vec<crate::protocol::ToolIcon>>,
annotations: Option<crate::protocol::ToolAnnotations>,
task_support: crate::protocol::TaskSupportMode,
state: S,
handler: F,
input_schema: Value,
layer: L,
_phantom: PhantomData<T>,
}
#[allow(private_bounds)]
impl<S, F, T, L> ToolBuilderWithExtractorLayer<S, F, T, L>
where
S: Clone + Send + Sync + 'static,
F: ExtractorHandler<S, T> + Clone,
T: Send + Sync + 'static,
L: tower::Layer<ToolHandlerService<ExtractorToolHandler<S, F, T>>>
+ Clone
+ Send
+ Sync
+ 'static,
L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
<L::Service as Service<ToolRequest>>::Error: std::fmt::Display + Send,
<L::Service as Service<ToolRequest>>::Future: Send,
{
pub fn build(self) -> Tool {
let handler = ExtractorToolHandler {
state: self.state,
handler: self.handler,
input_schema: self.input_schema.clone(),
_phantom: PhantomData,
};
let handler_service = ToolHandlerService::new(handler);
let layered = self.layer.layer(handler_service);
let catch_error = ToolCatchError::new(layered);
let service = BoxCloneService::new(catch_error);
Tool {
name: self.name,
title: self.title,
description: self.description,
output_schema: self.output_schema,
icons: self.icons,
annotations: self.annotations,
task_support: self.task_support,
service,
input_schema: self.input_schema,
}
}
pub fn layer<L2>(
self,
layer: L2,
) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<L2, L>> {
ToolBuilderWithExtractorLayer {
name: self.name,
title: self.title,
description: self.description,
output_schema: self.output_schema,
icons: self.icons,
annotations: self.annotations,
task_support: self.task_support,
state: self.state,
handler: self.handler,
input_schema: self.input_schema,
layer: tower::layer::util::Stack::new(layer, self.layer),
_phantom: PhantomData,
}
}
pub fn guard<G>(
self,
guard: G,
) -> ToolBuilderWithExtractorLayer<S, F, T, tower::layer::util::Stack<GuardLayer<G>, L>>
where
G: Fn(&ToolRequest) -> std::result::Result<(), String> + Clone + Send + Sync + 'static,
{
self.layer(GuardLayer::new(guard))
}
}
#[doc(hidden)]
#[deprecated(
since = "0.8.0",
note = "Use `ToolBuilderWithExtractor` via `extractor_handler` instead"
)]
pub struct ToolBuilderWithTypedExtractor<S, F, T, I> {
pub(crate) name: String,
pub(crate) title: Option<String>,
pub(crate) description: Option<String>,
pub(crate) output_schema: Option<Value>,
pub(crate) icons: Option<Vec<crate::protocol::ToolIcon>>,
pub(crate) annotations: Option<crate::protocol::ToolAnnotations>,
pub(crate) task_support: crate::protocol::TaskSupportMode,
pub(crate) state: S,
pub(crate) handler: F,
pub(crate) _phantom: PhantomData<(T, I)>,
}
#[allow(deprecated)]
impl<S, F, T, I> ToolBuilderWithTypedExtractor<S, F, T, I>
where
S: Clone + Send + Sync + 'static,
F: TypedExtractorHandler<S, T, I> + Clone,
T: Send + Sync + 'static,
I: JsonSchema + Send + Sync + 'static,
{
pub fn build(self) -> Tool {
let input_schema = {
let schema = schemars::schema_for!(I);
let schema = serde_json::to_value(schema).unwrap_or_else(|_| {
serde_json::json!({
"type": "object"
})
});
crate::tool::ensure_object_schema(schema)
};
let handler = TypedExtractorToolHandler {
state: self.state,
handler: self.handler,
input_schema: input_schema.clone(),
_phantom: PhantomData,
};
let handler_service = crate::tool::ToolHandlerService::new(handler);
let catch_error = ToolCatchError::new(handler_service);
let service = BoxCloneService::new(catch_error);
Tool {
name: self.name,
title: self.title,
description: self.description,
output_schema: self.output_schema,
icons: self.icons,
annotations: self.annotations,
task_support: self.task_support,
service,
input_schema,
}
}
}
struct TypedExtractorToolHandler<S, F, T, I> {
state: S,
handler: F,
input_schema: Value,
_phantom: PhantomData<(T, I)>,
}
#[allow(deprecated)]
impl<S, F, T, I> ToolHandler for TypedExtractorToolHandler<S, F, T, I>
where
S: Clone + Send + Sync + 'static,
F: TypedExtractorHandler<S, T, I> + Clone,
T: Send + Sync + 'static,
I: JsonSchema + Send + Sync + 'static,
{
fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
self.call_with_context(ctx, args)
}
fn call_with_context(
&self,
ctx: RequestContext,
args: Value,
) -> BoxFuture<'_, Result<CallToolResult>> {
let state = self.state.clone();
let handler = self.handler.clone();
Box::pin(async move { handler.call(ctx, state, args).await })
}
fn uses_context(&self) -> bool {
true
}
fn input_schema(&self) -> Value {
self.input_schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::RequestId;
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug, Deserialize, JsonSchema)]
struct TestInput {
name: String,
count: i32,
}
#[test]
fn test_json_extraction() {
let args = serde_json::json!({"name": "test", "count": 42});
let ctx = RequestContext::new(RequestId::Number(1));
let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
assert!(result.is_ok());
let Json(input) = result.unwrap();
assert_eq!(input.name, "test");
assert_eq!(input.count, 42);
}
#[test]
fn test_json_extraction_error() {
let args = serde_json::json!({"name": "test"}); let ctx = RequestContext::new(RequestId::Number(1));
let result = Json::<TestInput>::from_tool_request(&ctx, &(), &args);
assert!(result.is_err());
let rejection = result.unwrap_err();
assert!(rejection.message().contains("count"));
}
#[test]
fn test_state_extraction() {
let args = serde_json::json!({});
let ctx = RequestContext::new(RequestId::Number(1));
let state = Arc::new("my-state".to_string());
let result = State::<Arc<String>>::from_tool_request(&ctx, &state, &args);
assert!(result.is_ok());
let State(extracted) = result.unwrap();
assert_eq!(*extracted, "my-state");
}
#[test]
fn test_context_extraction() {
let args = serde_json::json!({});
let ctx = RequestContext::new(RequestId::Number(42));
let result = Context::from_tool_request(&ctx, &(), &args);
assert!(result.is_ok());
let extracted = result.unwrap();
assert_eq!(*extracted.request_id(), RequestId::Number(42));
}
#[test]
fn test_raw_args_extraction() {
let args = serde_json::json!({"foo": "bar", "baz": 123});
let ctx = RequestContext::new(RequestId::Number(1));
let result = RawArgs::from_tool_request(&ctx, &(), &args);
assert!(result.is_ok());
let RawArgs(extracted) = result.unwrap();
assert_eq!(extracted["foo"], "bar");
assert_eq!(extracted["baz"], 123);
}
#[test]
fn test_extension_extraction() {
use crate::context::Extensions;
#[derive(Clone, Debug, PartialEq)]
struct DatabasePool {
url: String,
}
let args = serde_json::json!({});
let mut extensions = Extensions::new();
extensions.insert(Arc::new(DatabasePool {
url: "postgres://localhost".to_string(),
}));
let ctx = RequestContext::new(RequestId::Number(1)).with_extensions(Arc::new(extensions));
let result = Extension::<Arc<DatabasePool>>::from_tool_request(&ctx, &(), &args);
assert!(result.is_ok());
let Extension(pool) = result.unwrap();
assert_eq!(pool.url, "postgres://localhost");
}
#[test]
fn test_extension_extraction_missing() {
#[derive(Clone, Debug)]
struct NotPresent;
let args = serde_json::json!({});
let ctx = RequestContext::new(RequestId::Number(1));
let result = Extension::<NotPresent>::from_tool_request(&ctx, &(), &args);
assert!(result.is_err());
let rejection = result.unwrap_err();
assert!(rejection.type_name().contains("NotPresent"));
}
#[tokio::test]
async fn test_single_extractor_handler() {
let handler = |Json(input): Json<TestInput>| async move {
Ok(CallToolResult::text(format!(
"{}: {}",
input.name, input.count
)))
};
let ctx = RequestContext::new(RequestId::Number(1));
let args = serde_json::json!({"name": "test", "count": 5});
let result: Result<CallToolResult> =
ExtractorHandler::<(), (Json<TestInput>,)>::call(handler, ctx, (), args).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_two_extractor_handler() {
let handler = |State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
Ok(CallToolResult::text(format!(
"{}: {} - {}",
state, input.name, input.count
)))
};
let ctx = RequestContext::new(RequestId::Number(1));
let state = Arc::new("prefix".to_string());
let args = serde_json::json!({"name": "test", "count": 5});
let result: Result<CallToolResult> = ExtractorHandler::<
Arc<String>,
(State<Arc<String>>, Json<TestInput>),
>::call(handler, ctx, state, args)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_three_extractor_handler() {
let handler = |State(state): State<Arc<String>>,
ctx: Context,
Json(input): Json<TestInput>| async move {
assert!(!ctx.is_cancelled());
Ok(CallToolResult::text(format!(
"{}: {} - {}",
state, input.name, input.count
)))
};
let ctx = RequestContext::new(RequestId::Number(1));
let state = Arc::new("prefix".to_string());
let args = serde_json::json!({"name": "test", "count": 5});
let result: Result<CallToolResult> = ExtractorHandler::<
Arc<String>,
(State<Arc<String>>, Context, Json<TestInput>),
>::call(handler, ctx, state, args)
.await;
assert!(result.is_ok());
}
#[test]
fn test_json_schema_generation() {
let schema = Json::<TestInput>::schema();
assert!(schema.is_some());
let schema = schema.unwrap();
assert!(schema.get("properties").is_some());
}
#[test]
fn test_rejection_into_error() {
let rejection = Rejection::new("test error");
let error: Error = rejection.into();
assert!(error.to_string().contains("test error"));
}
#[test]
fn test_json_rejection() {
let rejection = JsonRejection::new("missing field `name`");
assert_eq!(rejection.message(), "missing field `name`");
assert!(rejection.path().is_none());
assert!(rejection.to_string().contains("Invalid input"));
let rejection = JsonRejection::with_path("expected string", "users[0].name");
assert_eq!(rejection.message(), "expected string");
assert_eq!(rejection.path(), Some("users[0].name"));
assert!(rejection.to_string().contains("users[0].name"));
let error: Error = rejection.into();
assert!(error.to_string().contains("users[0].name"));
}
#[test]
fn test_json_rejection_from_serde_error() {
#[derive(Debug, serde::Deserialize)]
struct TestStruct {
#[allow(dead_code)]
name: String,
}
let result: std::result::Result<TestStruct, _> =
serde_json::from_value(serde_json::json!({"count": 42}));
assert!(result.is_err());
let rejection: JsonRejection = result.unwrap_err().into();
assert!(rejection.message().contains("name"));
}
#[test]
fn test_extension_rejection() {
let rejection = ExtensionRejection::not_found::<String>();
assert!(rejection.type_name().contains("String"));
assert!(rejection.to_string().contains("not found"));
assert!(rejection.to_string().contains("with_state"));
let error: Error = rejection.into();
assert!(error.to_string().contains("not found"));
}
#[tokio::test]
async fn test_tool_builder_extractor_handler() {
use crate::ToolBuilder;
let state = Arc::new("shared-state".to_string());
let tool =
ToolBuilder::new("test_extractor")
.description("Test extractor handler")
.extractor_handler(
state,
|State(state): State<Arc<String>>,
ctx: Context,
Json(input): Json<TestInput>| async move {
assert!(!ctx.is_cancelled());
Ok(CallToolResult::text(format!(
"{}: {} - {}",
state, input.name, input.count
)))
},
)
.build();
assert_eq!(tool.name, "test_extractor");
assert_eq!(tool.description.as_deref(), Some("Test extractor handler"));
let result = tool
.call(serde_json::json!({"name": "test", "count": 42}))
.await;
assert!(!result.is_error);
}
#[tokio::test]
#[allow(deprecated)]
async fn test_tool_builder_extractor_handler_typed() {
use crate::ToolBuilder;
let state = Arc::new("typed-state".to_string());
let tool = ToolBuilder::new("test_typed")
.description("Test typed extractor handler")
.extractor_handler_typed::<_, _, _, TestInput>(
state,
|State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
Ok(CallToolResult::text(format!(
"{}: {} - {}",
state, input.name, input.count
)))
},
)
.build();
assert_eq!(tool.name, "test_typed");
let def = tool.definition();
let schema = def.input_schema;
assert!(schema.get("properties").is_some());
let result = tool
.call(serde_json::json!({"name": "world", "count": 99}))
.await;
assert!(!result.is_error);
}
#[tokio::test]
async fn test_extractor_handler_auto_schema() {
use crate::ToolBuilder;
let state = Arc::new("auto-schema".to_string());
let tool = ToolBuilder::new("test_auto_schema")
.description("Test auto schema detection")
.extractor_handler(
state,
|State(state): State<Arc<String>>, Json(input): Json<TestInput>| async move {
Ok(CallToolResult::text(format!(
"{}: {} - {}",
state, input.name, input.count
)))
},
)
.build();
let def = tool.definition();
let schema = def.input_schema;
assert!(
schema.get("properties").is_some(),
"Schema should have properties from TestInput, got: {}",
schema
);
let props = schema.get("properties").unwrap();
assert!(
props.get("name").is_some(),
"Schema should have 'name' property"
);
assert!(
props.get("count").is_some(),
"Schema should have 'count' property"
);
let result = tool
.call(serde_json::json!({"name": "world", "count": 99}))
.await;
assert!(!result.is_error);
}
#[test]
fn test_extractor_handler_no_json_fallback() {
use crate::ToolBuilder;
let tool = ToolBuilder::new("test_no_json")
.description("Test no json fallback")
.extractor_handler((), |RawArgs(args): RawArgs| async move {
Ok(CallToolResult::json(args))
})
.build();
let def = tool.definition();
let schema = def.input_schema;
assert_eq!(
schema.get("type").and_then(|v| v.as_str()),
Some("object"),
"Schema should be generic object"
);
assert_eq!(
schema.get("additionalProperties").and_then(|v| v.as_bool()),
Some(true),
"Schema should allow additional properties"
);
assert!(
schema.get("properties").is_none(),
"Generic schema should not have specific properties"
);
}
#[tokio::test]
async fn test_extractor_handler_with_layer() {
use crate::ToolBuilder;
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let state = Arc::new("layered".to_string());
let tool = ToolBuilder::new("test_extractor_layer")
.description("Test extractor handler with layer")
.extractor_handler(
state,
|State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
},
)
.layer(TimeoutLayer::new(Duration::from_secs(5)))
.build();
let result = tool
.call(serde_json::json!({"name": "test", "count": 1}))
.await;
assert!(!result.is_error);
assert_eq!(result.first_text().unwrap(), "layered: test");
let def = tool.definition();
let schema = def.input_schema;
assert!(
schema.get("properties").is_some(),
"Schema should have properties even with layer"
);
}
#[tokio::test]
async fn test_extractor_handler_with_timeout_layer() {
use crate::ToolBuilder;
use std::time::Duration;
use tower::timeout::TimeoutLayer;
let tool = ToolBuilder::new("test_extractor_timeout")
.description("Test extractor handler timeout")
.extractor_handler((), |Json(input): Json<TestInput>| async move {
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(CallToolResult::text(input.name.to_string()))
})
.layer(TimeoutLayer::new(Duration::from_millis(50)))
.build();
let result = tool
.call(serde_json::json!({"name": "slow", "count": 1}))
.await;
assert!(result.is_error);
let msg = result.first_text().unwrap().to_lowercase();
assert!(
msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
"Expected timeout error, got: {}",
msg
);
}
#[tokio::test]
async fn test_extractor_handler_with_multiple_layers() {
use crate::ToolBuilder;
use std::time::Duration;
use tower::limit::ConcurrencyLimitLayer;
use tower::timeout::TimeoutLayer;
let state = Arc::new("multi".to_string());
let tool = ToolBuilder::new("test_multi_layer")
.description("Test multiple layers")
.extractor_handler(
state,
|State(s): State<Arc<String>>, Json(input): Json<TestInput>| async move {
Ok(CallToolResult::text(format!("{}: {}", s, input.name)))
},
)
.layer(TimeoutLayer::new(Duration::from_secs(5)))
.layer(ConcurrencyLimitLayer::new(10))
.build();
let result = tool
.call(serde_json::json!({"name": "test", "count": 1}))
.await;
assert!(!result.is_error);
assert_eq!(result.first_text().unwrap(), "multi: test");
}
}