use crate::DynFields;
use crate::{
InjectionError,
node::{AnyNode, Node},
};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub enum RefType {
Owned,
Borrowed,
BorrowedMut,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StageShape {
pub stage_name: &'static str,
pub inputs: &'static [&'static str],
pub outputs: &'static [&'static str],
}
#[cfg_attr(feature = "tokio", async_trait::async_trait)]
pub trait Stage: Clone + 'static {
const SHAPE: StageShape;
type State: Send + Sync;
type Input: Send + Sync + Default + DynFields;
type Output: Send + Sync + Default + DynFields;
fn evaluate(
&self,
state: &mut Self::State,
inputs: &mut Self::Input,
cache: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
) -> Result<Self::Output, InjectionError>;
#[cfg(feature = "tokio")]
async fn evaluate_async(
&self,
state: &mut Self::State,
inputs: &mut Self::Input,
cache: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
) -> Result<Self::Output, InjectionError>;
fn reeval_rule(&self) -> ReevaluationRule {
ReevaluationRule::Move
}
fn inject_input(
&self,
node: &mut Node<Self>,
parent: &mut Box<dyn AnyNode>,
output: Option<&'static str>,
input: Option<&'static str>,
) -> Result<(), InjectionError>;
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReevaluationRule {
Move,
CacheLast,
CacheAll,
}
#[cfg_attr(feature = "tokio", async_trait::async_trait)]
impl Stage for () {
const SHAPE: StageShape = StageShape {
stage_name: "()",
inputs: &[],
outputs: &[],
};
type State = ();
type Input = ();
type Output = ();
fn evaluate(
&self,
_: &mut Self::State,
_: &mut Self::Input,
_: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
) -> Result<Self::Output, InjectionError> {
Ok(())
}
#[cfg(feature = "tokio")]
async fn evaluate_async(
&self,
_: &mut Self::State,
_: &mut Self::Input,
_: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
) -> Result<Self::Output, InjectionError> {
Ok(())
}
fn inject_input(
&self,
_: &mut Node<Self>,
_: &mut Box<dyn AnyNode>,
_: Option<&'static str>,
_: Option<&'static str>,
) -> Result<(), InjectionError> {
Ok(())
}
}
#[derive(Clone, Copy)]
pub struct ValueStage<T: Send + Sync + Clone + 'static>(std::marker::PhantomData<T>);
impl<T: Send + Sync + Clone + 'static> ValueStage<T> {
pub fn new() -> Self {
Self(std::marker::PhantomData)
}
}
#[derive(Clone)]
pub struct ValueWrapper<T: Send + Sync + Clone + 'static>(pub Option<T>);
impl<T: Send + Sync + Clone + 'static> Default for ValueWrapper<T> {
fn default() -> Self {
Self(None)
}
}
impl<T: Send + Sync + Clone + 'static> DynFields for ValueWrapper<T> {
fn field<'a>(&'a self, _: Option<&'static str>) -> Option<&'a (dyn std::any::Any + 'static)> {
self.0.as_ref().map(|t| t as &dyn std::any::Any)
}
fn field_mut<'a>(
&'a mut self,
_: Option<&'static str>,
) -> Option<&'a mut (dyn std::any::Any + 'static)> {
self.0.as_mut().map(|t| t as &mut dyn std::any::Any)
}
fn take_field(&mut self, _: Option<&'static str>) -> Option<Box<dyn std::any::Any>> {
self.0.take().map(|t| Box::new(t) as Box<dyn std::any::Any>)
}
fn replace(&mut self, other: Box<dyn std::any::Any>) -> Box<dyn DynFields> {
if let Ok(other) = other.downcast() {
Box::new(std::mem::replace(self, *other))
} else {
panic!("Attempted to replace value with wrong type")
}
}
fn clear(&mut self) {
self.0 = None;
}
}
#[cfg_attr(feature = "tokio", async_trait::async_trait)]
impl<T: Send + Sync + Clone + 'static> Stage for ValueStage<T> {
const SHAPE: StageShape = StageShape {
stage_name: "_",
inputs: &[],
outputs: &["_"],
};
type State = ValueWrapper<T>;
type Input = ();
type Output = ValueWrapper<T>;
fn evaluate(
&self,
state: &mut Self::State,
_: &mut Self::Input,
_: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
) -> Result<Self::Output, InjectionError> {
Ok(state.clone())
}
#[cfg(feature = "tokio")]
async fn evaluate_async(
&self,
state: &mut Self::State,
_: &mut Self::Input,
_: &mut HashMap<u64, Vec<crate::Cached<Self>>>,
) -> Result<Self::Output, InjectionError> {
Ok(state.clone())
}
fn inject_input(
&self,
_: &mut Node<Self>,
_: &mut Box<dyn AnyNode>,
_: Option<&'static str>,
_: Option<&'static str>,
) -> Result<(), InjectionError> {
Ok(())
}
}