mod error;
#[doc(inline)]
pub use error::NliError;
use std::sync::Mutex;
use ort::session::Session;
use ort::value::Tensor;
use tokenizers::Tokenizer;
const DEFAULT_MODEL_REPO: &str = "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33";
const DEFAULT_MODEL_FILE: &str = "onnx/model_quantized.onnx";
const DEFAULT_TOKENIZER_FILE: &str = "tokenizer.json";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NliConfig {
HuggingFace {
repo: String,
model_file: String,
tokenizer_file: String,
},
}
impl NliConfig {
#[must_use]
pub fn huggingface(
repo: impl Into<String>,
model_file: impl Into<String>,
tokenizer_file: impl Into<String>,
) -> Self {
Self::HuggingFace {
repo: repo.into(),
model_file: model_file.into(),
tokenizer_file: tokenizer_file.into(),
}
}
}
impl Default for NliConfig {
fn default() -> Self {
Self::huggingface(DEFAULT_MODEL_REPO, DEFAULT_MODEL_FILE, DEFAULT_TOKENIZER_FILE)
}
}
const ENTAILMENT_IDX: usize = 0;
#[derive(Debug, Clone)]
pub struct ScoredLabel {
pub label: String,
pub score: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutionProvider {
Cuda,
Cpu,
}
impl ExecutionProvider {
#[must_use]
pub fn ort_name(self) -> &'static str {
match self {
Self::Cuda => "CUDAExecutionProvider",
Self::Cpu => "CPUExecutionProvider",
}
}
}
pub struct NliClassifier {
session: Mutex<Session>,
tokenizer: Mutex<Tokenizer>,
execution_provider: ExecutionProvider,
}
impl std::fmt::Debug for NliClassifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NliClassifier")
.field("execution_provider", &self.execution_provider)
.finish_non_exhaustive()
}
}
impl NliClassifier {
pub fn new(config: NliConfig) -> Result<Self, NliError> {
let NliConfig::HuggingFace {
repo,
model_file,
tokenizer_file,
} = config;
let (model_path, tokenizer_path) = download_model_files(&repo, &model_file, &tokenizer_file)?;
let (session, execution_provider) = create_session(&model_path)?;
let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| NliError::TokenizerLoad(e.to_string()))?;
tracing::event!(
name: "memoir.nli.loaded",
tracing::Level::INFO,
model = %repo,
execution_provider = execution_provider.ort_name(),
"NLI classifier loaded with {{execution_provider}}",
);
Ok(Self {
session: Mutex::new(session),
tokenizer: Mutex::new(tokenizer),
execution_provider,
})
}
#[must_use]
pub fn execution_provider(&self) -> ExecutionProvider {
self.execution_provider
}
pub fn classify(
&self,
text: &str,
labels: &[&str],
hypothesis_template: &str,
) -> Result<Vec<ScoredLabel>, NliError> {
if labels.is_empty() {
return Ok(Vec::new());
}
let mut scored: Vec<ScoredLabel> = labels
.iter()
.map(|label| {
let hypothesis = hypothesis_template.replace("{}", label);
let score = self.entailment_score(text, &hypothesis)?;
Ok(ScoredLabel {
label: (*label).to_string(),
score,
})
})
.collect::<Result<Vec<_>, NliError>>()?;
scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored)
}
fn entailment_score(&self, premise: &str, hypothesis: &str) -> Result<f32, NliError> {
let encoding = {
let tokenizer = self
.tokenizer
.lock()
.map_err(|e| NliError::Inference(format!("tokenizer lock poisoned: {e}")))?;
tokenizer
.encode((premise, hypothesis), true)
.map_err(|e| NliError::Inference(format!("tokenization failed: {e}")))?
};
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| i64::from(id)).collect();
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&m| i64::from(m)).collect();
let shape = [1_usize, input_ids.len()];
let input_ids_tensor = Tensor::from_array((shape, input_ids))
.map_err(|e| NliError::Inference(format!("failed to create input_ids tensor: {e}")))?;
let attention_mask_tensor = Tensor::from_array((shape, attention_mask))
.map_err(|e| NliError::Inference(format!("failed to create attention_mask tensor: {e}")))?;
let mut session = self
.session
.lock()
.map_err(|e| NliError::Inference(format!("session lock poisoned: {e}")))?;
let outputs = session
.run(ort::inputs![input_ids_tensor, attention_mask_tensor])
.map_err(|e| NliError::Inference(format!("model inference failed: {e}")))?;
let (_shape, logits) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| NliError::Inference(format!("failed to extract logits: {e}")))?;
if logits.len() < 2 {
return Err(NliError::Inference(format!(
"expected at least 2 logits, got {}",
logits.len()
)));
}
Ok(softmax(logits)[ENTAILMENT_IDX])
}
}
fn download_model_files(
repo: &str,
model_file: &str,
tokenizer_file: &str,
) -> Result<(std::path::PathBuf, std::path::PathBuf), NliError> {
let api = hf_hub::api::sync::Api::new().map_err(|e| NliError::Download(e.to_string()))?;
let repo = api.model(repo.to_string());
let model_path = repo
.get(model_file)
.map_err(|e| NliError::Download(format!("failed to download {model_file}: {e}")))?;
let tokenizer_path = repo
.get(tokenizer_file)
.map_err(|e| NliError::Download(format!("failed to download {tokenizer_file}: {e}")))?;
Ok((model_path, tokenizer_path))
}
#[cfg(not(feature = "cuda"))]
fn create_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
let session = build_cpu_session(model_path)
.map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CPU: {e}")))?;
Ok((session, ExecutionProvider::Cpu))
}
fn build_cpu_session(model_path: &std::path::Path) -> Result<Session, String> {
build_cpu_session_inner(model_path).map_err(|e| e.to_string())
}
fn build_cpu_session_inner(model_path: &std::path::Path) -> ort::Result<Session> {
let session = Session::builder()?
.with_intra_threads(1)?
.commit_from_file(model_path)?;
Ok(session)
}
#[cfg(feature = "cuda")]
fn create_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
match ExecutionProviderPreference::from_env()? {
ExecutionProviderPreference::Auto => create_auto_session(model_path),
ExecutionProviderPreference::Cuda => build_gpu_session(model_path, ExecutionProvider::Cuda)
.map(|session| (session, ExecutionProvider::Cuda))
.map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CUDA: {e}"))),
ExecutionProviderPreference::Cpu => build_cpu_session(model_path)
.map(|session| (session, ExecutionProvider::Cpu))
.map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CPU: {e}"))),
}
}
#[cfg(feature = "cuda")]
fn create_auto_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
match build_gpu_session(model_path, ExecutionProvider::Cuda) {
Ok(session) => Ok((session, ExecutionProvider::Cuda)),
Err(cuda_err) => {
let session = build_cpu_session(model_path).map_err(|cpu_err| {
NliError::ModelLoad(format!(
"failed to initialize NLI session; CUDA error: {cuda_err}; CPU fallback error: {cpu_err}"
))
})?;
tracing::event!(
name: "memoir.nli.cuda_fallback",
tracing::Level::WARN,
error = %cuda_err,
"CUDA init failed; falling back to CPU",
);
Ok((session, ExecutionProvider::Cpu))
}
}
}
#[cfg(feature = "cuda")]
fn build_gpu_session(model_path: &std::path::Path, provider: ExecutionProvider) -> Result<Session, String> {
build_gpu_session_inner(model_path, provider).map_err(|e| e.to_string())
}
#[cfg(feature = "cuda")]
fn build_gpu_session_inner(model_path: &std::path::Path, provider: ExecutionProvider) -> ort::Result<Session> {
let dispatch = match provider {
ExecutionProvider::Cuda => ort::ep::CUDA::default().build().error_on_failure(),
ExecutionProvider::Cpu => ort::ep::CPU::default().build().error_on_failure(),
};
let session = Session::builder()?
.with_execution_providers([dispatch])?
.with_intra_threads(1)?
.commit_from_file(model_path)?;
Ok(session)
}
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ExecutionProviderPreference {
Auto,
Cuda,
Cpu,
}
#[cfg(feature = "cuda")]
impl ExecutionProviderPreference {
fn from_env() -> Result<Self, NliError> {
match std::env::var("NLI_EXECUTION_PROVIDER") {
Ok(value) => Self::parse(&value).map_err(|invalid| {
NliError::ModelLoad(format!(
"invalid NLI_EXECUTION_PROVIDER `{invalid}`; expected one of: auto, cuda, cpu"
))
}),
Err(std::env::VarError::NotPresent) => Ok(Self::Auto),
Err(e) => Err(NliError::ModelLoad(format!(
"failed to read NLI_EXECUTION_PROVIDER: {e}"
))),
}
}
fn parse(value: &str) -> Result<Self, &str> {
match value.trim().to_ascii_lowercase().as_str() {
"auto" => Ok(Self::Auto),
"cuda" => Ok(Self::Cuda),
"cpu" => Ok(Self::Cpu),
_ => Err(value),
}
}
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_compute_softmax_correctly() {
let logits = [2.0, 1.0, 0.1];
let probs = softmax(&logits);
assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
assert!(probs[0] > probs[1]);
assert!(probs[1] > probs[2]);
}
#[test]
fn should_handle_softmax_with_large_values() {
let logits = [1000.0, 1.0, 0.1];
let probs = softmax(&logits);
assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
assert!(probs[0] > 0.99);
}
#[test]
fn should_report_cpu_provider_ort_name() {
assert_eq!(ExecutionProvider::Cpu.ort_name(), "CPUExecutionProvider");
}
#[test]
fn should_default_nli_config_to_the_shipped_moritzlaurer_model() {
let NliConfig::HuggingFace {
repo,
model_file,
tokenizer_file,
} = NliConfig::default();
assert_eq!(repo, "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33");
assert_eq!(model_file, "onnx/model_quantized.onnx");
assert_eq!(tokenizer_file, "tokenizer.json");
}
#[test]
fn should_build_nli_config_from_huggingface_constructor() {
let config = NliConfig::huggingface("org/model", "m.onnx", "tok.json");
assert_eq!(
config,
NliConfig::HuggingFace {
repo: "org/model".to_string(),
model_file: "m.onnx".to_string(),
tokenizer_file: "tok.json".to_string(),
}
);
}
#[cfg(feature = "cuda")]
#[test]
fn should_parse_auto_execution_provider_preference() {
assert_eq!(
ExecutionProviderPreference::parse("auto"),
Ok(ExecutionProviderPreference::Auto)
);
}
#[cfg(feature = "cuda")]
#[test]
fn should_parse_cpu_execution_provider_preference_case_insensitively() {
assert_eq!(
ExecutionProviderPreference::parse("CPU"),
Ok(ExecutionProviderPreference::Cpu)
);
}
#[cfg(feature = "cuda")]
#[test]
fn should_parse_cuda_execution_provider_preference_with_whitespace() {
assert_eq!(
ExecutionProviderPreference::parse(" cuda "),
Ok(ExecutionProviderPreference::Cuda)
);
}
#[cfg(feature = "cuda")]
#[test]
fn should_reject_unknown_execution_provider_preference() {
assert_eq!(ExecutionProviderPreference::parse("metal"), Err("metal"));
}
}