Skip to main content

ferro_ai/classifier/
mod.rs

1#[cfg(feature = "llm")]
2pub mod anthropic;
3pub mod provider;
4
5use crate::error::Error;
6use provider::ClassificationProvider;
7use serde::de::DeserializeOwned;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::time::sleep;
12use tracing::{info, warn};
13
14/// Configuration for the AI classifier.
15///
16/// All fields have sensible defaults via [`Default`].
17#[derive(Debug, Clone)]
18pub struct ClassifierConfig {
19    /// Anthropic model ID (or equivalent for other providers).
20    pub model: String,
21    /// Maximum tokens in the response.
22    pub max_tokens: u32,
23    /// Number of retry attempts on transient errors (total attempts = max_retries + 1).
24    pub max_retries: u32,
25    /// Delay between retry attempts.
26    pub retry_delay: Duration,
27    /// Minimum confidence score required to return a successful result.
28    ///
29    /// If the response includes a `"confidence"` field below this threshold,
30    /// an [`Error::LowConfidence`] is returned. Set to `0.0` to disable.
31    pub confidence_threshold: f64,
32}
33
34impl Default for ClassifierConfig {
35    fn default() -> Self {
36        Self {
37            model: String::new(), // resolved from client.default_model() at call time (D-03)
38            max_tokens: 1024,
39            max_retries: 1,
40            retry_delay: Duration::from_secs(1),
41            confidence_threshold: 0.7,
42        }
43    }
44}
45
46/// The result of a successful classification.
47#[derive(Debug)]
48pub struct ClassificationResult<T> {
49    /// The deserialized output value.
50    pub value: T,
51    /// Confidence score if the provider included one in the response.
52    ///
53    /// The schema must include a `"confidence"` field of type `f64` for this
54    /// to be populated; the Anthropic API does not return metadata outside the
55    /// schema.
56    pub confidence: Option<f64>,
57    /// Raw JSON returned by the provider, useful for prompt improvement feedback.
58    pub raw_json: serde_json::Value,
59}
60
61/// Generic AI classification facade.
62///
63/// `T` is the output type. It must implement [`serde::de::DeserializeOwned`] so
64/// the raw JSON from the provider can be deserialized into it.
65///
66/// # Example
67///
68/// ```rust,ignore
69/// use ferro_ai::{Classifier, ClassifierConfig, AnthropicProvider};
70/// use serde::Deserialize;
71///
72/// #[derive(Deserialize)]
73/// struct Intent { category: String, confidence: f64 }
74///
75/// async fn classify_message(text: &str) -> ferro_ai::Error {
76///     let provider = AnthropicProvider::from_env().unwrap();
77///     let classifier = Classifier::<Intent>::new(
78///         std::sync::Arc::new(provider),
79///         ClassifierConfig::default(),
80///     );
81///     let schema = serde_json::json!({ /* ... */ });
82///     let result = classifier.classify("You classify intents.", text, &schema).await?;
83///     println!("category: {}", result.value.category);
84///     Ok(())
85/// }
86/// ```
87pub struct Classifier<T> {
88    provider: Arc<dyn ClassificationProvider>,
89    config: ClassifierConfig,
90    _phantom: PhantomData<T>,
91}
92
93impl<T: DeserializeOwned> Classifier<T> {
94    /// Create a new classifier with the given provider and configuration.
95    pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
96        Self {
97            provider,
98            config,
99            _phantom: PhantomData,
100        }
101    }
102
103    /// Classify using the given prompts and JSON schema.
104    ///
105    /// Retries on transient errors up to `config.max_retries` additional times.
106    /// Fails immediately on permanent errors (auth, bad request, schema mismatch).
107    pub async fn classify(
108        &self,
109        system_prompt: &str,
110        user_prompt: &str,
111        schema: &serde_json::Value,
112    ) -> Result<ClassificationResult<T>, Error> {
113        let max_attempts = self.config.max_retries + 1;
114        let mut last_error: Option<Error> = None;
115
116        for attempt in 1..=max_attempts {
117            info!(
118                model = %self.config.model,
119                attempt,
120                max_attempts,
121                "Classifying"
122            );
123
124            match self
125                .provider
126                .classify_raw(system_prompt, user_prompt, schema, &self.config)
127                .await
128            {
129                Ok(raw_json) => {
130                    let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
131
132                    if let Some(conf) = confidence {
133                        if conf < self.config.confidence_threshold {
134                            return Err(Error::LowConfidence {
135                                best_guess: raw_json,
136                                confidence: conf,
137                            });
138                        }
139                    }
140
141                    let value = serde_json::from_value::<T>(raw_json.clone())
142                        .map_err(|e| Error::Deserialization(e.to_string()))?;
143
144                    return Ok(ClassificationResult {
145                        value,
146                        confidence,
147                        raw_json,
148                    });
149                }
150                Err(e) if !e.is_retryable() => {
151                    // Do not retry permanent errors
152                    return Err(e);
153                }
154                Err(e) => {
155                    warn!(attempt, error = %e, "Classification attempt failed, may retry");
156                    last_error = Some(e);
157                    if attempt < max_attempts {
158                        sleep(self.config.retry_delay).await;
159                    }
160                }
161            }
162        }
163
164        // All attempts exhausted
165        match last_error {
166            Some(Error::Timeout) => Err(Error::Timeout),
167            Some(e) => Err(e),
168            None => Err(Error::Timeout),
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use async_trait::async_trait;
177    use serde::Deserialize;
178    use std::sync::atomic::{AtomicU32, Ordering};
179    use std::sync::Arc;
180
181    #[test]
182    fn test_classifier_config_defaults() {
183        let config = ClassifierConfig::default();
184        // model is empty string after D-03: resolved from client.default_model() at call time
185        assert!(config.model.is_empty());
186        assert_eq!(config.max_tokens, 1024);
187        assert_eq!(config.max_retries, 1);
188        assert_eq!(config.retry_delay, Duration::from_secs(1));
189        assert_eq!(config.confidence_threshold, 0.7);
190    }
191
192    #[derive(Debug, Deserialize)]
193    struct SampleOutput {
194        category: String,
195    }
196
197    struct ConstProvider {
198        response: serde_json::Value,
199    }
200
201    #[async_trait]
202    impl ClassificationProvider for ConstProvider {
203        async fn classify_raw(
204            &self,
205            _system_prompt: &str,
206            _user_prompt: &str,
207            _schema: &serde_json::Value,
208            _config: &ClassifierConfig,
209        ) -> Result<serde_json::Value, Error> {
210            Ok(self.response.clone())
211        }
212    }
213
214    #[tokio::test]
215    async fn test_classification_result_deserialization() {
216        let provider = ConstProvider {
217            response: serde_json::json!({"category": "greeting"}),
218        };
219        let classifier = Classifier::<SampleOutput>::new(
220            Arc::new(provider),
221            ClassifierConfig {
222                confidence_threshold: 0.0,
223                ..Default::default()
224            },
225        );
226        let schema = serde_json::json!({});
227        let result = classifier
228            .classify("system", "user", &schema)
229            .await
230            .unwrap();
231        assert_eq!(result.value.category, "greeting");
232        assert!(result.confidence.is_none());
233    }
234
235    #[tokio::test]
236    async fn test_classification_extracts_confidence() {
237        #[derive(Debug, Deserialize)]
238        #[allow(dead_code)]
239        struct WithConfidence {
240            category: String,
241            confidence: f64,
242        }
243
244        let provider = ConstProvider {
245            response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
246        };
247        let classifier = Classifier::<WithConfidence>::new(
248            Arc::new(provider),
249            ClassifierConfig {
250                confidence_threshold: 0.5,
251                ..Default::default()
252            },
253        );
254        let result = classifier
255            .classify("system", "user", &serde_json::json!({}))
256            .await
257            .unwrap();
258        assert_eq!(result.confidence, Some(0.9));
259    }
260
261    struct CountingProvider {
262        call_count: Arc<AtomicU32>,
263        fail_times: u32,
264    }
265
266    #[async_trait]
267    impl ClassificationProvider for CountingProvider {
268        async fn classify_raw(
269            &self,
270            _system_prompt: &str,
271            _user_prompt: &str,
272            _schema: &serde_json::Value,
273            _config: &ClassifierConfig,
274        ) -> Result<serde_json::Value, Error> {
275            let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
276            if count <= self.fail_times {
277                Err(Error::Provider {
278                    status: Some(500),
279                    message: "internal server error".into(),
280                })
281            } else {
282                Ok(serde_json::json!({"category": "ok"}))
283            }
284        }
285    }
286
287    #[tokio::test]
288    async fn test_retry_on_transient_error() {
289        let call_count = Arc::new(AtomicU32::new(0));
290        let provider = CountingProvider {
291            call_count: Arc::clone(&call_count),
292            fail_times: 1, // fail once, succeed on second attempt
293        };
294        let config = ClassifierConfig {
295            max_retries: 1,
296            retry_delay: Duration::from_millis(1), // fast for tests
297            confidence_threshold: 0.0,
298            ..Default::default()
299        };
300        let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
301        let result = classifier
302            .classify("s", "u", &serde_json::json!({}))
303            .await
304            .unwrap();
305        assert_eq!(result.value.category, "ok");
306        assert_eq!(call_count.load(Ordering::SeqCst), 2);
307    }
308
309    #[tokio::test]
310    async fn test_no_retry_on_permanent_error() {
311        let call_count = Arc::new(AtomicU32::new(0));
312        let provider = CountingProvider {
313            call_count: Arc::clone(&call_count),
314            fail_times: 10, // always fail with 401
315        };
316
317        struct PermanentProvider {
318            call_count: Arc<AtomicU32>,
319        }
320
321        #[async_trait]
322        impl ClassificationProvider for PermanentProvider {
323            async fn classify_raw(
324                &self,
325                _system_prompt: &str,
326                _user_prompt: &str,
327                _schema: &serde_json::Value,
328                _config: &ClassifierConfig,
329            ) -> Result<serde_json::Value, Error> {
330                self.call_count.fetch_add(1, Ordering::SeqCst);
331                Err(Error::Provider {
332                    status: Some(401),
333                    message: "unauthorized".into(),
334                })
335            }
336        }
337
338        drop(provider); // avoid unused warning
339        let perm_count = Arc::new(AtomicU32::new(0));
340        let perm_provider = PermanentProvider {
341            call_count: Arc::clone(&perm_count),
342        };
343        let config = ClassifierConfig {
344            max_retries: 3,
345            retry_delay: Duration::from_millis(1),
346            confidence_threshold: 0.0,
347            ..Default::default()
348        };
349        let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
350        let result = classifier.classify("s", "u", &serde_json::json!({})).await;
351        assert!(result.is_err());
352        // Must not retry on permanent error — only 1 call
353        assert_eq!(perm_count.load(Ordering::SeqCst), 1);
354    }
355}