pub mod keyword;
pub mod refusal;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{Map, Value};
use crate::canonical::{ChatRequest, ChatResponse};
use crate::config::Config;
pub struct ClassifierContext {
pub settings: Map<String, Value>,
}
#[async_trait]
pub trait Classifier: Send + Sync {
fn id(&self) -> &'static str;
async fn classify(&self, ctx: &ClassifierContext, req: &ChatRequest) -> anyhow::Result<Vec<String>>;
}
struct ClassifierEntry {
classifier: Arc<dyn Classifier>,
enabled: bool,
settings: Map<String, Value>,
}
pub struct ClassifierRegistry {
entries: Vec<ClassifierEntry>,
}
impl ClassifierRegistry {
pub fn from_config(config: &Config) -> Self {
let classifiers: Vec<Arc<dyn Classifier>> = vec![Arc::new(keyword::KeywordClassifier)];
let entries = classifiers
.into_iter()
.map(|classifier| {
let cfg = config.classifiers.get(classifier.id());
ClassifierEntry {
enabled: cfg.is_some_and(|c| c.enabled),
settings: cfg.map(|c| c.settings.clone()).unwrap_or_default(),
classifier,
}
})
.collect();
ClassifierRegistry { entries }
}
pub async fn classify(&self, req: &ChatRequest) -> Vec<String> {
let mut tags: Vec<String> = Vec::new();
for entry in &self.entries {
if !entry.enabled {
continue;
}
let ctx = ClassifierContext {
settings: entry.settings.clone(),
};
match entry.classifier.classify(&ctx, req).await {
Ok(new_tags) => {
for tag in new_tags {
if !tags.contains(&tag) {
tags.push(tag);
}
}
}
Err(err) => {
tracing::warn!("classifier '{}' failed: {err}", entry.classifier.id());
}
}
}
tags
}
}
#[async_trait]
pub trait ResponseClassifier: Send + Sync {
fn id(&self) -> &'static str;
async fn classify(
&self,
ctx: &ClassifierContext,
req: &ChatRequest,
resp: &ChatResponse,
) -> anyhow::Result<Vec<String>>;
}
struct ResponseClassifierEntry {
classifier: Arc<dyn ResponseClassifier>,
enabled: bool,
settings: Map<String, Value>,
}
pub struct ResponseClassifierRegistry {
entries: Vec<ResponseClassifierEntry>,
}
impl ResponseClassifierRegistry {
pub fn from_config(config: &Config) -> Self {
let classifiers: Vec<Arc<dyn ResponseClassifier>> = vec![Arc::new(refusal::RefusalClassifier)];
let entries = classifiers
.into_iter()
.map(|classifier| {
let cfg = config.response_classifiers.get(classifier.id());
ResponseClassifierEntry {
enabled: cfg.is_some_and(|c| c.enabled),
settings: cfg.map(|c| c.settings.clone()).unwrap_or_default(),
classifier,
}
})
.collect();
ResponseClassifierRegistry { entries }
}
pub async fn classify(&self, req: &ChatRequest, resp: &ChatResponse) -> Vec<String> {
let mut tags: Vec<String> = Vec::new();
for entry in &self.entries {
if !entry.enabled {
continue;
}
let ctx = ClassifierContext {
settings: entry.settings.clone(),
};
match entry.classifier.classify(&ctx, req, resp).await {
Ok(new_tags) => {
for tag in new_tags {
if !tags.contains(&tag) {
tags.push(tag);
}
}
}
Err(err) => {
tracing::warn!("response classifier '{}' failed: {err}", entry.classifier.id());
}
}
}
tags
}
}