use crate::schemas::structured_output::{
validate_against_schema, StructuredOutputError, StructuredOutputStrategy, ToolStrategy,
};
use crate::tools::{Tool, ToolResult, ToolRuntime};
use async_trait::async_trait;
use serde_json::{json, Value};
use std::error::Error;
use std::sync::Arc;
pub struct StructuredOutputTool<T> {
strategy: ToolStrategy<T>,
schema_name: String,
}
impl<T> StructuredOutputTool<T>
where
T: crate::schemas::StructuredOutputSchema,
{
pub fn new(strategy: ToolStrategy<T>) -> Self {
let schema_name = strategy.schema_name();
Self {
strategy,
schema_name,
}
}
pub fn schema(&self) -> Value {
self.strategy.schema()
}
pub fn schema_name(&self) -> &str {
&self.schema_name
}
}
#[async_trait]
impl<T> Tool for StructuredOutputTool<T>
where
T: crate::schemas::StructuredOutputSchema + Send + Sync + 'static,
{
fn name(&self) -> String {
self.schema_name.clone()
}
fn description(&self) -> String {
format!(
"Tool for returning structured output in the format: {}",
self.schema_name
)
}
fn parameters(&self) -> Value {
let mut schema = self.schema();
if let Some(schema_obj) = schema.as_object_mut() {
schema_obj.remove("$schema");
}
json!({
"type": "object",
"properties": schema,
"required": self.get_required_fields()
})
}
async fn run(&self, input: Value) -> Result<String, crate::error::ToolError> {
validate_against_schema(&input, &self.schema())
.map_err(|e| crate::error::ToolError::InvalidInputError(e.to_string()))?;
serde_json::to_string(&input)
.map_err(|e| crate::error::ToolError::ExecutionError(e.to_string()))
}
async fn run_with_runtime(
&self,
input: Value,
_runtime: &ToolRuntime,
) -> Result<ToolResult, Box<dyn Error>> {
let result = self.run(input).await?;
Ok(ToolResult::Text(result))
}
fn requires_runtime(&self) -> bool {
false
}
}
impl<T> StructuredOutputTool<T>
where
T: crate::schemas::StructuredOutputSchema,
{
fn get_required_fields(&self) -> Vec<String> {
if let Some(schema_obj) = self.schema().as_object() {
if let Some(required) = schema_obj.get("required").and_then(|v| v.as_array()) {
return required
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
}
}
Vec::new()
}
}
pub async fn handle_structured_output_tool_call(
_tool_name: &str,
tool_input: &str,
_strategy: &dyn StructuredOutputStrategy,
) -> Result<Value, StructuredOutputError> {
let parsed: Value = serde_json::from_str(tool_input).map_err(|e| {
StructuredOutputError::ParseError(format!("Failed to parse tool input: {}", e))
})?;
validate_against_schema(&parsed, &_strategy.schema())?;
Ok(parsed)
}
#[allow(clippy::boxed_local)]
pub fn create_structured_output_tool<S: StructuredOutputStrategy>(
_strategy: Box<S>,
) -> Result<Arc<dyn Tool>, StructuredOutputError> {
Err(StructuredOutputError::SchemaError(
"Cannot create tool from trait object - use concrete type".to_string(),
))
}
#[derive(Debug, thiserror::Error)]
#[error("Multiple structured outputs returned: {tool_names:?}")]
pub struct MultipleStructuredOutputsError {
pub tool_names: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
#[error("Structured output validation failed: {message}")]
pub struct StructuredOutputValidationError {
pub message: String,
pub tool_name: String,
}
impl From<serde_json::Error> for StructuredOutputValidationError {
fn from(err: serde_json::Error) -> Self {
Self {
message: err.to_string(),
tool_name: "unknown".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schemas::structured_output::{StructuredOutputSchema, ToolStrategy};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, JsonSchema, Debug)]
struct TestOutput {
name: String,
age: i32,
}
impl StructuredOutputSchema for TestOutput {}
#[tokio::test]
async fn test_structured_output_tool() {
let strategy = ToolStrategy::<TestOutput>::new();
let tool = StructuredOutputTool::new(strategy);
let input = json!({
"name": "John",
"age": 30
});
let result = tool.run(input).await;
assert!(result.is_ok());
}
#[test]
fn test_get_required_fields() {
let strategy = ToolStrategy::<TestOutput>::new();
let tool = StructuredOutputTool::new(strategy);
let required = tool.get_required_fields();
assert!(required.is_empty() || required.contains(&"name".to_string()));
}
}