1use std::marker::PhantomData;
32
33use schemars::{schema_for, JsonSchema};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder},
39 completion::{Completion, CompletionError, CompletionModel, ToolDefinition},
40 message::{AssistantContent, Message, ToolCall, ToolFunction},
41 tool::Tool,
42};
43
44const SUBMIT_TOOL_NAME: &str = "submit";
45
46#[derive(Debug, thiserror::Error)]
47pub enum ExtractionError {
48 #[error("No data extracted")]
49 NoData,
50
51 #[error("Failed to deserialize the extracted data: {0}")]
52 DeserializationError(#[from] serde_json::Error),
53
54 #[error("CompletionError: {0}")]
55 CompletionError(#[from] CompletionError),
56}
57
58pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
60 agent: Agent<M>,
61 _t: PhantomData<T>,
62}
63
64impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
65where
66 M: Sync,
67{
68 pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
69 let response = self.agent.completion(text, vec![]).await?.send().await?;
70
71 let arguments = response
72 .choice
73 .into_iter()
74 .filter_map(|content| {
76 if let AssistantContent::ToolCall(ToolCall {
77 function: ToolFunction { arguments, name },
78 ..
79 }) = content
80 {
81 if name == SUBMIT_TOOL_NAME {
82 Some(arguments)
83 } else {
84 None
85 }
86 } else {
87 None
88 }
89 })
90 .collect::<Vec<_>>();
91
92 if arguments.len() > 1 {
93 tracing::warn!(
94 "Multiple submit calls detected, using the last one. Providers / agents should only ensure one submit call."
95 );
96 }
97
98 let raw_data = if let Some(arg) = arguments.into_iter().next() {
99 arg
100 } else {
101 return Err(ExtractionError::NoData);
102 };
103
104 Ok(serde_json::from_value(raw_data)?)
105 }
106}
107
108pub struct ExtractorBuilder<
110 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
111 M: CompletionModel,
112> {
113 agent_builder: AgentBuilder<M>,
114 _t: PhantomData<T>,
115}
116
117impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
118 ExtractorBuilder<T, M>
119{
120 pub fn new(model: M) -> Self {
121 Self {
122 agent_builder: AgentBuilder::new(model)
123 .preamble("\
124 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
125 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
126 Use the `submit` function to submit the structured data.\n\
127 Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
128 ")
129 .tool(SubmitTool::<T> {_t: PhantomData}),
130
131 _t: PhantomData,
132 }
133 }
134
135 pub fn preamble(mut self, preamble: &str) -> Self {
137 self.agent_builder = self.agent_builder.append_preamble(&format!(
138 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
139 ));
140 self
141 }
142
143 pub fn context(mut self, doc: &str) -> Self {
145 self.agent_builder = self.agent_builder.context(doc);
146 self
147 }
148
149 pub fn additional_params(mut self, params: serde_json::Value) -> Self {
150 self.agent_builder = self.agent_builder.additional_params(params);
151 self
152 }
153
154 pub fn build(self) -> Extractor<M, T> {
156 Extractor {
157 agent: self.agent_builder.build(),
158 _t: PhantomData,
159 }
160 }
161}
162
163#[derive(Deserialize, Serialize)]
164struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
165 _t: PhantomData<T>,
166}
167
168#[derive(Debug, thiserror::Error)]
169#[error("SubmitError")]
170struct SubmitError;
171
172impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
173 const NAME: &'static str = SUBMIT_TOOL_NAME;
174 type Error = SubmitError;
175 type Args = T;
176 type Output = T;
177
178 async fn definition(&self, _prompt: String) -> ToolDefinition {
179 ToolDefinition {
180 name: Self::NAME.to_string(),
181 description: "Submit the structured data you extracted from the provided text."
182 .to_string(),
183 parameters: json!(schema_for!(T)),
184 }
185 }
186
187 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
188 Ok(data)
189 }
190}