1use std::iter::FusedIterator;
2use std::mem::replace;
3
4use rayon::iter::{
5 plumbing::{bridge, Consumer, Folder, Producer, ProducerCallback, UnindexedConsumer},
6 IndexedParallelIterator, ParallelIterator,
7};
8
9use crate::{
10 archetype::Archetype,
11 query::{Fetch, QuerySpec},
12};
13
14pub struct QueryParIter<'q, S>
16where
17 S: QuerySpec,
18{
19 comps: &'q [(u16, <S::Fetch as Fetch<'q>>::Ty)],
20 archetypes: &'q [Archetype],
21 idx: u32,
22 len: u32,
23}
24
25unsafe impl<'q, S> Send for QueryParIter<'q, S>
26where
27 S: QuerySpec,
28 <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
29{
30}
31
32impl<'q, S> QueryParIter<'q, S>
33where
34 S: QuerySpec,
35{
36 pub(crate) fn new(
37 comps: &'q [(u16, <S::Fetch as Fetch<'q>>::Ty)],
38 archetypes: &'q [Archetype],
39 ) -> Self {
40 let len = comps
41 .iter()
42 .map(|(idx, _ty)| archetypes[*idx as usize].len())
43 .sum();
44
45 Self {
46 comps,
47 archetypes,
48 idx: 0,
49 len,
50 }
51 }
52}
53
54impl<'q, S> ParallelIterator for QueryParIter<'q, S>
55where
56 S: QuerySpec,
57 <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
58 <S::Fetch as Fetch<'q>>::Item: Send,
59{
60 type Item = <S::Fetch as Fetch<'q>>::Item;
61
62 fn drive_unindexed<C>(self, consumer: C) -> C::Result
63 where
64 C: UnindexedConsumer<Self::Item>,
65 {
66 bridge(self, consumer)
67 }
68
69 fn opt_len(&self) -> Option<usize> {
70 Some(self.len())
71 }
72}
73
74impl<'q, S> IndexedParallelIterator for QueryParIter<'q, S>
75where
76 S: QuerySpec,
77 <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
78 <S::Fetch as Fetch<'q>>::Item: Send,
79{
80 fn drive<C>(self, consumer: C) -> C::Result
81 where
82 C: Consumer<Self::Item>,
83 {
84 bridge(self, consumer)
85 }
86
87 fn len(&self) -> usize {
88 (self.len - self.idx) as usize
89 }
90
91 fn with_producer<CB>(self, callback: CB) -> CB::Output
92 where
93 CB: ProducerCallback<Self::Item>,
94 {
95 callback.callback(self)
96 }
97}
98
99impl<'q, S> Producer for QueryParIter<'q, S>
100where
101 S: QuerySpec,
102 <S::Fetch as Fetch<'q>>::Ty: Send + Sync,
103 <S::Fetch as Fetch<'q>>::Item: Send,
104{
105 type Item = <S::Fetch as Fetch<'q>>::Item;
106 type IntoIter = QueryParIterIntoIter<'q, S>;
107
108 fn into_iter(self) -> Self::IntoIter {
109 let mut sum = 0;
110
111 let mut first = 0;
112 let mut last = 0;
113
114 let mut idx = 0;
115 let mut len = 0;
116 let mut ptr = S::Fetch::dangling();
117
118 let idx_back = 0;
119 let mut len_back = 0;
120 let mut ptr_back = S::Fetch::dangling();
121
122 for (pos, (archetype_idx, ty)) in self.comps.iter().enumerate() {
123 let archetype = &self.archetypes[*archetype_idx as usize];
124
125 if archetype.len() == 0 {
126 continue;
127 }
128
129 sum += archetype.len();
130
131 if self.idx >= sum {
132 continue;
133 }
134
135 if sum - self.idx <= archetype.len() {
136 first = pos + 1;
137
138 idx = archetype.len() - (sum - self.idx);
139 len = archetype.len();
140 ptr = unsafe { S::Fetch::base_pointer(archetype, *ty) };
141
142 if self.len <= sum {
143 last = first;
144
145 len -= sum - self.len;
146
147 break;
148 }
149 }
150
151 if self.len <= sum {
152 last = pos;
153
154 len_back = archetype.len() - (sum - self.len);
155 ptr_back = unsafe { S::Fetch::base_pointer(archetype, *ty) };
156
157 break;
158 }
159 }
160
161 let comps = &self.comps[first..last];
162
163 QueryParIterIntoIter {
164 comps,
165 archetypes: self.archetypes,
166 idx,
167 len,
168 ptr,
169 idx_back,
170 len_back,
171 ptr_back,
172 }
173 }
174
175 fn fold_with<F>(self, mut folder: F) -> F
176 where
177 F: Folder<Self::Item>,
178 {
179 let mut sum = 0;
180
181 for (archetype_idx, ty) in self.comps {
182 if self.len <= sum {
183 break;
184 }
185
186 let archetype = &self.archetypes[*archetype_idx as usize];
187
188 if archetype.len() == 0 {
189 continue;
190 }
191
192 sum += archetype.len();
193
194 if self.idx >= sum {
195 continue;
196 }
197
198 let mut idx = 0;
199 let mut len = archetype.len();
200 let ptr = unsafe { S::Fetch::base_pointer(archetype, *ty) };
201
202 if sum - self.idx < len {
203 idx = len - (sum - self.idx);
204 }
205
206 if self.len < sum {
207 len -= sum - self.len;
208 }
209
210 while idx != len {
211 let val = unsafe { S::Fetch::deref(ptr, idx) };
212 idx += 1;
213
214 folder = folder.consume(val);
215 if folder.full() {
216 break;
217 }
218 }
219 }
220
221 folder
222 }
223
224 fn split_at(self, mid: usize) -> (Self, Self) {
225 let mid = self.idx + mid as u32;
226
227 let lhs = Self {
228 comps: self.comps,
229 archetypes: self.archetypes,
230 idx: self.idx,
231 len: mid,
232 };
233
234 let rhs = Self {
235 comps: self.comps,
236 archetypes: self.archetypes,
237 idx: mid,
238 len: self.len,
239 };
240
241 (lhs, rhs)
242 }
243}
244
245pub struct QueryParIterIntoIter<'q, S>
246where
247 S: QuerySpec,
248{
249 comps: &'q [(u16, <S::Fetch as Fetch<'q>>::Ty)],
250 archetypes: &'q [Archetype],
251 idx: u32,
252 len: u32,
253 ptr: <S::Fetch as Fetch<'q>>::Ptr,
254 idx_back: u32,
255 len_back: u32,
256 ptr_back: <S::Fetch as Fetch<'q>>::Ptr,
257}
258
259impl<'q, S> Iterator for QueryParIterIntoIter<'q, S>
260where
261 S: QuerySpec,
262{
263 type Item = <S::Fetch as Fetch<'q>>::Item;
264
265 fn next(&mut self) -> Option<Self::Item> {
266 loop {
267 if self.idx != self.len {
268 let val = unsafe { S::Fetch::deref(self.ptr, self.idx) };
269 self.idx += 1;
270 return Some(val);
271 } else {
272 match self.comps.split_first() {
273 Some(((idx, ty), rest)) => {
274 self.comps = rest;
275
276 let archetype = &self.archetypes[*idx as usize];
277 self.idx = 0;
278 self.len = archetype.len();
279 self.ptr = unsafe { S::Fetch::base_pointer(archetype, *ty) };
280 }
281 None => {
282 if self.idx_back == self.len_back {
283 return None;
284 } else {
285 self.idx = replace(&mut self.idx_back, 0);
286 self.len = replace(&mut self.len_back, 0);
287 self.ptr = replace(&mut self.ptr_back, S::Fetch::dangling());
288 }
289 }
290 }
291 }
292 }
293 }
294
295 fn size_hint(&self) -> (usize, Option<usize>) {
296 let len = self.len();
297 (len, Some(len))
298 }
299}
300
301impl<S> DoubleEndedIterator for QueryParIterIntoIter<'_, S>
302where
303 S: QuerySpec,
304{
305 fn next_back(&mut self) -> Option<Self::Item> {
306 loop {
307 if self.idx_back != self.len_back {
308 let val = unsafe { S::Fetch::deref(self.ptr, self.len_back - 1) };
309 self.len_back -= 1;
310 return Some(val);
311 } else {
312 match self.comps.split_last() {
313 Some(((idx, ty), rest)) => {
314 self.comps = rest;
315
316 let archetype = &self.archetypes[*idx as usize];
317 self.idx_back = 0;
318 self.len_back = archetype.len();
319 self.ptr_back = unsafe { S::Fetch::base_pointer(archetype, *ty) };
320 }
321 None => {
322 if self.idx == self.len {
323 return None;
324 } else {
325 self.idx_back = replace(&mut self.idx, 0);
326 self.len_back = replace(&mut self.len, 0);
327 self.ptr_back = replace(&mut self.ptr, S::Fetch::dangling());
328 }
329 }
330 }
331 }
332 }
333 }
334}
335
336impl<S> ExactSizeIterator for QueryParIterIntoIter<'_, S>
337where
338 S: QuerySpec,
339{
340 fn len(&self) -> usize {
341 let len = self
342 .comps
343 .iter()
344 .map(|(idx, _)| self.archetypes[*idx as usize].len())
345 .sum::<u32>()
346 + self.len
347 - self.idx;
348 len as usize
349 }
350}
351
352impl<S> FusedIterator for QueryParIterIntoIter<'_, S> where S: QuerySpec {}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 use crate::{query::Query, world::World};
359
360 struct Pos(#[allow(dead_code)] f32);
361 struct Vel(#[allow(dead_code)] f32);
362
363 fn spawn_two<const N: usize>(world: &mut World) {
364 let ent = world.alloc();
365 world.insert(ent, (Pos(0.), Vel(0.), [N; 1], [0; 2], [0; 3], [(); N]));
366 world.remove::<([i32; 2],)>(ent).unwrap();
367
368 let ent = world.alloc();
369 world.insert(ent, (Pos(0.), [0; 4], [0; 5], [(); N]));
370 world.remove::<([i32; 4],)>(ent).unwrap();
371 }
372
373 fn spawn_few(world: &mut World) {
374 for _ in 0..131072 / 2 {
375 spawn_two::<1>(world);
376 }
377 }
378
379 fn spawn_few_in_many_archetypes(world: &mut World) {
380 for _ in 0..131072 / 2 / 8 {
381 spawn_two::<1>(world);
382 spawn_two::<2>(world);
383 spawn_two::<3>(world);
384 spawn_two::<4>(world);
385 spawn_two::<5>(world);
386 spawn_two::<6>(world);
387 spawn_two::<7>(world);
388 spawn_two::<8>(world);
389 }
390 }
391
392 fn spawn_few_in_very_many_small_archetypes(world: &mut World) {
393 for _ in 0..1024 / 2 / 32 {
394 spawn_two::<1>(world);
395 spawn_two::<2>(world);
396 spawn_two::<3>(world);
397 spawn_two::<4>(world);
398 spawn_two::<5>(world);
399 spawn_two::<6>(world);
400 spawn_two::<7>(world);
401 spawn_two::<8>(world);
402 spawn_two::<9>(world);
403 spawn_two::<10>(world);
404 spawn_two::<11>(world);
405 spawn_two::<12>(world);
406 spawn_two::<13>(world);
407 spawn_two::<14>(world);
408 spawn_two::<15>(world);
409 spawn_two::<16>(world);
410 spawn_two::<17>(world);
411 spawn_two::<18>(world);
412 spawn_two::<19>(world);
413 spawn_two::<20>(world);
414 spawn_two::<21>(world);
415 spawn_two::<22>(world);
416 spawn_two::<23>(world);
417 spawn_two::<24>(world);
418 spawn_two::<25>(world);
419 spawn_two::<26>(world);
420 spawn_two::<27>(world);
421 spawn_two::<28>(world);
422 spawn_two::<29>(world);
423 spawn_two::<30>(world);
424 spawn_two::<31>(world);
425 spawn_two::<32>(world);
426 }
427 }
428
429 #[test]
430 fn it_works_with_a_single_archetype() {
431 let mut world = World::new();
432
433 spawn_few(&mut world);
434
435 let mut query = Query::<(&mut Pos, &Vel, &[usize; 1])>::new();
436
437 let sum = query
438 .borrow(&world)
439 .par_iter()
440 .map(|(_pos, _vel, comp)| comp[0])
441 .sum::<usize>();
442
443 assert_eq!(sum, 65_536)
444 }
445
446 #[test]
447 fn it_works_with_many_archetypes() {
448 let mut world = World::new();
449
450 spawn_few_in_many_archetypes(&mut world);
451
452 let mut query = Query::<(&mut Pos, &Vel, &[usize; 1])>::new();
453
454 let sum = query
455 .borrow(&world)
456 .par_iter()
457 .map(|(_pos, _vel, comp)| comp[0])
458 .sum::<usize>();
459
460 assert_eq!(sum, 294_912)
461 }
462
463 #[test]
464 fn it_works_with_very_many_small_archetypes() {
465 let mut world = World::new();
466
467 spawn_few_in_very_many_small_archetypes(&mut world);
468
469 let mut query = Query::<(&mut Pos, &Vel, &[usize; 1])>::new();
470
471 let sum = query
472 .borrow(&world)
473 .par_iter()
474 .map(|(_pos, _vel, comp)| comp[0])
475 .sum::<usize>();
476
477 assert_eq!(sum, 8_448)
478 }
479}