rig/
extractor.rs

1//! This module provides high-level abstractions for extracting structured data from text using LLMs.
2//!
3//! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`,
4//! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro.
5//!
6//! # Example
7//! ```
8//! use rig::providers::openai;
9//!
10//! // Initialize the OpenAI client
11//! let openai = openai::Client::new("your-open-ai-api-key");
12//!
13//! // Define the structure of the data you want to extract
14//! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
15//! struct Person {
16//!    name: Option<String>,
17//!    age: Option<u8>,
18//!    profession: Option<String>,
19//! }
20//!
21//! // Create the extractor
22//! let extractor = openai.extractor::<Person>(openai::GPT_4O)
23//!     .build();
24//!
25//! // Extract structured data from text
26//! let person = extractor.extract("John Doe is a 30 year old doctor.")
27//!     .await
28//!     .expect("Failed to extract data from text");
29//! ```
30
31use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38    agent::{Agent, AgentBuilder},
39    completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40    message::{AssistantContent, Message, ToolCall, ToolFunction},
41    tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48    #[error("No data extracted")]
49    NoData,
50
51    #[error("Failed to deserialize the extracted data: {0}")]
52    DeserializationError(#[from] serde_json::Error),
53
54    #[error("CompletionError: {0}")]
55    CompletionError(#[from] CompletionError),
56}
57
58/// Extractor for structured data from text
59pub struct Extractor<M, T>
60where
61    M: CompletionModel,
62    T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
63{
64    agent: Agent<M>,
65    _t: PhantomData<T>,
66    retries: u64,
67}
68
69impl<M, T> Extractor<M, T>
70where
71    M: CompletionModel,
72    T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
73{
74    /// Attempts to extract data from the given text with a number of retries.
75    ///
76    /// The function will retry the extraction if the initial attempt fails or
77    /// if the model does not call the `submit` tool.
78    ///
79    /// The number of retries is determined by the `retries` field on the Extractor struct.
80    pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
81        let mut last_error = None;
82        let text_message = text.into();
83
84        for i in 0..=self.retries {
85            tracing::debug!(
86                "Attempting to extract JSON. Retries left: {retries}",
87                retries = self.retries - i
88            );
89            let attempt_text = text_message.clone();
90            match self.extract_json(attempt_text).await {
91                Ok(data) => return Ok(data),
92                Err(e) => {
93                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
94                    last_error = Some(e);
95                }
96            }
97        }
98
99        // If the loop finishes without a successful extraction, return the last error encountered.
100        Err(last_error.unwrap_or(ExtractionError::NoData))
101    }
102
103    async fn extract_json(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
104        let response = self.agent.completion(text, vec![]).await?.send().await?;
105
106        if !response.choice.iter().any(|x| {
107            let AssistantContent::ToolCall(ToolCall {
108                function: ToolFunction { name, .. },
109                ..
110            }) = x
111            else {
112                return false;
113            };
114
115            name == SUBMIT_TOOL_NAME
116        }) {
117            tracing::warn!(
118                "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
119            );
120        }
121
122        let arguments = response
123            .choice
124            .into_iter()
125            // We filter tool calls to look for submit tool calls
126            .filter_map(|content| {
127                if let AssistantContent::ToolCall(ToolCall {
128                    function: ToolFunction { arguments, name },
129                    ..
130                }) = content
131                {
132                    if name == SUBMIT_TOOL_NAME {
133                        Some(arguments)
134                    } else {
135                        None
136                    }
137                } else {
138                    None
139                }
140            })
141            .collect::<Vec<_>>();
142
143        if arguments.len() > 1 {
144            tracing::warn!(
145                "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
146            );
147        }
148
149        let raw_data = if let Some(arg) = arguments.into_iter().next() {
150            arg
151        } else {
152            return Err(ExtractionError::NoData);
153        };
154
155        Ok(serde_json::from_value(raw_data)?)
156    }
157
158    pub async fn get_inner(&self) -> &Agent<M> {
159        &self.agent
160    }
161
162    pub async fn into_inner(self) -> Agent<M> {
163        self.agent
164    }
165}
166
167/// Builder for the Extractor
168pub struct ExtractorBuilder<M, T>
169where
170    M: CompletionModel,
171    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
172{
173    agent_builder: AgentBuilder<M>,
174    _t: PhantomData<T>,
175    retries: Option<u64>,
176}
177
178impl<M, T> ExtractorBuilder<M, T>
179where
180    M: CompletionModel,
181    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
182{
183    pub fn new(model: M) -> Self {
184        Self {
185            agent_builder: AgentBuilder::new(model)
186                .preamble("\
187                    You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
188                    You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
189                    Use the `submit` function to submit the structured data.\n\
190                    Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
191                ")
192                .tool(SubmitTool::<T> {_t: PhantomData}),
193            retries: None,
194            _t: PhantomData,
195        }
196    }
197
198    /// Add additional preamble to the extractor
199    pub fn preamble(mut self, preamble: &str) -> Self {
200        self.agent_builder = self.agent_builder.append_preamble(&format!(
201            "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
202        ));
203        self
204    }
205
206    /// Add a context document to the extractor
207    pub fn context(mut self, doc: &str) -> Self {
208        self.agent_builder = self.agent_builder.context(doc);
209        self
210    }
211
212    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
213        self.agent_builder = self.agent_builder.additional_params(params);
214        self
215    }
216
217    /// Set the maximum number of tokens for the completion
218    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
219        self.agent_builder = self.agent_builder.max_tokens(max_tokens);
220        self
221    }
222
223    /// Set the maximum number of retries for the extractor.
224    pub fn retries(mut self, retries: u64) -> Self {
225        self.retries = Some(retries);
226        self
227    }
228
229    /// Build the Extractor
230    pub fn build(self) -> Extractor<M, T> {
231        Extractor {
232            agent: self.agent_builder.build(),
233            _t: PhantomData,
234            retries: self.retries.unwrap_or(0),
235        }
236    }
237}
238
239#[derive(Deserialize, Serialize)]
240struct SubmitTool<T>
241where
242    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
243{
244    _t: PhantomData<T>,
245}
246
247#[derive(Debug, thiserror::Error)]
248#[error("SubmitError")]
249struct SubmitError;
250
251impl<T> Tool for SubmitTool<T>
252where
253    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
254{
255    const NAME: &'static str = SUBMIT_TOOL_NAME;
256    type Error = SubmitError;
257    type Args = T;
258    type Output = T;
259
260    async fn definition(&self, _prompt: String) -> ToolDefinition {
261        ToolDefinition {
262            name: Self::NAME.to_string(),
263            description: "Submit the structured data you extracted from the provided text."
264                .to_string(),
265            parameters: json!(schema_for!(T)),
266        }
267    }
268
269    async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
270        Ok(data)
271    }
272}