1use std::marker::PhantomData;
2
3use crate::{
4 agent::Agent,
5 chat::Completion,
6 task::TaskError,
7 tool::{StructureTool, ToolError},
8};
9use async_trait::async_trait;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12
13pub struct Extractor<M>
14where
15 M: Completion,
16{
17 agent: Agent<M>,
18}
19
20impl<M> Extractor<M>
21where
22 M: Completion,
23{
24 #[inline]
26 pub async fn new<T>(model: M) -> Self
27 where
28 T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync + 'static,
29 {
30 Self {
31 agent: Agent::new("extract-agent", model)
32 .preamble(
33 r#"Extract the data structure from the input string.
34Note you MUST use the tool named `extractor` to extract the input string to the
35data structure.
36"#,
37 )
38 .tool(ExtractTool::<T> { _data: PhantomData })
39 .await,
40 }
41 }
42
43 #[inline]
45 pub async fn extract<T>(&self, input: &str) -> Result<T, ExtractionError>
46 where
47 T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync + 'static,
48 {
49 Ok(serde_json::from_str(&self.agent.prompt(input).await?)?)
50 }
51}
52
53#[derive(Debug, thiserror::Error)]
54#[error("Extraction error")]
55pub enum ExtractionError {
56 #[error("TaskError: {0}")]
57 TaskError(#[from] TaskError),
58 #[error("JsonError: {0}")]
59 JsonError(#[from] serde_json::Error),
60}
61
62struct ExtractTool<T>
63where
64 T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync,
65{
66 _data: PhantomData<T>,
67}
68
69#[async_trait]
70impl<T> StructureTool for ExtractTool<T>
71where
72 T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync,
73{
74 type Input = T;
75 type Output = T;
76
77 fn name(&self) -> &str {
78 "extractor"
79 }
80
81 fn description(&self) -> &str {
82 "Extract the data structure from the input string."
83 }
84
85 async fn run_with_args(&self, input: Self::Input) -> Result<Self::Output, ToolError> {
86 Ok(input)
87 }
88}