use std::marker::PhantomData;
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::{
agent::{Agent, AgentBuilder, WithBuilderTools},
completion::{Completion, CompletionError, CompletionModel, ToolDefinition, Usage},
message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
tool::Tool,
vector_store::VectorStoreIndexDyn,
wasm_compat::{WasmCompatSend, WasmCompatSync},
};
const SUBMIT_TOOL_NAME: &str = "submit";
#[derive(Debug, Clone)]
pub struct ExtractionResponse<T> {
pub data: T,
pub usage: Usage,
}
#[derive(Debug, thiserror::Error)]
pub enum ExtractionError {
#[error("No data extracted")]
NoData,
#[error("Failed to deserialize the extracted data: {0}")]
DeserializationError(#[from] serde_json::Error),
#[error("CompletionError: {0}")]
CompletionError(#[from] CompletionError),
}
pub struct Extractor<M, T>
where
M: CompletionModel,
T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
{
agent: Agent<M>,
_t: PhantomData<T>,
retries: u64,
}
impl<M, T> Extractor<M, T>
where
M: CompletionModel,
T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
{
pub async fn extract(
&self,
text: impl Into<Message> + WasmCompatSend,
) -> Result<T, ExtractionError> {
let mut last_error = None;
let text_message = text.into();
for i in 0..=self.retries {
tracing::debug!(
"Attempting to extract JSON. Retries left: {retries}",
retries = self.retries - i
);
let attempt_text = text_message.clone();
match self.extract_json_with_usage(attempt_text, vec![]).await {
Ok((data, _usage)) => return Ok(data),
Err(e) => {
tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or(ExtractionError::NoData))
}
pub async fn extract_with_chat_history(
&self,
text: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> Result<T, ExtractionError> {
let mut last_error = None;
let text_message = text.into();
for i in 0..=self.retries {
tracing::debug!(
"Attempting to extract JSON. Retries left: {retries}",
retries = self.retries - i
);
let attempt_text = text_message.clone();
match self
.extract_json_with_usage(attempt_text, chat_history.clone())
.await
{
Ok((data, _usage)) => return Ok(data),
Err(e) => {
tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or(ExtractionError::NoData))
}
pub async fn extract_with_usage(
&self,
text: impl Into<Message> + WasmCompatSend,
) -> Result<ExtractionResponse<T>, ExtractionError> {
let mut last_error = None;
let text_message = text.into();
let mut usage = Usage::new();
for i in 0..=self.retries {
tracing::debug!(
"Attempting to extract JSON. Retries left: {retries}",
retries = self.retries - i
);
let attempt_text = text_message.clone();
match self.extract_json_with_usage(attempt_text, vec![]).await {
Ok((data, u)) => {
usage += u;
return Ok(ExtractionResponse { data, usage });
}
Err(e) => {
tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or(ExtractionError::NoData))
}
pub async fn extract_with_chat_history_with_usage(
&self,
text: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> Result<ExtractionResponse<T>, ExtractionError> {
let mut last_error = None;
let text_message = text.into();
let mut usage = Usage::new();
for i in 0..=self.retries {
tracing::debug!(
"Attempting to extract JSON. Retries left: {retries}",
retries = self.retries - i
);
let attempt_text = text_message.clone();
match self
.extract_json_with_usage(attempt_text, chat_history.clone())
.await
{
Ok((data, u)) => {
usage += u;
return Ok(ExtractionResponse { data, usage });
}
Err(e) => {
tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or(ExtractionError::NoData))
}
async fn extract_json_with_usage(
&self,
text: impl Into<Message> + WasmCompatSend,
messages: Vec<Message>,
) -> Result<(T, Usage), ExtractionError> {
let response = self.agent.completion(text, &messages).await?.send().await?;
let usage = response.usage;
if !response.choice.iter().any(|x| {
let AssistantContent::ToolCall(ToolCall {
function: ToolFunction { name, .. },
..
}) = x
else {
return false;
};
name == SUBMIT_TOOL_NAME
}) {
tracing::warn!(
"The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
);
}
let arguments = response
.choice
.into_iter()
.filter_map(|content| {
if let AssistantContent::ToolCall(ToolCall {
function: ToolFunction { arguments, name },
..
}) = content
{
if name == SUBMIT_TOOL_NAME {
Some(arguments)
} else {
None
}
} else {
None
}
})
.collect::<Vec<_>>();
if arguments.len() > 1 {
tracing::warn!(
"Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
);
}
let raw_data = if let Some(arg) = arguments.into_iter().next() {
arg
} else {
return Err(ExtractionError::NoData);
};
let data = serde_json::from_value(raw_data)?;
Ok((data, usage))
}
pub async fn get_inner(&self) -> &Agent<M> {
&self.agent
}
pub async fn into_inner(self) -> Agent<M> {
self.agent
}
}
pub struct ExtractorBuilder<M, T>
where
M: CompletionModel,
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
{
agent_builder: AgentBuilder<M, (), WithBuilderTools>,
_t: PhantomData<T>,
retries: Option<u64>,
}
impl<M, T> ExtractorBuilder<M, T>
where
M: CompletionModel,
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
{
pub fn new(model: M) -> Self {
Self {
agent_builder: AgentBuilder::new(model)
.preamble("\
You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
Use the `submit` function to submit the structured data.\n\
Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
")
.tool(SubmitTool::<T> {_t: PhantomData})
.tool_choice(ToolChoice::Required),
retries: None,
_t: PhantomData,
}
}
pub fn preamble(mut self, preamble: &str) -> Self {
self.agent_builder = self.agent_builder.append_preamble(&format!(
"\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
));
self
}
pub fn context(mut self, doc: &str) -> Self {
self.agent_builder = self.agent_builder.context(doc);
self
}
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.agent_builder = self.agent_builder.additional_params(params);
self
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.agent_builder = self.agent_builder.max_tokens(max_tokens);
self
}
pub fn retries(mut self, retries: u64) -> Self {
self.retries = Some(retries);
self
}
pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
self.agent_builder = self.agent_builder.tool_choice(choice);
self
}
pub fn build(self) -> Extractor<M, T> {
Extractor {
agent: self.agent_builder.build(),
_t: PhantomData,
retries: self.retries.unwrap_or(0),
}
}
pub fn dynamic_context(
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
) -> Self {
self.agent_builder = self.agent_builder.dynamic_context(sample, dynamic_context);
self
}
}
#[derive(Deserialize, Serialize)]
struct SubmitTool<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
{
_t: PhantomData<T>,
}
#[derive(Debug, thiserror::Error)]
#[error("SubmitError")]
struct SubmitError;
impl<T> Tool for SubmitTool<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
{
const NAME: &'static str = SUBMIT_TOOL_NAME;
type Error = SubmitError;
type Args = T;
type Output = T;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_string(),
description: "Submit the structured data you extracted from the provided text."
.to_string(),
parameters: json!(schema_for!(T)),
}
}
async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
Ok(data)
}
}