1pub mod anthropic;
2pub mod provider;
3
4use crate::error::Error;
5use provider::ClassificationProvider;
6use serde::de::DeserializeOwned;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::time::sleep;
11use tracing::{info, warn};
12
13#[derive(Debug, Clone)]
17pub struct ClassifierConfig {
18 pub model: String,
20 pub max_tokens: u32,
22 pub max_retries: u32,
24 pub retry_delay: Duration,
26 pub confidence_threshold: f64,
31}
32
33impl Default for ClassifierConfig {
34 fn default() -> Self {
35 Self {
36 model: String::new(), max_tokens: 1024,
38 max_retries: 1,
39 retry_delay: Duration::from_secs(1),
40 confidence_threshold: 0.7,
41 }
42 }
43}
44
45#[derive(Debug)]
47pub struct ClassificationResult<T> {
48 pub value: T,
50 pub confidence: Option<f64>,
56 pub raw_json: serde_json::Value,
58}
59
60pub struct Classifier<T> {
87 provider: Arc<dyn ClassificationProvider>,
88 config: ClassifierConfig,
89 _phantom: PhantomData<T>,
90}
91
92impl<T: DeserializeOwned> Classifier<T> {
93 pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
95 Self {
96 provider,
97 config,
98 _phantom: PhantomData,
99 }
100 }
101
102 pub async fn classify(
107 &self,
108 system_prompt: &str,
109 user_prompt: &str,
110 schema: &serde_json::Value,
111 ) -> Result<ClassificationResult<T>, Error> {
112 let max_attempts = self.config.max_retries + 1;
113 let mut last_error: Option<Error> = None;
114
115 for attempt in 1..=max_attempts {
116 info!(
117 model = %self.config.model,
118 attempt,
119 max_attempts,
120 "Classifying"
121 );
122
123 match self
124 .provider
125 .classify_raw(system_prompt, user_prompt, schema, &self.config)
126 .await
127 {
128 Ok(raw_json) => {
129 let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
130
131 if let Some(conf) = confidence {
132 if conf < self.config.confidence_threshold {
133 return Err(Error::LowConfidence {
134 best_guess: raw_json,
135 confidence: conf,
136 });
137 }
138 }
139
140 let value = serde_json::from_value::<T>(raw_json.clone())
141 .map_err(|e| Error::Deserialization(e.to_string()))?;
142
143 return Ok(ClassificationResult {
144 value,
145 confidence,
146 raw_json,
147 });
148 }
149 Err(e) if !e.is_retryable() => {
150 return Err(e);
152 }
153 Err(e) => {
154 warn!(attempt, error = %e, "Classification attempt failed, may retry");
155 last_error = Some(e);
156 if attempt < max_attempts {
157 sleep(self.config.retry_delay).await;
158 }
159 }
160 }
161 }
162
163 match last_error {
165 Some(Error::Timeout) => Err(Error::Timeout),
166 Some(e) => Err(e),
167 None => Err(Error::Timeout),
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use async_trait::async_trait;
176 use serde::Deserialize;
177 use std::sync::atomic::{AtomicU32, Ordering};
178 use std::sync::Arc;
179
180 #[test]
181 fn test_classifier_config_defaults() {
182 let config = ClassifierConfig::default();
183 assert!(config.model.is_empty());
185 assert_eq!(config.max_tokens, 1024);
186 assert_eq!(config.max_retries, 1);
187 assert_eq!(config.retry_delay, Duration::from_secs(1));
188 assert_eq!(config.confidence_threshold, 0.7);
189 }
190
191 #[derive(Debug, Deserialize)]
192 struct SampleOutput {
193 category: String,
194 }
195
196 struct ConstProvider {
197 response: serde_json::Value,
198 }
199
200 #[async_trait]
201 impl ClassificationProvider for ConstProvider {
202 async fn classify_raw(
203 &self,
204 _system_prompt: &str,
205 _user_prompt: &str,
206 _schema: &serde_json::Value,
207 _config: &ClassifierConfig,
208 ) -> Result<serde_json::Value, Error> {
209 Ok(self.response.clone())
210 }
211 }
212
213 #[tokio::test]
214 async fn test_classification_result_deserialization() {
215 let provider = ConstProvider {
216 response: serde_json::json!({"category": "greeting"}),
217 };
218 let classifier = Classifier::<SampleOutput>::new(
219 Arc::new(provider),
220 ClassifierConfig {
221 confidence_threshold: 0.0,
222 ..Default::default()
223 },
224 );
225 let schema = serde_json::json!({});
226 let result = classifier
227 .classify("system", "user", &schema)
228 .await
229 .unwrap();
230 assert_eq!(result.value.category, "greeting");
231 assert!(result.confidence.is_none());
232 }
233
234 #[tokio::test]
235 async fn test_classification_extracts_confidence() {
236 #[derive(Debug, Deserialize)]
237 #[allow(dead_code)]
238 struct WithConfidence {
239 category: String,
240 confidence: f64,
241 }
242
243 let provider = ConstProvider {
244 response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
245 };
246 let classifier = Classifier::<WithConfidence>::new(
247 Arc::new(provider),
248 ClassifierConfig {
249 confidence_threshold: 0.5,
250 ..Default::default()
251 },
252 );
253 let result = classifier
254 .classify("system", "user", &serde_json::json!({}))
255 .await
256 .unwrap();
257 assert_eq!(result.confidence, Some(0.9));
258 }
259
260 struct CountingProvider {
261 call_count: Arc<AtomicU32>,
262 fail_times: u32,
263 }
264
265 #[async_trait]
266 impl ClassificationProvider for CountingProvider {
267 async fn classify_raw(
268 &self,
269 _system_prompt: &str,
270 _user_prompt: &str,
271 _schema: &serde_json::Value,
272 _config: &ClassifierConfig,
273 ) -> Result<serde_json::Value, Error> {
274 let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
275 if count <= self.fail_times {
276 Err(Error::Provider {
277 status: Some(500),
278 message: "internal server error".into(),
279 })
280 } else {
281 Ok(serde_json::json!({"category": "ok"}))
282 }
283 }
284 }
285
286 #[tokio::test]
287 async fn test_retry_on_transient_error() {
288 let call_count = Arc::new(AtomicU32::new(0));
289 let provider = CountingProvider {
290 call_count: Arc::clone(&call_count),
291 fail_times: 1, };
293 let config = ClassifierConfig {
294 max_retries: 1,
295 retry_delay: Duration::from_millis(1), confidence_threshold: 0.0,
297 ..Default::default()
298 };
299 let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
300 let result = classifier
301 .classify("s", "u", &serde_json::json!({}))
302 .await
303 .unwrap();
304 assert_eq!(result.value.category, "ok");
305 assert_eq!(call_count.load(Ordering::SeqCst), 2);
306 }
307
308 #[tokio::test]
309 async fn test_no_retry_on_permanent_error() {
310 let call_count = Arc::new(AtomicU32::new(0));
311 let provider = CountingProvider {
312 call_count: Arc::clone(&call_count),
313 fail_times: 10, };
315
316 struct PermanentProvider {
317 call_count: Arc<AtomicU32>,
318 }
319
320 #[async_trait]
321 impl ClassificationProvider for PermanentProvider {
322 async fn classify_raw(
323 &self,
324 _system_prompt: &str,
325 _user_prompt: &str,
326 _schema: &serde_json::Value,
327 _config: &ClassifierConfig,
328 ) -> Result<serde_json::Value, Error> {
329 self.call_count.fetch_add(1, Ordering::SeqCst);
330 Err(Error::Provider {
331 status: Some(401),
332 message: "unauthorized".into(),
333 })
334 }
335 }
336
337 drop(provider); let perm_count = Arc::new(AtomicU32::new(0));
339 let perm_provider = PermanentProvider {
340 call_count: Arc::clone(&perm_count),
341 };
342 let config = ClassifierConfig {
343 max_retries: 3,
344 retry_delay: Duration::from_millis(1),
345 confidence_threshold: 0.0,
346 ..Default::default()
347 };
348 let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
349 let result = classifier.classify("s", "u", &serde_json::json!({})).await;
350 assert!(result.is_err());
351 assert_eq!(perm_count.load(Ordering::SeqCst), 1);
353 }
354}