use std::sync::Arc;
use crate::ident::PluginEventKind;
#[derive(Debug, thiserror::Error)]
pub enum StatefulPluginError {
#[error("stateful plugin '{plugin}' failed to snapshot: {details}")]
SnapshotFailed { plugin: String, details: String },
#[error("stateful plugin '{plugin}' failed to restore: {details}")]
RestoreFailed { plugin: String, details: String },
#[error(
"stateful plugin '{plugin}' received unsupported snapshot version: {version} (expected one of {expected:?})"
)]
UnsupportedVersion {
plugin: String,
version: u32,
expected: Vec<u32>,
},
}
pub type StatefulPluginResult<T> = std::result::Result<T, StatefulPluginError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StatefulPluginSnapshot {
pub id: PluginEventKind,
pub version: u32,
pub bytes: Vec<u8>,
}
impl StatefulPluginSnapshot {
#[must_use]
pub const fn new(id: PluginEventKind, version: u32, bytes: Vec<u8>) -> Self {
Self { id, version, bytes }
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
}
pub trait StatefulPlugin: Send + Sync {
fn id(&self) -> PluginEventKind;
fn snapshot(&self) -> StatefulPluginResult<StatefulPluginSnapshot>;
fn restore_snapshot(&self, snapshot: StatefulPluginSnapshot) -> StatefulPluginResult<()> {
let _ = snapshot;
Ok(())
}
}
#[derive(Clone)]
pub struct StatefulPluginHandle(pub Arc<dyn StatefulPlugin>);
impl StatefulPluginHandle {
#[must_use]
pub fn new<P: StatefulPlugin + 'static>(plugin: P) -> Self {
Self(Arc::new(plugin))
}
#[must_use]
pub fn from_arc(plugin: Arc<dyn StatefulPlugin>) -> Self {
Self(plugin)
}
#[must_use]
pub fn as_dyn(&self) -> &(dyn StatefulPlugin + 'static) {
self.0.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
const TEST_ID: PluginEventKind = PluginEventKind::from_static("bmux.test/stateful");
struct Counter {
value: Mutex<u32>,
}
impl StatefulPlugin for Counter {
fn id(&self) -> PluginEventKind {
TEST_ID
}
fn snapshot(&self) -> StatefulPluginResult<StatefulPluginSnapshot> {
let value = *self
.value
.lock()
.map_err(|_| StatefulPluginError::SnapshotFailed {
plugin: TEST_ID.as_str().to_string(),
details: "poisoned".into(),
})?;
Ok(StatefulPluginSnapshot::new(
TEST_ID,
1,
value.to_le_bytes().to_vec(),
))
}
fn restore_snapshot(&self, snapshot: StatefulPluginSnapshot) -> StatefulPluginResult<()> {
if snapshot.version != 1 {
return Err(StatefulPluginError::UnsupportedVersion {
plugin: TEST_ID.as_str().to_string(),
version: snapshot.version,
expected: vec![1],
});
}
let bytes: [u8; 4] =
snapshot
.bytes
.try_into()
.map_err(|_| StatefulPluginError::RestoreFailed {
plugin: TEST_ID.as_str().to_string(),
details: "expected 4 bytes".into(),
})?;
*self
.value
.lock()
.map_err(|_| StatefulPluginError::RestoreFailed {
plugin: TEST_ID.as_str().to_string(),
details: "poisoned".into(),
})? = u32::from_le_bytes(bytes);
Ok(())
}
}
#[test]
fn snapshot_round_trips_through_restore() {
let counter = Counter {
value: Mutex::new(42),
};
let snap = counter.snapshot().expect("snapshot");
assert_eq!(snap.id, TEST_ID);
assert_eq!(snap.version, 1);
let target = Counter {
value: Mutex::new(0),
};
target.restore_snapshot(snap).expect("restore");
assert_eq!(*target.value.lock().expect("lock"), 42);
}
#[test]
fn unknown_version_is_reported() {
let counter = Counter {
value: Mutex::new(0),
};
let bad = StatefulPluginSnapshot::new(TEST_ID, 99, vec![0, 0, 0, 0]);
let err = counter.restore_snapshot(bad).unwrap_err();
assert!(matches!(
err,
StatefulPluginError::UnsupportedVersion { version: 99, .. }
));
}
#[test]
fn default_restore_is_no_op() {
struct EmptyStateful;
impl StatefulPlugin for EmptyStateful {
fn id(&self) -> PluginEventKind {
TEST_ID
}
fn snapshot(&self) -> StatefulPluginResult<StatefulPluginSnapshot> {
Ok(StatefulPluginSnapshot::new(TEST_ID, 1, Vec::new()))
}
}
let plugin = EmptyStateful;
let snap = plugin.snapshot().expect("snapshot");
plugin
.restore_snapshot(snap)
.expect("default restore succeeds");
}
#[test]
fn handle_as_dyn_preserves_trait_object_identity() {
let handle = StatefulPluginHandle::new(Counter {
value: Mutex::new(7),
});
let snap = handle.as_dyn().snapshot().expect("snapshot via handle");
assert_eq!(snap.id, TEST_ID);
}
}