1use std::ffi::OsStr;
7use std::hash::{BuildHasher, BuildHasherDefault, Hasher};
8use std::time::Duration;
9
10use bytecheck::CheckBytes;
11use rkyv::ser::serializers::{AlignedSerializer, AllocSerializer};
12use rkyv::ser::Serializer;
13use rkyv::validation::validators::DefaultValidator;
14use rkyv::{archived_root, check_archived_root, AlignedVec, Archive, Serialize};
15use thiserror::Error;
16use wyhash::WyHash;
17
18use crate::data::DataContainer;
19use crate::guard::{ReadGuard, ReadResult};
20use crate::instance::InstanceVersion;
21use crate::locks::{LockDisabled, WriteLockStrategy};
22use crate::state::StateContainer;
23use crate::synchronizer::SynchronizerError::*;
24
25pub struct Synchronizer<
35 H: Hasher + Default = WyHash,
36 WL = LockDisabled,
37 const N: usize = 1024,
38 const SD: u64 = 1_000_000_000,
39> {
40 state_container: StateContainer<WL>,
42 data_container: DataContainer,
44 build_hasher: BuildHasherDefault<H>,
46 serialize_buffer: Option<AlignedVec>,
48}
49
50#[derive(Error, Debug)]
54pub enum SynchronizerError {
55 #[error("error writing data file: {0}")]
57 FailedDataWrite(std::io::Error),
58 #[error("error reading data file: {0}")]
60 FailedDataRead(std::io::Error),
61 #[error("error reading state file: {0}")]
63 FailedStateRead(std::io::Error),
64 #[error("error writing entity")]
66 FailedEntityWrite,
67 #[error("error reading entity")]
69 FailedEntityRead,
70 #[error("uninitialized state")]
72 UninitializedState,
73 #[error("invalid instance version params")]
75 InvalidInstanceVersionParams,
76 #[error("write blocked by conflicting lock")]
78 WriteLockConflict,
79}
80
81impl Synchronizer {
82 pub fn new(path_prefix: &OsStr) -> Self {
84 Self::with_params(path_prefix)
85 }
86}
87
88impl<'a, H, WL, const N: usize, const SD: u64> Synchronizer<H, WL, N, SD>
89where
90 H: Hasher + Default,
91 WL: WriteLockStrategy<'a>,
92{
93 pub fn with_params(path_prefix: &OsStr) -> Self {
95 Synchronizer {
96 state_container: StateContainer::new(path_prefix),
97 data_container: DataContainer::new(path_prefix),
98 build_hasher: BuildHasherDefault::default(),
99 serialize_buffer: Some(AlignedVec::new()),
100 }
101 }
102
103 pub fn write<T>(
121 &'a mut self,
122 entity: &T,
123 grace_duration: Duration,
124 ) -> Result<(usize, bool), SynchronizerError>
125 where
126 T: Serialize<AllocSerializer<N>>,
127 T::Archived: for<'b> CheckBytes<DefaultValidator<'b>>,
128 {
129 let mut buf = self.serialize_buffer.take().ok_or(FailedEntityWrite)?;
130 buf.clear();
131
132 let mut serializer = AllocSerializer::new(
134 AlignedSerializer::new(buf),
135 Default::default(),
136 Default::default(),
137 );
138 let _ = serializer
139 .serialize_value(entity)
140 .map_err(|_| FailedEntityWrite)?;
141 let data = serializer.into_serializer().into_inner();
142
143 check_archived_root::<T>(&data).map_err(|_| FailedEntityRead)?;
145
146 let state = self.state_container.state::<true>(true)?;
148
149 let mut hasher = self.build_hasher.build_hasher();
151 hasher.write(&data);
152 let checksum = hasher.finish();
153
154 let acquire_sleep_duration = Duration::from_nanos(SD);
156 let (new_idx, reset) = state.acquire_next_idx(grace_duration, acquire_sleep_duration);
157 let new_version = InstanceVersion::new(new_idx, data.len(), checksum)?;
158 let size = self.data_container.write(&data, new_version)?;
159
160 state.switch_version(new_version);
162
163 self.serialize_buffer.replace(data);
165
166 Ok((size, reset))
167 }
168
169 pub fn write_raw<T>(
173 &'a mut self,
174 data: &[u8],
175 grace_duration: Duration,
176 ) -> Result<(usize, bool), SynchronizerError>
177 where
178 T: Serialize<AllocSerializer<N>>,
179 T::Archived: for<'b> CheckBytes<DefaultValidator<'b>>,
180 {
181 let state = self.state_container.state::<true>(true)?;
183
184 let mut hasher = self.build_hasher.build_hasher();
186 hasher.write(data);
187 let checksum = hasher.finish();
188
189 let acquire_sleep_duration = Duration::from_nanos(SD);
191 let (new_idx, reset) = state.acquire_next_idx(grace_duration, acquire_sleep_duration);
192 let new_version = InstanceVersion::new(new_idx, data.len(), checksum)?;
193 let size = self.data_container.write(data, new_version)?;
194
195 state.switch_version(new_version);
197
198 Ok((size, reset))
199 }
200
201 pub unsafe fn read<T>(
220 &'a mut self,
221 check_bytes: bool,
222 ) -> Result<ReadResult<T>, SynchronizerError>
223 where
224 T: Archive,
225 T::Archived: for<'b> CheckBytes<DefaultValidator<'b>>,
226 {
227 let state = self.state_container.state::<false>(false)?;
229
230 let version = state.version()?;
232
233 let guard = ReadGuard::new(state, version)?;
235
236 let (data, switched) = self.data_container.data(version)?;
238
239 let entity = match check_bytes {
241 false => archived_root::<T>(data),
242 true => check_archived_root::<T>(data).map_err(|_| FailedEntityRead)?,
243 };
244
245 Ok(ReadResult::new(guard, entity, switched))
246 }
247
248 pub fn version(&'a mut self) -> Result<InstanceVersion, SynchronizerError> {
251 let state = self.state_container.state::<false>(false)?;
253
254 state.version()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use crate::instance::InstanceVersion;
262 use crate::locks::SingleWriter;
263 use crate::synchronizer::{Synchronizer, SynchronizerError};
264 use bytecheck::CheckBytes;
265 use rand::distributions::Uniform;
266 use rand::prelude::*;
267 use rkyv::{Archive, Deserialize, Serialize};
268 use std::collections::HashMap;
269 use std::fs;
270 use std::path::Path;
271 use std::time::Duration;
272 use wyhash::WyHash;
273
274 #[derive(Archive, Deserialize, Serialize, Debug, PartialEq)]
275 #[archive_attr(derive(CheckBytes))]
276 struct MockEntity {
277 version: u32,
278 map: HashMap<u64, Vec<f32>>,
279 }
280
281 struct MockEntityGenerator {
282 rng: StdRng,
283 }
284
285 impl MockEntityGenerator {
286 fn new(seed: u8) -> Self {
287 MockEntityGenerator {
288 rng: StdRng::from_seed([seed; 32]),
289 }
290 }
291
292 fn gen(&mut self, n: usize) -> MockEntity {
293 let mut entity = MockEntity {
294 version: self.rng.gen(),
295 map: HashMap::new(),
296 };
297 let range = Uniform::<f32>::from(0.0..100.0);
298 for _ in 0..n {
299 let key: u64 = self.rng.gen();
300 let n_vals = self.rng.gen::<usize>() % 20;
301 let vals: Vec<f32> = (0..n_vals).map(|_| self.rng.sample(range)).collect();
302 entity.map.insert(key, vals);
303 }
304 entity
305 }
306 }
307
308 #[test]
309 fn test_synchronizer() {
310 let path = "/tmp/synchro_test";
311 let state_path = path.to_owned() + "_state";
312 let data_path_0 = path.to_owned() + "_data_0";
313 let data_path_1 = path.to_owned() + "_data_1";
314
315 fs::remove_file(&state_path).unwrap_or_default();
317 fs::remove_file(&data_path_0).unwrap_or_default();
318 fs::remove_file(&data_path_1).unwrap_or_default();
319
320 let mut writer = Synchronizer::new(path.as_ref());
322 let mut reader = Synchronizer::new(path.as_ref());
323
324 let mut entity_generator = MockEntityGenerator::new(3);
326
327 let res = unsafe { reader.read::<MockEntity>(false) };
329 assert!(res.is_err());
330 assert_eq!(
331 res.err().unwrap().to_string(),
332 "error reading state file: No such file or directory (os error 2)"
333 );
334 assert!(!Path::new(&state_path).exists());
335
336 let entity = entity_generator.gen(100);
338 let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
339 assert!(size > 0);
340 assert_eq!(reset, false);
341 assert!(Path::new(&state_path).exists());
342 assert!(!Path::new(&data_path_1).exists());
343 assert_eq!(
344 reader.version().unwrap(),
345 InstanceVersion(8817430144856633152)
346 );
347
348 fetch_and_assert_entity(&mut reader, &entity, true);
350
351 fetch_and_assert_entity(&mut reader, &entity, false);
353
354 let entity = entity_generator.gen(200);
356 let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
357 assert!(size > 0);
358 assert_eq!(reset, false);
359 assert!(Path::new(&state_path).exists());
360 assert!(Path::new(&data_path_0).exists());
361 assert!(Path::new(&data_path_1).exists());
362 assert_eq!(
363 reader.version().unwrap(),
364 InstanceVersion(1441050725688826209)
365 );
366
367 fetch_and_assert_entity(&mut reader, &entity, true);
369
370 let entity = entity_generator.gen(100);
372 let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
373 assert!(size > 0);
374 assert_eq!(reset, false);
375 assert_eq!(
376 reader.version().unwrap(),
377 InstanceVersion(14058099486534675680)
378 );
379
380 let entity = entity_generator.gen(200);
381 let (size, reset) = writer.write(&entity, Duration::from_secs(1)).unwrap();
382 assert!(size > 0);
383 assert_eq!(reset, false);
384 assert_eq!(
385 reader.version().unwrap(),
386 InstanceVersion(18228729609619266545)
387 );
388
389 fetch_and_assert_entity(&mut reader, &entity, true);
390 }
391
392 fn fetch_and_assert_entity(
393 synchronizer: &mut Synchronizer,
394 expected_entity: &MockEntity,
395 expected_is_switched: bool,
396 ) {
397 let actual_entity = unsafe { synchronizer.read::<MockEntity>(false).unwrap() };
398 assert_eq!(actual_entity.map, expected_entity.map);
399 assert_eq!(actual_entity.version, expected_entity.version);
400 assert_eq!(actual_entity.is_switched(), expected_is_switched);
401 }
402
403 #[test]
404 fn single_writer_lock_prevents_multiple_writers() {
405 static PATH: &str = "/tmp/synchronizer_single_writer";
406 let mut entity_generator = MockEntityGenerator::new(3);
407 let entity = entity_generator.gen(100);
408
409 let mut writer1 = Synchronizer::<WyHash, SingleWriter>::with_params(PATH.as_ref());
410 let mut writer2 = Synchronizer::<WyHash, SingleWriter>::with_params(PATH.as_ref());
411
412 writer1.write(&entity, Duration::from_secs(1)).unwrap();
413 assert!(matches!(
414 writer2.write(&entity, Duration::from_secs(1)),
415 Err(SynchronizerError::WriteLockConflict)
416 ));
417 }
418}