use std::marker::PhantomData;
use crate::Handler;
use crate::handler::Param;
use crate::world::{Registry, World};
pub trait Blueprint {
type Event;
type Params: Param + 'static;
}
pub trait CallbackBlueprint: Blueprint {
type Context: Send + 'static;
}
#[doc(hidden)]
#[diagnostic::on_unimplemented(
message = "function signature doesn't match the blueprint's Event and Params types",
note = "the function must accept the blueprint's Params then Event, in that order"
)]
pub trait TemplateDispatch<P: Param, E> {
fn run_fn_ptr() -> unsafe fn(&mut P::State, &mut World, E);
fn validate(state: &P::State, registry: &Registry);
}
#[doc(hidden)]
#[diagnostic::on_unimplemented(
message = "function signature doesn't match the callback blueprint's Context, Event, and Params types",
note = "the function must accept &mut Context first, then Params, then Event"
)]
pub trait CallbackTemplateDispatch<C, P: Param, E> {
fn run_fn_ptr() -> unsafe fn(&mut C, &mut P::State, &mut World, E);
fn validate(state: &P::State, registry: &Registry);
}
impl<E, F: FnMut(E) + Send + 'static> TemplateDispatch<(), E> for F {
fn run_fn_ptr() -> unsafe fn(&mut (), &mut World, E) {
unsafe fn run<E, F: FnMut(E) + Send>(_state: &mut (), _world: &mut World, event: E) {
let mut f: F = unsafe { std::mem::zeroed() };
f(event);
}
run::<E, F>
}
fn validate(_state: &(), _registry: &Registry) {}
}
impl<C: Send + 'static, E, F: FnMut(&mut C, E) + Send + 'static> CallbackTemplateDispatch<C, (), E>
for F
{
fn run_fn_ptr() -> unsafe fn(&mut C, &mut (), &mut World, E) {
unsafe fn run<C: Send, E, F: FnMut(&mut C, E) + Send>(
ctx: &mut C,
_state: &mut (),
_world: &mut World,
event: E,
) {
let mut f: F = unsafe { std::mem::zeroed() };
f(ctx, event);
}
run::<C, E, F>
}
fn validate(_state: &(), _registry: &Registry) {}
}
macro_rules! impl_template_dispatch {
($($P:ident),+) => {
impl<E, F: Send + 'static, $($P: Param + 'static),+> TemplateDispatch<($($P,)+), E> for F
where
for<'a> &'a mut F: FnMut($($P,)+ E) + FnMut($($P::Item<'a>,)+ E),
{
fn run_fn_ptr() -> unsafe fn(&mut ($($P::State,)+), &mut World, E) {
#[allow(non_snake_case)]
unsafe fn run<E, F: Send + 'static, $($P: Param + 'static),+>(
state: &mut ($($P::State,)+),
world: &mut World,
event: E,
) where
for<'a> &'a mut F: FnMut($($P,)+ E) + FnMut($($P::Item<'a>,)+ E),
{
#[allow(clippy::too_many_arguments)]
fn call_inner<$($P,)+ Ev>(
mut f: impl FnMut($($P,)+ Ev),
$($P: $P,)+
event: Ev,
) {
f($($P,)+ event);
}
#[cfg(debug_assertions)]
world.clear_borrows();
let ($($P,)+) = unsafe {
<($($P,)+) as Param>::fetch(world, state)
};
let mut f: F = unsafe { std::mem::zeroed() };
call_inner(&mut f, $($P,)+ event);
}
run::<E, F, $($P),+>
}
#[allow(non_snake_case)]
fn validate(state: &($($P::State,)+), registry: &Registry) {
let ($($P,)+) = state;
registry.check_access(&[
$((<$P as Param>::resource_id($P), std::any::type_name::<$P>()),)+
]);
}
}
impl<C: Send + 'static, E, F: Send + 'static, $($P: Param + 'static),+>
CallbackTemplateDispatch<C, ($($P,)+), E> for F
where
for<'a> &'a mut F:
FnMut(&mut C, $($P,)+ E) +
FnMut(&mut C, $($P::Item<'a>,)+ E),
{
fn run_fn_ptr() -> unsafe fn(&mut C, &mut ($($P::State,)+), &mut World, E) {
#[allow(non_snake_case)]
unsafe fn run<C: Send, E, F: Send + 'static, $($P: Param + 'static),+>(
ctx: &mut C,
state: &mut ($($P::State,)+),
world: &mut World,
event: E,
) where
for<'a> &'a mut F:
FnMut(&mut C, $($P,)+ E) +
FnMut(&mut C, $($P::Item<'a>,)+ E),
{
#[allow(clippy::too_many_arguments)]
fn call_inner<Ctx, $($P,)+ Ev>(
mut f: impl FnMut(&mut Ctx, $($P,)+ Ev),
ctx: &mut Ctx,
$($P: $P,)+
event: Ev,
) {
f(ctx, $($P,)+ event);
}
#[cfg(debug_assertions)]
world.clear_borrows();
let ($($P,)+) = unsafe {
<($($P,)+) as Param>::fetch(world, state)
};
let mut f: F = unsafe { std::mem::zeroed() };
call_inner(&mut f, ctx, $($P,)+ event);
}
run::<C, E, F, $($P),+>
}
#[allow(non_snake_case)]
fn validate(state: &($($P::State,)+), registry: &Registry) {
let ($($P,)+) = state;
registry.check_access(&[
$((<$P as Param>::resource_id($P), std::any::type_name::<$P>()),)+
]);
}
}
};
}
all_tuples!(impl_template_dispatch);
pub struct HandlerTemplate<K: Blueprint>
where
<K::Params as Param>::State: Copy,
{
prototype: <K::Params as Param>::State,
run_fn: unsafe fn(&mut <K::Params as Param>::State, &mut World, K::Event),
name: &'static str,
_key: PhantomData<fn() -> K>,
}
impl<K: Blueprint> HandlerTemplate<K>
where
<K::Params as Param>::State: Copy,
{
#[allow(clippy::needless_pass_by_value)]
pub fn new<F>(f: F, registry: &Registry) -> Self
where
F: TemplateDispatch<K::Params, K::Event>,
{
const {
assert!(
std::mem::size_of::<F>() == 0,
"F must be a ZST (named function item, not a closure or fn pointer)"
);
}
let _ = f;
let prototype = K::Params::init(registry);
F::validate(&prototype, registry);
Self {
prototype,
run_fn: F::run_fn_ptr(),
name: std::any::type_name::<F>(),
_key: PhantomData,
}
}
#[must_use = "the generated handler must be stored or dispatched"]
pub fn generate(&self) -> TemplatedHandler<K> {
TemplatedHandler {
state: self.prototype,
run_fn: self.run_fn,
name: self.name,
_key: PhantomData,
}
}
}
pub struct TemplatedHandler<K: Blueprint>
where
<K::Params as Param>::State: Copy,
{
state: <K::Params as Param>::State,
run_fn: unsafe fn(&mut <K::Params as Param>::State, &mut World, K::Event),
name: &'static str,
_key: PhantomData<fn() -> K>,
}
impl<K: Blueprint> Handler<K::Event> for TemplatedHandler<K>
where
<K::Params as Param>::State: Copy,
{
fn run(&mut self, world: &mut World, event: K::Event) {
unsafe { (self.run_fn)(&mut self.state, world, event) }
}
fn name(&self) -> &'static str {
self.name
}
}
type CallbackRunFn<K> = unsafe fn(
&mut <K as CallbackBlueprint>::Context,
&mut <<K as Blueprint>::Params as Param>::State,
&mut World,
<K as Blueprint>::Event,
);
pub struct CallbackTemplate<K: CallbackBlueprint>
where
<K::Params as Param>::State: Copy,
{
prototype: <K::Params as Param>::State,
run_fn: CallbackRunFn<K>,
name: &'static str,
_key: PhantomData<fn() -> K>,
}
impl<K: CallbackBlueprint> CallbackTemplate<K>
where
<K::Params as Param>::State: Copy,
{
#[allow(clippy::needless_pass_by_value)]
pub fn new<F>(f: F, registry: &Registry) -> Self
where
F: CallbackTemplateDispatch<K::Context, K::Params, K::Event>,
{
const {
assert!(
std::mem::size_of::<F>() == 0,
"F must be a ZST (named function item, not a closure or fn pointer)"
);
}
let _ = f;
let prototype = K::Params::init(registry);
F::validate(&prototype, registry);
Self {
prototype,
run_fn: F::run_fn_ptr(),
name: std::any::type_name::<F>(),
_key: PhantomData,
}
}
#[must_use = "the generated callback must be stored or dispatched"]
pub fn generate(&self, ctx: K::Context) -> TemplatedCallback<K> {
TemplatedCallback {
ctx,
state: self.prototype,
run_fn: self.run_fn,
name: self.name,
_key: PhantomData,
}
}
}
pub struct TemplatedCallback<K: CallbackBlueprint>
where
<K::Params as Param>::State: Copy,
{
ctx: K::Context,
state: <K::Params as Param>::State,
run_fn: CallbackRunFn<K>,
name: &'static str,
_key: PhantomData<fn() -> K>,
}
impl<K: CallbackBlueprint> TemplatedCallback<K>
where
<K::Params as Param>::State: Copy,
{
pub fn ctx(&self) -> &K::Context {
&self.ctx
}
pub fn ctx_mut(&mut self) -> &mut K::Context {
&mut self.ctx
}
}
impl<K: CallbackBlueprint> Handler<K::Event> for TemplatedCallback<K>
where
<K::Params as Param>::State: Copy,
{
fn run(&mut self, world: &mut World, event: K::Event) {
unsafe { (self.run_fn)(&mut self.ctx, &mut self.state, world, event) }
}
fn name(&self) -> &'static str {
self.name
}
}
#[macro_export]
macro_rules! handler_blueprint {
($name:ident, Event = $event:ty, Params = $params:ty) => {
struct $name;
impl $crate::template::Blueprint for $name {
type Event = $event;
type Params = $params;
}
};
}
#[macro_export]
macro_rules! callback_blueprint {
($name:ident, Context = $ctx:ty, Event = $event:ty, Params = $params:ty) => {
struct $name;
impl $crate::template::Blueprint for $name {
type Event = $event;
type Params = $params;
}
impl $crate::template::CallbackBlueprint for $name {
type Context = $ctx;
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Res, ResMut, WorldBuilder};
struct OnTick;
impl Blueprint for OnTick {
type Event = u32;
type Params = (ResMut<'static, u64>,);
}
struct EventOnly;
impl Blueprint for EventOnly {
type Event = u32;
type Params = ();
}
struct TwoParams;
impl Blueprint for TwoParams {
type Event = ();
type Params = (Res<'static, u64>, ResMut<'static, bool>);
}
fn tick(mut counter: ResMut<u64>, event: u32) {
*counter += event as u64;
}
#[test]
fn handler_template_basic() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = HandlerTemplate::<OnTick>::new(tick, world.registry());
let mut h = template.generate();
h.run(&mut world, 10);
assert_eq!(*world.resource::<u64>(), 10);
}
#[test]
fn handler_template_stamps_independent() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = HandlerTemplate::<OnTick>::new(tick, world.registry());
let mut h1 = template.generate();
let mut h2 = template.generate();
h1.run(&mut world, 10);
h2.run(&mut world, 5);
assert_eq!(*world.resource::<u64>(), 15);
}
fn event_only_fn(event: u32) {
assert!(event > 0);
}
#[test]
fn handler_template_event_only() {
let mut world = WorldBuilder::new().build();
let template = HandlerTemplate::<EventOnly>::new(event_only_fn, world.registry());
let mut h = template.generate();
h.run(&mut world, 42);
}
fn two_params_fn(counter: Res<u64>, mut flag: ResMut<bool>, _event: ()) {
if *counter > 0 {
*flag = true;
}
}
#[test]
fn handler_template_two_params() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(1);
builder.register::<bool>(false);
let mut world = builder.build();
let template = HandlerTemplate::<TwoParams>::new(two_params_fn, world.registry());
let mut h = template.generate();
h.run(&mut world, ());
assert!(*world.resource::<bool>());
}
#[test]
fn templated_handler_boxable() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = HandlerTemplate::<OnTick>::new(tick, world.registry());
let h = template.generate();
let mut boxed: Box<dyn Handler<u32>> = Box::new(h);
boxed.run(&mut world, 7);
assert_eq!(*world.resource::<u64>(), 7);
}
#[test]
fn templated_handler_name() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let world = builder.build();
let template = HandlerTemplate::<OnTick>::new(tick, world.registry());
let h = template.generate();
assert!(h.name().contains("tick"));
}
#[test]
#[should_panic(expected = "not registered")]
fn handler_template_missing_resource() {
let world = WorldBuilder::new().build();
let _template = HandlerTemplate::<OnTick>::new(tick, world.registry());
}
#[test]
#[should_panic(expected = "conflicting access")]
fn handler_template_duplicate_access() {
struct BadBlueprint;
impl Blueprint for BadBlueprint {
type Event = ();
type Params = (Res<'static, u64>, ResMut<'static, u64>);
}
fn bad(a: Res<u64>, b: ResMut<u64>, _e: ()) {
let _ = (*a, &*b);
}
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let world = builder.build();
let _template = HandlerTemplate::<BadBlueprint>::new(bad, world.registry());
}
struct TimerCtx {
order_id: u64,
fires: u64,
}
struct OnTimeout;
impl Blueprint for OnTimeout {
type Event = ();
type Params = (ResMut<'static, u64>,);
}
impl CallbackBlueprint for OnTimeout {
type Context = TimerCtx;
}
fn on_timeout(ctx: &mut TimerCtx, mut counter: ResMut<u64>, _event: ()) {
ctx.fires += 1;
*counter += ctx.order_id;
}
#[test]
fn callback_template_basic() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = CallbackTemplate::<OnTimeout>::new(on_timeout, world.registry());
let mut cb = template.generate(TimerCtx {
order_id: 42,
fires: 0,
});
cb.run(&mut world, ());
assert_eq!(cb.ctx().fires, 1);
assert_eq!(*world.resource::<u64>(), 42);
}
#[test]
fn callback_template_independent_contexts() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = CallbackTemplate::<OnTimeout>::new(on_timeout, world.registry());
let mut cb1 = template.generate(TimerCtx {
order_id: 10,
fires: 0,
});
let mut cb2 = template.generate(TimerCtx {
order_id: 20,
fires: 0,
});
cb1.run(&mut world, ());
cb2.run(&mut world, ());
assert_eq!(cb1.ctx().fires, 1);
assert_eq!(cb2.ctx().fires, 1);
assert_eq!(*world.resource::<u64>(), 30);
}
struct CtxOnlyKey;
impl Blueprint for CtxOnlyKey {
type Event = u32;
type Params = ();
}
impl CallbackBlueprint for CtxOnlyKey {
type Context = u64;
}
fn ctx_only(ctx: &mut u64, event: u32) {
*ctx += event as u64;
}
#[test]
fn callback_template_event_only() {
let mut world = WorldBuilder::new().build();
let template = CallbackTemplate::<CtxOnlyKey>::new(ctx_only, world.registry());
let mut cb = template.generate(0u64);
cb.run(&mut world, 5);
assert_eq!(*cb.ctx(), 5);
}
#[test]
fn callback_template_boxable() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = CallbackTemplate::<OnTimeout>::new(on_timeout, world.registry());
let cb = template.generate(TimerCtx {
order_id: 7,
fires: 0,
});
let mut boxed: Box<dyn Handler<()>> = Box::new(cb);
boxed.run(&mut world, ());
assert_eq!(*world.resource::<u64>(), 7);
}
#[test]
fn callback_template_ctx_accessible() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = CallbackTemplate::<OnTimeout>::new(on_timeout, world.registry());
let mut cb = template.generate(TimerCtx {
order_id: 42,
fires: 0,
});
assert_eq!(cb.ctx().order_id, 42);
cb.run(&mut world, ());
assert_eq!(cb.ctx().fires, 1);
cb.ctx_mut().order_id = 99;
cb.run(&mut world, ());
assert_eq!(*world.resource::<u64>(), 42 + 99);
}
handler_blueprint!(MacroOnTick, Event = u32, Params = (ResMut<'static, u64>,));
#[test]
fn macro_handler_blueprint() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = HandlerTemplate::<MacroOnTick>::new(tick, world.registry());
let mut h = template.generate();
h.run(&mut world, 3);
assert_eq!(*world.resource::<u64>(), 3);
}
struct Offset(i64);
impl crate::world::Resource for Offset {}
struct Scale(u32);
impl crate::world::Resource for Scale {}
struct Tag(u32);
impl crate::world::Resource for Tag {}
struct FiveParamBlueprint;
impl Blueprint for FiveParamBlueprint {
type Event = u32;
type Params = (
ResMut<'static, u64>,
Res<'static, bool>,
ResMut<'static, Offset>,
Res<'static, Scale>,
ResMut<'static, Tag>,
);
}
fn five_param_fn(
mut counter: ResMut<u64>,
flag: Res<bool>,
mut offset: ResMut<Offset>,
scale: Res<Scale>,
mut tag: ResMut<Tag>,
event: u32,
) {
if *flag {
*counter += event as u64;
}
offset.0 += (scale.0 as i64) * (event as i64);
tag.0 = event;
}
#[test]
fn handler_template_five_params() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
builder.register::<bool>(true);
builder.register(Offset(0));
builder.register(Scale(2));
builder.register(Tag(0));
let mut world = builder.build();
let template = HandlerTemplate::<FiveParamBlueprint>::new(five_param_fn, world.registry());
let mut h1 = template.generate();
let mut h2 = template.generate();
h1.run(&mut world, 10);
assert_eq!(*world.resource::<u64>(), 10);
assert_eq!(world.resource::<Offset>().0, 20);
assert_eq!(world.resource::<Tag>().0, 10);
h2.run(&mut world, 5);
assert_eq!(*world.resource::<u64>(), 15);
assert_eq!(world.resource::<Offset>().0, 30);
assert_eq!(world.resource::<Tag>().0, 5);
}
callback_blueprint!(MacroOnTimeout, Context = TimerCtx, Event = (), Params = (ResMut<'static, u64>,));
#[test]
fn macro_callback_blueprint() {
let mut builder = WorldBuilder::new();
builder.register::<u64>(0);
let mut world = builder.build();
let template = CallbackTemplate::<MacroOnTimeout>::new(on_timeout, world.registry());
let mut cb = template.generate(TimerCtx {
order_id: 5,
fires: 0,
});
cb.run(&mut world, ());
assert_eq!(cb.ctx().fires, 1);
assert_eq!(*world.resource::<u64>(), 5);
}
}