use anda_core::{
Agent, AgentOutput, BoxError, CompletionFeatures, CompletionRequest, FunctionDefinition,
Resource, Tool, ToolOutput, root_schema_for,
};
use schemars::JsonSchema;
use serde_json::Value;
use std::marker::PhantomData;
pub use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::context::{AgentCtx, BaseCtx};
#[derive(Debug, Clone)]
pub struct SubmitTool<T: JsonSchema + DeserializeOwned + Send + Sync> {
name: String,
schema: Value,
_t: PhantomData<T>,
}
impl<T> Default for SubmitTool<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
fn default() -> Self {
Self::new()
}
}
impl<T> SubmitTool<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
pub fn new() -> SubmitTool<T> {
let schema = root_schema_for::<T>();
let name = schema
.get("title")
.and_then(Value::as_str)
.unwrap_or("tool")
.to_ascii_lowercase();
SubmitTool {
name,
schema: schema.to_value(),
_t: PhantomData,
}
}
pub fn submit(&self, args: Value) -> Result<T, BoxError> {
serde_json::from_value(args).map_err(|err| format!("invalid args: {}", err).into())
}
}
impl<T> Tool<BaseCtx> for SubmitTool<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
type Args = T;
type Output = T;
fn name(&self) -> String {
format!("submit_{}", self.name)
}
fn description(&self) -> String {
"Submit the structured data you extracted from the provided text.".to_string()
}
fn definition(&self) -> FunctionDefinition {
FunctionDefinition {
name: self.name(),
description: self.description(),
parameters: self.schema.clone(),
strict: Some(true),
}
}
async fn call(
&self,
_ctx: BaseCtx,
args: Self::Args,
_resources: Vec<Resource>,
) -> Result<ToolOutput<Self::Output>, BoxError> {
Ok(ToolOutput::new(args))
}
}
#[derive(Debug, Clone)]
pub struct Extractor<T: JsonSchema + DeserializeOwned + Serialize + Send + Sync> {
tool: SubmitTool<T>,
instructions: String,
max_tokens: Option<usize>,
}
impl<T: JsonSchema + DeserializeOwned + Serialize + Send + Sync> Default for Extractor<T> {
fn default() -> Self {
Self::new(None, None)
}
}
impl<T: JsonSchema + DeserializeOwned + Serialize + Send + Sync> Extractor<T> {
pub fn new(max_tokens: Option<usize>, system_prompt: Option<String>) -> Self {
let tool = SubmitTool::new();
Self::new_with_tool(tool, max_tokens, system_prompt)
}
pub fn new_with_tool(
tool: SubmitTool<T>,
max_tokens: Option<usize>,
instructions: Option<String>,
) -> Self {
let tool_name = tool.name();
Self {
tool,
max_tokens,
instructions: instructions.unwrap_or_else(|| format!("\
You are an AI assistant whose purpose is to\
extract structured data from the provided text.\n\
You will have access to a `{tool_name}` function that defines the structure of the data to extract from the provided text.\n\
Use the `{tool_name}` function to submit the structured data.\n\
Be sure to fill out every field and ALWAYS CALL THE `{tool_name}` function, event with default values!!!.")),
}
}
pub async fn extract(
&self,
ctx: &impl CompletionFeatures,
prompt: String,
) -> Result<(T, AgentOutput), BoxError> {
let req = CompletionRequest {
instructions: self.instructions.clone(),
prompt,
tools: vec![self.tool.definition()],
tool_choice_required: true,
max_output_tokens: self.max_tokens,
..Default::default()
};
let mut res = ctx.completion(req, Vec::new()).await?;
if let Some(failed) = res.failed_reason {
return Err(failed.into());
}
if let Some(tool) = res.tool_calls.iter_mut().next() {
let result = self.tool.submit(tool.args.clone())?;
return Ok((result, res));
}
Err(format!("extract with {} failed, no tool_calls", self.tool.name()).into())
}
}
impl<T> Agent<AgentCtx> for Extractor<T>
where
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync,
{
fn name(&self) -> String {
format!("{}_extractor", self.tool.name)
}
fn description(&self) -> String {
"Extract structured data from text using LLMs.".to_string()
}
async fn run(
&self,
ctx: AgentCtx,
prompt: String,
_resources: Vec<Resource>,
) -> Result<AgentOutput, BoxError> {
let (_, res) = self.extract(&ctx, prompt).await?;
Ok(res)
}
}