rig/
extractor.rs

1//! This module provides high-level abstractions for extracting structured data from text using LLMs.
2//!
3//! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`,
4//! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro.
5//!
6//! # Example
7//! ```
8//! use rig::providers::openai;
9//!
10//! // Initialize the OpenAI client
11//! let openai = openai::Client::new("your-open-ai-api-key");
12//!
13//! // Define the structure of the data you want to extract
14//! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
15//! struct Person {
16//!    name: Option<String>,
17//!    age: Option<u8>,
18//!    profession: Option<String>,
19//! }
20//!
21//! // Create the extractor
22//! let extractor = openai.extractor::<Person>(openai::GPT_4O)
23//!     .build();
24//!
25//! // Extract structured data from text
26//! let person = extractor.extract("John Doe is a 30 year old doctor.")
27//!     .await
28//!     .expect("Failed to extract data from text");
29//! ```
30
31use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38    agent::{Agent, AgentBuilder, AgentBuilderSimple},
39    completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40    message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
41    tool::Tool,
42    wasm_compat::{WasmCompatSend, WasmCompatSync},
43};
44
45const SUBMIT_TOOL_NAME: &str = "submit";
46
47#[derive(Debug, thiserror::Error)]
48pub enum ExtractionError {
49    #[error("No data extracted")]
50    NoData,
51
52    #[error("Failed to deserialize the extracted data: {0}")]
53    DeserializationError(#[from] serde_json::Error),
54
55    #[error("CompletionError: {0}")]
56    CompletionError(#[from] CompletionError),
57}
58
59/// Extractor for structured data from text
60pub struct Extractor<M, T>
61where
62    M: CompletionModel,
63    T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
64{
65    agent: Agent<M>,
66    _t: PhantomData<T>,
67    retries: u64,
68}
69
70impl<M, T> Extractor<M, T>
71where
72    M: CompletionModel,
73    T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
74{
75    /// Attempts to extract data from the given text with a number of retries.
76    ///
77    /// The function will retry the extraction if the initial attempt fails or
78    /// if the model does not call the `submit` tool.
79    ///
80    /// The number of retries is determined by the `retries` field on the Extractor struct.
81    pub async fn extract(
82        &self,
83        text: impl Into<Message> + WasmCompatSend,
84    ) -> Result<T, ExtractionError> {
85        let mut last_error = None;
86        let text_message = text.into();
87
88        for i in 0..=self.retries {
89            tracing::debug!(
90                "Attempting to extract JSON. Retries left: {retries}",
91                retries = self.retries - i
92            );
93            let attempt_text = text_message.clone();
94            match self.extract_json(attempt_text, vec![]).await {
95                Ok(data) => return Ok(data),
96                Err(e) => {
97                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
98                    last_error = Some(e);
99                }
100            }
101        }
102
103        // If the loop finishes without a successful extraction, return the last error encountered.
104        Err(last_error.unwrap_or(ExtractionError::NoData))
105    }
106
107    /// Attempts to extract data from the given text with a number of retries.
108    ///
109    /// The function will retry the extraction if the initial attempt fails or
110    /// if the model does not call the `submit` tool.
111    ///
112    /// The number of retries is determined by the `retries` field on the Extractor struct.
113    pub async fn extract_with_chat_history(
114        &self,
115        text: impl Into<Message> + WasmCompatSend,
116        chat_history: Vec<Message>,
117    ) -> Result<T, ExtractionError> {
118        let mut last_error = None;
119        let text_message = text.into();
120
121        for i in 0..=self.retries {
122            tracing::debug!(
123                "Attempting to extract JSON. Retries left: {retries}",
124                retries = self.retries - i
125            );
126            let attempt_text = text_message.clone();
127            match self.extract_json(attempt_text, chat_history.clone()).await {
128                Ok(data) => return Ok(data),
129                Err(e) => {
130                    tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
131                    last_error = Some(e);
132                }
133            }
134        }
135
136        // If the loop finishes without a successful extraction, return the last error encountered.
137        Err(last_error.unwrap_or(ExtractionError::NoData))
138    }
139
140    async fn extract_json(
141        &self,
142        text: impl Into<Message> + WasmCompatSend,
143        messages: Vec<Message>,
144    ) -> Result<T, ExtractionError> {
145        let response = self.agent.completion(text, messages).await?.send().await?;
146
147        if !response.choice.iter().any(|x| {
148            let AssistantContent::ToolCall(ToolCall {
149                function: ToolFunction { name, .. },
150                ..
151            }) = x
152            else {
153                return false;
154            };
155
156            name == SUBMIT_TOOL_NAME
157        }) {
158            tracing::warn!(
159                "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
160            );
161        }
162
163        let arguments = response
164            .choice
165            .into_iter()
166            // We filter tool calls to look for submit tool calls
167            .filter_map(|content| {
168                if let AssistantContent::ToolCall(ToolCall {
169                    function: ToolFunction { arguments, name },
170                    ..
171                }) = content
172                {
173                    if name == SUBMIT_TOOL_NAME {
174                        Some(arguments)
175                    } else {
176                        None
177                    }
178                } else {
179                    None
180                }
181            })
182            .collect::<Vec<_>>();
183
184        if arguments.len() > 1 {
185            tracing::warn!(
186                "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
187            );
188        }
189
190        let raw_data = if let Some(arg) = arguments.into_iter().next() {
191            arg
192        } else {
193            return Err(ExtractionError::NoData);
194        };
195
196        Ok(serde_json::from_value(raw_data)?)
197    }
198
199    pub async fn get_inner(&self) -> &Agent<M> {
200        &self.agent
201    }
202
203    pub async fn into_inner(self) -> Agent<M> {
204        self.agent
205    }
206}
207
208/// Builder for the Extractor
209pub struct ExtractorBuilder<M, T>
210where
211    M: CompletionModel,
212    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
213{
214    agent_builder: AgentBuilderSimple<M>,
215    _t: PhantomData<T>,
216    retries: Option<u64>,
217}
218
219impl<M, T> ExtractorBuilder<M, T>
220where
221    M: CompletionModel,
222    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
223{
224    pub fn new(model: M) -> Self {
225        Self {
226            agent_builder: AgentBuilder::new(model)
227                .preamble("\
228                    You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
229                    You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
230                    Use the `submit` function to submit the structured data.\n\
231                    Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
232                ")
233                .tool(SubmitTool::<T> {_t: PhantomData})
234                .tool_choice(ToolChoice::Required),
235            retries: None,
236            _t: PhantomData,
237        }
238    }
239
240    /// Add additional preamble to the extractor
241    pub fn preamble(mut self, preamble: &str) -> Self {
242        self.agent_builder = self.agent_builder.append_preamble(&format!(
243            "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
244        ));
245        self
246    }
247
248    /// Add a context document to the extractor
249    pub fn context(mut self, doc: &str) -> Self {
250        self.agent_builder = self.agent_builder.context(doc);
251        self
252    }
253
254    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
255        self.agent_builder = self.agent_builder.additional_params(params);
256        self
257    }
258
259    /// Set the maximum number of tokens for the completion
260    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
261        self.agent_builder = self.agent_builder.max_tokens(max_tokens);
262        self
263    }
264
265    /// Set the maximum number of retries for the extractor.
266    pub fn retries(mut self, retries: u64) -> Self {
267        self.retries = Some(retries);
268        self
269    }
270
271    /// Set the `tool_choice` option for the inner Agent.
272    pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
273        self.agent_builder = self.agent_builder.tool_choice(choice);
274        self
275    }
276
277    /// Build the Extractor
278    pub fn build(self) -> Extractor<M, T> {
279        Extractor {
280            agent: self.agent_builder.build(),
281            _t: PhantomData,
282            retries: self.retries.unwrap_or(0),
283        }
284    }
285}
286
287#[derive(Deserialize, Serialize)]
288struct SubmitTool<T>
289where
290    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
291{
292    _t: PhantomData<T>,
293}
294
295#[derive(Debug, thiserror::Error)]
296#[error("SubmitError")]
297struct SubmitError;
298
299impl<T> Tool for SubmitTool<T>
300where
301    T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
302{
303    const NAME: &'static str = SUBMIT_TOOL_NAME;
304    type Error = SubmitError;
305    type Args = T;
306    type Output = T;
307
308    async fn definition(&self, _prompt: String) -> ToolDefinition {
309        ToolDefinition {
310            name: Self::NAME.to_string(),
311            description: "Submit the structured data you extracted from the provided text."
312                .to_string(),
313            parameters: json!(schema_for!(T)),
314        }
315    }
316
317    async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
318        Ok(data)
319    }
320}