use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentLabel {
pub label: String,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabelResult {
pub labels: Vec<ContentLabel>,
pub blocked: bool,
pub reason: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum LabelError {
#[error("unsupported content type: {0}")]
UnsupportedType(String),
#[error("model not loaded: {0}")]
ModelNotLoaded(String),
#[error("classification failed: {0}")]
ClassificationFailed(String),
#[error("API error: {0}")]
ApiError(String),
}
pub trait MediaLabeler: Send + Sync {
fn classify(&self, data: &[u8], mime_type: &str) -> Result<LabelResult, LabelError>;
fn supports(&self, mime_type: &str) -> bool;
}
pub struct NoopLabeler;
impl MediaLabeler for NoopLabeler {
fn classify(&self, _data: &[u8], _mime_type: &str) -> Result<LabelResult, LabelError> {
Ok(LabelResult {
labels: vec![ContentLabel {
label: "safe".to_string(),
confidence: 1.0,
}],
blocked: false,
reason: None,
})
}
fn supports(&self, _mime_type: &str) -> bool {
true
}
}
pub struct BlockAllLabeler {
reason: String,
}
impl BlockAllLabeler {
pub fn new(reason: &str) -> Self {
Self {
reason: reason.to_string(),
}
}
}
impl MediaLabeler for BlockAllLabeler {
fn classify(&self, _data: &[u8], _mime_type: &str) -> Result<LabelResult, LabelError> {
Ok(LabelResult {
labels: vec![ContentLabel {
label: "blocked".to_string(),
confidence: 1.0,
}],
blocked: true,
reason: Some(self.reason.clone()),
})
}
fn supports(&self, _mime_type: &str) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_noop_labeler_safe() {
let labeler = NoopLabeler;
let result = labeler.classify(b"test data", "image/png").unwrap();
assert!(!result.blocked);
assert_eq!(result.labels.len(), 1);
assert_eq!(result.labels[0].label, "safe");
assert_eq!(result.labels[0].confidence, 1.0);
}
#[test]
fn test_noop_supports_everything() {
let labeler = NoopLabeler;
assert!(labeler.supports("image/png"));
assert!(labeler.supports("video/mp4"));
assert!(labeler.supports("application/pdf"));
}
#[test]
fn test_block_all_labeler() {
let labeler = BlockAllLabeler::new("maintenance mode");
let result = labeler.classify(b"data", "image/jpeg").unwrap();
assert!(result.blocked);
assert_eq!(result.reason, Some("maintenance mode".to_string()));
}
#[test]
fn test_content_label_serde() {
let label = ContentLabel {
label: "nsfw".to_string(),
confidence: 0.95,
};
let json = serde_json::to_string(&label).unwrap();
let parsed: ContentLabel = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.label, "nsfw");
assert!((parsed.confidence - 0.95).abs() < f32::EPSILON);
}
#[test]
fn test_label_result_serde() {
let result = LabelResult {
labels: vec![
ContentLabel {
label: "safe".into(),
confidence: 0.8,
},
ContentLabel {
label: "nature".into(),
confidence: 0.6,
},
],
blocked: false,
reason: None,
};
let json = serde_json::to_string(&result).unwrap();
let parsed: LabelResult = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.labels.len(), 2);
assert!(!parsed.blocked);
}
}