1use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder, AgentBuilderSimple},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40 message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction},
41 tool::Tool,
42 wasm_compat::{WasmCompatSend, WasmCompatSync},
43};
44
45const SUBMIT_TOOL_NAME: &str = "submit";
46
47#[derive(Debug, thiserror::Error)]
48pub enum ExtractionError {
49 #[error("No data extracted")]
50 NoData,
51
52 #[error("Failed to deserialize the extracted data: {0}")]
53 DeserializationError(#[from] serde_json::Error),
54
55 #[error("CompletionError: {0}")]
56 CompletionError(#[from] CompletionError),
57}
58
59pub struct Extractor<M, T>
61where
62 M: CompletionModel,
63 T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
64{
65 agent: Agent<M>,
66 _t: PhantomData<T>,
67 retries: u64,
68}
69
70impl<M, T> Extractor<M, T>
71where
72 M: CompletionModel,
73 T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync,
74{
75 pub async fn extract(
82 &self,
83 text: impl Into<Message> + WasmCompatSend,
84 ) -> Result<T, ExtractionError> {
85 let mut last_error = None;
86 let text_message = text.into();
87
88 for i in 0..=self.retries {
89 tracing::debug!(
90 "Attempting to extract JSON. Retries left: {retries}",
91 retries = self.retries - i
92 );
93 let attempt_text = text_message.clone();
94 match self.extract_json(attempt_text, vec![]).await {
95 Ok(data) => return Ok(data),
96 Err(e) => {
97 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
98 last_error = Some(e);
99 }
100 }
101 }
102
103 Err(last_error.unwrap_or(ExtractionError::NoData))
105 }
106
107 pub async fn extract_with_chat_history(
114 &self,
115 text: impl Into<Message> + WasmCompatSend,
116 chat_history: Vec<Message>,
117 ) -> Result<T, ExtractionError> {
118 let mut last_error = None;
119 let text_message = text.into();
120
121 for i in 0..=self.retries {
122 tracing::debug!(
123 "Attempting to extract JSON. Retries left: {retries}",
124 retries = self.retries - i
125 );
126 let attempt_text = text_message.clone();
127 match self.extract_json(attempt_text, chat_history.clone()).await {
128 Ok(data) => return Ok(data),
129 Err(e) => {
130 tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
131 last_error = Some(e);
132 }
133 }
134 }
135
136 Err(last_error.unwrap_or(ExtractionError::NoData))
138 }
139
140 async fn extract_json(
141 &self,
142 text: impl Into<Message> + WasmCompatSend,
143 messages: Vec<Message>,
144 ) -> Result<T, ExtractionError> {
145 let response = self.agent.completion(text, messages).await?.send().await?;
146
147 if !response.choice.iter().any(|x| {
148 let AssistantContent::ToolCall(ToolCall {
149 function: ToolFunction { name, .. },
150 ..
151 }) = x
152 else {
153 return false;
154 };
155
156 name == SUBMIT_TOOL_NAME
157 }) {
158 tracing::warn!(
159 "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
160 );
161 }
162
163 let arguments = response
164 .choice
165 .into_iter()
166 .filter_map(|content| {
168 if let AssistantContent::ToolCall(ToolCall {
169 function: ToolFunction { arguments, name },
170 ..
171 }) = content
172 {
173 if name == SUBMIT_TOOL_NAME {
174 Some(arguments)
175 } else {
176 None
177 }
178 } else {
179 None
180 }
181 })
182 .collect::<Vec<_>>();
183
184 if arguments.len() > 1 {
185 tracing::warn!(
186 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
187 );
188 }
189
190 let raw_data = if let Some(arg) = arguments.into_iter().next() {
191 arg
192 } else {
193 return Err(ExtractionError::NoData);
194 };
195
196 Ok(serde_json::from_value(raw_data)?)
197 }
198
199 pub async fn get_inner(&self) -> &Agent<M> {
200 &self.agent
201 }
202
203 pub async fn into_inner(self) -> Agent<M> {
204 self.agent
205 }
206}
207
208pub struct ExtractorBuilder<M, T>
210where
211 M: CompletionModel,
212 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
213{
214 agent_builder: AgentBuilderSimple<M>,
215 _t: PhantomData<T>,
216 retries: Option<u64>,
217}
218
219impl<M, T> ExtractorBuilder<M, T>
220where
221 M: CompletionModel,
222 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static,
223{
224 pub fn new(model: M) -> Self {
225 Self {
226 agent_builder: AgentBuilder::new(model)
227 .preamble("\
228 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
229 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
230 Use the `submit` function to submit the structured data.\n\
231 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
232 ")
233 .tool(SubmitTool::<T> {_t: PhantomData})
234 .tool_choice(ToolChoice::Required),
235 retries: None,
236 _t: PhantomData,
237 }
238 }
239
240 pub fn preamble(mut self, preamble: &str) -> Self {
242 self.agent_builder = self.agent_builder.append_preamble(&format!(
243 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
244 ));
245 self
246 }
247
248 pub fn context(mut self, doc: &str) -> Self {
250 self.agent_builder = self.agent_builder.context(doc);
251 self
252 }
253
254 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
255 self.agent_builder = self.agent_builder.additional_params(params);
256 self
257 }
258
259 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
261 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
262 self
263 }
264
265 pub fn retries(mut self, retries: u64) -> Self {
267 self.retries = Some(retries);
268 self
269 }
270
271 pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
273 self.agent_builder = self.agent_builder.tool_choice(choice);
274 self
275 }
276
277 pub fn build(self) -> Extractor<M, T> {
279 Extractor {
280 agent: self.agent_builder.build(),
281 _t: PhantomData,
282 retries: self.retries.unwrap_or(0),
283 }
284 }
285}
286
287#[derive(Deserialize, Serialize)]
288struct SubmitTool<T>
289where
290 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
291{
292 _t: PhantomData<T>,
293}
294
295#[derive(Debug, thiserror::Error)]
296#[error("SubmitError")]
297struct SubmitError;
298
299impl<T> Tool for SubmitTool<T>
300where
301 T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync,
302{
303 const NAME: &'static str = SUBMIT_TOOL_NAME;
304 type Error = SubmitError;
305 type Args = T;
306 type Output = T;
307
308 async fn definition(&self, _prompt: String) -> ToolDefinition {
309 ToolDefinition {
310 name: Self::NAME.to_string(),
311 description: "Submit the structured data you extracted from the provided text."
312 .to_string(),
313 parameters: json!(schema_for!(T)),
314 }
315 }
316
317 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
318 Ok(data)
319 }
320}