1pub mod anthropic;
2pub mod provider;
3
4use crate::error::Error;
5use provider::ClassificationProvider;
6use serde::de::DeserializeOwned;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::time::sleep;
11use tracing::{info, warn};
12
13#[derive(Debug, Clone)]
17pub struct ClassifierConfig {
18 pub model: String,
20 pub max_tokens: u32,
22 pub max_retries: u32,
24 pub retry_delay: Duration,
26 pub confidence_threshold: f64,
31}
32
33impl Default for ClassifierConfig {
34 fn default() -> Self {
35 Self {
36 model: "claude-sonnet-4-6".to_string(),
37 max_tokens: 1024,
38 max_retries: 1,
39 retry_delay: Duration::from_secs(1),
40 confidence_threshold: 0.7,
41 }
42 }
43}
44
45#[derive(Debug)]
47pub struct ClassificationResult<T> {
48 pub value: T,
50 pub confidence: Option<f64>,
56 pub raw_json: serde_json::Value,
58}
59
60pub struct Classifier<T> {
87 provider: Arc<dyn ClassificationProvider>,
88 config: ClassifierConfig,
89 _phantom: PhantomData<T>,
90}
91
92impl<T: DeserializeOwned> Classifier<T> {
93 pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
95 Self {
96 provider,
97 config,
98 _phantom: PhantomData,
99 }
100 }
101
102 pub async fn classify(
107 &self,
108 system_prompt: &str,
109 user_prompt: &str,
110 schema: &serde_json::Value,
111 ) -> Result<ClassificationResult<T>, Error> {
112 let max_attempts = self.config.max_retries + 1;
113 let mut last_error: Option<Error> = None;
114
115 for attempt in 1..=max_attempts {
116 info!(
117 model = %self.config.model,
118 attempt,
119 max_attempts,
120 "Classifying"
121 );
122
123 match self
124 .provider
125 .classify_raw(system_prompt, user_prompt, schema, &self.config)
126 .await
127 {
128 Ok(raw_json) => {
129 let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
130
131 if let Some(conf) = confidence {
132 if conf < self.config.confidence_threshold {
133 return Err(Error::LowConfidence {
134 best_guess: raw_json,
135 confidence: conf,
136 });
137 }
138 }
139
140 let value = serde_json::from_value::<T>(raw_json.clone())
141 .map_err(|e| Error::Deserialization(e.to_string()))?;
142
143 return Ok(ClassificationResult {
144 value,
145 confidence,
146 raw_json,
147 });
148 }
149 Err(Error::Provider(msg)) if is_permanent_provider_error(&msg) => {
150 return Err(Error::Provider(msg));
152 }
153 Err(e) => {
154 warn!(attempt, error = %e, "Classification attempt failed, may retry");
155 last_error = Some(e);
156 if attempt < max_attempts {
157 sleep(self.config.retry_delay).await;
158 }
159 }
160 }
161 }
162
163 match last_error {
165 Some(Error::Timeout) => Err(Error::Timeout),
166 Some(e) => Err(e),
167 None => Err(Error::Timeout),
168 }
169 }
170}
171
172pub(crate) fn is_permanent_provider_error(msg: &str) -> bool {
177 msg.contains("400")
178 || msg.contains("401")
179 || msg.contains("403")
180 || msg.contains("404")
181 || msg.contains("422")
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use async_trait::async_trait;
188 use serde::Deserialize;
189 use std::sync::atomic::{AtomicU32, Ordering};
190 use std::sync::Arc;
191
192 #[test]
193 fn test_classifier_config_defaults() {
194 let config = ClassifierConfig::default();
195 assert_eq!(config.model, "claude-sonnet-4-6");
196 assert_eq!(config.max_tokens, 1024);
197 assert_eq!(config.max_retries, 1);
198 assert_eq!(config.retry_delay, Duration::from_secs(1));
199 assert_eq!(config.confidence_threshold, 0.7);
200 }
201
202 #[derive(Debug, Deserialize)]
203 struct SampleOutput {
204 category: String,
205 }
206
207 struct ConstProvider {
208 response: serde_json::Value,
209 }
210
211 #[async_trait]
212 impl ClassificationProvider for ConstProvider {
213 async fn classify_raw(
214 &self,
215 _system_prompt: &str,
216 _user_prompt: &str,
217 _schema: &serde_json::Value,
218 _config: &ClassifierConfig,
219 ) -> Result<serde_json::Value, Error> {
220 Ok(self.response.clone())
221 }
222 }
223
224 #[tokio::test]
225 async fn test_classification_result_deserialization() {
226 let provider = ConstProvider {
227 response: serde_json::json!({"category": "greeting"}),
228 };
229 let classifier = Classifier::<SampleOutput>::new(
230 Arc::new(provider),
231 ClassifierConfig {
232 confidence_threshold: 0.0,
233 ..Default::default()
234 },
235 );
236 let schema = serde_json::json!({});
237 let result = classifier
238 .classify("system", "user", &schema)
239 .await
240 .unwrap();
241 assert_eq!(result.value.category, "greeting");
242 assert!(result.confidence.is_none());
243 }
244
245 #[tokio::test]
246 async fn test_classification_extracts_confidence() {
247 #[derive(Debug, Deserialize)]
248 #[allow(dead_code)]
249 struct WithConfidence {
250 category: String,
251 confidence: f64,
252 }
253
254 let provider = ConstProvider {
255 response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
256 };
257 let classifier = Classifier::<WithConfidence>::new(
258 Arc::new(provider),
259 ClassifierConfig {
260 confidence_threshold: 0.5,
261 ..Default::default()
262 },
263 );
264 let result = classifier
265 .classify("system", "user", &serde_json::json!({}))
266 .await
267 .unwrap();
268 assert_eq!(result.confidence, Some(0.9));
269 }
270
271 struct CountingProvider {
272 call_count: Arc<AtomicU32>,
273 fail_times: u32,
274 }
275
276 #[async_trait]
277 impl ClassificationProvider for CountingProvider {
278 async fn classify_raw(
279 &self,
280 _system_prompt: &str,
281 _user_prompt: &str,
282 _schema: &serde_json::Value,
283 _config: &ClassifierConfig,
284 ) -> Result<serde_json::Value, Error> {
285 let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
286 if count <= self.fail_times {
287 Err(Error::Provider("500 internal server error".to_string()))
288 } else {
289 Ok(serde_json::json!({"category": "ok"}))
290 }
291 }
292 }
293
294 #[tokio::test]
295 async fn test_retry_on_transient_error() {
296 let call_count = Arc::new(AtomicU32::new(0));
297 let provider = CountingProvider {
298 call_count: Arc::clone(&call_count),
299 fail_times: 1, };
301 let config = ClassifierConfig {
302 max_retries: 1,
303 retry_delay: Duration::from_millis(1), confidence_threshold: 0.0,
305 ..Default::default()
306 };
307 let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
308 let result = classifier
309 .classify("s", "u", &serde_json::json!({}))
310 .await
311 .unwrap();
312 assert_eq!(result.value.category, "ok");
313 assert_eq!(call_count.load(Ordering::SeqCst), 2);
314 }
315
316 #[tokio::test]
317 async fn test_no_retry_on_permanent_error() {
318 let call_count = Arc::new(AtomicU32::new(0));
319 let provider = CountingProvider {
320 call_count: Arc::clone(&call_count),
321 fail_times: 10, };
323
324 struct PermanentProvider {
325 call_count: Arc<AtomicU32>,
326 }
327
328 #[async_trait]
329 impl ClassificationProvider for PermanentProvider {
330 async fn classify_raw(
331 &self,
332 _system_prompt: &str,
333 _user_prompt: &str,
334 _schema: &serde_json::Value,
335 _config: &ClassifierConfig,
336 ) -> Result<serde_json::Value, Error> {
337 self.call_count.fetch_add(1, Ordering::SeqCst);
338 Err(Error::Provider("401 unauthorized".to_string()))
339 }
340 }
341
342 drop(provider); let perm_count = Arc::new(AtomicU32::new(0));
344 let perm_provider = PermanentProvider {
345 call_count: Arc::clone(&perm_count),
346 };
347 let config = ClassifierConfig {
348 max_retries: 3,
349 retry_delay: Duration::from_millis(1),
350 confidence_threshold: 0.0,
351 ..Default::default()
352 };
353 let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
354 let result = classifier.classify("s", "u", &serde_json::json!({})).await;
355 assert!(result.is_err());
356 assert_eq!(perm_count.load(Ordering::SeqCst), 1);
358 }
359}