use std::ffi::OsStr;
use std::hash::{BuildHasher, BuildHasherDefault, Hasher};
use std::time::Duration;
use bytecheck::CheckBytes;
use rkyv::ser::serializers::{AlignedSerializer, AllocSerializer};
use rkyv::ser::Serializer;
use rkyv::validation::validators::DefaultValidator;
use rkyv::{archived_root, check_archived_root, AlignedVec, Archive, Serialize};
use thiserror::Error;
use wyhash::WyHash;
use crate::data::DataContainer;
use crate::guard::{ReadGuard, ReadResult};
use crate::instance::InstanceVersion;
use crate::locks::{LockDisabled, WriteLockStrategy};
use crate::state::StateContainer;
use crate::synchronizer::SynchronizerError::*;
pub struct Synchronizer<
H: Hasher + Default = WyHash,
WL = LockDisabled,
const N: usize = 1024,
const SD: u64 = 1_000_000_000,
> {
state_container: StateContainer<WL>,
data_container: DataContainer,
build_hasher: BuildHasherDefault<H>,
serialize_buffer: Option<AlignedVec>,
}
#[derive(Error, Debug)]
pub enum SynchronizerError {
#[error("error writing data file: {0}")]
FailedDataWrite(std::io::Error),
#[error("error reading data file: {0}")]
FailedDataRead(std::io::Error),
#[error("error reading state file: {0}")]
FailedStateRead(std::io::Error),
#[error("error writing entity")]
FailedEntityWrite,
#[error("error reading entity")]
FailedEntityRead,
#[error("uninitialized state")]
UninitializedState,
#[error("invalid instance version params")]
InvalidInstanceVersionParams,
#[error("write blocked by conflicting lock")]
WriteLockConflict,
}
impl Synchronizer {
pub fn new(path_prefix: &OsStr) -> Self {
Self::with_params(path_prefix)
}
}
impl<'a, H, WL, const N: usize, const SD: u64> Synchronizer<H, WL, N, SD>
where
H: Hasher + Default,
WL: WriteLockStrategy<'a>,
{
pub fn with_params(path_prefix: &OsStr) -> Self {
Synchronizer {
state_container: StateContainer::new(path_prefix),
data_container: DataContainer::new(path_prefix),
build_hasher: BuildHasherDefault::default(),
serialize_buffer: Some(AlignedVec::new()),
}
}
pub fn write<T>(
&'a mut self,
entity: &T,
grace_duration: Duration,
) -> Result<(usize, bool), SynchronizerError>
where
T: Serialize<AllocSerializer<N>>,
T::Archived: for<'b> CheckBytes<DefaultValidator<'b>>,
{
let mut buf = self.serialize_buffer.take().ok_or(FailedEntityWrite)?;
buf.clear();
let mut serializer = AllocSerializer::new(
AlignedSerializer::new(buf),
Default::default(),
Default::default(),
);
let _ = serializer
.serialize_value(entity)
.map_err(|_| FailedEntityWrite)?;
let data = serializer.into_serializer().into_inner();
check_archived_root::<T>(&data).map_err(|_| FailedEntityRead)?;
let state = self.state_container.state::<true>(true)?;
let mut hasher = self.build_hasher.build_hasher();
hasher.write(&data);
let checksum = hasher.finish();
let acquire_sleep_duration = Duration::from_nanos(SD);
let (new_idx, reset) = state.acquire_next_idx(grace_duration, acquire_sleep_duration);
let new_version = InstanceVersion::new(new_idx, data.len(), checksum)?;
let size = self.data_container.write(&data, new_version)?;
state.switch_version(new_version);
self.serialize_buffer.replace(data);
Ok((size, reset))
}
pub fn write_raw<T>(
&'a mut self,
data: &[u8],
grace_duration: Duration,
) -> Result<(usize, bool), SynchronizerError>
where
T: Serialize<AllocSerializer<N>>,
T::Archived: for<'b> CheckBytes<DefaultValidator<'b>>,
{
let state = self.state_container.state::<true>(true)?;
let mut hasher = self.build_hasher.build_hasher();
hasher.write(data);
let checksum = hasher.finish();
let acquire_sleep_duration = Duration::from_nanos(SD);
let (new_idx, reset) = state.acquire_next_idx(grace_duration, acquire_sleep_duration);
let new_version = InstanceVersion::new(new_idx, data.len(), checksum)?;
let size = self.data_container.write(data, new_version)?;
state.switch_version(new_version);
Ok((size, reset))
}
pub unsafe fn read<T>(
&'a mut self,
check_bytes: bool,
) -> Result<ReadResult<T>, SynchronizerError>
where
T: Archive,
T::Archived: for<'b> CheckBytes<DefaultValidator<'b>>,
{
let state = self.state_container.state::<false>(false)?;
let version = state.version()?;
let guard = ReadGuard::new(state, version)?;
let (data, switched) = self.data_container.data(version)?;
let entity = match check_bytes {
false => archived_root::<T>(data),
true => check_archived_root::<T>(data).map_err(|_| FailedEntityRead)?,
};
Ok(ReadResult::new(guard, entity, switched))
}
pub fn version(&'a mut self) -> Result<InstanceVersion, SynchronizerError> {
let state = self.state_container.state::<false>(false)?;
state.version()
}
}
#[cfg(test)]
mod tests {
use crate::instance::InstanceVersion;
use crate::locks::SingleWriter;
use crate::synchronizer::{Synchronizer, SynchronizerError};
use bytecheck::CheckBytes;
use rand::distributions::Uniform;
use rand::prelude::*;
use rkyv::{Archive, Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::time::Duration;
use wyhash::WyHash;
#[derive(Archive, Deserialize, Serialize, Debug, PartialEq)]
#[archive_attr(derive(CheckBytes))]
struct MockEntity {
version: u32,
map: HashMap<u64, Vec<f32>>,
}
struct MockEntityGenerator {
rng: StdRng,
}
impl MockEntityGenerator {
fn new(seed: u8) -> Self {
MockEntityGenerator {
rng: StdRng::from_seed([seed; 32]),
}
}
fn gen(&mut self, n: usize) -> MockEntity {
let mut entity = MockEntity {
version: self.rng.gen(),
map: HashMap::new(),
};
let range = Uniform::<f32>::from(0.0..100.0);
for _ in 0..n {
let key: u64 = self.rng.gen();
let n_vals = self.rng.gen::<usize>() % 20;
let vals: Vec<f32> = (0..n_vals).map(|_| self.rng.sample(range)).collect();
entity.map.insert(key, vals);
}
entity
}
}
#[test]
fn test_synchronizer() {
let path = "/tmp/synchro_test";
let state_path = path.to_owned() + "_state";
let data_path_0 = path.to_owned() + "_data_0";
let data_path_1 = path.to_owned() + "_data_1";
fs::remove_file(&state_path).unwrap_or_default();
fs::remove_file(&data_path_0).unwrap_or_default();
fs::remove_file(&data_path_1).unwrap_or_default();
let mut writer = Synchronizer::new(path.as_ref());
let mut reader = Synchronizer::new(path.as_ref());
let mut entity_generator = MockEntityGenerator::new(3);
let res = unsafe { reader.read::<MockEntity>(false) };
assert!(res.is_err());
assert_eq!(
res.err().unwrap().to_string(),
"error reading state file: No such file or directory (os error 2)"
);
assert!(!Path::new(&state_path).exists());
let entity = entity_generator.gen(100);
let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
assert!(size > 0);
assert_eq!(reset, false);
assert!(Path::new(&state_path).exists());
assert!(!Path::new(&data_path_1).exists());
assert_eq!(
reader.version().unwrap(),
InstanceVersion(8817430144856633152)
);
fetch_and_assert_entity(&mut reader, &entity, true);
fetch_and_assert_entity(&mut reader, &entity, false);
let entity = entity_generator.gen(200);
let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
assert!(size > 0);
assert_eq!(reset, false);
assert!(Path::new(&state_path).exists());
assert!(Path::new(&data_path_0).exists());
assert!(Path::new(&data_path_1).exists());
assert_eq!(
reader.version().unwrap(),
InstanceVersion(1441050725688826209)
);
fetch_and_assert_entity(&mut reader, &entity, true);
let entity = entity_generator.gen(100);
let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
assert!(size > 0);
assert_eq!(reset, false);
assert_eq!(
reader.version().unwrap(),
InstanceVersion(14058099486534675680)
);
let entity = entity_generator.gen(200);
let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
assert!(size > 0);
assert_eq!(reset, false);
assert_eq!(
reader.version().unwrap(),
InstanceVersion(18228729609619266545)
);
fetch_and_assert_entity(&mut reader, &entity, true);
}
fn fetch_and_assert_entity(
synchronizer: &mut Synchronizer,
expected_entity: &MockEntity,
expected_is_switched: bool,
) {
let actual_entity = unsafe { synchronizer.read::<MockEntity>(false).unwrap() };
assert_eq!(actual_entity.map, expected_entity.map);
assert_eq!(actual_entity.version, expected_entity.version);
assert_eq!(actual_entity.is_switched(), expected_is_switched);
}
#[test]
fn single_writer_lock_prevents_multiple_writers() {
static PATH: &str = "/tmp/synchronizer_single_writer";
let mut entity_generator = MockEntityGenerator::new(3);
let entity = entity_generator.gen(100);
let mut writer1 = Synchronizer::<WyHash, SingleWriter>::with_params(PATH.as_ref());
let mut writer2 = Synchronizer::<WyHash, SingleWriter>::with_params(PATH.as_ref());
writer1.write(&entity, Duration::from_secs(1)).unwrap();
assert!(matches!(
writer2.write(&entity, Duration::from_secs(1)),
Err(SynchronizerError::WriteLockConflict)
));
}
}