use crate::error::ClassifierError;
use crate::types::{ClassificationResult, Intent};
use ort::session::Session;
use ort::value::Tensor;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::io::Read;
use std::path::Path;
use super::BAKED_MANIFEST;
use super::calibration::CalibrationParams;
use super::manifest::Manifest;
use super::resolve::TrustMode;
#[must_use]
pub fn onnx_runtime_install_hint() -> String {
#[cfg(target_os = "linux")]
{
"Install via apt: 'sudo apt-get install libonnxruntime-dev' OR \
download from https://github.com/microsoft/onnxruntime/releases"
.to_string()
}
#[cfg(target_os = "macos")]
{
"Install via brew: 'brew install onnxruntime'".to_string()
}
#[cfg(target_os = "windows")]
{
"Download libonnxruntime.dll from \
https://github.com/microsoft/onnxruntime/releases and place in PATH"
.to_string()
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{
"Install libonnxruntime via your platform package manager OR \
download from https://github.com/microsoft/onnxruntime/releases"
.to_string()
}
}
#[cfg(debug_assertions)]
fn ort_missing_forced() -> bool {
match std::env::var("SQRY_NL_FORCE_ORT_MISSING") {
Ok(v) => {
let v = v.trim();
v.eq_ignore_ascii_case("1")
|| v.eq_ignore_ascii_case("true")
|| v.eq_ignore_ascii_case("yes")
|| v.eq_ignore_ascii_case("on")
}
Err(_) => false,
}
}
#[cfg(not(debug_assertions))]
fn ort_missing_forced() -> bool {
false
}
fn looks_like_dylib_load_failure(msg: &str) -> bool {
let lower = msg.to_ascii_lowercase();
lower.contains("libonnxruntime")
|| lower.contains("onnxruntime.dll")
|| lower.contains("ortgetapibase")
}
fn onnx_runtime_missing_error() -> ClassifierError {
ClassifierError::OnnxRuntimeMissing {
hint: onnx_runtime_install_hint(),
}
}
pub struct IntentClassifier {
session: Session,
tokenizer: tokenizers::Tokenizer,
calibration: CalibrationParams,
model_version: String,
has_token_type_ids: bool,
}
fn compute_file_hash(path: &Path) -> Result<String, ClassifierError> {
let mut file = std::fs::File::open(path).map_err(|e| {
ClassifierError::OnnxError(format!("Failed to open {}: {e}", path.display()))
})?;
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192];
loop {
let bytes_read = file.read(&mut buffer).map_err(|e| {
ClassifierError::OnnxError(format!("Failed to read {}: {e}", path.display()))
})?;
if bytes_read == 0 {
break;
}
hasher.update(&buffer[..bytes_read]);
}
Ok(format!("{:x}", hasher.finalize()))
}
fn try_load_checksums(
checksums_path: &Path,
) -> Result<Option<HashMap<String, String>>, ClassifierError> {
if !checksums_path.exists() {
return Ok(None);
}
let content = std::fs::read_to_string(checksums_path)
.map_err(|e| ClassifierError::OnnxError(format!("Failed to read checksums.json: {e}")))?;
let map = serde_json::from_str(&content)
.map_err(|e| ClassifierError::OnnxError(format!("Failed to parse checksums.json: {e}")))?;
Ok(Some(map))
}
fn verify_integrity(
model_dir: &Path,
allow_unverified: bool,
trust_mode: TrustMode,
) -> Result<(), ClassifierError> {
verify_integrity_with_trusted_manifest(model_dir, allow_unverified, trust_mode, &BAKED_MANIFEST)
}
fn verify_integrity_with_trusted_manifest(
model_dir: &Path,
allow_unverified: bool,
trust_mode: TrustMode,
trusted_manifest: &Manifest,
) -> Result<(), ClassifierError> {
let checksums_path = model_dir.join("checksums.json");
match trust_mode {
TrustMode::Trusted => {
verify_trusted_checksums_anchor(&checksums_path, allow_unverified, trusted_manifest)?;
}
TrustMode::Custom => verify_custom_checksums_anchor(model_dir, &checksums_path)?,
}
let Some(checksums) = try_load_checksums(&checksums_path)? else {
if allow_unverified {
tracing::warn!(
"No checksums.json found in {} — allow_unverified=true; \
skipping integrity verification (development workflow)",
model_dir.display()
);
return Ok(());
}
return Err(ClassifierError::ChecksumsMissing);
};
let mut verified_count = 0usize;
for (filename, expected_hash) in &checksums {
let file_path = model_dir.join(filename);
if !file_path.exists() {
if allow_unverified {
tracing::warn!(
"Checksummed file missing: {filename} — allow_unverified=true; \
continuing (other listed files will still be hashed)"
);
continue;
}
return Err(ClassifierError::ChecksummedFileMissing(filename.clone()));
}
let actual_hash = compute_file_hash(&file_path)?;
if &actual_hash != expected_hash {
return Err(ClassifierError::ChecksumMismatch {
file: filename.clone(),
expected: expected_hash.clone(),
actual: actual_hash,
});
}
verified_count += 1;
tracing::debug!("Verified checksum for {filename}");
}
tracing::info!(
"Model integrity verified: {} of {} listed files checked",
verified_count,
checksums.len()
);
Ok(())
}
fn verify_trusted_checksums_anchor(
checksums_path: &Path,
allow_unverified: bool,
trusted_manifest: &Manifest,
) -> Result<(), ClassifierError> {
let Some(expected_checksums_hash) = trusted_manifest.files.get("checksums.json") else {
return Ok(());
};
if checksums_path.exists() {
verify_checksums_json_hash(
checksums_path,
expected_checksums_hash,
"Trusted-mode anchor OK: checksums.json matches BAKED_MANIFEST",
)
} else if allow_unverified {
tracing::warn!(
"checksums.json missing under Trusted resolver level — \
allow_unverified=true downgrades to warn; baked-in trust \
anchor cannot be cross-checked"
);
Ok(())
} else {
Err(ClassifierError::ChecksumsMissing)
}
}
fn verify_custom_checksums_anchor(
model_dir: &Path,
checksums_path: &Path,
) -> Result<(), ClassifierError> {
let local_manifest_path = model_dir.join("manifest.json");
if !local_manifest_path.exists() {
return Err(ClassifierError::ManifestAnchorInvalid(format!(
"manifest.json missing at {}",
local_manifest_path.display()
)));
}
let local_manifest = Manifest::parse_path(&local_manifest_path).map_err(|err| {
ClassifierError::ManifestAnchorInvalid(format!(
"failed to parse manifest.json at {}: {err}",
local_manifest_path.display()
))
})?;
let expected_checksums_hash = local_manifest.files.get("checksums.json").ok_or_else(|| {
ClassifierError::ManifestAnchorInvalid(format!(
"manifest.files[\"checksums.json\"] missing in {}",
local_manifest_path.display()
))
})?;
if checksums_path.exists() {
verify_checksums_json_hash(
checksums_path,
expected_checksums_hash,
"Custom-mode anchor OK: checksums.json matches local manifest.json",
)
} else {
tracing::warn!(
target: "sqry_nl::classifier",
"Custom-mode integrity anchor skipped: checksums.json missing at {} \
(operator-supplied dir without a complete manifest)",
checksums_path.display()
);
Ok(())
}
}
fn verify_checksums_json_hash(
checksums_path: &Path,
expected_checksums_hash: &str,
success_message: &str,
) -> Result<(), ClassifierError> {
let actual = compute_file_hash(checksums_path)?;
if actual != expected_checksums_hash {
return Err(ClassifierError::ChecksumMismatch {
file: "checksums.json".to_string(),
expected: expected_checksums_hash.to_string(),
actual,
});
}
tracing::debug!("{success_message}");
Ok(())
}
fn parse_model_version(content: &str) -> String {
for line in content.lines() {
let line = line.trim();
if line.starts_with("model_version=") {
return line
.strip_prefix("model_version=")
.unwrap_or("unknown")
.to_string();
}
}
"unknown".to_string()
}
impl IntentClassifier {
pub fn load(
model_dir: &Path,
allow_unverified: bool,
trust_mode: TrustMode,
) -> Result<Self, ClassifierError> {
Self::load_inner(model_dir, allow_unverified, trust_mode)
}
#[doc(hidden)]
pub fn verify_integrity_for_tests(
model_dir: &Path,
allow_unverified: bool,
trust_mode: TrustMode,
) -> Result<(), ClassifierError> {
verify_integrity(model_dir, allow_unverified, trust_mode)
}
#[doc(hidden)]
pub fn verify_integrity_with_manifest_for_tests(
model_dir: &Path,
allow_unverified: bool,
trust_mode: TrustMode,
trusted_manifest: &Manifest,
) -> Result<(), ClassifierError> {
verify_integrity_with_trusted_manifest(
model_dir,
allow_unverified,
trust_mode,
trusted_manifest,
)
}
fn load_inner(
model_dir: &Path,
allow_unverified: bool,
trust_mode: TrustMode,
) -> Result<Self, ClassifierError> {
if ort_missing_forced() {
return Err(onnx_runtime_missing_error());
}
if !model_dir.exists() {
return Err(ClassifierError::ModelNotFound(
model_dir.display().to_string(),
));
}
verify_integrity(model_dir, allow_unverified, trust_mode)?;
let model_path = model_dir.join("intent_classifier.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
if !model_path.exists() {
return Err(ClassifierError::ModelNotFound(
model_path.display().to_string(),
));
}
if !tokenizer_path.exists() {
return Err(ClassifierError::ModelNotFound(
tokenizer_path.display().to_string(),
));
}
let model_path_for_load = model_path.clone();
let session_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Session::builder()?
.with_intra_threads(1)?
.commit_from_file(&model_path_for_load)
}));
let session = match session_result {
Ok(Ok(session)) => session,
Ok(Err(e)) => {
let msg = e.to_string();
if looks_like_dylib_load_failure(&msg) {
return Err(onnx_runtime_missing_error());
}
return Err(ClassifierError::OnnxError(msg));
}
Err(panic_payload) => {
let panic_msg = panic_payload
.downcast_ref::<&'static str>()
.map(|s| (*s).to_string())
.or_else(|| panic_payload.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "ort panic with unknown payload".to_string());
if looks_like_dylib_load_failure(&panic_msg) {
return Err(onnx_runtime_missing_error());
}
return Err(ClassifierError::OnnxError(format!(
"ort panic during session init: {panic_msg}"
)));
}
};
let model_inputs = session.inputs();
let has_token_type_ids = model_inputs
.iter()
.any(|input| input.name() == "token_type_ids");
tracing::debug!(
"Model inputs: {:?}, has_token_type_ids: {has_token_type_ids}",
model_inputs
.iter()
.map(ort::value::Outlet::name)
.collect::<Vec<_>>()
);
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| ClassifierError::TokenizationFailed(e.to_string()))?;
let calibration_path = model_dir.join("calibration.json");
let temperature_path = model_dir.join("temperature.json");
let calibration = if calibration_path.exists() {
let content = std::fs::read_to_string(&calibration_path)
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
serde_json::from_str(&content).unwrap_or_default()
} else if temperature_path.exists() {
let content = std::fs::read_to_string(&temperature_path)
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
let params: CalibrationParams = serde_json::from_str(&content).unwrap_or_default();
tracing::debug!(
"Loaded calibration temperature={} from temperature.json",
params.temperature
);
params
} else {
CalibrationParams::default()
};
let version_path = model_dir.join("version.txt");
let model_version = if version_path.exists() {
std::fs::read_to_string(&version_path)
.map_or_else(|_| "unknown".to_string(), |s| parse_model_version(&s))
} else {
"unknown".to_string()
};
Ok(Self {
session,
tokenizer,
calibration,
model_version,
has_token_type_ids,
})
}
pub fn classify(&mut self, input: &str) -> Result<ClassificationResult, ClassifierError> {
let encoding = self
.tokenizer
.encode(input, true)
.map_err(|e| ClassifierError::TokenizationFailed(e.to_string()))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let seq_len = input_ids.len().min(512);
if input_ids.len() > 512 {
tracing::warn!("Input truncated from {} to 512 tokens", input_ids.len());
}
let input_ids_i64: Vec<i64> = input_ids[..seq_len].iter().map(|&x| i64::from(x)).collect();
let attention_mask_i64: Vec<i64> = attention_mask[..seq_len]
.iter()
.map(|&x| i64::from(x))
.collect();
let input_ids_tensor = Tensor::from_array(([1, seq_len], input_ids_i64))
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
let attention_mask_tensor = Tensor::from_array(([1, seq_len], attention_mask_i64))
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
let inputs = if self.has_token_type_ids {
let type_ids = encoding.get_type_ids();
let token_type_ids_i64: Vec<i64> =
type_ids[..seq_len].iter().map(|&x| i64::from(x)).collect();
let token_type_ids_tensor = Tensor::from_array(([1, seq_len], token_type_ids_i64))
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
ort::inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor,
]
} else {
ort::inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
]
};
let outputs = self
.session
.run(inputs)
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
let logits_tensor = outputs
.get("logits")
.ok_or_else(|| ClassifierError::OnnxError("No 'logits' output".to_string()))?;
let (_, logits_data) = logits_tensor
.try_extract_tensor::<f32>()
.map_err(|e| ClassifierError::OnnxError(e.to_string()))?;
let logits: Vec<f32> = logits_data.to_vec();
let probabilities = self.calibration.apply_temperature_scaling(&logits);
let (intent_idx, confidence) = probabilities
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or((Intent::NUM_CLASSES - 1, 0.0), |(idx, &conf)| (idx, conf));
let intent = Intent::from_index(intent_idx);
Ok(ClassificationResult {
intent,
confidence,
all_probabilities: probabilities,
model_version: self.model_version.clone(),
})
}
#[must_use]
pub fn model_version(&self) -> &str {
&self.model_version
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_model_version() {
let content = r"
# sqry-nl Intent Classifier Model
model_version=1.0.0
model_date=2025-12-09T07:34:00Z
accuracy=0.9998
";
assert_eq!(parse_model_version(content), "1.0.0");
}
#[test]
fn test_parse_model_version_missing() {
let content = "# No version here\naccuracy=0.99";
assert_eq!(parse_model_version(content), "unknown");
}
#[test]
fn test_parse_model_version_empty() {
assert_eq!(parse_model_version(""), "unknown");
}
#[test]
#[ignore = "Requires trained model files"]
fn test_classifier_load() {
}
#[test]
#[ignore = "Requires trained model files"]
fn test_classifier_inference() {
}
#[test]
#[ignore = "Requires trained model files"]
fn test_checksum_verification() {
}
}