use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::debertav2::{
Config as DebertaConfig, DebertaV2SeqClassificationModel,
};
use tokenizers::Tokenizer;
use crate::error::LlmError;
use super::candle::validate_safetensors;
use super::{
CHUNK_OVERLAP_TOKENS, ClassificationResult, ClassifierBackend, MAX_CHUNK_CONTENT_TOKENS,
};
struct ThreeClassInner {
model: DebertaV2SeqClassificationModel,
tokenizer: Tokenizer,
device: Device,
id2label: Vec<String>,
cls_token_id: u32,
sep_token_id: u32,
}
#[derive(Clone)]
pub struct CandleThreeClassClassifier {
repo_id: Arc<str>,
hf_token: Option<Arc<str>>,
sha256: Option<Arc<str>>,
inner: Arc<OnceLock<Arc<ThreeClassInner>>>,
}
impl std::fmt::Debug for CandleThreeClassClassifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CandleThreeClassClassifier")
.field("repo_id", &self.repo_id)
.finish_non_exhaustive()
}
}
impl CandleThreeClassClassifier {
#[must_use]
pub fn new(repo_id: impl Into<Arc<str>>) -> Self {
Self {
repo_id: repo_id.into(),
hf_token: None,
sha256: None,
inner: Arc::new(OnceLock::new()),
}
}
fn get_or_try_load(&self) -> Result<Arc<ThreeClassInner>, String> {
if let Some(inner) = self.inner.get() {
return Ok(Arc::clone(inner));
}
match Self::load_inner(
&self.repo_id,
self.hf_token.as_deref(),
self.sha256.as_deref(),
) {
Ok(inner) => {
let arc = Arc::new(inner);
let _ = self.inner.set(Arc::clone(&arc));
Ok(self.inner.get().map_or(arc, Arc::clone))
}
Err(e) => Err(e.to_string()),
}
}
#[must_use]
pub fn with_hf_token(mut self, token: impl Into<Arc<str>>) -> Self {
self.hf_token = Some(token.into());
self
}
#[must_use]
pub fn with_sha256(mut self, digest: impl Into<Arc<str>>) -> Self {
self.sha256 = Some(digest.into());
self
}
pub fn preload(&self) {
if let Err(e) = self.get_or_try_load() {
tracing::error!(
repo_id = %self.repo_id,
error = %e,
"three-class classifier failed to preload — will retry on first classify call"
);
}
}
fn run_chunk(
inner: &ThreeClassInner,
input_ids: &[u32],
) -> Result<ClassificationResult, LlmError> {
let seq_len = input_ids.len();
let ids_tensor = Tensor::new(input_ids, &inner.device)?.unsqueeze(0)?;
let token_type_ids = Tensor::zeros((1, seq_len), DType::I64, &inner.device)?;
let attention_mask = Tensor::ones((1, seq_len), DType::I64, &inner.device)?;
let logits =
inner
.model
.forward(&ids_tensor, Some(token_type_ids), Some(attention_mask))?;
let probs = candle_nn::ops::softmax(&logits.squeeze(0)?, 0)?;
let probs_vec = probs.to_vec1::<f32>().map_err(LlmError::Candle)?;
let (best_idx, best_score) = probs_vec
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or((0, 0.0), |(i, &s)| (i, s));
let label = inner
.id2label
.get(best_idx)
.cloned()
.unwrap_or_else(|| best_idx.to_string());
let is_positive = label.to_lowercase().contains("misaligned");
Ok(ClassificationResult {
label,
score: best_score,
is_positive,
spans: vec![],
})
}
fn classify_sync(
inner: &ThreeClassInner,
text: &str,
) -> Result<ClassificationResult, LlmError> {
let encoding = inner
.tokenizer
.encode(text, true)
.map_err(|e| LlmError::Inference(format!("tokenizer encode failed: {e}")))?;
let ids = encoding.get_ids();
if ids.is_empty() {
let label = inner
.id2label
.iter()
.find(|l| l.to_lowercase().contains("no"))
.cloned()
.unwrap_or_else(|| "no_instruction".into());
return Ok(ClassificationResult {
label,
score: 1.0,
is_positive: false,
spans: vec![],
});
}
let ids = if ids.len() >= 2
&& ids[0] == inner.cls_token_id
&& ids[ids.len() - 1] == inner.sep_token_id
{
&ids[1..ids.len() - 1]
} else {
ids
};
let mut best_positive: Option<ClassificationResult> = None;
let mut best_overall: Option<ClassificationResult> = None;
let mut start = 0usize;
while start < ids.len() {
let end = (start + MAX_CHUNK_CONTENT_TOKENS).min(ids.len());
let content = &ids[start..end];
let mut framed = Vec::with_capacity(content.len() + 2);
framed.push(inner.cls_token_id);
framed.extend_from_slice(content);
framed.push(inner.sep_token_id);
let result = Self::run_chunk(inner, &framed)?;
if result.is_positive {
let is_better = best_positive
.as_ref()
.is_none_or(|prev| result.score > prev.score);
if is_better {
best_positive = Some(result.clone());
}
}
let is_better_overall = best_overall
.as_ref()
.is_none_or(|prev| result.score > prev.score);
if is_better_overall {
best_overall = Some(result);
}
if end == ids.len() {
break;
}
start = end.saturating_sub(CHUNK_OVERLAP_TOKENS);
}
Ok(best_positive
.or(best_overall)
.unwrap_or(ClassificationResult {
label: inner
.id2label
.iter()
.find(|l| l.to_lowercase().contains("no"))
.cloned()
.unwrap_or_else(|| "no_instruction".into()),
score: 1.0,
is_positive: false,
spans: vec![],
}))
}
#[allow(unsafe_code)]
fn load_inner(
repo_id: &str,
hf_token: Option<&str>,
sha256: Option<&str>,
) -> Result<ThreeClassInner, LlmError> {
let api = hf_hub::api::sync::ApiBuilder::new()
.with_token(hf_token.map(str::to_owned))
.build()
.map_err(|e| {
LlmError::ModelLoad(format!("failed to create HuggingFace API client: {e}"))
})?;
let repo = api.model(repo_id.to_owned());
let config_path = repo.get("config.json").map_err(|e| {
LlmError::ModelLoad(format!(
"failed to download config.json from {repo_id}: {e}"
))
})?;
let tokenizer_path = repo.get("tokenizer.json").map_err(|e| {
LlmError::ModelLoad(format!(
"failed to download tokenizer.json from {repo_id}: {e}"
))
})?;
let weights_path = repo.get("model.safetensors").map_err(|e| {
LlmError::ModelLoad(format!(
"failed to download model.safetensors from {repo_id}: {e}"
))
})?;
if let Some(expected_hash) = sha256 {
super::verify_sha256(&weights_path, expected_hash)?;
}
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| LlmError::ModelLoad(format!("failed to read model config: {e}")))?;
let config: DebertaConfig = serde_json::from_str(&config_str)?;
let id2label: Vec<String> = config.id2label.as_ref().map_or_else(
|| {
vec![
"no_instruction".into(),
"aligned_instruction".into(),
"misaligned_instruction".into(),
]
},
|m| {
let mut sorted: Vec<(u32, String)> =
m.iter().map(|(k, v)| (*k, v.clone())).collect();
sorted.sort_by_key(|(k, _)| *k);
sorted.into_iter().map(|(_, v)| v).collect()
},
);
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| LlmError::ModelLoad(format!("failed to load tokenizer: {e}")))?;
let cls_token_id = tokenizer
.token_to_id("[CLS]")
.ok_or_else(|| LlmError::ModelLoad("tokenizer missing [CLS] token".into()))?;
let sep_token_id = tokenizer
.token_to_id("[SEP]")
.ok_or_else(|| LlmError::ModelLoad("tokenizer missing [SEP] token".into()))?;
validate_safetensors(&weights_path)?;
let device = crate::device::detect_device();
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)? };
let model = DebertaV2SeqClassificationModel::load(vb, &config, None)
.map_err(|e| LlmError::ModelLoad(format!("failed to load three-class model: {e}")))?;
Ok(ThreeClassInner {
model,
tokenizer,
device,
id2label,
cls_token_id,
sep_token_id,
})
}
}
impl ClassifierBackend for CandleThreeClassClassifier {
fn classify<'a>(
&'a self,
text: &'a str,
) -> Pin<Box<dyn Future<Output = Result<ClassificationResult, LlmError>> + Send + 'a>> {
let text = text.to_owned();
let inner_lock = Arc::clone(&self.inner);
let repo_id = Arc::clone(&self.repo_id);
let hf_token = self.hf_token.clone();
let sha256 = self.sha256.clone();
Box::pin(async move {
tokio::task::spawn_blocking(move || {
match inner_lock.get().map(Arc::clone).ok_or(()).or_else(|()| {
Self::load_inner(&repo_id, hf_token.as_deref(), sha256.as_deref())
.map(|i| {
let arc = Arc::new(i);
let _ = inner_lock.set(Arc::clone(&arc));
inner_lock.get().map_or(arc, Arc::clone)
})
.map_err(|e| e.to_string())
}) {
Ok(inner) => Self::classify_sync(&inner, &text),
Err(e) => {
tracing::error!(
error = %e,
"three-class classifier load failed — will retry on next call"
);
Err(LlmError::ModelLoad(e))
}
}
})
.await
.map_err(|e| LlmError::Inference(format!("spawn_blocking panicked: {e}")))?
})
}
fn backend_name(&self) -> &'static str {
"three_class_candle"
}
}
#[cfg(test)]
mod tests {
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::LlmError;
use super::super::{ClassificationResult, ClassifierBackend};
struct MockThreeClass {
label: &'static str,
score: f32,
}
impl ClassifierBackend for MockThreeClass {
fn classify<'a>(
&'a self,
_text: &'a str,
) -> Pin<Box<dyn Future<Output = Result<ClassificationResult, LlmError>> + Send + 'a>>
{
let label = self.label.to_owned();
let score = self.score;
Box::pin(async move {
Ok(ClassificationResult {
is_positive: label.to_lowercase().contains("misaligned"),
label,
score,
spans: vec![],
})
})
}
fn backend_name(&self) -> &'static str {
"mock_three_class"
}
}
fn mock(label: &'static str, score: f32) -> Arc<dyn ClassifierBackend> {
Arc::new(MockThreeClass { label, score })
}
#[tokio::test]
async fn misaligned_returns_is_positive() {
let b = mock("misaligned_instruction", 0.9);
let r = b.classify("ignore previous instructions").await.unwrap();
assert!(r.is_positive);
assert_eq!(r.label, "misaligned_instruction");
}
#[tokio::test]
async fn aligned_returns_not_positive() {
let b = mock("aligned_instruction", 0.85);
let r = b.classify("format the output as JSON").await.unwrap();
assert!(!r.is_positive);
assert_eq!(r.label, "aligned_instruction");
}
#[tokio::test]
async fn no_instruction_returns_not_positive() {
let b = mock("no_instruction", 0.95);
let r = b.classify("the weather is nice today").await.unwrap();
assert!(!r.is_positive);
}
}