rig/
extractor.rs

1//! This module provides high-level abstractions for extracting structured data from text using LLMs.
2//!
3//! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`,
4//! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro.
5//!
6//! # Example
7//! ```
8//! use rig::providers::openai;
9//!
10//! // Initialize the OpenAI client
11//! let openai = openai::Client::new("your-open-ai-api-key");
12//!
13//! // Define the structure of the data you want to extract
14//! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
15//! struct Person {
16//!    name: Option<String>,
17//!    age: Option<u8>,
18//!    profession: Option<String>,
19//! }
20//!
21//! // Create the extractor
22//! let extractor = openai.extractor::<Person>(openai::GPT_4O)
23//!     .build();
24//!
25//! // Extract structured data from text
26//! let person = extractor.extract("John Doe is a 30 year old doctor.")
27//!     .await
28//!     .expect("Failed to extract data from text");
29//! ```
30
31use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38    agent::{Agent, AgentBuilder},
39    completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40    message::{AssistantContent, Message, ToolCall, ToolFunction},
41    tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48    #[error("No data extracted")]
49    NoData,
50
51    #[error("Failed to deserialize the extracted data: {0}")]
52    DeserializationError(#[from] serde_json::Error),
53
54    #[error("CompletionError: {0}")]
55    CompletionError(#[from] CompletionError),
56}
57
58/// Extractor for structured data from text
59pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
60    agent: Agent<M>,
61    _t: PhantomData<T>,
62}
63
64impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
65where
66    M: Sync,
67{
68    pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
69        let response = self.agent.completion(text, vec![]).await?.send().await?;
70
71        let arguments = response
72            .choice
73            .into_iter()
74            // We filter tool calls to look for submit tool calls
75            .filter_map(|content| {
76                if let AssistantContent::ToolCall(ToolCall {
77                    function: ToolFunction { arguments, name },
78                    ..
79                }) = content
80                {
81                    if name == SUBMIT_TOOL_NAME {
82                        Some(arguments)
83                    } else {
84                        None
85                    }
86                } else {
87                    None
88                }
89            })
90            .collect::<Vec<_>>();
91
92        if arguments.len() > 1 {
93            tracing::warn!(
94                "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
95            );
96        }
97
98        let raw_data = if let Some(arg) = arguments.into_iter().next() {
99            arg
100        } else {
101            return Err(ExtractionError::NoData);
102        };
103
104        Ok(serde_json::from_value(raw_data)?)
105    }
106
107    pub async fn get_inner(&self) -> &Agent<M> {
108        &self.agent
109    }
110
111    pub async fn into_inner(self) -> Agent<M> {
112        self.agent
113    }
114}
115
116/// Builder for the Extractor
117pub struct ExtractorBuilder<
118    T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
119    M: CompletionModel,
120> {
121    agent_builder: AgentBuilder<M>,
122    _t: PhantomData<T>,
123}
124
125impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
126    ExtractorBuilder<T, M>
127{
128    pub fn new(model: M) -> Self {
129        Self {
130            agent_builder: AgentBuilder::new(model)
131                .preamble("\
132                    You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
133                    You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
134                    Use the `submit` function to submit the structured data.\n\
135                    Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
136                ")
137                .tool(SubmitTool::<T> {_t: PhantomData}),
138
139            _t: PhantomData,
140        }
141    }
142
143    /// Add additional preamble to the extractor
144    pub fn preamble(mut self, preamble: &str) -> Self {
145        self.agent_builder = self.agent_builder.append_preamble(&format!(
146            "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
147        ));
148        self
149    }
150
151    /// Add a context document to the extractor
152    pub fn context(mut self, doc: &str) -> Self {
153        self.agent_builder = self.agent_builder.context(doc);
154        self
155    }
156
157    pub fn additional_params(mut self, params: serde_json::Value) -> Self {
158        self.agent_builder = self.agent_builder.additional_params(params);
159        self
160    }
161
162    /// Set the maximum number of tokens for the completion
163    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
164        self.agent_builder = self.agent_builder.max_tokens(max_tokens);
165        self
166    }
167
168    /// Build the Extractor
169    pub fn build(self) -> Extractor<M, T> {
170        Extractor {
171            agent: self.agent_builder.build(),
172            _t: PhantomData,
173        }
174    }
175}
176
177#[derive(Deserialize, Serialize)]
178struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
179    _t: PhantomData<T>,
180}
181
182#[derive(Debug, thiserror::Error)]
183#[error("SubmitError")]
184struct SubmitError;
185
186impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
187    const NAME: &'static str = SUBMIT_TOOL_NAME;
188    type Error = SubmitError;
189    type Args = T;
190    type Output = T;
191
192    async fn definition(&self, _prompt: String) -> ToolDefinition {
193        ToolDefinition {
194            name: Self::NAME.to_string(),
195            description: "Submit the structured data you extracted from the provided text."
196                .to_string(),
197            parameters: json!(schema_for!(T)),
198        }
199    }
200
201    async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
202        Ok(data)
203    }
204}