1use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder, WithBuilderTools},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition, Usage},
40 message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
41 tool::Tool,
42 vector_store::VectorStoreIndexDyn,
43 wasm_compat::{WasmCompatSend, WasmCompatSync},
44};
45
46const SUBMIT_TOOL_NAME: &str = "submit";
47
48#[derive(Debug, Clone)]
50pub struct ExtractionResponse<T> {
51 pub data: T,
53 pub usage: Usage,
55}
56
57#[derive(Debug, thiserror::Error)]
58pub enum ExtractionError {
59 #[error("No data extracted")]
60 NoData,
61
62 #[error("Failed to deserialize the extracted data: {0}")]
63 DeserializationError(#[from] serde_json::Error),
64
65 #[error("CompletionError: {0}")]
66 CompletionError(#[from] CompletionError),
67}
68
69pub struct Extractor<M, T>
71where
72 M: CompletionModel,
73 T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
74{
75 agent: Agent<M>,
76 _t: PhantomData<T>,
77 retries: u64,
78}
79
80impl<M, T> Extractor<M, T>
81where
82 M: CompletionModel,
83 T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
84{
85 pub async fn extract(
92 &self,
93 text: impl Into<Message> + WasmCompatSend,
94 ) -> Result<T, ExtractionError> {
95 let mut last_error = None;
96 let text_message = text.into();
97
98 for i in 0..=self.retries {
99 tracing::debug!(
100 "Attempting to extract JSON. Retries left: {retries}",
101 retries = self.retries - i
102 );
103 let attempt_text = text_message.clone();
104 match self.extract_json_with_usage(attempt_text, vec![]).await {
105 Ok((data, _usage)) => return Ok(data),
106 Err(e) => {
107 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
108 last_error = Some(e);
109 }
110 }
111 }
112
113 Err(last_error.unwrap_or(ExtractionError::NoData))
115 }
116
117 pub async fn extract_with_chat_history(
124 &self,
125 text: impl Into<Message> + WasmCompatSend,
126 chat_history: Vec<Message>,
127 ) -> Result<T, ExtractionError> {
128 let mut last_error = None;
129 let text_message = text.into();
130
131 for i in 0..=self.retries {
132 tracing::debug!(
133 "Attempting to extract JSON. Retries left: {retries}",
134 retries = self.retries - i
135 );
136 let attempt_text = text_message.clone();
137 match self
138 .extract_json_with_usage(attempt_text, chat_history.clone())
139 .await
140 {
141 Ok((data, _usage)) => return Ok(data),
142 Err(e) => {
143 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
144 last_error = Some(e);
145 }
146 }
147 }
148
149 Err(last_error.unwrap_or(ExtractionError::NoData))
151 }
152
153 pub async fn extract_with_usage(
164 &self,
165 text: impl Into<Message> + WasmCompatSend,
166 ) -> Result<ExtractionResponse<T>, ExtractionError> {
167 let mut last_error = None;
168 let text_message = text.into();
169 let mut usage = Usage::new();
170
171 for i in 0..=self.retries {
172 tracing::debug!(
173 "Attempting to extract JSON. Retries left: {retries}",
174 retries = self.retries - i
175 );
176 let attempt_text = text_message.clone();
177 match self.extract_json_with_usage(attempt_text, vec![]).await {
178 Ok((data, u)) => {
179 usage += u;
180 return Ok(ExtractionResponse { data, usage });
181 }
182 Err(e) => {
183 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
184 last_error = Some(e);
185 }
186 }
187 }
188
189 Err(last_error.unwrap_or(ExtractionError::NoData))
191 }
192
193 pub async fn extract_with_chat_history_with_usage(
205 &self,
206 text: impl Into<Message> + WasmCompatSend,
207 chat_history: Vec<Message>,
208 ) -> Result<ExtractionResponse<T>, ExtractionError> {
209 let mut last_error = None;
210 let text_message = text.into();
211 let mut usage = Usage::new();
212
213 for i in 0..=self.retries {
214 tracing::debug!(
215 "Attempting to extract JSON. Retries left: {retries}",
216 retries = self.retries - i
217 );
218 let attempt_text = text_message.clone();
219 match self
220 .extract_json_with_usage(attempt_text, chat_history.clone())
221 .await
222 {
223 Ok((data, u)) => {
224 usage += u;
225 return Ok(ExtractionResponse { data, usage });
226 }
227 Err(e) => {
228 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
229 last_error = Some(e);
230 }
231 }
232 }
233
234 Err(last_error.unwrap_or(ExtractionError::NoData))
236 }
237
238 async fn extract_json_with_usage(
239 &self,
240 text: impl Into<Message> + WasmCompatSend,
241 messages: Vec<Message>,
242 ) -> Result<(T, Usage), ExtractionError> {
243 let response = self.agent.completion(text, messages).await?.send().await?;
244 let usage = response.usage;
245
246 if !response.choice.iter().any(|x| {
247 let AssistantContent::ToolCall(ToolCall {
248 function: ToolFunction { name, .. },
249 ..
250 }) = x
251 else {
252 return false;
253 };
254
255 name == SUBMIT_TOOL_NAME
256 }) {
257 tracing::warn!(
258 "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
259 );
260 }
261
262 let arguments = response
263 .choice
264 .into_iter()
265 .filter_map(|content| {
267 if let AssistantContent::ToolCall(ToolCall {
268 function: ToolFunction { arguments, name },
269 ..
270 }) = content
271 {
272 if name == SUBMIT_TOOL_NAME {
273 Some(arguments)
274 } else {
275 None
276 }
277 } else {
278 None
279 }
280 })
281 .collect::<Vec<_>>();
282
283 if arguments.len() > 1 {
284 tracing::warn!(
285 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
286 );
287 }
288
289 let raw_data = if let Some(arg) = arguments.into_iter().next() {
290 arg
291 } else {
292 return Err(ExtractionError::NoData);
293 };
294
295 let data = serde_json::from_value(raw_data)?;
296 Ok((data, usage))
297 }
298
299 pub async fn get_inner(&self) -> &Agent<M> {
300 &self.agent
301 }
302
303 pub async fn into_inner(self) -> Agent<M> {
304 self.agent
305 }
306}
307
308pub struct ExtractorBuilder<M, T>
310where
311 M: CompletionModel,
312 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
313{
314 agent_builder: AgentBuilder<M, (), WithBuilderTools>,
315 _t: PhantomData<T>,
316 retries: Option<u64>,
317}
318
319impl<M, T> ExtractorBuilder<M, T>
320where
321 M: CompletionModel,
322 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
323{
324 pub fn new(model: M) -> Self {
325 Self {
326 agent_builder: AgentBuilder::new(model)
327 .preamble("\
328 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
329 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
330 Use the `submit` function to submit the structured data.\n\
331 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
332 ")
333 .tool(SubmitTool::<T> {_t: PhantomData})
334 .tool_choice(ToolChoice::Required),
335 retries: None,
336 _t: PhantomData,
337 }
338 }
339
340 pub fn preamble(mut self, preamble: &str) -> Self {
342 self.agent_builder = self.agent_builder.append_preamble(&format!(
343 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
344 ));
345 self
346 }
347
348 pub fn context(mut self, doc: &str) -> Self {
350 self.agent_builder = self.agent_builder.context(doc);
351 self
352 }
353
354 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
355 self.agent_builder = self.agent_builder.additional_params(params);
356 self
357 }
358
359 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
361 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
362 self
363 }
364
365 pub fn retries(mut self, retries: u64) -> Self {
367 self.retries = Some(retries);
368 self
369 }
370
371 pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
373 self.agent_builder = self.agent_builder.tool_choice(choice);
374 self
375 }
376
377 pub fn build(self) -> Extractor<M, T> {
379 Extractor {
380 agent: self.agent_builder.build(),
381 _t: PhantomData,
382 retries: self.retries.unwrap_or(0),
383 }
384 }
385
386 pub fn dynamic_context(
391 mut self,
392 sample: usize,
393 dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
394 ) -> Self {
395 self.agent_builder = self.agent_builder.dynamic_context(sample, dynamic_context);
396 self
397 }
398}
399
400#[derive(Deserialize, Serialize)]
401struct SubmitTool<T>
402where
403 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
404{
405 _t: PhantomData<T>,
406}
407
408#[derive(Debug, thiserror::Error)]
409#[error("SubmitError")]
410struct SubmitError;
411
412impl<T> Tool for SubmitTool<T>
413where
414 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
415{
416 const NAME: &'static str = SUBMIT_TOOL_NAME;
417 type Error = SubmitError;
418 type Args = T;
419 type Output = T;
420
421 async fn definition(&self, _prompt: String) -> ToolDefinition {
422 ToolDefinition {
423 name: Self::NAME.to_string(),
424 description: "Submit the structured data you extracted from the provided text."
425 .to_string(),
426 parameters: json!(schema_for!(T)),
427 }
428 }
429
430 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
431 Ok(data)
432 }
433}