1use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40 message::{AssistantContent, Message, ToolCall, ToolFunction},
41 tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48 #[error("No data extracted")]
49 NoData,
50
51 #[error("Failed to deserialize the extracted data: {0}")]
52 DeserializationError(#[from] serde_json::Error),
53
54 #[error("CompletionError: {0}")]
55 CompletionError(#[from] CompletionError),
56}
57
58pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
60 agent: Agent<M>,
61 _t: PhantomData<T>,
62 retries: u64,
63}
64
65impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
66where
67 M: Sync,
68{
69 pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
76 let mut last_error = None;
77 let text_message = text.into();
78
79 for i in 0..=self.retries {
80 tracing::debug!(
81 "Attempting to extract JSON. Retries left: {retries}",
82 retries = self.retries - i
83 );
84 let attempt_text = text_message.clone();
85 match self.extract_json(attempt_text).await {
86 Ok(data) => return Ok(data),
87 Err(e) => {
88 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
89 last_error = Some(e);
90 }
91 }
92 }
93
94 Err(last_error.unwrap_or(ExtractionError::NoData))
96 }
97
98 async fn extract_json(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
99 let response = self.agent.completion(text, vec![]).await?.send().await?;
100
101 if !response.choice.iter().any(|x| {
102 let AssistantContent::ToolCall(ToolCall {
103 function: ToolFunction { name, .. },
104 ..
105 }) = x
106 else {
107 return false;
108 };
109
110 name == SUBMIT_TOOL_NAME
111 }) {
112 tracing::warn!(
113 "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
114 );
115 }
116
117 let arguments = response
118 .choice
119 .into_iter()
120 .filter_map(|content| {
122 if let AssistantContent::ToolCall(ToolCall {
123 function: ToolFunction { arguments, name },
124 ..
125 }) = content
126 {
127 if name == SUBMIT_TOOL_NAME {
128 Some(arguments)
129 } else {
130 None
131 }
132 } else {
133 None
134 }
135 })
136 .collect::<Vec<_>>();
137
138 if arguments.len() > 1 {
139 tracing::warn!(
140 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
141 );
142 }
143
144 let raw_data = if let Some(arg) = arguments.into_iter().next() {
145 arg
146 } else {
147 return Err(ExtractionError::NoData);
148 };
149
150 Ok(serde_json::from_value(raw_data)?)
151 }
152
153 pub async fn get_inner(&self) -> &Agent<M> {
154 &self.agent
155 }
156
157 pub async fn into_inner(self) -> Agent<M> {
158 self.agent
159 }
160}
161
162pub struct ExtractorBuilder<
164 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
165 M: CompletionModel,
166> {
167 agent_builder: AgentBuilder<M>,
168 _t: PhantomData<T>,
169 retries: Option<u64>,
170}
171
172impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
173 ExtractorBuilder<T, M>
174{
175 pub fn new(model: M) -> Self {
176 Self {
177 agent_builder: AgentBuilder::new(model)
178 .preamble("\
179 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
180 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
181 Use the `submit` function to submit the structured data.\n\
182 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
183 ")
184 .tool(SubmitTool::<T> {_t: PhantomData}),
185 retries: None,
186 _t: PhantomData,
187 }
188 }
189
190 pub fn preamble(mut self, preamble: &str) -> Self {
192 self.agent_builder = self.agent_builder.append_preamble(&format!(
193 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
194 ));
195 self
196 }
197
198 pub fn context(mut self, doc: &str) -> Self {
200 self.agent_builder = self.agent_builder.context(doc);
201 self
202 }
203
204 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
205 self.agent_builder = self.agent_builder.additional_params(params);
206 self
207 }
208
209 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
211 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
212 self
213 }
214
215 pub fn retries(mut self, retries: u64) -> Self {
217 self.retries = Some(retries);
218 self
219 }
220
221 pub fn build(self) -> Extractor<M, T> {
223 Extractor {
224 agent: self.agent_builder.build(),
225 _t: PhantomData,
226 retries: self.retries.unwrap_or(0),
227 }
228 }
229}
230
231#[derive(Deserialize, Serialize)]
232struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
233 _t: PhantomData<T>,
234}
235
236#[derive(Debug, thiserror::Error)]
237#[error("SubmitError")]
238struct SubmitError;
239
240impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
241 const NAME: &'static str = SUBMIT_TOOL_NAME;
242 type Error = SubmitError;
243 type Args = T;
244 type Output = T;
245
246 async fn definition(&self, _prompt: String) -> ToolDefinition {
247 ToolDefinition {
248 name: Self::NAME.to_string(),
249 description: "Submit the structured data you extracted from the provided text."
250 .to_string(),
251 parameters: json!(schema_for!(T)),
252 }
253 }
254
255 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
256 Ok(data)
257 }
258}