use std::future::Future;
use std::pin::Pin;
use rig::completion::ToolDefinition;
use rig::tool::{ToolDyn, ToolError};
use serde_json::Value;
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[derive(Debug, Clone)]
pub struct DynamicSubmitTool {
name: String,
schema: Value,
}
impl DynamicSubmitTool {
pub fn new(schema: Value) -> Self {
Self {
name: "submit_result".to_string(),
schema,
}
}
pub fn with_name(name: impl Into<String>, schema: Value) -> Self {
Self {
name: name.into(),
schema,
}
}
pub fn schema(&self) -> &Value {
&self.schema
}
}
impl ToolDyn for DynamicSubmitTool {
fn name(&self) -> String {
self.name.clone()
}
fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
let def = ToolDefinition {
name: self.name.clone(),
description: "Submit the structured result. You MUST call this tool with your \
response formatted as JSON matching the provided schema."
.to_string(),
parameters: self.schema.clone(),
};
Box::pin(async move { def })
}
fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
Box::pin(async move {
let _: Value = serde_json::from_str(&args).map_err(|e| {
ToolError::ToolCallError(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("submit_result: invalid JSON: {}", e),
)))
})?;
Ok(args)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig::tool::ToolDyn;
use serde_json::json;
#[test]
fn submit_tool_has_correct_name() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let tool = DynamicSubmitTool::new(schema);
assert_eq!(tool.name(), "submit_result");
}
#[tokio::test]
async fn submit_tool_definition_has_schema_as_parameters() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer" }
},
"required": ["name"]
});
let tool = DynamicSubmitTool::new(schema.clone());
let def = tool.definition("test prompt".to_string()).await;
assert_eq!(def.name, "submit_result");
assert_eq!(def.parameters, schema);
assert!(def.description.contains("Submit"));
}
#[tokio::test]
async fn submit_tool_call_returns_args_as_is() {
let schema = json!({"type": "object"});
let tool = DynamicSubmitTool::new(schema);
let args = r#"{"name": "Alice", "age": 30}"#;
let result = tool.call(args.to_string()).await.unwrap();
assert_eq!(result, args);
}
#[tokio::test]
async fn submit_tool_call_rejects_invalid_json() {
let schema = json!({"type": "object"});
let tool = DynamicSubmitTool::new(schema);
let result = tool.call("not json".to_string()).await;
assert!(result.is_err());
}
#[test]
fn submit_tool_custom_name() {
let schema = json!({"type": "object"});
let tool = DynamicSubmitTool::with_name("output_json", schema);
assert_eq!(tool.name(), "output_json");
}
#[test]
fn submit_tool_schema_accessor() {
let schema = json!({
"type": "object",
"properties": { "x": { "type": "integer" } }
});
let tool = DynamicSubmitTool::new(schema.clone());
assert_eq!(tool.schema(), &schema);
}
#[tokio::test]
async fn submit_tool_with_nested_schema() {
let schema = json!({
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": { "type": "string" },
"address": {
"type": "object",
"properties": {
"city": { "type": "string" }
}
}
}
}
}
});
let tool = DynamicSubmitTool::new(schema.clone());
let def = tool.definition("test".to_string()).await;
assert_eq!(def.parameters, schema);
}
#[tokio::test]
async fn submit_tool_with_array_schema() {
let schema = json!({
"type": "object",
"properties": {
"items": {
"type": "array",
"items": { "type": "string" }
}
}
});
let tool = DynamicSubmitTool::new(schema);
let args = r#"{"items": ["a", "b", "c"]}"#;
let result = tool.call(args.to_string()).await.unwrap();
let parsed: Value = serde_json::from_str(&result).unwrap();
assert_eq!(parsed["items"].as_array().unwrap().len(), 3);
}
#[tokio::test]
async fn submit_tool_preserves_exact_json() {
let schema = json!({"type": "object"});
let tool = DynamicSubmitTool::new(schema);
let args = r#"{"key":"value","num":42}"#;
let result = tool.call(args.to_string()).await.unwrap();
assert_eq!(result, args);
}
}