fips_md/codegen/
compiled.rs

1use std::{collections::{HashMap, HashSet}, ops::Range, sync::{self, Arc, RwLock}};
2
3use anyhow::{anyhow, Result};
4use rand_xoshiro::Xoshiro256PlusPlus;
5use rand::SeedableRng;
6
7use crate::{codegen::{analysis::SimulationGraph, util, ExecutorContext}, runtime::{ParticleBorrowMut, Runtime}};
8
9use crate::runtime::{SimulationID, ParticleID};
10
11use super::{CallbackStateType, CallbackThread, CallbackType, GlobalContext, ThreadContext, WorkerThread, analysis::SymbolTable, neighbors::NeighborList, util::IndexRange};
12
13/// Wrapper struct for runtime after compilation
14///
15/// This wrapping is intended to make all changes to the internal state impossible
16/// that would break the compiled simulation (thus wrapping all the unsafety of
17/// code generation)
18pub struct CompiledRuntime {
19    /// ID of the simulation we compiled for
20    _simulation: SimulationID,
21    /// Simulation executor (this spawn the threads for the actual computation)
22    executor: SimulationExecutor,
23    /// Active particle types (i.e. particle types that have at least one instance)
24    _active_particles: HashSet<ParticleID>,
25    /// Global compilation context
26    global_context: Arc<GlobalContext>,
27}
28
29impl CompiledRuntime {
30    /// Create a new compiled runtime from a regular runtime structure
31    /// The second parameter determines the simulations that will be compiled
32    pub(crate) fn new(runtime: Runtime, simulation: SimulationID) -> Result<Self> {
33        // Create index shorthands
34        let indices = util::IndicesRef::new(
35            &runtime.particle_index,
36            &runtime.simulation_index,
37            &runtime.interaction_index,
38        );
39        // Construct global symbol table
40        let mut global_symbols = SymbolTable::new();
41        for (name, value) in &runtime.constants {
42            global_symbols.add_constant_from_substitution(name.clone(), value)?;
43        }
44        for (function_id, function_def) in runtime.function_index.get_functions() {
45            global_symbols.add_function(function_def.get_name().to_string(), function_id)?;
46        }
47        // Insert additional constants
48        global_symbols.add_constant_f64("DT".into(), runtime.get_time_step())?;
49        // Determine active particles
50        let active_particles = runtime.particle_store.get_particles()
51            .map(|(particle_id, _)| particle_id)
52            .collect::<HashSet<_>>();
53        // Construct simulation graphs
54        let simgraph = SimulationGraph::new(active_particles.clone(), simulation, indices)?;
55        // Create global context (factored out because it is borrowed to all worker threads)
56        let global_context = Arc::new(GlobalContext {
57            runtime, global_symbols, simgraph
58        });
59        // Create simulation contexts
60        let executor = SimulationExecutor::new(global_context.clone(), &active_particles)?;
61        // Return compiled runtime
62        Ok(Self {
63            global_context: global_context.clone(),
64            executor, _simulation: simulation,
65            _active_particles: active_particles
66            // callbacks: HashMap::new()
67        })
68    }
69
70    /// Run a single time step
71    pub fn run_step(&self) {
72        // BIG TODO
73        self.executor.run();
74        // We need to wait for all workers to complete (otherwise we could create
75        // a data race)
76        self.wait();
77    }
78
79    fn wait(&self) {
80        // BIG TODO
81        self.executor.wait();
82    }
83
84    pub fn join(self) {
85        // BIG TODO
86        self.executor.join();
87    }
88
89    // Borrow data belonging to a particle type
90    pub fn borrow_particle_mut(&self, particle_name: &str) -> Result<ParticleBorrowMut> {
91        self.global_context.borrow_particle_mut(particle_name)
92    }
93
94    /// Define a new callback
95    pub fn define_callback(&mut self, name: &str,
96        callback: CallbackType,
97        callback_state: CallbackStateType) 
98    -> Result<()> {
99        self.executor.register_callback(name, callback, callback_state)
100    }
101
102    /// Undefine an existing callback (to get its context back)
103    pub fn undefine_callback(&mut self, name: &str)
104    -> Result<(CallbackType, CallbackStateType)> {
105        self.executor.unregister_callback(name)
106    }
107
108    /// Manually rebuild the neighbor list for a given interaction
109    pub fn rebuild_neighbor_list(&mut self, interaction_name: &str) -> Result<()> {
110        self.executor.rebuild_neighbor_list(interaction_name)
111    }
112
113    pub fn get_neighbor_lists(&mut self, interaction_name: &str) -> Result<HashMap<(ParticleID, Range<usize>), (Vec<usize>, Vec<usize>)>> {
114        self.executor.get_neighbor_lists(interaction_name)
115    }
116}
117
118/// Execution handle to a single simulation
119pub(crate) struct SimulationExecutor {
120    /// Worker threads
121    workers: Vec<WorkerThread>,
122    /// Callback thread
123    callback_thread: CallbackThread,
124    // /// Global context
125    // global_context: Arc<GlobalContext>,
126    /// Executor context
127    executor_context: ExecutorContext
128}
129
130impl SimulationExecutor {
131    pub(crate) fn new(global_context: Arc<GlobalContext>, active_particles: &HashSet<ParticleID>) -> Result<Self> {
132        let simgraph = &global_context.simgraph;
133        let particle_store = &global_context.runtime.particle_store;
134        let threads_per_particle = &global_context.runtime.threads_per_particle;
135        let interaction_index = &global_context.runtime.interaction_index;
136        let particle_index = &global_context.runtime.particle_index;
137        let enabled_interactions = &global_context.runtime.enabled_interactions;
138        let domain = &global_context.runtime.domain;
139        let rng_seeds = &global_context.runtime.rng_seeds;
140        // Determine the index ranges for the worker threads
141        let mut worker_index_ranges = HashMap::new();
142        for particle_id in active_particles {
143            let particle_count = particle_store
144                .get_particle(*particle_id).unwrap().get_particle_count();
145            worker_index_ranges.insert(*particle_id, 
146                IndexRange::new(0, particle_count).split(
147                    *threads_per_particle.get(particle_id).unwrap()
148                )
149            );
150        }
151        // Determine neighbor lists to maintain
152        let mut neighbor_lists = HashMap::new();
153        for (interaction, interaction_def) in interaction_index.iter() {
154            // Skip interactions that are dead
155            let affected_particles = interaction_def.get_affected_particles(particle_index)?;
156            if let None = affected_particles.intersection(&active_particles).next() {
157                continue;
158            }
159            // Skip interactions that have not been enabled
160            match enabled_interactions.get(&interaction) {
161                Some(details) => {
162                    let position_blocks = affected_particles.intersection(&active_particles)
163                        .map(|particle_id| 
164                            (*particle_id, worker_index_ranges.get(particle_id).unwrap().clone()))
165                        .collect::<HashMap<_,_>>();
166                    let cutoff_length = details.skin_factor * util::unwrap_f64_constant(&interaction_def.get_cutoff())?;
167                    let bin_size = details.cell_size.unwrap_or(cutoff_length);
168                    neighbor_lists.insert(interaction, Arc::new(RwLock::new(
169                        NeighborList::new(
170                            bin_size, cutoff_length,
171                            domain.clone(),
172                            details.num_workers,
173                            details.rebuild_interval,
174                            position_blocks,
175                            particle_index,
176                            particle_store
177                        )?))
178                    );
179                },
180                None => { continue; }
181            }
182        }
183        // Create barriers
184        let barriers = simgraph.barriers.iter()
185            .map(|(barrier_id, barrier_def)| (
186                barrier_id,
187                // Barrier count is sum of thread counts for each affected particle
188                Arc::new(sync::Barrier::new(
189                    barrier_def.affected_particles.iter()
190                        .map(|particle_id| worker_index_ranges.get(particle_id).unwrap().len())
191                        .sum()
192                ))
193            ))
194            .collect::<HashMap<_,_>>();
195        // Total number of workers is the sum of the number of workers of each particle type
196        let num_workers = worker_index_ranges.iter().map(|(_,ranges)| ranges.len()).sum();
197        let step_barrier = Arc::new(sync::Barrier::new(num_workers));
198        let call_end_barrier = Arc::new(sync::Barrier::new(num_workers+1));
199        // The step counter always starts at zero
200        let step_counter = Arc::new(RwLock::new(0));
201        // Create the callback thread
202        let callback_thread = CallbackThread::new(call_end_barrier.clone(), num_workers, global_context.clone());
203        // Create context
204        let executor_context = ExecutorContext {
205            barriers, step_barrier, step_counter,
206            call_end_barrier,
207            neighbor_lists,
208            call_sender: callback_thread.get_sender(),
209            global_context: global_context.clone()
210        };
211        // Create worker threads
212        let rng_seeds = match rng_seeds {
213            None => unimplemented!("Explicit RNG seeding required for now"),
214            Some(rng_seeds) => {
215                if rng_seeds.len() >= num_workers {
216                    Ok(rng_seeds)
217                }
218                else {
219                    Err(anyhow!("Not enough RNG seeds (need {}, got {})", num_workers, rng_seeds.len()))
220                }
221            }
222        }?;
223        let workers = worker_index_ranges.into_iter()
224            .zip(&rng_seeds[0..num_workers])
225            .map(|((particle_id, index_ranges), rng_seed)| {
226                let executor_context = &executor_context;
227                index_ranges.into_iter().map(move |particle_range| {
228                    // Create thread specific context
229                    let rng = Xoshiro256PlusPlus::seed_from_u64(*rng_seed);
230                    let normal_dist = rand_distr::Normal::new(0.0, 1.0)
231                        .expect("Math is broken. All is lost.");
232                    let thread_context = ThreadContext {
233                        particle_id, particle_range, rng, normal_dist,
234                        executor_context: executor_context.clone()
235                    };
236                    // Spawn worker thread
237                    WorkerThread::spawn(thread_context)
238                })
239            })
240            .flatten()
241            .collect::<Vec<_>>();
242        // Wait for workers to finish compiling
243        for worker in &workers {
244            worker.wait_for_compilation();
245        }
246        Ok(Self {
247            workers, callback_thread, executor_context
248        })
249    }
250
251    pub(crate) fn run(&self) {
252        // BIG TODO
253        for worker in &self.workers {
254            worker.run_step();
255        }
256    }
257
258    pub(crate) fn wait(&self) {
259        // No need to wait for the callback thread, since the workers will only
260        // be idle if the callback worker is too
261        for worker in &self.workers {
262            worker.wait();
263        }
264    }
265
266    pub(crate) fn join(self) {
267        for worker in self.workers {
268            worker.join();
269        }
270        self.callback_thread.join();
271    }
272
273    pub(crate) fn register_callback(&mut self, name: &str,
274        callback: CallbackType,
275        callback_state: CallbackStateType) 
276    -> Result<()> {
277        // Try to resolve callback
278        let barriers = self.executor_context.global_context.simgraph.callbacks.get(name)
279            .ok_or(anyhow!("No callback named {} found in simulation graph", &name))?
280            .clone();
281        self.callback_thread.register_callback(barriers, callback, callback_state)
282    }
283
284    pub(crate) fn unregister_callback(&mut self, name: &str) 
285    -> Result<(CallbackType, CallbackStateType)> {
286        let barriers = self.executor_context.global_context.simgraph.callbacks.get(name)
287            .ok_or(anyhow!("No callback named {} found in simulation graph", &name))?
288            .clone();
289        self.callback_thread.unregister_callback(barriers)
290    }
291
292    pub(crate) fn rebuild_neighbor_list(&mut self, interaction_name: &str) -> Result<()> {
293        let (interaction_id,_) = self.executor_context.global_context.runtime.interaction_index
294            .get_interaction_by_name(interaction_name)
295            .ok_or(anyhow!("Cannot find interaction with name {}", interaction_name))?;
296        self.executor_context.neighbor_lists.get_mut(&interaction_id)
297            .ok_or(anyhow!("There is no neighbor list for interaction with name {}", interaction_name))?
298            .write().unwrap()
299            .rebuild(&self.executor_context.global_context.runtime.particle_index, 
300                &self.executor_context.global_context.runtime.particle_store);
301        Ok(())
302    }
303
304    pub fn get_neighbor_lists(&mut self, interaction_name: &str) -> Result<HashMap<(ParticleID, Range<usize>), (Vec<usize>, Vec<usize>)>> {
305        // Check if interaction exists
306        let (interaction_id,_) = self.executor_context.global_context.runtime.interaction_index
307            .get_interaction_by_name(interaction_name)
308            .ok_or(anyhow!("Cannot find interaction with name {}", interaction_name))?;
309        // Get read handle to neighbor list structure
310        let neighbor_list = self.executor_context.neighbor_lists.get(&interaction_id)
311            .ok_or(anyhow!("There is no neighbor list for interaction with name {}", interaction_name))?
312            .read().unwrap();
313        // Zip pos blocks and (sub) neighbor lists together and clone lists for return value
314        Ok(neighbor_list.pos_blocks.iter()
315            .zip(neighbor_list.neighbor_lists.iter())
316            .map(|((particle_id, index_range), (neighbor_list_index, neighbor_list))| {
317                ((*particle_id, index_range.start..index_range.end), 
318                (neighbor_list_index.clone(), neighbor_list.clone()))
319            })
320            .collect::<HashMap<_,_>>())
321    }
322}
323