Skip to main content

mnm_core/injection/
policy.rs

1//! Server-side injection-scoring policy TOML loader.
2//!
3//! Mirrors [`crate::scoring_policy`] in style: a schema-versioned, fail-fast
4//! TOML shape with `deny_unknown_fields`, finite/range validation, and a
5//! thiserror error enum. Loaded once at server startup; invalid TOML fails the
6//! load (Constitution fail-fast). The policy controls how the pattern score and
7//! the optional model score are blended into a single reject decision, and which
8//! source attributions trigger the (more expensive) model pass.
9
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13/// Canonical schema version for injection-policy TOML.
14pub const SCHEMA_VERSION: u32 = 1;
15
16/// What to do when the model scorer is unavailable or errors.
17#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Debug, Default)]
18#[serde(rename_all = "snake_case")]
19pub enum FailMode {
20    /// Admit the content when the model pass cannot run (availability-first).
21    #[default]
22    Open,
23    /// Reject the content when the model pass cannot run (security-first).
24    Closed,
25}
26
27/// Full injection-policy shape — every knob is independently overridable.
28#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
29#[serde(deny_unknown_fields)]
30pub struct InjectionPolicy {
31    /// Schema sentinel. Always `1` in v1.
32    pub schema_version: u32,
33    /// Blended-score threshold at or above which content is rejected.
34    pub reject_threshold: f64,
35    /// Weight of the pattern score in the blend.
36    pub pattern_weight: f64,
37    /// Weight of the model score in the blend.
38    pub model_weight: f64,
39    /// Model score at or above which the model considers content injection.
40    pub model_threshold: f64,
41    /// Source attributions that trigger the model pass (e.g. untrusted tiers).
42    pub gate_attributions: Vec<String>,
43    /// Behaviour when the model pass cannot run.
44    pub fail_mode: FailMode,
45}
46
47impl Default for InjectionPolicy {
48    fn default() -> Self {
49        Self {
50            schema_version: SCHEMA_VERSION,
51            reject_threshold: 0.85,
52            pattern_weight: 0.5,
53            model_weight: 0.5,
54            model_threshold: 0.5,
55            gate_attributions: vec![
56                "third_party".to_owned(),
57                "community".to_owned(),
58                "unknown".to_owned(),
59            ],
60            fail_mode: FailMode::Open,
61        }
62    }
63}
64
65impl InjectionPolicy {
66    /// Parse an injection-policy TOML body.
67    ///
68    /// # Errors
69    ///
70    /// Returns [`InjectionPolicyError::Parse`] if the TOML is malformed or
71    /// carries unknown keys, [`InjectionPolicyError::SchemaVersionMismatch`] if
72    /// the schema sentinel disagrees, or [`InjectionPolicyError::InvalidWeight`]
73    /// if any weight is non-finite/negative or a threshold falls outside
74    /// `0.0..=1.0`.
75    pub fn parse(body: &str) -> Result<Self, InjectionPolicyError> {
76        let policy: Self =
77            toml::from_str(body).map_err(|e| InjectionPolicyError::Parse(e.to_string()))?;
78        if policy.schema_version != SCHEMA_VERSION {
79            return Err(InjectionPolicyError::SchemaVersionMismatch {
80                found: policy.schema_version,
81                expected: SCHEMA_VERSION,
82            });
83        }
84        policy.validate()?;
85        Ok(policy)
86    }
87
88    fn validate(&self) -> Result<(), InjectionPolicyError> {
89        let weights: [(&str, f64); 4] = [
90            ("reject_threshold", self.reject_threshold),
91            ("pattern_weight", self.pattern_weight),
92            ("model_weight", self.model_weight),
93            ("model_threshold", self.model_threshold),
94        ];
95        for (field, value) in weights {
96            if !value.is_finite() || value < 0.0 {
97                return Err(InjectionPolicyError::InvalidWeight { field: field.to_owned(), value });
98            }
99        }
100        for (field, value) in [
101            ("reject_threshold", self.reject_threshold),
102            ("model_threshold", self.model_threshold),
103        ] {
104            if value > 1.0 {
105                return Err(InjectionPolicyError::InvalidWeight { field: field.to_owned(), value });
106            }
107        }
108        Ok(())
109    }
110
111    /// Blend the pattern score with an optional model score into `0.0..=1.0`.
112    ///
113    /// With no model score, the clamped pattern score is returned. Otherwise the
114    /// weighted mean `(pw·pattern + mw·model) / (pw + mw)` is returned, clamped.
115    /// If both weights are zero (degenerate config), the pattern score is used.
116    #[must_use]
117    pub fn blend(&self, pattern: f64, model: Option<f64>) -> f64 {
118        let Some(model) = model else {
119            return pattern.clamp(0.0, 1.0);
120        };
121        let denom = self.pattern_weight + self.model_weight;
122        if denom <= 0.0 {
123            return pattern.clamp(0.0, 1.0);
124        }
125        let weighted = self
126            .pattern_weight
127            .mul_add(pattern, self.model_weight * model);
128        (weighted / denom).clamp(0.0, 1.0)
129    }
130
131    /// Whether the (more expensive) model pass should run for `attribution`.
132    #[must_use]
133    pub fn model_gated_for(&self, attribution: &str) -> bool {
134        self.gate_attributions.iter().any(|a| a == attribution)
135    }
136}
137
138/// All the ways injection-policy parsing can fail.
139#[derive(Debug, Error)]
140pub enum InjectionPolicyError {
141    /// TOML body did not parse against the [`InjectionPolicy`] shape.
142    #[error("failed to parse injection policy: {0}")]
143    Parse(String),
144    /// `schema_version` did not match [`SCHEMA_VERSION`].
145    #[error("injection policy schema_version={found}; expected {expected}")]
146    SchemaVersionMismatch {
147        /// The schema version we found.
148        found: u32,
149        /// The version we expected.
150        expected: u32,
151    },
152    /// A weight/threshold was non-finite, negative, or out of `0.0..=1.0`.
153    #[error("injection policy field `{field}` has invalid value {value}")]
154    InvalidWeight {
155        /// Name of the offending field.
156        field: String,
157        /// The offending value.
158        value: f64,
159    },
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn defaults_are_well_formed() {
168        let p = InjectionPolicy::default();
169        assert_eq!(p.schema_version, 1);
170        assert!((p.reject_threshold - 0.85).abs() < 1e-12);
171        assert!((p.pattern_weight - 0.5).abs() < 1e-12);
172        assert!((p.model_weight - 0.5).abs() < 1e-12);
173        assert!((p.model_threshold - 0.5).abs() < 1e-12);
174        assert_eq!(p.fail_mode, FailMode::Open);
175        assert!(p.model_gated_for("unknown"));
176        assert!(!p.model_gated_for("foundation"));
177    }
178
179    #[test]
180    fn round_trips_through_toml() {
181        let body = toml::to_string(&InjectionPolicy::default()).unwrap();
182        let back = InjectionPolicy::parse(&body).unwrap();
183        assert_eq!(back, InjectionPolicy::default());
184    }
185
186    #[test]
187    fn rejects_schema_mismatch() {
188        let p = InjectionPolicy {
189            schema_version: 99,
190            ..InjectionPolicy::default()
191        };
192        let body = toml::to_string(&p).unwrap();
193        let err = InjectionPolicy::parse(&body).unwrap_err();
194        assert!(matches!(
195            err,
196            InjectionPolicyError::SchemaVersionMismatch { found: 99, expected: 1 }
197        ));
198    }
199
200    #[test]
201    fn rejects_unknown_key() {
202        let mut body = toml::to_string(&InjectionPolicy::default()).unwrap();
203        body.push_str("\nbogus_top_level_key = 42\n");
204        let err = InjectionPolicy::parse(&body).unwrap_err();
205        assert!(matches!(err, InjectionPolicyError::Parse(_)));
206    }
207
208    #[test]
209    fn rejects_negative_weight() {
210        let p = InjectionPolicy {
211            pattern_weight: -1.0,
212            ..InjectionPolicy::default()
213        };
214        let body = toml::to_string(&p).unwrap();
215        let err = InjectionPolicy::parse(&body).unwrap_err();
216        assert!(matches!(
217            err,
218            InjectionPolicyError::InvalidWeight { ref field, .. } if field == "pattern_weight"
219        ));
220    }
221
222    #[test]
223    fn rejects_threshold_above_one() {
224        let p = InjectionPolicy {
225            reject_threshold: 1.5,
226            ..InjectionPolicy::default()
227        };
228        let body = toml::to_string(&p).unwrap();
229        let err = InjectionPolicy::parse(&body).unwrap_err();
230        assert!(matches!(
231            err,
232            InjectionPolicyError::InvalidWeight { ref field, .. } if field == "reject_threshold"
233        ));
234    }
235
236    #[test]
237    fn blend_without_model_returns_pattern() {
238        let p = InjectionPolicy::default();
239        assert!((p.blend(0.7, None) - 0.7).abs() < 1e-12);
240        // Clamped.
241        assert!((p.blend(1.5, None) - 1.0).abs() < 1e-12);
242        assert!((p.blend(-0.2, None) - 0.0).abs() < 1e-12);
243    }
244
245    #[test]
246    fn blend_with_model_is_weighted_mean() {
247        let p = InjectionPolicy::default(); // 0.5 / 0.5
248        assert!((p.blend(0.8, Some(0.4)) - 0.6).abs() < 1e-12);
249
250        let skewed = InjectionPolicy {
251            pattern_weight: 0.75,
252            model_weight: 0.25,
253            ..InjectionPolicy::default()
254        };
255        // (0.75*0.8 + 0.25*0.4) / 1.0 = 0.7
256        assert!((skewed.blend(0.8, Some(0.4)) - 0.7).abs() < 1e-12);
257    }
258
259    #[test]
260    fn blend_with_zero_weights_falls_back_to_pattern() {
261        let p = InjectionPolicy {
262            pattern_weight: 0.0,
263            model_weight: 0.0,
264            ..InjectionPolicy::default()
265        };
266        assert!((p.blend(0.42, Some(0.99)) - 0.42).abs() < 1e-12);
267    }
268
269    #[test]
270    fn gating_matches_default_set() {
271        let p = InjectionPolicy::default();
272        for a in ["third_party", "community", "unknown"] {
273            assert!(p.model_gated_for(a), "{a} should gate");
274        }
275        for a in ["foundation", "partner"] {
276            assert!(!p.model_gated_for(a), "{a} should not gate");
277        }
278    }
279}