1use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard};
16use serde::{Serialize, de::DeserializeOwned};
17use std::path::PathBuf;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
20use std::time::Duration;
21
22use crate::backend::{Backend, NullBackend};
23use crate::error::{Error, Result};
24use crate::wal::{IncrementalSave, Op, Replayable, Transactable, WalBackend};
25
26#[derive(Debug, Clone)]
28pub enum FlushPolicy {
29 Immediate,
31 Grouped { interval: Duration },
33}
34
35struct FlushShared<T, B: Backend<T>> {
37 state: Arc<RwLock<T>>,
38 backend: Arc<B>,
39 incremental: Option<Arc<dyn IncrementalSave<T>>>,
42 pending_ops: Mutex<Vec<Vec<Op>>>,
46 gen_written: AtomicU64,
47 gen_flushed: AtomicU64,
48 notify: Condvar,
49 notify_mu: Mutex<()>,
50 last_error: Mutex<Option<Error>>,
51 shutdown: AtomicBool,
52}
53
54struct FlushState {
56 handle: Mutex<Option<std::thread::JoinHandle<()>>>,
57}
58
59pub struct Store<T, B: Backend<T> = NullBackend> {
69 state: Arc<RwLock<T>>,
70 write_gate: Mutex<()>,
71 backend: Arc<B>,
72 incremental: Option<Arc<dyn IncrementalSave<T>>>,
75 shared: Option<Arc<FlushShared<T, B>>>,
77 flusher: Option<FlushState>,
79}
80
81pub struct Ref<'a, T>(RwLockReadGuard<'a, T>);
83
84impl<'a, T> std::ops::Deref for Ref<'a, T> {
85 type Target = T;
86 fn deref(&self) -> &T {
87 &self.0
88 }
89}
90
91impl<T: Default> Store<T, NullBackend> {
92 pub fn memory() -> Self {
94 Self {
95 state: Arc::new(RwLock::new(T::default())),
96 write_gate: Mutex::new(()),
97 backend: Arc::new(NullBackend),
98 incremental: None,
99 shared: None,
100 flusher: None,
101 }
102 }
103}
104
105impl<T: Replayable + Serialize + DeserializeOwned + Default> Store<T, WalBackend<T>> {
106 pub fn open_wal(dir: PathBuf) -> Result<Self> {
108 let backend = WalBackend::open(&dir)?;
109 let state = backend.load()?;
110 let backend = Arc::new(backend);
111 let incremental: Arc<dyn IncrementalSave<T>> = Arc::clone(&backend) as _;
112 Ok(Self {
113 state: Arc::new(RwLock::new(state)),
114 write_gate: Mutex::new(()),
115 backend,
116 incremental: Some(incremental),
117 shared: None,
118 flusher: None,
119 })
120 }
121}
122
123impl<T: Clone, B: Backend<T>> Store<T, B> {
124 pub fn with_backend(backend: B) -> Result<Self>
126 where
127 T: DeserializeOwned,
128 {
129 let state = backend.load()?;
130 Ok(Self {
131 state: Arc::new(RwLock::new(state)),
132 write_gate: Mutex::new(()),
133 backend: Arc::new(backend),
134 incremental: None,
135 shared: None,
136 flusher: None,
137 })
138 }
139
140 pub fn read(&self) -> Ref<'_, T> {
142 Ref(self.state.read())
143 }
144
145 pub fn flush_error(&self) -> Option<Error> {
147 self.shared
148 .as_ref()
149 .and_then(|s| s.last_error.lock().take())
150 }
151
152 pub fn backend(&self) -> &B {
154 &self.backend
155 }
156}
157
158impl<T: Transactable, B: Backend<T>> Store<T, B> {
160 pub fn write<F, R>(&self, f: F) -> Result<R>
170 where
171 F: for<'a> FnOnce(&mut T::Tx<'a>) -> Result<R>,
172 {
173 let _gate = self.write_gate.lock();
174
175 if let Some(ref shared) = self.shared
177 && let Some(err) = shared.last_error.lock().take()
178 {
179 return Err(err);
180 }
181
182 let state_guard = self.state.read();
184 let mut tx = state_guard.begin_tx();
185 let result = f(&mut tx)?;
186 let (ops, overlay) = T::finish_tx(tx);
187 drop(state_guard); if let Some(ref inc) = self.incremental {
191 if !ops.is_empty() {
193 match &self.shared {
194 None => {
195 inc.save_ops(&ops)?;
196 inc.sync()?;
197 }
198 Some(shared) => {
199 shared.pending_ops.lock().push(ops);
200 }
201 }
202 }
203 } else {
204 self.state.write().apply_overlay(overlay);
206 match &self.shared {
207 None => {
208 self.backend.save(&self.state.read())?;
209 }
210 Some(shared) => {
211 shared.gen_written.fetch_add(1, Ordering::Release);
212 shared.notify.notify_one();
213 }
214 }
215 return Ok(result);
216 }
217
218 self.state.write().apply_overlay(overlay);
219
220 if let Some(ref shared) = self.shared {
221 shared.gen_written.fetch_add(1, Ordering::Release);
222 shared.notify.notify_one();
223 }
224
225 Ok(result)
226 }
227
228 pub fn write_durable<F, R>(&self, f: F) -> Result<R>
233 where
234 F: for<'a> FnOnce(&mut T::Tx<'a>) -> Result<R>,
235 {
236 let _gate = self.write_gate.lock();
237
238 let state_guard = self.state.read();
239 let mut tx = state_guard.begin_tx();
240 let result = f(&mut tx)?;
241 let (ops, overlay) = T::finish_tx(tx);
242 drop(state_guard);
243
244 if let Some(ref inc) = self.incremental {
245 if let Some(ref shared) = self.shared {
247 let batched: Vec<Vec<Op>> = {
248 let mut pending = shared.pending_ops.lock();
249 std::mem::take(&mut *pending)
250 };
251 for batch in &batched {
252 inc.save_ops(batch)?;
253 }
254 }
255
256 if !ops.is_empty() {
257 inc.save_ops(&ops)?;
258 }
259 inc.sync()?;
260 } else {
261 self.state.write().apply_overlay(overlay);
263 self.backend.save(&self.state.read())?;
264 if let Some(ref shared) = self.shared {
265 let generation = shared.gen_written.fetch_add(1, Ordering::Release) + 1;
266 shared.gen_flushed.store(generation, Ordering::Release);
267 }
268 return Ok(result);
269 }
270
271 self.state.write().apply_overlay(overlay);
272
273 if let Some(ref shared) = self.shared {
274 let generation = shared.gen_written.fetch_add(1, Ordering::Release) + 1;
275 shared.gen_flushed.store(generation, Ordering::Release);
276 }
277
278 Ok(result)
279 }
280}
281
282impl<T: Clone + Send + Sync + 'static, B: Backend<T> + Send + Sync + 'static> Store<T, B> {
283 pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
286 self.shutdown_flusher();
288
289 match policy {
290 FlushPolicy::Immediate => {
291 self.shared = None;
292 self.flusher = None;
293 }
294 FlushPolicy::Grouped { interval } => {
295 let shared = Arc::new(FlushShared {
296 state: Arc::clone(&self.state),
297 backend: Arc::clone(&self.backend),
298 incremental: self.incremental.clone(),
299 pending_ops: Mutex::new(Vec::new()),
300 gen_written: AtomicU64::new(0),
301 gen_flushed: AtomicU64::new(0),
302 notify: Condvar::new(),
303 notify_mu: Mutex::new(()),
304 last_error: Mutex::new(None),
305 shutdown: AtomicBool::new(false),
306 });
307
308 let thread_shared = Arc::clone(&shared);
309 let handle = std::thread::Builder::new()
310 .name("store-flusher".into())
311 .spawn(move || flusher_loop(&thread_shared, interval))
312 .expect("failed to spawn flusher thread");
313
314 self.shared = Some(shared);
315 self.flusher = Some(FlushState {
316 handle: Mutex::new(Some(handle)),
317 });
318 }
319 }
320 }
321
322 pub fn flush(&self) -> Result<()> {
328 let Some(ref shared) = self.shared else {
329 return Ok(());
330 };
331
332 let target_gen = shared.gen_written.load(Ordering::Acquire);
333 if target_gen == shared.gen_flushed.load(Ordering::Acquire) {
334 return Ok(());
335 }
336
337 let start = std::time::Instant::now();
339 loop {
340 shared.notify.notify_one();
341
342 if shared.gen_flushed.load(Ordering::Acquire) >= target_gen {
343 break;
344 }
345
346 if start.elapsed() > std::time::Duration::from_secs(5) {
347 return Err(Error::Io(std::io::Error::new(
348 std::io::ErrorKind::TimedOut,
349 "flush timed out waiting for flusher",
350 )));
351 }
352
353 if let Some(err) = shared.last_error.lock().take() {
355 return Err(err);
356 }
357
358 std::thread::sleep(std::time::Duration::from_millis(1));
359 }
360
361 if let Some(err) = shared.last_error.lock().take() {
363 return Err(err);
364 }
365 Ok(())
366 }
367
368 pub fn close(&mut self) -> Result<()> {
370 self.shutdown_flusher();
371 Ok(())
372 }
373
374 fn shutdown_flusher(&mut self) {
375 if let Some(ref shared) = self.shared {
376 shared.shutdown.store(true, Ordering::Release);
377 shared.notify.notify_one();
378 }
379 if let Some(ref flusher) = self.flusher
380 && let Some(handle) = flusher.handle.lock().take()
381 {
382 let _ = handle.join();
383 }
384 }
385}
386
387fn flusher_loop<T: Clone, B: Backend<T>>(shared: &FlushShared<T, B>, interval: Duration) {
392 loop {
393 {
394 let mut guard = shared.notify_mu.lock();
395 shared.notify.wait_for(&mut guard, interval);
396 }
397
398 let should_shutdown = shared.shutdown.load(Ordering::Acquire);
399
400 let current_gen = shared.gen_written.load(Ordering::Acquire);
401 let flushed_gen = shared.gen_flushed.load(Ordering::Acquire);
402
403 if current_gen != flushed_gen {
404 let result = if let Some(ref inc) = shared.incremental {
405 let batched: Vec<Vec<Op>> = {
407 let mut pending = shared.pending_ops.lock();
408 std::mem::take(&mut *pending)
409 };
410 let mut write_err = None;
411 for ops in &batched {
412 if let Err(e) = inc.save_ops(ops) {
413 write_err = Some(e);
414 break;
415 }
416 }
417 match write_err {
418 Some(e) => Err(e),
419 None => match inc.sync() {
420 Ok(()) => {
421 if inc.should_snapshot() {
423 let snapshot = shared.state.read().clone();
424 inc.snapshot(&snapshot)
425 } else {
426 Ok(())
427 }
428 }
429 Err(e) => Err(e),
430 },
431 }
432 } else {
433 let snapshot = shared.state.read().clone();
435 shared.backend.save(&snapshot)
436 };
437
438 match result {
439 Ok(()) => {
440 shared.gen_flushed.store(current_gen, Ordering::Release);
441 }
442 Err(e) => {
443 *shared.last_error.lock() = Some(e);
444 }
445 }
446 }
447
448 if should_shutdown {
449 break;
450 }
451 }
452}
453
454impl<T, B: Backend<T>> Drop for Store<T, B> {
455 fn drop(&mut self) {
456 if let Some(ref shared) = self.shared {
457 shared.shutdown.store(true, Ordering::Release);
458 shared.notify.notify_one();
459 }
460 if let Some(ref flusher) = self.flusher
461 && let Some(handle) = flusher.handle.lock().take()
462 {
463 let _ = handle.join();
464 }
465 }
466}
467
468#[cfg(test)]
469#[path = "store_test.rs"]
470mod store_test;