oxygengine_core/ecs/pipeline/engines/
jobs.rs

1#[cfg(feature = "parallel")]
2use crate::ecs::System;
3use crate::ecs::{
4    pipeline::{PipelineEngine, PipelineGraph, PipelineGraphSystem},
5    Universe,
6};
7#[cfg(feature = "parallel")]
8use std::{
9    any::TypeId,
10    cell::RefCell,
11    collections::HashMap,
12    sync::{
13        mpsc::{channel, Receiver, Sender},
14        Arc, Condvar, Mutex,
15    },
16    thread::JoinHandle,
17    time::{Duration, Instant},
18};
19
20#[cfg(feature = "parallel")]
21#[derive(Debug, Default)]
22struct Access {
23    read: usize,
24    write: bool,
25}
26
27#[cfg(feature = "parallel")]
28impl Access {
29    pub fn can_read(&self) -> bool {
30        !self.write
31    }
32
33    pub fn can_write(&self) -> bool {
34        !self.write && self.read == 0
35    }
36
37    pub fn acquire_read(&mut self) {
38        if self.can_read() {
39            self.read += 1;
40        }
41    }
42
43    pub fn acquire_write(&mut self) {
44        if self.can_write() {
45            self.write = true;
46        }
47    }
48
49    pub fn release_read(&mut self) {
50        self.read = self.read.checked_sub(1).unwrap_or_default();
51    }
52
53    pub fn release_write(&mut self) {
54        self.write = false;
55    }
56}
57
58#[cfg(feature = "parallel")]
59struct Message {
60    universe: &'static Universe,
61    system: System,
62}
63
64#[cfg(feature = "parallel")]
65type Notifier = Arc<(Mutex<bool>, Condvar)>;
66
67#[cfg(feature = "parallel")]
68struct Worker {
69    sender: Sender<Option<Message>>,
70    handle: JoinHandle<()>,
71}
72
73pub struct JobsPipelineEngine {
74    #[cfg(feature = "parallel")]
75    workers: Vec<Worker>,
76    #[cfg(feature = "parallel")]
77    notifier: Notifier,
78    #[cfg(feature = "parallel")]
79    receiver: Receiver<(usize, Option<(System, Duration)>)>,
80    #[cfg(feature = "parallel")]
81    systems_last_duration: RefCell<Vec<Duration>>,
82    #[cfg(feature = "parallel")]
83    systems_preferred_worker: RefCell<Vec<Option<usize>>>,
84    pub(crate) systems: Vec<PipelineGraphSystem>,
85}
86
87#[cfg(feature = "web")]
88unsafe impl Send for JobsPipelineEngine {}
89#[cfg(feature = "web")]
90unsafe impl Sync for JobsPipelineEngine {}
91
92impl Default for JobsPipelineEngine {
93    #[cfg(not(feature = "parallel"))]
94    fn default() -> Self {
95        Self::new(1)
96    }
97
98    #[cfg(feature = "parallel")]
99    fn default() -> Self {
100        Self::new(rayon::current_num_threads())
101    }
102}
103
104impl JobsPipelineEngine {
105    #[cfg(not(feature = "parallel"))]
106    pub fn new(_jobs_count: usize) -> Self {
107        Self {
108            systems: Default::default(),
109        }
110    }
111
112    #[cfg(feature = "parallel")]
113    pub fn new(jobs_count: usize) -> Self {
114        #[allow(clippy::mutex_atomic)]
115        let notifier = Arc::new((Mutex::new(false), Condvar::new()));
116        let (sender, receiver) = channel();
117        let workers = Self::build_workers(notifier.clone(), sender, jobs_count);
118        Self {
119            workers,
120            notifier,
121            receiver,
122            systems_last_duration: Default::default(),
123            systems_preferred_worker: Default::default(),
124            systems: Default::default(),
125        }
126    }
127
128    #[cfg(feature = "parallel")]
129    fn build_workers(
130        notifier: Notifier,
131        sender: Sender<(usize, Option<(System, Duration)>)>,
132        mut jobs_count: usize,
133    ) -> Vec<Worker> {
134        jobs_count = jobs_count.max(1);
135        (0..jobs_count)
136            .into_iter()
137            .map(|index| {
138                let (my_sender, receiver) = channel();
139                let notifier = Arc::clone(&notifier);
140                let sender = sender.clone();
141                let handle = std::thread::spawn(move || {
142                    let (lock, cvar) = &*notifier;
143                    while let Ok(msg) = receiver.recv() {
144                        if let Some(Message { universe, system }) = msg {
145                            let timer = Instant::now();
146                            #[allow(mutable_transmutes)]
147                            #[allow(clippy::transmute_ptr_to_ptr)]
148                            system(unsafe { std::mem::transmute(universe) });
149                            let _ = sender.send((index, Some((system, timer.elapsed()))));
150                            let mut busy = lock.lock().unwrap();
151                            *busy = false;
152                            cvar.notify_all();
153                        } else {
154                            break;
155                        }
156                    }
157                    let _ = sender.send((index, None));
158                    let mut busy = lock.lock().unwrap();
159                    *busy = false;
160                    cvar.notify_all();
161                });
162                Worker {
163                    sender: my_sender,
164                    handle,
165                }
166            })
167            .collect::<Vec<_>>()
168    }
169
170    #[cfg(feature = "parallel")]
171    fn find_system_to_run(
172        systems_left: &[usize],
173        systems: &[PipelineGraphSystem],
174        resources: &HashMap<TypeId, Access>,
175        worker_index: usize,
176        systems_preferred_worker: &[Option<usize>],
177    ) -> Option<usize> {
178        for index in systems_left {
179            if let Some(index) = systems_preferred_worker[*index] {
180                if worker_index != index {
181                    continue;
182                }
183            }
184            let data = &systems[*index];
185            let can_read = data.reads.iter().all(|id| {
186                resources
187                    .get(id)
188                    .map(|access| access.can_read())
189                    .unwrap_or(true)
190            });
191            let can_write = data.writes.iter().all(|id| {
192                resources
193                    .get(id)
194                    .map(|access| access.can_write())
195                    .unwrap_or(true)
196            });
197            if can_read && can_write {
198                return Some(*index);
199            }
200        }
201        None
202    }
203}
204
205impl PipelineEngine for JobsPipelineEngine {
206    fn setup(&mut self, graph: PipelineGraph) {
207        match graph {
208            PipelineGraph::System(system) => {
209                #[cfg(feature = "parallel")]
210                self.systems_last_duration
211                    .borrow_mut()
212                    .push(Default::default());
213                #[cfg(feature = "parallel")]
214                self.systems_preferred_worker.borrow_mut().push(None);
215                self.systems.push(system);
216            }
217            PipelineGraph::Sequence(list) | PipelineGraph::Parallel(list) => {
218                for item in list {
219                    self.setup(item);
220                }
221            }
222        }
223    }
224
225    fn run(&self, universe: &mut Universe) {
226        #[cfg(not(feature = "parallel"))]
227        {
228            for system in &self.systems {
229                (system.system)(universe);
230            }
231        }
232        #[cfg(feature = "parallel")]
233        {
234            if self.workers.len() <= 1 {
235                for system in &self.systems {
236                    (system.system)(universe);
237                }
238                return;
239            }
240            let mut systems_last_duration = self.systems_last_duration.borrow_mut();
241            let mut systems_preferred_worker = self.systems_preferred_worker.borrow_mut();
242            let mut systems_left = (0..self.systems.len()).into_iter().collect::<Vec<_>>();
243            let mut load = vec![(false, Duration::default()); self.workers.len()];
244            let mut sorted_load = (0..self.workers.len()).into_iter().collect::<Vec<_>>();
245            let mut resources = self
246                .systems
247                .iter()
248                .flat_map(|s| s.reads.iter().chain(s.writes.iter()))
249                .map(|id| (*id, Access::default()))
250                .collect::<HashMap<_, _>>();
251            loop {
252                let (lock, cvar) = &*self.notifier;
253                let mut guard = cvar
254                    .wait_while(lock.lock().unwrap(), |pending| *pending)
255                    .unwrap();
256                if systems_left.is_empty() {
257                    break;
258                }
259                while let Ok((worker_index, duration)) = self.receiver.try_recv() {
260                    let load = &mut load[worker_index];
261                    load.0 = false;
262                    if let Some((system, duration)) = duration {
263                        load.1 += duration;
264                        let found = self.systems.iter().position(|s| {
265                            let a = s.system as *const ();
266                            let b = system as *const ();
267                            a == b
268                        });
269                        if let Some(system_index) = found {
270                            systems_last_duration[system_index] = duration;
271                            if self.systems[system_index].lock_on_single_thread {
272                                systems_preferred_worker[system_index] = Some(worker_index);
273                            }
274                            for id in &self.systems[system_index].reads {
275                                if let Some(access) = resources.get_mut(id) {
276                                    access.release_read();
277                                }
278                            }
279                            for id in &self.systems[system_index].writes {
280                                if let Some(access) = resources.get_mut(id) {
281                                    access.release_write();
282                                }
283                            }
284                        }
285                    }
286                }
287                sorted_load.sort_by(|a, b| load[*a].1.cmp(&load[*b].1));
288                systems_left.sort_by(|a, b| {
289                    self.systems[*a]
290                        .layer
291                        .cmp(&self.systems[*b].layer)
292                        .then_with(|| systems_last_duration[*a].cmp(&systems_last_duration[*b]))
293                });
294                *guard = true;
295                let layer = self.systems[*systems_left.first().unwrap()].layer;
296                let mut chunk_size = systems_left.iter().fold(0, |a, v| {
297                    a + if self.systems[*v].layer == layer {
298                        1
299                    } else {
300                        0
301                    }
302                });
303                for index in sorted_load.iter().copied() {
304                    let mut load = &mut load[index];
305                    if load.0 {
306                        continue;
307                    }
308                    let worker = &self.workers[index];
309                    let found = Self::find_system_to_run(
310                        &systems_left[..chunk_size],
311                        &self.systems,
312                        &resources,
313                        index,
314                        &systems_preferred_worker,
315                    );
316                    if let Some(index) = found {
317                        #[allow(mutable_transmutes)]
318                        #[allow(clippy::transmute_ptr_to_ptr)]
319                        let universe = unsafe { std::mem::transmute(&mut *universe) };
320                        let msg = Message {
321                            universe,
322                            system: self.systems[index].system,
323                        };
324                        if worker.sender.send(Some(msg)).is_ok() {
325                            for id in &self.systems[index].reads {
326                                if let Some(access) = resources.get_mut(id) {
327                                    access.acquire_read();
328                                }
329                            }
330                            for id in &self.systems[index].writes {
331                                if let Some(access) = resources.get_mut(id) {
332                                    access.acquire_write();
333                                }
334                            }
335                            if let Some(index) = systems_left.iter().position(|i| *i == index) {
336                                systems_left.swap_remove(index);
337                            }
338                            load.0 = true;
339                            chunk_size -= 1;
340                        }
341                    }
342                }
343            }
344        }
345    }
346}
347
348#[cfg(feature = "parallel")]
349impl Drop for JobsPipelineEngine {
350    fn drop(&mut self) {
351        for worker in std::mem::take(&mut self.workers) {
352            let _ = worker.sender.send(None);
353            let _ = worker.handle.join();
354        }
355    }
356}