alith_core/
extractor.rs

1use std::marker::PhantomData;
2
3use crate::{
4    agent::Agent,
5    chat::Completion,
6    task::TaskError,
7    tool::{StructureTool, ToolError},
8};
9use async_trait::async_trait;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12
13pub struct Extractor<M>
14where
15    M: Completion,
16{
17    agent: Agent<M>,
18}
19
20impl<M> Extractor<M>
21where
22    M: Completion,
23{
24    /// Constructor for Extractor that initializes the agent with the given model.
25    #[inline]
26    pub async fn new<T>(model: M) -> Self
27    where
28        T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync + 'static,
29    {
30        Self {
31            agent: Agent::new("extract-agent", model)
32                .preamble(
33                    r#"Extract the data structure from the input string.
34Note you MUST use the tool named `extractor` to extract the input string to the
35data structure.
36"#,
37                )
38                .tool(ExtractTool::<T> { _data: PhantomData })
39                .await,
40        }
41    }
42
43    /// Extract structure data from an input string.
44    #[inline]
45    pub async fn extract<T>(&self, input: &str) -> Result<T, ExtractionError>
46    where
47        T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync + 'static,
48    {
49        Ok(serde_json::from_str(&self.agent.prompt(input).await?)?)
50    }
51}
52
53#[derive(Debug, thiserror::Error)]
54#[error("Extraction error")]
55pub enum ExtractionError {
56    #[error("TaskError: {0}")]
57    TaskError(#[from] TaskError),
58    #[error("JsonError: {0}")]
59    JsonError(#[from] serde_json::Error),
60}
61
62struct ExtractTool<T>
63where
64    T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync,
65{
66    _data: PhantomData<T>,
67}
68
69#[async_trait]
70impl<T> StructureTool for ExtractTool<T>
71where
72    T: Serialize + for<'a> Deserialize<'a> + JsonSchema + Send + Sync,
73{
74    type Input = T;
75    type Output = T;
76
77    fn name(&self) -> &str {
78        "extractor"
79    }
80
81    fn description(&self) -> &str {
82        "Extract the data structure from the input string."
83    }
84
85    async fn run_with_args(&self, input: Self::Input) -> Result<Self::Output, ToolError> {
86        Ok(input)
87    }
88}