Skip to main content

rig/
extractor.rs

1//! This module provides high-level abstractions for extracting structured data from text using LLMs.
2//!
3//! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`,
4//! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro.
5//!
6//! # Example
7//! ```
8//! use rig::providers::openai;
9//!
10//! // Initialize the OpenAI client
11//! let openai = openai::Client::new("your-open-ai-api-key");
12//!
13//! // Define the structure of the data you want to extract
14//! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
15//! struct Person {
16//!    name: Option<String>,
17//!    age: Option<u8>,
18//!    profession: Option<String>,
19//! }
20//!
21//! // Create the extractor
22//! let extractor = openai.extractor::<Person>(openai::GPT_4O)
23//!     .build();
24//!
25//! // Extract structured data from text
26//! let person = extractor.extract("John Doe is a 30 year old doctor.")
27//!     .await
28//!     .expect("Failed to extract data from text");
29//! ```
30
31use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38    agent::{Agent, AgentBuilder, WithBuilderTools},
39    completion::{Completion, CompletionError, CompletionModel, ToolDefinition, Usage},
40    message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
41    tool::Tool,
42    vector_store::VectorStoreIndexDyn,
43    wasm_compat::{WasmCompatSend, WasmCompatSync},
44};
45
46const SUBMIT_TOOL_NAME: &str = "submit";
47
48/// Response from an extraction operation containing the extracted data and usage information.
49#[derive(Debug, Clone)]
50pub struct ExtractionResponse<T> {
51    /// The extracted structured data
52    pub data: T,
53    /// Accumulated token usage across all attempts (including retries)
54    pub usage: Usage,
55}
56
57#[derive(Debug, thiserror::Error)]
58pub enum ExtractionError {
59    #[error("No data extracted")]
60    NoData,
61
62    #[error("Failed to deserialize the extracted data: {0}")]
63    DeserializationError(#[from] serde_json::Error),
64
65    #[error("CompletionError: {0}")]
66    CompletionError(#[from] CompletionError),
67}
68
69/// Extractor for structured data from text
70pub struct Extractor<M, T>
71where
72    M: CompletionModel,
73    T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
74{
75    agent: Agent<M>,
76    _t: PhantomData<T>,
77    retries: u64,
78}
79
80impl<M, T> Extractor<M, T>
81where
82    M: CompletionModel,
83    T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
84{
85    /// Attempts to extract data from the given text with a number of retries.
86    ///
87    /// The function will retry the extraction if the initial attempt fails or
88    /// if the model does not call the `submit` tool.
89    ///
90    /// The number of retries is determined by the `retries` field on the Extractor struct.
91    pub async fn extract(
92        &self,
93        text: impl Into<Message> + WasmCompatSend,
94    ) -> Result<T, ExtractionError> {
95        let mut last_error = None;
96        let text_message = text.into();
97
98        for i in 0..=self.retries {
99            tracing::debug!(
100                "Attempting to extract JSON. Retries left: {retries}",
101                retries = self.retries - i
102            );
103            let attempt_text = text_message.clone();
104            match self.extract_json_with_usage(attempt_text, vec![]).await {
105                Ok((data, _usage)) => return Ok(data),
106                Err(e) => {
107                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
108                    last_error = Some(e);
109                }
110            }
111        }
112
113        // If the loop finishes without a successful extraction, return the last error encountered.
114        Err(last_error.unwrap_or(ExtractionError::NoData))
115    }
116
117    /// Attempts to extract data from the given text with a number of retries.
118    ///
119    /// The function will retry the extraction if the initial attempt fails or
120    /// if the model does not call the `submit` tool.
121    ///
122    /// The number of retries is determined by the `retries` field on the Extractor struct.
123    pub async fn extract_with_chat_history(
124        &self,
125        text: impl Into<Message> + WasmCompatSend,
126        chat_history: Vec<Message>,
127    ) -> Result<T, ExtractionError> {
128        let mut last_error = None;
129        let text_message = text.into();
130
131        for i in 0..=self.retries {
132            tracing::debug!(
133                "Attempting to extract JSON. Retries left: {retries}",
134                retries = self.retries - i
135            );
136            let attempt_text = text_message.clone();
137            match self
138                .extract_json_with_usage(attempt_text, chat_history.clone())
139                .await
140            {
141                Ok((data, _usage)) => return Ok(data),
142                Err(e) => {
143                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
144                    last_error = Some(e);
145                }
146            }
147        }
148
149        // If the loop finishes without a successful extraction, return the last error encountered.
150        Err(last_error.unwrap_or(ExtractionError::NoData))
151    }
152
153    /// Attempts to extract data from the given text with a number of retries,
154    /// returning both the extracted data and accumulated token usage.
155    ///
156    /// The function will retry the extraction if the initial attempt fails or
157    /// if the model does not call the `submit` tool.
158    ///
159    /// The number of retries is determined by the `retries` field on the Extractor struct.
160    ///
161    /// Usage accumulates across all retry attempts, providing the complete cost picture
162    /// including failed attempts.
163    pub async fn extract_with_usage(
164        &self,
165        text: impl Into<Message> + WasmCompatSend,
166    ) -> Result<ExtractionResponse<T>, ExtractionError> {
167        let mut last_error = None;
168        let text_message = text.into();
169        let mut usage = Usage::new();
170
171        for i in 0..=self.retries {
172            tracing::debug!(
173                "Attempting to extract JSON. Retries left: {retries}",
174                retries = self.retries - i
175            );
176            let attempt_text = text_message.clone();
177            match self.extract_json_with_usage(attempt_text, vec![]).await {
178                Ok((data, u)) => {
179                    usage += u;
180                    return Ok(ExtractionResponse { data, usage });
181                }
182                Err(e) => {
183                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
184                    last_error = Some(e);
185                }
186            }
187        }
188
189        // If the loop finishes without a successful extraction, return the last error encountered.
190        Err(last_error.unwrap_or(ExtractionError::NoData))
191    }
192
193    /// Attempts to extract data from the given text with a number of retries,
194    /// providing chat history context, and returning both the extracted data
195    /// and accumulated token usage.
196    ///
197    /// The function will retry the extraction if the initial attempt fails or
198    /// if the model does not call the `submit` tool.
199    ///
200    /// The number of retries is determined by the `retries` field on the Extractor struct.
201    ///
202    /// Usage accumulates across all retry attempts, providing the complete cost picture
203    /// including failed attempts.
204    pub async fn extract_with_chat_history_with_usage(
205        &self,
206        text: impl Into<Message> + WasmCompatSend,
207        chat_history: Vec<Message>,
208    ) -> Result<ExtractionResponse<T>, ExtractionError> {
209        let mut last_error = None;
210        let text_message = text.into();
211        let mut usage = Usage::new();
212
213        for i in 0..=self.retries {
214            tracing::debug!(
215                "Attempting to extract JSON. Retries left: {retries}",
216                retries = self.retries - i
217            );
218            let attempt_text = text_message.clone();
219            match self
220                .extract_json_with_usage(attempt_text, chat_history.clone())
221                .await
222            {
223                Ok((data, u)) => {
224                    usage += u;
225                    return Ok(ExtractionResponse { data, usage });
226                }
227                Err(e) => {
228                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
229                    last_error = Some(e);
230                }
231            }
232        }
233
234        // If the loop finishes without a successful extraction, return the last error encountered.
235        Err(last_error.unwrap_or(ExtractionError::NoData))
236    }
237
238    async fn extract_json_with_usage(
239        &self,
240        text: impl Into<Message> + WasmCompatSend,
241        messages: Vec<Message>,
242    ) -> Result<(T, Usage), ExtractionError> {
243        let response = self.agent.completion(text, messages).await?.send().await?;
244        let usage = response.usage;
245
246        if !response.choice.iter().any(|x| {
247            let AssistantContent::ToolCall(ToolCall {
248                function: ToolFunction { name, .. },
249                ..
250            }) = x
251            else {
252                return false;
253            };
254
255            name == SUBMIT_TOOL_NAME
256        }) {
257            tracing::warn!(
258                "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
259            );
260        }
261
262        let arguments = response
263            .choice
264            .into_iter()
265            // We filter tool calls to look for submit tool calls
266            .filter_map(|content| {
267                if let AssistantContent::ToolCall(ToolCall {
268                    function: ToolFunction { arguments, name },
269                    ..
270                }) = content
271                {
272                    if name == SUBMIT_TOOL_NAME {
273                        Some(arguments)
274                    } else {
275                        None
276                    }
277                } else {
278                    None
279                }
280            })
281            .collect::<Vec<_>>();
282
283        if arguments.len() > 1 {
284            tracing::warn!(
285                "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
286            );
287        }
288
289        let raw_data = if let Some(arg) = arguments.into_iter().next() {
290            arg
291        } else {
292            return Err(ExtractionError::NoData);
293        };
294
295        let data = serde_json::from_value(raw_data)?;
296        Ok((data, usage))
297    }
298
299    pub async fn get_inner(&self) -> &Agent<M> {
300        &self.agent
301    }
302
303    pub async fn into_inner(self) -> Agent<M> {
304        self.agent
305    }
306}
307
308/// Builder for the Extractor
309pub struct ExtractorBuilder<M, T>
310where
311    M: CompletionModel,
312    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
313{
314    agent_builder: AgentBuilder<M, (), WithBuilderTools>,
315    _t: PhantomData<T>,
316    retries: Option<u64>,
317}
318
319impl<M, T> ExtractorBuilder<M, T>
320where
321    M: CompletionModel,
322    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
323{
324    pub fn new(model: M) -> Self {
325        Self {
326            agent_builder: AgentBuilder::new(model)
327                .preamble("\
328                    You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
329                    You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
330                    Use the `submit` function to submit the structured data.\n\
331                    Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
332                ")
333                .tool(SubmitTool::<T> {_t: PhantomData})
334                .tool_choice(ToolChoice::Required),
335            retries: None,
336            _t: PhantomData,
337        }
338    }
339
340    /// Add additional preamble to the extractor
341    pub fn preamble(mut self, preamble: &str) -> Self {
342        self.agent_builder = self.agent_builder.append_preamble(&format!(
343            "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
344        ));
345        self
346    }
347
348    /// Add a context document to the extractor
349    pub fn context(mut self, doc: &str) -> Self {
350        self.agent_builder = self.agent_builder.context(doc);
351        self
352    }
353
354    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
355        self.agent_builder = self.agent_builder.additional_params(params);
356        self
357    }
358
359    /// Set the maximum number of tokens for the completion
360    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
361        self.agent_builder = self.agent_builder.max_tokens(max_tokens);
362        self
363    }
364
365    /// Set the maximum number of retries for the extractor.
366    pub fn retries(mut self, retries: u64) -> Self {
367        self.retries = Some(retries);
368        self
369    }
370
371    /// Set the `tool_choice` option for the inner Agent.
372    pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
373        self.agent_builder = self.agent_builder.tool_choice(choice);
374        self
375    }
376
377    /// Build the Extractor
378    pub fn build(self) -> Extractor<M, T> {
379        Extractor {
380            agent: self.agent_builder.build(),
381            _t: PhantomData,
382            retries: self.retries.unwrap_or(0),
383        }
384    }
385
386    /// Add dynamic context (RAG) to the extractor.
387    ///
388    /// On each prompt, `sample` documents will be retrieved from the index based on the RAG text
389    /// and inserted in the request.
390    pub fn dynamic_context(
391        mut self,
392        sample: usize,
393        dynamic_context: impl VectorStoreIndexDyn + Send + Sync + 'static,
394    ) -> Self {
395        self.agent_builder = self.agent_builder.dynamic_context(sample, dynamic_context);
396        self
397    }
398}
399
400#[derive(Deserialize, Serialize)]
401struct SubmitTool<T>
402where
403    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
404{
405    _t: PhantomData<T>,
406}
407
408#[derive(Debug, thiserror::Error)]
409#[error("SubmitError")]
410struct SubmitError;
411
412impl<T> Tool for SubmitTool<T>
413where
414    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
415{
416    const NAME: &'static str = SUBMIT_TOOL_NAME;
417    type Error = SubmitError;
418    type Args = T;
419    type Output = T;
420
421    async fn definition(&self, _prompt: String) -> ToolDefinition {
422        ToolDefinition {
423            name: Self::NAME.to_string(),
424            description: "Submit the structured data you extracted from the provided text."
425                .to_string(),
426            parameters: json!(schema_for!(T)),
427        }
428    }
429
430    async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
431        Ok(data)
432    }
433}