use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use parking_lot::{Condvar, Mutex, RwLock};
use crate::binary::Bounds;
use crate::error::ComputableError;
pub trait BoundsAccess {
fn get_bounds(&self) -> Result<Bounds, ComputableError>;
}
pub trait BaseNode: Send + Sync {
fn get_bounds(&self) -> Result<Bounds, ComputableError>;
fn refine(&self) -> Result<(), ComputableError>;
}
#[derive(Clone)]
struct BaseSnapshot<X> {
state: X,
bounds: Option<Bounds>,
}
pub struct TypedBaseNode<X, B, F>
where
X: Eq + Clone + Send + Sync + 'static,
B: Fn(&X) -> Result<Bounds, ComputableError> + Send + Sync + 'static,
F: Fn(X) -> Result<X, ComputableError> + Send + Sync + 'static,
{
snapshot: RwLock<BaseSnapshot<X>>,
bounds: B,
refine: F,
}
impl<X, B, F> TypedBaseNode<X, B, F>
where
X: Eq + Clone + Send + Sync + 'static,
B: Fn(&X) -> Result<Bounds, ComputableError> + Send + Sync + 'static,
F: Fn(X) -> Result<X, ComputableError> + Send + Sync + 'static,
{
pub fn new(state: X, bounds: B, refine: F) -> Self {
Self {
snapshot: RwLock::new(BaseSnapshot {
state,
bounds: None,
}),
bounds,
refine,
}
}
fn snapshot_bounds(&self, snapshot: &mut BaseSnapshot<X>) -> Result<Bounds, ComputableError> {
if let Some(bounds) = &snapshot.bounds {
return Ok(bounds.clone());
}
let bounds = (self.bounds)(&snapshot.state)?;
snapshot.bounds = Some(bounds.clone());
Ok(bounds)
}
}
impl<X, B, F> BaseNode for TypedBaseNode<X, B, F>
where
X: Eq + Clone + Send + Sync + 'static,
B: Fn(&X) -> Result<Bounds, ComputableError> + Send + Sync + 'static,
F: Fn(X) -> Result<X, ComputableError> + Send + Sync + 'static,
{
fn get_bounds(&self) -> Result<Bounds, ComputableError> {
let mut snapshot = self.snapshot.write();
let bounds = self.snapshot_bounds(&mut snapshot)?;
Ok(bounds)
}
fn refine(&self) -> Result<(), ComputableError> {
let mut snapshot = self.snapshot.write();
let previous_bounds = self.snapshot_bounds(&mut snapshot)?;
let previous_state = snapshot.state.clone();
let next_state = (self.refine)(previous_state.clone())?;
if next_state == previous_state {
if previous_bounds.small() == &previous_bounds.large() {
return Ok(());
}
return Err(ComputableError::StateUnchanged);
}
let next_bounds = (self.bounds)(&next_state)?;
let lower_worsened = next_bounds.small() < previous_bounds.small();
let upper_worsened = next_bounds.large() > previous_bounds.large();
if lower_worsened || upper_worsened {
return Err(ComputableError::BoundsWorsened);
}
snapshot.state = next_state;
snapshot.bounds = Some(next_bounds);
Ok(())
}
}
impl<T: BaseNode + ?Sized> BoundsAccess for T {
fn get_bounds(&self) -> Result<Bounds, ComputableError> {
BaseNode::get_bounds(self)
}
}
pub trait NodeOp: Send + Sync {
fn compute_bounds(&self) -> Result<Bounds, ComputableError>;
fn refine_step(&self, precision_bits: usize) -> Result<bool, ComputableError>;
fn children(&self) -> Vec<Arc<Node>>;
fn is_refiner(&self) -> bool;
}
pub struct RefinementSync {
pub state: Mutex<RefinementState>,
pub condvar: Condvar,
}
pub struct RefinementState {
pub active: bool,
pub epoch: u64,
}
impl RefinementSync {
pub fn new() -> Self {
Self {
state: Mutex::new(RefinementState {
active: false,
epoch: 0,
}),
condvar: Condvar::new(),
}
}
pub fn notify_bounds_updated(&self) {
let mut state = self.state.lock();
state.epoch = state.epoch.wrapping_add(1);
self.condvar.notify_all();
}
}
impl Default for RefinementSync {
fn default() -> Self {
Self::new()
}
}
pub struct Node {
pub id: usize,
pub op: Arc<dyn NodeOp>,
pub bounds_cache: RwLock<Option<Bounds>>,
pub refinement: RefinementSync,
}
impl Node {
pub fn new(op: Arc<dyn NodeOp>) -> Arc<Self> {
static NODE_IDS: AtomicUsize = AtomicUsize::new(0);
Arc::new(Self {
id: NODE_IDS.fetch_add(1, Ordering::Relaxed),
op,
bounds_cache: RwLock::new(None),
refinement: RefinementSync::new(),
})
}
pub fn cached_bounds(&self) -> Option<Bounds> {
self.bounds_cache.read().clone()
}
pub fn get_bounds(&self) -> Result<Bounds, ComputableError> {
if let Some(bounds) = self.cached_bounds() {
return Ok(bounds);
}
let bounds = self.compute_bounds()?;
self.set_bounds(bounds.clone());
Ok(bounds)
}
pub fn set_bounds(&self, bounds: Bounds) {
let mut cache = self.bounds_cache.write();
*cache = Some(bounds);
self.refinement.notify_bounds_updated();
}
pub fn compute_bounds(&self) -> Result<Bounds, ComputableError> {
self.op.compute_bounds()
}
pub fn refine_step(&self, precision_bits: usize) -> Result<bool, ComputableError> {
self.op.refine_step(precision_bits)
}
pub fn children(&self) -> Vec<Arc<Node>> {
self.op.children()
}
pub fn is_refiner(&self) -> bool {
self.op.is_refiner()
}
}
impl BoundsAccess for Node {
fn get_bounds(&self) -> Result<Bounds, ComputableError> {
Node::get_bounds(self)
}
}