1use error::{StoreError, StoreResult};
2use std::convert::Infallible;
3use std::ffi::OsStr;
4use std::marker::PhantomData;
5use std::num::NonZeroUsize;
6use std::ops::Deref;
7use std::path::Path;
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::fs::{self, File};
12use tokio::io::BufWriter;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::sync::RwLock;
15use tokio::sync::RwLockReadGuard;
16use tokio::task::JoinHandle;
17use tokio::time::sleep;
18
19pub mod error;
20pub mod tx;
21
22pub use tx::Tx;
23
24pub type JsonStore<D, T> = Store<D, T, JsonSerializer>;
25
26pub struct Store<D, T, S> {
36 inner: Arc<StoreInner<D, T, S>>,
37}
38
39impl<D, T, S> Clone for Store<D, T, S> {
40 fn clone(&self) -> Self {
41 Self {
42 inner: Arc::clone(&self.inner),
43 }
44 }
45}
46
47impl<D, T, S> Store<D, T, S>
48where
49 D: Default + Send + Sync + 'static,
50 T: Tx<D>,
51 S: Serializer<D> + Serializer<T>,
52{
53 pub async fn open(
56 serializer: S,
57 options: StoreOptions,
58 dir: impl Into<PathBuf>,
59 ) -> StoreResult<Self> {
60 let inner = StoreInner::open(serializer, options, dir).await?;
61 Ok(Self {
62 inner: Arc::new(inner),
63 })
64 }
65
66 pub async fn commit<Q, R>(&mut self, tx_query: Q) -> StoreResult<R>
84 where
85 Q: Tx<D, R> + Into<T> + From<T>,
86 S: Send + Sync + 'static,
87 T: Send + Sync + 'static,
88 R: Send + 'static,
89 {
90 let inner = Arc::clone(&self.inner);
91 let wrapped_tx: T = tx_query.into();
92 let serialized: Vec<u8> = inner
93 .serializer
94 .serialize(&wrapped_tx)
95 .map_err(|err| StoreError::EncodeJournalEntry(err.into()))?;
96 let result = tokio::spawn(async move {
97 let mut persistent_holder = inner.persistent.write().await;
98 let mut persistent_data = match persistent_holder.take() {
99 Some(persistent_data) => persistent_data,
100 None => {
101 let reloaded_data = inner.load_persistent_data().await?;
102 reloaded_data
103 }
104 };
105
106 let journal_file = persistent_data
107 .writable_journal(
108 inner.options.max_journal_entries,
109 &inner.serializer,
110 &inner.dir,
111 )
112 .await?;
113
114 journal_file
115 .append(&serialized)
116 .await
117 .map_err(StoreError::JournalIO)?;
118 if let JournalFlushPolicy::EveryCommit {
119 journal_flush_method,
120 ..
121 } = inner.options.journal_flush_policy
122 {
123 journal_file
124 .persist(journal_flush_method)
125 .await
126 .map_err(StoreError::JournalIO)?;
127 }
128 let tx_query: Q = wrapped_tx.into();
129 let result = tx_query.execute(&mut persistent_data.data);
130
131 *persistent_holder = Some(persistent_data);
132 Ok::<R, StoreError>(result)
133 })
134 .await
135 .map_err(StoreError::JoinError)??;
136 Ok(result)
137 }
138
139 pub async fn query(&self) -> StoreResult<QueryGuard<'_, D>> {
142 let mut persistent_data = self.inner.persistent.read().await;
143 if persistent_data.is_none() {
144 drop(persistent_data);
145 let mut writable_persistent_data = self.inner.persistent.write().await;
146 *writable_persistent_data = Some(self.inner.load_persistent_data().await?);
147 persistent_data = writable_persistent_data.downgrade();
148 }
149 Ok(QueryGuard(RwLockReadGuard::map(persistent_data, |p| {
150 p.as_ref().unwrap()
151 })))
152 }
153
154 pub async fn flush(&mut self) -> StoreResult<()> {
155 let mut persistent = self.inner.persistent.write().await;
156 match &mut *persistent {
157 Some(persistent) => persistent
158 .journal
159 .flush()
160 .await
161 .map_err(StoreError::JournalIO),
162 None => Err(StoreError::StatePoisoned),
163 }
164 }
165
166 pub async fn flush_and_sync(&mut self) -> StoreResult<()> {
168 let mut persistent = self.inner.persistent.write().await;
169 match &mut *persistent {
170 Some(persistent) => persistent
171 .journal
172 .flush_and_sync()
173 .await
174 .map_err(StoreError::JournalIO),
175 None => Err(StoreError::StatePoisoned),
176 }
177 }
178}
179
180struct StoreInner<D, T, S> {
181 persistent: SharedPersistentData<D>,
182 serializer: S,
183 options: StoreOptions,
184 dir: PathBuf,
185 _flusher_guard: Option<FlusherGuard>,
186 _phantom: PhantomData<T>,
187}
188
189impl<T, D, S> StoreInner<D, T, S>
190where
191 D: Default + Send + Sync + 'static,
192 S: Serializer<D> + Serializer<T>,
193 T: Tx<D>,
194{
195 pub async fn open(
196 serializer: S,
197 options: StoreOptions,
198 dir: impl Into<PathBuf>,
199 ) -> StoreResult<Self> {
200 let dir: PathBuf = dir.into();
201 let persistent_data = PersistentData::load::<T, S>(
202 &dir,
203 options.can_flush_synchronously_on_drop(),
204 &serializer,
205 )
206 .await?;
207 let shared_persistent_data = Arc::new(RwLock::new(Some(persistent_data)));
208 let _flusher_guard = match options.journal_flush_policy {
209 JournalFlushPolicy::EveryCommit { .. } | JournalFlushPolicy::Manually => None,
210 JournalFlushPolicy::Every {
211 duration,
212 journal_flush_method,
213 ..
214 } => Some(Self::start_flusher(
215 duration,
216 journal_flush_method,
217 Arc::clone(&shared_persistent_data),
218 )),
219 };
220 Ok(Self {
221 persistent: shared_persistent_data,
222 serializer,
223 options,
224 dir,
225 _flusher_guard,
226 _phantom: PhantomData,
227 })
228 }
229
230 async fn load_persistent_data(&self) -> StoreResult<PersistentData<D>> {
231 PersistentData::load::<T, S>(
232 &self.dir,
233 self.options.can_flush_synchronously_on_drop(),
234 &self.serializer,
235 )
236 .await
237 }
238
239 fn start_flusher(
240 interval: Duration,
241 journal_flush_method: JournalFlushMethod,
242 persistent: SharedPersistentData<D>,
243 ) -> FlusherGuard
244 where
245 D: Sync + Send + 'static,
246 S: Serializer<T>,
247 {
248 FlusherGuard(tokio::spawn(async move {
249 loop {
250 sleep(interval).await;
251 let mut persistent = persistent.write().await;
252 if let Some(persistent) = &mut *persistent {
253 if let Err(err) = persistent.journal.persist(journal_flush_method).await {
254 eprintln!("Could not flush journal log: {err:?}");
255 }
256 }
257 }
258 }))
259 }
260}
261
262#[derive(Debug, Clone)]
263pub struct StoreOptions {
264 max_journal_entries: NonZeroUsize,
265 journal_flush_policy: JournalFlushPolicy,
266 flush_synchronously_on_drop: bool,
267}
268
269impl StoreOptions {
270 pub fn max_journal_entries(mut self, value: NonZeroUsize) -> Self {
273 self.max_journal_entries = value;
274 self
275 }
276
277 pub fn journal_flush_policy(mut self, value: JournalFlushPolicy) -> Self {
279 self.journal_flush_policy = value;
280 self
281 }
282
283 pub fn flush_synchronously_on_drop(mut self, value: bool) -> Self {
286 self.flush_synchronously_on_drop = value;
287 self
288 }
289
290 fn can_flush_synchronously_on_drop(&self) -> bool {
291 self.flush_synchronously_on_drop
292 && !matches!(
293 self.journal_flush_policy,
294 JournalFlushPolicy::EveryCommit {
295 journal_flush_method: JournalFlushMethod::FlushAndSync
296 }
297 )
298 }
299}
300
301impl Default for StoreOptions {
302 fn default() -> Self {
303 Self {
304 max_journal_entries: NonZeroUsize::new(65535).unwrap(),
305 journal_flush_policy: JournalFlushPolicy::EveryCommit {
306 journal_flush_method: JournalFlushMethod::FlushAndSync,
307 },
308 flush_synchronously_on_drop: true,
309 }
310 }
311}
312
313#[derive(Debug, Clone, Copy)]
315pub enum JournalFlushPolicy {
316 EveryCommit {
318 journal_flush_method: JournalFlushMethod,
321 },
322 Every {
327 duration: Duration,
328 journal_flush_method: JournalFlushMethod,
329 },
330 Manually,
332}
333
334#[derive(Debug, Clone, Copy)]
335pub enum JournalFlushMethod {
336 FlushAndSync,
339 Flush,
341}
342
343struct FlusherGuard(JoinHandle<Infallible>);
344
345impl Drop for FlusherGuard {
346 fn drop(&mut self) {
347 self.0.abort();
348 }
349}
350
351type SharedPersistentData<D> = Arc<RwLock<Option<PersistentData<D>>>>;
352
353struct PersistentData<D> {
354 data: D,
355 next_snapshot_version: SnapshotVersion,
356 journal: JournalFile,
357 flush_on_drop: bool,
358}
359
360impl<D> PersistentData<D> {
361 pub async fn load<T, S>(dir: &Path, flush_on_drop: bool, serializer: &S) -> StoreResult<Self>
362 where
363 T: Tx<D>,
364 S: Serializer<D> + Serializer<T>,
365 D: Default,
366 {
367 fs::create_dir_all(dir).await.map_err(StoreError::FileIO)?;
368 let persistence_actions = PersistenceAction::rebuild(dir).await?;
369 let next_snapshot_version: SnapshotVersion = persistence_actions
370 .last()
371 .map(|action| match action {
372 PersistenceAction::Snapshot { version, .. } => *version + 1,
373 PersistenceAction::Journal { version, .. } => *version,
374 })
375 .unwrap_or_default();
376 let mut persistent = Self {
377 data: Default::default(),
378 journal: JournalFile::open(dir, next_snapshot_version).await?,
379 next_snapshot_version,
380 flush_on_drop,
381 };
382 persistent
383 .rebuild::<T, S>(&serializer, persistence_actions)
384 .await?;
385 Ok(persistent)
386 }
387
388 async fn rebuild<T, S>(
389 &mut self,
390 serializer: &S,
391 persistence_actions: Vec<PersistenceAction>,
392 ) -> StoreResult<()>
393 where
394 T: Tx<D>,
395 S: Serializer<D> + Serializer<T>,
396 {
397 for action in persistence_actions {
398 match action {
399 PersistenceAction::Snapshot { path, .. } => {
400 let file = fs::read(path).await.map_err(StoreError::SnapshotIO)?;
401 self.data = serializer
402 .deserialize(&file[..])
403 .map_err(|err| StoreError::DecodeSnapshot(Box::new(err)))?
404 .ok_or(StoreError::DecodeSnapshot(Box::new(std::io::Error::new(
405 std::io::ErrorKind::UnexpectedEof,
406 "corrupted snapshot",
407 ))))?;
408 }
409 PersistenceAction::Journal { path, .. } => {
410 let mut file = File::open(path).await.map_err(StoreError::JournalIO)?;
411 for tx in JournalFile::parse::<T, S>(serializer, &mut file).await? {
412 tx.execute(&mut self.data);
413 }
414 }
415 }
416 }
417 Ok(())
418 }
419
420 async fn writable_journal<S>(
421 &mut self,
422 max_entries: NonZeroUsize,
423 serializer: &S,
424 dir: &Path,
425 ) -> StoreResult<&mut JournalFile>
426 where
427 S: Serializer<D>,
428 {
429 if self.journal.written_entries >= max_entries.into() {
430 self.create_new_journal(serializer, dir).await?;
431 }
432 Ok(&mut self.journal)
433 }
434
435 async fn create_new_journal<S>(&mut self, serializer: &S, dir: &Path) -> StoreResult<()>
436 where
437 S: Serializer<D>,
438 {
439 self.journal
440 .flush_and_sync()
441 .await
442 .map_err(StoreError::JournalIO)?;
443 self.snapshot(serializer, dir).await?;
444 self.journal = JournalFile::open(dir, self.next_snapshot_version).await?;
445 Ok(())
446 }
447
448 async fn snapshot<S>(&mut self, serializer: &S, dir: &Path) -> StoreResult<()>
451 where
452 S: Serializer<D>,
453 {
454 let serialized = serializer
455 .serialize(&self.data)
456 .map_err(|err| StoreError::EncodeSnapshot(Box::new(err)))?;
457 let mut file = fs::OpenOptions::new()
458 .create(true)
459 .write(true)
460 .open(dir.join(format!("{:0>10}.snapshot", self.next_snapshot_version)))
461 .await
462 .map_err(StoreError::SnapshotIO)?;
463 file.write_all(&serialized)
464 .await
465 .map_err(StoreError::SnapshotIO)?;
466 file.sync_data().await.map_err(StoreError::SnapshotIO)?;
467 self.next_snapshot_version += 1;
468 Ok(())
469 }
470}
471
472impl<T> Drop for PersistentData<T> {
473 fn drop(&mut self) {
474 if self.flush_on_drop {
475 futures::executor::block_on(async move {
476 if let Err(err) = self.journal.flush_and_sync().await {
477 eprintln!("Could not flush journal log on drop: {err:?}");
478 };
479 });
480 }
481 }
482}
483
484pub struct QueryGuard<'a, T>(RwLockReadGuard<'a, PersistentData<T>>);
485
486impl<'a, T> Deref for QueryGuard<'a, T> {
487 type Target = T;
488
489 fn deref(&self) -> &Self::Target {
490 &self.0.data
491 }
492}
493
494type SnapshotVersion = u32;
495
496#[derive(Debug)]
497struct JournalFile {
498 writer: BufWriter<File>,
499 written_entries: usize,
500}
501
502impl JournalFile {
503 pub async fn open(dir: impl Into<PathBuf>, version: SnapshotVersion) -> StoreResult<Self> {
504 let dir: PathBuf = dir.into();
505 let file = fs::OpenOptions::new()
506 .create(true)
507 .append(true)
508 .open(dir.join(format!("{:0>10}.journal", version)))
509 .await
510 .map_err(StoreError::JournalIO)?;
511 Ok(Self {
512 writer: BufWriter::new(file),
513 written_entries: 0,
514 })
515 }
516
517 async fn append(&mut self, transaction: &[u8]) -> std::io::Result<()> {
518 self.writer.write_all(transaction).await?;
519 self.written_entries += 1;
520 Ok(())
521 }
522
523 async fn flush(&mut self) -> std::io::Result<()> {
524 self.writer.flush().await
525 }
526
527 async fn flush_and_sync(&mut self) -> std::io::Result<()> {
528 self.flush().await?;
529 self.writer.get_mut().sync_data().await
530 }
531
532 async fn persist(&mut self, flush_method: JournalFlushMethod) -> std::io::Result<()> {
533 match flush_method {
534 JournalFlushMethod::FlushAndSync => {
535 self.flush().await?;
536 self.writer.get_mut().sync_data().await
537 }
538 JournalFlushMethod::Flush => self.flush().await,
539 }
540 }
541
542 async fn parse<T, S>(serializer: &S, file: &mut File) -> StoreResult<Vec<T>>
543 where
544 S: Serializer<T>,
545 {
546 let mut buf = Vec::new();
547 file.read_to_end(&mut buf)
548 .await
549 .map_err(StoreError::JournalIO)?;
550
551 let mut cursor = &buf[..];
552 let mut transactions: Vec<T> = Vec::new();
553 while let Some(entry) = serializer.deserialize(&mut cursor).transpose() {
554 let tx = entry.map_err(|err| StoreError::DecodeJournalEntry(err.into()))?;
555 transactions.push(tx);
556 }
557 Ok(transactions)
558 }
559}
560
561enum PersistenceAction {
562 Snapshot {
563 version: SnapshotVersion,
564 path: PathBuf,
565 },
566 Journal {
567 version: SnapshotVersion,
570 path: PathBuf,
571 },
572}
573
574impl PersistenceAction {
575 async fn rebuild(dir_path: impl AsRef<Path>) -> StoreResult<Vec<PersistenceAction>> {
586 let mut actions = Self::read_dir(&dir_path).await?;
587 actions.sort_by_key(|action| action.snapshot_version());
588 let latest_version = actions
589 .iter()
590 .filter_map(|action| match action {
591 PersistenceAction::Snapshot {
592 version: snapshot_version,
593 ..
594 } => Some(*snapshot_version),
595 _ => None,
596 })
597 .last();
598 if let Some(latest_version) = latest_version {
599 actions.retain(|action| match action {
600 PersistenceAction::Journal { version, .. } => *version > latest_version,
601 PersistenceAction::Snapshot { version, .. } => *version == latest_version,
602 });
603 }
604 Ok(actions)
605 }
606
607 async fn read_dir(path: impl AsRef<Path>) -> StoreResult<Vec<PersistenceAction>> {
609 let mut actions = Vec::new();
610 let mut read_dir = fs::read_dir(&path).await.map_err(StoreError::JournalIO)?;
611 while let Some(entry) = read_dir.next_entry().await.map_err(StoreError::JournalIO)? {
612 let entry_path = entry.path();
613 let version = entry_path
614 .file_stem()
615 .and_then(OsStr::to_str)
616 .and_then(|path| path.parse::<SnapshotVersion>().ok())
617 .ok_or_else(|| {
618 StoreError::JournalInvalidFileName(entry_path.file_stem().map(OsStr::to_owned))
619 })?;
620 match entry_path.extension() {
621 Some(extension) if extension == "journal" => {
622 actions.push(PersistenceAction::Journal {
623 version,
624 path: entry_path,
625 });
626 }
627 Some(extension) if extension == "snapshot" => {
628 actions.push(PersistenceAction::Snapshot {
629 version,
630 path: entry_path,
631 });
632 }
633 _ => {}
634 }
635 }
636 Ok(actions)
637 }
638
639 fn snapshot_version(&self) -> u32 {
640 match self {
641 PersistenceAction::Snapshot { version, .. } => *version,
642 PersistenceAction::Journal { version, .. } => *version,
643 }
644 }
645}
646
647pub trait Serializer<T> {
648 type Error: std::error::Error + Send + Sync + 'static;
649
650 fn serialize(&self, transaction: &T) -> Result<Vec<u8>, Self::Error>;
651
652 fn deserialize<R>(&self, reader: R) -> Result<Option<T>, Self::Error>
653 where
654 R: std::io::Read;
655}
656
657pub struct JsonSerializer;
658
659impl<D> Serializer<D> for JsonSerializer
660where
661 D: serde::Serialize,
662 D: for<'a> serde::Deserialize<'a>,
663{
664 type Error = serde_json::Error;
665
666 fn serialize(&self, data: &D) -> Result<Vec<u8>, Self::Error> {
667 serde_json::to_vec(data).map(|mut bytes| {
668 bytes.push(b'\n');
670 bytes
671 })
672 }
673
674 fn deserialize<'a, R>(&self, reader: R) -> Result<Option<D>, Self::Error>
675 where
676 R: std::io::Read,
677 {
678 let mut deserializer = serde_json::Deserializer::from_reader(reader);
679 match serde::de::Deserialize::deserialize(&mut deserializer) {
680 Ok(data) => Ok(data),
681 Err(err) if err.is_eof() => {
682 Ok(None)
687 }
688 Err(err) => Err(err),
689 }
690 }
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696 use serde::{Deserialize, Serialize};
697 use tempfile::tempdir;
698
699 #[derive(Serialize, Deserialize, Default, Debug)]
700 struct Counter {
701 value: usize,
702 }
703
704 MergeTx!(CounterTx<Counter> = Increase | IncreaseBy | DecreaseBy);
705
706 #[derive(Serialize, Deserialize)]
707 struct Increase;
708
709 #[derive(Serialize, Deserialize)]
710 struct IncreaseBy {
711 by: usize,
712 }
713
714 #[derive(Serialize, Deserialize)]
715 struct DecreaseBy {
716 by: usize,
717 }
718
719 impl Tx<Counter> for Increase {
720 fn execute(self, data: &mut Counter) {
721 data.value += 1;
722 }
723 }
724
725 impl Tx<Counter> for IncreaseBy {
726 fn execute(self, data: &mut Counter) {
727 data.value += data.value;
728 }
729 }
730
731 impl Tx<Counter> for DecreaseBy {
732 fn execute(self, data: &mut Counter) {
733 data.value -= data.value;
734 }
735 }
736
737 #[tokio::test]
738 async fn test_journal_chunking() {
739 let dir = tempdir().unwrap();
740 let options = StoreOptions::default().max_journal_entries(NonZeroUsize::new(2).unwrap());
741 let mut store: JsonStore<Counter, CounterTx> =
742 Store::open(JsonSerializer, options, dir.path())
743 .await
744 .unwrap();
745 let first_ver = get_snapshot_version(&store).await;
746 store.commit(IncreaseBy { by: 2115 }).await.unwrap();
747 store.commit(Increase).await.unwrap();
748 assert_eq!(get_snapshot_version(&store).await, first_ver);
749 store.commit(Increase).await.unwrap();
750 assert_eq!(get_snapshot_version(&store).await, first_ver + 1);
751 }
752
753 #[tokio::test]
754 async fn test_retake_unfulfilled_journal_on_recovery() {
755 let dir = tempdir().unwrap();
756 let options = StoreOptions::default().max_journal_entries(NonZeroUsize::new(10).unwrap());
757 let first_ver = {
758 let mut store: JsonStore<Counter, CounterTx> =
759 Store::open(JsonSerializer, options.clone(), dir.path())
760 .await
761 .unwrap();
762 store.commit(Increase).await.unwrap();
763 get_snapshot_version(&store).await
764 };
765
766 let mut store: JsonStore<Counter, CounterTx> =
767 Store::open(JsonSerializer, options, dir.path())
768 .await
769 .unwrap();
770 store.commit(Increase).await.unwrap();
771 assert_eq!(get_snapshot_version(&store).await, first_ver);
772 }
773
774 async fn get_snapshot_version(store: &JsonStore<Counter, CounterTx>) -> u32 {
775 store
776 .inner
777 .persistent
778 .read()
779 .await
780 .as_ref()
781 .unwrap()
782 .next_snapshot_version
783 }
784}