use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Instant,
};
use crate::wire::MEvent;
#[derive(Debug, Clone)]
pub struct PersistError {
pub entity_type: String,
pub message: String,
}
impl std::fmt::Display for PersistError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"persist failed for {}: {}",
self.entity_type, self.message
)
}
}
impl std::error::Error for PersistError {}
const RATE_WINDOW_SECS: f64 = 1.0;
#[derive(Debug)]
pub struct PersistHealth {
pub queued: AtomicU64,
pub total_persisted: AtomicU64,
pub total_errors: AtomicU64,
pub consecutive_errors: AtomicU64,
pub last_error: std::sync::RwLock<Option<String>>,
rate_window_count: AtomicU64,
rate_window_start: std::sync::RwLock<Instant>,
}
impl Default for PersistHealth {
fn default() -> Self {
Self {
queued: AtomicU64::new(0),
total_persisted: AtomicU64::new(0),
total_errors: AtomicU64::new(0),
consecutive_errors: AtomicU64::new(0),
last_error: std::sync::RwLock::new(None),
rate_window_count: AtomicU64::new(0),
rate_window_start: std::sync::RwLock::new(Instant::now()),
}
}
}
impl PersistHealth {
pub fn record_enqueue(&self) {
self.queued.fetch_add(1, Ordering::Relaxed);
}
pub fn record_success(&self) {
self.queued.fetch_sub(1, Ordering::Relaxed);
self.total_persisted.fetch_add(1, Ordering::Relaxed);
if self.consecutive_errors.swap(0, Ordering::Relaxed) > 0 {
*self.last_error.write().unwrap() = None;
}
}
pub fn record_success_batch(&self, count: u64) {
self.queued.fetch_sub(count, Ordering::Relaxed);
self.total_persisted.fetch_add(count, Ordering::Relaxed);
if self.consecutive_errors.swap(0, Ordering::Relaxed) > 0 {
*self.last_error.write().unwrap() = None;
}
}
pub fn record_error(&self, msg: String) {
self.queued.fetch_sub(1, Ordering::Relaxed);
self.total_errors.fetch_add(1, Ordering::Relaxed);
self.consecutive_errors.fetch_add(1, Ordering::Relaxed);
*self.last_error.write().unwrap() = Some(msg);
}
pub fn record_dropped(&self, msg: String) {
self.total_errors.fetch_add(1, Ordering::Relaxed);
self.consecutive_errors.fetch_add(1, Ordering::Relaxed);
*self.last_error.write().unwrap() = Some(msg);
}
pub fn record_error_no_dequeue(&self, msg: String) {
self.total_errors.fetch_add(1, Ordering::Relaxed);
self.consecutive_errors.fetch_add(1, Ordering::Relaxed);
*self.last_error.write().unwrap() = Some(msg);
}
pub fn writes_per_second(&self) -> f64 {
let current_total = self.total_persisted.load(Ordering::Relaxed);
let mut start = self.rate_window_start.write().unwrap();
let elapsed = start.elapsed().as_secs_f64();
if elapsed >= RATE_WINDOW_SECS {
let window_count = self
.rate_window_count
.swap(current_total, Ordering::Relaxed);
let delta = current_total.saturating_sub(window_count);
*start = Instant::now();
delta as f64 / elapsed
} else if elapsed > 0.0 {
let window_count = self.rate_window_count.load(Ordering::Relaxed);
let delta = current_total.saturating_sub(window_count);
delta as f64 / elapsed
} else {
0.0
}
}
}
pub trait Persister: Send + Sync + 'static {
fn persist(&self, event: MEvent) -> Result<(), PersistError>;
fn startup_healthcheck(&self) -> Result<(), String> {
Ok(())
}
fn health(&self) -> Arc<PersistHealth> {
static HEALTHY: std::sync::OnceLock<Arc<PersistHealth>> = std::sync::OnceLock::new();
HEALTHY
.get_or_init(|| Arc::new(PersistHealth::default()))
.clone()
}
}
pub struct NullPersister;
impl Persister for NullPersister {
fn persist(&self, _event: MEvent) -> Result<(), PersistError> {
Ok(())
}
}
pub struct BlackholePersister;
impl Persister for BlackholePersister {
fn persist(&self, _event: MEvent) -> Result<(), PersistError> {
Ok(())
}
}
#[derive(Default, Clone)]
pub struct PersisterRouter {
default: Option<Arc<dyn Persister>>,
overrides: HashMap<String, Arc<dyn Persister>>,
}
impl PersisterRouter {
pub fn set_default(&mut self, persister: Option<Arc<dyn Persister>>) {
self.default = persister;
}
pub fn set_override(&mut self, entity_type: impl Into<String>, persister: Arc<dyn Persister>) {
self.overrides.insert(entity_type.into(), persister);
}
pub fn resolve(&self, entity_type: &str) -> Option<Arc<dyn Persister>> {
self.overrides
.get(entity_type)
.cloned()
.or_else(|| self.default.clone())
}
pub fn default_health(&self) -> Arc<PersistHealth> {
self.default
.as_ref()
.map(|p| p.health())
.unwrap_or_else(|| {
static HEALTHY: std::sync::OnceLock<Arc<PersistHealth>> =
std::sync::OnceLock::new();
HEALTHY
.get_or_init(|| Arc::new(PersistHealth::default()))
.clone()
})
}
pub fn startup_healthcheck(&self, entity_types: &[&str]) -> Result<(), String> {
for entity_type in entity_types {
if let Some(persister) = self.resolve(entity_type) {
persister.startup_healthcheck().map_err(|reason| {
format!(
"Persister startup healthcheck failed for entity type `{}`: {}",
entity_type, reason
)
})?;
}
}
Ok(())
}
}