use std::any::Any;
use std::future::Future;
use std::marker::PhantomData;
use std::ops::DerefMut;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::task::Wake;
#[cfg(feature = "meta")]
use dfir_lang::diagnostic::{Diagnostic, Diagnostics, SerdeSpan};
#[cfg(feature = "meta")]
use dfir_lang::graph::DfirGraph;
use super::metrics::{DfirMetrics, DfirMetricsIntervals};
use super::state::StateHandle;
use super::{StateLifespan, StateTag, SubgraphId};
use crate::scheduled::ticks::TickInstant;
use crate::util::slot_vec::SlotVec;
struct StateData {
state: Box<dyn Any>,
lifespan_hook_fn: Option<LifespanResetFn>,
lifespan: Option<StateLifespan>,
}
type LifespanResetFn = Box<dyn FnMut(&mut dyn Any)>;
#[doc(hidden)]
pub struct WakeState {
can_start_tick: std::sync::atomic::AtomicBool,
task_waker: futures::task::AtomicWaker,
}
impl Default for WakeState {
fn default() -> Self {
Self {
can_start_tick: std::sync::atomic::AtomicBool::new(false),
task_waker: futures::task::AtomicWaker::new(),
}
}
}
impl Wake for WakeState {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
self.can_start_tick.store(true, Ordering::Relaxed);
self.task_waker.wake();
}
}
#[doc(hidden)]
#[derive(Default)]
pub struct Context {
states: SlotVec<StateTag, StateData>,
current_tick: TickInstant,
wake_state: Arc<WakeState>,
metrics: Rc<DfirMetrics>,
tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
}
impl Context {
pub fn new(wake_state: Arc<WakeState>, metrics: Rc<DfirMetrics>) -> Self {
Self {
states: SlotVec::new(),
current_tick: TickInstant::default(),
wake_state,
metrics,
tasks_to_spawn: Vec::new(),
}
}
pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
where
T: Any,
{
let state_data = StateData {
state: Box::new(state),
lifespan_hook_fn: None,
lifespan: None,
};
let state_id = self.states.insert(state_data);
StateHandle {
state_id,
_phantom: PhantomData,
}
}
pub fn set_state_lifespan_hook<T>(
&mut self,
handle: StateHandle<T>,
_lifespan: StateLifespan,
mut hook_fn: impl 'static + FnMut(&mut T),
) where
T: Any,
{
let state_data = self
.states
.get_mut(handle.state_id)
.expect("Failed to find state with given handle.");
state_data.lifespan_hook_fn = Some(Box::new(move |state| {
(hook_fn)(state.downcast_mut::<T>().unwrap());
}));
state_data.lifespan = Some(_lifespan);
}
pub fn request_task<Fut>(&mut self, future: Fut)
where
Fut: Future<Output = ()> + 'static,
{
self.tasks_to_spawn.push(Box::pin(future));
}
pub unsafe fn state_ref_unchecked<T>(&self, handle: StateHandle<T>) -> &'_ T
where
T: Any,
{
let state = self
.states
.get(handle.state_id)
.expect("Failed to find state with given handle.")
.state
.as_ref();
debug_assert!(state.is::<T>());
unsafe { &*(state as *const dyn Any as *const T) }
}
pub fn is_first_run_this_tick(&self) -> bool {
true
}
pub fn current_tick(&self) -> TickInstant {
self.current_tick
}
pub fn metrics(&self) -> &Rc<DfirMetrics> {
&self.metrics
}
pub fn current_subgraph(&self) -> SubgraphId {
SubgraphId::from_raw(0)
}
pub fn schedule_subgraph(&self, _sg_id: SubgraphId, is_external: bool) {
if is_external {
self.wake_state.wake_by_ref();
}
}
pub fn waker(&self) -> std::task::Waker {
std::task::Waker::from(self.wake_state.clone())
}
#[doc(hidden)]
pub fn __end_tick(&mut self) {
for state_data in self.states.values_mut() {
let StateData {
state,
lifespan_hook_fn: Some(lifespan_hook_fn),
lifespan: Some(StateLifespan::Tick),
} = state_data
else {
continue;
};
(lifespan_hook_fn)(Box::deref_mut(state));
}
self.current_tick += crate::scheduled::ticks::TickDuration::SINGLE_TICK;
}
}
#[doc(hidden)]
pub struct Dfir<Tick> {
tick_closure: Tick,
wake_state: Arc<WakeState>,
context: Context,
#[cfg(feature = "meta")]
meta_graph: Option<DfirGraph>,
#[cfg(feature = "meta")]
diagnostics: Option<Vec<Diagnostic<SerdeSpan>>>,
}
#[doc(hidden)]
pub trait TickClosure {
fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a;
}
impl<F: for<'a> AsyncFnMut(&'a mut Context) -> bool> TickClosure for F {
fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
self(ctx)
}
}
#[doc(hidden)]
pub struct NullTickClosure;
impl TickClosure for NullTickClosure {
fn call_tick<'a>(&'a mut self, _ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
std::future::ready(false)
}
}
#[doc(hidden)]
pub struct TickClosureErased(Box<dyn TickClosureErasedInner>);
trait TickClosureErasedInner {
fn call_tick<'a>(
&'a mut self,
ctx: &'a mut Context,
) -> Pin<Box<dyn Future<Output = bool> + 'a>>;
}
impl<F: for<'a> AsyncFnMut(&'a mut Context) -> bool> TickClosureErasedInner for F {
fn call_tick<'a>(
&'a mut self,
ctx: &'a mut Context,
) -> Pin<Box<dyn Future<Output = bool> + 'a>> {
Box::pin(self(ctx))
}
}
impl TickClosure for TickClosureErased {
fn call_tick<'a>(&'a mut self, ctx: &'a mut Context) -> impl Future<Output = bool> + 'a {
self.0.call_tick(ctx)
}
}
pub type DfirErased = Dfir<TickClosureErased>;
impl<Tick: TickClosure> Dfir<Tick> {
#[doc(hidden)]
pub fn new(
tick_closure: Tick,
context: Context,
meta_graph_json: Option<&str>,
diagnostics_json: Option<&str>,
) -> Self {
#[cfg(not(feature = "meta"))]
let _ = (meta_graph_json, diagnostics_json);
Self {
tick_closure,
wake_state: context.wake_state.clone(),
context,
#[cfg(feature = "meta")]
meta_graph: meta_graph_json.map(|json| {
let mut meta_graph: DfirGraph =
serde_json::from_str(json).expect("Failed to deserialize graph.");
let mut op_inst_diagnostics = Diagnostics::new();
meta_graph.insert_node_op_insts_all(&mut op_inst_diagnostics);
assert!(
op_inst_diagnostics.is_empty(),
"Expected no diagnostics, got: {:#?}",
op_inst_diagnostics
);
meta_graph
}),
#[cfg(feature = "meta")]
diagnostics: diagnostics_json.map(|json| {
serde_json::from_str(json).expect("Failed to deserialize diagnostics.")
}),
}
}
#[cfg(feature = "meta")]
#[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
pub fn meta_graph(&self) -> Option<&DfirGraph> {
self.meta_graph.as_ref()
}
#[cfg(feature = "meta")]
#[cfg_attr(docsrs, doc(cfg(feature = "meta")))]
pub fn diagnostics(&self) -> Option<&[Diagnostic<SerdeSpan>]> {
self.diagnostics.as_deref()
}
pub fn metrics(&self) -> Rc<DfirMetrics> {
Rc::clone(self.context.metrics())
}
pub fn current_tick(&self) -> TickInstant {
self.context.current_tick()
}
pub fn metrics_intervals(&self) -> DfirMetricsIntervals {
DfirMetricsIntervals {
curr: self.metrics(),
prev: None,
}
}
}
impl<Tick: TickClosure> Dfir<Tick> {
fn spawn_tasks(&mut self) {
for task in self.context.tasks_to_spawn.drain(..) {
tokio::task::spawn_local(task);
}
}
pub async fn run_tick(&mut self) -> bool {
self.spawn_tasks();
let had_external = self
.wake_state
.can_start_tick
.swap(false, Ordering::Relaxed);
let tick_had_work = self.tick_closure.call_tick(&mut self.context).await;
had_external || tick_had_work || self.wake_state.can_start_tick.load(Ordering::Relaxed)
}
pub fn run_tick_sync(&mut self) -> bool {
let mut fut = std::pin::pin!(self.run_tick());
let mut ctx = std::task::Context::from_waker(std::task::Waker::noop());
match fut.as_mut().poll(&mut ctx) {
std::task::Poll::Ready(result) => result,
std::task::Poll::Pending => {
panic!("Dfir::run_tick_sync: tick yielded asynchronously.")
}
}
}
pub async fn run_available(&mut self) {
self.wake_state
.can_start_tick
.store(false, Ordering::Relaxed);
loop {
self.run_tick().await;
let can_start_tick = self
.wake_state
.can_start_tick
.swap(false, Ordering::Relaxed);
if !can_start_tick {
break;
}
tokio::task::yield_now().await;
}
}
pub fn run_available_sync(&mut self) {
self.wake_state
.can_start_tick
.store(false, Ordering::Relaxed);
loop {
self.run_tick_sync();
let can_start_tick = self
.wake_state
.can_start_tick
.swap(false, Ordering::Relaxed);
if !can_start_tick {
break;
}
}
}
pub async fn run(&mut self) -> crate::Never {
loop {
self.run_available().await;
std::future::poll_fn(|cx| {
self.wake_state.task_waker.register(cx.waker());
if self.wake_state.can_start_tick.load(Ordering::Relaxed) {
std::task::Poll::Ready(())
} else {
std::task::Poll::Pending
}
})
.await;
}
}
}
impl<Tick: 'static + for<'a> AsyncFnMut(&'a mut Context) -> bool> Dfir<Tick> {
pub fn into_erased(self) -> DfirErased {
Dfir {
tick_closure: TickClosureErased(Box::new(self.tick_closure)),
wake_state: self.wake_state,
context: self.context,
#[cfg(feature = "meta")]
meta_graph: self.meta_graph,
#[cfg(feature = "meta")]
diagnostics: self.diagnostics,
}
}
}