use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::AgentContext;
use entelix_core::LlmFacingSchema;
use entelix_core::error::{Error, Result};
use entelix_core::tools::{RetryHint, Tool, ToolEffect, ToolMetadata};
use schemars::JsonSchema;
use serde::Serialize;
use serde::de::DeserializeOwned;
#[async_trait]
pub trait SchemaTool: Send + Sync + 'static {
type Input: DeserializeOwned + JsonSchema + Send + 'static;
type Output: Serialize + Send + 'static;
const NAME: &'static str;
fn description(&self) -> &str;
fn effect(&self) -> ToolEffect {
ToolEffect::default()
}
fn retry_hint(&self) -> Option<RetryHint> {
None
}
fn version(&self) -> Option<&str> {
None
}
fn output_schema(&self) -> Option<serde_json::Value> {
None
}
fn idempotent(&self) -> bool {
false
}
async fn execute(&self, input: Self::Input, ctx: &AgentContext<()>) -> Result<Self::Output>;
}
pub trait SchemaToolExt: SchemaTool + Sized {
fn into_adapter(self) -> SchemaToolAdapter<Self> {
SchemaToolAdapter::new(self)
}
}
impl<T: SchemaTool> SchemaToolExt for T {}
pub struct SchemaToolAdapter<T: SchemaTool> {
inner: T,
metadata: Arc<ToolMetadata>,
}
impl<T: SchemaTool> SchemaToolAdapter<T> {
fn new(inner: T) -> Self {
let raw_schema: serde_json::Value = schemars::schema_for!(T::Input).to_value();
let input_schema = LlmFacingSchema::strip(&raw_schema);
let mut metadata = ToolMetadata::function(T::NAME, inner.description(), input_schema)
.with_effect(inner.effect())
.with_idempotent(inner.idempotent());
if let Some(version) = inner.version() {
metadata = metadata.with_version(version);
}
if let Some(hint) = inner.retry_hint() {
metadata = metadata.with_retry_hint(hint);
}
if let Some(output_schema) = inner.output_schema() {
metadata = metadata.with_output_schema(LlmFacingSchema::strip(&output_schema));
}
Self {
inner,
metadata: Arc::new(metadata),
}
}
pub const fn inner(&self) -> &T {
&self.inner
}
}
impl<T: SchemaTool> std::fmt::Debug for SchemaToolAdapter<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchemaToolAdapter")
.field("name", &self.metadata.name)
.field("inner", &std::any::type_name::<T>())
.finish()
}
}
#[async_trait]
impl<T: SchemaTool> Tool for SchemaToolAdapter<T> {
fn metadata(&self) -> &ToolMetadata {
&self.metadata
}
async fn execute(
&self,
input: serde_json::Value,
ctx: &AgentContext<()>,
) -> Result<serde_json::Value> {
let typed: T::Input = serde_json::from_value(input).map_err(|e| {
Error::invalid_request(format!(
"tool '{name}': input did not match schema: {e}",
name = T::NAME,
))
})?;
let output = self.inner.execute(typed, ctx).await?;
serde_json::to_value(output).map_err(Error::from)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Deserialize, JsonSchema)]
struct DoubleInput {
n: i64,
}
#[derive(Debug, Serialize, JsonSchema)]
struct DoubleOutput {
doubled: i64,
}
#[derive(Debug)]
struct DoubleTool;
#[async_trait]
impl SchemaTool for DoubleTool {
type Input = DoubleInput;
type Output = DoubleOutput;
const NAME: &'static str = "double";
fn description(&self) -> &str {
"Doubles an integer."
}
async fn execute(
&self,
input: Self::Input,
_ctx: &AgentContext<()>,
) -> Result<Self::Output> {
Ok(DoubleOutput {
doubled: input.n * 2,
})
}
}
#[derive(Debug)]
struct VersionedTool;
#[async_trait]
impl SchemaTool for VersionedTool {
type Input = DoubleInput;
type Output = DoubleOutput;
const NAME: &'static str = "versioned";
fn description(&self) -> &str {
"Versioned tool."
}
fn version(&self) -> Option<&str> {
Some("1.2.3")
}
fn effect(&self) -> ToolEffect {
ToolEffect::Mutating
}
async fn execute(
&self,
input: Self::Input,
_ctx: &AgentContext<()>,
) -> Result<Self::Output> {
Ok(DoubleOutput {
doubled: input.n + 1,
})
}
}
#[derive(Debug)]
struct RetryableTool;
#[async_trait]
impl SchemaTool for RetryableTool {
type Input = DoubleInput;
type Output = DoubleOutput;
const NAME: &'static str = "retryable";
fn description(&self) -> &str {
"Retryable tool."
}
fn retry_hint(&self) -> Option<RetryHint> {
Some(RetryHint::idempotent_transport())
}
fn output_schema(&self) -> Option<serde_json::Value> {
Some(serde_json::json!({
"type": "object",
"properties": {
"doubled": { "type": "integer" }
},
"required": ["doubled"]
}))
}
async fn execute(
&self,
input: Self::Input,
_ctx: &AgentContext<()>,
) -> Result<Self::Output> {
Ok(DoubleOutput { doubled: input.n })
}
}
#[tokio::test]
async fn typed_round_trip_through_adapter() {
let adapter = DoubleTool.into_adapter();
let ctx = AgentContext::default();
let out = adapter.execute(json!({"n": 21}), &ctx).await.unwrap();
assert_eq!(out, json!({"doubled": 42}));
}
#[tokio::test]
async fn malformed_input_surfaces_invalid_request() {
let adapter = DoubleTool.into_adapter();
let ctx = AgentContext::default();
let err = adapter
.execute(json!({"wrong_field": 21}), &ctx)
.await
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("double"), "{msg}");
assert!(msg.contains("input did not match schema"), "{msg}");
assert!(
!msg.contains("DoubleInput"),
"internal type name must not surface to the model: {msg}"
);
}
#[test]
fn metadata_carries_autogenerated_input_schema() {
let adapter = DoubleTool.into_adapter();
let meta = adapter.metadata();
assert_eq!(meta.name, "double");
assert_eq!(meta.description, "Doubles an integer.");
let schema_str = meta.input_schema.to_string();
assert!(schema_str.contains("\"n\""), "{schema_str}");
}
#[test]
fn metadata_propagates_effect_and_version() {
let adapter = VersionedTool.into_adapter();
let meta = adapter.metadata();
assert_eq!(meta.effect, ToolEffect::Mutating);
assert_eq!(meta.version.as_deref(), Some("1.2.3"));
}
#[test]
fn defaults_apply_when_overrides_absent() {
let adapter = DoubleTool.into_adapter();
let meta = adapter.metadata();
assert_eq!(meta.effect, ToolEffect::ReadOnly);
assert!(meta.version.is_none());
assert!(meta.retry_hint.is_none());
assert!(meta.output_schema.is_none());
}
#[test]
fn metadata_propagates_retry_hint() {
let adapter = RetryableTool.into_adapter();
let meta = adapter.metadata();
assert!(meta.retry_hint.is_some());
assert!(meta.idempotent);
}
#[test]
fn metadata_propagates_output_schema() {
let adapter = RetryableTool.into_adapter();
let meta = adapter.metadata();
let schema = meta
.output_schema
.as_ref()
.expect("output_schema override should land in metadata");
let schema_str = schema.to_string();
assert!(schema_str.contains("doubled"), "{schema_str}");
}
#[derive(Debug, Default, PartialEq, Eq)]
struct StatefulTool {
marker: u32,
}
#[async_trait]
impl SchemaTool for StatefulTool {
type Input = DoubleInput;
type Output = DoubleOutput;
const NAME: &'static str = "stateful";
fn description(&self) -> &str {
"Stateful tool."
}
async fn execute(
&self,
input: Self::Input,
_ctx: &AgentContext<()>,
) -> Result<Self::Output> {
Ok(DoubleOutput { doubled: input.n })
}
}
#[test]
fn inner_preserves_wrapped_instance_identity() {
let adapter = StatefulTool {
marker: 0xDEAD_BEEF,
}
.into_adapter();
assert_eq!(adapter.inner().marker, 0xDEAD_BEEF);
}
}