use serde::{Deserialize, Serialize};
fn default_classifier_timeout_ms() -> u64 {
5000
}
fn default_injection_model() -> String {
"protectai/deberta-v3-small-prompt-injection-v2".into()
}
fn default_injection_threshold() -> f32 {
0.95
}
fn default_injection_threshold_soft() -> f32 {
0.5
}
fn default_enforcement_mode() -> InjectionEnforcementMode {
InjectionEnforcementMode::Warn
}
fn default_pii_model() -> String {
"iiiorg/piiranha-v1-detect-personal-information".into()
}
fn default_pii_threshold() -> f32 {
0.75
}
fn default_pii_ner_max_chars() -> usize {
8192
}
fn default_pii_ner_circuit_breaker() -> u32 {
2
}
fn default_pii_ner_allowlist() -> Vec<String> {
vec![
"Zeph".into(),
"Rust".into(),
"OpenAI".into(),
"Ollama".into(),
"Claude".into(),
]
}
fn default_three_class_threshold() -> f32 {
0.7
}
fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
if value.is_nan() || value.is_infinite() {
return Err(serde::de::Error::custom(
"threshold must be a finite number",
));
}
if !(value > 0.0 && value <= 1.0) {
return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
}
Ok(value)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum InjectionEnforcementMode {
Warn,
Block,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ClassifiersConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_classifier_timeout_ms")]
pub timeout_ms: u64,
#[serde(default)]
pub hf_token: Option<String>,
#[serde(default)]
pub scan_user_input: bool,
#[serde(default = "default_injection_model")]
pub injection_model: String,
#[serde(default = "default_enforcement_mode")]
pub enforcement_mode: InjectionEnforcementMode,
#[serde(
default = "default_injection_threshold_soft",
deserialize_with = "validate_unit_threshold"
)]
pub injection_threshold_soft: f32,
#[serde(
default = "default_injection_threshold",
deserialize_with = "validate_unit_threshold"
)]
pub injection_threshold: f32,
#[serde(default)]
pub injection_model_sha256: Option<String>,
#[serde(default)]
pub three_class_model: Option<String>,
#[serde(
default = "default_three_class_threshold",
deserialize_with = "validate_unit_threshold"
)]
pub three_class_threshold: f32,
#[serde(default)]
pub three_class_model_sha256: Option<String>,
#[serde(default)]
pub pii_enabled: bool,
#[serde(default = "default_pii_model")]
pub pii_model: String,
#[serde(default = "default_pii_threshold")]
pub pii_threshold: f32,
#[serde(default)]
pub pii_model_sha256: Option<String>,
#[serde(default = "default_pii_ner_max_chars")]
pub pii_ner_max_chars: usize,
#[serde(default = "default_pii_ner_allowlist")]
pub pii_ner_allowlist: Vec<String>,
#[serde(default = "default_pii_ner_circuit_breaker")]
pub pii_ner_circuit_breaker: u32,
}
impl Default for ClassifiersConfig {
fn default() -> Self {
Self {
enabled: false,
timeout_ms: default_classifier_timeout_ms(),
hf_token: None,
scan_user_input: false,
injection_model: default_injection_model(),
enforcement_mode: default_enforcement_mode(),
injection_threshold_soft: default_injection_threshold_soft(),
injection_threshold: default_injection_threshold(),
injection_model_sha256: None,
three_class_model: None,
three_class_threshold: default_three_class_threshold(),
three_class_model_sha256: None,
pii_enabled: false,
pii_model: default_pii_model(),
pii_threshold: default_pii_threshold(),
pii_model_sha256: None,
pii_ner_max_chars: default_pii_ner_max_chars(),
pii_ner_allowlist: default_pii_ner_allowlist(),
pii_ner_circuit_breaker: default_pii_ner_circuit_breaker(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_values() {
let cfg = ClassifiersConfig::default();
assert!(!cfg.enabled);
assert_eq!(cfg.timeout_ms, 5000);
assert!(cfg.hf_token.is_none());
assert!(!cfg.scan_user_input);
assert_eq!(
cfg.injection_model,
"protectai/deberta-v3-small-prompt-injection-v2"
);
assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
assert!(cfg.injection_model_sha256.is_none());
assert!(cfg.three_class_model.is_none());
assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
assert!(cfg.three_class_model_sha256.is_none());
assert!(!cfg.pii_enabled);
assert_eq!(
cfg.pii_model,
"iiiorg/piiranha-v1-detect-personal-information"
);
assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
assert!(cfg.pii_model_sha256.is_none());
assert_eq!(
cfg.pii_ner_allowlist,
vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
);
}
#[test]
fn hf_token_and_scan_user_input_round_trip() {
let toml = r#"
hf_token = "hf_secret"
scan_user_input = true
"#;
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
assert!(cfg.scan_user_input);
}
#[test]
fn deserialize_empty_section_uses_defaults() {
let cfg: ClassifiersConfig = toml::from_str("").unwrap();
assert!(!cfg.enabled);
assert_eq!(cfg.timeout_ms, 5000);
assert_eq!(
cfg.injection_model,
"protectai/deberta-v3-small-prompt-injection-v2"
);
assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
assert!(!cfg.pii_enabled);
assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
}
#[test]
fn deserialize_custom_values() {
let toml = r#"
enabled = true
timeout_ms = 2000
injection_model = "custom/model-v1"
injection_threshold = 0.9
pii_enabled = true
pii_threshold = 0.85
"#;
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert!(cfg.enabled);
assert_eq!(cfg.timeout_ms, 2000);
assert_eq!(cfg.injection_model, "custom/model-v1");
assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
assert!(cfg.pii_enabled);
assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
}
#[test]
fn deserialize_sha256_fields() {
let toml = r#"
injection_model_sha256 = "abc123"
pii_model_sha256 = "def456"
"#;
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
}
#[test]
fn serialize_roundtrip() {
let original = ClassifiersConfig {
enabled: true,
timeout_ms: 3000,
hf_token: Some("hf_test_token".into()),
scan_user_input: true,
injection_model: "org/model".into(),
enforcement_mode: InjectionEnforcementMode::Block,
injection_threshold_soft: 0.45,
injection_threshold: 0.75,
injection_model_sha256: Some("deadbeef".into()),
three_class_model: Some("org/three-class".into()),
three_class_threshold: 0.65,
three_class_model_sha256: Some("abc456".into()),
pii_enabled: true,
pii_model: "org/pii-model".into(),
pii_threshold: 0.80,
pii_model_sha256: None,
pii_ner_max_chars: 4096,
pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
pii_ner_circuit_breaker: 3,
};
let serialized = toml::to_string(&original).unwrap();
let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn dual_threshold_deserialization() {
let toml = r"
injection_threshold_soft = 0.4
injection_threshold = 0.85
";
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
}
#[test]
fn soft_threshold_defaults_when_only_hard_provided() {
let toml = "injection_threshold = 0.9";
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
}
#[test]
fn partial_override_timeout_only() {
let toml = "timeout_ms = 1000";
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert!(!cfg.enabled);
assert_eq!(cfg.timeout_ms, 1000);
assert_eq!(
cfg.injection_model,
"protectai/deberta-v3-small-prompt-injection-v2"
);
assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
}
#[test]
fn enforcement_mode_warn_is_default() {
let cfg: ClassifiersConfig = toml::from_str("").unwrap();
assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
}
#[test]
fn enforcement_mode_block_roundtrip() {
let toml = r#"enforcement_mode = "block""#;
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
let back = toml::to_string(&cfg).unwrap();
let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
}
#[test]
fn threshold_validation_rejects_zero() {
let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
assert!(result.is_err());
}
#[test]
fn threshold_validation_rejects_above_one() {
let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
assert!(result.is_err());
}
#[test]
fn threshold_validation_accepts_exactly_one() {
let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
}
#[test]
fn threshold_validation_soft_rejects_zero() {
let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
assert!(result.is_err());
}
#[test]
fn three_class_model_roundtrip() {
let toml = r#"
three_class_model = "org/align-sentinel"
three_class_threshold = 0.65
three_class_model_sha256 = "aabbcc"
"#;
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
}
#[test]
fn pii_ner_allowlist_default_entries() {
let cfg = ClassifiersConfig::default();
assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
}
#[test]
fn pii_ner_allowlist_configurable() {
let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
}
#[test]
fn pii_ner_allowlist_empty_disables() {
let toml = "pii_ner_allowlist = []";
let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
assert!(cfg.pii_ner_allowlist.is_empty());
}
#[test]
fn three_class_threshold_validation_rejects_zero() {
let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
assert!(result.is_err());
}
#[test]
fn pii_ner_circuit_breaker_default() {
let cfg = ClassifiersConfig::default();
assert_eq!(cfg.pii_ner_circuit_breaker, 2);
}
#[test]
fn pii_ner_circuit_breaker_configurable() {
let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 5").unwrap();
assert_eq!(cfg.pii_ner_circuit_breaker, 5);
}
#[test]
fn pii_ner_circuit_breaker_zero_disables() {
let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 0").unwrap();
assert_eq!(cfg.pii_ner_circuit_breaker, 0);
}
#[test]
fn pii_ner_circuit_breaker_missing_uses_default() {
let cfg: ClassifiersConfig = toml::from_str("").unwrap();
assert_eq!(cfg.pii_ner_circuit_breaker, 2);
}
}