1#![forbid(unsafe_code)]
2
3use std::collections::HashMap;
32use std::fmt;
33use std::sync::{Arc, RwLock};
34
35use arc_swap::ArcSwap;
36
37use crate::policy_config::PolicyConfig;
38
39pub const STANDARD_POLICY: &str = "standard";
41
42#[derive(Debug, Clone)]
49struct ActivePolicy {
50 name: String,
51 config: PolicyConfig,
52}
53
54#[derive(Debug, Clone)]
60pub struct PolicySwitchEvent {
61 pub old_name: String,
63 pub new_name: String,
65 pub switch_id: u64,
67}
68
69impl PolicySwitchEvent {
70 pub fn to_jsonl(&self) -> String {
72 format!(
73 r#"{{"schema":"policy-switch-v1","switch_id":{},"old":"{}","new":"{}"}}"#,
74 self.switch_id,
75 self.old_name.replace('"', "\\\""),
76 self.new_name.replace('"', "\\\""),
77 )
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq)]
87pub enum PolicyRegistryError {
88 NotFound(String),
90 StandardPolicyProtected,
92 ValidationFailed(Vec<String>),
94}
95
96impl fmt::Display for PolicyRegistryError {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 match self {
99 Self::NotFound(name) => write!(f, "policy not found: {name}"),
100 Self::StandardPolicyProtected => write!(f, "cannot remove standard policy"),
101 Self::ValidationFailed(errors) => {
102 write!(f, "policy validation failed: {}", errors.join("; "))
103 }
104 }
105 }
106}
107
108impl std::error::Error for PolicyRegistryError {}
109
110pub struct PolicyRegistry {
122 policies: RwLock<HashMap<String, PolicyConfig>>,
124 active: ArcSwap<ActivePolicy>,
126 switch_count: std::sync::atomic::AtomicU64,
128}
129
130impl PolicyRegistry {
131 pub fn new() -> Self {
133 let standard = PolicyConfig::default();
134 let mut map = HashMap::new();
135 map.insert(STANDARD_POLICY.to_string(), standard.clone());
136
137 Self {
138 policies: RwLock::new(map),
139 active: ArcSwap::from_pointee(ActivePolicy {
140 name: STANDARD_POLICY.to_string(),
141 config: standard,
142 }),
143 switch_count: std::sync::atomic::AtomicU64::new(0),
144 }
145 }
146
147 pub fn active_config(&self) -> PolicyConfig {
149 self.active.load().config.clone()
150 }
151
152 pub fn active_name(&self) -> String {
154 self.active.load().name.clone()
155 }
156
157 pub fn register(&self, name: &str, config: PolicyConfig) -> Result<(), PolicyRegistryError> {
162 if name == STANDARD_POLICY {
163 return Err(PolicyRegistryError::StandardPolicyProtected);
164 }
165
166 let errors = config.validate();
167 if !errors.is_empty() {
168 return Err(PolicyRegistryError::ValidationFailed(errors));
169 }
170
171 let mut map = self.policies.write().unwrap_or_else(|e| e.into_inner());
172 map.insert(name.to_string(), config);
173 Ok(())
174 }
175
176 pub fn remove(&self, name: &str) -> Result<(), PolicyRegistryError> {
179 if name == STANDARD_POLICY {
180 return Err(PolicyRegistryError::StandardPolicyProtected);
181 }
182
183 if self.active_name() == name {
185 return Err(PolicyRegistryError::NotFound(format!(
186 "cannot remove active policy: {name}"
187 )));
188 }
189
190 let mut map = self.policies.write().unwrap_or_else(|e| e.into_inner());
191 map.remove(name)
192 .map(|_| ())
193 .ok_or_else(|| PolicyRegistryError::NotFound(name.to_string()))
194 }
195
196 pub fn set_active(&self, name: &str) -> Result<PolicySwitchEvent, PolicyRegistryError> {
202 let map = self.policies.read().unwrap_or_else(|e| e.into_inner());
203 let config = map
204 .get(name)
205 .cloned()
206 .ok_or_else(|| PolicyRegistryError::NotFound(name.to_string()))?;
207 drop(map);
208
209 let old_name = self.active_name();
210 let switch_id = self
211 .switch_count
212 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
213
214 self.active.store(Arc::new(ActivePolicy {
215 name: name.to_string(),
216 config,
217 }));
218
219 Ok(PolicySwitchEvent {
220 old_name,
221 new_name: name.to_string(),
222 switch_id,
223 })
224 }
225
226 pub fn list(&self) -> Vec<String> {
228 let map = self.policies.read().unwrap_or_else(|e| e.into_inner());
229 let mut names: Vec<String> = map.keys().cloned().collect();
230 names.sort();
231 names
232 }
233
234 pub fn get(&self, name: &str) -> Option<PolicyConfig> {
236 let map = self.policies.read().unwrap_or_else(|e| e.into_inner());
237 map.get(name).cloned()
238 }
239
240 pub fn switch_count(&self) -> u64 {
242 self.switch_count.load(std::sync::atomic::Ordering::Relaxed)
243 }
244}
245
246impl Default for PolicyRegistry {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252impl fmt::Debug for PolicyRegistry {
253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254 f.debug_struct("PolicyRegistry")
255 .field("active", &self.active_name())
256 .field("policies", &self.list())
257 .field("switch_count", &self.switch_count())
258 .finish()
259 }
260}
261
262#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn new_has_standard_policy() {
272 let reg = PolicyRegistry::new();
273 assert_eq!(reg.active_name(), STANDARD_POLICY);
274 assert_eq!(reg.list(), vec![STANDARD_POLICY.to_string()]);
275 }
276
277 #[test]
278 fn register_and_switch() {
279 let reg = PolicyRegistry::new();
280 let mut custom = PolicyConfig::default();
281 custom.conformal.alpha = 0.01;
282
283 reg.register("custom", custom).unwrap();
284 let event = reg.set_active("custom").unwrap();
285
286 assert_eq!(event.old_name, STANDARD_POLICY);
287 assert_eq!(event.new_name, "custom");
288 assert_eq!(event.switch_id, 0);
289 assert_eq!(reg.active_name(), "custom");
290 assert!((reg.active_config().conformal.alpha - 0.01).abs() < f64::EPSILON);
291 }
292
293 #[test]
294 fn switch_back_to_standard() {
295 let reg = PolicyRegistry::new();
296 let custom = PolicyConfig::default();
297 reg.register("custom", custom).unwrap();
298 reg.set_active("custom").unwrap();
299
300 let event = reg.set_active(STANDARD_POLICY).unwrap();
301 assert_eq!(event.old_name, "custom");
302 assert_eq!(event.new_name, STANDARD_POLICY);
303 assert_eq!(event.switch_id, 1);
304 assert_eq!(reg.switch_count(), 2);
305 }
306
307 #[test]
308 fn switch_to_nonexistent_fails() {
309 let reg = PolicyRegistry::new();
310 let err = reg.set_active("nonexistent").unwrap_err();
311 assert!(matches!(err, PolicyRegistryError::NotFound(_)));
312 }
313
314 #[test]
315 fn cannot_overwrite_standard() {
316 let reg = PolicyRegistry::new();
317 let err = reg
318 .register(STANDARD_POLICY, PolicyConfig::default())
319 .unwrap_err();
320 assert!(matches!(err, PolicyRegistryError::StandardPolicyProtected));
321 }
322
323 #[test]
324 fn cannot_remove_standard() {
325 let reg = PolicyRegistry::new();
326 let err = reg.remove(STANDARD_POLICY).unwrap_err();
327 assert!(matches!(err, PolicyRegistryError::StandardPolicyProtected));
328 }
329
330 #[test]
331 fn cannot_remove_active() {
332 let reg = PolicyRegistry::new();
333 reg.register("custom", PolicyConfig::default()).unwrap();
334 reg.set_active("custom").unwrap();
335 let err = reg.remove("custom").unwrap_err();
336 assert!(matches!(err, PolicyRegistryError::NotFound(_)));
337 }
338
339 #[test]
340 fn remove_inactive() {
341 let reg = PolicyRegistry::new();
342 reg.register("custom", PolicyConfig::default()).unwrap();
343 assert_eq!(reg.list().len(), 2);
344
345 reg.remove("custom").unwrap();
346 assert_eq!(reg.list().len(), 1);
347 }
348
349 #[test]
350 fn register_validates() {
351 let reg = PolicyRegistry::new();
352 let mut bad = PolicyConfig::default();
353 bad.conformal.alpha = 0.0; let err = reg.register("bad", bad).unwrap_err();
356 assert!(matches!(err, PolicyRegistryError::ValidationFailed(_)));
357 }
358
359 #[test]
360 fn get_existing() {
361 let reg = PolicyRegistry::new();
362 let config = reg.get(STANDARD_POLICY);
363 assert!(config.is_some());
364 }
365
366 #[test]
367 fn get_nonexistent() {
368 let reg = PolicyRegistry::new();
369 assert!(reg.get("nonexistent").is_none());
370 }
371
372 #[test]
373 fn switch_event_jsonl() {
374 let event = PolicySwitchEvent {
375 old_name: "standard".into(),
376 new_name: "aggressive".into(),
377 switch_id: 42,
378 };
379 let jsonl = event.to_jsonl();
380 assert!(jsonl.contains("policy-switch-v1"));
381 assert!(jsonl.contains("\"switch_id\":42"));
382 assert!(jsonl.contains("\"old\":\"standard\""));
383 assert!(jsonl.contains("\"new\":\"aggressive\""));
384
385 let parsed: serde_json::Value = serde_json::from_str(&jsonl).unwrap();
387 assert!(parsed.is_object());
388 }
389
390 #[test]
391 fn debug_format() {
392 let reg = PolicyRegistry::new();
393 let debug = format!("{reg:?}");
394 assert!(debug.contains("PolicyRegistry"));
395 assert!(debug.contains("standard"));
396 }
397
398 #[test]
399 fn concurrent_reads_during_switch() {
400 let reg = Arc::new(PolicyRegistry::new());
401 let mut custom = PolicyConfig::default();
402 custom.conformal.alpha = 0.02;
403 reg.register("custom", custom).unwrap();
404
405 std::thread::scope(|s| {
406 for _ in 0..4 {
408 let reg = Arc::clone(®);
409 s.spawn(move || {
410 for _ in 0..100 {
411 let _name = reg.active_name();
412 let _config = reg.active_config();
413 }
415 });
416 }
417
418 {
420 let reg = Arc::clone(®);
421 s.spawn(move || {
422 for i in 0..50 {
423 if i % 2 == 0 {
424 let _ = reg.set_active("custom");
425 } else {
426 let _ = reg.set_active(STANDARD_POLICY);
427 }
428 }
429 });
430 }
431 });
432
433 assert!(reg.switch_count() > 0);
436 }
437
438 #[test]
439 fn overwrite_registered_policy() {
440 let reg = PolicyRegistry::new();
441 let mut v1 = PolicyConfig::default();
442 v1.conformal.alpha = 0.02;
443 reg.register("custom", v1).unwrap();
444
445 let mut v2 = PolicyConfig::default();
446 v2.conformal.alpha = 0.03;
447 reg.register("custom", v2).unwrap();
448
449 let config = reg.get("custom").unwrap();
450 assert!((config.conformal.alpha - 0.03).abs() < f64::EPSILON);
451 }
452}