use crate::error::DrasiError;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Clone)]
pub struct StateGuard {
initialized: Arc<AtomicBool>,
}
impl StateGuard {
pub fn new() -> Self {
Self {
initialized: Arc::new(AtomicBool::new(false)),
}
}
pub fn mark_initialized(&self) {
self.initialized.store(true, Ordering::Release);
}
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Acquire)
}
pub fn require_initialized(&self) -> crate::error::Result<()> {
if !self.initialized.load(Ordering::Acquire) {
return Err(DrasiError::invalid_state(
"Server must be initialized before this operation",
));
}
Ok(())
}
}
impl Default for StateGuard {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_state_not_initialized() {
let guard = StateGuard::new();
assert!(!guard.is_initialized());
}
#[tokio::test]
async fn test_mark_initialized() {
let guard = StateGuard::new();
guard.mark_initialized();
assert!(guard.is_initialized());
}
#[tokio::test]
async fn test_require_initialized_fails_when_not_initialized() {
let guard = StateGuard::new();
let result = guard.require_initialized();
assert!(result.is_err());
match result {
Err(DrasiError::InvalidState { message }) => {
assert!(message.contains("initialized"));
}
_ => panic!("Expected InvalidState error"),
}
}
#[tokio::test]
async fn test_require_initialized_succeeds_when_initialized() {
let guard = StateGuard::new();
guard.mark_initialized();
let result = guard.require_initialized();
assert!(result.is_ok());
}
#[tokio::test]
async fn test_clone_shares_state() {
let guard1 = StateGuard::new();
let guard2 = guard1.clone();
guard1.mark_initialized();
assert!(guard1.is_initialized());
assert!(guard2.is_initialized());
}
#[tokio::test]
async fn test_concurrent_reads() {
let guard = StateGuard::new();
guard.mark_initialized();
let mut handles = Vec::new();
for _ in 0..100 {
let guard_clone = guard.clone();
handles.push(tokio::spawn(async move {
for _ in 0..100 {
assert!(guard_clone.is_initialized());
assert!(guard_clone.require_initialized().is_ok());
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}
#[tokio::test]
async fn test_initialization_visibility() {
let guard = StateGuard::new();
assert!(!guard.is_initialized());
let started = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let guard_clone = guard.clone();
let started_clone = started.clone();
handles.push(tokio::spawn(async move {
started_clone.fetch_add(1, std::sync::atomic::Ordering::Release);
let start = std::time::Instant::now();
while !guard_clone.is_initialized() {
if start.elapsed() > std::time::Duration::from_secs(5) {
panic!("Timeout waiting for initialization to be visible");
}
tokio::task::yield_now().await;
}
guard_clone.require_initialized().unwrap();
}));
}
while started.load(std::sync::atomic::Ordering::Acquire) < 10 {
tokio::task::yield_now().await;
}
guard.mark_initialized();
for handle in handles {
handle.await.unwrap();
}
}
}