1#[cfg(feature = "llm")]
2pub mod anthropic;
3pub mod provider;
4
5use crate::error::Error;
6use provider::ClassificationProvider;
7use serde::de::DeserializeOwned;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::time::sleep;
12use tracing::{info, warn};
13
14#[derive(Debug, Clone)]
18pub struct ClassifierConfig {
19 pub model: String,
21 pub max_tokens: u32,
23 pub max_retries: u32,
25 pub retry_delay: Duration,
27 pub confidence_threshold: f64,
32}
33
34impl Default for ClassifierConfig {
35 fn default() -> Self {
36 Self {
37 model: String::new(), max_tokens: 1024,
39 max_retries: 1,
40 retry_delay: Duration::from_secs(1),
41 confidence_threshold: 0.7,
42 }
43 }
44}
45
46#[derive(Debug)]
48pub struct ClassificationResult<T> {
49 pub value: T,
51 pub confidence: Option<f64>,
57 pub raw_json: serde_json::Value,
59}
60
61pub struct Classifier<T> {
88 provider: Arc<dyn ClassificationProvider>,
89 config: ClassifierConfig,
90 _phantom: PhantomData<T>,
91}
92
93impl<T: DeserializeOwned> Classifier<T> {
94 pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
96 Self {
97 provider,
98 config,
99 _phantom: PhantomData,
100 }
101 }
102
103 pub async fn classify(
108 &self,
109 system_prompt: &str,
110 user_prompt: &str,
111 schema: &serde_json::Value,
112 ) -> Result<ClassificationResult<T>, Error> {
113 let max_attempts = self.config.max_retries + 1;
114 let mut last_error: Option<Error> = None;
115
116 for attempt in 1..=max_attempts {
117 info!(
118 model = %self.config.model,
119 attempt,
120 max_attempts,
121 "Classifying"
122 );
123
124 match self
125 .provider
126 .classify_raw(system_prompt, user_prompt, schema, &self.config)
127 .await
128 {
129 Ok(raw_json) => {
130 let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
131
132 if let Some(conf) = confidence {
133 if conf < self.config.confidence_threshold {
134 return Err(Error::LowConfidence {
135 best_guess: raw_json,
136 confidence: conf,
137 });
138 }
139 }
140
141 let value = serde_json::from_value::<T>(raw_json.clone())
142 .map_err(|e| Error::Deserialization(e.to_string()))?;
143
144 return Ok(ClassificationResult {
145 value,
146 confidence,
147 raw_json,
148 });
149 }
150 Err(e) if !e.is_retryable() => {
151 return Err(e);
153 }
154 Err(e) => {
155 warn!(attempt, error = %e, "Classification attempt failed, may retry");
156 last_error = Some(e);
157 if attempt < max_attempts {
158 sleep(self.config.retry_delay).await;
159 }
160 }
161 }
162 }
163
164 match last_error {
166 Some(Error::Timeout) => Err(Error::Timeout),
167 Some(e) => Err(e),
168 None => Err(Error::Timeout),
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use async_trait::async_trait;
177 use serde::Deserialize;
178 use std::sync::atomic::{AtomicU32, Ordering};
179 use std::sync::Arc;
180
181 #[test]
182 fn test_classifier_config_defaults() {
183 let config = ClassifierConfig::default();
184 assert!(config.model.is_empty());
186 assert_eq!(config.max_tokens, 1024);
187 assert_eq!(config.max_retries, 1);
188 assert_eq!(config.retry_delay, Duration::from_secs(1));
189 assert_eq!(config.confidence_threshold, 0.7);
190 }
191
192 #[derive(Debug, Deserialize)]
193 struct SampleOutput {
194 category: String,
195 }
196
197 struct ConstProvider {
198 response: serde_json::Value,
199 }
200
201 #[async_trait]
202 impl ClassificationProvider for ConstProvider {
203 async fn classify_raw(
204 &self,
205 _system_prompt: &str,
206 _user_prompt: &str,
207 _schema: &serde_json::Value,
208 _config: &ClassifierConfig,
209 ) -> Result<serde_json::Value, Error> {
210 Ok(self.response.clone())
211 }
212 }
213
214 #[tokio::test]
215 async fn test_classification_result_deserialization() {
216 let provider = ConstProvider {
217 response: serde_json::json!({"category": "greeting"}),
218 };
219 let classifier = Classifier::<SampleOutput>::new(
220 Arc::new(provider),
221 ClassifierConfig {
222 confidence_threshold: 0.0,
223 ..Default::default()
224 },
225 );
226 let schema = serde_json::json!({});
227 let result = classifier
228 .classify("system", "user", &schema)
229 .await
230 .unwrap();
231 assert_eq!(result.value.category, "greeting");
232 assert!(result.confidence.is_none());
233 }
234
235 #[tokio::test]
236 async fn test_classification_extracts_confidence() {
237 #[derive(Debug, Deserialize)]
238 #[allow(dead_code)]
239 struct WithConfidence {
240 category: String,
241 confidence: f64,
242 }
243
244 let provider = ConstProvider {
245 response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
246 };
247 let classifier = Classifier::<WithConfidence>::new(
248 Arc::new(provider),
249 ClassifierConfig {
250 confidence_threshold: 0.5,
251 ..Default::default()
252 },
253 );
254 let result = classifier
255 .classify("system", "user", &serde_json::json!({}))
256 .await
257 .unwrap();
258 assert_eq!(result.confidence, Some(0.9));
259 }
260
261 struct CountingProvider {
262 call_count: Arc<AtomicU32>,
263 fail_times: u32,
264 }
265
266 #[async_trait]
267 impl ClassificationProvider for CountingProvider {
268 async fn classify_raw(
269 &self,
270 _system_prompt: &str,
271 _user_prompt: &str,
272 _schema: &serde_json::Value,
273 _config: &ClassifierConfig,
274 ) -> Result<serde_json::Value, Error> {
275 let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
276 if count <= self.fail_times {
277 Err(Error::Provider {
278 status: Some(500),
279 message: "internal server error".into(),
280 })
281 } else {
282 Ok(serde_json::json!({"category": "ok"}))
283 }
284 }
285 }
286
287 #[tokio::test]
288 async fn test_retry_on_transient_error() {
289 let call_count = Arc::new(AtomicU32::new(0));
290 let provider = CountingProvider {
291 call_count: Arc::clone(&call_count),
292 fail_times: 1, };
294 let config = ClassifierConfig {
295 max_retries: 1,
296 retry_delay: Duration::from_millis(1), confidence_threshold: 0.0,
298 ..Default::default()
299 };
300 let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
301 let result = classifier
302 .classify("s", "u", &serde_json::json!({}))
303 .await
304 .unwrap();
305 assert_eq!(result.value.category, "ok");
306 assert_eq!(call_count.load(Ordering::SeqCst), 2);
307 }
308
309 #[tokio::test]
310 async fn test_no_retry_on_permanent_error() {
311 let call_count = Arc::new(AtomicU32::new(0));
312 let provider = CountingProvider {
313 call_count: Arc::clone(&call_count),
314 fail_times: 10, };
316
317 struct PermanentProvider {
318 call_count: Arc<AtomicU32>,
319 }
320
321 #[async_trait]
322 impl ClassificationProvider for PermanentProvider {
323 async fn classify_raw(
324 &self,
325 _system_prompt: &str,
326 _user_prompt: &str,
327 _schema: &serde_json::Value,
328 _config: &ClassifierConfig,
329 ) -> Result<serde_json::Value, Error> {
330 self.call_count.fetch_add(1, Ordering::SeqCst);
331 Err(Error::Provider {
332 status: Some(401),
333 message: "unauthorized".into(),
334 })
335 }
336 }
337
338 drop(provider); let perm_count = Arc::new(AtomicU32::new(0));
340 let perm_provider = PermanentProvider {
341 call_count: Arc::clone(&perm_count),
342 };
343 let config = ClassifierConfig {
344 max_retries: 3,
345 retry_delay: Duration::from_millis(1),
346 confidence_threshold: 0.0,
347 ..Default::default()
348 };
349 let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
350 let result = classifier.classify("s", "u", &serde_json::json!({})).await;
351 assert!(result.is_err());
352 assert_eq!(perm_count.load(Ordering::SeqCst), 1);
354 }
355}