rig/
extractor.rs

1//! This module provides high-level abstractions for extracting structured data from text using LLMs.
2//!
3//! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`,
4//! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro.
5//!
6//! # Example
7//! ```
8//! use rig::providers::openai;
9//!
10//! // Initialize the OpenAI client
11//! let openai = openai::Client::new("your-open-ai-api-key");
12//!
13//! // Define the structure of the data you want to extract
14//! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
15//! struct Person {
16//!    name: Option<String>,
17//!    age: Option<u8>,
18//!    profession: Option<String>,
19//! }
20//!
21//! // Create the extractor
22//! let extractor = openai.extractor::<Person>(openai::GPT_4O)
23//!     .build();
24//!
25//! // Extract structured data from text
26//! let person = extractor.extract("John Doe is a 30 year old doctor.")
27//!     .await
28//!     .expect("Failed to extract data from text");
29//! ```
30
31use std::marker::PhantomData;
32
33use schemars::{schema_for, JsonSchema};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38    agent::{Agent, AgentBuilder},
39    completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
40    tool::Tool,
41};
42
43#[derive(Debug, thiserror::Error)]
44pub enum ExtractionError {
45    #[error("No data extracted")]
46    NoData,
47
48    #[error("Failed to deserialize the extracted data: {0}")]
49    DeserializationError(#[from] serde_json::Error),
50
51    #[error("PromptError: {0}")]
52    PromptError(#[from] PromptError),
53}
54
55/// Extractor for structured data from text
56pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
57    agent: Agent<M>,
58    _t: PhantomData<T>,
59}
60
61impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
62where
63    M: Sync,
64{
65    pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
66        let summary = self.agent.prompt(text).await?;
67
68        if summary.is_empty() {
69            return Err(ExtractionError::NoData);
70        }
71
72        Ok(serde_json::from_str(&summary)?)
73    }
74}
75
76/// Builder for the Extractor
77pub struct ExtractorBuilder<
78    T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
79    M: CompletionModel,
80> {
81    agent_builder: AgentBuilder<M>,
82    _t: PhantomData<T>,
83}
84
85impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
86    ExtractorBuilder<T, M>
87{
88    pub fn new(model: M) -> Self {
89        Self {
90            agent_builder: AgentBuilder::new(model)
91                .preamble("\
92                    You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
93                    You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
94                    Use the `submit` function to submit the structured data.\n\
95                    Be sure to fill out every field and ALWAYS CALL THE `submit` function, event with default values!!!.
96                ")
97                .tool(SubmitTool::<T> {_t: PhantomData}),
98            _t: PhantomData,
99        }
100    }
101
102    /// Add additional preamble to the extractor
103    pub fn preamble(mut self, preamble: &str) -> Self {
104        self.agent_builder = self.agent_builder.append_preamble(&format!(
105            "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
106        ));
107        self
108    }
109
110    /// Add a context document to the extractor
111    pub fn context(mut self, doc: &str) -> Self {
112        self.agent_builder = self.agent_builder.context(doc);
113        self
114    }
115
116    /// Build the Extractor
117    pub fn build(self) -> Extractor<M, T> {
118        Extractor {
119            agent: self.agent_builder.build(),
120            _t: PhantomData,
121        }
122    }
123}
124
125#[derive(Deserialize, Serialize)]
126struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
127    _t: PhantomData<T>,
128}
129
130#[derive(Debug, thiserror::Error)]
131#[error("SubmitError")]
132struct SubmitError;
133
134impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
135    const NAME: &'static str = "submit";
136    type Error = SubmitError;
137    type Args = T;
138    type Output = T;
139
140    async fn definition(&self, _prompt: String) -> ToolDefinition {
141        ToolDefinition {
142            name: Self::NAME.to_string(),
143            description: "Submit the structured data you extracted from the provided text."
144                .to_string(),
145            parameters: json!(schema_for!(T)),
146        }
147    }
148
149    async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
150        Ok(data)
151    }
152}