1mod error;
38
39#[doc(inline)]
40pub use error::NliError;
41
42use std::sync::Mutex;
43
44use ort::session::Session;
45use ort::value::Tensor;
46use tokenizers::Tokenizer;
47
48const DEFAULT_MODEL_REPO: &str = "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33";
53
54const DEFAULT_MODEL_FILE: &str = "onnx/model_quantized.onnx";
56
57const DEFAULT_TOKENIZER_FILE: &str = "tokenizer.json";
59
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
84pub enum NliConfig {
85 HuggingFace {
87 repo: String,
89 model_file: String,
91 tokenizer_file: String,
93 },
94}
95
96impl NliConfig {
97 #[must_use]
99 pub fn huggingface(
100 repo: impl Into<String>,
101 model_file: impl Into<String>,
102 tokenizer_file: impl Into<String>,
103 ) -> Self {
104 Self::HuggingFace {
105 repo: repo.into(),
106 model_file: model_file.into(),
107 tokenizer_file: tokenizer_file.into(),
108 }
109 }
110}
111
112impl Default for NliConfig {
113 fn default() -> Self {
115 Self::huggingface(DEFAULT_MODEL_REPO, DEFAULT_MODEL_FILE, DEFAULT_TOKENIZER_FILE)
116 }
117}
118
119const ENTAILMENT_IDX: usize = 0;
123
124#[derive(Debug, Clone)]
126pub struct ScoredLabel {
127 pub label: String,
129
130 pub score: f32,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum ExecutionProvider {
137 Cuda,
139
140 Cpu,
142}
143
144impl ExecutionProvider {
145 #[must_use]
147 pub fn ort_name(self) -> &'static str {
148 match self {
149 Self::Cuda => "CUDAExecutionProvider",
150 Self::Cpu => "CPUExecutionProvider",
151 }
152 }
153}
154
155pub struct NliClassifier {
161 session: Mutex<Session>,
162 tokenizer: Mutex<Tokenizer>,
163 execution_provider: ExecutionProvider,
164}
165
166impl std::fmt::Debug for NliClassifier {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("NliClassifier")
169 .field("execution_provider", &self.execution_provider)
170 .finish_non_exhaustive()
171 }
172}
173
174impl NliClassifier {
175 pub fn new(config: NliConfig) -> Result<Self, NliError> {
189 let NliConfig::HuggingFace {
190 repo,
191 model_file,
192 tokenizer_file,
193 } = config;
194
195 let (model_path, tokenizer_path) = download_model_files(&repo, &model_file, &tokenizer_file)?;
196
197 let (session, execution_provider) = create_session(&model_path)?;
198
199 let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| NliError::TokenizerLoad(e.to_string()))?;
200
201 tracing::event!(
202 name: "memoir.nli.loaded",
203 tracing::Level::INFO,
204 model = %repo,
205 execution_provider = execution_provider.ort_name(),
206 "NLI classifier loaded with {{execution_provider}}",
207 );
208
209 Ok(Self {
210 session: Mutex::new(session),
211 tokenizer: Mutex::new(tokenizer),
212 execution_provider,
213 })
214 }
215
216 #[must_use]
218 pub fn execution_provider(&self) -> ExecutionProvider {
219 self.execution_provider
220 }
221
222 pub fn classify(
235 &self,
236 text: &str,
237 labels: &[&str],
238 hypothesis_template: &str,
239 ) -> Result<Vec<ScoredLabel>, NliError> {
240 if labels.is_empty() {
241 return Ok(Vec::new());
242 }
243
244 let mut scored: Vec<ScoredLabel> = labels
245 .iter()
246 .map(|label| {
247 let hypothesis = hypothesis_template.replace("{}", label);
248 let score = self.entailment_score(text, &hypothesis)?;
249 Ok(ScoredLabel {
250 label: (*label).to_string(),
251 score,
252 })
253 })
254 .collect::<Result<Vec<_>, NliError>>()?;
255
256 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
257
258 Ok(scored)
259 }
260
261 fn entailment_score(&self, premise: &str, hypothesis: &str) -> Result<f32, NliError> {
263 let encoding = {
264 let tokenizer = self
265 .tokenizer
266 .lock()
267 .map_err(|e| NliError::Inference(format!("tokenizer lock poisoned: {e}")))?;
268 tokenizer
270 .encode((premise, hypothesis), true)
271 .map_err(|e| NliError::Inference(format!("tokenization failed: {e}")))?
272 };
273
274 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| i64::from(id)).collect();
275 let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&m| i64::from(m)).collect();
276 let shape = [1_usize, input_ids.len()];
277
278 let input_ids_tensor = Tensor::from_array((shape, input_ids))
279 .map_err(|e| NliError::Inference(format!("failed to create input_ids tensor: {e}")))?;
280 let attention_mask_tensor = Tensor::from_array((shape, attention_mask))
281 .map_err(|e| NliError::Inference(format!("failed to create attention_mask tensor: {e}")))?;
282
283 let mut session = self
284 .session
285 .lock()
286 .map_err(|e| NliError::Inference(format!("session lock poisoned: {e}")))?;
287
288 let outputs = session
290 .run(ort::inputs![input_ids_tensor, attention_mask_tensor])
291 .map_err(|e| NliError::Inference(format!("model inference failed: {e}")))?;
292
293 let (_shape, logits) = outputs[0]
296 .try_extract_tensor::<f32>()
297 .map_err(|e| NliError::Inference(format!("failed to extract logits: {e}")))?;
298
299 if logits.len() < 2 {
300 return Err(NliError::Inference(format!(
301 "expected at least 2 logits, got {}",
302 logits.len()
303 )));
304 }
305
306 Ok(softmax(logits)[ENTAILMENT_IDX])
307 }
308}
309
310fn download_model_files(
312 repo: &str,
313 model_file: &str,
314 tokenizer_file: &str,
315) -> Result<(std::path::PathBuf, std::path::PathBuf), NliError> {
316 let api = hf_hub::api::sync::Api::new().map_err(|e| NliError::Download(e.to_string()))?;
317 let repo = api.model(repo.to_string());
318
319 let model_path = repo
320 .get(model_file)
321 .map_err(|e| NliError::Download(format!("failed to download {model_file}: {e}")))?;
322 let tokenizer_path = repo
323 .get(tokenizer_file)
324 .map_err(|e| NliError::Download(format!("failed to download {tokenizer_file}: {e}")))?;
325
326 Ok((model_path, tokenizer_path))
327}
328
329#[cfg(not(feature = "cuda"))]
334fn create_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
335 let session = build_cpu_session(model_path)
336 .map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CPU: {e}")))?;
337 Ok((session, ExecutionProvider::Cpu))
338}
339
340fn build_cpu_session(model_path: &std::path::Path) -> Result<Session, String> {
342 build_cpu_session_inner(model_path).map_err(|e| e.to_string())
348}
349
350fn build_cpu_session_inner(model_path: &std::path::Path) -> ort::Result<Session> {
352 let session = Session::builder()?
353 .with_intra_threads(1)?
354 .commit_from_file(model_path)?;
355 Ok(session)
356}
357
358#[cfg(feature = "cuda")]
364fn create_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
365 match ExecutionProviderPreference::from_env()? {
366 ExecutionProviderPreference::Auto => create_auto_session(model_path),
367 ExecutionProviderPreference::Cuda => build_gpu_session(model_path, ExecutionProvider::Cuda)
368 .map(|session| (session, ExecutionProvider::Cuda))
369 .map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CUDA: {e}"))),
370 ExecutionProviderPreference::Cpu => build_cpu_session(model_path)
371 .map(|session| (session, ExecutionProvider::Cpu))
372 .map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CPU: {e}"))),
373 }
374}
375
376#[cfg(feature = "cuda")]
378fn create_auto_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
379 match build_gpu_session(model_path, ExecutionProvider::Cuda) {
380 Ok(session) => Ok((session, ExecutionProvider::Cuda)),
381 Err(cuda_err) => {
382 let session = build_cpu_session(model_path).map_err(|cpu_err| {
383 NliError::ModelLoad(format!(
384 "failed to initialize NLI session; CUDA error: {cuda_err}; CPU fallback error: {cpu_err}"
385 ))
386 })?;
387 tracing::event!(
388 name: "memoir.nli.cuda_fallback",
389 tracing::Level::WARN,
390 error = %cuda_err,
391 "CUDA init failed; falling back to CPU",
392 );
393 Ok((session, ExecutionProvider::Cpu))
394 }
395 }
396}
397
398#[cfg(feature = "cuda")]
400fn build_gpu_session(model_path: &std::path::Path, provider: ExecutionProvider) -> Result<Session, String> {
401 build_gpu_session_inner(model_path, provider).map_err(|e| e.to_string())
402}
403
404#[cfg(feature = "cuda")]
406fn build_gpu_session_inner(model_path: &std::path::Path, provider: ExecutionProvider) -> ort::Result<Session> {
407 let dispatch = match provider {
408 ExecutionProvider::Cuda => ort::ep::CUDA::default().build().error_on_failure(),
409 ExecutionProvider::Cpu => ort::ep::CPU::default().build().error_on_failure(),
410 };
411 let session = Session::builder()?
412 .with_execution_providers([dispatch])?
413 .with_intra_threads(1)?
414 .commit_from_file(model_path)?;
415 Ok(session)
416}
417
418#[cfg(feature = "cuda")]
420#[derive(Debug, Clone, Copy, PartialEq, Eq)]
421enum ExecutionProviderPreference {
422 Auto,
423 Cuda,
424 Cpu,
425}
426
427#[cfg(feature = "cuda")]
428impl ExecutionProviderPreference {
429 fn from_env() -> Result<Self, NliError> {
430 match std::env::var("NLI_EXECUTION_PROVIDER") {
431 Ok(value) => Self::parse(&value).map_err(|invalid| {
432 NliError::ModelLoad(format!(
433 "invalid NLI_EXECUTION_PROVIDER `{invalid}`; expected one of: auto, cuda, cpu"
434 ))
435 }),
436 Err(std::env::VarError::NotPresent) => Ok(Self::Auto),
437 Err(e) => Err(NliError::ModelLoad(format!(
438 "failed to read NLI_EXECUTION_PROVIDER: {e}"
439 ))),
440 }
441 }
442
443 fn parse(value: &str) -> Result<Self, &str> {
444 match value.trim().to_ascii_lowercase().as_str() {
445 "auto" => Ok(Self::Auto),
446 "cuda" => Ok(Self::Cuda),
447 "cpu" => Ok(Self::Cpu),
448 _ => Err(value),
449 }
450 }
451}
452
453fn softmax(logits: &[f32]) -> Vec<f32> {
455 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
456 let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
457 let sum: f32 = exps.iter().sum();
458 exps.iter().map(|&e| e / sum).collect()
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn should_compute_softmax_correctly() {
467 let logits = [2.0, 1.0, 0.1];
468 let probs = softmax(&logits);
469
470 assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
471 assert!(probs[0] > probs[1]);
472 assert!(probs[1] > probs[2]);
473 }
474
475 #[test]
476 fn should_handle_softmax_with_large_values() {
477 let logits = [1000.0, 1.0, 0.1];
479 let probs = softmax(&logits);
480
481 assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
482 assert!(probs[0] > 0.99);
483 }
484
485 #[test]
486 fn should_report_cpu_provider_ort_name() {
487 assert_eq!(ExecutionProvider::Cpu.ort_name(), "CPUExecutionProvider");
488 }
489
490 #[test]
491 fn should_default_nli_config_to_the_shipped_moritzlaurer_model() {
492 let NliConfig::HuggingFace {
493 repo,
494 model_file,
495 tokenizer_file,
496 } = NliConfig::default();
497 assert_eq!(repo, "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33");
498 assert_eq!(model_file, "onnx/model_quantized.onnx");
499 assert_eq!(tokenizer_file, "tokenizer.json");
500 }
501
502 #[test]
503 fn should_build_nli_config_from_huggingface_constructor() {
504 let config = NliConfig::huggingface("org/model", "m.onnx", "tok.json");
505 assert_eq!(
506 config,
507 NliConfig::HuggingFace {
508 repo: "org/model".to_string(),
509 model_file: "m.onnx".to_string(),
510 tokenizer_file: "tok.json".to_string(),
511 }
512 );
513 }
514
515 #[cfg(feature = "cuda")]
516 #[test]
517 fn should_parse_auto_execution_provider_preference() {
518 assert_eq!(
519 ExecutionProviderPreference::parse("auto"),
520 Ok(ExecutionProviderPreference::Auto)
521 );
522 }
523
524 #[cfg(feature = "cuda")]
525 #[test]
526 fn should_parse_cpu_execution_provider_preference_case_insensitively() {
527 assert_eq!(
528 ExecutionProviderPreference::parse("CPU"),
529 Ok(ExecutionProviderPreference::Cpu)
530 );
531 }
532
533 #[cfg(feature = "cuda")]
534 #[test]
535 fn should_parse_cuda_execution_provider_preference_with_whitespace() {
536 assert_eq!(
537 ExecutionProviderPreference::parse(" cuda "),
538 Ok(ExecutionProviderPreference::Cuda)
539 );
540 }
541
542 #[cfg(feature = "cuda")]
543 #[test]
544 fn should_reject_unknown_execution_provider_preference() {
545 assert_eq!(ExecutionProviderPreference::parse("metal"), Err("metal"));
546 }
547}