Skip to main content

zeph_config/
classifiers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7    5000
8}
9
10fn default_injection_model() -> String {
11    "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15    0.95
16}
17
18fn default_injection_threshold_soft() -> f32 {
19    0.5
20}
21
22fn default_enforcement_mode() -> InjectionEnforcementMode {
23    InjectionEnforcementMode::Warn
24}
25
26fn default_pii_model() -> String {
27    "iiiorg/piiranha-v1-detect-personal-information".into()
28}
29
30fn default_pii_threshold() -> f32 {
31    0.75
32}
33
34fn default_pii_ner_max_chars() -> usize {
35    8192
36}
37
38fn default_pii_ner_circuit_breaker() -> u32 {
39    2
40}
41
42fn default_pii_ner_allowlist() -> Vec<String> {
43    vec![
44        "Zeph".into(),
45        "Rust".into(),
46        "OpenAI".into(),
47        "Ollama".into(),
48        "Claude".into(),
49    ]
50}
51
52fn default_three_class_threshold() -> f32 {
53    0.7
54}
55
56fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
57where
58    D: serde::Deserializer<'de>,
59{
60    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
61    if value.is_nan() || value.is_infinite() {
62        return Err(serde::de::Error::custom(
63            "threshold must be a finite number",
64        ));
65    }
66    if !(value > 0.0 && value <= 1.0) {
67        return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
68    }
69    Ok(value)
70}
71
72/// Enforcement mode for the injection classifier.
73///
74/// `warn` (default): scores above `injection_threshold` emit WARN and increment metrics
75/// but do NOT block content. Use this when deploying `DeBERTa` classifiers on tool outputs —
76/// FPR of 12-37% on benign content makes hard-blocking unsafe.
77///
78/// `block`: scores above `injection_threshold` block content (behavior before v0.17).
79/// Only safe for well-calibrated models or when FPR is verified on your workload.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
81#[serde(rename_all = "snake_case")]
82pub enum InjectionEnforcementMode {
83    /// Log + metric only, never block.
84    Warn,
85    /// Block content above hard threshold.
86    Block,
87}
88
89/// Configuration for the ML-backed classifier subsystem.
90///
91/// Placed under `[classifiers]` in `config.toml`. All fields are optional with safe defaults
92/// so existing configs continue to work when this section is absent.
93///
94/// When `enabled = false` (the default), all classifier code is bypassed and the existing
95/// regex-based detection runs unchanged.
96#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
97pub struct ClassifiersConfig {
98    /// Master switch. When `false`, classifiers are never loaded or invoked.
99    #[serde(default)]
100    pub enabled: bool,
101
102    /// Per-inference timeout in milliseconds.
103    ///
104    /// On timeout the call site falls back to regex. Separate from model download time.
105    #[serde(default = "default_classifier_timeout_ms")]
106    pub timeout_ms: u64,
107
108    /// Resolved `HuggingFace` Hub API token.
109    ///
110    /// Must be the **token value** (not a vault key name) — resolved by the caller before
111    /// constructing `ClassifiersConfig`. When `None`, model downloads are unauthenticated,
112    /// which fails for gated or private repos.
113    #[serde(default)]
114    pub hf_token: Option<String>,
115
116    /// When `true`, the ML injection classifier runs on direct user chat messages.
117    ///
118    /// Default `false`: the `DeBERTa` model is intended for external/untrusted content
119    /// (tool output, web scrapes) — not for direct user input. Enabling this may cause
120    /// false positives on benign conversational messages.
121    #[serde(default)]
122    pub scan_user_input: bool,
123
124    /// `HuggingFace` repo ID for the injection detection model.
125    #[serde(default = "default_injection_model")]
126    pub injection_model: String,
127
128    /// Enforcement mode for the injection classifier.
129    ///
130    /// `warn` (default): scores above `injection_threshold` emit WARN and increment metrics
131    /// but do NOT block content. Use this when deploying classifiers on tool outputs —
132    /// FPR of 12-37% on benign content makes hard-blocking unsafe.
133    ///
134    /// `block`: scores above `injection_threshold` block content. Only safe for well-calibrated
135    /// models or when FPR is verified on your workload.
136    #[serde(default = "default_enforcement_mode")]
137    pub enforcement_mode: InjectionEnforcementMode,
138
139    /// Soft threshold: classifier score at or above this emits a WARN log and increments
140    /// the suspicious-injection metric, but content is allowed through.
141    ///
142    /// Range: `(0.0, 1.0]`. Default `0.5`. Must be ≤ `injection_threshold`.
143    #[serde(
144        default = "default_injection_threshold_soft",
145        deserialize_with = "validate_unit_threshold"
146    )]
147    pub injection_threshold_soft: f32,
148
149    /// Hard threshold: classifier score at or above this blocks the content (in `block` mode)
150    /// or emits WARN (in `warn` mode).
151    ///
152    /// Range: `(0.0, 1.0]`. Conservative default of `0.95` minimises false positives.
153    /// Real-world ML injection classifiers have 12–37% recall gaps at high thresholds —
154    /// defense-in-depth via regex fallback and spotlighting is mandatory.
155    #[serde(
156        default = "default_injection_threshold",
157        deserialize_with = "validate_unit_threshold"
158    )]
159    pub injection_threshold: f32,
160
161    /// Optional SHA-256 hex digest of the injection model safetensors file.
162    ///
163    /// When set, the file is verified before loading. Mismatch aborts startup with an error.
164    /// Useful for security-sensitive deployments to detect corruption or tampering.
165    #[serde(default)]
166    pub injection_model_sha256: Option<String>,
167
168    /// Optional `HuggingFace` repo ID or local path for the three-class `AlignSentinel` model.
169    ///
170    /// When set, content flagged as Suspicious or Blocked by the binary `DeBERTa` classifier
171    /// is passed to this model for refinement. If the three-class model classifies the content
172    /// as `aligned-instruction` or `no-instruction`, the verdict is downgraded to `Clean`.
173    /// This directly reduces false positives from legitimate instruction-style content.
174    #[serde(default)]
175    pub three_class_model: Option<String>,
176
177    /// Confidence threshold for the three-class model's `misaligned-instruction` label.
178    ///
179    /// Content is only kept as Suspicious/Blocked when the misaligned score meets this threshold.
180    /// Range: `(0.0, 1.0]`. Default `0.7`.
181    #[serde(
182        default = "default_three_class_threshold",
183        deserialize_with = "validate_unit_threshold"
184    )]
185    pub three_class_threshold: f32,
186
187    /// Optional SHA-256 hex digest of the three-class model safetensors file.
188    #[serde(default)]
189    pub three_class_model_sha256: Option<String>,
190
191    /// Enable PII detection via the NER model (`pii_model`).
192    ///
193    /// When `true`, `CandlePiiClassifier` runs on user messages in addition to the
194    /// regex-based `PiiFilter`. Both results are merged (union with deduplication).
195    #[serde(default)]
196    pub pii_enabled: bool,
197
198    /// `HuggingFace` repo ID for the PII NER model.
199    #[serde(default = "default_pii_model")]
200    pub pii_model: String,
201
202    /// Minimum per-token confidence to accept a PII label.
203    ///
204    /// Tokens below this threshold are treated as O (no entity).
205    /// Default `0.75` balances recall on rarer entity types (DRIVERLICENSE, PASSPORT, IBAN)
206    /// with precision. Raise to `0.85` to prefer precision over recall.
207    #[serde(default = "default_pii_threshold")]
208    pub pii_threshold: f32,
209
210    /// Optional SHA-256 hex digest of the PII model safetensors file.
211    #[serde(default)]
212    pub pii_model_sha256: Option<String>,
213
214    /// Maximum number of bytes passed to the NER PII classifier per call.
215    ///
216    /// Input is truncated at a valid UTF-8 boundary before classification to prevent
217    /// timeout on large tool outputs (e.g. `search_code`). Default `8192`.
218    #[serde(default = "default_pii_ner_max_chars")]
219    pub pii_ner_max_chars: usize,
220
221    /// Allowlist of tokens that are never redacted by the NER PII classifier, regardless
222    /// of model confidence.
223    ///
224    /// Matching is case-insensitive and exact (whole span text must equal an allowlist entry).
225    /// This suppresses common false positives from the piiranha model — for example,
226    /// "Zeph" is misclassified as a city (PII:CITY) by the base model.
227    ///
228    /// Default entries: `["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]`.
229    /// Set to `[]` to disable the allowlist entirely.
230    #[serde(default = "default_pii_ner_allowlist")]
231    pub pii_ner_allowlist: Vec<String>,
232
233    /// Number of consecutive NER timeouts before the circuit breaker trips and disables NER
234    /// for the remainder of the session.
235    ///
236    /// When the breaker trips, all subsequent chunks fall back to regex-only PII detection,
237    /// preventing repeated timeout stalls on paginated reads (e.g. 12 chunks × 30 s = 6 min).
238    /// Set to `0` to disable the circuit breaker (NER is always attempted).
239    ///
240    /// Default: `2`. Takes effect on the next session start if changed mid-session.
241    #[serde(default = "default_pii_ner_circuit_breaker")]
242    pub pii_ner_circuit_breaker: u32,
243}
244
245impl Default for ClassifiersConfig {
246    fn default() -> Self {
247        Self {
248            enabled: false,
249            timeout_ms: default_classifier_timeout_ms(),
250            hf_token: None,
251            scan_user_input: false,
252            injection_model: default_injection_model(),
253            enforcement_mode: default_enforcement_mode(),
254            injection_threshold_soft: default_injection_threshold_soft(),
255            injection_threshold: default_injection_threshold(),
256            injection_model_sha256: None,
257            three_class_model: None,
258            three_class_threshold: default_three_class_threshold(),
259            three_class_model_sha256: None,
260            pii_enabled: false,
261            pii_model: default_pii_model(),
262            pii_threshold: default_pii_threshold(),
263            pii_model_sha256: None,
264            pii_ner_max_chars: default_pii_ner_max_chars(),
265            pii_ner_allowlist: default_pii_ner_allowlist(),
266            pii_ner_circuit_breaker: default_pii_ner_circuit_breaker(),
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn default_values() {
277        let cfg = ClassifiersConfig::default();
278        assert!(!cfg.enabled);
279        assert_eq!(cfg.timeout_ms, 5000);
280        assert!(cfg.hf_token.is_none());
281        assert!(!cfg.scan_user_input);
282        assert_eq!(
283            cfg.injection_model,
284            "protectai/deberta-v3-small-prompt-injection-v2"
285        );
286        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
287        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
288        assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
289        assert!(cfg.injection_model_sha256.is_none());
290        assert!(cfg.three_class_model.is_none());
291        assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
292        assert!(cfg.three_class_model_sha256.is_none());
293        assert!(!cfg.pii_enabled);
294        assert_eq!(
295            cfg.pii_model,
296            "iiiorg/piiranha-v1-detect-personal-information"
297        );
298        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
299        assert!(cfg.pii_model_sha256.is_none());
300        assert_eq!(
301            cfg.pii_ner_allowlist,
302            vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
303        );
304    }
305
306    #[test]
307    fn hf_token_and_scan_user_input_round_trip() {
308        let toml = r#"
309            hf_token = "hf_secret"
310            scan_user_input = true
311        "#;
312        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
313        assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
314        assert!(cfg.scan_user_input);
315    }
316
317    #[test]
318    fn deserialize_empty_section_uses_defaults() {
319        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
320        assert!(!cfg.enabled);
321        assert_eq!(cfg.timeout_ms, 5000);
322        assert_eq!(
323            cfg.injection_model,
324            "protectai/deberta-v3-small-prompt-injection-v2"
325        );
326        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
327        assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
328        assert!(!cfg.pii_enabled);
329        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
330    }
331
332    #[test]
333    fn deserialize_custom_values() {
334        let toml = r#"
335            enabled = true
336            timeout_ms = 2000
337            injection_model = "custom/model-v1"
338            injection_threshold = 0.9
339            pii_enabled = true
340            pii_threshold = 0.85
341        "#;
342        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
343        assert!(cfg.enabled);
344        assert_eq!(cfg.timeout_ms, 2000);
345        assert_eq!(cfg.injection_model, "custom/model-v1");
346        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
347        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
348        assert!(cfg.pii_enabled);
349        assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
350    }
351
352    #[test]
353    fn deserialize_sha256_fields() {
354        let toml = r#"
355            injection_model_sha256 = "abc123"
356            pii_model_sha256 = "def456"
357        "#;
358        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
359        assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
360        assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
361    }
362
363    #[test]
364    fn serialize_roundtrip() {
365        let original = ClassifiersConfig {
366            enabled: true,
367            timeout_ms: 3000,
368            hf_token: Some("hf_test_token".into()),
369            scan_user_input: true,
370            injection_model: "org/model".into(),
371            enforcement_mode: InjectionEnforcementMode::Block,
372            injection_threshold_soft: 0.45,
373            injection_threshold: 0.75,
374            injection_model_sha256: Some("deadbeef".into()),
375            three_class_model: Some("org/three-class".into()),
376            three_class_threshold: 0.65,
377            three_class_model_sha256: Some("abc456".into()),
378            pii_enabled: true,
379            pii_model: "org/pii-model".into(),
380            pii_threshold: 0.80,
381            pii_model_sha256: None,
382            pii_ner_max_chars: 4096,
383            pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
384            pii_ner_circuit_breaker: 3,
385        };
386        let serialized = toml::to_string(&original).unwrap();
387        let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
388        assert_eq!(original, deserialized);
389    }
390
391    #[test]
392    fn dual_threshold_deserialization() {
393        let toml = r"
394            injection_threshold_soft = 0.4
395            injection_threshold = 0.85
396        ";
397        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
398        assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
399        assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
400    }
401
402    #[test]
403    fn soft_threshold_defaults_when_only_hard_provided() {
404        let toml = "injection_threshold = 0.9";
405        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
406        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
407        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
408    }
409
410    #[test]
411    fn partial_override_timeout_only() {
412        let toml = "timeout_ms = 1000";
413        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
414        assert!(!cfg.enabled);
415        assert_eq!(cfg.timeout_ms, 1000);
416        assert_eq!(
417            cfg.injection_model,
418            "protectai/deberta-v3-small-prompt-injection-v2"
419        );
420        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
421        assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
422    }
423
424    #[test]
425    fn enforcement_mode_warn_is_default() {
426        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
427        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
428    }
429
430    #[test]
431    fn enforcement_mode_block_roundtrip() {
432        let toml = r#"enforcement_mode = "block""#;
433        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
434        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
435        let back = toml::to_string(&cfg).unwrap();
436        let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
437        assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
438    }
439
440    #[test]
441    fn threshold_validation_rejects_zero() {
442        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
443        assert!(result.is_err());
444    }
445
446    #[test]
447    fn threshold_validation_rejects_above_one() {
448        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
449        assert!(result.is_err());
450    }
451
452    #[test]
453    fn threshold_validation_accepts_exactly_one() {
454        let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
455        assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
456    }
457
458    #[test]
459    fn threshold_validation_soft_rejects_zero() {
460        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
461        assert!(result.is_err());
462    }
463
464    #[test]
465    fn three_class_model_roundtrip() {
466        let toml = r#"
467            three_class_model = "org/align-sentinel"
468            three_class_threshold = 0.65
469            three_class_model_sha256 = "aabbcc"
470        "#;
471        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
472        assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
473        assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
474        assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
475    }
476
477    #[test]
478    fn pii_ner_allowlist_default_entries() {
479        let cfg = ClassifiersConfig::default();
480        assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
481        assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
482        assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
483        assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
484        assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
485    }
486
487    #[test]
488    fn pii_ner_allowlist_configurable() {
489        let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
490        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
491        assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
492    }
493
494    #[test]
495    fn pii_ner_allowlist_empty_disables() {
496        let toml = "pii_ner_allowlist = []";
497        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
498        assert!(cfg.pii_ner_allowlist.is_empty());
499    }
500
501    #[test]
502    fn three_class_threshold_validation_rejects_zero() {
503        let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
504        assert!(result.is_err());
505    }
506
507    #[test]
508    fn pii_ner_circuit_breaker_default() {
509        let cfg = ClassifiersConfig::default();
510        assert_eq!(cfg.pii_ner_circuit_breaker, 2);
511    }
512
513    #[test]
514    fn pii_ner_circuit_breaker_configurable() {
515        let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 5").unwrap();
516        assert_eq!(cfg.pii_ner_circuit_breaker, 5);
517    }
518
519    #[test]
520    fn pii_ner_circuit_breaker_zero_disables() {
521        let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 0").unwrap();
522        assert_eq!(cfg.pii_ner_circuit_breaker, 0);
523    }
524
525    #[test]
526    fn pii_ner_circuit_breaker_missing_uses_default() {
527        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
528        assert_eq!(cfg.pii_ner_circuit_breaker, 2);
529    }
530}