use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use thiserror::Error;
use crate::ontology::SigmaSnapshot;
#[derive(Error, Debug)]
pub enum PromotionError {
#[error("Lock poisoned: {0}")]
LockPoisoned(&'static str),
}
pub struct AtomicSnapshotPromoter {
current: Arc<RwLock<Arc<SnapshotHandle>>>,
promotion_count: AtomicUsize,
last_promotion_ns: AtomicUsize,
}
pub struct SnapshotHandle {
snapshot: Arc<SigmaSnapshot>,
reference_count: AtomicUsize,
}
impl SnapshotHandle {
fn new(snapshot: Arc<SigmaSnapshot>) -> Self {
Self {
snapshot,
reference_count: AtomicUsize::new(1),
}
}
fn increment_refs(&self) {
self.reference_count.fetch_add(1, Ordering::AcqRel);
}
fn decrement_refs(&self) -> usize {
self.reference_count.fetch_sub(1, Ordering::AcqRel)
}
}
impl AtomicSnapshotPromoter {
pub fn new(initial_snapshot: Arc<SigmaSnapshot>) -> Self {
let handle = Arc::new(SnapshotHandle::new(initial_snapshot));
Self {
current: Arc::new(RwLock::new(handle)),
promotion_count: AtomicUsize::new(0),
last_promotion_ns: AtomicUsize::new(0),
}
}
pub fn get_current(&self) -> Result<SnapshotGuard, PromotionError> {
let handle = {
let guard = self
.current
.read()
.map_err(|_| PromotionError::LockPoisoned("current snapshot lock poisoned"))?;
Arc::clone(&*guard)
};
handle.increment_refs();
Ok(SnapshotGuard { handle })
}
pub fn promote(
&self, new_snapshot: Arc<SigmaSnapshot>,
) -> Result<PromotionResult, PromotionError> {
let start_ns = get_time_ns();
let new_handle = Arc::new(SnapshotHandle::new(new_snapshot));
let old_handle = {
let mut current_guard = self
.current
.write()
.map_err(|_| PromotionError::LockPoisoned("current snapshot lock poisoned"))?;
let old = Arc::clone(&*current_guard);
*current_guard = new_handle;
old
};
let end_ns = get_time_ns();
self.promotion_count.fetch_add(1, Ordering::Relaxed);
self.last_promotion_ns.store(end_ns, Ordering::Relaxed);
old_handle.decrement_refs();
Ok(PromotionResult {
duration_ns: end_ns - start_ns,
promotion_count: self.promotion_count.load(Ordering::Relaxed),
})
}
pub fn metrics(&self) -> PromotionMetrics {
PromotionMetrics {
total_promotions: self.promotion_count.load(Ordering::Relaxed),
last_promotion_ns: self.last_promotion_ns.load(Ordering::Relaxed),
}
}
}
unsafe impl Send for AtomicSnapshotPromoter {}
unsafe impl Sync for AtomicSnapshotPromoter {}
pub struct SnapshotGuard {
handle: Arc<SnapshotHandle>,
}
impl SnapshotGuard {
pub fn snapshot(&self) -> Arc<SigmaSnapshot> {
self.handle.snapshot.clone()
}
}
impl Drop for SnapshotGuard {
fn drop(&mut self) {
self.handle.decrement_refs();
}
}
unsafe impl Send for SnapshotGuard {}
unsafe impl Sync for SnapshotGuard {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromotionResult {
pub duration_ns: usize,
pub promotion_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromotionMetrics {
pub total_promotions: usize,
pub last_promotion_ns: usize,
}
fn get_time_ns() -> usize {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as usize
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_snapshot(version: &str) -> Arc<SigmaSnapshot> {
Arc::new(SigmaSnapshot::new(
None,
vec![],
version.to_string(),
"sig".to_string(),
Default::default(),
))
}
#[test]
fn test_promoter_creation() {
let snap = create_test_snapshot("1.0.0");
let promoter = AtomicSnapshotPromoter::new(snap.clone());
let current = promoter.get_current().unwrap();
assert_eq!(current.snapshot().version, "1.0.0");
}
#[test]
fn test_atomic_promotion() {
let snap1 = create_test_snapshot("1.0.0");
let promoter = AtomicSnapshotPromoter::new(snap1);
let snap2 = create_test_snapshot("2.0.0");
let result = promoter.promote(snap2).unwrap();
assert!(result.duration_ns < 1_000_000);
let current = promoter.get_current().unwrap();
assert_eq!(current.snapshot().version, "2.0.0");
}
#[test]
fn test_promotion_metrics() {
let snap1 = create_test_snapshot("1.0.0");
let promoter = AtomicSnapshotPromoter::new(snap1);
assert_eq!(promoter.metrics().total_promotions, 0);
let snap2 = create_test_snapshot("2.0.0");
promoter.promote(snap2).unwrap();
let metrics = promoter.metrics();
assert_eq!(metrics.total_promotions, 1);
}
#[test]
fn test_multiple_promotions() {
let snap1 = create_test_snapshot("1.0.0");
let promoter = Arc::new(AtomicSnapshotPromoter::new(snap1));
for i in 2..=5 {
let snap = create_test_snapshot(&format!("{}.0.0", i));
promoter.promote(snap).unwrap();
}
let current = promoter.get_current().unwrap();
assert_eq!(current.snapshot().version, "5.0.0");
let metrics = promoter.metrics();
assert_eq!(metrics.total_promotions, 4);
}
#[test]
#[allow(clippy::expect_used)]
fn test_concurrent_reads() {
let snap1 = create_test_snapshot("1.0.0");
let promoter = Arc::new(AtomicSnapshotPromoter::new(snap1));
let mut handles = vec![];
for _ in 0..100 {
let p = promoter.clone();
handles.push(std::thread::spawn(move || {
let guard = p.get_current().unwrap();
assert!(!guard.snapshot().version.is_empty());
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
}
#[test]
fn test_snapshot_guard_raii() {
let snap = create_test_snapshot("1.0.0");
let promoter = AtomicSnapshotPromoter::new(snap);
{
let _guard1 = promoter.get_current().unwrap();
let _guard2 = promoter.get_current().unwrap();
}
let current = promoter.get_current().unwrap();
assert_eq!(current.snapshot().version, "1.0.0");
}
#[test]
fn test_promotion_with_guards() {
let snap1 = create_test_snapshot("1.0.0");
let promoter = Arc::new(AtomicSnapshotPromoter::new(snap1));
let guard1 = promoter.get_current().unwrap();
assert_eq!(guard1.snapshot().version, "1.0.0");
let snap2 = create_test_snapshot("2.0.0");
promoter.promote(snap2).unwrap();
assert_eq!(guard1.snapshot().version, "1.0.0");
let guard2 = promoter.get_current().unwrap();
assert_eq!(guard2.snapshot().version, "2.0.0");
}
#[test]
#[allow(clippy::expect_used)]
fn test_safe_concurrent_promotion_and_reads() {
let snap1 = create_test_snapshot("1.0.0");
let promoter = Arc::new(AtomicSnapshotPromoter::new(snap1));
let mut handles = vec![];
for i in 2..=11 {
let p = promoter.clone();
handles.push(std::thread::spawn(move || {
let snap = create_test_snapshot(&format!("{}.0.0", i));
p.promote(snap).unwrap();
}));
}
for _ in 0..50 {
let p = promoter.clone();
handles.push(std::thread::spawn(move || {
let guard = p.get_current().unwrap();
assert!(!guard.snapshot().version.is_empty());
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
let metrics = promoter.metrics();
assert_eq!(metrics.total_promotions, 10);
}
}