#![cfg(feature = "storage")]
use serde::{de::DeserializeOwned, Serialize};
use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
const KIND_SNAPSHOT: u32 = 0;
const KIND_OP: u32 = 1;
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("storage I/O: {0}")]
Io(#[from] io::Error),
#[error("storage codec: {0}")]
Codec(#[from] bincode::Error),
#[error("storage: corrupt record (kind {kind}, len {len})")]
Corrupt {
kind: u32,
len: u32,
},
#[error("storage: no snapshot")]
NoSnapshot,
}
pub trait Storage {
fn append_op<O: Serialize>(&mut self, op: &O) -> Result<(), StorageError>;
fn snapshot<S: Serialize>(&mut self, state: &S) -> Result<(), StorageError>;
fn load_snapshot<S: DeserializeOwned>(&mut self) -> Result<S, StorageError>;
fn load_ops_after_snapshot<O: DeserializeOwned>(&mut self) -> Result<Vec<O>, StorageError>;
}
pub struct FileStorage {
path: PathBuf,
file: File,
}
impl FileStorage {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, StorageError> {
let path = path.as_ref().to_path_buf();
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)?;
Ok(Self { path, file })
}
fn write_record(&mut self, kind: u32, payload: &[u8]) -> Result<(), StorageError> {
self.file.seek(SeekFrom::End(0))?;
let len = u32::try_from(payload.len()).map_err(|_| StorageError::Corrupt {
kind,
len: u32::MAX,
})?;
self.file.write_all(&kind.to_le_bytes())?;
self.file.write_all(&len.to_le_bytes())?;
self.file.write_all(payload)?;
self.file.flush()?;
self.file.sync_data()?;
Ok(())
}
fn find_latest_snapshot(&mut self) -> Result<(u64, Vec<u8>), StorageError> {
self.file.seek(SeekFrom::Start(0))?;
let mut reader = BufReader::new(&self.file);
let mut latest: Option<(u64, Vec<u8>)> = None;
let mut offset: u64 = 0;
loop {
let mut hdr = [0u8; 8];
match reader.read_exact(&mut hdr) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
let kind = u32::from_le_bytes(hdr[0..4].try_into().unwrap());
let len = u32::from_le_bytes(hdr[4..8].try_into().unwrap());
let mut payload = vec![0u8; len as usize];
reader.read_exact(&mut payload)?;
if kind == KIND_SNAPSHOT {
latest = Some((offset, payload));
} else if kind != KIND_OP {
return Err(StorageError::Corrupt { kind, len });
}
offset += 8 + u64::from(len);
}
latest.ok_or(StorageError::NoSnapshot)
}
}
impl Storage for FileStorage {
fn append_op<O: Serialize>(&mut self, op: &O) -> Result<(), StorageError> {
let payload = bincode::serialize(op)?;
self.write_record(KIND_OP, &payload)
}
fn snapshot<S: Serialize>(&mut self, state: &S) -> Result<(), StorageError> {
let payload = bincode::serialize(state)?;
let tmp_path = self.path.with_extension("tmp");
{
let mut tmp = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&tmp_path)?;
let len = u32::try_from(payload.len()).map_err(|_| StorageError::Corrupt {
kind: KIND_SNAPSHOT,
len: u32::MAX,
})?;
tmp.write_all(&KIND_SNAPSHOT.to_le_bytes())?;
tmp.write_all(&len.to_le_bytes())?;
tmp.write_all(&payload)?;
tmp.flush()?;
tmp.sync_data()?;
}
std::fs::rename(&tmp_path, &self.path)?;
self.file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&self.path)?;
Ok(())
}
fn load_snapshot<S: DeserializeOwned>(&mut self) -> Result<S, StorageError> {
let (_off, payload) = self.find_latest_snapshot()?;
Ok(bincode::deserialize(&payload)?)
}
fn load_ops_after_snapshot<O: DeserializeOwned>(&mut self) -> Result<Vec<O>, StorageError> {
let (snap_off, snap_payload) = match self.find_latest_snapshot() {
Ok((off, p)) => (off, p),
Err(StorageError::NoSnapshot) => (0, Vec::new()),
Err(e) => return Err(e),
};
let mut after = snap_off + 8 + snap_payload.len() as u64;
if snap_payload.is_empty() {
after = 0;
}
self.file.seek(SeekFrom::Start(after))?;
let mut reader = BufReader::new(&self.file);
let mut ops = Vec::new();
loop {
let mut hdr = [0u8; 8];
match reader.read_exact(&mut hdr) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
let kind = u32::from_le_bytes(hdr[0..4].try_into().unwrap());
let len = u32::from_le_bytes(hdr[4..8].try_into().unwrap());
let mut payload = vec![0u8; len as usize];
reader.read_exact(&mut payload)?;
if kind == KIND_OP {
ops.push(bincode::deserialize(&payload)?);
} else if kind != KIND_SNAPSHOT {
return Err(StorageError::Corrupt { kind, len });
}
}
Ok(ops)
}
}
#[derive(Default)]
pub struct MemoryStorage {
snapshot: Option<Vec<u8>>,
ops: Vec<Vec<u8>>,
}
impl MemoryStorage {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl Storage for MemoryStorage {
fn append_op<O: Serialize>(&mut self, op: &O) -> Result<(), StorageError> {
self.ops.push(bincode::serialize(op)?);
Ok(())
}
fn snapshot<S: Serialize>(&mut self, state: &S) -> Result<(), StorageError> {
self.snapshot = Some(bincode::serialize(state)?);
self.ops.clear();
Ok(())
}
fn load_snapshot<S: DeserializeOwned>(&mut self) -> Result<S, StorageError> {
let bytes = self.snapshot.as_ref().ok_or(StorageError::NoSnapshot)?;
Ok(bincode::deserialize(bytes)?)
}
fn load_ops_after_snapshot<O: DeserializeOwned>(&mut self) -> Result<Vec<O>, StorageError> {
self.ops
.iter()
.map(|bytes| bincode::deserialize(bytes).map_err(StorageError::from))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{List, ListOp};
#[test]
fn memory_storage_round_trip() {
let mut store = MemoryStorage::new();
let mut list = List::<char>::new(1);
list.insert(0, 'a');
list.insert(1, 'b');
store.snapshot(&list).unwrap();
let op = list.insert(2, 'c');
store.append_op(&op).unwrap();
let restored: List<char> = store.load_snapshot().unwrap();
let ops: Vec<ListOp<char>> = store.load_ops_after_snapshot().unwrap();
assert_eq!(restored.to_vec(), vec!['a', 'b']);
assert_eq!(ops.len(), 1);
let mut replayed = restored;
for op in ops {
replayed.apply(op).unwrap();
}
assert_eq!(replayed.to_vec(), vec!['a', 'b', 'c']);
}
#[test]
fn file_storage_round_trip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("doc.crdt");
{
let mut store = FileStorage::open(&path).unwrap();
let mut list = List::<u32>::new(1);
for i in 0..50u32 {
let op = list.insert(i as usize, i);
store.append_op(&op).unwrap();
}
store.snapshot(&list).unwrap();
let op = list.insert(50, 999);
store.append_op(&op).unwrap();
}
let mut store = FileStorage::open(&path).unwrap();
let snap: List<u32> = store.load_snapshot().unwrap();
assert_eq!(snap.len(), 50);
let ops: Vec<ListOp<u32>> = store.load_ops_after_snapshot().unwrap();
assert_eq!(ops.len(), 1);
let mut replayed = snap;
for op in ops {
replayed.apply(op).unwrap();
}
assert_eq!(replayed.len(), 51);
assert_eq!(replayed.get(50), Some(&999));
}
#[test]
fn file_storage_snapshot_replaces_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("doc.crdt");
let mut store = FileStorage::open(&path).unwrap();
let mut list = List::<u32>::new(1);
for i in 0..100u32 {
let op = list.insert(i as usize, i);
store.append_op(&op).unwrap();
}
store.snapshot(&list).unwrap();
let ops: Vec<ListOp<u32>> = store.load_ops_after_snapshot().unwrap();
assert!(ops.is_empty(), "snapshot didn't replace prior ops");
let restored: List<u32> = store.load_snapshot().unwrap();
assert_eq!(restored.to_vec(), list.to_vec());
}
}