use std::sync::{mpsc, Arc};
use zer_core::{
schema::Schema,
scoring::ScoredPair,
traits::{Judge, JudgeVerdict, RecordStore},
};
use crate::{
audit::{AuditEntry, AuditLog},
backend::JudgeBackend,
calibration::CalibrationTable,
error::JudgeError,
serialize::serialize_pair,
session::OnnxSession,
spec::JudgeModelSpec,
tokenize::JudgeTokenizer,
};
struct InferRequest {
texts: Vec<String>,
reply_tx: mpsc::SyncSender<Result<Vec<f32>, JudgeError>>,
}
struct WorkerGuard(Option<std::thread::JoinHandle<()>>);
impl Drop for WorkerGuard {
fn drop(&mut self) {
if let Some(handle) = self.0.take() {
let _ = handle.join();
}
}
}
fn worker_loop(
mut session: OnnxSession,
tokenizer: JudgeTokenizer,
rx: mpsc::Receiver<InferRequest>,
) {
for req in rx {
let result = (|| {
let (ids, mask, types) = tokenizer.encode_batch(&req.texts)?;
session.run_batch(&ids, &mask, &types)
})();
let _ = req.reply_tx.send(result);
}
}
#[derive(Clone)]
pub struct DebertaJudgeConfig {
pub promote_threshold: f32,
pub demote_threshold: f32,
pub batch_size: usize,
pub calibration: CalibrationTable,
pub audit_log: Option<Arc<AuditLog>>,
}
impl Default for DebertaJudgeConfig {
fn default() -> Self {
Self {
promote_threshold: 0.6,
demote_threshold: 0.35,
batch_size: 64,
calibration: CalibrationTable::default(),
audit_log: None,
}
}
}
#[derive(Clone)]
pub struct DebertaJudge {
work_tx: mpsc::SyncSender<InferRequest>,
_worker: Arc<WorkerGuard>,
record_store: Arc<dyn RecordStore>,
schema: Schema,
config: DebertaJudgeConfig,
}
impl DebertaJudge {
pub fn new(
spec: &dyn JudgeModelSpec,
backend: &JudgeBackend,
record_store: Arc<dyn RecordStore>,
schema: Schema,
config: DebertaJudgeConfig,
) -> Result<Self, JudgeError> {
let session = OnnxSession::from_spec(spec, backend)?;
let tokenizer = JudgeTokenizer::from_spec(spec)?;
let (work_tx, work_rx) = mpsc::sync_channel::<InferRequest>(32);
let model_name = spec.name().to_owned();
let handle = std::thread::Builder::new()
.name(format!("zer-judge[{model_name}]"))
.spawn(move || worker_loop(session, tokenizer, work_rx))
.map_err(JudgeError::Io)?;
let judge = Self {
work_tx,
_worker: Arc::new(WorkerGuard(Some(handle))),
record_store,
schema,
config,
};
let _ = judge.send_batch(vec![
"COL voornamen VAL Jan COL achternaam VAL Jansen COL geboortedatum VAL 1985-01-01 \
[SEP] COL voornamen VAL Janna COL achternaam VAL Jansen COL geboortedatum VAL 1985-06-15"
.to_string(),
]);
tracing::info!(model = %model_name, "ORT warm-up complete");
Ok(judge)
}
fn send_batch(&self, texts: Vec<String>) -> Result<Vec<f32>, JudgeError> {
let (reply_tx, reply_rx) = mpsc::sync_channel(1);
self.work_tx
.send(InferRequest { texts, reply_tx })
.map_err(|_| JudgeError::WorkerDisconnected)?;
reply_rx
.recv()
.map_err(|_| JudgeError::WorkerDisconnected)?
}
}
impl Judge for DebertaJudge {
fn adjudicate(&self, pairs: &[ScoredPair]) -> zer_core::traits::Result<Vec<JudgeVerdict>> {
if pairs.is_empty() {
return Ok(vec![]);
}
let mut texts = Vec::with_capacity(pairs.len());
for pair in pairs {
let a = self
.record_store
.get(pair.record_a)
.ok_or(JudgeError::RecordNotFound(pair.record_a))?;
let b = self
.record_store
.get(pair.record_b)
.ok_or(JudgeError::RecordNotFound(pair.record_b))?;
texts.push(serialize_pair(&a, &b, &self.schema));
}
let mut probs = Vec::with_capacity(texts.len());
let batch_size = self.config.batch_size.max(1);
for (chunk_idx, chunk) in texts.chunks(batch_size).enumerate() {
tracing::debug!(
chunk = chunk_idx,
chunk_size = chunk.len(),
total = texts.len(),
"judge inference chunk"
);
let chunk_probs = self
.send_batch(chunk.to_vec())
.map_err(zer_core::error::ZerError::from)?;
probs.extend(chunk_probs);
}
let mut verdicts = Vec::with_capacity(pairs.len());
for (idx, (pair, prob)) in pairs.iter().zip(probs.iter()).enumerate() {
let verdict = if *prob >= self.config.promote_threshold {
JudgeVerdict::IncreaseConfidence
} else if *prob <= self.config.demote_threshold {
JudgeVerdict::DecreaseConfidence
} else {
JudgeVerdict::NoChange
};
if let Some(log) = &self.config.audit_log {
let verdict_str = match &verdict {
JudgeVerdict::IncreaseConfidence => "increase",
JudgeVerdict::DecreaseConfidence => "decrease",
JudgeVerdict::NoChange => "no_change",
};
log.append(&AuditEntry {
record_a: pair.record_a,
record_b: pair.record_b,
pair_text: texts[idx].clone(),
match_probability: pair.match_probability,
entailment_score: *prob,
verdict: verdict_str,
});
}
verdicts.push(verdict);
}
Ok(verdicts)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deberta_judge_config_clone() {
let cfg = DebertaJudgeConfig::default();
let cloned = cfg.clone();
assert_eq!(cloned.promote_threshold, cfg.promote_threshold);
assert_eq!(cloned.demote_threshold, cfg.demote_threshold);
assert_eq!(cloned.batch_size, cfg.batch_size);
}
#[test]
fn deberta_judge_config_defaults() {
let cfg = DebertaJudgeConfig::default();
assert_eq!(cfg.promote_threshold, 0.6);
assert_eq!(cfg.demote_threshold, 0.35);
assert_eq!(cfg.batch_size, 64);
assert!(cfg.audit_log.is_none());
}
#[test]
fn deberta_judge_config_batch_size_custom() {
let cfg = DebertaJudgeConfig {
batch_size: 16,
..Default::default()
};
assert_eq!(cfg.batch_size, 16);
assert_eq!(cfg.promote_threshold, 0.6);
}
#[test]
fn deberta_judge_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DebertaJudge>();
}
#[test]
fn deberta_judge_config_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DebertaJudgeConfig>();
}
#[test]
fn batch_size_one_is_clamped_to_one() {
let cfg = DebertaJudgeConfig {
batch_size: 0,
..Default::default()
};
assert_eq!(cfg.batch_size.max(1), 1);
}
}