Skip to main content

ferro_ai/classifier/
mod.rs

1pub mod anthropic;
2pub mod provider;
3
4use crate::error::Error;
5use provider::ClassificationProvider;
6use serde::de::DeserializeOwned;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::time::sleep;
11use tracing::{info, warn};
12
13/// Configuration for the AI classifier.
14///
15/// All fields have sensible defaults via [`Default`].
16#[derive(Debug, Clone)]
17pub struct ClassifierConfig {
18    /// Anthropic model ID (or equivalent for other providers).
19    pub model: String,
20    /// Maximum tokens in the response.
21    pub max_tokens: u32,
22    /// Number of retry attempts on transient errors (total attempts = max_retries + 1).
23    pub max_retries: u32,
24    /// Delay between retry attempts.
25    pub retry_delay: Duration,
26    /// Minimum confidence score required to return a successful result.
27    ///
28    /// If the response includes a `"confidence"` field below this threshold,
29    /// an [`Error::LowConfidence`] is returned. Set to `0.0` to disable.
30    pub confidence_threshold: f64,
31}
32
33impl Default for ClassifierConfig {
34    fn default() -> Self {
35        Self {
36            model: String::new(), // resolved from client.default_model() at call time (D-03)
37            max_tokens: 1024,
38            max_retries: 1,
39            retry_delay: Duration::from_secs(1),
40            confidence_threshold: 0.7,
41        }
42    }
43}
44
45/// The result of a successful classification.
46#[derive(Debug)]
47pub struct ClassificationResult<T> {
48    /// The deserialized output value.
49    pub value: T,
50    /// Confidence score if the provider included one in the response.
51    ///
52    /// The schema must include a `"confidence"` field of type `f64` for this
53    /// to be populated; the Anthropic API does not return metadata outside the
54    /// schema.
55    pub confidence: Option<f64>,
56    /// Raw JSON returned by the provider, useful for prompt improvement feedback.
57    pub raw_json: serde_json::Value,
58}
59
60/// Generic AI classification facade.
61///
62/// `T` is the output type. It must implement [`serde::de::DeserializeOwned`] so
63/// the raw JSON from the provider can be deserialized into it.
64///
65/// # Example
66///
67/// ```rust,ignore
68/// use ferro_ai::{Classifier, ClassifierConfig, AnthropicProvider};
69/// use serde::Deserialize;
70///
71/// #[derive(Deserialize)]
72/// struct Intent { category: String, confidence: f64 }
73///
74/// async fn classify_message(text: &str) -> ferro_ai::Error {
75///     let provider = AnthropicProvider::from_env().unwrap();
76///     let classifier = Classifier::<Intent>::new(
77///         std::sync::Arc::new(provider),
78///         ClassifierConfig::default(),
79///     );
80///     let schema = serde_json::json!({ /* ... */ });
81///     let result = classifier.classify("You classify intents.", text, &schema).await?;
82///     println!("category: {}", result.value.category);
83///     Ok(())
84/// }
85/// ```
86pub struct Classifier<T> {
87    provider: Arc<dyn ClassificationProvider>,
88    config: ClassifierConfig,
89    _phantom: PhantomData<T>,
90}
91
92impl<T: DeserializeOwned> Classifier<T> {
93    /// Create a new classifier with the given provider and configuration.
94    pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
95        Self {
96            provider,
97            config,
98            _phantom: PhantomData,
99        }
100    }
101
102    /// Classify using the given prompts and JSON schema.
103    ///
104    /// Retries on transient errors up to `config.max_retries` additional times.
105    /// Fails immediately on permanent errors (auth, bad request, schema mismatch).
106    pub async fn classify(
107        &self,
108        system_prompt: &str,
109        user_prompt: &str,
110        schema: &serde_json::Value,
111    ) -> Result<ClassificationResult<T>, Error> {
112        let max_attempts = self.config.max_retries + 1;
113        let mut last_error: Option<Error> = None;
114
115        for attempt in 1..=max_attempts {
116            info!(
117                model = %self.config.model,
118                attempt,
119                max_attempts,
120                "Classifying"
121            );
122
123            match self
124                .provider
125                .classify_raw(system_prompt, user_prompt, schema, &self.config)
126                .await
127            {
128                Ok(raw_json) => {
129                    let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
130
131                    if let Some(conf) = confidence {
132                        if conf < self.config.confidence_threshold {
133                            return Err(Error::LowConfidence {
134                                best_guess: raw_json,
135                                confidence: conf,
136                            });
137                        }
138                    }
139
140                    let value = serde_json::from_value::<T>(raw_json.clone())
141                        .map_err(|e| Error::Deserialization(e.to_string()))?;
142
143                    return Ok(ClassificationResult {
144                        value,
145                        confidence,
146                        raw_json,
147                    });
148                }
149                Err(e) if !e.is_retryable() => {
150                    // Do not retry permanent errors
151                    return Err(e);
152                }
153                Err(e) => {
154                    warn!(attempt, error = %e, "Classification attempt failed, may retry");
155                    last_error = Some(e);
156                    if attempt < max_attempts {
157                        sleep(self.config.retry_delay).await;
158                    }
159                }
160            }
161        }
162
163        // All attempts exhausted
164        match last_error {
165            Some(Error::Timeout) => Err(Error::Timeout),
166            Some(e) => Err(e),
167            None => Err(Error::Timeout),
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use async_trait::async_trait;
176    use serde::Deserialize;
177    use std::sync::atomic::{AtomicU32, Ordering};
178    use std::sync::Arc;
179
180    #[test]
181    fn test_classifier_config_defaults() {
182        let config = ClassifierConfig::default();
183        // model is empty string after D-03: resolved from client.default_model() at call time
184        assert!(config.model.is_empty());
185        assert_eq!(config.max_tokens, 1024);
186        assert_eq!(config.max_retries, 1);
187        assert_eq!(config.retry_delay, Duration::from_secs(1));
188        assert_eq!(config.confidence_threshold, 0.7);
189    }
190
191    #[derive(Debug, Deserialize)]
192    struct SampleOutput {
193        category: String,
194    }
195
196    struct ConstProvider {
197        response: serde_json::Value,
198    }
199
200    #[async_trait]
201    impl ClassificationProvider for ConstProvider {
202        async fn classify_raw(
203            &self,
204            _system_prompt: &str,
205            _user_prompt: &str,
206            _schema: &serde_json::Value,
207            _config: &ClassifierConfig,
208        ) -> Result<serde_json::Value, Error> {
209            Ok(self.response.clone())
210        }
211    }
212
213    #[tokio::test]
214    async fn test_classification_result_deserialization() {
215        let provider = ConstProvider {
216            response: serde_json::json!({"category": "greeting"}),
217        };
218        let classifier = Classifier::<SampleOutput>::new(
219            Arc::new(provider),
220            ClassifierConfig {
221                confidence_threshold: 0.0,
222                ..Default::default()
223            },
224        );
225        let schema = serde_json::json!({});
226        let result = classifier
227            .classify("system", "user", &schema)
228            .await
229            .unwrap();
230        assert_eq!(result.value.category, "greeting");
231        assert!(result.confidence.is_none());
232    }
233
234    #[tokio::test]
235    async fn test_classification_extracts_confidence() {
236        #[derive(Debug, Deserialize)]
237        #[allow(dead_code)]
238        struct WithConfidence {
239            category: String,
240            confidence: f64,
241        }
242
243        let provider = ConstProvider {
244            response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
245        };
246        let classifier = Classifier::<WithConfidence>::new(
247            Arc::new(provider),
248            ClassifierConfig {
249                confidence_threshold: 0.5,
250                ..Default::default()
251            },
252        );
253        let result = classifier
254            .classify("system", "user", &serde_json::json!({}))
255            .await
256            .unwrap();
257        assert_eq!(result.confidence, Some(0.9));
258    }
259
260    struct CountingProvider {
261        call_count: Arc<AtomicU32>,
262        fail_times: u32,
263    }
264
265    #[async_trait]
266    impl ClassificationProvider for CountingProvider {
267        async fn classify_raw(
268            &self,
269            _system_prompt: &str,
270            _user_prompt: &str,
271            _schema: &serde_json::Value,
272            _config: &ClassifierConfig,
273        ) -> Result<serde_json::Value, Error> {
274            let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
275            if count <= self.fail_times {
276                Err(Error::Provider {
277                    status: Some(500),
278                    message: "internal server error".into(),
279                })
280            } else {
281                Ok(serde_json::json!({"category": "ok"}))
282            }
283        }
284    }
285
286    #[tokio::test]
287    async fn test_retry_on_transient_error() {
288        let call_count = Arc::new(AtomicU32::new(0));
289        let provider = CountingProvider {
290            call_count: Arc::clone(&call_count),
291            fail_times: 1, // fail once, succeed on second attempt
292        };
293        let config = ClassifierConfig {
294            max_retries: 1,
295            retry_delay: Duration::from_millis(1), // fast for tests
296            confidence_threshold: 0.0,
297            ..Default::default()
298        };
299        let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
300        let result = classifier
301            .classify("s", "u", &serde_json::json!({}))
302            .await
303            .unwrap();
304        assert_eq!(result.value.category, "ok");
305        assert_eq!(call_count.load(Ordering::SeqCst), 2);
306    }
307
308    #[tokio::test]
309    async fn test_no_retry_on_permanent_error() {
310        let call_count = Arc::new(AtomicU32::new(0));
311        let provider = CountingProvider {
312            call_count: Arc::clone(&call_count),
313            fail_times: 10, // always fail with 401
314        };
315
316        struct PermanentProvider {
317            call_count: Arc<AtomicU32>,
318        }
319
320        #[async_trait]
321        impl ClassificationProvider for PermanentProvider {
322            async fn classify_raw(
323                &self,
324                _system_prompt: &str,
325                _user_prompt: &str,
326                _schema: &serde_json::Value,
327                _config: &ClassifierConfig,
328            ) -> Result<serde_json::Value, Error> {
329                self.call_count.fetch_add(1, Ordering::SeqCst);
330                Err(Error::Provider {
331                    status: Some(401),
332                    message: "unauthorized".into(),
333                })
334            }
335        }
336
337        drop(provider); // avoid unused warning
338        let perm_count = Arc::new(AtomicU32::new(0));
339        let perm_provider = PermanentProvider {
340            call_count: Arc::clone(&perm_count),
341        };
342        let config = ClassifierConfig {
343            max_retries: 3,
344            retry_delay: Duration::from_millis(1),
345            confidence_threshold: 0.0,
346            ..Default::default()
347        };
348        let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
349        let result = classifier.classify("s", "u", &serde_json::json!({})).await;
350        assert!(result.is_err());
351        // Must not retry on permanent error — only 1 call
352        assert_eq!(perm_count.load(Ordering::SeqCst), 1);
353    }
354}