vigil_redaction/engine.rs
1//! ISS-008 Phase 1:Privacy Filter 推理引擎抽象。
2//!
3//! 设计目标(详见 `docs/adr/0013-hardfp-model-merge.md` + roadmap ISS-008):
4//! - 为 [`crate::scan::scan_text`] 引入**可注入的 Model 侧 finding 来源**,使
5//! "Hard 路径已闭环 + Model 路径预留扩展"成为同一函数的两条 path,而不是
6//! 分两个公共 API 导致 caller 双轨升级。
7//! - **默认 feature 0 ort 痕迹**:[`NoopEngine`] 提供"返空 model findings"等价语义,
8//! `scan_text` delegating 到它,行为与 v0.3 完全一致(`scan_text_v03_public_api_intact`
9//! 守门测试不动继续过)。
10//! - **`--features ort` 路径**:[`OrtEngine`] 用 ORT 1.24 q4f16 Privacy Filter 模型
11//! 做真推理,产出 BIOES 解码后的 Span findings。
12//!
13//! **不变量**:
14//! - `EngineError` **不持有** `ort::Error`(后者非 `Send + Sync`),全部子分类持 `String`。
15//! - `From<EngineError> for ScanError` 一律塌缩为 `ScanError::InferenceFailed { reason }`,
16//! `reason` 仅来自 `e.to_string()`,**绝不**拼接 input 内容(由 caller 保证)。
17//! - 所有引擎实现都 `Send + Sync`,在编译期由 [`static_assertions`] mod 守门。
18//! - 未识别的模型 label(BIOES core 不在 [`crate::label::PrivacyLabel::from_kind`] 白名单)
19//! 走 `eprintln!` warn 跳过,**不**panic / **不**fail-closed(Phase 1 决议)。
20//! - `OrtEngine` 内部 `Mutex<Session>`:rc.12 `Session::run` 需要 `&mut self`,
21//! `infer(&self, ..)` 用 Mutex 包外让 trait 保持 `&self`(锁开销 ns 级 vs 推理 358-630 ms)。
22
23use crate::merge::Finding;
24use crate::scan::ScanError;
25use thiserror::Error;
26
27/// 引擎私有诊断错误。
28///
29/// 6 变体覆盖从模型加载到张量解码的全链路失败模式;**绝不**持有 `ort::Error`
30/// 本身(rc.12 该类型非 `Send + Sync`,会污染 trait object 边界)。所有 ort/tokenizer
31/// 错误一律 `e.to_string()` 后塞进 `String` 字段。
32///
33/// caller 拿到的不是这个类型 —— `From<EngineError> for ScanError` 把 6 变体
34/// 全塌缩到 [`ScanError::InferenceFailed`],只暴露统一的 `reason`。
35#[derive(Error, Debug)]
36pub enum EngineError {
37 /// 指定 dir 下缺少模型文件(`tokenizer.json` / `config.json` / `model_q4f16.onnx`
38 /// 三件齐全才算就绪)或 `VIGIL_PRIVACY_FILTER_MODEL_DIR` 未设置。
39 #[error("model not found in directory: {dir}")]
40 ModelNotFound {
41 /// 失败时尝试加载的目录(诊断用,不含模型权重内容)
42 dir: String,
43 },
44 /// `tokenizers::Tokenizer::from_file` 失败。内部串来自 tokenizers crate。
45 #[error("tokenizer load failed: {0}")]
46 TokenizerLoad(String),
47 /// ORT `Session::builder` / `commit_from_file` / 优化等级设置等 init 阶段失败。
48 #[error("ORT session init failed: {0}")]
49 SessionInit(String),
50 /// 推理执行阶段失败(`tokenizer.encode` / `session.run`)。
51 #[error("inference run failed: {0}")]
52 InferRun(String),
53 /// 输出张量 shape 不符预期或 `try_extract_tensor::<f32>` 失败。
54 #[error("decode tensor shape failed: {0}")]
55 DecodeShape(String),
56 /// 其它内部错误(config.json 解析 / Mutex poisoned 等),用兜底变体避免新增 variant
57 /// 立刻冲击 caller。
58 #[error("internal engine error: {0}")]
59 Internal(String),
60}
61
62impl From<EngineError> for ScanError {
63 fn from(e: EngineError) -> Self {
64 // 6 变体全塌缩到 InferenceFailed;reason 只来自 e.to_string(),
65 // 绝不拼接 input 内容(避免 caller 把原文 secret 写进 audit log)。
66 ScanError::InferenceFailed {
67 reason: format!("{e}"),
68 }
69 }
70}
71
72/// Privacy Filter 推理引擎抽象。`scan_text_with_engine` 通过本 trait 拿 Model 侧 findings,
73/// 与 Hard 侧 [`crate::scan::collect_hard_findings`] 输出送 `merge_findings` 决策合并。
74///
75/// 实现要求:
76/// - 必须 `Send + Sync`(由 trait bound 强制;`scan_text_with_engine` 接 `&dyn`,
77/// 线程边界由 caller 决定)。
78/// - `infer` 失败应返 [`EngineError`] 各分类;`scan_text_with_engine` 经
79/// `From<EngineError> for ScanError` 自动转 [`ScanError::InferenceFailed`]。
80/// - 返回的 `Finding` `risk_delta` **可填 0**:caller 在 `scan_text_with_engine` 内会
81/// 按 `risk_of(kind)` 重新补值(engine 与 risk 表彻底解耦,SSOT 见 ADR 0012 §1.3)。
82/// `kind` 字段必须是 `&'static str`(由 [`crate::label::PrivacyLabel::as_str`] 提供)。
83pub trait RedactionEngine: Send + Sync {
84 /// 对 `text` 做模型推理,返回 Model 侧 findings。
85 ///
86 /// # Errors
87 /// 任意 [`EngineError`] 变体表示推理失败;caller 的 `scan_text_with_engine`
88 /// 会以 `?` 转 [`ScanError::InferenceFailed`] 早返(fail-closed)。
89 fn infer(&self, text: &str) -> Result<Vec<Finding>, EngineError>;
90
91 /// **v0.9 Sprint 1 P1.2** — 带 lang 上下文的推理(spike)。
92 ///
93 /// **default 实现**:忽略 `lang` 参数,委托 [`Self::infer`](向后兼容,SemVer
94 /// 安全;现有 RedactionEngine 实现不需改)。
95 ///
96 /// **OrtEngine override**:若 descriptor 提供
97 /// [`crate::model_descriptor::LangConditionalThresholdProfile`](通过新方法
98 /// `lang_conditional_profile()`),threshold 应用时优先查
99 /// `(lang, label)` override;无则 fallback default profile。
100 ///
101 /// **lang 规范**:case-sensitive,推荐 ISO 639-1 lowercase(`"en"` / `"de"` /
102 /// `"it"` / `"fr"` / ...),与 fixture lang 字段对齐。`None` 等价 `infer()`
103 /// 行为(无 lang 上下文)。
104 fn infer_with_lang(
105 &self,
106 text: &str,
107 _lang: Option<&str>,
108 ) -> Result<Vec<Finding>, EngineError> {
109 self.infer(text)
110 }
111}
112
113/// "什么也不做"的引擎:始终返回空 Model findings。
114///
115/// 用途:让 [`crate::scan::scan_text`] 公共 API 可 delegating 到
116/// `scan_text_with_engine(input, &NoopEngine)`,默认 feature 路径 0 ort 依赖,
117/// 而行为与 Stage 1 scaffold "Hard + 空 Model" 完全等价。
118#[derive(Debug, Default, Clone, Copy)]
119pub struct NoopEngine;
120
121impl RedactionEngine for NoopEngine {
122 fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
123 Ok(Vec::new())
124 }
125}
126
127/// 测试 / 集成用:固定返回构造时给的 findings 切片。
128///
129/// 用于在不接真模型的前提下走通 `scan_text_with_engine` → merge 链路,验证
130/// risk_delta 注入 / merge / aggregate 各环节(无 ort 依赖,默认 feature 即可用)。
131#[derive(Debug, Default, Clone)]
132pub struct MockEngine {
133 findings: Vec<Finding>,
134}
135
136impl MockEngine {
137 /// 用一组预设 findings 构造 mock(每次 `infer` 都克隆返出)。
138 pub fn from_findings(findings: Vec<Finding>) -> Self {
139 Self { findings }
140 }
141}
142
143impl RedactionEngine for MockEngine {
144 fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
145 Ok(self.findings.clone())
146 }
147}
148
149// 编译期 Send + Sync 守门(默认 feature)。新增引擎类型必须同步进这里。
150#[cfg(test)]
151mod static_assertions {
152 use super::*;
153 fn _assert_send_sync<T: Send + Sync>() {}
154 #[allow(dead_code)]
155 fn _check() {
156 _assert_send_sync::<MockEngine>();
157 _assert_send_sync::<NoopEngine>();
158 _assert_send_sync::<Box<dyn RedactionEngine>>();
159 }
160}
161
162// ──────────────────────────── ORT 真推理引擎(feature gated)────────────────────────────
163
164#[cfg(feature = "ort")]
165mod ort_engine {
166 use super::{EngineError, Finding, RedactionEngine};
167 use std::path::{Path, PathBuf};
168 use std::sync::Mutex;
169
170 use ort::execution_providers::CPUExecutionProvider;
171 use ort::inputs;
172 use ort::session::{builder::GraphOptimizationLevel, Session};
173 use ort::value::Value;
174 use tokenizers::Tokenizer;
175
176 /// ORT 1.24 q4f16 Privacy Filter 模型推理引擎。
177 ///
178 /// **生命周期纪律**(ISS-022 Phase 2 实测,详见 `project_vigil_v04_iss022_done.md`):
179 /// - cold-start ~7 s(commit_from_file + 809 MB weights);构造一次长期持有,
180 /// 不要把 [`OrtEngine::from_env`] 放进 hot path。
181 /// - warm 推理 358-630 ms / sample(CPU);Stage 2 API 必须 async 化(ADR 0013)。
182 ///
183 /// **线程模型**:
184 /// - rc.12 `Session::run` 需 `&mut self`(spike `main.rs:165` 实测)。我们让
185 /// `OrtEngine.session` 持 `Mutex<Session>`,trait 保持 `infer(&self, ..)`。
186 /// 锁开销纳秒级,与 358 ms 推理相比可忽略;并发场景由 caller 决定是否多实例。
187 pub struct OrtEngine {
188 session: Mutex<Session>,
189 tokenizer: Tokenizer,
190 /// 由 `config.json` 解析得;index = label_id,value = label 字面量(可能含 BIOES 前缀)。
191 id2label: Vec<String>,
192 /// 仅供 `Debug` / 诊断;不参与推理逻辑。
193 #[allow(dead_code)]
194 model_dir: PathBuf,
195 /// v0.7-α3 Phase 3 S2(E6a):descriptor 决定 decode kind + canonical mapping;
196 /// 默认 [`OpenAIPrivacyFilterDescriptor`] 保 v0.6 回归不变;新模型走
197 /// [`Self::from_env_with_descriptor`] 注入。
198 descriptor: Box<dyn crate::model_descriptor::ModelDescriptor>,
199 }
200
201 impl std::fmt::Debug for OrtEngine {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 // Session / Tokenizer 不实现 Debug 或 Debug 输出冗长,这里只露诊断字段。
204 f.debug_struct("OrtEngine")
205 .field("model_dir", &self.model_dir)
206 .field("id2label_count", &self.id2label.len())
207 .finish_non_exhaustive()
208 }
209 }
210
211 impl OrtEngine {
212 /// 从环境变量 `VIGIL_PRIVACY_FILTER_MODEL_DIR`(absolute path)读模型并构造 Session。
213 ///
214 /// 模型目录必须包含三件套:`tokenizer.json` / `config.json` / `model_q4f16.onnx`。
215 /// 三件齐全才算就绪;任一缺失返 [`EngineError::ModelNotFound`]。
216 ///
217 /// # Errors
218 /// - [`EngineError::ModelNotFound`]:env 未设置 / dir 不存在 / 三件套缺失
219 /// - [`EngineError::TokenizerLoad`] / [`EngineError::SessionInit`] / [`EngineError::Internal`]:
220 /// 底层 init 失败(具体 e.to_string() 进 `String` 字段)
221 pub fn from_env() -> Result<Self, EngineError> {
222 let dir = std::env::var("VIGIL_PRIVACY_FILTER_MODEL_DIR").map_err(|_| {
223 EngineError::ModelNotFound {
224 dir: "<env unset>".to_string(),
225 }
226 })?;
227 let model_dir = PathBuf::from(&dir);
228 let tok_path = model_dir.join("tokenizer.json");
229 let cfg_path = model_dir.join("config.json");
230 let onnx_path = model_dir.join("model_q4f16.onnx");
231 // 三件齐全检查(spike 同口径);任一缺失即视为模型未就绪
232 for p in [&tok_path, &cfg_path, &onnx_path] {
233 if !p.exists() {
234 return Err(EngineError::ModelNotFound { dir: dir.clone() });
235 }
236 }
237
238 let tokenizer = Tokenizer::from_file(&tok_path)
239 .map_err(|e| EngineError::TokenizerLoad(e.to_string()))?;
240 let id2label = parse_id2label(&cfg_path)?;
241
242 // ort::init().commit() 在 rc.12 返 bool(成功 / 重复 init 都返 ok),不返 Result。
243 // 多次 init 无副作用(spike main.rs:31)。
244 let _ = ort::init()
245 .with_name("vigil-redaction-ort")
246 .with_execution_providers([CPUExecutionProvider::default().build()])
247 .commit();
248
249 let session = Session::builder()
250 .map_err(|e| EngineError::SessionInit(e.to_string()))?
251 .with_optimization_level(GraphOptimizationLevel::Level1)
252 .map_err(|e| EngineError::SessionInit(e.to_string()))?
253 .with_intra_threads(4)
254 .map_err(|e| EngineError::SessionInit(e.to_string()))?
255 .commit_from_file(&onnx_path)
256 .map_err(|e| EngineError::SessionInit(e.to_string()))?;
257
258 Ok(Self {
259 session: Mutex::new(session),
260 tokenizer,
261 id2label,
262 model_dir,
263 // 默认 OpenAIPrivacyFilterDescriptor(BIOES 解码 + 33-class id2label)
264 // 保 v0.6 回归不变;新模型走 from_env_with_descriptor 工厂
265 descriptor: Box::new(crate::model_descriptor::OpenAIPrivacyFilterDescriptor),
266 })
267 }
268
269 /// v0.7-α3 Phase 3 S2(E6a):带 descriptor 注入的工厂,支持 BIO scheme 模型。
270 ///
271 /// 与 [`Self::from_env`] 区别:descriptor 决定 decode kind(BIOES vs BIO)+
272 /// canonical_mapping(影响 [`crate::PrivacyLabel::from_kind`] 路由)。
273 ///
274 /// **典型用例**:
275 /// - xlmr-pii(BIO):`from_env_with_descriptor(Box::new(XlmrPiiDescriptor))`
276 /// - yonigo-pii(BIO):`from_env_with_descriptor(Box::new(YonigoPiiDescriptor))`
277 /// - openai(默认 BIOES):`from_env()`(等价 [`OpenAIPrivacyFilterDescriptor`])
278 ///
279 /// **环境变量**与 [`Self::from_env`] 一致使用 `VIGIL_PRIVACY_FILTER_MODEL_DIR`。
280 /// 若 ensemble 多模型场景需独立路径,用 [`Self::from_dir_with_descriptor`]。
281 #[allow(dead_code)]
282 pub fn from_env_with_descriptor(
283 descriptor: Box<dyn crate::model_descriptor::ModelDescriptor>,
284 ) -> Result<Self, EngineError> {
285 let mut engine = Self::from_env()?;
286 engine.descriptor = descriptor;
287 Ok(engine)
288 }
289
290 /// v0.7-α3 Phase 3 S4(E6a):从指定目录构造 OrtEngine + descriptor 注入。
291 ///
292 /// 与 [`Self::from_env_with_descriptor`] 区别:**不**读 env,直接接 `dir`
293 /// 参数。专为 ensemble 多模型场景:每模型独立 dir,避免 env var 互斥。
294 ///
295 /// **典型用例**(ensemble runtime):
296 /// ```ignore
297 /// use std::sync::Arc;
298 /// use std::path::Path;
299 /// use vigil_redaction::OrtEngine;
300 /// use vigil_redaction::EnsembleEngine;
301 /// // (model_descriptor 是 crate-private,这里仅示意)
302 /// // let openai = Arc::new(OrtEngine::from_dir_with_descriptor(
303 /// // Path::new("/var/vigil/models/openai-pf/v1"),
304 /// // Box::new(OpenAIPrivacyFilterDescriptor),
305 /// // ).unwrap());
306 /// // let xlmr = Arc::new(OrtEngine::from_dir_with_descriptor(...).unwrap());
307 /// // let ens = EnsembleEngine::new(vec![openai, xlmr]);
308 /// ```
309 ///
310 /// **三件套契约**与 [`Self::from_env`] 同口径:`tokenizer.json` /
311 /// `config.json` / `model_q4f16.onnx`(后续可能扩 model.onnx)三件齐全。
312 ///
313 /// # Errors
314 /// - [`EngineError::ModelNotFound`]:dir 不存在或三件套缺失
315 /// - [`EngineError::TokenizerLoad`] / [`EngineError::SessionInit`] /
316 /// [`EngineError::Internal`]:底层 init 失败
317 #[allow(dead_code)]
318 pub fn from_dir_with_descriptor(
319 dir: &Path,
320 descriptor: Box<dyn crate::model_descriptor::ModelDescriptor>,
321 ) -> Result<Self, EngineError> {
322 let dir_str = dir.to_string_lossy().into_owned();
323 let model_dir = dir.to_path_buf();
324 let tok_path = model_dir.join("tokenizer.json");
325 let cfg_path = model_dir.join("config.json");
326 // v0.7-α4 R1b:用 descriptor.onnx_filename() 取代 hardcoded "model_q4f16.onnx",
327 // 适配多模型布局(OpenAI 顶层 / xlmr 在 onnx/ 子目录 / yonigo model.onnx)
328 let onnx_path = model_dir.join(descriptor.onnx_filename());
329 for p in [&tok_path, &cfg_path, &onnx_path] {
330 if !p.exists() {
331 return Err(EngineError::ModelNotFound {
332 dir: dir_str.clone(),
333 });
334 }
335 }
336 let tokenizer = Tokenizer::from_file(&tok_path)
337 .map_err(|e| EngineError::TokenizerLoad(e.to_string()))?;
338 let id2label = parse_id2label(&cfg_path)?;
339 let _ = ort::init()
340 .with_name("vigil-redaction-ort")
341 .with_execution_providers([CPUExecutionProvider::default().build()])
342 .commit();
343 let session = Session::builder()
344 .map_err(|e| EngineError::SessionInit(e.to_string()))?
345 .with_optimization_level(GraphOptimizationLevel::Level1)
346 .map_err(|e| EngineError::SessionInit(e.to_string()))?
347 .with_intra_threads(4)
348 .map_err(|e| EngineError::SessionInit(e.to_string()))?
349 .commit_from_file(&onnx_path)
350 .map_err(|e| EngineError::SessionInit(e.to_string()))?;
351 Ok(Self {
352 session: Mutex::new(session),
353 tokenizer,
354 id2label,
355 model_dir,
356 descriptor,
357 })
358 }
359
360 /// 返回当前 engine 装载的 descriptor model_id(诊断 / audit 关联)。
361 #[allow(dead_code)] // S3 EnsembleEngine 用此调度三引擎
362 pub fn descriptor_model_id(&self) -> &str {
363 self.descriptor.model_id()
364 }
365
366 /// v0.7-α2 Phase 2B(ADR 0016):预热 ORT session,把 cold inference 摊到启动期。
367 ///
368 /// **意图**:首次 [`infer`] 调用包含 graph optimization / kernel JIT / arena
369 /// 分配等 cold-path 开销(实测 ~7s on CPU q4f16);本 API 用 1-token 短文本
370 /// 触发同样路径,把 cold 开销前移到 app 启动 / 模型分发完成时,真正 user
371 /// 请求即落 warm 路径(实测 ~462ms/sample)。
372 ///
373 /// **不变量保留**:
374 /// - 仅消耗 1 次推理预算(短 prompt,~ms 级 token 数);不写日志、不影响 ledger
375 /// - 失败传 [`EngineError`] 但 caller 一般忽略(预热失败应不影响 cold-path 退化能力);
376 /// 推荐 caller `let _ = engine.warmup();` fire-and-forget
377 /// - 线程安全:与 `infer` 同走 `Mutex<Session>` 锁路径
378 ///
379 /// # 推荐用法
380 ///
381 /// apps/desktop GUI build 启动时异步 spawn:
382 /// ```ignore
383 /// let engine = Arc::new(OrtEngine::from_env()?);
384 /// std::thread::spawn({
385 /// let e = engine.clone();
386 /// move || { let _ = e.warmup(); }
387 /// });
388 /// ```
389 ///
390 /// # Errors
391 /// 同 [`infer`]:任何推理路径错误都会 propagate;caller 通常忽略。
392 pub fn warmup(&self) -> Result<(), EngineError> {
393 // 用单字符短文本(tokenizer 至少给 [CLS]+[SEP],seq_len ≥ 2);
394 // 推理结果丢弃,目的纯粹是触发 cold-path 一次性开销
395 let _ = <Self as RedactionEngine>::infer(self, "a")?;
396 Ok(())
397 }
398 }
399
400 impl RedactionEngine for OrtEngine {
401 fn infer(&self, text: &str) -> Result<Vec<Finding>, EngineError> {
402 // **v0.9 Sprint 1 P1.2**:legacy 路径 → infer_with_lang(text, None)
403 // (lang None 等价 v0.8 行为,threshold 走 threshold_profile() default;
404 // 不引入 LangConditionalThresholdProfile.overrides — caller 没 lang 上下文)
405 self.infer_with_lang(text, None)
406 }
407
408 fn infer_with_lang(
409 &self,
410 text: &str,
411 lang: Option<&str>,
412 ) -> Result<Vec<Finding>, EngineError> {
413 // ─── 1. tokenize ───
414 let enc = self
415 .tokenizer
416 .encode(text, true)
417 .map_err(|e| EngineError::InferRun(e.to_string()))?;
418 let ids: Vec<i64> = enc.get_ids().iter().map(|&i| i as i64).collect();
419 let mask: Vec<i64> = enc.get_attention_mask().iter().map(|&m| m as i64).collect();
420 let offsets = enc.get_offsets().to_vec();
421 let seq_len = ids.len();
422 if seq_len == 0 {
423 // 空 token 序列(理论上 tokenizer 至少给 [CLS]/[SEP],但保守兜底)
424 return Ok(Vec::new());
425 }
426
427 // ─── 2. 构 Value(spike main.rs:179-182 形态)───
428 let input_ids_val = Value::from_array((vec![1i64, seq_len as i64], ids))
429 .map_err(|e| EngineError::DecodeShape(e.to_string()))?;
430 let mask_val = Value::from_array((vec![1i64, seq_len as i64], mask))
431 .map_err(|e| EngineError::DecodeShape(e.to_string()))?;
432
433 // ─── 3 + 4. session.run + 提取 logits(都在锁内,因为 SessionOutputs<'_>
434 // 借 session;锁外只持 owned (shape, data) 副本以解耦借用)───
435 let (shape, data): (Vec<i64>, Vec<f32>) = {
436 let mut session = self
437 .session
438 .lock()
439 .map_err(|e| EngineError::Internal(format!("session mutex poisoned: {e}")))?;
440 let outputs = session
441 .run(inputs![
442 "input_ids" => input_ids_val,
443 "attention_mask" => mask_val,
444 ])
445 .map_err(|e| EngineError::InferRun(e.to_string()))?;
446
447 // 取 logits 张量(spike main.rs:192-198)。try_extract_tensor 借 outputs,
448 // outputs 又借 session;必须在锁释放前把数据 to_vec 出来。
449 let (_name, logits_val) = outputs
450 .iter()
451 .next()
452 .ok_or_else(|| EngineError::DecodeShape("no output tensor".to_string()))?;
453 let (raw_shape, raw_data) = logits_val
454 .try_extract_tensor::<f32>()
455 .map_err(|e| EngineError::DecodeShape(e.to_string()))?;
456 (raw_shape.to_vec(), raw_data.to_vec())
457 // session 锁在此 block 末释放,后续是纯 CPU 解码不持锁
458 };
459
460 if shape.len() != 3 || shape[0] != 1 || shape[1] as usize != seq_len {
461 return Err(EngineError::DecodeShape(format!(
462 "unexpected logits shape: {shape:?}"
463 )));
464 }
465 let num_labels = shape[2] as usize;
466
467 // ─── 5. argmax + max-shifted softmax(spike main.rs:201-211)───
468 // 注:`1.0 / sum_exp` 是 max-shifted softmax 的等价写法
469 // (exp(max - max) / Σexp(x - max) = 1 / Σ);保持与 spike 一致避免误改。
470 let mut token_preds: Vec<(usize, f32)> = Vec::with_capacity(seq_len);
471 for t in 0..seq_len {
472 let base = t * num_labels;
473 let slice = &data[base..base + num_labels];
474 let (arg, max_logit) = slice.iter().enumerate().fold(
475 (0usize, f32::NEG_INFINITY),
476 |(ai, av), (i, &v)| if v > av { (i, v) } else { (ai, av) },
477 );
478 let sum: f32 = slice.iter().map(|&v| (v - max_logit).exp()).sum();
479 let conf = if sum > 0.0 { 1.0 / sum } else { 0.0 };
480 token_preds.push((arg, conf));
481 }
482
483 // ─── 6. 合并 BIOES 同 core label 的连续 token 为 span ───
484 // (spike main.rs:214-247 同算法)
485 let mut findings: Vec<Finding> = Vec::new();
486 let mut i = 0usize;
487 while i < seq_len {
488 let (lid, conf) = token_preds[i];
489 let label_raw = &self.id2label[lid];
490 if label_raw == "O" || label_raw.is_empty() {
491 i += 1;
492 continue;
493 }
494 let core_raw = strip_bioes(label_raw);
495 let start = offsets[i].0;
496 let mut end = offsets[i].1;
497 let mut conf_min = conf;
498 let mut j = i + 1;
499 while j < seq_len {
500 let (nid, nconf) = token_preds[j];
501 let nlabel = &self.id2label[nid];
502 if nlabel == "O" || strip_bioes(nlabel) != core_raw {
503 break;
504 }
505 end = offsets[j].1;
506 conf_min = conf_min.min(nconf);
507 j += 1;
508 }
509
510 if start < end && end <= text.len() {
511 // ─── 7. canonical mapping via ModelDescriptor(v0.7-α3 S2)───
512 // descriptor.canonical_mapping(core_raw) 路由到 8 类 PrivacyLabel;
513 // OpenAI descriptor 内部 normalize lowercase,xlmr/yonigo 直
514 // match uppercase 字面量。SSOT 移到 descriptor,engine 不再
515 // hardcode PrivacyLabel::from_kind 路径,使新模型不需改 engine。
516 match self.descriptor.canonical_mapping(core_raw) {
517 Some(label) => {
518 // ─── v0.7-α4 R1h + v0.9 Sprint 1 P1.2 — threshold filter ───
519 // 优先级:
520 // 1. lang_conditional_profile().threshold_for(label, lang)
521 // (P1.2 新路径 — caller 提供 lang 时命中 (lang, label)
522 // override,否则该 profile 自身的 default)
523 // 2. fallback threshold_profile()(legacy / lang None 路径)
524 //
525 // > 1.0 阈值等价"屏蔽该 label",留给互补 engine + Hard
526 // rules 兜底。**关键**:不能用 `continue`(会跳过外层
527 // `i = j.max(i + 1)` 推进导致死循环);用 if-pass 包 push。
528 let min_conf_opt = self
529 .descriptor
530 .lang_conditional_profile()
531 .and_then(|p| p.threshold_for(label, lang))
532 .or_else(|| {
533 self.descriptor
534 .threshold_profile()
535 .and_then(|p| p.thresholds.get(&label).copied())
536 });
537 let pass_threshold = min_conf_opt
538 .map(|min_conf| conf_min >= min_conf)
539 .unwrap_or(true);
540 if pass_threshold {
541 // Finding.kind 是 &'static str,这里用 PrivacyLabel::as_str()
542 // 拿 'static 字面量(label.rs::as_str 已是 'static 契约)。
543 // risk_delta 由 caller `scan_text_with_engine` 按 risk_of(kind)
544 // 重新补值(C-7 决议:engine 不依赖 risk 表,避免漂移)。
545 findings.push(Finding::model(
546 label.as_str(),
547 (start, end),
548 conf_min,
549 0,
550 ));
551 }
552 // pass_threshold == false → silent drop(R1h FP filter)
553 }
554 None => {
555 // descriptor 显式 None = 该 native label 在 canonical 8 类外
556 // 应忽略(如 OpenAI/xlmr 的 AGE/GENDER/SEX);非显式漏(隐式
557 // 遗漏由 assert_canonical_mapping_total 测试守门捕获)。
558 // 不再 stderr warn 避免噪声(改 quiet 跳过)。
559 }
560 }
561 }
562 i = j.max(i + 1);
563 }
564 Ok(findings)
565 }
566 }
567
568 /// 剥 BIOES 前缀:`B-Person` / `I-Person` / `E-Person` / `S-Person` → `Person`。
569 /// 非 BIOES 前缀的 label 原样返回(例如 spike 模型可能直出 `private_email`)。
570 fn strip_bioes(label: &str) -> &str {
571 if let Some((prefix, rest)) = label.split_once('-') {
572 if matches!(prefix, "B" | "I" | "E" | "S") {
573 return rest;
574 }
575 }
576 label
577 }
578
579 /// 从 `config.json` 抽 `id2label` 表,按 id 升序还原 `Vec<String>`。
580 /// HF 标准 config 格式:`{"id2label": {"0": "O", "1": "B-Person", ...}}`。
581 fn parse_id2label(cfg_path: &Path) -> Result<Vec<String>, EngineError> {
582 let raw = std::fs::read_to_string(cfg_path)
583 .map_err(|e| EngineError::Internal(format!("read config.json: {e}")))?;
584 let cfg: serde_json::Value = serde_json::from_str(&raw)
585 .map_err(|e| EngineError::Internal(format!("parse config.json: {e}")))?;
586 let id2label = cfg
587 .get("id2label")
588 .and_then(|v| v.as_object())
589 .ok_or_else(|| EngineError::Internal("config.json missing id2label".to_string()))?;
590 let mut entries: Vec<(usize, String)> = id2label
591 .iter()
592 .map(|(k, v)| {
593 (
594 k.parse().unwrap_or(0),
595 v.as_str().unwrap_or("?").to_string(),
596 )
597 })
598 .collect();
599 entries.sort_by_key(|&(id, _)| id);
600 Ok(entries.into_iter().map(|(_, n)| n).collect())
601 }
602
603 // 编译期 Send + Sync 守门(--features ort 路径)
604 #[cfg(test)]
605 mod ort_static_assertions {
606 use super::*;
607 fn _assert_send_sync<T: Send + Sync>() {}
608 #[allow(dead_code)]
609 fn _check() {
610 _assert_send_sync::<OrtEngine>();
611 }
612 }
613}
614
615#[cfg(feature = "ort")]
616pub use ort_engine::OrtEngine;
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use crate::merge::FindingSource;
622
623 #[test]
624 fn noop_engine_returns_empty_findings() {
625 let engine = NoopEngine;
626 let result = engine.infer("anything").expect("noop should not fail");
627 assert!(result.is_empty(), "NoopEngine 必须返空 Vec");
628 }
629
630 #[test]
631 fn mock_engine_returns_preset_findings() {
632 let preset = vec![
633 Finding::model("private_person", (0, 5), 0.9, 5),
634 Finding::model("private_email", (10, 30), 0.95, 10),
635 ];
636 let engine = MockEngine::from_findings(preset.clone());
637 let got = engine.infer("ignored").expect("mock should not fail");
638 assert_eq!(got, preset, "MockEngine 应原样返回构造时的 findings");
639 // 第二次调用不应被消耗
640 let got2 = engine.infer("ignored").expect("mock again");
641 assert_eq!(got2, preset);
642 }
643
644 #[test]
645 fn mock_engine_default_is_empty() {
646 let engine = MockEngine::default();
647 let got = engine.infer("anything").expect("default mock");
648 assert!(got.is_empty());
649 }
650
651 #[test]
652 fn engine_error_to_scan_error_collapses_to_inference_failed() {
653 let cases: Vec<(EngineError, &str)> = vec![
654 (
655 EngineError::ModelNotFound {
656 dir: "/tmp/x".to_string(),
657 },
658 "model not found",
659 ),
660 (
661 EngineError::TokenizerLoad("bad json".to_string()),
662 "tokenizer load",
663 ),
664 (
665 EngineError::SessionInit("ort init fail".to_string()),
666 "session init",
667 ),
668 (
669 EngineError::InferRun("session.run fail".to_string()),
670 "inference run",
671 ),
672 (
673 EngineError::DecodeShape("bad shape".to_string()),
674 "decode tensor",
675 ),
676 (
677 EngineError::Internal("config.json missing".to_string()),
678 "internal",
679 ),
680 ];
681 for (e, fragment) in cases {
682 let scan_err: ScanError = e.into();
683 // 6 EngineError 变体必须全塌缩到 InferenceFailed,先用 matches! 守门塌缩走向,
684 // 再单独取 reason 校验包含原 Display 片段(避免单 if-let 的 else 走 panic!,
685 // 兼顾 workspace clippy::panic + clippy::assertions_on_constants 双严格规则)。
686 assert!(
687 matches!(scan_err, ScanError::InferenceFailed { .. }),
688 "EngineError 应塌缩到 InferenceFailed,实际:{scan_err:?}"
689 );
690 if let ScanError::InferenceFailed { reason } = scan_err {
691 assert!(
692 reason.contains(fragment),
693 "InferenceFailed.reason 应含原 EngineError Display 片段 {fragment:?},\
694 实际 reason = {reason:?}"
695 );
696 }
697 }
698 }
699
700 #[test]
701 fn mock_engine_finding_source_is_model() {
702 let preset = vec![Finding::model("private_phone", (0, 11), 0.88, 5)];
703 let engine = MockEngine::from_findings(preset);
704 let got = engine.infer("ignored").expect("mock");
705 assert_eq!(got.len(), 1);
706 assert_eq!(got[0].source, FindingSource::Model);
707 }
708
709 // ─────────── v0.7-α3 S2 守门 ───────────
710
711 /// OrtEngine.from_env_with_descriptor 在 env 缺失时返 ModelNotFound 不 panic。
712 /// 与 from_env 同 fail-fast 口径(沿用 ADR 0012)。
713 #[cfg(feature = "ort")]
714 #[test]
715 fn ort_engine_from_env_with_descriptor_env_miss_returns_modelnotfound() {
716 if std::env::var("VIGIL_PRIVACY_FILTER_MODEL_DIR").is_ok() {
717 eprintln!("skip: env already set");
718 return;
719 }
720 // 注入 XlmrPiiDescriptor — 工厂应仍 fail-fast 在 env miss(因为没有真模型路径)
721 let r = OrtEngine::from_env_with_descriptor(Box::new(
722 crate::model_descriptor::XlmrPiiDescriptor::default(),
723 ));
724 assert!(
725 matches!(r, Err(EngineError::ModelNotFound { .. })),
726 "env unset 应返 ModelNotFound,实际: {:?}",
727 r.map(|_| "Ok(engine)")
728 );
729 }
730
731 /// 三 descriptor 类型可作为 Box<dyn ModelDescriptor> 注入(类型层兼容守门)。
732 /// 编译期检查;不需真 ort 模型。
733 #[test]
734 fn descriptors_dyn_box_compatible_with_engine_field() {
735 // 编译期类型测试:Box<dyn ModelDescriptor> 可装载 3 个实例,符合 OrtEngine.descriptor 字段类型
736 let _list: Vec<Box<dyn crate::model_descriptor::ModelDescriptor>> = vec![
737 Box::new(crate::model_descriptor::OpenAIPrivacyFilterDescriptor),
738 Box::new(crate::model_descriptor::XlmrPiiDescriptor::default()),
739 Box::new(crate::model_descriptor::YonigoPiiDescriptor),
740 ];
741 }
742
743 /// S4(E6a):from_dir_with_descriptor 在 dir 不存在时返 ModelNotFound 不 panic。
744 #[cfg(feature = "ort")]
745 #[test]
746 fn ort_engine_from_dir_with_descriptor_missing_dir_returns_modelnotfound() {
747 use std::path::Path;
748 let bogus_dir = Path::new("/nonexistent/vigil/spike-p3/model");
749 let r = OrtEngine::from_dir_with_descriptor(
750 bogus_dir,
751 Box::new(crate::model_descriptor::XlmrPiiDescriptor::default()),
752 );
753 assert!(
754 matches!(r, Err(EngineError::ModelNotFound { .. })),
755 "不存在 dir 应返 ModelNotFound,实际: {:?}",
756 r.map(|_| "Ok(engine)")
757 );
758 }
759}