1use std::marker::PhantomData;
32
33use schemars::{schema_for, JsonSchema};
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36
37use crate::{
38 agent::{Agent, AgentBuilder},
39 completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
40 tool::Tool,
41};
42
43#[derive(Debug, thiserror::Error)]
44pub enum ExtractionError {
45 #[error("No data extracted")]
46 NoData,
47
48 #[error("Failed to deserialize the extracted data: {0}")]
49 DeserializationError(#[from] serde_json::Error),
50
51 #[error("PromptError: {0}")]
52 PromptError(#[from] PromptError),
53}
54
55pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
57 agent: Agent<M>,
58 _t: PhantomData<T>,
59}
60
61impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
62where
63 M: Sync,
64{
65 pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> {
66 let summary = self.agent.prompt(text).await?;
67
68 if summary.is_empty() {
69 return Err(ExtractionError::NoData);
70 }
71
72 Ok(serde_json::from_str(&summary)?)
73 }
74}
75
76pub struct ExtractorBuilder<
78 T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static,
79 M: CompletionModel,
80> {
81 agent_builder: AgentBuilder<M>,
82 _t: PhantomData<T>,
83}
84
85impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
86 ExtractorBuilder<T, M>
87{
88 pub fn new(model: M) -> Self {
89 Self {
90 agent_builder: AgentBuilder::new(model)
91 .preamble("\
92 You are an AI assistant whose purpose is to extract structured data from the provided text.\n\
93 You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\
94 Use the `submit` function to submit the structured data.\n\
95 Be sure to fill out every field and ALWAYS CALL THE `submit` function, event with default values!!!.
96 ")
97 .tool(SubmitTool::<T> {_t: PhantomData}),
98 _t: PhantomData,
99 }
100 }
101
102 pub fn preamble(mut self, preamble: &str) -> Self {
104 self.agent_builder = self.agent_builder.append_preamble(&format!(
105 "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}"
106 ));
107 self
108 }
109
110 pub fn context(mut self, doc: &str) -> Self {
112 self.agent_builder = self.agent_builder.context(doc);
113 self
114 }
115
116 pub fn build(self) -> Extractor<M, T> {
118 Extractor {
119 agent: self.agent_builder.build(),
120 _t: PhantomData,
121 }
122 }
123}
124
125#[derive(Deserialize, Serialize)]
126struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
127 _t: PhantomData<T>,
128}
129
130#[derive(Debug, thiserror::Error)]
131#[error("SubmitError")]
132struct SubmitError;
133
134impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> {
135 const NAME: &'static str = "submit";
136 type Error = SubmitError;
137 type Args = T;
138 type Output = T;
139
140 async fn definition(&self, _prompt: String) -> ToolDefinition {
141 ToolDefinition {
142 name: Self::NAME.to_string(),
143 description: "Submit the structured data you extracted from the provided text."
144 .to_string(),
145 parameters: json!(schema_for!(T)),
146 }
147 }
148
149 async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> {
150 Ok(data)
151 }
152}