use std::{future::Future, pin::Pin, sync::Arc};
use serde::de::DeserializeOwned;
use serde_json::json;
use crate::schema::{JsonSchema, Schema};
pub trait Callback: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> Schema;
fn invoke(
&self,
args: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>>;
}
pub trait TypedCallback: Send + Sync {
type Args: DeserializeOwned + JsonSchema + Send;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn invoke_typed(
&self,
args: Self::Args,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>>;
}
impl<T: TypedCallback> Callback for T {
fn name(&self) -> &str {
<Self as TypedCallback>::name(self)
}
fn description(&self) -> &str {
<Self as TypedCallback>::description(self)
}
fn parameters_schema(&self) -> Schema {
Schema::for_type::<T::Args>()
}
fn invoke(
&self,
args: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>> {
let is_empty_object = args.is_object() && args.as_object().is_some_and(|m| m.is_empty());
let typed_args: Result<T::Args, _> = serde_json::from_value(args.clone()).or_else(|e| {
if is_empty_object {
serde_json::from_value(serde_json::Value::Null).map_err(|_| e)
} else {
Err(e)
}
});
Box::pin(async move {
let args = typed_args.map_err(|e| CallbackError::InvalidArguments(e.to_string()))?;
self.invoke_typed(args).await
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum CallbackError {
#[error("invalid arguments: {0}")]
InvalidArguments(String),
#[error("execution failed: {0}")]
ExecutionFailed(String),
#[error("callback not found: {0}")]
NotFound(String),
#[error("timeout")]
Timeout,
}
#[must_use]
pub fn empty_schema() -> Schema {
Schema::empty()
}
pub type DynamicHandler = Arc<
dyn Fn(
serde_json::Value,
)
-> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct DynamicCallback {
name: String,
description: String,
schema: Schema,
handler: DynamicHandler,
}
impl std::fmt::Debug for DynamicCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynamicCallback")
.field("name", &self.name)
.field("description", &self.description)
.field("schema", &self.schema)
.field("handler", &"<handler>")
.finish()
}
}
impl DynamicCallback {
#[must_use]
pub fn builder<F>(
name: impl Into<String>,
description: impl Into<String>,
handler: F,
) -> DynamicCallbackBuilder
where
F: Fn(
serde_json::Value,
)
-> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send>>
+ Send
+ Sync
+ 'static,
{
DynamicCallbackBuilder::new(name, description, handler)
}
#[must_use]
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
schema: Schema,
handler: DynamicHandler,
) -> Self {
Self {
name: name.into(),
description: description.into(),
schema,
handler,
}
}
}
impl Callback for DynamicCallback {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Schema {
self.schema.clone()
}
fn invoke(
&self,
args: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>> {
(self.handler)(args)
}
}
pub struct DynamicCallbackBuilder {
name: String,
description: String,
parameters: Vec<ParameterDef>,
custom_schema: Option<Schema>,
handler: DynamicHandler,
}
impl std::fmt::Debug for DynamicCallbackBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynamicCallbackBuilder")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &self.parameters)
.field("custom_schema", &self.custom_schema)
.field("handler", &"<handler>")
.finish()
}
}
#[derive(Debug, Clone)]
struct ParameterDef {
name: String,
json_type: String,
description: String,
required: bool,
}
impl DynamicCallbackBuilder {
fn new<F>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
where
F: Fn(
serde_json::Value,
)
-> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self {
name: name.into(),
description: description.into(),
parameters: Vec::new(),
custom_schema: None,
handler: Arc::new(handler),
}
}
#[must_use]
pub fn param(
mut self,
name: impl Into<String>,
json_type: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
self.parameters.push(ParameterDef {
name: name.into(),
json_type: json_type.into(),
description: description.into(),
required,
});
self
}
#[must_use]
pub fn schema(mut self, schema: Schema) -> Self {
self.custom_schema = Some(schema);
self
}
#[must_use]
pub fn build(self) -> DynamicCallback {
let schema = match self.custom_schema {
Some(s) => s,
None => {
if self.parameters.is_empty() {
Schema::empty()
} else {
let mut properties = serde_json::Map::new();
let mut required = Vec::new();
for param in &self.parameters {
properties.insert(
param.name.clone(),
json!({
"type": param.json_type,
"description": param.description
}),
);
if param.required {
required.push(serde_json::Value::String(param.name.clone()));
}
}
let schema_json = json!({
"type": "object",
"properties": properties,
"required": required
});
Schema::try_from_value(schema_json).unwrap_or_default()
}
}
};
DynamicCallback {
name: self.name,
description: self.description,
schema,
handler: self.handler,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use serde::Deserialize;
use serde_json::json;
#[test]
fn callback_error_invalid_arguments_displays_message() {
let error = CallbackError::InvalidArguments("missing field 'name'".to_string());
let display = format!("{}", error);
assert!(display.contains("invalid arguments"));
assert!(display.contains("missing field 'name'"));
}
#[test]
fn callback_error_execution_failed_displays_message() {
let error = CallbackError::ExecutionFailed("connection timeout".to_string());
let display = format!("{}", error);
assert!(display.contains("execution failed"));
assert!(display.contains("connection timeout"));
}
#[test]
fn callback_error_not_found_displays_name() {
let error = CallbackError::NotFound("unknown_callback".to_string());
let display = format!("{}", error);
assert!(display.contains("not found"));
assert!(display.contains("unknown_callback"));
}
#[test]
fn callback_error_timeout_displays_correctly() {
let error = CallbackError::Timeout;
let display = format!("{}", error);
assert!(display.contains("timeout"));
}
#[test]
fn callback_error_is_debug() {
let error = CallbackError::InvalidArguments("test".to_string());
let debug = format!("{:?}", error);
assert!(debug.contains("InvalidArguments"));
}
#[derive(Deserialize, JsonSchema)]
struct EchoArgs {
message: String,
}
struct EchoCallback;
impl TypedCallback for EchoCallback {
type Args = EchoArgs;
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echoes the message back"
}
fn invoke_typed(
&self,
args: EchoArgs,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>>
{
Box::pin(async move { Ok(json!({ "echoed": args.message })) })
}
}
#[derive(Deserialize, JsonSchema)]
struct AddArgs {
a: i64,
b: i64,
}
struct AddCallback;
impl TypedCallback for AddCallback {
type Args = AddArgs;
fn name(&self) -> &str {
"add"
}
fn description(&self) -> &str {
"Adds two numbers"
}
fn invoke_typed(
&self,
args: AddArgs,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>>
{
Box::pin(async move { Ok(json!(args.a + args.b)) })
}
}
struct NoArgsCallback;
impl TypedCallback for NoArgsCallback {
type Args = ();
fn name(&self) -> &str {
"no_args"
}
fn description(&self) -> &str {
"A callback with no arguments"
}
fn invoke_typed(
&self,
_args: (),
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>>
{
Box::pin(async move { Ok(json!("success")) })
}
}
struct FailingCallback;
impl TypedCallback for FailingCallback {
type Args = ();
fn name(&self) -> &str {
"fail"
}
fn description(&self) -> &str {
"Always fails"
}
fn invoke_typed(
&self,
_args: (),
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, CallbackError>> + Send + '_>>
{
Box::pin(
async move { Err(CallbackError::ExecutionFailed("intentional failure".into())) },
)
}
}
#[test]
fn typed_callback_name_and_description() {
let callback = EchoCallback;
assert_eq!(TypedCallback::name(&callback), "echo");
assert_eq!(
TypedCallback::description(&callback),
"Echoes the message back"
);
}
#[test]
fn typed_callback_generates_schema() {
let callback = EchoCallback;
let schema = Callback::parameters_schema(&callback);
let value = schema.to_value();
let properties = value.get("properties").expect("should have properties");
assert!(properties.get("message").is_some());
}
#[test]
fn typed_callback_no_args_generates_empty_schema() {
let callback = NoArgsCallback;
let schema = Callback::parameters_schema(&callback);
let value = schema.to_value();
assert!(value.is_object() || value.is_boolean());
}
#[tokio::test]
async fn typed_callback_invoke_with_valid_args() {
let callback = EchoCallback;
let args = json!({ "message": "hello world" });
let result = Callback::invoke(&callback, args).await;
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["echoed"], "hello world");
}
#[tokio::test]
async fn typed_callback_invoke_with_numeric_args() {
let callback = AddCallback;
let args = json!({ "a": 10, "b": 32 });
let result = Callback::invoke(&callback, args).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!(42));
}
#[tokio::test]
async fn typed_callback_invoke_with_null_for_unit() {
let callback = NoArgsCallback;
let args = serde_json::Value::Null;
let result = Callback::invoke(&callback, args).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!("success"));
}
#[tokio::test]
async fn typed_callback_invoke_with_missing_required_field() {
let callback = EchoCallback;
let args = json!({});
let result = Callback::invoke(&callback, args).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(matches!(error, CallbackError::InvalidArguments(_)));
}
#[tokio::test]
async fn typed_callback_invoke_with_wrong_type() {
let callback = EchoCallback;
let args = json!({ "message": 12345 });
let result = Callback::invoke(&callback, args).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(matches!(error, CallbackError::InvalidArguments(_)));
}
#[tokio::test]
async fn typed_callback_can_return_error() {
let callback = FailingCallback;
let args = serde_json::Value::Null;
let result = Callback::invoke(&callback, args).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(matches!(error, CallbackError::ExecutionFailed(_)));
}
#[test]
fn typed_callback_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<EchoCallback>();
assert_send_sync::<NoArgsCallback>();
}
#[test]
fn dynamic_callback_builder_basic() {
let callback = DynamicCallback::builder("test", "A test callback", |_args| {
Box::pin(async move { Ok(json!({"ok": true})) })
})
.build();
assert_eq!(callback.name(), "test");
assert_eq!(callback.description(), "A test callback");
}
#[test]
fn dynamic_callback_builder_with_params() {
let callback = DynamicCallback::builder("greet", "Greets someone", |_args| {
Box::pin(async move { Ok(json!({})) })
})
.param("name", "string", "The person's name", true)
.param("formal", "boolean", "Use formal greeting", false)
.build();
let schema = callback.parameters_schema();
let value = schema.to_value();
let properties = value.get("properties").expect("should have properties");
assert!(properties.get("name").is_some());
assert!(properties.get("formal").is_some());
let required = value.get("required").expect("should have required");
let required_arr = required.as_array().expect("required should be array");
assert!(required_arr.contains(&json!("name")));
assert!(!required_arr.contains(&json!("formal")));
}
#[test]
fn dynamic_callback_builder_no_params_empty_schema() {
let callback = DynamicCallback::builder("simple", "No params", |_args| {
Box::pin(async move { Ok(json!(null)) })
})
.build();
let schema = callback.parameters_schema();
let value = schema.to_value();
assert!(value.is_object() || value.is_boolean());
}
#[test]
fn dynamic_callback_builder_with_custom_schema() {
let custom_schema = Schema::try_from_value(json!({
"type": "object",
"properties": {
"mode": {
"type": "string",
"enum": ["fast", "slow"]
}
},
"required": ["mode"]
}))
.unwrap();
let callback = DynamicCallback::builder("custom", "Custom schema", |_args| {
Box::pin(async move { Ok(json!({})) })
})
.schema(custom_schema)
.build();
let schema = callback.parameters_schema();
let value = schema.to_value();
let properties = value.get("properties").unwrap();
let mode = properties.get("mode").unwrap();
assert!(mode.get("enum").is_some());
}
#[test]
fn dynamic_callback_custom_schema_overrides_params() {
let custom_schema = Schema::try_from_value(json!({
"type": "object",
"properties": {
"override": { "type": "string" }
}
}))
.unwrap();
let callback = DynamicCallback::builder("test", "Test", |_args| {
Box::pin(async move { Ok(json!({})) })
})
.param("ignored", "string", "This should be ignored", true)
.schema(custom_schema)
.build();
let schema = callback.parameters_schema();
let value = schema.to_value();
let properties = value.get("properties").unwrap();
assert!(properties.get("override").is_some());
assert!(properties.get("ignored").is_none());
}
#[tokio::test]
async fn dynamic_callback_invoke_success() {
let callback = DynamicCallback::builder("echo", "Echo", |args| {
Box::pin(async move {
let msg = args
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("default");
Ok(json!({ "echoed": msg }))
})
})
.param("message", "string", "Message to echo", true)
.build();
let result = callback.invoke(json!({ "message": "hello" })).await;
assert!(result.is_ok());
assert_eq!(result.unwrap()["echoed"], "hello");
}
#[tokio::test]
async fn dynamic_callback_invoke_can_fail() {
let callback = DynamicCallback::builder("fail", "Always fails", |_args| {
Box::pin(async move { Err(CallbackError::ExecutionFailed("boom".into())) })
})
.build();
let result = callback.invoke(json!({})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn dynamic_callback_invoke_with_validation() {
let callback = DynamicCallback::builder("validate", "Validates input", |args| {
Box::pin(async move {
let value = args
.get("value")
.and_then(|v| v.as_i64())
.ok_or_else(|| CallbackError::InvalidArguments("missing 'value'".into()))?;
if value < 0 {
return Err(CallbackError::InvalidArguments("value must be >= 0".into()));
}
Ok(json!({ "validated": value }))
})
})
.param("value", "integer", "Value to validate", true)
.build();
let result = callback.invoke(json!({ "value": 42 })).await;
assert!(result.is_ok());
let result = callback.invoke(json!({ "value": -5 })).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
CallbackError::InvalidArguments(_)
));
let result = callback.invoke(json!({})).await;
assert!(result.is_err());
}
#[test]
fn dynamic_callback_is_clone() {
let callback = DynamicCallback::builder("test", "Test", |_args| {
Box::pin(async move { Ok(json!({})) })
})
.build();
let cloned = callback.clone();
assert_eq!(cloned.name(), callback.name());
}
#[test]
fn dynamic_callback_is_debug() {
let callback = DynamicCallback::builder("test", "Test callback", |_args| {
Box::pin(async move { Ok(json!({})) })
})
.build();
let debug = format!("{:?}", callback);
assert!(debug.contains("DynamicCallback"));
assert!(debug.contains("test"));
}
#[test]
fn dynamic_callback_builder_is_debug() {
let builder = DynamicCallback::builder("test", "Test", |_args| {
Box::pin(async move { Ok(json!({})) })
})
.param("x", "string", "A param", true);
let debug = format!("{:?}", builder);
assert!(debug.contains("DynamicCallbackBuilder"));
}
#[test]
fn dynamic_callback_new_direct_construction() {
let handler: DynamicHandler = Arc::new(|_args| Box::pin(async move { Ok(json!({})) }));
let callback =
DynamicCallback::new("direct", "Directly constructed", Schema::empty(), handler);
assert_eq!(callback.name(), "direct");
assert_eq!(callback.description(), "Directly constructed");
}
#[test]
fn empty_schema_returns_valid_schema() {
let schema = empty_schema();
let value = schema.to_value();
assert!(value.is_object() || value.is_boolean());
}
#[test]
fn callbacks_can_be_boxed_as_trait_objects() {
let typed: Box<dyn Callback> = Box::new(EchoCallback);
let dynamic: Box<dyn Callback> = Box::new(
DynamicCallback::builder("dyn", "Dynamic", |_| Box::pin(async { Ok(json!({})) }))
.build(),
);
let callbacks: Vec<Box<dyn Callback>> = vec![typed, dynamic];
assert_eq!(callbacks.len(), 2);
assert_eq!(callbacks[0].name(), "echo");
assert_eq!(callbacks[1].name(), "dyn");
}
#[tokio::test]
async fn callback_trait_object_invoke() {
let callback: Box<dyn Callback> = Box::new(AddCallback);
let result = callback.invoke(json!({ "a": 5, "b": 7 })).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!(12));
}
#[test]
fn dynamic_callback_supports_all_json_types() {
let callback =
DynamicCallback::builder("types", "All types", |_| Box::pin(async { Ok(json!({})) }))
.param("str_param", "string", "A string", true)
.param("num_param", "number", "A number", true)
.param("int_param", "integer", "An integer", true)
.param("bool_param", "boolean", "A boolean", true)
.param("obj_param", "object", "An object", false)
.param("arr_param", "array", "An array", false)
.build();
let schema = callback.parameters_schema();
let value = schema.to_value();
let properties = value.get("properties").unwrap();
assert_eq!(properties["str_param"]["type"], "string");
assert_eq!(properties["num_param"]["type"], "number");
assert_eq!(properties["int_param"]["type"], "integer");
assert_eq!(properties["bool_param"]["type"], "boolean");
assert_eq!(properties["obj_param"]["type"], "object");
assert_eq!(properties["arr_param"]["type"], "array");
}
#[test]
fn dynamic_callback_param_descriptions_preserved() {
let callback = DynamicCallback::builder("desc", "Test descriptions", |_| {
Box::pin(async { Ok(json!({})) })
})
.param("field", "string", "This is a detailed description", true)
.build();
let schema = callback.parameters_schema();
let value = schema.to_value();
let properties = value.get("properties").unwrap();
assert_eq!(
properties["field"]["description"],
"This is a detailed description"
);
}
}