Skip to main content

rig_core/
extractor.rs

1//! This module provides high-level abstractions for extracting structured data from text using LLMs.
2//!
3//! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`,
4//! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro.
5//!
6//! # Example
7//! ```no_run
8//! use rig_core::{client::CompletionClient, providers::openai};
9//!
10//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
11//! // Initialize the OpenAI client
12//! let openai = openai::Client::new("your-open-ai-api-key")?;
13//!
14//! // Define the structure of the data you want to extract
15//! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
16//! struct Person {
17//!    name: Option<String>,
18//!    age: Option<u8>,
19//!    profession: Option<String>,
20//! }
21//!
22//! // Create the extractor
23//! let extractor = openai.extractor::<Person>(openai::GPT_4O)
24//!     .build();
25//!
26//! // Extract structured data from text
27//! let person = extractor.extract("John Doe is a 30 year old doctor.").await?;
28//! # Ok(())
29//! # }
30//! ```
31
32use std::marker::PhantomData;
33
34use schemars::{JsonSchema, schema_for};
35use serde::{Deserialize, Serialize};
36use serde_json::json;
37
38use crate::{
39    agent::{Agent, AgentBuilder, WithBuilderTools},
40    completion::{Completion, CompletionError, CompletionModel, ToolDefinition, Usage},
41    message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
42    tool::Tool,
43    vector_store::VectorStoreIndexDyn,
44    wasm_compat::{WasmCompatSend, WasmCompatSync},
45};
46
47const SUBMIT_TOOL_NAME: &str = "submit";
48
49/// Response from an extraction operation containing the extracted data and usage information.
50#[derive(Debug, Clone)]
51pub struct ExtractionResponse<T> {
52    /// The extracted structured data
53    pub data: T,
54    /// Accumulated token usage across all attempts (including retries)
55    pub usage: Usage,
56}
57
58#[derive(Debug, thiserror::Error)]
59pub enum ExtractionError {
60    #[error("No data extracted")]
61    NoData,
62
63    #[error("Failed to deserialize the extracted data: {0}")]
64    DeserializationError(#[from] serde_json::Error),
65
66    #[error("CompletionError: {0}")]
67    CompletionError(#[from] CompletionError),
68}
69
70/// Extractor for structured data from text
71pub struct Extractor<M, T>
72where
73    M: CompletionModel,
74    T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
75{
76    agent: Agent<M>,
77    _t: PhantomData<T>,
78    retries: u64,
79}
80
81impl<M, T> Extractor<M, T>
82where
83    M: CompletionModel,
84    T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
85{
86    /// Attempts to extract data from the given text with a number of retries.
87    ///
88    /// The function will retry the extraction if the initial attempt fails or
89    /// if the model does not call the `submit` tool.
90    ///
91    /// The number of retries is determined by the `retries` field on the Extractor struct.
92    pub async fn extract(
93        &self,
94        text: impl Into<Message> + WasmCompatSend,
95    ) -> Result<T, ExtractionError> {
96        let mut last_error = None;
97        let text_message = text.into();
98
99        for i in 0..=self.retries {
100            tracing::debug!(
101                "Attempting to extract JSON. Retries left: {retries}",
102                retries = self.retries - i
103            );
104            let attempt_text = text_message.clone();
105            match self.extract_json_with_usage(attempt_text, vec![]).await {
106                Ok((data, _usage)) => return Ok(data),
107                Err(e) => {
108                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
109                    last_error = Some(e);
110                }
111            }
112        }
113
114        // If the loop finishes without a successful extraction, return the last error encountered.
115        Err(last_error.unwrap_or(ExtractionError::NoData))
116    }
117
118    /// Attempts to extract data from the given text with a number of retries.
119    ///
120    /// The function will retry the extraction if the initial attempt fails or
121    /// if the model does not call the `submit` tool.
122    ///
123    /// The number of retries is determined by the `retries` field on the Extractor struct.
124    pub async fn extract_with_chat_history(
125        &self,
126        text: impl Into<Message> + WasmCompatSend,
127        chat_history: Vec<Message>,
128    ) -> Result<T, ExtractionError> {
129        let mut last_error = None;
130        let text_message = text.into();
131
132        for i in 0..=self.retries {
133            tracing::debug!(
134                "Attempting to extract JSON. Retries left: {retries}",
135                retries = self.retries - i
136            );
137            let attempt_text = text_message.clone();
138            match self
139                .extract_json_with_usage(attempt_text, chat_history.clone())
140                .await
141            {
142                Ok((data, _usage)) => return Ok(data),
143                Err(e) => {
144                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
145                    last_error = Some(e);
146                }
147            }
148        }
149
150        // If the loop finishes without a successful extraction, return the last error encountered.
151        Err(last_error.unwrap_or(ExtractionError::NoData))
152    }
153
154    /// Attempts to extract data from the given text with a number of retries,
155    /// returning both the extracted data and accumulated token usage.
156    ///
157    /// The function will retry the extraction if the initial attempt fails or
158    /// if the model does not call the `submit` tool.
159    ///
160    /// The number of retries is determined by the `retries` field on the Extractor struct.
161    ///
162    /// Usage accumulates across all retry attempts, providing the complete cost picture
163    /// including failed attempts.
164    pub async fn extract_with_usage(
165        &self,
166        text: impl Into<Message> + WasmCompatSend,
167    ) -> Result<ExtractionResponse<T>, ExtractionError> {
168        let mut last_error = None;
169        let text_message = text.into();
170        let mut usage = Usage::new();
171
172        for i in 0..=self.retries {
173            tracing::debug!(
174                "Attempting to extract JSON. Retries left: {retries}",
175                retries = self.retries - i
176            );
177            let attempt_text = text_message.clone();
178            match self.extract_json_with_usage(attempt_text, vec![]).await {
179                Ok((data, u)) => {
180                    usage += u;
181                    return Ok(ExtractionResponse { data, usage });
182                }
183                Err(e) => {
184                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
185                    last_error = Some(e);
186                }
187            }
188        }
189
190        // If the loop finishes without a successful extraction, return the last error encountered.
191        Err(last_error.unwrap_or(ExtractionError::NoData))
192    }
193
194    /// Attempts to extract data from the given text with a number of retries,
195    /// providing chat history context, and returning both the extracted data
196    /// and accumulated token usage.
197    ///
198    /// The function will retry the extraction if the initial attempt fails or
199    /// if the model does not call the `submit` tool.
200    ///
201    /// The number of retries is determined by the `retries` field on the Extractor struct.
202    ///
203    /// Usage accumulates across all retry attempts, providing the complete cost picture
204    /// including failed attempts.
205    pub async fn extract_with_chat_history_with_usage(
206        &self,
207        text: impl Into<Message> + WasmCompatSend,
208        chat_history: Vec<Message>,
209    ) -> Result<ExtractionResponse<T>, ExtractionError> {
210        let mut last_error = None;
211        let text_message = text.into();
212        let mut usage = Usage::new();
213
214        for i in 0..=self.retries {
215            tracing::debug!(
216                "Attempting to extract JSON. Retries left: {retries}",
217                retries = self.retries - i
218            );
219            let attempt_text = text_message.clone();
220            match self
221                .extract_json_with_usage(attempt_text, chat_history.clone())
222                .await
223            {
224                Ok((data, u)) => {
225                    usage += u;
226                    return Ok(ExtractionResponse { data, usage });
227                }
228                Err(e) => {
229                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
230                    last_error = Some(e);
231                }
232            }
233        }
234
235        // If the loop finishes without a successful extraction, return the last error encountered.
236        Err(last_error.unwrap_or(ExtractionError::NoData))
237    }
238
239    async fn extract_json_with_usage(
240        &self,
241        text: impl Into<Message> + WasmCompatSend,
242        messages: Vec<Message>,
243    ) -> Result<(T, Usage), ExtractionError> {
244        let response = self.agent.completion(text, &messages).await?.send().await?;
245        let usage = response.usage;
246
247        if !response.choice.iter().any(|x| {
248            let AssistantContent::ToolCall(ToolCall {
249                function: ToolFunction { name, .. },
250                ..
251            }) = x
252            else {
253                return false;
254            };
255
256            name == SUBMIT_TOOL_NAME
257        }) {
258            tracing::warn!(
259                "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
260            );
261        }
262
263        let arguments = response
264            .choice
265            .into_iter()
266            // We filter tool calls to look for submit tool calls
267            .filter_map(|content| {
268                if let AssistantContent::ToolCall(ToolCall {
269                    function: ToolFunction { arguments, name },
270                    ..
271                }) = content
272                {
273                    if name == SUBMIT_TOOL_NAME {
274                        Some(arguments)
275                    } else {
276                        None
277                    }
278                } else {
279                    None
280                }
281            })
282            .collect::<Vec<_>>();
283
284        if arguments.len() > 1 {
285            tracing::warn!(
286                "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
287            );
288        }
289
290        let raw_data = if let Some(arg) = arguments.into_iter().next() {
291            arg
292        } else {
293            return Err(ExtractionError::NoData);
294        };
295
296        let data = serde_json::from_value(raw_data)?;
297        Ok((data, usage))
298    }
299
300    pub async fn get_inner(&self) -> &Agent<M> {
301        &self.agent
302    }
303
304    pub async fn into_inner(self) -> Agent<M> {
305        self.agent
306    }
307}
308
309/// Builder for the Extractor
310pub struct ExtractorBuilder<M, T>
311where
312    M: CompletionModel,
313    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
314{
315    agent_builder: AgentBuilder<M, (), WithBuilderTools>,
316    _t: PhantomData<T>,
317    retries: Option<u64>,
318}
319
320impl<M, T> ExtractorBuilder<M, T>
321where
322    M: CompletionModel,
323    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
324{
325    pub fn new(model: M) -> Self {
326        Self {
327            agent_builder: AgentBuilder::new(model)
328                .preamble("\
329                    You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
330                    You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
331                    Use the `submit` function to submit the structured data.\n\
332                    Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
333                ")
334                .tool(SubmitTool::<T> {_t: PhantomData})
335                .tool_choice(ToolChoice::Required),
336            retries: None,
337            _t: PhantomData,
338        }
339    }
340
341    /// Add additional preamble to the extractor
342    pub fn preamble(mut self, preamble: &str) -> Self {
343        self.agent_builder = self.agent_builder.append_preamble(&format!(
344            "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
345        ));
346        self
347    }
348
349    /// Add a context document to the extractor
350    pub fn context(mut self, doc: &str) -> Self {
351        self.agent_builder = self.agent_builder.context(doc);
352        self
353    }
354
355    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
356        self.agent_builder = self.agent_builder.additional_params(params);
357        self
358    }
359
360    /// Set the maximum number of tokens for the completion
361    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
362        self.agent_builder = self.agent_builder.max_tokens(max_tokens);
363        self
364    }
365
366    /// Set the maximum number of retries for the extractor.
367    pub fn retries(mut self, retries: u64) -> Self {
368        self.retries = Some(retries);
369        self
370    }
371
372    /// Set the `tool_choice` option for the inner Agent.
373    pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
374        self.agent_builder = self.agent_builder.tool_choice(choice);
375        self
376    }
377
378    /// Build the Extractor
379    pub fn build(self) -> Extractor<M, T> {
380        Extractor {
381            agent: self.agent_builder.build(),
382            _t: PhantomData,
383            retries: self.retries.unwrap_or(0),
384        }
385    }
386
387    /// Add dynamic context (RAG) to the extractor.
388    ///
389    /// On each prompt, `sample` documents will be retrieved from the index based on the RAG text
390    /// and inserted in the request.
391    pub fn dynamic_context(
392        mut self,
393        sample: usize,
394        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
395    ) -> Self {
396        self.agent_builder = self.agent_builder.dynamic_context(sample, dynamic_context);
397        self
398    }
399}
400
401#[derive(Deserialize, Serialize)]
402struct SubmitTool<T>
403where
404    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
405{
406    _t: PhantomData<T>,
407}
408
409#[derive(Debug, thiserror::Error)]
410#[error("SubmitError")]
411struct SubmitError;
412
413impl<T> Tool for SubmitTool<T>
414where
415    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
416{
417    const NAME: &'static str = SUBMIT_TOOL_NAME;
418    type Error = SubmitError;
419    type Args = T;
420    type Output = T;
421
422    async fn definition(&self, _prompt: String) -> ToolDefinition {
423        ToolDefinition {
424            name: Self::NAME.to_string(),
425            description: "Submit the structured data you extracted from the provided text."
426                .to_string(),
427            parameters: json!(schema_for!(T)),
428        }
429    }
430
431    async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
432        Ok(data)
433    }
434}