1use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant, SystemTime};
15
16use super::{CompiledRule, ReloadReport, RuleError, RuleHandle, RuleStatus};
17
18pub trait SignatureVerifier: Send + Sync {
24 fn verify(&self, rule: &CompiledRule) -> Result<(), RuleError>;
28}
29
30pub trait RuleSwapBackend: Send + Sync {
39 fn pre_stage(&self, rule: &CompiledRule) -> Result<(), RuleError>;
43
44 fn quiesce(&self, rule_id: &str, version: u64) -> Result<u64, RuleError>;
47
48 fn swap(&self, rule_id: &str, new_version: u64) -> Result<(), RuleError>;
50
51 fn terminate_old(&self, rule_id: &str, old_version: u64) -> Result<(), RuleError>;
53}
54
55#[derive(Debug, Default, Clone, Copy)]
57pub struct NoopSwapBackend;
58
59impl RuleSwapBackend for NoopSwapBackend {
60 fn pre_stage(&self, _rule: &CompiledRule) -> Result<(), RuleError> {
61 Ok(())
62 }
63
64 fn quiesce(&self, _rule_id: &str, _version: u64) -> Result<u64, RuleError> {
65 Ok(0)
66 }
67
68 fn swap(&self, _rule_id: &str, _new_version: u64) -> Result<(), RuleError> {
69 Ok(())
70 }
71
72 fn terminate_old(&self, _rule_id: &str, _old_version: u64) -> Result<(), RuleError> {
73 Ok(())
74 }
75}
76
77struct RuleVersionHistory {
79 versions: Vec<CompiledRule>,
82 active_version: Option<u64>,
84 status_by_version: HashMap<u64, RuleStatus>,
86 registered_at: HashMap<u64, SystemTime>,
88}
89
90impl RuleVersionHistory {
91 fn new() -> Self {
92 Self {
93 versions: Vec::new(),
94 active_version: None,
95 status_by_version: HashMap::new(),
96 registered_at: HashMap::new(),
97 }
98 }
99
100 fn insert_version(&mut self, rule: CompiledRule, max_history: usize) {
103 let version = rule.version;
104 self.registered_at.insert(version, SystemTime::now());
105 self.versions.push(rule);
106
107 while self.versions.len() > max_history {
108 let evicted = self.versions.remove(0);
109 self.status_by_version.remove(&evicted.version);
110 self.registered_at.remove(&evicted.version);
111 }
112 }
113
114 fn get(&self, version: u64) -> Option<&CompiledRule> {
115 self.versions.iter().find(|r| r.version == version)
116 }
117
118 fn active(&self) -> Option<&CompiledRule> {
119 self.active_version.and_then(|v| self.get(v))
120 }
121}
122
123pub struct RuleRegistry {
128 rules: RwLock<HashMap<String, RuleVersionHistory>>,
129 signature_verifier: Option<Arc<dyn SignatureVerifier>>,
130 swap_backend: Arc<dyn RuleSwapBackend>,
131 max_history_per_rule: usize,
132}
133
134impl RuleRegistry {
135 pub fn new(max_history_per_rule: usize, swap_backend: Arc<dyn RuleSwapBackend>) -> Self {
141 let max_history_per_rule = max_history_per_rule.max(1);
142 Self {
143 rules: RwLock::new(HashMap::new()),
144 signature_verifier: None,
145 swap_backend,
146 max_history_per_rule,
147 }
148 }
149
150 pub fn with_verifier(mut self, verifier: Arc<dyn SignatureVerifier>) -> Self {
153 self.signature_verifier = Some(verifier);
154 self
155 }
156
157 pub fn rule_count(&self) -> usize {
159 self.rules.read().len()
160 }
161
162 pub fn max_history(&self) -> usize {
164 self.max_history_per_rule
165 }
166
167 pub async fn register_rule(
178 &self,
179 rule: CompiledRule,
180 device_compute_cap: &str,
181 ) -> Result<RuleHandle, RuleError> {
182 self.validate(&rule, device_compute_cap, false)?;
183
184 self.swap_backend.pre_stage(&rule)?;
186
187 let version = rule.version;
188 let rule_id = rule.rule_id.clone();
189
190 let mut rules = self.rules.write();
191 let history = rules
192 .entry(rule_id.clone())
193 .or_insert_with(RuleVersionHistory::new);
194
195 if history.get(version).is_some() {
196 return Err(RuleError::DuplicateVersion { rule_id, version });
197 }
198
199 let status = if history.active_version.is_some() {
200 RuleStatus::Registered
201 } else {
202 RuleStatus::Active
203 };
204
205 history.insert_version(rule, self.max_history_per_rule);
206 history.status_by_version.insert(version, status);
207 if matches!(status, RuleStatus::Active) {
208 history.active_version = Some(version);
209 }
210
211 let registered_at = history
212 .registered_at
213 .get(&version)
214 .copied()
215 .unwrap_or_else(SystemTime::now);
216
217 Ok(RuleHandle {
218 rule_id,
219 version,
220 status,
221 registered_at,
222 })
223 }
224
225 pub async fn reload_rule(
237 &self,
238 rule: CompiledRule,
239 device_compute_cap: &str,
240 ) -> Result<ReloadReport, RuleError> {
241 self.validate(&rule, device_compute_cap, true)?;
242
243 let rule_id = rule.rule_id.clone();
244 let new_version = rule.version;
245
246 self.swap_backend.pre_stage(&rule)?;
248
249 let old_version = {
254 let rules = self.rules.read();
255 rules.get(&rule_id).and_then(|h| h.active_version)
256 };
257
258 let quiesce_start = Instant::now();
260 let messages_in_flight = if let Some(old_v) = old_version {
261 self.swap_backend.quiesce(&rule_id, old_v)?
262 } else {
263 0
264 };
265 let quiesce_duration = quiesce_start.elapsed();
266
267 let swap_start = Instant::now();
269 self.swap_backend.swap(&rule_id, new_version)?;
270 let swap_duration = swap_start.elapsed();
271
272 let mut rules = self.rules.write();
274 let history = rules
275 .entry(rule_id.clone())
276 .or_insert_with(RuleVersionHistory::new);
277
278 if history.get(new_version).is_some() {
279 return Err(RuleError::DuplicateVersion {
280 rule_id,
281 version: new_version,
282 });
283 }
284
285 history.insert_version(rule, self.max_history_per_rule);
286
287 if let Some(old_v) = old_version {
288 history
289 .status_by_version
290 .insert(old_v, RuleStatus::Superseded(new_version));
291 }
292 history
293 .status_by_version
294 .insert(new_version, RuleStatus::Active);
295 history.active_version = Some(new_version);
296
297 let rollback_available = old_version
298 .and_then(|v| history.versions.iter().find(|r| r.version == v))
299 .is_some();
300
301 drop(rules);
302
303 if let Some(old_v) = old_version {
306 self.swap_backend.terminate_old(&rule_id, old_v)?;
307 }
308
309 Ok(ReloadReport {
310 rule_id,
311 from_version: old_version.unwrap_or(0),
312 to_version: new_version,
313 quiesce_duration,
314 swap_duration,
315 messages_in_flight_during_swap: messages_in_flight,
316 rollback_available,
317 })
318 }
319
320 pub async fn rollback_rule(
326 &self,
327 rule_id: &str,
328 to_version: u64,
329 ) -> Result<ReloadReport, RuleError> {
330 let (current_active, target_rule) = {
332 let rules = self.rules.read();
333 let history = rules
334 .get(rule_id)
335 .ok_or_else(|| RuleError::NotFound(rule_id.to_string()))?;
336
337 let active = history.active_version.ok_or(RuleError::NoActiveVersion)?;
338 if active == to_version {
339 return Ok(ReloadReport {
341 rule_id: rule_id.to_string(),
342 from_version: active,
343 to_version,
344 quiesce_duration: Duration::from_nanos(0),
345 swap_duration: Duration::from_nanos(0),
346 messages_in_flight_during_swap: 0,
347 rollback_available: true,
348 });
349 }
350
351 let target = history
352 .get(to_version)
353 .cloned()
354 .ok_or(RuleError::RollbackTargetMissing(to_version))?;
355
356 (active, target)
357 };
358
359 self.swap_backend.pre_stage(&target_rule)?;
362
363 let quiesce_start = Instant::now();
364 let drained = self.swap_backend.quiesce(rule_id, current_active)?;
365 let quiesce_duration = quiesce_start.elapsed();
366
367 let swap_start = Instant::now();
368 self.swap_backend.swap(rule_id, to_version)?;
369 let swap_duration = swap_start.elapsed();
370
371 let mut rules = self.rules.write();
373 let history = rules
374 .get_mut(rule_id)
375 .ok_or_else(|| RuleError::NotFound(rule_id.to_string()))?;
376
377 history
378 .status_by_version
379 .insert(current_active, RuleStatus::Rolledback);
380 history
381 .status_by_version
382 .insert(to_version, RuleStatus::Active);
383 history.active_version = Some(to_version);
384
385 drop(rules);
386
387 self.swap_backend.terminate_old(rule_id, current_active)?;
388
389 Ok(ReloadReport {
390 rule_id: rule_id.to_string(),
391 from_version: current_active,
392 to_version,
393 quiesce_duration,
394 swap_duration,
395 messages_in_flight_during_swap: drained,
396 rollback_available: false,
397 })
398 }
399
400 pub fn list_rules(&self) -> Vec<RuleHandle> {
402 let rules = self.rules.read();
403 let mut out = Vec::new();
404 for (rule_id, history) in rules.iter() {
405 if let Some(active) = history.active_version {
406 if let Some(status) = history.status_by_version.get(&active).copied() {
407 let registered_at = history
408 .registered_at
409 .get(&active)
410 .copied()
411 .unwrap_or_else(SystemTime::now);
412 out.push(RuleHandle {
413 rule_id: rule_id.clone(),
414 version: active,
415 status,
416 registered_at,
417 });
418 }
419 }
420 }
421 out
422 }
423
424 pub fn get_rule(&self, rule_id: &str, version: u64) -> Option<CompiledRule> {
426 let rules = self.rules.read();
427 rules.get(rule_id).and_then(|h| h.get(version).cloned())
428 }
429
430 pub fn get_active(&self, rule_id: &str) -> Option<CompiledRule> {
432 let rules = self.rules.read();
433 rules.get(rule_id).and_then(|h| h.active().cloned())
434 }
435
436 pub fn history(&self, rule_id: &str) -> Vec<RuleHandle> {
438 let rules = self.rules.read();
439 let Some(history) = rules.get(rule_id) else {
440 return Vec::new();
441 };
442 history
443 .versions
444 .iter()
445 .map(|rule| RuleHandle {
446 rule_id: rule.rule_id.clone(),
447 version: rule.version,
448 status: history
449 .status_by_version
450 .get(&rule.version)
451 .copied()
452 .unwrap_or(RuleStatus::Registered),
453 registered_at: history
454 .registered_at
455 .get(&rule.version)
456 .copied()
457 .unwrap_or_else(SystemTime::now),
458 })
459 .collect()
460 }
461
462 fn validate(
468 &self,
469 rule: &CompiledRule,
470 device_compute_cap: &str,
471 is_reload: bool,
472 ) -> Result<(), RuleError> {
473 if let Some(verifier) = self.signature_verifier.as_ref() {
475 if rule.signature.is_none() {
476 return Err(RuleError::InvalidSignature);
477 }
478 verifier.verify(rule)?;
479 }
480
481 if !compute_cap_compatible(&rule.compute_cap, device_compute_cap) {
483 return Err(RuleError::ComputeCapMismatch {
484 required: rule.compute_cap.clone(),
485 available: device_compute_cap.to_string(),
486 });
487 }
488
489 {
491 let rules = self.rules.read();
492 for dep in &rule.depends_on {
493 if !rules
494 .get(dep)
495 .map(|h| h.active_version.is_some())
496 .unwrap_or(false)
497 {
498 return Err(RuleError::MissingDependency(dep.clone()));
499 }
500 }
501
502 if let Some(history) = rules.get(&rule.rule_id) {
504 if history.get(rule.version).is_some() {
508 return Err(RuleError::DuplicateVersion {
509 rule_id: rule.rule_id.clone(),
510 version: rule.version,
511 });
512 }
513
514 if let Some(active) = history.active_version {
517 if rule.version <= active {
518 return Err(RuleError::VersionDowngrade {
519 current: active,
520 proposed: rule.version,
521 });
522 }
523 } else if is_reload {
524 }
527 }
528 }
529
530 Ok(())
531 }
532}
533
534fn compute_cap_compatible(rule_cap: &str, device_cap: &str) -> bool {
541 match (parse_sm(rule_cap), parse_sm(device_cap)) {
542 (Some(req), Some(dev)) => dev >= req,
543 _ => rule_cap == device_cap,
544 }
545}
546
547fn parse_sm(s: &str) -> Option<u32> {
548 let digits = s.strip_prefix("sm_").or_else(|| s.strip_prefix("SM_"))?;
549 digits.parse().ok()
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::rules::{ActorConfig, RuleMetadata};
556
557 fn base_rule(rule_id: &str, version: u64) -> CompiledRule {
558 CompiledRule {
559 rule_id: rule_id.to_string(),
560 version,
561 ptx: vec![0xCA, 0xFE, 0xBA, 0xBE],
562 compute_cap: "sm_90".to_string(),
563 depends_on: Vec::new(),
564 signature: None,
565 actor_config: ActorConfig::default(),
566 metadata: RuleMetadata::default(),
567 }
568 }
569
570 fn registry() -> RuleRegistry {
571 RuleRegistry::new(5, Arc::new(NoopSwapBackend))
572 }
573
574 #[tokio::test]
575 async fn register_first_version_activates_immediately() {
576 let reg = registry();
577 let handle = reg
578 .register_rule(base_rule("r1", 1), "sm_90")
579 .await
580 .expect("register");
581 assert_eq!(handle.version, 1);
582 assert_eq!(handle.status, RuleStatus::Active);
583 assert_eq!(reg.get_active("r1").map(|r| r.version), Some(1));
584 }
585
586 #[tokio::test]
587 async fn register_duplicate_version_rejected() {
588 let reg = registry();
589 reg.register_rule(base_rule("r1", 1), "sm_90")
590 .await
591 .expect("initial");
592 let err = reg
593 .register_rule(base_rule("r1", 1), "sm_90")
594 .await
595 .expect_err("duplicate should fail");
596 assert!(matches!(err, RuleError::DuplicateVersion { .. }));
597 }
598
599 #[tokio::test]
600 async fn register_additional_version_stays_registered_not_active() {
601 let reg = registry();
602 reg.register_rule(base_rule("r1", 1), "sm_90")
603 .await
604 .expect("v1");
605 let h2 = reg
606 .register_rule(base_rule("r1", 2), "sm_90")
607 .await
608 .expect("v2");
609 assert_eq!(h2.status, RuleStatus::Registered);
610 assert_eq!(reg.get_active("r1").map(|r| r.version), Some(1));
611 }
612
613 #[tokio::test]
614 async fn reload_higher_version_succeeds() {
615 let reg = registry();
616 reg.register_rule(base_rule("r1", 1), "sm_90")
617 .await
618 .expect("v1");
619 let report = reg
620 .reload_rule(base_rule("r1", 2), "sm_90")
621 .await
622 .expect("reload");
623 assert_eq!(report.from_version, 1);
624 assert_eq!(report.to_version, 2);
625 assert!(report.rollback_available);
626 assert_eq!(reg.get_active("r1").map(|r| r.version), Some(2));
627 }
628
629 #[tokio::test]
630 async fn reload_lower_version_rejected() {
631 let reg = registry();
632 reg.register_rule(base_rule("r1", 5), "sm_90")
633 .await
634 .expect("v5");
635 let err = reg
636 .reload_rule(base_rule("r1", 4), "sm_90")
637 .await
638 .expect_err("downgrade should fail");
639 assert!(matches!(
640 err,
641 RuleError::VersionDowngrade {
642 current: 5,
643 proposed: 4
644 }
645 ));
646 assert_eq!(reg.get_active("r1").map(|r| r.version), Some(5));
647 }
648
649 #[tokio::test]
650 async fn reload_equal_version_rejected() {
651 let reg = registry();
652 reg.register_rule(base_rule("r1", 5), "sm_90")
653 .await
654 .expect("v5");
655 let err = reg
656 .reload_rule(base_rule("r1", 5), "sm_90")
657 .await
658 .expect_err("equal version rejected");
659 assert!(matches!(
662 err,
663 RuleError::DuplicateVersion { .. } | RuleError::VersionDowngrade { .. }
664 ));
665 }
666
667 #[tokio::test]
668 async fn compute_cap_mismatch_rejected() {
669 let reg = registry();
670 let mut rule = base_rule("r1", 1);
671 rule.compute_cap = "sm_100".to_string();
672 let err = reg
673 .register_rule(rule, "sm_90")
674 .await
675 .expect_err("cap mismatch");
676 assert!(matches!(err, RuleError::ComputeCapMismatch { .. }));
677 }
678
679 #[tokio::test]
680 async fn compute_cap_higher_device_is_compatible() {
681 let reg = registry();
682 let mut rule = base_rule("r1", 1);
684 rule.compute_cap = "sm_80".to_string();
685 let handle = reg.register_rule(rule, "sm_90").await.expect("compatible");
686 assert_eq!(handle.status, RuleStatus::Active);
687 }
688
689 #[tokio::test]
690 async fn missing_dependency_rejected() {
691 let reg = registry();
692 let mut rule = base_rule("r1", 1);
693 rule.depends_on = vec!["not_present".to_string()];
694 let err = reg
695 .register_rule(rule, "sm_90")
696 .await
697 .expect_err("missing dep");
698 assert!(matches!(err, RuleError::MissingDependency(_)));
699 }
700
701 #[tokio::test]
702 async fn present_dependency_accepted() {
703 let reg = registry();
704 reg.register_rule(base_rule("dep", 1), "sm_90")
705 .await
706 .expect("dep");
707 let mut rule = base_rule("main", 1);
708 rule.depends_on = vec!["dep".to_string()];
709 reg.register_rule(rule, "sm_90").await.expect("main");
710 }
711
712 struct RejectAllVerifier;
713 impl SignatureVerifier for RejectAllVerifier {
714 fn verify(&self, _rule: &CompiledRule) -> Result<(), RuleError> {
715 Err(RuleError::InvalidSignature)
716 }
717 }
718
719 struct AcceptAllVerifier;
720 impl SignatureVerifier for AcceptAllVerifier {
721 fn verify(&self, _rule: &CompiledRule) -> Result<(), RuleError> {
722 Ok(())
723 }
724 }
725
726 #[tokio::test]
727 async fn signature_rejection() {
728 let reg = RuleRegistry::new(5, Arc::new(NoopSwapBackend))
729 .with_verifier(Arc::new(RejectAllVerifier));
730 let mut rule = base_rule("r1", 1);
731 rule.signature = Some(vec![1, 2, 3]);
732 let err = reg
733 .register_rule(rule, "sm_90")
734 .await
735 .expect_err("bad signature");
736 assert!(matches!(err, RuleError::InvalidSignature));
737 }
738
739 #[tokio::test]
740 async fn signature_required_when_verifier_set() {
741 let reg = RuleRegistry::new(5, Arc::new(NoopSwapBackend))
742 .with_verifier(Arc::new(AcceptAllVerifier));
743 let err = reg
745 .register_rule(base_rule("r1", 1), "sm_90")
746 .await
747 .expect_err("missing signature");
748 assert!(matches!(err, RuleError::InvalidSignature));
749 }
750
751 #[tokio::test]
752 async fn signature_acceptance() {
753 let reg = RuleRegistry::new(5, Arc::new(NoopSwapBackend))
754 .with_verifier(Arc::new(AcceptAllVerifier));
755 let mut rule = base_rule("r1", 1);
756 rule.signature = Some(vec![1]);
757 let handle = reg
758 .register_rule(rule, "sm_90")
759 .await
760 .expect("valid signature");
761 assert_eq!(handle.status, RuleStatus::Active);
762 }
763
764 #[tokio::test]
765 async fn rollback_to_prior_version() {
766 let reg = registry();
767 reg.register_rule(base_rule("r1", 1), "sm_90")
768 .await
769 .expect("v1");
770 reg.reload_rule(base_rule("r1", 2), "sm_90")
771 .await
772 .expect("v2");
773 let report = reg.rollback_rule("r1", 1).await.expect("rollback");
774 assert_eq!(report.from_version, 2);
775 assert_eq!(report.to_version, 1);
776 assert_eq!(reg.get_active("r1").map(|r| r.version), Some(1));
777
778 let history = reg.history("r1");
780 let v2 = history
781 .iter()
782 .find(|h| h.version == 2)
783 .expect("v2 in history");
784 assert_eq!(v2.status, RuleStatus::Rolledback);
785 let v1 = history
786 .iter()
787 .find(|h| h.version == 1)
788 .expect("v1 in history");
789 assert_eq!(v1.status, RuleStatus::Active);
790 }
791
792 #[tokio::test]
793 async fn rollback_to_nonexistent_version_rejected() {
794 let reg = registry();
795 reg.register_rule(base_rule("r1", 1), "sm_90")
796 .await
797 .expect("v1");
798 let err = reg
799 .rollback_rule("r1", 99)
800 .await
801 .expect_err("no such version");
802 assert!(matches!(err, RuleError::RollbackTargetMissing(99)));
803 }
804
805 #[tokio::test]
806 async fn rollback_unknown_rule_rejected() {
807 let reg = registry();
808 let err = reg
809 .rollback_rule("nope", 1)
810 .await
811 .expect_err("unknown rule");
812 assert!(matches!(err, RuleError::NotFound(_)));
813 }
814
815 #[tokio::test]
816 async fn rollback_to_active_is_noop() {
817 let reg = registry();
818 reg.register_rule(base_rule("r1", 1), "sm_90")
819 .await
820 .expect("v1");
821 let report = reg.rollback_rule("r1", 1).await.expect("noop rollback");
822 assert_eq!(report.from_version, 1);
823 assert_eq!(report.to_version, 1);
824 assert_eq!(reg.get_active("r1").map(|r| r.version), Some(1));
825 }
826
827 #[tokio::test]
828 async fn history_retention_evicts_oldest() {
829 let reg = RuleRegistry::new(3, Arc::new(NoopSwapBackend));
830 reg.register_rule(base_rule("r1", 1), "sm_90")
831 .await
832 .expect("v1");
833 for v in 2..=5 {
834 reg.reload_rule(base_rule("r1", v), "sm_90")
835 .await
836 .unwrap_or_else(|e| panic!("reload v{} failed: {:?}", v, e));
837 }
838 let history = reg.history("r1");
839 assert_eq!(history.len(), 3, "retains most recent 3 versions");
840 let versions: Vec<u64> = history.iter().map(|h| h.version).collect();
841 assert_eq!(versions, vec![3, 4, 5]);
843 }
844
845 #[tokio::test]
846 async fn multiple_concurrent_rules() {
847 let reg = registry();
848 reg.register_rule(base_rule("a", 1), "sm_90")
849 .await
850 .expect("a");
851 reg.register_rule(base_rule("b", 7), "sm_90")
852 .await
853 .expect("b");
854 reg.register_rule(base_rule("c", 3), "sm_90")
855 .await
856 .expect("c");
857 assert_eq!(reg.rule_count(), 3);
858 assert_eq!(reg.get_active("a").map(|r| r.version), Some(1));
859 assert_eq!(reg.get_active("b").map(|r| r.version), Some(7));
860 assert_eq!(reg.get_active("c").map(|r| r.version), Some(3));
861 }
862
863 #[tokio::test]
864 async fn list_rules_returns_active_only() {
865 let reg = registry();
866 reg.register_rule(base_rule("a", 1), "sm_90")
867 .await
868 .expect("a");
869 reg.register_rule(base_rule("b", 2), "sm_90")
870 .await
871 .expect("b");
872 let listed = reg.list_rules();
873 assert_eq!(listed.len(), 2);
874 for h in &listed {
875 assert!(matches!(h.status, RuleStatus::Active));
876 }
877 }
878
879 #[tokio::test]
880 async fn get_rule_returns_specific_version() {
881 let reg = registry();
882 reg.register_rule(base_rule("r1", 1), "sm_90")
883 .await
884 .expect("v1");
885 reg.reload_rule(base_rule("r1", 2), "sm_90")
886 .await
887 .expect("v2");
888 assert!(reg.get_rule("r1", 1).is_some());
889 assert!(reg.get_rule("r1", 2).is_some());
890 assert!(reg.get_rule("r1", 3).is_none());
891 }
892
893 #[tokio::test]
894 async fn reload_report_fields_populated() {
895 let reg = registry();
896 reg.register_rule(base_rule("r1", 1), "sm_90")
897 .await
898 .expect("v1");
899 let report = reg
900 .reload_rule(base_rule("r1", 2), "sm_90")
901 .await
902 .expect("v2");
903 assert_eq!(report.rule_id, "r1");
904 assert_eq!(report.from_version, 1);
905 assert_eq!(report.to_version, 2);
906 assert!(report.rollback_available);
907 let _: Duration = report.quiesce_duration;
910 let _: Duration = report.swap_duration;
911 }
912
913 #[tokio::test]
914 async fn rollback_report_marks_no_further_rollback_available() {
915 let reg = registry();
916 reg.register_rule(base_rule("r1", 1), "sm_90")
917 .await
918 .expect("v1");
919 reg.reload_rule(base_rule("r1", 2), "sm_90")
920 .await
921 .expect("v2");
922 let report = reg.rollback_rule("r1", 1).await.expect("rollback");
923 assert!(!report.rollback_available);
924 }
925
926 struct CountingBackend {
927 pre_stage: std::sync::atomic::AtomicU64,
928 quiesce: std::sync::atomic::AtomicU64,
929 swap: std::sync::atomic::AtomicU64,
930 terminate: std::sync::atomic::AtomicU64,
931 drained_per_call: u64,
932 }
933
934 impl CountingBackend {
935 fn new(drained: u64) -> Self {
936 Self {
937 pre_stage: std::sync::atomic::AtomicU64::new(0),
938 quiesce: std::sync::atomic::AtomicU64::new(0),
939 swap: std::sync::atomic::AtomicU64::new(0),
940 terminate: std::sync::atomic::AtomicU64::new(0),
941 drained_per_call: drained,
942 }
943 }
944 }
945
946 impl RuleSwapBackend for CountingBackend {
947 fn pre_stage(&self, _rule: &CompiledRule) -> Result<(), RuleError> {
948 self.pre_stage
949 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
950 Ok(())
951 }
952 fn quiesce(&self, _rule_id: &str, _version: u64) -> Result<u64, RuleError> {
953 self.quiesce
954 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
955 Ok(self.drained_per_call)
956 }
957 fn swap(&self, _rule_id: &str, _new_version: u64) -> Result<(), RuleError> {
958 self.swap.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
959 Ok(())
960 }
961 fn terminate_old(&self, _rule_id: &str, _old_version: u64) -> Result<(), RuleError> {
962 self.terminate
963 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
964 Ok(())
965 }
966 }
967
968 #[tokio::test]
969 async fn backend_called_in_correct_order_for_reload() {
970 let backend = Arc::new(CountingBackend::new(42));
971 let reg = RuleRegistry::new(5, backend.clone());
972 reg.register_rule(base_rule("r1", 1), "sm_90")
973 .await
974 .expect("v1");
975 assert_eq!(
977 backend.pre_stage.load(std::sync::atomic::Ordering::Relaxed),
978 1
979 );
980 assert_eq!(
981 backend.quiesce.load(std::sync::atomic::Ordering::Relaxed),
982 0
983 );
984
985 let report = reg
986 .reload_rule(base_rule("r1", 2), "sm_90")
987 .await
988 .expect("v2");
989 assert_eq!(report.messages_in_flight_during_swap, 42);
990 assert_eq!(
991 backend.pre_stage.load(std::sync::atomic::Ordering::Relaxed),
992 2
993 );
994 assert_eq!(
995 backend.quiesce.load(std::sync::atomic::Ordering::Relaxed),
996 1
997 );
998 assert_eq!(backend.swap.load(std::sync::atomic::Ordering::Relaxed), 1);
999 assert_eq!(
1000 backend.terminate.load(std::sync::atomic::Ordering::Relaxed),
1001 1
1002 );
1003 }
1004
1005 struct FailingSwapBackend;
1006 impl RuleSwapBackend for FailingSwapBackend {
1007 fn pre_stage(&self, _rule: &CompiledRule) -> Result<(), RuleError> {
1008 Err(RuleError::BackendError("pre_stage failed".into()))
1009 }
1010 fn quiesce(&self, _rule_id: &str, _version: u64) -> Result<u64, RuleError> {
1011 Ok(0)
1012 }
1013 fn swap(&self, _rule_id: &str, _new_version: u64) -> Result<(), RuleError> {
1014 Ok(())
1015 }
1016 fn terminate_old(&self, _rule_id: &str, _old_version: u64) -> Result<(), RuleError> {
1017 Ok(())
1018 }
1019 }
1020
1021 #[tokio::test]
1022 async fn backend_pre_stage_failure_propagates_without_state_change() {
1023 let reg = RuleRegistry::new(5, Arc::new(FailingSwapBackend));
1024 let err = reg
1025 .register_rule(base_rule("r1", 1), "sm_90")
1026 .await
1027 .expect_err("pre_stage fails");
1028 assert!(matches!(err, RuleError::BackendError(_)));
1029 assert_eq!(reg.rule_count(), 0);
1030 }
1031
1032 #[tokio::test]
1033 async fn history_lists_all_retained_versions_with_statuses() {
1034 let reg = registry();
1035 reg.register_rule(base_rule("r1", 1), "sm_90")
1036 .await
1037 .expect("v1");
1038 reg.reload_rule(base_rule("r1", 2), "sm_90")
1039 .await
1040 .expect("v2");
1041 reg.reload_rule(base_rule("r1", 3), "sm_90")
1042 .await
1043 .expect("v3");
1044 let history = reg.history("r1");
1045 assert_eq!(history.len(), 3);
1046 let v1 = history
1048 .iter()
1049 .find(|h| h.version == 1)
1050 .expect("v1 in history");
1051 let v2 = history
1052 .iter()
1053 .find(|h| h.version == 2)
1054 .expect("v2 in history");
1055 let v3 = history
1056 .iter()
1057 .find(|h| h.version == 3)
1058 .expect("v3 in history");
1059 assert_eq!(v1.status, RuleStatus::Superseded(2));
1060 assert_eq!(v2.status, RuleStatus::Superseded(3));
1061 assert_eq!(v3.status, RuleStatus::Active);
1062 }
1063
1064 #[tokio::test]
1065 async fn reload_rule_with_no_existing_rule_activates_it() {
1066 let reg = registry();
1067 let report = reg
1069 .reload_rule(base_rule("fresh", 1), "sm_90")
1070 .await
1071 .expect("initial reload");
1072 assert_eq!(report.from_version, 0);
1073 assert_eq!(report.to_version, 1);
1074 assert_eq!(reg.get_active("fresh").map(|r| r.version), Some(1));
1075 }
1076
1077 #[tokio::test]
1078 async fn get_active_none_when_no_rule() {
1079 let reg = registry();
1080 assert!(reg.get_active("missing").is_none());
1081 }
1082
1083 #[tokio::test]
1084 async fn history_empty_for_unknown_rule() {
1085 let reg = registry();
1086 assert!(reg.history("unknown").is_empty());
1087 }
1088
1089 #[test]
1090 fn compute_cap_compatibility_matrix() {
1091 assert!(compute_cap_compatible("sm_80", "sm_90"));
1092 assert!(compute_cap_compatible("sm_90", "sm_90"));
1093 assert!(!compute_cap_compatible("sm_90", "sm_80"));
1094 assert!(!compute_cap_compatible("sm_90", "sm_86"));
1095 assert!(compute_cap_compatible("custom", "custom"));
1097 assert!(!compute_cap_compatible("custom", "other"));
1098 }
1099
1100 #[test]
1101 fn default_history_is_at_least_one() {
1102 let reg = RuleRegistry::new(0, Arc::new(NoopSwapBackend));
1103 assert!(reg.max_history() >= 1);
1104 }
1105
1106 #[tokio::test]
1107 async fn duplicate_version_rejected_on_reload() {
1108 let reg = registry();
1109 reg.register_rule(base_rule("r1", 1), "sm_90")
1110 .await
1111 .expect("v1");
1112 reg.register_rule(base_rule("r1", 2), "sm_90")
1114 .await
1115 .expect("v2");
1116 let err = reg
1117 .reload_rule(base_rule("r1", 2), "sm_90")
1118 .await
1119 .expect_err("duplicate");
1120 assert!(matches!(err, RuleError::DuplicateVersion { .. }));
1121 }
1122}