use std::cell::UnsafeCell;
use std::fmt::{Debug, Formatter};
use std::sync::{Arc, Condvar, Mutex};
use serde::{Deserialize, Serialize};
mod iterate;
mod iterate_delta;
mod iteration_end;
mod leader;
mod replay;
mod state_handler;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub(crate) enum IterationResult {
Continue,
Finished,
}
impl IterationResult {
pub(crate) fn from_condition(should_continue: bool) -> IterationResult {
if should_continue {
Self::Continue
} else {
Self::Finished
}
}
}
pub(crate) type StateFeedback<State> = (IterationResult, State);
#[derive(Clone, Debug)]
pub(crate) struct IterationStateRef<State> {
state: Arc<UnsafeCell<State>>,
}
impl<State> IterationStateRef<State> {
fn new(init: State) -> Self {
Self {
state: Arc::new(UnsafeCell::new(init)),
}
}
unsafe fn set(&self, new_state: State) {
let state_ptr = &mut *self.state.get();
*state_ptr = new_state;
}
unsafe fn get(&self) -> &State {
&*self.state.get()
}
}
unsafe impl<State: Send + Sync> Send for IterationStateRef<State> {}
#[derive(Clone)]
pub struct IterationStateHandle<T> {
result: IterationStateRef<T>,
}
impl<T: Clone> IterationStateHandle<T> {
pub(crate) fn new(init: T) -> Self {
Self {
result: IterationStateRef::new(init),
}
}
pub(crate) unsafe fn set(&self, new_state: T) {
self.result.set(new_state);
}
pub fn get(&self) -> &T {
unsafe { self.result.get() }
}
}
impl<T> Debug for IterationStateHandle<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IterationStateHandle").finish()
}
}
#[derive(Debug, Default)]
#[allow(clippy::mutex_atomic)]
pub(crate) struct IterationStateLock {
generation: Mutex<usize>,
cond_var: Condvar,
}
#[allow(clippy::mutex_atomic)]
impl IterationStateLock {
pub fn lock(&self) {
let mut lock = self.generation.lock().unwrap();
if *lock % 2 == 0 {
*lock += 1;
}
}
pub fn unlock(&self) {
let mut lock = self.generation.lock().unwrap();
assert_eq!(*lock % 2, 1, "cannot unlock a non-locked lock");
*lock += 1;
self.cond_var.notify_all();
}
pub fn wait_for_update(&self, generation: usize) {
let _gen = self
.cond_var
.wait_while(self.generation.lock().unwrap(), |r| *r < generation)
.unwrap();
}
}