Skip to main content

memoir_core/nli/
mod.rs

1//! Zero-shot text classification via Natural Language Inference (NLI).
2//!
3//! Wraps a DeBERTa-v3-xsmall ONNX model (~87 MB quantized) trained on 33
4//! classification datasets. Classification is framed as entailment: given the
5//! input text as a premise and each candidate label rendered into a hypothesis
6//! template, the model scores how strongly the premise entails the hypothesis.
7//! The entailment probability ranks the labels.
8//!
9//! The model is downloaded from HuggingFace on first construction and cached
10//! locally, mirroring how [`crate::embedding::OnnxEmbedding`] loads its model.
11//!
12//! # Execution providers
13//!
14//! By default the classifier runs on CPU. Enabling the `cuda` cargo feature
15//! adds GPU execution: the classifier then negotiates CUDA-then-CPU and honors
16//! the `NLI_EXECUTION_PROVIDER` environment variable (`auto` | `cuda` | `cpu`).
17//! Enabling the feature does not change the public API — only which provider is
18//! selected internally.
19//!
20//! # Examples
21//!
22//! ```no_run
23//! use memoir_core::nli::{NliClassifier, NliConfig};
24//!
25//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
26//! let classifier = NliClassifier::new(NliConfig::default())?;
27//! let results = classifier.classify(
28//!     "We decided to use Pulumi instead of Terraform",
29//!     &["a decision that was made", "a personal preference"],
30//!     "This text is about {}.",
31//! )?;
32//! // results[0].label == "a decision that was made"
33//! # Ok(())
34//! # }
35//! ```
36
37mod error;
38
39#[doc(inline)]
40pub use error::NliError;
41
42use std::sync::Mutex;
43
44use ort::session::Session;
45use ort::value::Tensor;
46use tokenizers::Tokenizer;
47
48/// Default HuggingFace repo holding the pre-exported ONNX model.
49///
50/// MoritzLaurer's model is trained on 33 classification datasets, making it far
51/// better at zero-shot classification than a plain NLI cross-encoder.
52const DEFAULT_MODEL_REPO: &str = "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33";
53
54/// Default quantized ONNX model file within the repo (~87 MB).
55const DEFAULT_MODEL_FILE: &str = "onnx/model_quantized.onnx";
56
57/// Default tokenizer file within the repo.
58const DEFAULT_TOKENIZER_FILE: &str = "tokenizer.json";
59
60/// Source of the zero-shot NLI classifier model.
61///
62/// Selects which model [`NliClassifier::new`] downloads and loads for the
63/// categorize stage. [`NliConfig::default`] is the model memoir ships with
64/// (MoritzLaurer's DeBERTa-v3-xsmall); a consumer overrides it to point the
65/// classifier at a different HuggingFace repo. Mirrors [`crate::llm::LlmConfig`]'s
66/// enum-of-sources shape — one variant today, room to add others (a local path,
67/// a direct URL) without a breaking change.
68///
69/// # Examples
70///
71/// ```
72/// # use memoir_core::nli::NliConfig;
73/// // The shipped default.
74/// let config = NliConfig::default();
75///
76/// // A different zero-shot NLI repo, ONNX export + tokenizer.
77/// let custom = NliConfig::huggingface(
78///     "my-org/my-zeroshot-model",
79///     "onnx/model.onnx",
80///     "tokenizer.json",
81/// );
82/// ```
83#[derive(Debug, Clone, PartialEq, Eq, Hash)]
84pub enum NliConfig {
85    /// A model hosted on the HuggingFace Hub.
86    HuggingFace {
87        /// Repo id, e.g. `"MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33"`.
88        repo: String,
89        /// Path to the ONNX model file within the repo.
90        model_file: String,
91        /// Path to the tokenizer JSON within the repo.
92        tokenizer_file: String,
93    },
94}
95
96impl NliConfig {
97    /// Builds a config for a model on the HuggingFace Hub.
98    #[must_use]
99    pub fn huggingface(
100        repo: impl Into<String>,
101        model_file: impl Into<String>,
102        tokenizer_file: impl Into<String>,
103    ) -> Self {
104        Self::HuggingFace {
105            repo: repo.into(),
106            model_file: model_file.into(),
107            tokenizer_file: tokenizer_file.into(),
108        }
109    }
110}
111
112impl Default for NliConfig {
113    /// The model memoir ships with — MoritzLaurer's DeBERTa-v3-xsmall.
114    fn default() -> Self {
115        Self::huggingface(DEFAULT_MODEL_REPO, DEFAULT_MODEL_FILE, DEFAULT_TOKENIZER_FILE)
116    }
117}
118
119/// Index of the entailment logit in the model's two-class output.
120///
121/// Output order is `[entailment, not_entailment]`, so entailment is index 0.
122const ENTAILMENT_IDX: usize = 0;
123
124/// A label paired with its entailment score from [`NliClassifier::classify`].
125#[derive(Debug, Clone)]
126pub struct ScoredLabel {
127    /// The candidate label that was scored.
128    pub label: String,
129
130    /// Entailment probability in `[0.0, 1.0]`; higher means stronger match.
131    pub score: f32,
132}
133
134/// The hardware backend an [`NliClassifier`] runs its inference on.
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum ExecutionProvider {
137    /// NVIDIA CUDA. Only reachable when the `cuda` cargo feature is enabled.
138    Cuda,
139
140    /// CPU (MLAS). The default, and the only option without the `cuda` feature.
141    Cpu,
142}
143
144impl ExecutionProvider {
145    /// Returns the ONNX Runtime identifier for this provider.
146    #[must_use]
147    pub fn ort_name(self) -> &'static str {
148        match self {
149            Self::Cuda => "CUDAExecutionProvider",
150            Self::Cpu => "CPUExecutionProvider",
151        }
152    }
153}
154
155/// Zero-shot text classifier using NLI entailment scoring.
156///
157/// Holds a DeBERTa-v3-xsmall ONNX session and its tokenizer, each behind a
158/// `Mutex` so the classifier is `Send + Sync` and can be shared via `Arc`. The
159/// model is downloaded on first construction and cached locally.
160pub struct NliClassifier {
161    session: Mutex<Session>,
162    tokenizer: Mutex<Tokenizer>,
163    execution_provider: ExecutionProvider,
164}
165
166impl std::fmt::Debug for NliClassifier {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("NliClassifier")
169            .field("execution_provider", &self.execution_provider)
170            .finish_non_exhaustive()
171    }
172}
173
174impl NliClassifier {
175    /// Creates a classifier from `config`, downloading the model if not cached.
176    ///
177    /// This is synchronous and blocks on the HuggingFace download and ONNX
178    /// session creation — mirroring [`crate::embedding::OnnxEmbedding::new`].
179    /// Call it from a blocking context (e.g. `tokio::task::spawn_blocking`)
180    /// when constructing from async code. Pass [`NliConfig::default`] for the
181    /// model memoir ships with.
182    ///
183    /// # Errors
184    ///
185    /// Returns [`NliError::Download`] if the model or tokenizer cannot be
186    /// fetched, [`NliError::ModelLoad`] if the ONNX session cannot be built,
187    /// and [`NliError::TokenizerLoad`] if the tokenizer cannot be parsed.
188    pub fn new(config: NliConfig) -> Result<Self, NliError> {
189        let NliConfig::HuggingFace {
190            repo,
191            model_file,
192            tokenizer_file,
193        } = config;
194
195        let (model_path, tokenizer_path) = download_model_files(&repo, &model_file, &tokenizer_file)?;
196
197        let (session, execution_provider) = create_session(&model_path)?;
198
199        let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| NliError::TokenizerLoad(e.to_string()))?;
200
201        tracing::event!(
202            name: "memoir.nli.loaded",
203            tracing::Level::INFO,
204            model = %repo,
205            execution_provider = execution_provider.ort_name(),
206            "NLI classifier loaded with {{execution_provider}}",
207        );
208
209        Ok(Self {
210            session: Mutex::new(session),
211            tokenizer: Mutex::new(tokenizer),
212            execution_provider,
213        })
214    }
215
216    /// Returns the execution provider this classifier resolved to at load time.
217    #[must_use]
218    pub fn execution_provider(&self) -> ExecutionProvider {
219        self.execution_provider
220    }
221
222    /// Classifies `text` against `labels` using a hypothesis template.
223    ///
224    /// The template must contain `{}`, which is replaced with each label: with
225    /// template `"This text is about {}."` and label `"a decision"`, the
226    /// hypothesis is `"This text is about a decision."`. Returns the labels
227    /// sorted by entailment score, highest first. An empty `labels` slice
228    /// returns an empty vector.
229    ///
230    /// # Errors
231    ///
232    /// Returns [`NliError::Inference`] if tokenization or model inference
233    /// fails for any label.
234    pub fn classify(
235        &self,
236        text: &str,
237        labels: &[&str],
238        hypothesis_template: &str,
239    ) -> Result<Vec<ScoredLabel>, NliError> {
240        if labels.is_empty() {
241            return Ok(Vec::new());
242        }
243
244        let mut scored: Vec<ScoredLabel> = labels
245            .iter()
246            .map(|label| {
247                let hypothesis = hypothesis_template.replace("{}", label);
248                let score = self.entailment_score(text, &hypothesis)?;
249                Ok(ScoredLabel {
250                    label: (*label).to_string(),
251                    score,
252                })
253            })
254            .collect::<Result<Vec<_>, NliError>>()?;
255
256        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
257
258        Ok(scored)
259    }
260
261    /// Computes the entailment probability for one premise-hypothesis pair.
262    fn entailment_score(&self, premise: &str, hypothesis: &str) -> Result<f32, NliError> {
263        let encoding = {
264            let tokenizer = self
265                .tokenizer
266                .lock()
267                .map_err(|e| NliError::Inference(format!("tokenizer lock poisoned: {e}")))?;
268            // The tokenizer pairs the inputs as [CLS] premise [SEP] hypothesis [SEP].
269            tokenizer
270                .encode((premise, hypothesis), true)
271                .map_err(|e| NliError::Inference(format!("tokenization failed: {e}")))?
272        };
273
274        let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| i64::from(id)).collect();
275        let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&m| i64::from(m)).collect();
276        let shape = [1_usize, input_ids.len()];
277
278        let input_ids_tensor = Tensor::from_array((shape, input_ids))
279            .map_err(|e| NliError::Inference(format!("failed to create input_ids tensor: {e}")))?;
280        let attention_mask_tensor = Tensor::from_array((shape, attention_mask))
281            .map_err(|e| NliError::Inference(format!("failed to create attention_mask tensor: {e}")))?;
282
283        let mut session = self
284            .session
285            .lock()
286            .map_err(|e| NliError::Inference(format!("session lock poisoned: {e}")))?;
287
288        // DeBERTa-v2 takes only input_ids + attention_mask (no token_type_ids).
289        let outputs = session
290            .run(ort::inputs![input_ids_tensor, attention_mask_tensor])
291            .map_err(|e| NliError::Inference(format!("model inference failed: {e}")))?;
292
293        // Output shape is [1, 2] — [entailment, not_entailment]. We only need
294        // the data slice, not the shape.
295        let (_shape, logits) = outputs[0]
296            .try_extract_tensor::<f32>()
297            .map_err(|e| NliError::Inference(format!("failed to extract logits: {e}")))?;
298
299        if logits.len() < 2 {
300            return Err(NliError::Inference(format!(
301                "expected at least 2 logits, got {}",
302                logits.len()
303            )));
304        }
305
306        Ok(softmax(logits)[ENTAILMENT_IDX])
307    }
308}
309
310/// Downloads the model and tokenizer from HuggingFace, caching them locally.
311fn download_model_files(
312    repo: &str,
313    model_file: &str,
314    tokenizer_file: &str,
315) -> Result<(std::path::PathBuf, std::path::PathBuf), NliError> {
316    let api = hf_hub::api::sync::Api::new().map_err(|e| NliError::Download(e.to_string()))?;
317    let repo = api.model(repo.to_string());
318
319    let model_path = repo
320        .get(model_file)
321        .map_err(|e| NliError::Download(format!("failed to download {model_file}: {e}")))?;
322    let tokenizer_path = repo
323        .get(tokenizer_file)
324        .map_err(|e| NliError::Download(format!("failed to download {tokenizer_file}: {e}")))?;
325
326    Ok((model_path, tokenizer_path))
327}
328
329/// Builds an ONNX session on CPU.
330///
331/// Without the `cuda` feature this is the only path: the CPU provider is always
332/// available, so there is no provider to negotiate.
333#[cfg(not(feature = "cuda"))]
334fn create_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
335    let session = build_cpu_session(model_path)
336        .map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CPU: {e}")))?;
337    Ok((session, ExecutionProvider::Cpu))
338}
339
340/// Builds a CPU ONNX session from `model_path`.
341fn build_cpu_session(model_path: &std::path::Path) -> Result<Session, String> {
342    // A single intra-op thread: classification runs one short sequence per
343    // call, so extra threads add scheduling overhead without speedup. rc.12's
344    // builder methods return `Error<SessionBuilder>` (carrying the builder for
345    // recovery), so `?` is used to coerce into a uniform error rather than
346    // `.and_then`, whose error types would not unify with `commit_from_file`.
347    build_cpu_session_inner(model_path).map_err(|e| e.to_string())
348}
349
350/// Inner CPU session builder using `?`-chaining over `ort::Result`.
351fn build_cpu_session_inner(model_path: &std::path::Path) -> ort::Result<Session> {
352    let session = Session::builder()?
353        .with_intra_threads(1)?
354        .commit_from_file(model_path)?;
355    Ok(session)
356}
357
358/// Builds an ONNX session honoring `NLI_EXECUTION_PROVIDER` (CUDA build).
359///
360/// Reads the provider preference from the environment (`auto` | `cuda` | `cpu`,
361/// defaulting to `auto`). `auto` attempts CUDA and falls back to CPU; an
362/// explicit `cuda` or `cpu` uses only that provider.
363#[cfg(feature = "cuda")]
364fn create_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
365    match ExecutionProviderPreference::from_env()? {
366        ExecutionProviderPreference::Auto => create_auto_session(model_path),
367        ExecutionProviderPreference::Cuda => build_gpu_session(model_path, ExecutionProvider::Cuda)
368            .map(|session| (session, ExecutionProvider::Cuda))
369            .map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CUDA: {e}"))),
370        ExecutionProviderPreference::Cpu => build_cpu_session(model_path)
371            .map(|session| (session, ExecutionProvider::Cpu))
372            .map_err(|e| NliError::ModelLoad(format!("failed to initialize NLI session on CPU: {e}"))),
373    }
374}
375
376/// Attempts CUDA, falling back to CPU on failure (CUDA build).
377#[cfg(feature = "cuda")]
378fn create_auto_session(model_path: &std::path::Path) -> Result<(Session, ExecutionProvider), NliError> {
379    match build_gpu_session(model_path, ExecutionProvider::Cuda) {
380        Ok(session) => Ok((session, ExecutionProvider::Cuda)),
381        Err(cuda_err) => {
382            let session = build_cpu_session(model_path).map_err(|cpu_err| {
383                NliError::ModelLoad(format!(
384                    "failed to initialize NLI session; CUDA error: {cuda_err}; CPU fallback error: {cpu_err}"
385                ))
386            })?;
387            tracing::event!(
388                name: "memoir.nli.cuda_fallback",
389                tracing::Level::WARN,
390                error = %cuda_err,
391                "CUDA init failed; falling back to CPU",
392            );
393            Ok((session, ExecutionProvider::Cpu))
394        }
395    }
396}
397
398/// Builds a session bound to a specific execution provider (CUDA build).
399#[cfg(feature = "cuda")]
400fn build_gpu_session(model_path: &std::path::Path, provider: ExecutionProvider) -> Result<Session, String> {
401    build_gpu_session_inner(model_path, provider).map_err(|e| e.to_string())
402}
403
404/// Inner provider-bound session builder using `?`-chaining (CUDA build).
405#[cfg(feature = "cuda")]
406fn build_gpu_session_inner(model_path: &std::path::Path, provider: ExecutionProvider) -> ort::Result<Session> {
407    let dispatch = match provider {
408        ExecutionProvider::Cuda => ort::ep::CUDA::default().build().error_on_failure(),
409        ExecutionProvider::Cpu => ort::ep::CPU::default().build().error_on_failure(),
410    };
411    let session = Session::builder()?
412        .with_execution_providers([dispatch])?
413        .with_intra_threads(1)?
414        .commit_from_file(model_path)?;
415    Ok(session)
416}
417
418/// Execution-provider preference parsed from `NLI_EXECUTION_PROVIDER`.
419#[cfg(feature = "cuda")]
420#[derive(Debug, Clone, Copy, PartialEq, Eq)]
421enum ExecutionProviderPreference {
422    Auto,
423    Cuda,
424    Cpu,
425}
426
427#[cfg(feature = "cuda")]
428impl ExecutionProviderPreference {
429    fn from_env() -> Result<Self, NliError> {
430        match std::env::var("NLI_EXECUTION_PROVIDER") {
431            Ok(value) => Self::parse(&value).map_err(|invalid| {
432                NliError::ModelLoad(format!(
433                    "invalid NLI_EXECUTION_PROVIDER `{invalid}`; expected one of: auto, cuda, cpu"
434                ))
435            }),
436            Err(std::env::VarError::NotPresent) => Ok(Self::Auto),
437            Err(e) => Err(NliError::ModelLoad(format!(
438                "failed to read NLI_EXECUTION_PROVIDER: {e}"
439            ))),
440        }
441    }
442
443    fn parse(value: &str) -> Result<Self, &str> {
444        match value.trim().to_ascii_lowercase().as_str() {
445            "auto" => Ok(Self::Auto),
446            "cuda" => Ok(Self::Cuda),
447            "cpu" => Ok(Self::Cpu),
448            _ => Err(value),
449        }
450    }
451}
452
453/// Softmax over a slice of logits.
454fn softmax(logits: &[f32]) -> Vec<f32> {
455    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
456    let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
457    let sum: f32 = exps.iter().sum();
458    exps.iter().map(|&e| e / sum).collect()
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn should_compute_softmax_correctly() {
467        let logits = [2.0, 1.0, 0.1];
468        let probs = softmax(&logits);
469
470        assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
471        assert!(probs[0] > probs[1]);
472        assert!(probs[1] > probs[2]);
473    }
474
475    #[test]
476    fn should_handle_softmax_with_large_values() {
477        // Shifting by the max keeps exp() from overflowing on large logits.
478        let logits = [1000.0, 1.0, 0.1];
479        let probs = softmax(&logits);
480
481        assert!((probs.iter().sum::<f32>() - 1.0).abs() < 1e-5);
482        assert!(probs[0] > 0.99);
483    }
484
485    #[test]
486    fn should_report_cpu_provider_ort_name() {
487        assert_eq!(ExecutionProvider::Cpu.ort_name(), "CPUExecutionProvider");
488    }
489
490    #[test]
491    fn should_default_nli_config_to_the_shipped_moritzlaurer_model() {
492        let NliConfig::HuggingFace {
493            repo,
494            model_file,
495            tokenizer_file,
496        } = NliConfig::default();
497        assert_eq!(repo, "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33");
498        assert_eq!(model_file, "onnx/model_quantized.onnx");
499        assert_eq!(tokenizer_file, "tokenizer.json");
500    }
501
502    #[test]
503    fn should_build_nli_config_from_huggingface_constructor() {
504        let config = NliConfig::huggingface("org/model", "m.onnx", "tok.json");
505        assert_eq!(
506            config,
507            NliConfig::HuggingFace {
508                repo: "org/model".to_string(),
509                model_file: "m.onnx".to_string(),
510                tokenizer_file: "tok.json".to_string(),
511            }
512        );
513    }
514
515    #[cfg(feature = "cuda")]
516    #[test]
517    fn should_parse_auto_execution_provider_preference() {
518        assert_eq!(
519            ExecutionProviderPreference::parse("auto"),
520            Ok(ExecutionProviderPreference::Auto)
521        );
522    }
523
524    #[cfg(feature = "cuda")]
525    #[test]
526    fn should_parse_cpu_execution_provider_preference_case_insensitively() {
527        assert_eq!(
528            ExecutionProviderPreference::parse("CPU"),
529            Ok(ExecutionProviderPreference::Cpu)
530        );
531    }
532
533    #[cfg(feature = "cuda")]
534    #[test]
535    fn should_parse_cuda_execution_provider_preference_with_whitespace() {
536        assert_eq!(
537            ExecutionProviderPreference::parse(" cuda "),
538            Ok(ExecutionProviderPreference::Cuda)
539        );
540    }
541
542    #[cfg(feature = "cuda")]
543    #[test]
544    fn should_reject_unknown_execution_provider_preference() {
545        assert_eq!(ExecutionProviderPreference::parse("metal"), Err("metal"));
546    }
547}