Skip to main content

zeph_llm/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use schemars::JsonSchema;
5use serde::de::DeserializeOwned;
6
7use crate::LlmError;
8use crate::provider::{LlmProvider, Message, Role};
9
10pub struct Extractor<'a, P: LlmProvider> {
11    provider: &'a P,
12    preamble: Option<String>,
13}
14
15impl<'a, P: LlmProvider> Extractor<'a, P> {
16    pub fn new(provider: &'a P) -> Self {
17        Self {
18            provider,
19            preamble: None,
20        }
21    }
22
23    #[must_use]
24    pub fn with_preamble(mut self, preamble: impl Into<String>) -> Self {
25        self.preamble = Some(preamble.into());
26        self
27    }
28
29    /// # Errors
30    ///
31    /// Returns an error if the provider fails or the response cannot be parsed.
32    pub async fn extract<T>(&self, input: &str) -> Result<T, LlmError>
33    where
34        T: DeserializeOwned + JsonSchema + 'static,
35    {
36        let mut messages = Vec::new();
37        if let Some(ref preamble) = self.preamble {
38            messages.push(Message::from_legacy(Role::System, preamble.clone()));
39        }
40        messages.push(Message::from_legacy(Role::User, input));
41        self.provider.chat_typed::<T>(&messages).await
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use crate::provider::{ChatStream, LlmProvider, Message};
49
50    struct StubProvider {
51        response: String,
52    }
53
54    impl LlmProvider for StubProvider {
55        async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
56            Ok(self.response.clone())
57        }
58
59        async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
60            let response = self.chat(messages).await?;
61            Ok(Box::pin(tokio_stream::once(Ok(
62                crate::StreamChunk::Content(response),
63            ))))
64        }
65
66        fn supports_streaming(&self) -> bool {
67            false
68        }
69
70        async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
71            Err(LlmError::EmbedUnsupported {
72                provider: "stub".into(),
73            })
74        }
75
76        fn supports_embeddings(&self) -> bool {
77            false
78        }
79
80        fn name(&self) -> &'static str {
81            "stub"
82        }
83    }
84
85    #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
86    struct TestOutput {
87        value: String,
88    }
89
90    #[tokio::test]
91    async fn extract_without_preamble() {
92        let provider = StubProvider {
93            response: r#"{"value": "result"}"#.into(),
94        };
95        let extractor = Extractor::new(&provider);
96        let result: TestOutput = extractor.extract("test input").await.unwrap();
97        assert_eq!(
98            result,
99            TestOutput {
100                value: "result".into()
101            }
102        );
103    }
104
105    #[tokio::test]
106    async fn extract_with_preamble() {
107        let provider = StubProvider {
108            response: r#"{"value": "with_preamble"}"#.into(),
109        };
110        let extractor = Extractor::new(&provider).with_preamble("Analyze this");
111        let result: TestOutput = extractor.extract("test input").await.unwrap();
112        assert_eq!(
113            result,
114            TestOutput {
115                value: "with_preamble".into()
116            }
117        );
118    }
119
120    #[tokio::test]
121    async fn extract_error_propagation() {
122        struct FailProvider;
123
124        impl LlmProvider for FailProvider {
125            async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
126                Err(LlmError::Unavailable)
127            }
128
129            async fn chat_stream(&self, _messages: &[Message]) -> Result<ChatStream, LlmError> {
130                Err(LlmError::Unavailable)
131            }
132
133            fn supports_streaming(&self) -> bool {
134                false
135            }
136
137            async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
138                Err(LlmError::Unavailable)
139            }
140
141            fn supports_embeddings(&self) -> bool {
142                false
143            }
144
145            fn name(&self) -> &'static str {
146                "fail"
147            }
148        }
149
150        let provider = FailProvider;
151        let extractor = Extractor::new(&provider);
152        let result = extractor.extract::<TestOutput>("test").await;
153        assert!(matches!(result, Err(LlmError::Unavailable)));
154    }
155}