use crate::merge::Finding;
use crate::scan::ScanError;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum EngineError {
#[error("model not found in directory: {dir}")]
ModelNotFound {
dir: String,
},
#[error("tokenizer load failed: {0}")]
TokenizerLoad(String),
#[error("ORT session init failed: {0}")]
SessionInit(String),
#[error("inference run failed: {0}")]
InferRun(String),
#[error("decode tensor shape failed: {0}")]
DecodeShape(String),
#[error("internal engine error: {0}")]
Internal(String),
}
impl From<EngineError> for ScanError {
fn from(e: EngineError) -> Self {
ScanError::InferenceFailed {
reason: format!("{e}"),
}
}
}
pub trait RedactionEngine: Send + Sync {
fn infer(&self, text: &str) -> Result<Vec<Finding>, EngineError>;
fn infer_with_lang(
&self,
text: &str,
_lang: Option<&str>,
) -> Result<Vec<Finding>, EngineError> {
self.infer(text)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoopEngine;
impl RedactionEngine for NoopEngine {
fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
Ok(Vec::new())
}
}
#[derive(Debug, Default, Clone)]
pub struct MockEngine {
findings: Vec<Finding>,
}
impl MockEngine {
pub fn from_findings(findings: Vec<Finding>) -> Self {
Self { findings }
}
}
impl RedactionEngine for MockEngine {
fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
Ok(self.findings.clone())
}
}
#[cfg(test)]
mod static_assertions {
use super::*;
fn _assert_send_sync<T: Send + Sync>() {}
#[allow(dead_code)]
fn _check() {
_assert_send_sync::<MockEngine>();
_assert_send_sync::<NoopEngine>();
_assert_send_sync::<Box<dyn RedactionEngine>>();
}
}
#[cfg(feature = "ort")]
mod ort_engine {
use super::{EngineError, Finding, RedactionEngine};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use ort::execution_providers::CPUExecutionProvider;
use ort::inputs;
use ort::session::{builder::GraphOptimizationLevel, Session};
use ort::value::Value;
use tokenizers::Tokenizer;
pub struct OrtEngine {
session: Mutex<Session>,
tokenizer: Tokenizer,
id2label: Vec<String>,
#[allow(dead_code)]
model_dir: PathBuf,
descriptor: Box<dyn crate::model_descriptor::ModelDescriptor>,
}
impl std::fmt::Debug for OrtEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OrtEngine")
.field("model_dir", &self.model_dir)
.field("id2label_count", &self.id2label.len())
.finish_non_exhaustive()
}
}
impl OrtEngine {
pub fn from_env() -> Result<Self, EngineError> {
let dir = std::env::var("VIGIL_PRIVACY_FILTER_MODEL_DIR").map_err(|_| {
EngineError::ModelNotFound {
dir: "<env unset>".to_string(),
}
})?;
let model_dir = PathBuf::from(&dir);
let tok_path = model_dir.join("tokenizer.json");
let cfg_path = model_dir.join("config.json");
let onnx_path = model_dir.join("model_q4f16.onnx");
for p in [&tok_path, &cfg_path, &onnx_path] {
if !p.exists() {
return Err(EngineError::ModelNotFound { dir: dir.clone() });
}
}
let tokenizer = Tokenizer::from_file(&tok_path)
.map_err(|e| EngineError::TokenizerLoad(e.to_string()))?;
let id2label = parse_id2label(&cfg_path)?;
let _ = ort::init()
.with_name("vigil-redaction-ort")
.with_execution_providers([CPUExecutionProvider::default().build()])
.commit();
let session = Session::builder()
.map_err(|e| EngineError::SessionInit(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level1)
.map_err(|e| EngineError::SessionInit(e.to_string()))?
.with_intra_threads(4)
.map_err(|e| EngineError::SessionInit(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| EngineError::SessionInit(e.to_string()))?;
Ok(Self {
session: Mutex::new(session),
tokenizer,
id2label,
model_dir,
descriptor: Box::new(crate::model_descriptor::OpenAIPrivacyFilterDescriptor),
})
}
#[allow(dead_code)]
pub fn from_env_with_descriptor(
descriptor: Box<dyn crate::model_descriptor::ModelDescriptor>,
) -> Result<Self, EngineError> {
let mut engine = Self::from_env()?;
engine.descriptor = descriptor;
Ok(engine)
}
#[allow(dead_code)]
pub fn from_dir_with_descriptor(
dir: &Path,
descriptor: Box<dyn crate::model_descriptor::ModelDescriptor>,
) -> Result<Self, EngineError> {
let dir_str = dir.to_string_lossy().into_owned();
let model_dir = dir.to_path_buf();
let tok_path = model_dir.join("tokenizer.json");
let cfg_path = model_dir.join("config.json");
let onnx_path = model_dir.join(descriptor.onnx_filename());
for p in [&tok_path, &cfg_path, &onnx_path] {
if !p.exists() {
return Err(EngineError::ModelNotFound {
dir: dir_str.clone(),
});
}
}
let tokenizer = Tokenizer::from_file(&tok_path)
.map_err(|e| EngineError::TokenizerLoad(e.to_string()))?;
let id2label = parse_id2label(&cfg_path)?;
let _ = ort::init()
.with_name("vigil-redaction-ort")
.with_execution_providers([CPUExecutionProvider::default().build()])
.commit();
let session = Session::builder()
.map_err(|e| EngineError::SessionInit(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level1)
.map_err(|e| EngineError::SessionInit(e.to_string()))?
.with_intra_threads(4)
.map_err(|e| EngineError::SessionInit(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| EngineError::SessionInit(e.to_string()))?;
Ok(Self {
session: Mutex::new(session),
tokenizer,
id2label,
model_dir,
descriptor,
})
}
#[allow(dead_code)] pub fn descriptor_model_id(&self) -> &str {
self.descriptor.model_id()
}
pub fn warmup(&self) -> Result<(), EngineError> {
let _ = <Self as RedactionEngine>::infer(self, "a")?;
Ok(())
}
}
impl RedactionEngine for OrtEngine {
fn infer(&self, text: &str) -> Result<Vec<Finding>, EngineError> {
self.infer_with_lang(text, None)
}
fn infer_with_lang(
&self,
text: &str,
lang: Option<&str>,
) -> Result<Vec<Finding>, EngineError> {
let enc = self
.tokenizer
.encode(text, true)
.map_err(|e| EngineError::InferRun(e.to_string()))?;
let ids: Vec<i64> = enc.get_ids().iter().map(|&i| i as i64).collect();
let mask: Vec<i64> = enc.get_attention_mask().iter().map(|&m| m as i64).collect();
let offsets = enc.get_offsets().to_vec();
let seq_len = ids.len();
if seq_len == 0 {
return Ok(Vec::new());
}
let input_ids_val = Value::from_array((vec![1i64, seq_len as i64], ids))
.map_err(|e| EngineError::DecodeShape(e.to_string()))?;
let mask_val = Value::from_array((vec![1i64, seq_len as i64], mask))
.map_err(|e| EngineError::DecodeShape(e.to_string()))?;
let (shape, data): (Vec<i64>, Vec<f32>) = {
let mut session = self
.session
.lock()
.map_err(|e| EngineError::Internal(format!("session mutex poisoned: {e}")))?;
let outputs = session
.run(inputs![
"input_ids" => input_ids_val,
"attention_mask" => mask_val,
])
.map_err(|e| EngineError::InferRun(e.to_string()))?;
let (_name, logits_val) = outputs
.iter()
.next()
.ok_or_else(|| EngineError::DecodeShape("no output tensor".to_string()))?;
let (raw_shape, raw_data) = logits_val
.try_extract_tensor::<f32>()
.map_err(|e| EngineError::DecodeShape(e.to_string()))?;
(raw_shape.to_vec(), raw_data.to_vec())
};
if shape.len() != 3 || shape[0] != 1 || shape[1] as usize != seq_len {
return Err(EngineError::DecodeShape(format!(
"unexpected logits shape: {shape:?}"
)));
}
let num_labels = shape[2] as usize;
let mut token_preds: Vec<(usize, f32)> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let base = t * num_labels;
let slice = &data[base..base + num_labels];
let (arg, max_logit) = slice.iter().enumerate().fold(
(0usize, f32::NEG_INFINITY),
|(ai, av), (i, &v)| if v > av { (i, v) } else { (ai, av) },
);
let sum: f32 = slice.iter().map(|&v| (v - max_logit).exp()).sum();
let conf = if sum > 0.0 { 1.0 / sum } else { 0.0 };
token_preds.push((arg, conf));
}
let mut findings: Vec<Finding> = Vec::new();
let mut i = 0usize;
while i < seq_len {
let (lid, conf) = token_preds[i];
let label_raw = &self.id2label[lid];
if label_raw == "O" || label_raw.is_empty() {
i += 1;
continue;
}
let core_raw = strip_bioes(label_raw);
let start = offsets[i].0;
let mut end = offsets[i].1;
let mut conf_min = conf;
let mut j = i + 1;
while j < seq_len {
let (nid, nconf) = token_preds[j];
let nlabel = &self.id2label[nid];
if nlabel == "O" || strip_bioes(nlabel) != core_raw {
break;
}
end = offsets[j].1;
conf_min = conf_min.min(nconf);
j += 1;
}
if start < end && end <= text.len() {
match self.descriptor.canonical_mapping(core_raw) {
Some(label) => {
let min_conf_opt = self
.descriptor
.lang_conditional_profile()
.and_then(|p| p.threshold_for(label, lang))
.or_else(|| {
self.descriptor
.threshold_profile()
.and_then(|p| p.thresholds.get(&label).copied())
});
let pass_threshold = min_conf_opt
.map(|min_conf| conf_min >= min_conf)
.unwrap_or(true);
if pass_threshold {
findings.push(Finding::model(
label.as_str(),
(start, end),
conf_min,
0,
));
}
}
None => {
}
}
}
i = j.max(i + 1);
}
Ok(findings)
}
}
fn strip_bioes(label: &str) -> &str {
if let Some((prefix, rest)) = label.split_once('-') {
if matches!(prefix, "B" | "I" | "E" | "S") {
return rest;
}
}
label
}
fn parse_id2label(cfg_path: &Path) -> Result<Vec<String>, EngineError> {
let raw = std::fs::read_to_string(cfg_path)
.map_err(|e| EngineError::Internal(format!("read config.json: {e}")))?;
let cfg: serde_json::Value = serde_json::from_str(&raw)
.map_err(|e| EngineError::Internal(format!("parse config.json: {e}")))?;
let id2label = cfg
.get("id2label")
.and_then(|v| v.as_object())
.ok_or_else(|| EngineError::Internal("config.json missing id2label".to_string()))?;
let mut entries: Vec<(usize, String)> = id2label
.iter()
.map(|(k, v)| {
(
k.parse().unwrap_or(0),
v.as_str().unwrap_or("?").to_string(),
)
})
.collect();
entries.sort_by_key(|&(id, _)| id);
Ok(entries.into_iter().map(|(_, n)| n).collect())
}
#[cfg(test)]
mod ort_static_assertions {
use super::*;
fn _assert_send_sync<T: Send + Sync>() {}
#[allow(dead_code)]
fn _check() {
_assert_send_sync::<OrtEngine>();
}
}
}
#[cfg(feature = "ort")]
pub use ort_engine::OrtEngine;
#[cfg(test)]
mod tests {
use super::*;
use crate::merge::FindingSource;
#[test]
fn noop_engine_returns_empty_findings() {
let engine = NoopEngine;
let result = engine.infer("anything").expect("noop should not fail");
assert!(result.is_empty(), "NoopEngine 必须返空 Vec");
}
#[test]
fn mock_engine_returns_preset_findings() {
let preset = vec![
Finding::model("private_person", (0, 5), 0.9, 5),
Finding::model("private_email", (10, 30), 0.95, 10),
];
let engine = MockEngine::from_findings(preset.clone());
let got = engine.infer("ignored").expect("mock should not fail");
assert_eq!(got, preset, "MockEngine 应原样返回构造时的 findings");
let got2 = engine.infer("ignored").expect("mock again");
assert_eq!(got2, preset);
}
#[test]
fn mock_engine_default_is_empty() {
let engine = MockEngine::default();
let got = engine.infer("anything").expect("default mock");
assert!(got.is_empty());
}
#[test]
fn engine_error_to_scan_error_collapses_to_inference_failed() {
let cases: Vec<(EngineError, &str)> = vec![
(
EngineError::ModelNotFound {
dir: "/tmp/x".to_string(),
},
"model not found",
),
(
EngineError::TokenizerLoad("bad json".to_string()),
"tokenizer load",
),
(
EngineError::SessionInit("ort init fail".to_string()),
"session init",
),
(
EngineError::InferRun("session.run fail".to_string()),
"inference run",
),
(
EngineError::DecodeShape("bad shape".to_string()),
"decode tensor",
),
(
EngineError::Internal("config.json missing".to_string()),
"internal",
),
];
for (e, fragment) in cases {
let scan_err: ScanError = e.into();
assert!(
matches!(scan_err, ScanError::InferenceFailed { .. }),
"EngineError 应塌缩到 InferenceFailed,实际:{scan_err:?}"
);
if let ScanError::InferenceFailed { reason } = scan_err {
assert!(
reason.contains(fragment),
"InferenceFailed.reason 应含原 EngineError Display 片段 {fragment:?},\
实际 reason = {reason:?}"
);
}
}
}
#[test]
fn mock_engine_finding_source_is_model() {
let preset = vec![Finding::model("private_phone", (0, 11), 0.88, 5)];
let engine = MockEngine::from_findings(preset);
let got = engine.infer("ignored").expect("mock");
assert_eq!(got.len(), 1);
assert_eq!(got[0].source, FindingSource::Model);
}
#[cfg(feature = "ort")]
#[test]
fn ort_engine_from_env_with_descriptor_env_miss_returns_modelnotfound() {
if std::env::var("VIGIL_PRIVACY_FILTER_MODEL_DIR").is_ok() {
eprintln!("skip: env already set");
return;
}
let r = OrtEngine::from_env_with_descriptor(Box::new(
crate::model_descriptor::XlmrPiiDescriptor::default(),
));
assert!(
matches!(r, Err(EngineError::ModelNotFound { .. })),
"env unset 应返 ModelNotFound,实际: {:?}",
r.map(|_| "Ok(engine)")
);
}
#[test]
fn descriptors_dyn_box_compatible_with_engine_field() {
let _list: Vec<Box<dyn crate::model_descriptor::ModelDescriptor>> = vec![
Box::new(crate::model_descriptor::OpenAIPrivacyFilterDescriptor),
Box::new(crate::model_descriptor::XlmrPiiDescriptor::default()),
Box::new(crate::model_descriptor::YonigoPiiDescriptor),
];
}
#[cfg(feature = "ort")]
#[test]
fn ort_engine_from_dir_with_descriptor_missing_dir_returns_modelnotfound() {
use std::path::Path;
let bogus_dir = Path::new("/nonexistent/vigil/spike-p3/model");
let r = OrtEngine::from_dir_with_descriptor(
bogus_dir,
Box::new(crate::model_descriptor::XlmrPiiDescriptor::default()),
);
assert!(
matches!(r, Err(EngineError::ModelNotFound { .. })),
"不存在 dir 应返 ModelNotFound,实际: {:?}",
r.map(|_| "Ok(engine)")
);
}
}