1use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40 message::{AssistantContent, Message, ToolCall, ToolFunction},
41 tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48 #[error("No data extracted")]
49 NoData,
50
51 #[error("Failed to deserialize the extracted data: {0}")]
52 DeserializationError(#[from] serde_json::Error),
53
54 #[error("CompletionError: {0}")]
55 CompletionError(#[from] CompletionError),
56}
57
58pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
60 agent: Agent<M>,
61 _t: PhantomData<T>,
62}
63
64impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
65where
66 M: Sync,
67{
68 pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
69 let response = self.agent.completion(text, vec![]).await?.send().await?;
70
71 let arguments = response
72 .choice
73 .into_iter()
74 .filter_map(|content| {
76 if let AssistantContent::ToolCall(ToolCall {
77 function: ToolFunction { arguments, name },
78 ..
79 }) = content
80 {
81 if name == SUBMIT_TOOL_NAME {
82 Some(arguments)
83 } else {
84 None
85 }
86 } else {
87 None
88 }
89 })
90 .collect::<Vec<_>>();
91
92 if arguments.len() > 1 {
93 tracing::warn!(
94 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
95 );
96 }
97
98 let raw_data = if let Some(arg) = arguments.into_iter().next() {
99 arg
100 } else {
101 return Err(ExtractionError::NoData);
102 };
103
104 Ok(serde_json::from_value(raw_data)?)
105 }
106
107 pub async fn get_inner(&self) -> &Agent<M> {
108 &self.agent
109 }
110
111 pub async fn into_inner(self) -> Agent<M> {
112 self.agent
113 }
114}
115
116pub struct ExtractorBuilder<
118 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
119 M: CompletionModel,
120> {
121 agent_builder: AgentBuilder<M>,
122 _t: PhantomData<T>,
123}
124
125impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
126 ExtractorBuilder<T, M>
127{
128 pub fn new(model: M) -> Self {
129 Self {
130 agent_builder: AgentBuilder::new(model)
131 .preamble("\
132 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
133 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
134 Use the `submit` function to submit the structured data.\n\
135 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
136 ")
137 .tool(SubmitTool::<T> {_t: PhantomData}),
138
139 _t: PhantomData,
140 }
141 }
142
143 pub fn preamble(mut self, preamble: &str) -> Self {
145 self.agent_builder = self.agent_builder.append_preamble(&format!(
146 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
147 ));
148 self
149 }
150
151 pub fn context(mut self, doc: &str) -> Self {
153 self.agent_builder = self.agent_builder.context(doc);
154 self
155 }
156
157 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
158 self.agent_builder = self.agent_builder.additional_params(params);
159 self
160 }
161
162 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
164 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
165 self
166 }
167
168 pub fn build(self) -> Extractor<M, T> {
170 Extractor {
171 agent: self.agent_builder.build(),
172 _t: PhantomData,
173 }
174 }
175}
176
177#[derive(Deserialize, Serialize)]
178struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
179 _t: PhantomData<T>,
180}
181
182#[derive(Debug, thiserror::Error)]
183#[error("SubmitError")]
184struct SubmitError;
185
186impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
187 const NAME: &'static str = SUBMIT_TOOL_NAME;
188 type Error = SubmitError;
189 type Args = T;
190 type Output = T;
191
192 async fn definition(&self, _prompt: String) -> ToolDefinition {
193 ToolDefinition {
194 name: Self::NAME.to_string(),
195 description: "Submit the structured data you extracted from the provided text."
196 .to_string(),
197 parameters: json!(schema_for!(T)),
198 }
199 }
200
201 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
202 Ok(data)
203 }
204}