1use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40 message::{AssistantContent, Message, ToolCall, ToolFunction},
41 tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48 #[error("No data extracted")]
49 NoData,
50
51 #[error("Failed to deserialize the extracted data: {0}")]
52 DeserializationError(#[from] serde_json::Error),
53
54 #[error("CompletionError: {0}")]
55 CompletionError(#[from] CompletionError),
56}
57
58pub struct Extractor<M, T>
60where
61 M: CompletionModel,
62 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
63{
64 agent: Agent<M>,
65 _t: PhantomData<T>,
66 retries: u64,
67}
68
69impl<M, T> Extractor<M, T>
70where
71 M: CompletionModel,
72 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync,
73{
74 pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
81 let mut last_error = None;
82 let text_message = text.into();
83
84 for i in 0..=self.retries {
85 tracing::debug!(
86 "Attempting to extract JSON. Retries left: {retries}",
87 retries = self.retries - i
88 );
89 let attempt_text = text_message.clone();
90 match self.extract_json(attempt_text).await {
91 Ok(data) => return Ok(data),
92 Err(e) => {
93 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
94 last_error = Some(e);
95 }
96 }
97 }
98
99 Err(last_error.unwrap_or(ExtractionError::NoData))
101 }
102
103 async fn extract_json(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
104 let response = self.agent.completion(text, vec![]).await?.send().await?;
105
106 if !response.choice.iter().any(|x| {
107 let AssistantContent::ToolCall(ToolCall {
108 function: ToolFunction { name, .. },
109 ..
110 }) = x
111 else {
112 return false;
113 };
114
115 name == SUBMIT_TOOL_NAME
116 }) {
117 tracing::warn!(
118 "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
119 );
120 }
121
122 let arguments = response
123 .choice
124 .into_iter()
125 .filter_map(|content| {
127 if let AssistantContent::ToolCall(ToolCall {
128 function: ToolFunction { arguments, name },
129 ..
130 }) = content
131 {
132 if name == SUBMIT_TOOL_NAME {
133 Some(arguments)
134 } else {
135 None
136 }
137 } else {
138 None
139 }
140 })
141 .collect::<Vec<_>>();
142
143 if arguments.len() > 1 {
144 tracing::warn!(
145 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
146 );
147 }
148
149 let raw_data = if let Some(arg) = arguments.into_iter().next() {
150 arg
151 } else {
152 return Err(ExtractionError::NoData);
153 };
154
155 Ok(serde_json::from_value(raw_data)?)
156 }
157
158 pub async fn get_inner(&self) -> &Agent<M> {
159 &self.agent
160 }
161
162 pub async fn into_inner(self) -> Agent<M> {
163 self.agent
164 }
165}
166
167pub struct ExtractorBuilder<M, T>
169where
170 M: CompletionModel,
171 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
172{
173 agent_builder: AgentBuilder<M>,
174 _t: PhantomData<T>,
175 retries: Option<u64>,
176}
177
178impl<M, T> ExtractorBuilder<M, T>
179where
180 M: CompletionModel,
181 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
182{
183 pub fn new(model: M) -> Self {
184 Self {
185 agent_builder: AgentBuilder::new(model)
186 .preamble("\
187 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
188 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
189 Use the `submit` function to submit the structured data.\n\
190 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
191 ")
192 .tool(SubmitTool::<T> {_t: PhantomData}),
193 retries: None,
194 _t: PhantomData,
195 }
196 }
197
198 pub fn preamble(mut self, preamble: &str) -> Self {
200 self.agent_builder = self.agent_builder.append_preamble(&format!(
201 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
202 ));
203 self
204 }
205
206 pub fn context(mut self, doc: &str) -> Self {
208 self.agent_builder = self.agent_builder.context(doc);
209 self
210 }
211
212 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
213 self.agent_builder = self.agent_builder.additional_params(params);
214 self
215 }
216
217 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
219 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
220 self
221 }
222
223 pub fn retries(mut self, retries: u64) -> Self {
225 self.retries = Some(retries);
226 self
227 }
228
229 pub fn build(self) -> Extractor<M, T> {
231 Extractor {
232 agent: self.agent_builder.build(),
233 _t: PhantomData,
234 retries: self.retries.unwrap_or(0),
235 }
236 }
237}
238
239#[derive(Deserialize, Serialize)]
240struct SubmitTool<T>
241where
242 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
243{
244 _t: PhantomData<T>,
245}
246
247#[derive(Debug, thiserror::Error)]
248#[error("SubmitError")]
249struct SubmitError;
250
251impl<T> Tool for SubmitTool<T>
252where
253 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
254{
255 const NAME: &'static str = SUBMIT_TOOL_NAME;
256 type Error = SubmitError;
257 type Args = T;
258 type Output = T;
259
260 async fn definition(&self, _prompt: String) -> ToolDefinition {
261 ToolDefinition {
262 name: Self::NAME.to_string(),
263 description: "Submit the structured data you extracted from the provided text."
264 .to_string(),
265 parameters: json!(schema_for!(T)),
266 }
267 }
268
269 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
270 Ok(data)
271 }
272}