use std::marker::PhantomData;
use bevy_ecs::archetype::Archetype;
use bevy_ecs::change_detection::Tick;
use bevy_ecs::component::{ComponentId, Components, Immutable, StorageType};
use bevy_ecs::lifecycle::{ComponentHook, HookContext};
use bevy_ecs::prelude::*;
use bevy_ecs::query::{FilteredAccess, QueryData, ReadOnlyQueryData, WorldQuery};
use bevy_ecs::storage::{Table, TableRow};
use bevy_ecs::world::{unsafe_world_cell::UnsafeWorldCell, DeferredWorld};
use bevy_platform::collections::HashMap;
use crate::Static;
pub struct Expect<T>(PhantomData<T>);
impl<T: Component> Expect<T> {
fn on_add(mut world: DeferredWorld, ctx: HookContext) {
world.commands().queue(move |world: &mut World| {
let expect = world.entity_mut(ctx.entity).take::<Self>().unwrap();
let entity = world.entity(ctx.entity);
if world.contains_resource::<ExpectDeferred>() || entity.contains::<ExpectDeferred>() {
let mut buffer = world.get_resource_or_init::<ExpectDeferredBuffer>();
buffer.add(ctx.entity, Box::new(expect));
} else {
expect.validate(entity);
}
});
}
fn validate(self, entity: EntityRef) {
if !entity.contains::<T>() {
panic!(
"expected component of type `{}` does not exist on entity {:?}",
std::any::type_name::<T>(),
entity.id()
);
}
}
}
impl<T: Component> Component for Expect<T> {
const STORAGE_TYPE: StorageType = StorageType::SparseSet;
type Mutability = Immutable;
fn on_add() -> Option<ComponentHook> {
Some(Self::on_add)
}
}
impl<T: Component> Default for Expect<T> {
fn default() -> Self {
Self(Default::default())
}
}
trait ExpectValidate: Static {
fn validate(self: Box<Self>, entity: EntityRef);
}
impl<T: Component> ExpectValidate for Expect<T> {
fn validate(self: Box<Self>, entity: EntityRef) {
(*self).validate(entity);
}
}
#[derive(Resource, Component, Default)]
#[component(on_remove = Self::on_remove)]
pub struct ExpectDeferred;
impl ExpectDeferred {
fn on_remove(mut world: DeferredWorld, ctx: HookContext) {
world.commands().queue(move |world: &mut World| {
let Some(mut buffer) = world.get_resource_mut::<ExpectDeferredBuffer>() else {
return;
};
let Some(expects) = buffer.0.remove(&ctx.entity) else {
return;
};
let entity = world.entity(ctx.entity);
for expect in expects {
expect.validate(entity);
}
});
}
}
#[derive(Resource, Default)]
struct ExpectDeferredBuffer(HashMap<Entity, Vec<Box<dyn ExpectValidate>>>);
impl ExpectDeferredBuffer {
fn add(&mut self, entity: Entity, expect: Box<dyn ExpectValidate>) {
self.0.entry(entity).or_default().push(expect);
}
}
pub fn expect_deferred(world: &mut World) {
let Some(ExpectDeferredBuffer(buffer)) = world.remove_resource::<ExpectDeferredBuffer>() else {
return;
};
for (entity, expects) in buffer {
let Ok(entity) = world.get_entity(entity) else {
continue;
};
if entity.contains::<ExpectDeferred>() {
continue;
}
for expect in expects {
expect.validate(entity);
}
}
let _ = world.remove_resource::<ExpectDeferred>();
}
#[doc(hidden)]
pub struct ExpectFetch<'w, T: WorldQuery> {
fetch: T::Fetch<'w>,
matches: bool,
}
impl<T: WorldQuery> Clone for ExpectFetch<'_, T> {
fn clone(&self) -> Self {
Self {
fetch: self.fetch.clone(),
matches: self.matches,
}
}
}
unsafe impl<T: QueryData> QueryData for Expect<T> {
type ReadOnly = Expect<T::ReadOnly>;
const IS_READ_ONLY: bool = true;
const IS_ARCHETYPAL: bool = T::IS_ARCHETYPAL;
type Item<'w, 's> = T::Item<'w, 's>;
fn shrink<'wlong: 'wshort, 'wshort, 's>(
item: Self::Item<'wlong, 's>,
) -> Self::Item<'wshort, 's> {
T::shrink(item)
}
unsafe fn fetch<'w, 's>(
state: &'s Self::State,
fetch: &mut Self::Fetch<'w>,
entity: Entity,
table_row: TableRow,
) -> Option<Self::Item<'w, 's>> {
if !fetch.matches {
panic!(
"expected query of type `{}` does not match entity {:?}",
std::any::type_name::<T>(),
entity
);
}
fetch
.matches
.then(|| T::fetch(state, &mut fetch.fetch, entity, table_row))
.flatten()
}
fn iter_access(
state: &Self::State,
) -> impl Iterator<Item = bevy_ecs::query::EcsAccessType<'_>> {
T::iter_access(state)
}
}
unsafe impl<T: ReadOnlyQueryData> ReadOnlyQueryData for Expect<T> {}
unsafe impl<T: QueryData> WorldQuery for Expect<T> {
type Fetch<'w> = ExpectFetch<'w, T>;
type State = T::State;
fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
ExpectFetch {
fetch: T::shrink_fetch(fetch.fetch),
matches: fetch.matches,
}
}
const IS_DENSE: bool = T::IS_DENSE;
#[inline]
unsafe fn init_fetch<'w>(
world: UnsafeWorldCell<'w>,
state: &T::State,
last_run: Tick,
this_run: Tick,
) -> ExpectFetch<'w, T> {
ExpectFetch {
fetch: T::init_fetch(world, state, last_run, this_run),
matches: false,
}
}
#[inline]
unsafe fn set_archetype<'w>(
fetch: &mut ExpectFetch<'w, T>,
state: &T::State,
archetype: &'w Archetype,
table: &'w Table,
) {
fetch.matches = T::matches_component_set(state, &|id| archetype.contains(id));
if fetch.matches {
T::set_archetype(&mut fetch.fetch, state, archetype, table);
}
}
#[inline]
unsafe fn set_table<'w>(fetch: &mut ExpectFetch<'w, T>, state: &T::State, table: &'w Table) {
fetch.matches = T::matches_component_set(state, &|id| table.has_column(id));
if fetch.matches {
T::set_table(&mut fetch.fetch, state, table);
}
}
fn update_component_access(state: &T::State, access: &mut FilteredAccess) {
let mut intermediate = access.clone();
T::update_component_access(state, &mut intermediate);
access.extend_access(&intermediate);
}
fn get_state(components: &Components) -> Option<Self::State> {
T::get_state(components)
}
fn init_state(world: &mut World) -> T::State {
T::init_state(world)
}
fn matches_component_set(
_state: &T::State,
_set_contains_id: &impl Fn(ComponentId) -> bool,
) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use bevy_ecs::system::RunSystemOnce;
use super::*;
#[derive(Default, Component)]
struct A;
#[derive(Default, Component)]
struct B;
#[test]
#[should_panic]
fn expect_query_panic() {
let mut w = World::default();
w.spawn(A);
w.run_system_once(|q: Query<(&A, Expect<&B>)>| for _ in q.iter() {})
.unwrap();
}
#[test]
#[should_panic]
fn expect_require_panic() {
#[derive(Component)]
#[require(Expect<B>)]
struct C;
let mut w = World::default();
w.spawn(C);
}
#[test]
fn expect_deferred() {
#[derive(Component)]
#[require(Expect<B>)]
struct C;
let mut w = World::default();
let e = w.spawn((ExpectDeferred, C)).id();
w.entity_mut(e).insert(B).remove::<ExpectDeferred>();
}
#[test]
#[should_panic]
fn expect_deferred_panic() {
#[derive(Component)]
#[require(Expect<B>)]
struct C;
let mut w = World::default();
let e = w.spawn((ExpectDeferred, C)).id();
w.entity_mut(e).remove::<ExpectDeferred>();
}
}