Skip to main content

ferro_ai/classifier/
mod.rs

1pub mod anthropic;
2pub mod provider;
3
4use crate::error::Error;
5use provider::ClassificationProvider;
6use serde::de::DeserializeOwned;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::time::sleep;
11use tracing::{info, warn};
12
13/// Configuration for the AI classifier.
14///
15/// All fields have sensible defaults via [`Default`].
16#[derive(Debug, Clone)]
17pub struct ClassifierConfig {
18    /// Anthropic model ID (or equivalent for other providers).
19    pub model: String,
20    /// Maximum tokens in the response.
21    pub max_tokens: u32,
22    /// Number of retry attempts on transient errors (total attempts = max_retries + 1).
23    pub max_retries: u32,
24    /// Delay between retry attempts.
25    pub retry_delay: Duration,
26    /// Minimum confidence score required to return a successful result.
27    ///
28    /// If the response includes a `"confidence"` field below this threshold,
29    /// an [`Error::LowConfidence`] is returned. Set to `0.0` to disable.
30    pub confidence_threshold: f64,
31}
32
33impl Default for ClassifierConfig {
34    fn default() -> Self {
35        Self {
36            model: "claude-sonnet-4-6".to_string(),
37            max_tokens: 1024,
38            max_retries: 1,
39            retry_delay: Duration::from_secs(1),
40            confidence_threshold: 0.7,
41        }
42    }
43}
44
45/// The result of a successful classification.
46#[derive(Debug)]
47pub struct ClassificationResult<T> {
48    /// The deserialized output value.
49    pub value: T,
50    /// Confidence score if the provider included one in the response.
51    ///
52    /// The schema must include a `"confidence"` field of type `f64` for this
53    /// to be populated; the Anthropic API does not return metadata outside the
54    /// schema.
55    pub confidence: Option<f64>,
56    /// Raw JSON returned by the provider, useful for prompt improvement feedback.
57    pub raw_json: serde_json::Value,
58}
59
60/// Generic AI classification facade.
61///
62/// `T` is the output type. It must implement [`serde::de::DeserializeOwned`] so
63/// the raw JSON from the provider can be deserialized into it.
64///
65/// # Example
66///
67/// ```rust,ignore
68/// use ferro_ai::{Classifier, ClassifierConfig, AnthropicProvider};
69/// use serde::Deserialize;
70///
71/// #[derive(Deserialize)]
72/// struct Intent { category: String, confidence: f64 }
73///
74/// async fn classify_message(text: &str) -> ferro_ai::Error {
75///     let provider = AnthropicProvider::from_env().unwrap();
76///     let classifier = Classifier::<Intent>::new(
77///         std::sync::Arc::new(provider),
78///         ClassifierConfig::default(),
79///     );
80///     let schema = serde_json::json!({ /* ... */ });
81///     let result = classifier.classify("You classify intents.", text, &schema).await?;
82///     println!("category: {}", result.value.category);
83///     Ok(())
84/// }
85/// ```
86pub struct Classifier<T> {
87    provider: Arc<dyn ClassificationProvider>,
88    config: ClassifierConfig,
89    _phantom: PhantomData<T>,
90}
91
92impl<T: DeserializeOwned> Classifier<T> {
93    /// Create a new classifier with the given provider and configuration.
94    pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
95        Self {
96            provider,
97            config,
98            _phantom: PhantomData,
99        }
100    }
101
102    /// Classify using the given prompts and JSON schema.
103    ///
104    /// Retries on transient errors up to `config.max_retries` additional times.
105    /// Fails immediately on permanent errors (auth, bad request, schema mismatch).
106    pub async fn classify(
107        &self,
108        system_prompt: &str,
109        user_prompt: &str,
110        schema: &serde_json::Value,
111    ) -> Result<ClassificationResult<T>, Error> {
112        let max_attempts = self.config.max_retries + 1;
113        let mut last_error: Option<Error> = None;
114
115        for attempt in 1..=max_attempts {
116            info!(
117                model = %self.config.model,
118                attempt,
119                max_attempts,
120                "Classifying"
121            );
122
123            match self
124                .provider
125                .classify_raw(system_prompt, user_prompt, schema, &self.config)
126                .await
127            {
128                Ok(raw_json) => {
129                    let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
130
131                    if let Some(conf) = confidence {
132                        if conf < self.config.confidence_threshold {
133                            return Err(Error::LowConfidence {
134                                best_guess: raw_json,
135                                confidence: conf,
136                            });
137                        }
138                    }
139
140                    let value = serde_json::from_value::<T>(raw_json.clone())
141                        .map_err(|e| Error::Deserialization(e.to_string()))?;
142
143                    return Ok(ClassificationResult {
144                        value,
145                        confidence,
146                        raw_json,
147                    });
148                }
149                Err(Error::Provider(msg)) if is_permanent_provider_error(&msg) => {
150                    // Do not retry permanent errors
151                    return Err(Error::Provider(msg));
152                }
153                Err(e) => {
154                    warn!(attempt, error = %e, "Classification attempt failed, may retry");
155                    last_error = Some(e);
156                    if attempt < max_attempts {
157                        sleep(self.config.retry_delay).await;
158                    }
159                }
160            }
161        }
162
163        // All attempts exhausted
164        match last_error {
165            Some(Error::Timeout) => Err(Error::Timeout),
166            Some(e) => Err(e),
167            None => Err(Error::Timeout),
168        }
169    }
170}
171
172/// Returns true if the provider error message indicates a permanent failure
173/// that should not be retried.
174///
175/// Permanent HTTP status codes: 400, 401, 403, 404, 422.
176pub(crate) fn is_permanent_provider_error(msg: &str) -> bool {
177    msg.contains("400")
178        || msg.contains("401")
179        || msg.contains("403")
180        || msg.contains("404")
181        || msg.contains("422")
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use async_trait::async_trait;
188    use serde::Deserialize;
189    use std::sync::atomic::{AtomicU32, Ordering};
190    use std::sync::Arc;
191
192    #[test]
193    fn test_classifier_config_defaults() {
194        let config = ClassifierConfig::default();
195        assert_eq!(config.model, "claude-sonnet-4-6");
196        assert_eq!(config.max_tokens, 1024);
197        assert_eq!(config.max_retries, 1);
198        assert_eq!(config.retry_delay, Duration::from_secs(1));
199        assert_eq!(config.confidence_threshold, 0.7);
200    }
201
202    #[derive(Debug, Deserialize)]
203    struct SampleOutput {
204        category: String,
205    }
206
207    struct ConstProvider {
208        response: serde_json::Value,
209    }
210
211    #[async_trait]
212    impl ClassificationProvider for ConstProvider {
213        async fn classify_raw(
214            &self,
215            _system_prompt: &str,
216            _user_prompt: &str,
217            _schema: &serde_json::Value,
218            _config: &ClassifierConfig,
219        ) -> Result<serde_json::Value, Error> {
220            Ok(self.response.clone())
221        }
222    }
223
224    #[tokio::test]
225    async fn test_classification_result_deserialization() {
226        let provider = ConstProvider {
227            response: serde_json::json!({"category": "greeting"}),
228        };
229        let classifier = Classifier::<SampleOutput>::new(
230            Arc::new(provider),
231            ClassifierConfig {
232                confidence_threshold: 0.0,
233                ..Default::default()
234            },
235        );
236        let schema = serde_json::json!({});
237        let result = classifier
238            .classify("system", "user", &schema)
239            .await
240            .unwrap();
241        assert_eq!(result.value.category, "greeting");
242        assert!(result.confidence.is_none());
243    }
244
245    #[tokio::test]
246    async fn test_classification_extracts_confidence() {
247        #[derive(Debug, Deserialize)]
248        #[allow(dead_code)]
249        struct WithConfidence {
250            category: String,
251            confidence: f64,
252        }
253
254        let provider = ConstProvider {
255            response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
256        };
257        let classifier = Classifier::<WithConfidence>::new(
258            Arc::new(provider),
259            ClassifierConfig {
260                confidence_threshold: 0.5,
261                ..Default::default()
262            },
263        );
264        let result = classifier
265            .classify("system", "user", &serde_json::json!({}))
266            .await
267            .unwrap();
268        assert_eq!(result.confidence, Some(0.9));
269    }
270
271    struct CountingProvider {
272        call_count: Arc<AtomicU32>,
273        fail_times: u32,
274    }
275
276    #[async_trait]
277    impl ClassificationProvider for CountingProvider {
278        async fn classify_raw(
279            &self,
280            _system_prompt: &str,
281            _user_prompt: &str,
282            _schema: &serde_json::Value,
283            _config: &ClassifierConfig,
284        ) -> Result<serde_json::Value, Error> {
285            let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
286            if count <= self.fail_times {
287                Err(Error::Provider("500 internal server error".to_string()))
288            } else {
289                Ok(serde_json::json!({"category": "ok"}))
290            }
291        }
292    }
293
294    #[tokio::test]
295    async fn test_retry_on_transient_error() {
296        let call_count = Arc::new(AtomicU32::new(0));
297        let provider = CountingProvider {
298            call_count: Arc::clone(&call_count),
299            fail_times: 1, // fail once, succeed on second attempt
300        };
301        let config = ClassifierConfig {
302            max_retries: 1,
303            retry_delay: Duration::from_millis(1), // fast for tests
304            confidence_threshold: 0.0,
305            ..Default::default()
306        };
307        let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
308        let result = classifier
309            .classify("s", "u", &serde_json::json!({}))
310            .await
311            .unwrap();
312        assert_eq!(result.value.category, "ok");
313        assert_eq!(call_count.load(Ordering::SeqCst), 2);
314    }
315
316    #[tokio::test]
317    async fn test_no_retry_on_permanent_error() {
318        let call_count = Arc::new(AtomicU32::new(0));
319        let provider = CountingProvider {
320            call_count: Arc::clone(&call_count),
321            fail_times: 10, // always fail with 401
322        };
323
324        struct PermanentProvider {
325            call_count: Arc<AtomicU32>,
326        }
327
328        #[async_trait]
329        impl ClassificationProvider for PermanentProvider {
330            async fn classify_raw(
331                &self,
332                _system_prompt: &str,
333                _user_prompt: &str,
334                _schema: &serde_json::Value,
335                _config: &ClassifierConfig,
336            ) -> Result<serde_json::Value, Error> {
337                self.call_count.fetch_add(1, Ordering::SeqCst);
338                Err(Error::Provider("401 unauthorized".to_string()))
339            }
340        }
341
342        drop(provider); // avoid unused warning
343        let perm_count = Arc::new(AtomicU32::new(0));
344        let perm_provider = PermanentProvider {
345            call_count: Arc::clone(&perm_count),
346        };
347        let config = ClassifierConfig {
348            max_retries: 3,
349            retry_delay: Duration::from_millis(1),
350            confidence_threshold: 0.0,
351            ..Default::default()
352        };
353        let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
354        let result = classifier.classify("s", "u", &serde_json::json!({})).await;
355        assert!(result.is_err());
356        // Must not retry on permanent error — only 1 call
357        assert_eq!(perm_count.load(Ordering::SeqCst), 1);
358    }
359}