rant_simulator/scan/
mod.rs

1use rayon::prelude::*;
2
3pub mod adapters;
4pub mod generators;
5
6pub trait VectorGenerator {
7    type Vector;
8
9    fn generate_scan_vectors(self) -> impl Iterator<Item = Self::Vector>;
10    fn size_hint(&self) -> usize;
11}
12
13pub trait ParallelVectorGenerator {
14    type Vector;
15
16    fn generate_scan_vectors(
17        self,
18    ) -> impl Iterator<Item = impl Iterator<Item = Self::Vector> + Send> + Send;
19    fn num_chunks(&self) -> usize;
20    fn size_hint(&self) -> usize;
21}
22
23pub trait ParameterAdapter<State, Parameters> {
24    type Vector;
25
26    fn compute_initial_state_and_parameters(&self, vector: Self::Vector) -> (State, Parameters);
27}
28
29pub fn scan<Vector, State, Parameters, Result>(
30    vector_generator: impl VectorGenerator<Vector = Vector>,
31    parameter_adapter: impl ParameterAdapter<State, Parameters, Vector = Vector>,
32    simulate: impl Fn(State, &Parameters) -> Result,
33) -> impl Iterator<Item = (State, Parameters, Result)>
34where
35    State: Default + Clone,
36{
37    let scan_points = vector_generator.generate_scan_vectors();
38
39    scan_points.map(move |scan_point| {
40        let (initial_state, parameters) =
41            parameter_adapter.compute_initial_state_and_parameters(scan_point);
42
43        let result = simulate(initial_state.clone(), &parameters);
44        (initial_state, parameters, result)
45    })
46}
47
48pub fn scan_parallel<Vector, State, Parameters, Result>(
49    vector_generator: impl ParallelVectorGenerator<Vector = Vector> + 'static,
50    parameter_adapter: impl ParameterAdapter<State, Parameters, Vector = Vector>
51        + Sync
52        + Send
53        + Copy
54        + 'static,
55    simulate: impl Fn(State, &Parameters) -> Result + Sync + Send + Copy + 'static,
56) -> impl ParallelIterator<Item = (State, Parameters, Result)>
57where
58    Vector: Send,
59    State: Default + Clone + Send + Sync,
60    Parameters: Send + Sync,
61    Result: Send + Sync,
62{
63    vector_generator
64        .generate_scan_vectors()
65        .par_bridge()
66        .map(move |scan_points| {
67            scan_points.map(move |scan_point| {
68                let (initial_state, parameters) =
69                    parameter_adapter.compute_initial_state_and_parameters(scan_point);
70
71                let result = simulate(initial_state.clone(), &parameters);
72                (initial_state, parameters, result)
73            })
74        })
75        .flatten_iter()
76}
77
78// TODO: remove expects
79/*
80pub fn scan_parallel_channels<Vector, State, Parameters, Result>(
81    vector_generator: impl ParallelVectorGenerator<Vector = Vector> + Send + 'static,
82    parameter_adapter: impl ParameterAdapter<State, Parameters, Vector = Vector>
83        + Clone
84        + Send
85        + Sync
86        + 'static,
87    simulate: impl Fn(State, &Parameters) -> Result + Clone + Send + 'static,
88) -> Vec<(State, Parameters, Result)>
89where
90    Vector: Send + 'static,
91    State: Default + Clone + Send + 'static,
92    Parameters: Send + 'static,
93    Result: Send + 'static,
94{
95    let num_workers = 12; // TODO: as optional parameter, else depending on processor
96
97    let mut results = Vec::with_capacity(vector_generator.size_hint());
98
99    let (scan_vector_sender, scan_vector_receiver) =
100        crossbeam_channel::bounded::<Box<dyn Iterator<Item = Vector> + Send>>(num_workers);
101    let (result_sender, result_receiver) =
102        crossbeam_channel::bounded::<Vec<(State, Parameters, Result)>>(num_workers);
103
104    let mut worker_threads = Vec::with_capacity(num_workers);
105    for _ in 0..num_workers {
106        let scan_vector_receiver = scan_vector_receiver.clone();
107        let result_sender = result_sender.clone();
108        let parameter_adapter = parameter_adapter.clone();
109        let simulate = simulate.clone();
110
111        let worker_thread = thread::spawn(move || {
112            for scan_vector_chunk in scan_vector_receiver {
113                let results = scan_vector_chunk
114                    .map(|scan_vector| {
115                        let (initial_state, parameters) =
116                            parameter_adapter.compute_initial_state_and_parameters(scan_vector);
117                        let result = simulate(initial_state.clone(), &parameters);
118                        (initial_state, parameters, result)
119                    })
120                    .collect();
121                result_sender.send(results).expect("could not send results");
122            }
123        });
124        worker_threads.push(worker_thread);
125    }
126
127    let num_chunks = vector_generator.num_chunks();
128    let scan_point_thread = thread::spawn(move || {
129        let scan_point_chunks = vector_generator.generate_scan_vectors();
130        for chunk in scan_point_chunks {
131            scan_vector_sender
132                .send(Box::new(chunk))
133                .expect("could not send scan vector chunk")
134        }
135    });
136
137    for _ in 0..num_chunks {
138        let mut result_chunk = result_receiver.recv().expect("could not receive result");
139        results.append(&mut result_chunk);
140    }
141
142    scan_point_thread
143        .join()
144        .expect("could not join thread that sends scan points");
145    for worker_thread in worker_threads {
146        worker_thread.join().expect("could not join worker thread");
147    }
148
149    results
150}
151*/