sqry_nl/classifier/model.rs
1//! ONNX model loading and inference.
2
3use crate::error::ClassifierError;
4use crate::types::{ClassificationResult, Intent};
5use ort::session::Session;
6use ort::value::Tensor;
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::io::Read;
10use std::path::Path;
11
12use super::BAKED_MANIFEST;
13use super::calibration::CalibrationParams;
14use super::manifest::Manifest;
15use super::resolve::TrustMode;
16
17// ---------------------------------------------------------------------------
18// NL08 — ONNX Runtime "missing dylib" detection
19// ---------------------------------------------------------------------------
20//
21// The `ort` crate (with the `load-dynamic` feature, which sqry-nl uses)
22// resolves `libonnxruntime` at first API call via `libloading`. If the
23// shared library is absent, ort's `setup_api()` calls `.expect("Failed
24// to load ONNX Runtime dylib")` — meaning the failure surfaces as a
25// **panic**, not a typed `Result::Err`. Some downstream surfaces (e.g.
26// symbol lookup after a successful library open) do return a typed
27// `ort::Error` that carries the substring `"libonnxruntime"` /
28// `"failed to load"` / `"OrtGetApiBase"` / `"dlopen"` / `"DyLib"` in its
29// `Display` form.
30//
31// We therefore detect the missing-dylib condition through TWO channels:
32//
33// 1. `std::panic::catch_unwind` around the `Session::builder()` chain
34// to convert panics into a typed `OnnxRuntimeMissing` error.
35// 2. String-pattern matching on the `Display` of any returned
36// `ort::Error` for the substrings above, so symbol-lookup failures
37// after a partial library load also surface as
38// `OnnxRuntimeMissing` instead of the opaque `OnnxError(_)`.
39//
40// A deterministic test seam — the `SQRY_NL_FORCE_ORT_MISSING` env var
41// — short-circuits this path before any ORT call. The seam is gated on
42// `debug_assertions` so it cannot be exploited in release binaries
43// shipped to operators. Cargo test runs under `debug_assertions` by
44// default, so the CLI / MCP / LSP integration tests can drive this path
45// without needing an actual missing libonnxruntime on the host.
46
47/// Return the platform-specific install hint for missing ONNX Runtime.
48///
49/// Used to populate
50/// [`crate::error::ClassifierError::OnnxRuntimeMissing`] /
51/// [`crate::error::NlError::OnnxRuntimeMissing`].
52#[must_use]
53pub fn onnx_runtime_install_hint() -> String {
54 #[cfg(target_os = "linux")]
55 {
56 "Install via apt: 'sudo apt-get install libonnxruntime-dev' OR \
57 download from https://github.com/microsoft/onnxruntime/releases"
58 .to_string()
59 }
60 #[cfg(target_os = "macos")]
61 {
62 "Install via brew: 'brew install onnxruntime'".to_string()
63 }
64 #[cfg(target_os = "windows")]
65 {
66 "Download libonnxruntime.dll from \
67 https://github.com/microsoft/onnxruntime/releases and place in PATH"
68 .to_string()
69 }
70 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
71 {
72 // Other Unix-likes (FreeBSD, etc.) — mirror Linux guidance.
73 "Install libonnxruntime via your platform package manager OR \
74 download from https://github.com/microsoft/onnxruntime/releases"
75 .to_string()
76 }
77}
78
79/// Return `true` when the env-var test seam is active.
80///
81/// Gated on `debug_assertions` so release binaries do not honour the
82/// override. `cargo test` runs under `debug_assertions` regardless of
83/// the harness binary's profile, so subprocess tests of the release
84/// `sqry` binary need to spawn the debug-built binary (which `cargo
85/// test` always does — `cargo build --release` is a separate command).
86#[cfg(debug_assertions)]
87fn ort_missing_forced() -> bool {
88 match std::env::var("SQRY_NL_FORCE_ORT_MISSING") {
89 Ok(v) => {
90 let v = v.trim();
91 v.eq_ignore_ascii_case("1")
92 || v.eq_ignore_ascii_case("true")
93 || v.eq_ignore_ascii_case("yes")
94 || v.eq_ignore_ascii_case("on")
95 }
96 Err(_) => false,
97 }
98}
99
100#[cfg(not(debug_assertions))]
101fn ort_missing_forced() -> bool {
102 false
103}
104
105/// Returns `true` if the given error string looks like a dylib-load
106/// failure for `libonnxruntime`. Matches the substrings ort emits in
107/// the load-dynamic path. Case-insensitive on the substring tokens.
108///
109/// NL08 review iter-1: the broad tokens `"dylib"`, `"dlopen"`, and
110/// `"failed to load"` were intentionally excluded from this OR set.
111/// They false-positive on operator-supplied paths (e.g.
112/// `SQRY_NL_MODEL_DIR=/some/dylib-models/...`) and on unrelated model
113/// load errors carrying such paths in their message — a bad-ONNX-bytes
114/// failure for a model under such a path would otherwise be
115/// misclassified as `OnnxRuntimeMissing`. The three remaining tokens
116/// (`libonnxruntime`, `onnxruntime.dll`, `ortgetapibase`) uniquely
117/// identify the ort dylib-load surface and will never appear in a
118/// legitimate file path or model parse error.
119fn looks_like_dylib_load_failure(msg: &str) -> bool {
120 let lower = msg.to_ascii_lowercase();
121 lower.contains("libonnxruntime")
122 || lower.contains("onnxruntime.dll")
123 || lower.contains("ortgetapibase")
124}
125
126/// Construct `ClassifierError::OnnxRuntimeMissing` with the platform hint.
127fn onnx_runtime_missing_error() -> ClassifierError {
128 ClassifierError::OnnxRuntimeMissing {
129 hint: onnx_runtime_install_hint(),
130 }
131}
132
133/// Intent classifier using an ONNX model (`all-MiniLM-L6-v2` or `DistilBERT`).
134pub struct IntentClassifier {
135 /// ONNX Runtime session
136 session: Session,
137 /// `HuggingFace` tokenizer
138 tokenizer: tokenizers::Tokenizer,
139 /// Calibration parameters for confidence scaling
140 calibration: CalibrationParams,
141 /// Model version string
142 model_version: String,
143 /// Whether the ONNX model declares `token_type_ids` as an input.
144 /// BERT-architecture models (`MiniLM`) require it; `DistilBERT` does not.
145 /// Passing an undeclared input to ort causes a runtime error.
146 has_token_type_ids: bool,
147}
148
149/// Compute SHA256 hash of a file.
150fn compute_file_hash(path: &Path) -> Result<String, ClassifierError> {
151 let mut file = std::fs::File::open(path).map_err(|e| {
152 ClassifierError::OnnxError(format!("Failed to open {}: {e}", path.display()))
153 })?;
154
155 let mut hasher = Sha256::new();
156 let mut buffer = [0u8; 8192];
157
158 loop {
159 let bytes_read = file.read(&mut buffer).map_err(|e| {
160 ClassifierError::OnnxError(format!("Failed to read {}: {e}", path.display()))
161 })?;
162 if bytes_read == 0 {
163 break;
164 }
165 hasher.update(&buffer[..bytes_read]);
166 }
167
168 Ok(format!("{:x}", hasher.finalize()))
169}
170
171// ---------------------------------------------------------------------------
172// NL04 Integrity Contract — AUTHORITATIVE
173// ---------------------------------------------------------------------------
174//
175// `verify_integrity` is the single point at which on-disk model artifacts
176// are validated against an expected-hash table. Two distinct failure modes
177// must NEVER be conflated:
178//
179// 1. TAMPERING — a file is present on disk, its sha256 was checked, and
180// the computed hash does NOT match the expected hash. This ALWAYS
181// yields `Err(ChecksumMismatch { file, expected, actual })`,
182// regardless of `allow_unverified`. The escape hatch covers
183// missingness only; it never silences hash mismatch on a present
184// file. This matches spec FR-7 + FR-13.
185//
186// 2. MISSINGNESS — `checksums.json` itself is absent, or a file listed
187// in `checksums.json` is absent on disk. In strict mode
188// (`allow_unverified == false`, the default per FR-7), missingness
189// is a fatal error (`ChecksumsMissing` / `ChecksummedFileMissing`).
190// With `allow_unverified == true`, missingness downgrades to a
191// `tracing::warn!` and the loader continues — but ALL still-present
192// files are still hashed.
193//
194// Trust mode (FR-14):
195// - `TrustMode::Trusted` (resolver levels 4-5): the on-disk
196// `checksums.json` is hashed and cross-checked against
197// `BAKED_MANIFEST.files["checksums.json"]`. A mismatch ALWAYS errors,
198// even when `allow_unverified == true`. This anchors integrity in
199// the binary itself rather than the operator-supplied directory.
200// - `TrustMode::Custom` (resolver levels 1-3): the local
201// `manifest.json` (parsed from disk in the same directory) is the
202// trust root. `Translator::new` is responsible for emitting the
203// loud `tracing::warn!` that integrity is rooted in user-supplied
204// data; this function focuses on the actual verification.
205// ---------------------------------------------------------------------------
206
207/// Load checksums from `checksums.json` if present.
208///
209/// Returns `Ok(None)` when the file is absent (caller decides whether
210/// that is fatal based on `allow_unverified`). Returns `Ok(Some(map))`
211/// when present and parseable. Returns `Err` only on parse / I/O
212/// failure — those are always fatal.
213fn try_load_checksums(
214 checksums_path: &Path,
215) -> Result<Option<HashMap<String, String>>, ClassifierError> {
216 if !checksums_path.exists() {
217 return Ok(None);
218 }
219 let content = std::fs::read_to_string(checksums_path)
220 .map_err(|e| ClassifierError::OnnxError(format!("Failed to read checksums.json: {e}")))?;
221 let map = serde_json::from_str(&content)
222 .map_err(|e| ClassifierError::OnnxError(format!("Failed to parse checksums.json: {e}")))?;
223 Ok(Some(map))
224}
225
226/// Verify model directory integrity per the NL04 contract documented above.
227///
228/// See the module-level "NL04 Integrity Contract — AUTHORITATIVE" comment
229/// block for the full tampering-vs-missingness rules. A short summary:
230///
231/// - Tampering on a present file ALWAYS errors.
232/// - Missingness errors only when `allow_unverified == false`.
233/// - In `TrustMode::Trusted`, `checksums.json`'s own bytes are
234/// cross-checked against `BAKED_MANIFEST.files["checksums.json"]` —
235/// a mismatch is ALWAYS fatal.
236fn verify_integrity(
237 model_dir: &Path,
238 allow_unverified: bool,
239 trust_mode: TrustMode,
240) -> Result<(), ClassifierError> {
241 verify_integrity_with_trusted_manifest(model_dir, allow_unverified, trust_mode, &BAKED_MANIFEST)
242}
243
244fn verify_integrity_with_trusted_manifest(
245 model_dir: &Path,
246 allow_unverified: bool,
247 trust_mode: TrustMode,
248 trusted_manifest: &Manifest,
249) -> Result<(), ClassifierError> {
250 let checksums_path = model_dir.join("checksums.json");
251
252 match trust_mode {
253 TrustMode::Trusted => {
254 verify_trusted_checksums_anchor(&checksums_path, allow_unverified, trusted_manifest)?;
255 }
256 TrustMode::Custom => verify_custom_checksums_anchor(model_dir, &checksums_path)?,
257 }
258
259 // Per-file pass over `checksums.json`. Same tampering-vs-missingness
260 // rules apply file-by-file.
261 let Some(checksums) = try_load_checksums(&checksums_path)? else {
262 if allow_unverified {
263 tracing::warn!(
264 "No checksums.json found in {} — allow_unverified=true; \
265 skipping integrity verification (development workflow)",
266 model_dir.display()
267 );
268 return Ok(());
269 }
270 return Err(ClassifierError::ChecksumsMissing);
271 };
272
273 let mut verified_count = 0usize;
274 for (filename, expected_hash) in &checksums {
275 let file_path = model_dir.join(filename);
276 if !file_path.exists() {
277 // MISSINGNESS — strict by default, warn-and-skip with hatch.
278 if allow_unverified {
279 tracing::warn!(
280 "Checksummed file missing: {filename} — allow_unverified=true; \
281 continuing (other listed files will still be hashed)"
282 );
283 continue;
284 }
285 return Err(ClassifierError::ChecksummedFileMissing(filename.clone()));
286 }
287
288 let actual_hash = compute_file_hash(&file_path)?;
289 if &actual_hash != expected_hash {
290 // TAMPERING — ALWAYS fatal, regardless of allow_unverified.
291 return Err(ClassifierError::ChecksumMismatch {
292 file: filename.clone(),
293 expected: expected_hash.clone(),
294 actual: actual_hash,
295 });
296 }
297 verified_count += 1;
298 tracing::debug!("Verified checksum for {filename}");
299 }
300 tracing::info!(
301 "Model integrity verified: {} of {} listed files checked",
302 verified_count,
303 checksums.len()
304 );
305 Ok(())
306}
307
308fn verify_trusted_checksums_anchor(
309 checksums_path: &Path,
310 allow_unverified: bool,
311 trusted_manifest: &Manifest,
312) -> Result<(), ClassifierError> {
313 let Some(expected_checksums_hash) = trusted_manifest.files.get("checksums.json") else {
314 return Ok(());
315 };
316
317 if checksums_path.exists() {
318 verify_checksums_json_hash(
319 checksums_path,
320 expected_checksums_hash,
321 "Trusted-mode anchor OK: checksums.json matches BAKED_MANIFEST",
322 )
323 } else if allow_unverified {
324 tracing::warn!(
325 "checksums.json missing under Trusted resolver level — \
326 allow_unverified=true downgrades to warn; baked-in trust \
327 anchor cannot be cross-checked"
328 );
329 Ok(())
330 } else {
331 Err(ClassifierError::ChecksumsMissing)
332 }
333}
334
335fn verify_custom_checksums_anchor(
336 model_dir: &Path,
337 checksums_path: &Path,
338) -> Result<(), ClassifierError> {
339 let local_manifest_path = model_dir.join("manifest.json");
340 if !local_manifest_path.exists() {
341 return Err(ClassifierError::ManifestAnchorInvalid(format!(
342 "manifest.json missing at {}",
343 local_manifest_path.display()
344 )));
345 }
346
347 let local_manifest = Manifest::parse_path(&local_manifest_path).map_err(|err| {
348 ClassifierError::ManifestAnchorInvalid(format!(
349 "failed to parse manifest.json at {}: {err}",
350 local_manifest_path.display()
351 ))
352 })?;
353 let expected_checksums_hash = local_manifest.files.get("checksums.json").ok_or_else(|| {
354 ClassifierError::ManifestAnchorInvalid(format!(
355 "manifest.files[\"checksums.json\"] missing in {}",
356 local_manifest_path.display()
357 ))
358 })?;
359
360 if checksums_path.exists() {
361 verify_checksums_json_hash(
362 checksums_path,
363 expected_checksums_hash,
364 "Custom-mode anchor OK: checksums.json matches local manifest.json",
365 )
366 } else {
367 tracing::warn!(
368 target: "sqry_nl::classifier",
369 "Custom-mode integrity anchor skipped: checksums.json missing at {} \
370 (operator-supplied dir without a complete manifest)",
371 checksums_path.display()
372 );
373 Ok(())
374 }
375}
376
377fn verify_checksums_json_hash(
378 checksums_path: &Path,
379 expected_checksums_hash: &str,
380 success_message: &str,
381) -> Result<(), ClassifierError> {
382 let actual = compute_file_hash(checksums_path)?;
383 if actual != expected_checksums_hash {
384 // TAMPERING — always fatal, no opt-out.
385 return Err(ClassifierError::ChecksumMismatch {
386 file: "checksums.json".to_string(),
387 expected: expected_checksums_hash.to_string(),
388 actual,
389 });
390 }
391 tracing::debug!("{success_message}");
392 Ok(())
393}
394
395/// Parse model version from version.txt content.
396fn parse_model_version(content: &str) -> String {
397 for line in content.lines() {
398 let line = line.trim();
399 if line.starts_with("model_version=") {
400 return line
401 .strip_prefix("model_version=")
402 .unwrap_or("unknown")
403 .to_string();
404 }
405 }
406 "unknown".to_string()
407}
408
409impl IntentClassifier {
410 /// Load classifier from model directory.
411 ///
412 /// Expected directory structure:
413 /// ```text
414 /// model_dir/
415 /// ├── intent_classifier.onnx
416 /// ├── tokenizer.json
417 /// ├── config.json
418 /// ├── calibration.json or temperature.json (optional)
419 /// ├── checksums.json
420 /// └── version.txt
421 /// ```
422 ///
423 /// # Arguments
424 ///
425 /// * `model_dir` — Resolved model directory (output of NL02
426 /// resolver chain).
427 /// * `allow_unverified` — Operator escape hatch. When `false`
428 /// (NL04 default per FR-7), missingness is fatal. When `true`,
429 /// missingness downgrades to `tracing::warn!`. **Tampering on a
430 /// present file ALWAYS errors regardless of this flag** — see
431 /// the inline contract documented at [`verify_integrity`].
432 /// * `trust_mode` — Output of [`TrustMode::from(ResolverLevel)`].
433 /// Trusted mode anchors integrity in the binary's baked-in
434 /// manifest; Custom mode trusts the user-supplied
435 /// `manifest.json` shipped alongside the model directory.
436 ///
437 /// # Errors
438 ///
439 /// Returns [`ClassifierError`] if:
440 /// - Model files not found
441 /// - Checksum verification fails (AC-11.8 / NL04 integrity contract)
442 /// - ONNX Runtime initialization fails
443 pub fn load(
444 model_dir: &Path,
445 allow_unverified: bool,
446 trust_mode: TrustMode,
447 ) -> Result<Self, ClassifierError> {
448 Self::load_inner(model_dir, allow_unverified, trust_mode)
449 }
450
451 /// Run only the NL04 integrity contract for a model directory,
452 /// without invoking ONNX Runtime.
453 ///
454 /// Same contract as [`Self::load`]'s integrity pass — exists so
455 /// integration tests can exercise the contract on synthetic
456 /// fixtures (stub ONNX bytes) without the dylib dependency.
457 ///
458 /// # Errors
459 ///
460 /// Returns [`ClassifierError::ChecksumMismatch`] /
461 /// [`ClassifierError::ChecksumsMissing`] /
462 /// [`ClassifierError::ChecksummedFileMissing`] per the contract.
463 #[doc(hidden)]
464 pub fn verify_integrity_for_tests(
465 model_dir: &Path,
466 allow_unverified: bool,
467 trust_mode: TrustMode,
468 ) -> Result<(), ClassifierError> {
469 verify_integrity(model_dir, allow_unverified, trust_mode)
470 }
471
472 /// Run the NL04 integrity contract with a test-supplied trusted
473 /// manifest instead of the binary's baked model manifest.
474 ///
475 /// This keeps active integration tests hermetic: they can exercise
476 /// the Trusted-mode anchor and strict per-file pass against
477 /// synthetic model fixtures without committing the large external
478 /// ONNX model tree.
479 ///
480 /// # Errors
481 ///
482 /// Returns the same [`ClassifierError`] variants as
483 /// [`Self::verify_integrity_for_tests`].
484 #[doc(hidden)]
485 pub fn verify_integrity_with_manifest_for_tests(
486 model_dir: &Path,
487 allow_unverified: bool,
488 trust_mode: TrustMode,
489 trusted_manifest: &Manifest,
490 ) -> Result<(), ClassifierError> {
491 verify_integrity_with_trusted_manifest(
492 model_dir,
493 allow_unverified,
494 trust_mode,
495 trusted_manifest,
496 )
497 }
498
499 fn load_inner(
500 model_dir: &Path,
501 allow_unverified: bool,
502 trust_mode: TrustMode,
503 ) -> Result<Self, ClassifierError> {
504 // NL08: deterministic test seam — when
505 // `SQRY_NL_FORCE_ORT_MISSING` is truthy AND we are running a
506 // debug build (cargo test / cargo run), short-circuit straight
507 // to `OnnxRuntimeMissing`. This lets the CLI / MCP / LSP
508 // integration tests drive the missing-runtime path without
509 // needing an actual missing libonnxruntime on the host. The
510 // helper is a no-op in release builds.
511 if ort_missing_forced() {
512 return Err(onnx_runtime_missing_error());
513 }
514
515 // Check model directory exists
516 if !model_dir.exists() {
517 return Err(ClassifierError::ModelNotFound(
518 model_dir.display().to_string(),
519 ));
520 }
521
522 // Verify integrity BEFORE any artifact load — this is the
523 // first-fail gate per the NL04 integrity contract. Tampering
524 // detection happens here, prior to ONNX session creation, so
525 // synthetic test fixtures (stub ONNX bytes) can exercise the
526 // contract without invoking the inference engine.
527 verify_integrity(model_dir, allow_unverified, trust_mode)?;
528
529 let model_path = model_dir.join("intent_classifier.onnx");
530 let tokenizer_path = model_dir.join("tokenizer.json");
531
532 if !model_path.exists() {
533 return Err(ClassifierError::ModelNotFound(
534 model_path.display().to_string(),
535 ));
536 }
537
538 if !tokenizer_path.exists() {
539 return Err(ClassifierError::ModelNotFound(
540 tokenizer_path.display().to_string(),
541 ));
542 }
543
544 // Load ONNX session.
545 //
546 // NL08: the `ort` crate panics in `setup_api()` (with the
547 // `load-dynamic` feature) if `libonnxruntime` cannot be loaded,
548 // so we wrap the whole builder chain in `catch_unwind` and
549 // reinterpret either a panic or any error string that looks
550 // like a dylib-load failure as
551 // `ClassifierError::OnnxRuntimeMissing` so callers can surface
552 // an actionable platform-specific install hint.
553 let model_path_for_load = model_path.clone();
554 let session_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
555 Session::builder()?
556 .with_intra_threads(1)?
557 .commit_from_file(&model_path_for_load)
558 }));
559 let session = match session_result {
560 Ok(Ok(session)) => session,
561 Ok(Err(e)) => {
562 let msg = e.to_string();
563 if looks_like_dylib_load_failure(&msg) {
564 return Err(onnx_runtime_missing_error());
565 }
566 return Err(ClassifierError::OnnxError(msg));
567 }
568 Err(panic_payload) => {
569 let panic_msg = panic_payload
570 .downcast_ref::<&'static str>()
571 .map(|s| (*s).to_string())
572 .or_else(|| panic_payload.downcast_ref::<String>().cloned())
573 .unwrap_or_else(|| "ort panic with unknown payload".to_string());
574 if looks_like_dylib_load_failure(&panic_msg) {
575 return Err(onnx_runtime_missing_error());
576 }
577 // Any other panic from ort is escalated as a generic
578 // ONNX error rather than re-thrown — translator
579 // construction must always return a typed error.
580 return Err(ClassifierError::OnnxError(format!(
581 "ort panic during session init: {panic_msg}"
582 )));
583 }
584 };
585
586 // Detect whether model expects token_type_ids (BERT vs DistilBERT)
587 let model_inputs = session.inputs();
588 let has_token_type_ids = model_inputs
589 .iter()
590 .any(|input| input.name() == "token_type_ids");
591 tracing::debug!(
592 "Model inputs: {:?}, has_token_type_ids: {has_token_type_ids}",
593 model_inputs
594 .iter()
595 .map(ort::value::Outlet::name)
596 .collect::<Vec<_>>()
597 );
598
599 // Load tokenizer
600 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
601 .map_err(|e| ClassifierError::TokenizationFailed(e.to_string()))?;
602
603 // Load calibration (optional) — try calibration.json first, then temperature.json
604 let calibration_path = model_dir.join("calibration.json");
605 let temperature_path = model_dir.join("temperature.json");
606 let calibration = if calibration_path.exists() {
607 let content = std::fs::read_to_string(&calibration_path)
608 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
609 serde_json::from_str(&content).unwrap_or_default()
610 } else if temperature_path.exists() {
611 let content = std::fs::read_to_string(&temperature_path)
612 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
613 let params: CalibrationParams = serde_json::from_str(&content).unwrap_or_default();
614 tracing::debug!(
615 "Loaded calibration temperature={} from temperature.json",
616 params.temperature
617 );
618 params
619 } else {
620 CalibrationParams::default()
621 };
622
623 // Load and parse version
624 let version_path = model_dir.join("version.txt");
625 let model_version = if version_path.exists() {
626 std::fs::read_to_string(&version_path)
627 .map_or_else(|_| "unknown".to_string(), |s| parse_model_version(&s))
628 } else {
629 "unknown".to_string()
630 };
631
632 Ok(Self {
633 session,
634 tokenizer,
635 calibration,
636 model_version,
637 has_token_type_ids,
638 })
639 }
640
641 /// Classify intent from natural language input.
642 ///
643 /// # Critical: `batch_size=1` enforcement (C1 mitigation)
644 ///
645 /// ONNX Runtime may crash with `batch_size` > 1. This method
646 /// always processes exactly one input.
647 ///
648 /// # Errors
649 ///
650 /// Returns [`ClassifierError`] if tokenization or inference fails.
651 ///
652 /// # Note
653 ///
654 /// This method requires `&mut self` due to ort 2.0 API requirements.
655 /// Use a Mutex wrapper if concurrent access is needed.
656 pub fn classify(&mut self, input: &str) -> Result<ClassificationResult, ClassifierError> {
657 // Tokenize input
658 let encoding = self
659 .tokenizer
660 .encode(input, true)
661 .map_err(|e| ClassifierError::TokenizationFailed(e.to_string()))?;
662
663 let input_ids = encoding.get_ids();
664 let attention_mask = encoding.get_attention_mask();
665
666 // Truncate to max 512 tokens
667 let seq_len = input_ids.len().min(512);
668 if input_ids.len() > 512 {
669 tracing::warn!("Input truncated from {} to 512 tokens", input_ids.len());
670 }
671
672 // Prepare input tensors (batch_size=1)
673 let input_ids_i64: Vec<i64> = input_ids[..seq_len].iter().map(|&x| i64::from(x)).collect();
674 let attention_mask_i64: Vec<i64> = attention_mask[..seq_len]
675 .iter()
676 .map(|&x| i64::from(x))
677 .collect();
678
679 // Create input tensors with shape [1, seq_len] - ort 2.0 requires Vec not slice
680 let input_ids_tensor = Tensor::from_array(([1, seq_len], input_ids_i64))
681 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
682 let attention_mask_tensor = Tensor::from_array(([1, seq_len], attention_mask_i64))
683 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
684
685 // Build inputs conditionally: BERT-family models (MiniLM) require token_type_ids,
686 // while DistilBERT does not declare it. ort rejects undeclared input names.
687 let inputs = if self.has_token_type_ids {
688 let type_ids = encoding.get_type_ids();
689 let token_type_ids_i64: Vec<i64> =
690 type_ids[..seq_len].iter().map(|&x| i64::from(x)).collect();
691 let token_type_ids_tensor = Tensor::from_array(([1, seq_len], token_type_ids_i64))
692 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
693 ort::inputs![
694 "input_ids" => input_ids_tensor,
695 "attention_mask" => attention_mask_tensor,
696 "token_type_ids" => token_type_ids_tensor,
697 ]
698 } else {
699 ort::inputs![
700 "input_ids" => input_ids_tensor,
701 "attention_mask" => attention_mask_tensor,
702 ]
703 };
704
705 let outputs = self
706 .session
707 .run(inputs)
708 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
709
710 // Extract logits from output
711 let logits_tensor = outputs
712 .get("logits")
713 .ok_or_else(|| ClassifierError::OnnxError("No 'logits' output".to_string()))?;
714
715 // try_extract_tensor returns (&Shape, &[T]) tuple in ort 2.0
716 let (_, logits_data) = logits_tensor
717 .try_extract_tensor::<f32>()
718 .map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
719
720 let logits: Vec<f32> = logits_data.to_vec();
721
722 // Apply calibration and softmax
723 let probabilities = self.calibration.apply_temperature_scaling(&logits);
724
725 // Find argmax
726 let (intent_idx, confidence) = probabilities
727 .iter()
728 .enumerate()
729 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
730 .map_or((Intent::NUM_CLASSES - 1, 0.0), |(idx, &conf)| (idx, conf)); // Default to Ambiguous
731
732 let intent = Intent::from_index(intent_idx);
733
734 Ok(ClassificationResult {
735 intent,
736 confidence,
737 all_probabilities: probabilities,
738 model_version: self.model_version.clone(),
739 })
740 }
741
742 /// Get the model version.
743 #[must_use]
744 pub fn model_version(&self) -> &str {
745 &self.model_version
746 }
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752
753 #[test]
754 fn test_parse_model_version() {
755 let content = r"
756# sqry-nl Intent Classifier Model
757model_version=1.0.0
758model_date=2025-12-09T07:34:00Z
759accuracy=0.9998
760";
761 assert_eq!(parse_model_version(content), "1.0.0");
762 }
763
764 #[test]
765 fn test_parse_model_version_missing() {
766 let content = "# No version here\naccuracy=0.99";
767 assert_eq!(parse_model_version(content), "unknown");
768 }
769
770 #[test]
771 fn test_parse_model_version_empty() {
772 assert_eq!(parse_model_version(""), "unknown");
773 }
774
775 // Tests requiring actual model files are marked as ignored
776 // and run during integration testing.
777
778 #[test]
779 #[ignore = "Requires trained model files"]
780 fn test_classifier_load() {
781 // Would test model loading
782 }
783
784 #[test]
785 #[ignore = "Requires trained model files"]
786 fn test_classifier_inference() {
787 // Would test inference
788 }
789
790 #[test]
791 #[ignore = "Requires trained model files"]
792 fn test_checksum_verification() {
793 // Would test checksum verification against deployed model
794 }
795}