1use std::{
7 collections::HashMap,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12 time::Instant,
13};
14
15use crate::wire::MEvent;
16
17#[derive(Debug, Clone)]
19pub struct PersistError {
20 pub entity_type: String,
21 pub message: String,
22}
23
24impl std::fmt::Display for PersistError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(
27 f,
28 "persist failed for {}: {}",
29 self.entity_type, self.message
30 )
31 }
32}
33
34impl std::error::Error for PersistError {}
35
36const RATE_WINDOW_SECS: f64 = 1.0;
38
39#[derive(Debug)]
41pub struct PersistHealth {
42 pub queued: AtomicU64,
44 pub total_persisted: AtomicU64,
46 pub total_errors: AtomicU64,
48 pub consecutive_errors: AtomicU64,
50 pub last_error: std::sync::RwLock<Option<String>>,
52 rate_window_count: AtomicU64,
54 rate_window_start: std::sync::RwLock<Instant>,
56}
57
58impl Default for PersistHealth {
59 fn default() -> Self {
60 Self {
61 queued: AtomicU64::new(0),
62 total_persisted: AtomicU64::new(0),
63 total_errors: AtomicU64::new(0),
64 consecutive_errors: AtomicU64::new(0),
65 last_error: std::sync::RwLock::new(None),
66 rate_window_count: AtomicU64::new(0),
67 rate_window_start: std::sync::RwLock::new(Instant::now()),
68 }
69 }
70}
71
72impl PersistHealth {
73 pub fn record_enqueue(&self) {
74 self.queued.fetch_add(1, Ordering::Relaxed);
75 }
76
77 pub fn record_success(&self) {
78 self.queued.fetch_sub(1, Ordering::Relaxed);
79 self.total_persisted.fetch_add(1, Ordering::Relaxed);
80 if self.consecutive_errors.swap(0, Ordering::Relaxed) > 0 {
81 *self.last_error.write().unwrap() = None;
82 }
83 }
84
85 pub fn record_success_batch(&self, count: u64) {
87 self.queued.fetch_sub(count, Ordering::Relaxed);
88 self.total_persisted.fetch_add(count, Ordering::Relaxed);
89 if self.consecutive_errors.swap(0, Ordering::Relaxed) > 0 {
90 *self.last_error.write().unwrap() = None;
91 }
92 }
93
94 pub fn record_error(&self, msg: String) {
95 self.queued.fetch_sub(1, Ordering::Relaxed);
96 self.total_errors.fetch_add(1, Ordering::Relaxed);
97 self.consecutive_errors.fetch_add(1, Ordering::Relaxed);
98 *self.last_error.write().unwrap() = Some(msg);
99 }
100
101 pub fn record_dropped(&self, msg: String) {
102 self.total_errors.fetch_add(1, Ordering::Relaxed);
103 self.consecutive_errors.fetch_add(1, Ordering::Relaxed);
104 *self.last_error.write().unwrap() = Some(msg);
105 }
106
107 pub fn record_error_no_dequeue(&self, msg: String) {
109 self.total_errors.fetch_add(1, Ordering::Relaxed);
110 self.consecutive_errors.fetch_add(1, Ordering::Relaxed);
111 *self.last_error.write().unwrap() = Some(msg);
112 }
113
114 pub fn writes_per_second(&self) -> f64 {
120 let current_total = self.total_persisted.load(Ordering::Relaxed);
121 let mut start = self.rate_window_start.write().unwrap();
122 let elapsed = start.elapsed().as_secs_f64();
123
124 if elapsed >= RATE_WINDOW_SECS {
125 let window_count = self
126 .rate_window_count
127 .swap(current_total, Ordering::Relaxed);
128 let delta = current_total.saturating_sub(window_count);
129 *start = Instant::now();
130 delta as f64 / elapsed
131 } else if elapsed > 0.0 {
132 let window_count = self.rate_window_count.load(Ordering::Relaxed);
133 let delta = current_total.saturating_sub(window_count);
134 delta as f64 / elapsed
135 } else {
136 0.0
137 }
138 }
139}
140
141pub trait Persister: Send + Sync + 'static {
143 fn persist(&self, event: MEvent) -> Result<(), PersistError>;
145
146 fn startup_healthcheck(&self) -> Result<(), String> {
151 Ok(())
152 }
153
154 fn health(&self) -> Arc<PersistHealth> {
156 static HEALTHY: std::sync::OnceLock<Arc<PersistHealth>> = std::sync::OnceLock::new();
158 HEALTHY
159 .get_or_init(|| Arc::new(PersistHealth::default()))
160 .clone()
161 }
162}
163
164pub struct NullPersister;
166
167impl Persister for NullPersister {
168 fn persist(&self, _event: MEvent) -> Result<(), PersistError> {
169 Ok(())
170 }
171}
172
173pub struct BlackholePersister;
175
176impl Persister for BlackholePersister {
177 fn persist(&self, _event: MEvent) -> Result<(), PersistError> {
178 Ok(())
179 }
180}
181
182#[derive(Default, Clone)]
186pub struct PersisterRouter {
187 default: Option<Arc<dyn Persister>>,
188 overrides: HashMap<String, Arc<dyn Persister>>,
189}
190
191impl PersisterRouter {
192 pub fn set_default(&mut self, persister: Option<Arc<dyn Persister>>) {
194 self.default = persister;
195 }
196
197 pub fn set_override(&mut self, entity_type: impl Into<String>, persister: Arc<dyn Persister>) {
199 self.overrides.insert(entity_type.into(), persister);
200 }
201
202 pub fn resolve(&self, entity_type: &str) -> Option<Arc<dyn Persister>> {
204 self.overrides
205 .get(entity_type)
206 .cloned()
207 .or_else(|| self.default.clone())
208 }
209
210 pub fn default_health(&self) -> Arc<PersistHealth> {
216 self.default
217 .as_ref()
218 .map(|p| p.health())
219 .unwrap_or_else(|| {
220 static HEALTHY: std::sync::OnceLock<Arc<PersistHealth>> =
221 std::sync::OnceLock::new();
222 HEALTHY
223 .get_or_init(|| Arc::new(PersistHealth::default()))
224 .clone()
225 })
226 }
227
228 pub fn startup_healthcheck(&self, entity_types: &[&str]) -> Result<(), String> {
230 for entity_type in entity_types {
231 if let Some(persister) = self.resolve(entity_type) {
232 persister.startup_healthcheck().map_err(|reason| {
233 format!(
234 "Persister startup healthcheck failed for entity type `{}`: {}",
235 entity_type, reason
236 )
237 })?;
238 }
239 }
240 Ok(())
241 }
242}