1use std::marker::PhantomData;
32
33use schemars::{JsonSchema, schema_for};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40 message::{AssistantContent, Message, ToolCall, ToolFunction},
41 tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48 #[error("No data extracted")]
49 NoData,
50
51 #[error("Failed to deserialize the extracted data: {0}")]
52 DeserializationError(#[from] serde_json::Error),
53
54 #[error("CompletionError: {0}")]
55 CompletionError(#[from] CompletionError),
56}
57
58pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
60 agent: Agent<M>,
61 _t: PhantomData<T>,
62}
63
64impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
65where
66 M: Sync,
67{
68 pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
69 let response = self.agent.completion(text, vec![]).await?.send().await?;
70
71 if !response.choice.iter().any(|x| {
72 let AssistantContent::ToolCall(ToolCall {
73 function: ToolFunction { name, .. },
74 ..
75 }) = x
76 else {
77 return false;
78 };
79
80 name == SUBMIT_TOOL_NAME
81 }) {
82 tracing::warn!(
83 "The submit tool was not called. If this happens more than once, please ensure the model you are using is powerful enough to reliably call tools."
84 );
85 }
86
87 let arguments = response
88 .choice
89 .into_iter()
90 .filter_map(|content| {
92 if let AssistantContent::ToolCall(ToolCall {
93 function: ToolFunction { arguments, name },
94 ..
95 }) = content
96 {
97 if name == SUBMIT_TOOL_NAME {
98 Some(arguments)
99 } else {
100 None
101 }
102 } else {
103 None
104 }
105 })
106 .collect::<Vec<_>>();
107
108 if arguments.len() > 1 {
109 tracing::warn!(
110 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
111 );
112 }
113
114 let raw_data = if let Some(arg) = arguments.into_iter().next() {
115 arg
116 } else {
117 return Err(ExtractionError::NoData);
118 };
119
120 Ok(serde_json::from_value(raw_data)?)
121 }
122
123 pub async fn get_inner(&self) -> &Agent<M> {
124 &self.agent
125 }
126
127 pub async fn into_inner(self) -> Agent<M> {
128 self.agent
129 }
130}
131
132pub struct ExtractorBuilder<
134 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
135 M: CompletionModel,
136> {
137 agent_builder: AgentBuilder<M>,
138 _t: PhantomData<T>,
139}
140
141impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
142 ExtractorBuilder<T, M>
143{
144 pub fn new(model: M) -> Self {
145 Self {
146 agent_builder: AgentBuilder::new(model)
147 .preamble("\
148 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
149 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
150 Use the `submit` function to submit the structured data.\n\
151 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
152 ")
153 .tool(SubmitTool::<T> {_t: PhantomData}),
154
155 _t: PhantomData,
156 }
157 }
158
159 pub fn preamble(mut self, preamble: &str) -> Self {
161 self.agent_builder = self.agent_builder.append_preamble(&format!(
162 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
163 ));
164 self
165 }
166
167 pub fn context(mut self, doc: &str) -> Self {
169 self.agent_builder = self.agent_builder.context(doc);
170 self
171 }
172
173 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
174 self.agent_builder = self.agent_builder.additional_params(params);
175 self
176 }
177
178 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
180 self.agent_builder = self.agent_builder.max_tokens(max_tokens);
181 self
182 }
183
184 pub fn build(self) -> Extractor<M, T> {
186 Extractor {
187 agent: self.agent_builder.build(),
188 _t: PhantomData,
189 }
190 }
191}
192
193#[derive(Deserialize, Serialize)]
194struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
195 _t: PhantomData<T>,
196}
197
198#[derive(Debug, thiserror::Error)]
199#[error("SubmitError")]
200struct SubmitError;
201
202impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
203 const NAME: &'static str = SUBMIT_TOOL_NAME;
204 type Error = SubmitError;
205 type Args = T;
206 type Output = T;
207
208 async fn definition(&self, _prompt: String) -> ToolDefinition {
209 ToolDefinition {
210 name: Self::NAME.to_string(),
211 description: "Submit the structured data you extracted from the provided text."
212 .to_string(),
213 parameters: json!(schema_for!(T)),
214 }
215 }
216
217 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
218 Ok(data)
219 }
220}