1use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13pub const SCHEMA_VERSION: u32 = 1;
15
16#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Debug, Default)]
18#[serde(rename_all = "snake_case")]
19pub enum FailMode {
20 #[default]
22 Open,
23 Closed,
25}
26
27#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
29#[serde(deny_unknown_fields)]
30pub struct InjectionPolicy {
31 pub schema_version: u32,
33 pub reject_threshold: f64,
35 pub pattern_weight: f64,
37 pub model_weight: f64,
39 pub model_threshold: f64,
41 pub gate_attributions: Vec<String>,
43 pub fail_mode: FailMode,
45}
46
47impl Default for InjectionPolicy {
48 fn default() -> Self {
49 Self {
50 schema_version: SCHEMA_VERSION,
51 reject_threshold: 0.85,
52 pattern_weight: 0.5,
53 model_weight: 0.5,
54 model_threshold: 0.5,
55 gate_attributions: vec![
56 "third_party".to_owned(),
57 "community".to_owned(),
58 "unknown".to_owned(),
59 ],
60 fail_mode: FailMode::Open,
61 }
62 }
63}
64
65impl InjectionPolicy {
66 pub fn parse(body: &str) -> Result<Self, InjectionPolicyError> {
76 let policy: Self =
77 toml::from_str(body).map_err(|e| InjectionPolicyError::Parse(e.to_string()))?;
78 if policy.schema_version != SCHEMA_VERSION {
79 return Err(InjectionPolicyError::SchemaVersionMismatch {
80 found: policy.schema_version,
81 expected: SCHEMA_VERSION,
82 });
83 }
84 policy.validate()?;
85 Ok(policy)
86 }
87
88 fn validate(&self) -> Result<(), InjectionPolicyError> {
89 let weights: [(&str, f64); 4] = [
90 ("reject_threshold", self.reject_threshold),
91 ("pattern_weight", self.pattern_weight),
92 ("model_weight", self.model_weight),
93 ("model_threshold", self.model_threshold),
94 ];
95 for (field, value) in weights {
96 if !value.is_finite() || value < 0.0 {
97 return Err(InjectionPolicyError::InvalidWeight { field: field.to_owned(), value });
98 }
99 }
100 for (field, value) in [
101 ("reject_threshold", self.reject_threshold),
102 ("model_threshold", self.model_threshold),
103 ] {
104 if value > 1.0 {
105 return Err(InjectionPolicyError::InvalidWeight { field: field.to_owned(), value });
106 }
107 }
108 Ok(())
109 }
110
111 #[must_use]
117 pub fn blend(&self, pattern: f64, model: Option<f64>) -> f64 {
118 let Some(model) = model else {
119 return pattern.clamp(0.0, 1.0);
120 };
121 let denom = self.pattern_weight + self.model_weight;
122 if denom <= 0.0 {
123 return pattern.clamp(0.0, 1.0);
124 }
125 let weighted = self
126 .pattern_weight
127 .mul_add(pattern, self.model_weight * model);
128 (weighted / denom).clamp(0.0, 1.0)
129 }
130
131 #[must_use]
133 pub fn model_gated_for(&self, attribution: &str) -> bool {
134 self.gate_attributions.iter().any(|a| a == attribution)
135 }
136}
137
138#[derive(Debug, Error)]
140pub enum InjectionPolicyError {
141 #[error("failed to parse injection policy: {0}")]
143 Parse(String),
144 #[error("injection policy schema_version={found}; expected {expected}")]
146 SchemaVersionMismatch {
147 found: u32,
149 expected: u32,
151 },
152 #[error("injection policy field `{field}` has invalid value {value}")]
154 InvalidWeight {
155 field: String,
157 value: f64,
159 },
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn defaults_are_well_formed() {
168 let p = InjectionPolicy::default();
169 assert_eq!(p.schema_version, 1);
170 assert!((p.reject_threshold - 0.85).abs() < 1e-12);
171 assert!((p.pattern_weight - 0.5).abs() < 1e-12);
172 assert!((p.model_weight - 0.5).abs() < 1e-12);
173 assert!((p.model_threshold - 0.5).abs() < 1e-12);
174 assert_eq!(p.fail_mode, FailMode::Open);
175 assert!(p.model_gated_for("unknown"));
176 assert!(!p.model_gated_for("foundation"));
177 }
178
179 #[test]
180 fn round_trips_through_toml() {
181 let body = toml::to_string(&InjectionPolicy::default()).unwrap();
182 let back = InjectionPolicy::parse(&body).unwrap();
183 assert_eq!(back, InjectionPolicy::default());
184 }
185
186 #[test]
187 fn rejects_schema_mismatch() {
188 let p = InjectionPolicy {
189 schema_version: 99,
190 ..InjectionPolicy::default()
191 };
192 let body = toml::to_string(&p).unwrap();
193 let err = InjectionPolicy::parse(&body).unwrap_err();
194 assert!(matches!(
195 err,
196 InjectionPolicyError::SchemaVersionMismatch { found: 99, expected: 1 }
197 ));
198 }
199
200 #[test]
201 fn rejects_unknown_key() {
202 let mut body = toml::to_string(&InjectionPolicy::default()).unwrap();
203 body.push_str("\nbogus_top_level_key = 42\n");
204 let err = InjectionPolicy::parse(&body).unwrap_err();
205 assert!(matches!(err, InjectionPolicyError::Parse(_)));
206 }
207
208 #[test]
209 fn rejects_negative_weight() {
210 let p = InjectionPolicy {
211 pattern_weight: -1.0,
212 ..InjectionPolicy::default()
213 };
214 let body = toml::to_string(&p).unwrap();
215 let err = InjectionPolicy::parse(&body).unwrap_err();
216 assert!(matches!(
217 err,
218 InjectionPolicyError::InvalidWeight { ref field, .. } if field == "pattern_weight"
219 ));
220 }
221
222 #[test]
223 fn rejects_threshold_above_one() {
224 let p = InjectionPolicy {
225 reject_threshold: 1.5,
226 ..InjectionPolicy::default()
227 };
228 let body = toml::to_string(&p).unwrap();
229 let err = InjectionPolicy::parse(&body).unwrap_err();
230 assert!(matches!(
231 err,
232 InjectionPolicyError::InvalidWeight { ref field, .. } if field == "reject_threshold"
233 ));
234 }
235
236 #[test]
237 fn blend_without_model_returns_pattern() {
238 let p = InjectionPolicy::default();
239 assert!((p.blend(0.7, None) - 0.7).abs() < 1e-12);
240 assert!((p.blend(1.5, None) - 1.0).abs() < 1e-12);
242 assert!((p.blend(-0.2, None) - 0.0).abs() < 1e-12);
243 }
244
245 #[test]
246 fn blend_with_model_is_weighted_mean() {
247 let p = InjectionPolicy::default(); assert!((p.blend(0.8, Some(0.4)) - 0.6).abs() < 1e-12);
249
250 let skewed = InjectionPolicy {
251 pattern_weight: 0.75,
252 model_weight: 0.25,
253 ..InjectionPolicy::default()
254 };
255 assert!((skewed.blend(0.8, Some(0.4)) - 0.7).abs() < 1e-12);
257 }
258
259 #[test]
260 fn blend_with_zero_weights_falls_back_to_pattern() {
261 let p = InjectionPolicy {
262 pattern_weight: 0.0,
263 model_weight: 0.0,
264 ..InjectionPolicy::default()
265 };
266 assert!((p.blend(0.42, Some(0.99)) - 0.42).abs() < 1e-12);
267 }
268
269 #[test]
270 fn gating_matches_default_set() {
271 let p = InjectionPolicy::default();
272 for a in ["third_party", "community", "unknown"] {
273 assert!(p.model_gated_for(a), "{a} should gate");
274 }
275 for a in ["foundation", "partner"] {
276 assert!(!p.model_gated_for(a), "{a} should not gate");
277 }
278 }
279}