broccoli_rayon/queries/
colfind.rs

1use broccoli::{
2    aabb::pin::AabbPin,
3    aabb::Aabb,
4    queries::colfind::{
5        build::{CollisionHandler, CollisionVisitor, NodeHandler},
6        oned::DefaultNodeHandler,
7    },
8    Tree,
9};
10
11pub trait CollisionHandlerExt<T: Aabb>: CollisionHandler<T> + Sized {
12    ///Called to split this into two to be passed to the children.
13    fn div(&mut self) -> Self;
14
15    ///Called to add the results of the recursive calls on the children.
16    fn add(&mut self, b: Self);
17}
18
19pub trait NodeHandlerExt<T: Aabb>: NodeHandler<T> + Sized {
20    ///Called to split this into two to be passed to the children.
21    fn div(&mut self) -> Self;
22
23    ///Called to add the results of the recursive calls on the children.
24    fn add(&mut self, b: Self);
25}
26
27//pub const SEQ_FALLBACK_DEFAULT: usize = 512;
28pub const SEQ_FALLBACK_DEFAULT: usize = 256;
29
30pub trait RayonQueryPar<'a, T: Aabb> {
31    fn par_find_colliding_pairs<F>(&mut self, func: F)
32    where
33        F: FnMut(AabbPin<&mut T>, AabbPin<&mut T>),
34        F: Send + Clone,
35        T: Send,
36        T::Num: Send;
37
38    fn par_find_colliding_pairs_acc_closure<Acc, A, B, F>(
39        &mut self,
40        acc: Acc,
41        div: A,
42        add: B,
43        func: F,
44    ) -> Acc
45    where
46        A: FnMut(&mut Acc) -> Acc + Clone + Send,
47        B: FnMut(&mut Acc, Acc) + Clone + Send,
48        F: FnMut(&mut Acc, AabbPin<&mut T>, AabbPin<&mut T>) + Clone + Send,
49        Acc: Send,
50        T: Send,
51        T::Num: Send;
52}
53
54impl<'a, T: Aabb> RayonQueryPar<'a, T> for Tree<'a, T> {
55    fn par_find_colliding_pairs_acc_closure<Acc, A, B, F>(
56        &mut self,
57        acc: Acc,
58        div: A,
59        add: B,
60        func: F,
61    ) -> Acc
62    where
63        A: FnMut(&mut Acc) -> Acc + Clone + Send,
64        B: FnMut(&mut Acc, Acc) + Clone + Send,
65        F: FnMut(&mut Acc, AabbPin<&mut T>, AabbPin<&mut T>) + Clone + Send,
66        Acc: Send,
67        T: Send,
68        T::Num: Send,
69    {
70        let mut f = DefaultNodeHandler::new(ClosureExt {
71            acc,
72            div,
73            add,
74            func,
75        });
76
77        let vv = CollisionVisitor::new(self.vistr_mut());
78        recurse_par(vv, &mut f, SEQ_FALLBACK_DEFAULT);
79        f.coll_handler.acc
80    }
81
82    fn par_find_colliding_pairs<F>(&mut self, func: F)
83    where
84        F: FnMut(AabbPin<&mut T>, AabbPin<&mut T>) + Clone,
85        F: Send,
86        T: Send,
87        T::Num: Send,
88    {
89        let mut f = DefaultNodeHandler::new(func);
90
91        let vv = CollisionVisitor::new(self.vistr_mut());
92        recurse_par(vv, &mut f, SEQ_FALLBACK_DEFAULT);
93    }
94}
95
96impl<F, T: Aabb> CollisionHandlerExt<T> for F
97where
98    F: Clone + FnMut(AabbPin<&mut T>, AabbPin<&mut T>),
99{
100    fn div(&mut self) -> Self {
101        self.clone()
102    }
103
104    fn add(&mut self, _: Self) {}
105}
106
107///
108/// Collision call back handler that has callbacks
109/// to handle the events where the closure has to be split
110/// off and then joined again.
111///
112pub struct ClosureExt<K, A, B, F> {
113    pub acc: K,
114    pub div: A,
115    pub add: B,
116    pub func: F,
117}
118impl<T: Aabb, K, A, B, F> CollisionHandler<T> for ClosureExt<K, A, B, F>
119where
120    F: FnMut(&mut K, AabbPin<&mut T>, AabbPin<&mut T>),
121{
122    fn collide(&mut self, a: AabbPin<&mut T>, b: AabbPin<&mut T>) {
123        (self.func)(&mut self.acc, a, b)
124    }
125}
126impl<T: Aabb, K, A: FnMut(&mut K) -> K + Clone, B: FnMut(&mut K, K) + Clone, F: Clone>
127    CollisionHandlerExt<T> for ClosureExt<K, A, B, F>
128where
129    F: FnMut(&mut K, AabbPin<&mut T>, AabbPin<&mut T>),
130{
131    fn div(&mut self) -> Self {
132        ClosureExt {
133            acc: (self.div)(&mut self.acc),
134            div: self.div.clone(),
135            add: self.add.clone(),
136            func: self.func.clone(),
137        }
138    }
139
140    fn add(&mut self, b: Self) {
141        (self.add)(&mut self.acc, b.acc)
142    }
143}
144
145impl<Acc: CollisionHandlerExt<T>, T: Aabb> NodeHandlerExt<T> for DefaultNodeHandler<Acc> {
146    fn div(&mut self) -> Self {
147        DefaultNodeHandler::new(self.coll_handler.div())
148    }
149
150    fn add(&mut self, b: Self) {
151        self.coll_handler.add(b.coll_handler);
152    }
153}
154
155pub fn recurse_par<T: Aabb, SO: NodeHandlerExt<T>>(
156    vistr: CollisionVisitor<T>,
157    handler: &mut SO,
158    num_seq_fallback: usize,
159) where
160    T: Send,
161    T::Num: Send,
162    SO: Send,
163{
164    if vistr.num_elem() <= num_seq_fallback {
165        vistr.recurse_seq(handler);
166    } else {
167        let (n, rest) = vistr.collide_and_next(handler);
168        if let Some([left, right]) = rest {
169            let mut h2 = handler.div();
170
171            rayon::join(
172                || {
173                    n.finish(handler);
174                    recurse_par(left, handler, num_seq_fallback)
175                },
176                || recurse_par(right, &mut h2, num_seq_fallback),
177            );
178            handler.add(h2);
179        } else {
180            n.finish(handler);
181        }
182    }
183}