rs_ecs/
rayon.rs

1use std::iter::FusedIterator;
2use std::mem::replace;
3
4use rayon::iter::{
5    plumbing::{bridge, Consumer, Folder, Producer, ProducerCallback, UnindexedConsumer},
6    IndexedParallelIterator, ParallelIterator,
7};
8
9use crate::{
10    archetype::Archetype,
11    query::{Fetch, QuerySpec},
12};
13
14/// Used to iterate through the entities which match a certain [Query][crate::query::Query] in parallel.
15pub struct QueryParIter<'q, S>
16where
17    S: QuerySpec,
18{
19    comps: &'q [(u16, <S::Fetch as Fetch<'q>>::Ty)],
20    archetypes: &'q [Archetype],
21    idx: u32,
22    len: u32,
23}
24
25unsafe impl<'q, S> Send for QueryParIter<'q, S>
26where
27    S: QuerySpec,
28    <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
29{
30}
31
32impl<'q, S> QueryParIter<'q, S>
33where
34    S: QuerySpec,
35{
36    pub(crate) fn new(
37        comps: &'q [(u16, <S::Fetch as Fetch<'q>>::Ty)],
38        archetypes: &'q [Archetype],
39    ) -> Self {
40        let len = comps
41            .iter()
42            .map(|(idx, _ty)| archetypes[*idx as usize].len())
43            .sum();
44
45        Self {
46            comps,
47            archetypes,
48            idx: 0,
49            len,
50        }
51    }
52}
53
54impl<'q, S> ParallelIterator for QueryParIter<'q, S>
55where
56    S: QuerySpec,
57    <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
58    <S::Fetch as Fetch<'q>>::Item: Send,
59{
60    type Item = <S::Fetch as Fetch<'q>>::Item;
61
62    fn drive_unindexed<C>(self, consumer: C) -> C::Result
63    where
64        C: UnindexedConsumer<Self::Item>,
65    {
66        bridge(self, consumer)
67    }
68
69    fn opt_len(&self) -> Option<usize> {
70        Some(self.len())
71    }
72}
73
74impl<'q, S> IndexedParallelIterator for QueryParIter<'q, S>
75where
76    S: QuerySpec,
77    <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
78    <S::Fetch as Fetch<'q>>::Item: Send,
79{
80    fn drive<C>(self, consumer: C) -> C::Result
81    where
82        C: Consumer<Self::Item>,
83    {
84        bridge(self, consumer)
85    }
86
87    fn len(&self) -> usize {
88        (self.len - self.idx) as usize
89    }
90
91    fn with_producer<CB>(self, callback: CB) -> CB::Output
92    where
93        CB: ProducerCallback<Self::Item>,
94    {
95        callback.callback(self)
96    }
97}
98
99impl<'q, S> Producer for QueryParIter<'q, S>
100where
101    S: QuerySpec,
102    <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
103    <S::Fetch as Fetch<'q>>::Item: Send,
104{
105    type Item = <S::Fetch as Fetch<'q>>::Item;
106    type IntoIter = QueryParIterIntoIter<'q, S>;
107
108    fn into_iter(self) -> Self::IntoIter {
109        let mut sum = 0;
110
111        let mut first = 0;
112        let mut last = 0;
113
114        let mut idx = 0;
115        let mut len = 0;
116        let mut ptr = S::Fetch::dangling();
117
118        let idx_back = 0;
119        let mut len_back = 0;
120        let mut ptr_back = S::Fetch::dangling();
121
122        for (pos, (archetype_idx, ty)) in self.comps.iter().enumerate() {
123            let archetype = &self.archetypes[*archetype_idx as usize];
124
125            if archetype.len() == 0 {
126                continue;
127            }
128
129            sum += archetype.len();
130
131            if self.idx >= sum {
132                continue;
133            }
134
135            if sum - self.idx <= archetype.len() {
136                first = pos + 1;
137
138                idx = archetype.len() - (sum - self.idx);
139                len = archetype.len();
140                ptr = unsafe { S::Fetch::base_pointer(archetype, *ty) };
141
142                if self.len <= sum {
143                    last = first;
144
145                    len -= sum - self.len;
146
147                    break;
148                }
149            }
150
151            if self.len <= sum {
152                last = pos;
153
154                len_back = archetype.len() - (sum - self.len);
155                ptr_back = unsafe { S::Fetch::base_pointer(archetype, *ty) };
156
157                break;
158            }
159        }
160
161        let comps = &self.comps[first..last];
162
163        QueryParIterIntoIter {
164            comps,
165            archetypes: self.archetypes,
166            idx,
167            len,
168            ptr,
169            idx_back,
170            len_back,
171            ptr_back,
172        }
173    }
174
175    fn fold_with<F>(self, mut folder: F) -> F
176    where
177        F: Folder<Self::Item>,
178    {
179        let mut sum = 0;
180
181        for (archetype_idx, ty) in self.comps {
182            if self.len <= sum {
183                break;
184            }
185
186            let archetype = &self.archetypes[*archetype_idx as usize];
187
188            if archetype.len() == 0 {
189                continue;
190            }
191
192            sum += archetype.len();
193
194            if self.idx >= sum {
195                continue;
196            }
197
198            let mut idx = 0;
199            let mut len = archetype.len();
200            let ptr = unsafe { S::Fetch::base_pointer(archetype, *ty) };
201
202            if sum - self.idx < len {
203                idx = len - (sum - self.idx);
204            }
205
206            if self.len < sum {
207                len -= sum - self.len;
208            }
209
210            while idx != len {
211                let val = unsafe { S::Fetch::deref(ptr, idx) };
212                idx += 1;
213
214                folder = folder.consume(val);
215                if folder.full() {
216                    break;
217                }
218            }
219        }
220
221        folder
222    }
223
224    fn split_at(self, mid: usize) -> (Self, Self) {
225        let mid = self.idx + mid as u32;
226
227        let lhs = Self {
228            comps: self.comps,
229            archetypes: self.archetypes,
230            idx: self.idx,
231            len: mid,
232        };
233
234        let rhs = Self {
235            comps: self.comps,
236            archetypes: self.archetypes,
237            idx: mid,
238            len: self.len,
239        };
240
241        (lhs, rhs)
242    }
243}
244
245pub struct QueryParIterIntoIter<'q, S>
246where
247    S: QuerySpec,
248{
249    comps: &'q [(u16, <S::Fetch as Fetch<'q>>::Ty)],
250    archetypes: &'q [Archetype],
251    idx: u32,
252    len: u32,
253    ptr: <S::Fetch as Fetch<'q>>::Ptr,
254    idx_back: u32,
255    len_back: u32,
256    ptr_back: <S::Fetch as Fetch<'q>>::Ptr,
257}
258
259impl<'q, S> Iterator for QueryParIterIntoIter<'q, S>
260where
261    S: QuerySpec,
262{
263    type Item = <S::Fetch as Fetch<'q>>::Item;
264
265    fn next(&mut self) -> Option<Self::Item> {
266        loop {
267            if self.idx != self.len {
268                let val = unsafe { S::Fetch::deref(self.ptr, self.idx) };
269                self.idx += 1;
270                return Some(val);
271            } else {
272                match self.comps.split_first() {
273                    Some(((idx, ty), rest)) => {
274                        self.comps = rest;
275
276                        let archetype = &self.archetypes[*idx as usize];
277                        self.idx = 0;
278                        self.len = archetype.len();
279                        self.ptr = unsafe { S::Fetch::base_pointer(archetype, *ty) };
280                    }
281                    None => {
282                        if self.idx_back == self.len_back {
283                            return None;
284                        } else {
285                            self.idx = replace(&mut self.idx_back, 0);
286                            self.len = replace(&mut self.len_back, 0);
287                            self.ptr = replace(&mut self.ptr_back, S::Fetch::dangling());
288                        }
289                    }
290                }
291            }
292        }
293    }
294
295    fn size_hint(&self) -> (usize, Option<usize>) {
296        let len = self.len();
297        (len, Some(len))
298    }
299}
300
301impl<S> DoubleEndedIterator for QueryParIterIntoIter<'_, S>
302where
303    S: QuerySpec,
304{
305    fn next_back(&mut self) -> Option<Self::Item> {
306        loop {
307            if self.idx_back != self.len_back {
308                let val = unsafe { S::Fetch::deref(self.ptr, self.len_back - 1) };
309                self.len_back -= 1;
310                return Some(val);
311            } else {
312                match self.comps.split_last() {
313                    Some(((idx, ty), rest)) => {
314                        self.comps = rest;
315
316                        let archetype = &self.archetypes[*idx as usize];
317                        self.idx_back = 0;
318                        self.len_back = archetype.len();
319                        self.ptr_back = unsafe { S::Fetch::base_pointer(archetype, *ty) };
320                    }
321                    None => {
322                        if self.idx == self.len {
323                            return None;
324                        } else {
325                            self.idx_back = replace(&mut self.idx, 0);
326                            self.len_back = replace(&mut self.len, 0);
327                            self.ptr_back = replace(&mut self.ptr, S::Fetch::dangling());
328                        }
329                    }
330                }
331            }
332        }
333    }
334}
335
336impl<S> ExactSizeIterator for QueryParIterIntoIter<'_, S>
337where
338    S: QuerySpec,
339{
340    fn len(&self) -> usize {
341        let len = self
342            .comps
343            .iter()
344            .map(|(idx, _)| self.archetypes[*idx as usize].len())
345            .sum::<u32>()
346            + self.len
347            - self.idx;
348        len as usize
349    }
350}
351
352impl<S> FusedIterator for QueryParIterIntoIter<'_, S> where S: QuerySpec {}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    use crate::{query::Query, world::World};
359
360    struct Pos(#[allow(dead_code)] f32);
361    struct Vel(#[allow(dead_code)] f32);
362
363    fn spawn_two<const N: usize>(world: &mut World) {
364        let ent = world.alloc();
365        world.insert(ent, (Pos(0.), Vel(0.), [N; 1], [0; 2], [0; 3], [(); N]));
366        world.remove::<([i32; 2],)>(ent).unwrap();
367
368        let ent = world.alloc();
369        world.insert(ent, (Pos(0.), [0; 4], [0; 5], [(); N]));
370        world.remove::<([i32; 4],)>(ent).unwrap();
371    }
372
373    fn spawn_few(world: &mut World) {
374        for _ in 0..131072 / 2 {
375            spawn_two::<1>(world);
376        }
377    }
378
379    fn spawn_few_in_many_archetypes(world: &mut World) {
380        for _ in 0..131072 / 2 / 8 {
381            spawn_two::<1>(world);
382            spawn_two::<2>(world);
383            spawn_two::<3>(world);
384            spawn_two::<4>(world);
385            spawn_two::<5>(world);
386            spawn_two::<6>(world);
387            spawn_two::<7>(world);
388            spawn_two::<8>(world);
389        }
390    }
391
392    fn spawn_few_in_very_many_small_archetypes(world: &mut World) {
393        for _ in 0..1024 / 2 / 32 {
394            spawn_two::<1>(world);
395            spawn_two::<2>(world);
396            spawn_two::<3>(world);
397            spawn_two::<4>(world);
398            spawn_two::<5>(world);
399            spawn_two::<6>(world);
400            spawn_two::<7>(world);
401            spawn_two::<8>(world);
402            spawn_two::<9>(world);
403            spawn_two::<10>(world);
404            spawn_two::<11>(world);
405            spawn_two::<12>(world);
406            spawn_two::<13>(world);
407            spawn_two::<14>(world);
408            spawn_two::<15>(world);
409            spawn_two::<16>(world);
410            spawn_two::<17>(world);
411            spawn_two::<18>(world);
412            spawn_two::<19>(world);
413            spawn_two::<20>(world);
414            spawn_two::<21>(world);
415            spawn_two::<22>(world);
416            spawn_two::<23>(world);
417            spawn_two::<24>(world);
418            spawn_two::<25>(world);
419            spawn_two::<26>(world);
420            spawn_two::<27>(world);
421            spawn_two::<28>(world);
422            spawn_two::<29>(world);
423            spawn_two::<30>(world);
424            spawn_two::<31>(world);
425            spawn_two::<32>(world);
426        }
427    }
428
429    #[test]
430    fn it_works_with_a_single_archetype() {
431        let mut world = World::new();
432
433        spawn_few(&mut world);
434
435        let mut query = Query::<(&mut Pos, &Vel, &[usize; 1])>::new();
436
437        let sum = query
438            .borrow(&world)
439            .par_iter()
440            .map(|(_pos, _vel, comp)| comp[0])
441            .sum::<usize>();
442
443        assert_eq!(sum, 65_536)
444    }
445
446    #[test]
447    fn it_works_with_many_archetypes() {
448        let mut world = World::new();
449
450        spawn_few_in_many_archetypes(&mut world);
451
452        let mut query = Query::<(&mut Pos, &Vel, &[usize; 1])>::new();
453
454        let sum = query
455            .borrow(&world)
456            .par_iter()
457            .map(|(_pos, _vel, comp)| comp[0])
458            .sum::<usize>();
459
460        assert_eq!(sum, 294_912)
461    }
462
463    #[test]
464    fn it_works_with_very_many_small_archetypes() {
465        let mut world = World::new();
466
467        spawn_few_in_very_many_small_archetypes(&mut world);
468
469        let mut query = Query::<(&mut Pos, &Vel, &[usize; 1])>::new();
470
471        let sum = query
472            .borrow(&world)
473            .par_iter()
474            .map(|(_pos, _vel, comp)| comp[0])
475            .sum::<usize>();
476
477        assert_eq!(sum, 8_448)
478    }
479}