Skip to main content

zeph_llm/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Typed structured extraction from free-form text.
5//!
6//! [`Extractor`] wraps any [`LlmProvider`] and exposes a single
7//! [`extract::<T>()`](Extractor::extract) method that:
8//! 1. Injects a JSON schema derived from `T` into the system prompt.
9//! 2. Sends the input text as a user message.
10//! 3. Parses the response as `T`, retrying once on parse failure.
11//!
12//! # Examples
13//!
14//! ```rust,no_run
15//! use serde::Deserialize;
16//! use schemars::JsonSchema;
17//! use zeph_llm::Extractor;
18//!
19//! #[derive(Debug, Deserialize, JsonSchema)]
20//! struct Sentiment {
21//!     label: String,  // "positive" | "negative" | "neutral"
22//!     score: f32,
23//! }
24//!
25//! # async fn run(provider: &impl zeph_llm::provider::LlmProvider) -> Result<(), zeph_llm::LlmError> {
26//! let sentiment: Sentiment = Extractor::new(provider)
27//!     .with_preamble("Classify the sentiment of the following text.")
28//!     .extract("I love Rust!").await?;
29//! println!("{:?}", sentiment);
30//! # Ok(())
31//! # }
32//! ```
33
34use schemars::JsonSchema;
35use serde::de::DeserializeOwned;
36
37use crate::LlmError;
38use crate::provider::{LlmProvider, Message, Role};
39
40/// Structured data extractor built on top of any [`LlmProvider`].
41///
42/// See the [module documentation](self) for usage examples.
43pub struct Extractor<'a, P: LlmProvider> {
44    provider: &'a P,
45    preamble: Option<String>,
46}
47
48impl<'a, P: LlmProvider> Extractor<'a, P> {
49    /// Create a new extractor borrowing `provider`.
50    pub fn new(provider: &'a P) -> Self {
51        Self {
52            provider,
53            preamble: None,
54        }
55    }
56
57    /// Set an optional system-level preamble that guides the model on what to extract.
58    #[must_use]
59    pub fn with_preamble(mut self, preamble: impl Into<String>) -> Self {
60        self.preamble = Some(preamble.into());
61        self
62    }
63
64    /// Extract structured data of type `T` from free-form `input` text.
65    ///
66    /// The JSON schema for `T` is injected into the prompt automatically. On a parse failure
67    /// the call is retried once with the raw response appended for self-correction.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the provider fails or the response cannot be parsed after the retry.
72    pub async fn extract<T>(&self, input: &str) -> Result<T, LlmError>
73    where
74        T: DeserializeOwned + JsonSchema + 'static,
75    {
76        let mut messages = Vec::new();
77        if let Some(ref preamble) = self.preamble {
78            messages.push(Message::from_legacy(Role::System, preamble.clone()));
79        }
80        messages.push(Message::from_legacy(Role::User, input));
81        self.provider.chat_typed::<T>(&messages).await
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use crate::provider::{ChatStream, LlmProvider, Message};
89
90    struct StubProvider {
91        response: String,
92    }
93
94    impl LlmProvider for StubProvider {
95        async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
96            Ok(self.response.clone())
97        }
98
99        async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
100            let response = self.chat(messages).await?;
101            Ok(Box::pin(tokio_stream::once(Ok(
102                crate::StreamChunk::Content(response),
103            ))))
104        }
105
106        fn supports_streaming(&self) -> bool {
107            false
108        }
109
110        async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
111            Err(LlmError::EmbedUnsupported {
112                provider: "stub".into(),
113            })
114        }
115
116        fn supports_embeddings(&self) -> bool {
117            false
118        }
119
120        fn name(&self) -> &'static str {
121            "stub"
122        }
123    }
124
125    #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
126    struct TestOutput {
127        value: String,
128    }
129
130    #[tokio::test]
131    async fn extract_without_preamble() {
132        let provider = StubProvider {
133            response: r#"{"value": "result"}"#.into(),
134        };
135        let extractor = Extractor::new(&provider);
136        let result: TestOutput = extractor.extract("test input").await.unwrap();
137        assert_eq!(
138            result,
139            TestOutput {
140                value: "result".into()
141            }
142        );
143    }
144
145    #[tokio::test]
146    async fn extract_with_preamble() {
147        let provider = StubProvider {
148            response: r#"{"value": "with_preamble"}"#.into(),
149        };
150        let extractor = Extractor::new(&provider).with_preamble("Analyze this");
151        let result: TestOutput = extractor.extract("test input").await.unwrap();
152        assert_eq!(
153            result,
154            TestOutput {
155                value: "with_preamble".into()
156            }
157        );
158    }
159
160    #[tokio::test]
161    async fn extract_error_propagation() {
162        struct FailProvider;
163
164        impl LlmProvider for FailProvider {
165            async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
166                Err(LlmError::Unavailable)
167            }
168
169            async fn chat_stream(&self, _messages: &[Message]) -> Result<ChatStream, LlmError> {
170                Err(LlmError::Unavailable)
171            }
172
173            fn supports_streaming(&self) -> bool {
174                false
175            }
176
177            async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
178                Err(LlmError::Unavailable)
179            }
180
181            fn supports_embeddings(&self) -> bool {
182                false
183            }
184
185            fn name(&self) -> &'static str {
186                "fail"
187            }
188        }
189
190        let provider = FailProvider;
191        let extractor = Extractor::new(&provider);
192        let result = extractor.extract::<TestOutput>("test").await;
193        assert!(matches!(result, Err(LlmError::Unavailable)));
194    }
195}