use serde::{Deserialize, Serialize};
use thiserror::Error;
pub const SCHEMA_VERSION: u32 = 1;
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Debug, Default)]
#[serde(rename_all = "snake_case")]
pub enum FailMode {
#[default]
Open,
Closed,
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
#[serde(deny_unknown_fields)]
pub struct InjectionPolicy {
pub schema_version: u32,
pub reject_threshold: f64,
pub pattern_weight: f64,
pub model_weight: f64,
pub model_threshold: f64,
pub gate_attributions: Vec<String>,
pub fail_mode: FailMode,
}
impl Default for InjectionPolicy {
fn default() -> Self {
Self {
schema_version: SCHEMA_VERSION,
reject_threshold: 0.85,
pattern_weight: 0.5,
model_weight: 0.5,
model_threshold: 0.5,
gate_attributions: vec![
"third_party".to_owned(),
"community".to_owned(),
"unknown".to_owned(),
],
fail_mode: FailMode::Open,
}
}
}
impl InjectionPolicy {
pub fn parse(body: &str) -> Result<Self, InjectionPolicyError> {
let policy: Self =
toml::from_str(body).map_err(|e| InjectionPolicyError::Parse(e.to_string()))?;
if policy.schema_version != SCHEMA_VERSION {
return Err(InjectionPolicyError::SchemaVersionMismatch {
found: policy.schema_version,
expected: SCHEMA_VERSION,
});
}
policy.validate()?;
Ok(policy)
}
fn validate(&self) -> Result<(), InjectionPolicyError> {
let weights: [(&str, f64); 4] = [
("reject_threshold", self.reject_threshold),
("pattern_weight", self.pattern_weight),
("model_weight", self.model_weight),
("model_threshold", self.model_threshold),
];
for (field, value) in weights {
if !value.is_finite() || value < 0.0 {
return Err(InjectionPolicyError::InvalidWeight { field: field.to_owned(), value });
}
}
for (field, value) in [
("reject_threshold", self.reject_threshold),
("model_threshold", self.model_threshold),
] {
if value > 1.0 {
return Err(InjectionPolicyError::InvalidWeight { field: field.to_owned(), value });
}
}
Ok(())
}
#[must_use]
pub fn blend(&self, pattern: f64, model: Option<f64>) -> f64 {
let Some(model) = model else {
return pattern.clamp(0.0, 1.0);
};
let denom = self.pattern_weight + self.model_weight;
if denom <= 0.0 {
return pattern.clamp(0.0, 1.0);
}
let weighted = self
.pattern_weight
.mul_add(pattern, self.model_weight * model);
(weighted / denom).clamp(0.0, 1.0)
}
#[must_use]
pub fn model_gated_for(&self, attribution: &str) -> bool {
self.gate_attributions.iter().any(|a| a == attribution)
}
}
#[derive(Debug, Error)]
pub enum InjectionPolicyError {
#[error("failed to parse injection policy: {0}")]
Parse(String),
#[error("injection policy schema_version={found}; expected {expected}")]
SchemaVersionMismatch {
found: u32,
expected: u32,
},
#[error("injection policy field `{field}` has invalid value {value}")]
InvalidWeight {
field: String,
value: f64,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_are_well_formed() {
let p = InjectionPolicy::default();
assert_eq!(p.schema_version, 1);
assert!((p.reject_threshold - 0.85).abs() < 1e-12);
assert!((p.pattern_weight - 0.5).abs() < 1e-12);
assert!((p.model_weight - 0.5).abs() < 1e-12);
assert!((p.model_threshold - 0.5).abs() < 1e-12);
assert_eq!(p.fail_mode, FailMode::Open);
assert!(p.model_gated_for("unknown"));
assert!(!p.model_gated_for("foundation"));
}
#[test]
fn round_trips_through_toml() {
let body = toml::to_string(&InjectionPolicy::default()).unwrap();
let back = InjectionPolicy::parse(&body).unwrap();
assert_eq!(back, InjectionPolicy::default());
}
#[test]
fn rejects_schema_mismatch() {
let p = InjectionPolicy {
schema_version: 99,
..InjectionPolicy::default()
};
let body = toml::to_string(&p).unwrap();
let err = InjectionPolicy::parse(&body).unwrap_err();
assert!(matches!(
err,
InjectionPolicyError::SchemaVersionMismatch { found: 99, expected: 1 }
));
}
#[test]
fn rejects_unknown_key() {
let mut body = toml::to_string(&InjectionPolicy::default()).unwrap();
body.push_str("\nbogus_top_level_key = 42\n");
let err = InjectionPolicy::parse(&body).unwrap_err();
assert!(matches!(err, InjectionPolicyError::Parse(_)));
}
#[test]
fn rejects_negative_weight() {
let p = InjectionPolicy {
pattern_weight: -1.0,
..InjectionPolicy::default()
};
let body = toml::to_string(&p).unwrap();
let err = InjectionPolicy::parse(&body).unwrap_err();
assert!(matches!(
err,
InjectionPolicyError::InvalidWeight { ref field, .. } if field == "pattern_weight"
));
}
#[test]
fn rejects_threshold_above_one() {
let p = InjectionPolicy {
reject_threshold: 1.5,
..InjectionPolicy::default()
};
let body = toml::to_string(&p).unwrap();
let err = InjectionPolicy::parse(&body).unwrap_err();
assert!(matches!(
err,
InjectionPolicyError::InvalidWeight { ref field, .. } if field == "reject_threshold"
));
}
#[test]
fn blend_without_model_returns_pattern() {
let p = InjectionPolicy::default();
assert!((p.blend(0.7, None) - 0.7).abs() < 1e-12);
assert!((p.blend(1.5, None) - 1.0).abs() < 1e-12);
assert!((p.blend(-0.2, None) - 0.0).abs() < 1e-12);
}
#[test]
fn blend_with_model_is_weighted_mean() {
let p = InjectionPolicy::default(); assert!((p.blend(0.8, Some(0.4)) - 0.6).abs() < 1e-12);
let skewed = InjectionPolicy {
pattern_weight: 0.75,
model_weight: 0.25,
..InjectionPolicy::default()
};
assert!((skewed.blend(0.8, Some(0.4)) - 0.7).abs() < 1e-12);
}
#[test]
fn blend_with_zero_weights_falls_back_to_pattern() {
let p = InjectionPolicy {
pattern_weight: 0.0,
model_weight: 0.0,
..InjectionPolicy::default()
};
assert!((p.blend(0.42, Some(0.99)) - 0.42).abs() < 1e-12);
}
#[test]
fn gating_matches_default_set() {
let p = InjectionPolicy::default();
for a in ["third_party", "community", "unknown"] {
assert!(p.model_gated_for(a), "{a} should gate");
}
for a in ["foundation", "partner"] {
assert!(!p.model_gated_for(a), "{a} should not gate");
}
}
}