1use std::marker::PhantomData;
4
5use bevy_ecs::archetype::Archetype;
6use bevy_ecs::change_detection::Tick;
7use bevy_ecs::component::{ComponentId, Components, Immutable, StorageType};
8use bevy_ecs::lifecycle::{ComponentHook, HookContext};
9use bevy_ecs::prelude::*;
10use bevy_ecs::query::{FilteredAccess, QueryData, ReadOnlyQueryData, WorldQuery};
11use bevy_ecs::storage::{Table, TableRow};
12use bevy_ecs::world::{unsafe_world_cell::UnsafeWorldCell, DeferredWorld};
13use bevy_platform::collections::HashMap;
14
15use crate::Static;
16
17pub struct Expect<T>(PhantomData<T>);
96
97impl<T: Component> Expect<T> {
98 fn on_add(mut world: DeferredWorld, ctx: HookContext) {
99 world.commands().queue(move |world: &mut World| {
100 let expect = world.entity_mut(ctx.entity).take::<Self>().unwrap();
101 let entity = world.entity(ctx.entity);
102 if world.contains_resource::<ExpectDeferred>() || entity.contains::<ExpectDeferred>() {
103 let mut buffer = world.get_resource_or_init::<ExpectDeferredBuffer>();
104 buffer.add(ctx.entity, Box::new(expect));
105 } else {
106 expect.validate(entity);
107 }
108 });
109 }
110
111 fn validate(self, entity: EntityRef) {
112 if !entity.contains::<T>() {
113 panic!(
114 "expected component of type `{}` does not exist on entity {:?}",
115 std::any::type_name::<T>(),
116 entity.id()
117 );
118 }
119 }
120}
121
122impl<T: Component> Component for Expect<T> {
123 const STORAGE_TYPE: StorageType = StorageType::SparseSet;
124
125 type Mutability = Immutable;
126
127 fn on_add() -> Option<ComponentHook> {
128 Some(Self::on_add)
129 }
130}
131
132impl<T: Component> Default for Expect<T> {
133 fn default() -> Self {
134 Self(Default::default())
135 }
136}
137
138trait ExpectValidate: Static {
139 fn validate(self: Box<Self>, entity: EntityRef);
140}
141
142impl<T: Component> ExpectValidate for Expect<T> {
143 fn validate(self: Box<Self>, entity: EntityRef) {
144 (*self).validate(entity);
145 }
146}
147
148#[derive(Resource, Component, Default)]
164#[component(on_remove = Self::on_remove)]
165pub struct ExpectDeferred;
166
167impl ExpectDeferred {
168 fn on_remove(mut world: DeferredWorld, ctx: HookContext) {
169 world.commands().queue(move |world: &mut World| {
170 let Some(mut buffer) = world.get_resource_mut::<ExpectDeferredBuffer>() else {
171 return;
172 };
173
174 let Some(expects) = buffer.0.remove(&ctx.entity) else {
175 return;
176 };
177
178 let entity = world.entity(ctx.entity);
179 for expect in expects {
180 expect.validate(entity);
181 }
182 });
183 }
184}
185
186#[derive(Resource, Default)]
187struct ExpectDeferredBuffer(HashMap<Entity, Vec<Box<dyn ExpectValidate>>>);
188
189impl ExpectDeferredBuffer {
190 fn add(&mut self, entity: Entity, expect: Box<dyn ExpectValidate>) {
191 self.0.entry(entity).or_default().push(expect);
192 }
193}
194
195pub fn expect_deferred(world: &mut World) {
205 let Some(ExpectDeferredBuffer(buffer)) = world.remove_resource::<ExpectDeferredBuffer>() else {
206 return;
207 };
208
209 for (entity, expects) in buffer {
210 let Ok(entity) = world.get_entity(entity) else {
211 continue;
212 };
213
214 if entity.contains::<ExpectDeferred>() {
215 continue;
216 }
217
218 for expect in expects {
219 expect.validate(entity);
220 }
221 }
222
223 let _ = world.remove_resource::<ExpectDeferred>();
224}
225
226#[doc(hidden)]
227pub struct ExpectFetch<'w, T: WorldQuery> {
228 fetch: T::Fetch<'w>,
229 matches: bool,
230}
231
232impl<T: WorldQuery> Clone for ExpectFetch<'_, T> {
233 fn clone(&self) -> Self {
234 Self {
235 fetch: self.fetch.clone(),
236 matches: self.matches,
237 }
238 }
239}
240
241unsafe impl<T: QueryData> QueryData for Expect<T> {
242 type ReadOnly = Expect<T::ReadOnly>;
243
244 const IS_READ_ONLY: bool = true;
245
246 const IS_ARCHETYPAL: bool = T::IS_ARCHETYPAL;
247
248 type Item<'w, 's> = T::Item<'w, 's>;
249
250 fn shrink<'wlong: 'wshort, 'wshort, 's>(
251 item: Self::Item<'wlong, 's>,
252 ) -> Self::Item<'wshort, 's> {
253 T::shrink(item)
254 }
255
256 unsafe fn fetch<'w, 's>(
257 state: &'s Self::State,
258 fetch: &mut Self::Fetch<'w>,
259 entity: Entity,
260 table_row: TableRow,
261 ) -> Option<Self::Item<'w, 's>> {
262 if !fetch.matches {
263 panic!(
264 "expected query of type `{}` does not match entity {:?}",
265 std::any::type_name::<T>(),
266 entity
267 );
268 }
269 fetch
270 .matches
271 .then(|| T::fetch(state, &mut fetch.fetch, entity, table_row))
272 .flatten()
273 }
274
275 fn iter_access(
276 state: &Self::State,
277 ) -> impl Iterator<Item = bevy_ecs::query::EcsAccessType<'_>> {
278 T::iter_access(state)
279 }
280}
281
282unsafe impl<T: ReadOnlyQueryData> ReadOnlyQueryData for Expect<T> {}
283
284unsafe impl<T: QueryData> WorldQuery for Expect<T> {
285 type Fetch<'w> = ExpectFetch<'w, T>;
286 type State = T::State;
287
288 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
289 ExpectFetch {
290 fetch: T::shrink_fetch(fetch.fetch),
291 matches: fetch.matches,
292 }
293 }
294
295 const IS_DENSE: bool = T::IS_DENSE;
296
297 #[inline]
298 unsafe fn init_fetch<'w>(
299 world: UnsafeWorldCell<'w>,
300 state: &T::State,
301 last_run: Tick,
302 this_run: Tick,
303 ) -> ExpectFetch<'w, T> {
304 ExpectFetch {
305 fetch: T::init_fetch(world, state, last_run, this_run),
306 matches: false,
307 }
308 }
309
310 #[inline]
311 unsafe fn set_archetype<'w>(
312 fetch: &mut ExpectFetch<'w, T>,
313 state: &T::State,
314 archetype: &'w Archetype,
315 table: &'w Table,
316 ) {
317 fetch.matches = T::matches_component_set(state, &|id| archetype.contains(id));
318 if fetch.matches {
319 T::set_archetype(&mut fetch.fetch, state, archetype, table);
320 }
321 }
322
323 #[inline]
324 unsafe fn set_table<'w>(fetch: &mut ExpectFetch<'w, T>, state: &T::State, table: &'w Table) {
325 fetch.matches = T::matches_component_set(state, &|id| table.has_column(id));
326 if fetch.matches {
327 T::set_table(&mut fetch.fetch, state, table);
328 }
329 }
330
331 fn update_component_access(state: &T::State, access: &mut FilteredAccess) {
332 let mut intermediate = access.clone();
333 T::update_component_access(state, &mut intermediate);
334 access.extend_access(&intermediate);
335 }
336
337 fn get_state(components: &Components) -> Option<Self::State> {
338 T::get_state(components)
339 }
340
341 fn init_state(world: &mut World) -> T::State {
342 T::init_state(world)
343 }
344
345 fn matches_component_set(
346 _state: &T::State,
347 _set_contains_id: &impl Fn(ComponentId) -> bool,
348 ) -> bool {
349 true
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use bevy_ecs::system::RunSystemOnce;
356
357 use super::*;
358
359 #[derive(Default, Component)]
360 struct A;
361
362 #[derive(Default, Component)]
363 struct B;
364
365 #[test]
366 #[should_panic]
367 fn expect_query_panic() {
368 let mut w = World::default();
369 w.spawn(A);
370 w.run_system_once(|q: Query<(&A, Expect<&B>)>| for _ in q.iter() {})
371 .unwrap();
372 }
373
374 #[test]
375 #[should_panic]
376 fn expect_require_panic() {
377 #[derive(Component)]
378 #[require(Expect<B>)]
379 struct C;
380
381 let mut w = World::default();
382 w.spawn(C);
383 }
384
385 #[test]
386 fn expect_deferred() {
387 #[derive(Component)]
388 #[require(Expect<B>)]
389 struct C;
390
391 let mut w = World::default();
392 let e = w.spawn((ExpectDeferred, C)).id();
393 w.entity_mut(e).insert(B).remove::<ExpectDeferred>();
394 }
395
396 #[test]
397 #[should_panic]
398 fn expect_deferred_panic() {
399 #[derive(Component)]
400 #[require(Expect<B>)]
401 struct C;
402
403 let mut w = World::default();
404 let e = w.spawn((ExpectDeferred, C)).id();
405 w.entity_mut(e).remove::<ExpectDeferred>();
406 }
407}