use anda_core::{
Agent, AgentOutput, BoxError, CompletionFeatures, CompletionRequest, FunctionDefinition,
Resource, Tool, ToolOutput, root_schema_for,
};
use schemars::JsonSchema;
use serde_json::Value;
use std::marker::PhantomData;
pub use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::context::{AgentCtx, BaseCtx};
#[derive(Debug, Clone)]
pub struct SubmitTool<T: JsonSchema + DeserializeOwned + Send + Sync> {
name: String,
schema: Value,
_t: PhantomData<T>,
}
impl<T> Default for SubmitTool<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
fn default() -> Self {
Self::new()
}
}
impl<T> SubmitTool<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
pub fn new() -> SubmitTool<T> {
let schema = root_schema_for::<T>();
let name = schema
.get("title")
.and_then(Value::as_str)
.unwrap_or("tool")
.to_ascii_lowercase();
SubmitTool {
name,
schema: schema.to_value(),
_t: PhantomData,
}
}
pub fn submit(&self, args: Value) -> Result<T, BoxError> {
serde_json::from_value(args).map_err(|err| format!("invalid args: {}", err).into())
}
}
impl<T> Tool<BaseCtx> for SubmitTool<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
type Args = T;
type Output = T;
fn name(&self) -> String {
format!("submit_{}", self.name)
}
fn description(&self) -> String {
"Submit the structured data you extracted from the provided text.".to_string()
}
fn definition(&self) -> FunctionDefinition {
FunctionDefinition {
name: self.name(),
description: self.description(),
parameters: self.schema.clone(),
strict: Some(true),
}
}
async fn call(
&self,
_ctx: BaseCtx,
args: Self::Args,
_resources: Vec<Resource>,
) -> Result<ToolOutput<Self::Output>, BoxError> {
Ok(ToolOutput::new(args))
}
}
#[derive(Debug, Clone)]
pub struct Extractor<T: JsonSchema + DeserializeOwned + Serialize + Send + Sync> {
tool: SubmitTool<T>,
instructions: String,
max_tokens: Option<usize>,
}
impl<T: JsonSchema + DeserializeOwned + Serialize + Send + Sync> Default for Extractor<T> {
fn default() -> Self {
Self::new(None, None)
}
}
impl<T: JsonSchema + DeserializeOwned + Serialize + Send + Sync> Extractor<T> {
pub fn new(max_tokens: Option<usize>, system_prompt: Option<String>) -> Self {
let tool = SubmitTool::new();
Self::new_with_tool(tool, max_tokens, system_prompt)
}
pub fn new_with_tool(
tool: SubmitTool<T>,
max_tokens: Option<usize>,
instructions: Option<String>,
) -> Self {
let tool_name = tool.name();
Self {
tool,
max_tokens,
instructions: instructions.unwrap_or_else(|| format!("\
You are an AI assistant whose purpose is to\
extract structured data from the provided text.\n\
You will have access to a `{tool_name}` function that defines the structure of the data to extract from the provided text.\n\
Use the `{tool_name}` function to submit the structured data.\n\
Be sure to fill out every field and ALWAYS CALL THE `{tool_name}` function, event with default values!!!.")),
}
}
pub async fn extract(
&self,
ctx: &impl CompletionFeatures,
prompt: String,
) -> Result<(T, AgentOutput), BoxError> {
let req = CompletionRequest {
instructions: self.instructions.clone(),
prompt,
tools: vec![self.tool.definition()],
tool_choice_required: true,
max_output_tokens: self.max_tokens,
..Default::default()
};
let mut res = ctx.completion(req, Vec::new()).await?;
if let Some(failed) = res.failed_reason {
return Err(failed.into());
}
if let Some(tool) = res.tool_calls.iter_mut().next() {
let result = self.tool.submit(tool.args.clone())?;
return Ok((result, res));
}
Err(format!("extract with {} failed, no tool_calls", self.tool.name()).into())
}
}
impl<T> Agent<AgentCtx> for Extractor<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
fn name(&self) -> String {
format!("{}_extractor", self.tool.name)
}
fn description(&self) -> String {
"Extract structured data from text using LLMs.".to_string()
}
async fn run(
&self,
ctx: AgentCtx,
prompt: String,
_resources: Vec<Resource>,
) -> Result<AgentOutput, BoxError> {
let (_, res) = self.extract(&ctx, prompt).await?;
Ok(res)
}
}
#[cfg(test)]
mod tests {
use anda_core::{AgentContext, AgentInput, ToolInput};
use serde_json::json;
use std::sync::Arc;
use super::*;
use crate::{engine::EngineBuilder, model::Model};
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
struct TestStruct {
name: String,
age: Option<u8>,
}
#[test]
fn test_definition() {
let tool = SubmitTool::<TestStruct>::new();
let definition = tool.definition();
assert_eq!(definition.name, "submit_teststruct");
let s = serde_json::to_string(&definition).unwrap();
println!("{}", s);
assert!(s.contains(r#""required":["name"]"#));
assert!(!s.contains("$schema"));
let agent = Extractor::<TestStruct>::default();
let definition = agent.definition();
assert_eq!(definition.name, "teststruct_extractor");
let s = serde_json::to_string(&definition).unwrap();
println!("{}", s);
assert_eq!(
definition.parameters["description"],
json!(
"Run this agent on a focused task. Provide a self-contained prompt with the goal, relevant context, constraints, and expected output."
)
);
assert_eq!(
definition.parameters["properties"]["prompt"]["minLength"],
json!(1)
);
assert_eq!(definition.parameters["additionalProperties"], json!(false));
assert!(!s.contains("$schema"));
}
#[tokio::test]
async fn test_with_ctx() {
let tool = SubmitTool::<TestStruct>::default();
let agent = Extractor::<TestStruct>::default();
let tool_name = tool.name();
let agent_name = agent.name();
let ctx = EngineBuilder::new()
.with_model(Model::mock_implemented())
.register_tool(Arc::new(tool))
.unwrap()
.register_agent(Arc::new(agent), None)
.unwrap()
.mock_ctx();
let (res, _) = ctx
.tool_call(ToolInput::new(
tool_name.clone(),
json!({"name":"Anda","age": 1}),
))
.await
.unwrap();
assert_eq!(res.output, json!({"name":"Anda","age": 1}));
let (res, _) = ctx
.tool_call(ToolInput::new(tool_name.clone(), json!({"name": "Anda"})))
.await
.unwrap();
assert_eq!(res.output, json!({"name": "Anda","age": null}));
let res = ctx
.tool_call(ToolInput::new(tool_name.clone(), json!({"name": 123})))
.await;
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("invalid args"));
let _res = ctx
.clone()
.agent_run(AgentInput::new(
agent_name.to_string(),
r#"{"name": "Anda"}"#.into(),
))
.await
.unwrap();
let res = ctx
.agent_run(AgentInput::new(
agent_name.to_string(),
r#"{"name": 123}"#.into(),
))
.await;
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("invalid args"));
}
}