pub mod anthropic;
pub mod provider;
use crate::error::Error;
use provider::ClassificationProvider;
use serde::de::DeserializeOwned;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{info, warn};
#[derive(Debug, Clone)]
pub struct ClassifierConfig {
pub model: String,
pub max_tokens: u32,
pub max_retries: u32,
pub retry_delay: Duration,
pub confidence_threshold: f64,
}
impl Default for ClassifierConfig {
fn default() -> Self {
Self {
model: "claude-sonnet-4-6".to_string(),
max_tokens: 1024,
max_retries: 1,
retry_delay: Duration::from_secs(1),
confidence_threshold: 0.7,
}
}
}
#[derive(Debug)]
pub struct ClassificationResult<T> {
pub value: T,
pub confidence: Option<f64>,
pub raw_json: serde_json::Value,
}
pub struct Classifier<T> {
provider: Arc<dyn ClassificationProvider>,
config: ClassifierConfig,
_phantom: PhantomData<T>,
}
impl<T: DeserializeOwned> Classifier<T> {
pub fn new(provider: Arc<dyn ClassificationProvider>, config: ClassifierConfig) -> Self {
Self {
provider,
config,
_phantom: PhantomData,
}
}
pub async fn classify(
&self,
system_prompt: &str,
user_prompt: &str,
schema: &serde_json::Value,
) -> Result<ClassificationResult<T>, Error> {
let max_attempts = self.config.max_retries + 1;
let mut last_error: Option<Error> = None;
for attempt in 1..=max_attempts {
info!(
model = %self.config.model,
attempt,
max_attempts,
"Classifying"
);
match self
.provider
.classify_raw(system_prompt, user_prompt, schema, &self.config)
.await
{
Ok(raw_json) => {
let confidence = raw_json.get("confidence").and_then(|v| v.as_f64());
if let Some(conf) = confidence {
if conf < self.config.confidence_threshold {
return Err(Error::LowConfidence {
best_guess: raw_json,
confidence: conf,
});
}
}
let value = serde_json::from_value::<T>(raw_json.clone())
.map_err(|e| Error::Deserialization(e.to_string()))?;
return Ok(ClassificationResult {
value,
confidence,
raw_json,
});
}
Err(Error::Provider(msg)) if is_permanent_provider_error(&msg) => {
return Err(Error::Provider(msg));
}
Err(e) => {
warn!(attempt, error = %e, "Classification attempt failed, may retry");
last_error = Some(e);
if attempt < max_attempts {
sleep(self.config.retry_delay).await;
}
}
}
}
match last_error {
Some(Error::Timeout) => Err(Error::Timeout),
Some(e) => Err(e),
None => Err(Error::Timeout),
}
}
}
pub(crate) fn is_permanent_provider_error(msg: &str) -> bool {
msg.contains("400")
|| msg.contains("401")
|| msg.contains("403")
|| msg.contains("404")
|| msg.contains("422")
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use serde::Deserialize;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_classifier_config_defaults() {
let config = ClassifierConfig::default();
assert_eq!(config.model, "claude-sonnet-4-6");
assert_eq!(config.max_tokens, 1024);
assert_eq!(config.max_retries, 1);
assert_eq!(config.retry_delay, Duration::from_secs(1));
assert_eq!(config.confidence_threshold, 0.7);
}
#[derive(Debug, Deserialize)]
struct SampleOutput {
category: String,
}
struct ConstProvider {
response: serde_json::Value,
}
#[async_trait]
impl ClassificationProvider for ConstProvider {
async fn classify_raw(
&self,
_system_prompt: &str,
_user_prompt: &str,
_schema: &serde_json::Value,
_config: &ClassifierConfig,
) -> Result<serde_json::Value, Error> {
Ok(self.response.clone())
}
}
#[tokio::test]
async fn test_classification_result_deserialization() {
let provider = ConstProvider {
response: serde_json::json!({"category": "greeting"}),
};
let classifier = Classifier::<SampleOutput>::new(
Arc::new(provider),
ClassifierConfig {
confidence_threshold: 0.0,
..Default::default()
},
);
let schema = serde_json::json!({});
let result = classifier
.classify("system", "user", &schema)
.await
.unwrap();
assert_eq!(result.value.category, "greeting");
assert!(result.confidence.is_none());
}
#[tokio::test]
async fn test_classification_extracts_confidence() {
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct WithConfidence {
category: String,
confidence: f64,
}
let provider = ConstProvider {
response: serde_json::json!({"category": "greeting", "confidence": 0.9}),
};
let classifier = Classifier::<WithConfidence>::new(
Arc::new(provider),
ClassifierConfig {
confidence_threshold: 0.5,
..Default::default()
},
);
let result = classifier
.classify("system", "user", &serde_json::json!({}))
.await
.unwrap();
assert_eq!(result.confidence, Some(0.9));
}
struct CountingProvider {
call_count: Arc<AtomicU32>,
fail_times: u32,
}
#[async_trait]
impl ClassificationProvider for CountingProvider {
async fn classify_raw(
&self,
_system_prompt: &str,
_user_prompt: &str,
_schema: &serde_json::Value,
_config: &ClassifierConfig,
) -> Result<serde_json::Value, Error> {
let count = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
if count <= self.fail_times {
Err(Error::Provider("500 internal server error".to_string()))
} else {
Ok(serde_json::json!({"category": "ok"}))
}
}
}
#[tokio::test]
async fn test_retry_on_transient_error() {
let call_count = Arc::new(AtomicU32::new(0));
let provider = CountingProvider {
call_count: Arc::clone(&call_count),
fail_times: 1, };
let config = ClassifierConfig {
max_retries: 1,
retry_delay: Duration::from_millis(1), confidence_threshold: 0.0,
..Default::default()
};
let classifier = Classifier::<SampleOutput>::new(Arc::new(provider), config);
let result = classifier
.classify("s", "u", &serde_json::json!({}))
.await
.unwrap();
assert_eq!(result.value.category, "ok");
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_no_retry_on_permanent_error() {
let call_count = Arc::new(AtomicU32::new(0));
let provider = CountingProvider {
call_count: Arc::clone(&call_count),
fail_times: 10, };
struct PermanentProvider {
call_count: Arc<AtomicU32>,
}
#[async_trait]
impl ClassificationProvider for PermanentProvider {
async fn classify_raw(
&self,
_system_prompt: &str,
_user_prompt: &str,
_schema: &serde_json::Value,
_config: &ClassifierConfig,
) -> Result<serde_json::Value, Error> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Err(Error::Provider("401 unauthorized".to_string()))
}
}
drop(provider); let perm_count = Arc::new(AtomicU32::new(0));
let perm_provider = PermanentProvider {
call_count: Arc::clone(&perm_count),
};
let config = ClassifierConfig {
max_retries: 3,
retry_delay: Duration::from_millis(1),
confidence_threshold: 0.0,
..Default::default()
};
let classifier = Classifier::<SampleOutput>::new(Arc::new(perm_provider), config);
let result = classifier.classify("s", "u", &serde_json::json!({})).await;
assert!(result.is_err());
assert_eq!(perm_count.load(Ordering::SeqCst), 1);
}
}