1use std::marker::PhantomData;
33
34use schemars::{JsonSchema, schema_for};
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37
38use crate::{
39 agent::{Agent, AgentBuilder, WithBuilderTools},
40 completion::{Completion, CompletionError, CompletionModel, ToolDefinition, Usage},
41 message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
42 tool::Tool,
43 vector_store::VectorStoreIndexDyn,
44 wasm_compat::{WasmCompatSend, WasmCompatSync},
45};
46
47const SUBMIT_TOOL_NAME: &str = "submit";
48
49#[derive(Debug, Clone)]
51pub struct ExtractionResponse<T> {
52 pub data: T,
54 pub usage: Usage,
56}
57
58#[derive(Debug, thiserror::Error)]
59pub enum ExtractionError {
60 #[error("No data extracted")]
61 NoData,
62
63 #[error("Failed to deserialize the extracted data: {0}")]
64 DeserializationError(#[from] serde_json::Error),
65
66 #[error("CompletionError: {0}")]
67 CompletionError(#[from] CompletionError),
68}
69
70pub struct Extractor<M, T>
72where
73 M: CompletionModel,
74 T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
75{
76 agent: Agent<M>,
77 _t: PhantomData<T>,
78 retries: u64,
79}
80
81impl<M, T> Extractor<M, T>
82where
83 M: CompletionModel,
84 T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
85{
86 pub async fn extract(
93 &self,
94 text: impl Into<Message> + WasmCompatSend,
95 ) -> Result<T, ExtractionError> {
96 let mut last_error = None;
97 let text_message = text.into();
98
99 for i in 0..=self.retries {
100 tracing::debug!(
101 "Attempting to extract JSON. Retries left: {retries}",
102 retries = self.retries - i
103 );
104 let attempt_text = text_message.clone();
105 match self.extract_json_with_usage(attempt_text, vec![]).await {
106 Ok((data, _usage)) => return Ok(data),
107 Err(e) => {
108 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
109 last_error = Some(e);
110 }
111 }
112 }
113
114 Err(last_error.unwrap_or(ExtractionError::NoData))
116 }
117
118 pub async fn extract_with_chat_history(
125 &self,
126 text: impl Into<Message> + WasmCompatSend,
127 chat_history: Vec<Message>,
128 ) -> Result<T, ExtractionError> {
129 let mut last_error = None;
130 let text_message = text.into();
131
132 for i in 0..=self.retries {
133 tracing::debug!(
134 "Attempting to extract JSON. Retries left: {retries}",
135 retries = self.retries - i
136 );
137 let attempt_text = text_message.clone();
138 match self
139 .extract_json_with_usage(attempt_text, chat_history.clone())
140 .await
141 {
142 Ok((data, _usage)) => return Ok(data),
143 Err(e) => {
144 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
145 last_error = Some(e);
146 }
147 }
148 }
149
150 Err(last_error.unwrap_or(ExtractionError::NoData))
152 }
153
154 pub async fn extract_with_usage(
165 &self,
166 text: impl Into<Message> + WasmCompatSend,
167 ) -> Result<ExtractionResponse<T>, ExtractionError> {
168 let mut last_error = None;
169 let text_message = text.into();
170 let mut usage = Usage::new();
171
172 for i in 0..=self.retries {
173 tracing::debug!(
174 "Attempting to extract JSON. Retries left: {retries}",
175 retries = self.retries - i
176 );
177 let attempt_text = text_message.clone();
178 match self.extract_json_with_usage(attempt_text, vec![]).await {
179 Ok((data, u)) => {
180 usage += u;
181 return Ok(ExtractionResponse { data, usage });
182 }
183 Err(e) => {
184 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
185 last_error = Some(e);
186 }
187 }
188 }
189
190 Err(last_error.unwrap_or(ExtractionError::NoData))
192 }
193
194 pub async fn extract_with_chat_history_with_usage(
206 &self,
207 text: impl Into<Message> + WasmCompatSend,
208 chat_history: Vec<Message>,
209 ) -> Result<ExtractionResponse<T>, ExtractionError> {
210 let mut last_error = None;
211 let text_message = text.into();
212 let mut usage = Usage::new();
213
214 for i in 0..=self.retries {
215 tracing::debug!(
216 "Attempting to extract JSON. Retries left: {retries}",
217 retries = self.retries - i
218 );
219 let attempt_text = text_message.clone();
220 match self
221 .extract_json_with_usage(attempt_text, chat_history.clone())
222 .await
223 {
224 Ok((data, u)) => {
225 usage += u;
226 return Ok(ExtractionResponse { data, usage });
227 }
228 Err(e) => {
229 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
230 last_error = Some(e);
231 }
232 }
233 }
234
235 Err(last_error.unwrap_or(ExtractionError::NoData))
237 }
238
239 async fn extract_json_with_usage(
240 &self,
241 text: impl Into<Message> + WasmCompatSend,
242 messages: Vec<Message>,
243 ) -> Result<(T, Usage), ExtractionError> {
244 let response = self.agent.completion(text, &messages).await?.send().await?;
245 let usage = response.usage;
246
247 if !response.choice.iter().any(|x| {
248 let AssistantContent::ToolCall(ToolCall {
249 function: ToolFunction { name, .. },
250 ..
251 }) = x
252 else {
253 return false;
254 };
255
256 name == SUBMIT_TOOL_NAME
257 }) {
258 tracing::warn!(
259 "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
260 );
261 }
262
263 let arguments = response
264 .choice
265 .into_iter()
266 .filter_map(|content| {
268 if let AssistantContent::ToolCall(ToolCall {
269 function: ToolFunction { arguments, name },
270 ..
271 }) = content
272 {
273 if name == SUBMIT_TOOL_NAME {
274 Some(arguments)
275 } else {
276 None
277 }
278 } else {
279 None
280 }
281 })
282 .collect::<Vec<_>>();
283
284 if arguments.len() > 1 {
285 tracing::warn!(
286 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
287 );
288 }
289
290 let raw_data = if let Some(arg) = arguments.into_iter().next() {
291 arg
292 } else {
293 return Err(ExtractionError::NoData);
294 };
295
296 let data = serde_json::from_value(raw_data)?;
297 Ok((data, usage))
298 }
299
300 pub async fn get_inner(&self) -> &Agent<M> {
301 &self.agent
302 }
303
304 pub async fn into_inner(self) -> Agent<M> {
305 self.agent
306 }
307}
308
309pub struct ExtractorBuilder<M, T>
311where
312 M: CompletionModel,
313 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
314{
315 agent_builder: AgentBuilder<M, (), WithBuilderTools>,
316 _t: PhantomData<T>,
317 retries: Option<u64>,
318}
319
320impl<M, T> ExtractorBuilder<M, T>
321where
322 M: CompletionModel,
323 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
324{
325 pub fn new(model: M) -> Self {
326 Self {
327 agent_builder: AgentBuilder::new(model)
328 .preamble("\
329 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
330 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
331 Use the `submit` function to submit the structured data.\n\
332 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
333 ")
334 .tool(SubmitTool::<T> {_t: PhantomData})
335 .tool_choice(ToolChoice::Required),
336 retries: None,
337 _t: PhantomData,
338 }
339 }
340
341 pub fn preamble(mut self, preamble: &str) -> Self {
343 self.agent_builder = self.agent_builder.append_preamble(&format!(
344 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
345 ));
346 self
347 }
348
349 pub fn context(mut self, doc: &str) -> Self {
351 self.agent_builder = self.agent_builder.context(doc);
352 self
353 }
354
355 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
356 self.agent_builder = self.agent_builder.additional_params(params);
357 self
358 }
359
360 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
362 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
363 self
364 }
365
366 pub fn retries(mut self, retries: u64) -> Self {
368 self.retries = Some(retries);
369 self
370 }
371
372 pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
374 self.agent_builder = self.agent_builder.tool_choice(choice);
375 self
376 }
377
378 pub fn build(self) -> Extractor<M, T> {
380 Extractor {
381 agent: self.agent_builder.build(),
382 _t: PhantomData,
383 retries: self.retries.unwrap_or(0),
384 }
385 }
386
387 pub fn dynamic_context(
392 mut self,
393 sample: usize,
394 dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
395 ) -> Self {
396 self.agent_builder = self.agent_builder.dynamic_context(sample, dynamic_context);
397 self
398 }
399}
400
401#[derive(Deserialize, Serialize)]
402struct SubmitTool<T>
403where
404 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
405{
406 _t: PhantomData<T>,
407}
408
409#[derive(Debug, thiserror::Error)]
410#[error("SubmitError")]
411struct SubmitError;
412
413impl<T> Tool for SubmitTool<T>
414where
415 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
416{
417 const NAME: &'static str = SUBMIT_TOOL_NAME;
418 type Error = SubmitError;
419 type Args = T;
420 type Output = T;
421
422 async fn definition(&self, _prompt: String) -> ToolDefinition {
423 ToolDefinition {
424 name: Self::NAME.to_string(),
425 description: "Submit the structured data you extracted from the provided text."
426 .to_string(),
427 parameters: json!(schema_for!(T)),
428 }
429 }
430
431 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
432 Ok(data)
433 }
434}