use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use crate::plugin::PluginId;
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum LifecycleState {
Loaded = 0,
Linked = 1,
Initialized = 2,
Active = 3,
Draining = 4,
Removed = 5,
}
impl LifecycleState {
fn from_u8(v: u8) -> Self {
match v {
0 => Self::Loaded,
1 => Self::Linked,
2 => Self::Initialized,
3 => Self::Active,
4 => Self::Draining,
_ => Self::Removed,
}
}
}
#[derive(Debug)]
pub struct PluginLifecycle {
plugin: PluginId,
state: AtomicU8,
}
impl PluginLifecycle {
#[must_use]
pub fn new(plugin: PluginId) -> Self {
Self {
plugin,
state: AtomicU8::new(LifecycleState::Loaded as u8),
}
}
#[must_use]
pub fn state(&self) -> LifecycleState {
LifecycleState::from_u8(self.state.load(Ordering::SeqCst))
}
#[must_use]
pub fn plugin(&self) -> &PluginId {
&self.plugin
}
pub fn advance(&self) -> LifecycleState {
let cur = self.state.load(Ordering::SeqCst);
let next = match LifecycleState::from_u8(cur) {
LifecycleState::Loaded => LifecycleState::Linked,
LifecycleState::Linked => LifecycleState::Initialized,
LifecycleState::Initialized => LifecycleState::Active,
LifecycleState::Active => LifecycleState::Draining,
LifecycleState::Draining => LifecycleState::Removed,
LifecycleState::Removed => LifecycleState::Removed,
};
self.state.store(next as u8, Ordering::SeqCst);
next
}
pub fn set(&self, s: LifecycleState) {
self.state.store(s as u8, Ordering::SeqCst);
}
#[must_use]
pub fn is_active(&self) -> bool {
self.state() == LifecycleState::Active
}
#[must_use]
pub fn is_winding_down(&self) -> bool {
matches!(
self.state(),
LifecycleState::Draining | LifecycleState::Removed
)
}
}
pub type SharedLifecycle = Arc<PluginLifecycle>;
#[derive(Debug)]
pub struct EpochFencedReload {
old: Arc<PluginLifecycle>,
}
impl EpochFencedReload {
#[must_use]
pub fn new(old: Arc<PluginLifecycle>) -> Self {
Self { old }
}
pub fn begin_drain(&self) -> Result<(), DrainError> {
if !self.old.is_active() {
return Err(DrainError::NotActive {
current: self.old.state(),
});
}
let new = self.old.advance();
if new != LifecycleState::Draining {
return Err(DrainError::UnexpectedTransition { reached: new });
}
Ok(())
}
pub fn wait_for_drain(
&self,
threshold: usize,
poll_interval: std::time::Duration,
max_wait: std::time::Duration,
) -> Result<(), DrainError> {
let start = std::time::Instant::now();
loop {
let count = Arc::strong_count(&self.old);
if count <= threshold {
return Ok(());
}
if start.elapsed() >= max_wait {
return Err(DrainError::Timeout {
waited: start.elapsed(),
strong_count: count,
});
}
std::thread::sleep(poll_interval);
}
}
pub fn finalize(&self) {
while self.old.state() != LifecycleState::Removed {
self.old.advance();
}
}
#[must_use]
pub fn old_lifecycle(&self) -> &Arc<PluginLifecycle> {
&self.old
}
}
#[derive(Debug, thiserror::Error)]
pub enum DrainError {
#[error("cannot drain: plugin is in {current:?}, not Active")]
NotActive {
current: LifecycleState,
},
#[error("unexpected lifecycle transition: reached {reached:?}")]
UnexpectedTransition {
reached: LifecycleState,
},
#[error(
"drain timed out after {waited:?}; strong_count remained {strong_count} (threshold not reached)"
)]
Timeout {
waited: std::time::Duration,
strong_count: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lifecycle_starts_at_loaded() {
let l = PluginLifecycle::new(PluginId::new("x"));
assert_eq!(l.state(), LifecycleState::Loaded);
}
#[test]
fn advance_progresses_through_states() {
let l = PluginLifecycle::new(PluginId::new("x"));
assert_eq!(l.advance(), LifecycleState::Linked);
assert_eq!(l.advance(), LifecycleState::Initialized);
assert_eq!(l.advance(), LifecycleState::Active);
assert!(l.is_active());
assert_eq!(l.advance(), LifecycleState::Draining);
assert!(l.is_winding_down());
assert_eq!(l.advance(), LifecycleState::Removed);
assert!(l.is_winding_down());
assert_eq!(l.advance(), LifecycleState::Removed);
}
#[test]
fn set_is_explicit_state_override() {
let l = PluginLifecycle::new(PluginId::new("x"));
l.set(LifecycleState::Active);
assert!(l.is_active());
}
#[test]
fn epoch_drain_advances_state_from_active() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("x")));
l.set(LifecycleState::Active);
let driver = EpochFencedReload::new(Arc::clone(&l));
driver.begin_drain().unwrap();
assert_eq!(l.state(), LifecycleState::Draining);
}
#[test]
fn epoch_drain_rejects_non_active_state() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("x")));
let driver = EpochFencedReload::new(l);
let err = driver.begin_drain().unwrap_err();
match err {
DrainError::NotActive { current } => {
assert_eq!(current, LifecycleState::Loaded);
}
other => panic!("expected NotActive, got {other:?}"),
}
}
#[test]
fn wait_for_drain_returns_immediately_when_below_threshold() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("x")));
l.set(LifecycleState::Active);
let driver = EpochFencedReload::new(Arc::clone(&l));
driver.begin_drain().unwrap();
driver
.wait_for_drain(
2,
std::time::Duration::from_millis(1),
std::time::Duration::from_millis(100),
)
.expect("should drain immediately");
}
#[test]
fn wait_for_drain_times_out_when_references_persist() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("x")));
l.set(LifecycleState::Active);
let extra = Arc::clone(&l);
let driver = EpochFencedReload::new(Arc::clone(&l));
driver.begin_drain().unwrap();
let err = driver
.wait_for_drain(
1,
std::time::Duration::from_millis(1),
std::time::Duration::from_millis(20),
)
.unwrap_err();
match err {
DrainError::Timeout {
waited: _,
strong_count,
} => {
assert!(strong_count >= 3);
}
other => panic!("expected Timeout, got {other:?}"),
}
drop(extra); }
#[test]
fn finalize_advances_to_removed() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("x")));
l.set(LifecycleState::Active);
let driver = EpochFencedReload::new(Arc::clone(&l));
driver.begin_drain().unwrap();
driver.finalize();
assert_eq!(l.state(), LifecycleState::Removed);
}
#[test]
fn finalize_is_idempotent_at_removed() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("x")));
l.set(LifecycleState::Removed);
let driver = EpochFencedReload::new(Arc::clone(&l));
driver.finalize();
driver.finalize();
assert_eq!(l.state(), LifecycleState::Removed);
}
#[test]
fn epoch_fenced_reload_end_to_end() {
let l = Arc::new(PluginLifecycle::new(PluginId::new("plugin.geo")));
l.set(LifecycleState::Active);
let host_arc = Arc::clone(&l);
let driver = EpochFencedReload::new(Arc::clone(&l));
driver.begin_drain().expect("drain begin");
drop(host_arc);
driver
.wait_for_drain(
2,
std::time::Duration::from_millis(1),
std::time::Duration::from_secs(1),
)
.expect("wait_for_drain");
driver.finalize();
assert_eq!(l.state(), LifecycleState::Removed);
}
}