use crate::errors::{Error, Result};
use crate::server::VersionId;
use crate::storage::{StorageTxn, TaskMap};
use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression};
use serde::de::{Deserialize, Deserializer, MapAccess, Visitor};
use serde::ser::{Serialize, SerializeMap, Serializer};
use std::fmt;
use uuid::Uuid;
pub(super) struct SnapshotTasks(Vec<(Uuid, TaskMap)>);
impl Serialize for SnapshotTasks {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(self.0.len()))?;
for (k, v) in &self.0 {
map.serialize_entry(k, v)?;
}
map.end()
}
}
struct TaskDbVisitor;
impl<'de> Visitor<'de> for TaskDbVisitor {
type Value = SnapshotTasks;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a map representing a task snapshot")
}
fn visit_map<M>(self, mut access: M) -> std::result::Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut map = SnapshotTasks(Vec::with_capacity(access.size_hint().unwrap_or(0)));
while let Some((key, value)) = access.next_entry()? {
map.0.push((key, value));
}
Ok(map)
}
}
impl<'de> Deserialize<'de> for SnapshotTasks {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(TaskDbVisitor)
}
}
impl SnapshotTasks {
pub(super) fn encode(&self) -> Result<Vec<u8>> {
let encoder = ZlibEncoder::new(Vec::new(), Compression::default());
let mut encoder = std::io::BufWriter::new(encoder);
serde_json::to_writer(&mut encoder, &self)?;
let encoder = encoder
.into_inner()
.map_err(|e| anyhow::anyhow!("While flushing snapshot encoder: {e}"))?;
Ok(encoder.finish()?)
}
pub(super) fn decode(snapshot: &[u8]) -> Result<Self> {
let decoder = ZlibDecoder::new(snapshot);
Ok(serde_json::from_reader(decoder)?)
}
pub(super) fn into_inner(self) -> Vec<(Uuid, TaskMap)> {
self.0
}
}
pub(super) async fn make_snapshot(txn: &mut dyn StorageTxn) -> Result<Vec<u8>> {
let all_tasks = SnapshotTasks(txn.all_tasks().await?);
all_tasks.encode()
}
pub(super) async fn apply_snapshot(
txn: &mut dyn StorageTxn,
version: VersionId,
snapshot: &[u8],
) -> Result<()> {
let all_tasks = SnapshotTasks::decode(snapshot)?;
if !txn.is_empty().await? {
return Err(Error::Database(String::from(
"Cannot apply snapshot to a non-empty task database",
)));
}
for (uuid, task) in all_tasks.into_inner().drain(..) {
txn.set_task(uuid, task).await?;
}
txn.set_base_version(version).await?;
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use crate::storage::{inmemory::InMemoryStorage, Storage, TaskMap};
use pretty_assertions::assert_eq;
#[test]
fn test_serialize_empty() -> Result<()> {
let empty = SnapshotTasks(vec![]);
assert_eq!(serde_json::to_vec(&empty)?, b"{}".to_owned());
Ok(())
}
#[test]
fn test_serialize_tasks() -> Result<()> {
let u = Uuid::new_v4();
let m: TaskMap = vec![("description".to_owned(), "my task".to_owned())]
.drain(..)
.collect();
let all_tasks = SnapshotTasks(vec![(u, m)]);
assert_eq!(
serde_json::to_vec(&all_tasks)?,
format!("{{\"{u}\":{{\"description\":\"my task\"}}}}").into_bytes(),
);
Ok(())
}
#[tokio::test]
async fn test_round_trip() -> Result<()> {
let mut storage = InMemoryStorage::new();
let version = Uuid::new_v4();
let task1 = (
Uuid::new_v4(),
vec![("description".to_owned(), "one".to_owned())]
.drain(..)
.collect::<TaskMap>(),
);
let task2 = (
Uuid::new_v4(),
vec![("description".to_owned(), "two".to_owned())]
.drain(..)
.collect::<TaskMap>(),
);
{
let mut txn = storage.txn().await?;
txn.set_task(task1.0, task1.1.clone()).await?;
txn.set_task(task2.0, task2.1.clone()).await?;
txn.commit().await?;
}
let snap = {
let mut txn = storage.txn().await?;
make_snapshot(txn.as_mut()).await?
};
let mut storage = InMemoryStorage::new();
{
let mut txn = storage.txn().await?;
apply_snapshot(txn.as_mut(), version, &snap).await?;
txn.commit().await?
}
{
let mut txn = storage.txn().await?;
assert_eq!(txn.get_task(task1.0).await?, Some(task1.1));
assert_eq!(txn.get_task(task2.0).await?, Some(task2.1));
assert_eq!(txn.all_tasks().await?.len(), 2);
assert_eq!(txn.base_version().await?, version);
assert_eq!(txn.unsynced_operations().await?.len(), 0);
assert_eq!(txn.get_working_set().await?.len(), 1);
}
Ok(())
}
}