flax/fetch/
source.rs

1use core::{fmt::Debug, marker::PhantomData};
2
3use alloc::vec::Vec;
4
5use crate::{
6    archetype::{Archetype, ArchetypeId, Slice, Slot},
7    system::Access,
8    Entity, Fetch, FetchItem,
9};
10
11use super::{FetchAccessData, FetchPrepareData, PreparedFetch, RandomFetch};
12
13pub trait FetchSource {
14    fn resolve<'a, 'w, Q: Fetch<'w>>(
15        &self,
16        fetch: &Q,
17        data: FetchAccessData<'a>,
18    ) -> Option<(ArchetypeId, &'a Archetype, Option<Slot>)>;
19
20    fn describe(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result;
21}
22
23/// Selects the fetch value from the first target of the specified relation
24pub struct FromRelation {
25    pub(crate) relation: Entity,
26    pub(crate) name: &'static str,
27}
28
29impl FetchSource for FromRelation {
30    fn resolve<'a, 'w, Q: Fetch<'w>>(
31        &self,
32        fetch: &Q,
33        data: FetchAccessData<'a>,
34    ) -> Option<(ArchetypeId, &'a Archetype, Option<Slot>)> {
35        for (key, _) in data.arch.relations_like(self.relation) {
36            let target = key.target.unwrap();
37
38            let loc = data
39                .world
40                .location(target)
41                .expect("Relation contains invalid entity");
42
43            let arch = data.world.archetypes.get(loc.arch_id);
44
45            if fetch.filter_arch(FetchAccessData {
46                arch,
47                arch_id: loc.arch_id,
48                ..data
49            }) {
50                return Some((loc.arch_id, arch, Some(loc.slot)));
51            }
52        }
53
54        None
55    }
56
57    fn describe(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58        write!(f, "{}", self.name)
59    }
60}
61
62impl FetchSource for Entity {
63    fn resolve<'a, 'w, Q: Fetch<'w>>(
64        &self,
65        _fetch: &Q,
66        data: FetchAccessData<'a>,
67    ) -> Option<(ArchetypeId, &'a Archetype, Option<Slot>)> {
68        let loc = data.world.location(*self).ok()?;
69
70        Some((
71            loc.arch_id,
72            data.world.archetypes.get(loc.arch_id),
73            Some(loc.slot),
74        ))
75    }
76
77    fn describe(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78        self.fmt(f)
79    }
80}
81
82/// Traverse the edges of a relation recursively to find the first entity which matches the fetch
83pub struct Traverse {
84    pub(crate) relation: Entity,
85}
86
87fn traverse_resolve<'a, 'w, Q: Fetch<'w>>(
88    relation: Entity,
89    fetch: &Q,
90    data: FetchAccessData<'a>,
91) -> Option<(ArchetypeId, &'a Archetype, Option<Slot>)> {
92    let mut stack = Vec::new();
93    stack.push((data.arch_id, None));
94    while let Some((arch_id, slot)) = stack.pop() {
95        let data = FetchAccessData {
96            arch_id,
97            arch: data.world.archetypes.get(arch_id),
98            world: data.world,
99        };
100
101        if fetch.filter_arch(data) {
102            return (arch_id, data.arch, slot).into();
103        }
104
105        for (key, _) in data.arch.relations_like(relation) {
106            let target = key.target.unwrap();
107
108            let loc = data
109                .world
110                .location(target)
111                .expect("Relation contains invalid entity");
112
113            stack.push((loc.arch_id, Some(loc.slot)))
114        }
115    }
116
117    None
118}
119impl FetchSource for Traverse {
120    #[inline]
121    fn resolve<'a, 'w, Q: Fetch<'w>>(
122        &self,
123        fetch: &Q,
124        data: FetchAccessData<'a>,
125    ) -> Option<(ArchetypeId, &'a Archetype, Option<Slot>)> {
126        return traverse_resolve(self.relation, fetch, data);
127    }
128
129    fn describe(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
130        write!(f, "transitive({})", self.relation)
131    }
132}
133
134/// A fetch which proxies the source of the wrapped fetch.
135///
136/// This allows you to fetch different entities' components in tandem with the current items in a
137/// fetch.
138///
139/// As an explicit source means the same item may be returned for each in the fetch Q must be read
140/// only, so that the returned items can safely alias. Additionally, this reduces collateral damage
141/// as it forces mutation to be contained to the currently iterated entity (mostly).
142pub struct Source<Q, S> {
143    fetch: Q,
144    source: S,
145}
146
147impl<Q, S> Source<Q, S> {
148    /// Creates a new source fetch
149    pub const fn new(fetch: Q, source: S) -> Self {
150        Self { fetch, source }
151    }
152}
153
154impl<'q, Q, S> FetchItem<'q> for Source<Q, S>
155where
156    Q: FetchItem<'q>,
157{
158    type Item = Q::Item;
159}
160
161impl<'w, Q, S> Fetch<'w> for Source<Q, S>
162where
163    Q: Fetch<'w>,
164    Q::Prepared: for<'x> RandomFetch<'x>,
165    S: FetchSource,
166{
167    const MUTABLE: bool = Q::MUTABLE;
168
169    type Prepared = PreparedSource<'w, Q::Prepared>;
170
171    fn prepare(&'w self, data: super::FetchPrepareData<'w>) -> Option<Self::Prepared> {
172        let (arch_id, arch, slot) = self.source.resolve(&self.fetch, data.into())?;
173
174        // Bounce to the resolved archetype
175        let fetch = self.fetch.prepare(FetchPrepareData {
176            arch,
177            arch_id,
178            old_tick: data.old_tick,
179            new_tick: data.new_tick,
180            world: data.world,
181        })?;
182
183        Some(PreparedSource {
184            slot,
185            fetch,
186            _marker: PhantomData,
187        })
188    }
189
190    fn filter_arch(&self, data: FetchAccessData) -> bool {
191        self.source.resolve(&self.fetch, data).is_some()
192    }
193
194    fn describe(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
195        self.fetch.describe(f)?;
196        write!(f, "(")?;
197        self.source.describe(f)?;
198        write!(f, ")")?;
199        Ok(())
200    }
201
202    fn access(&self, data: FetchAccessData, dst: &mut Vec<Access>) {
203        if let Some((arch_id, arch, _)) = self.source.resolve(&self.fetch, data) {
204            self.fetch.access(
205                FetchAccessData {
206                    arch_id,
207                    world: data.world,
208                    arch,
209                },
210                dst,
211            )
212        }
213    }
214}
215
216// impl<'w, 'q, Q> ReadOnlyFetch<'q> for PreparedSource<Q>
217// where
218//     Q: ReadOnlyFetch<'q>,
219// {
220//     unsafe fn fetch_shared(&'q self, _: crate::archetype::Slot) -> Self::Item {
221//         self.fetch.fetch_shared(self.slot)
222//     }
223// }
224
225impl<'w, 'q, Q> PreparedFetch<'q> for PreparedSource<'w, Q>
226where
227    Q: 'w + RandomFetch<'q>,
228{
229    type Item = Q::Item;
230    const HAS_FILTER: bool = Q::HAS_FILTER;
231
232    unsafe fn filter_slots(&mut self, slots: crate::archetype::Slice) -> crate::archetype::Slice {
233        if let Some(slot) = self.slot {
234            if self.fetch.filter_slots(Slice::single(slot)).is_empty() {
235                Slice::new(slots.end, slots.end)
236            } else {
237                slots
238            }
239        } else {
240            self.fetch.filter_slots(slots)
241        }
242    }
243
244    type Chunk = (Q::Chunk, bool);
245
246    unsafe fn create_chunk(&'q mut self, slice: crate::archetype::Slice) -> Self::Chunk {
247        if let Some(slot) = self.slot {
248            (self.fetch.create_chunk(Slice::single(slot)), true)
249        } else {
250            (self.fetch.create_chunk(slice), false)
251        }
252    }
253
254    unsafe fn fetch_next(chunk: &mut Self::Chunk) -> Self::Item {
255        if chunk.1 {
256            Q::fetch_shared_chunk(&chunk.0, 0)
257        } else {
258            Q::fetch_next(&mut chunk.0)
259        }
260    }
261}
262
263pub struct PreparedSource<'w, Q> {
264    slot: Option<Slot>,
265    fetch: Q,
266    _marker: PhantomData<&'w mut ()>,
267}
268
269#[cfg(test)]
270mod test {
271    use itertools::Itertools;
272
273    use crate::{
274        component,
275        components::{child_of, name},
276        entity_ids, FetchExt, Query, Topo, World,
277    };
278
279    use super::*;
280
281    component! {
282        a: u32,
283        relation(id): (),
284    }
285
286    #[test]
287    fn parent_fetch() {
288        let mut world = World::new();
289
290        let child_1 = Entity::builder()
291            .set(name(), "child.1".into())
292            .set(a(), 8)
293            .spawn(&mut world);
294
295        let root = Entity::builder()
296            .set(name(), "root".into())
297            .set(a(), 4)
298            .spawn(&mut world);
299
300        let child_1_1 = Entity::builder()
301            .set(name(), "child.1.1".into())
302            .spawn(&mut world);
303
304        let child_2 = Entity::builder()
305            .set(name(), "child.2".into())
306            .spawn(&mut world);
307
308        world.set(child_1, child_of(root), ()).unwrap();
309        world.set(child_2, child_of(root), ()).unwrap();
310        world.set(child_1_1, child_of(child_1), ()).unwrap();
311
312        let mut query = Query::new((
313            name().deref(),
314            (name().deref(), a().copied()).relation(child_of).opt(),
315        ))
316        .with_strategy(Topo::new(child_of));
317
318        pretty_assertions::assert_eq!(
319            query.borrow(&world).iter().collect_vec(),
320            [
321                ("root", None),
322                ("child.1", Some(("root", 4))),
323                ("child.1.1", Some(("child.1", 8))),
324                ("child.2", Some(("root", 4))),
325            ]
326        );
327    }
328
329    #[test]
330    fn multi_parent_fetch() {
331        let mut world = World::new();
332
333        let child = Entity::builder()
334            .set(name(), "child".into())
335            .set(a(), 8)
336            .spawn(&mut world);
337
338        let parent = Entity::builder()
339            .set(name(), "parent".into())
340            .spawn(&mut world);
341
342        let parent2 = Entity::builder()
343            .set(name(), "parent2".into())
344            .set(a(), 8)
345            .spawn(&mut world);
346
347        world.set(child, relation(parent), ()).unwrap();
348        world.set(child, relation(parent2), ()).unwrap();
349
350        let mut query = Query::new((
351            name().deref(),
352            (name().deref(), a().copied()).relation(relation).opt(),
353        ))
354        .with_strategy(Topo::new(relation));
355
356        assert_eq!(
357            query.borrow(&world).iter().collect_vec(),
358            [
359                ("parent", None),
360                ("parent2", None),
361                ("child", Some(("parent2", 8))),
362            ]
363        );
364    }
365
366    #[test]
367    fn traverse() {
368        let mut world = World::new();
369
370        let root = Entity::builder()
371            .set(name(), "root".into())
372            .set(a(), 5)
373            .spawn(&mut world);
374
375        let root3 = Entity::builder()
376            .set(name(), "root".into())
377            .spawn(&mut world);
378
379        let root2 = Entity::builder()
380            .set(name(), "root2".into())
381            .set(a(), 7)
382            .spawn(&mut world);
383
384        let child_1 = Entity::builder()
385            .set(name(), "child_1".into())
386            .set(relation(root), ())
387            .spawn(&mut world);
388
389        let _child_3 = Entity::builder()
390            .set(name(), "child_3".into())
391            .set(relation(root2), ())
392            .spawn(&mut world);
393
394        let _child_4 = Entity::builder()
395            .set(name(), "child_4".into())
396            .set(relation(root3), ())
397            .spawn(&mut world);
398
399        let _child_5 = Entity::builder()
400            .set(name(), "child_5".into())
401            .set(relation(root3), ())
402            .set(relation(root2), ())
403            .spawn(&mut world);
404
405        let _child_2 = Entity::builder()
406            .set(name(), "child_2".into())
407            .set(relation(root), ())
408            .spawn(&mut world);
409
410        let _child_1_1 = Entity::builder()
411            .set(name(), "child_1_1".into())
412            .set(relation(child_1), ())
413            .spawn(&mut world);
414
415        let mut query = Query::new((
416            name().deref(),
417            (name().deref(), a().copied()).traverse(relation),
418        ));
419
420        assert_eq!(
421            query.borrow(&world).iter().sorted().collect_vec(),
422            [
423                ("child_1", ("root", 5)),
424                ("child_1_1", ("root", 5)),
425                ("child_2", ("root", 5)),
426                ("child_3", ("root2", 7)),
427                ("child_5", ("root2", 7)),
428                ("root", ("root", 5)),
429                ("root2", ("root2", 7)),
430            ]
431        );
432    }
433
434    #[test]
435    fn id_source() {
436        let mut world = World::new();
437
438        let _id1 = Entity::builder()
439            .set(name(), "id1".to_string())
440            .spawn(&mut world);
441        let _id2 = Entity::builder()
442            .set(name(), "id2".to_string())
443            .spawn(&mut world);
444
445        let id3 = Entity::builder()
446            .set(name(), "id3".to_string())
447            .set(a(), 5)
448            .spawn(&mut world);
449
450        let mut query = Query::new((
451            name().cloned(),
452            Source {
453                source: id3,
454                fetch: (entity_ids(), a(), name().cloned()),
455            },
456        ));
457
458        assert_eq!(
459            query.borrow(&world).iter().collect_vec(),
460            &[
461                ("id1".to_string(), (id3, &5, "id3".to_string())),
462                ("id2".to_string(), (id3, &5, "id3".to_string())),
463                ("id3".to_string(), (id3, &5, "id3".to_string()))
464            ]
465        );
466
467        let mut query2 = Query::new((
468            name().cloned(),
469            Source {
470                source: id3,
471                fetch: (a().maybe_mut()),
472            },
473        ));
474
475        for (name, id3_a) in &mut query2.borrow(&world) {
476            *id3_a.write() += name.len() as u32;
477        }
478
479        use alloc::string::ToString;
480
481        assert_eq!(
482            query.borrow(&world).iter().collect_vec(),
483            &[
484                ("id1".to_string(), (id3, &14, "id3".to_string())),
485                ("id2".to_string(), (id3, &14, "id3".to_string())),
486                ("id3".to_string(), (id3, &14, "id3".to_string()))
487            ]
488        );
489    }
490}