1use std::{collections::{HashMap, HashSet}, ops::Range, sync::{self, Arc, RwLock}};
2
3use anyhow::{anyhow, Result};
4use rand_xoshiro::Xoshiro256PlusPlus;
5use rand::SeedableRng;
6
7use crate::{codegen::{analysis::SimulationGraph, util, ExecutorContext}, runtime::{ParticleBorrowMut, Runtime}};
8
9use crate::runtime::{SimulationID, ParticleID};
10
11use super::{CallbackStateType, CallbackThread, CallbackType, GlobalContext, ThreadContext, WorkerThread, analysis::SymbolTable, neighbors::NeighborList, util::IndexRange};
12
13pub struct CompiledRuntime {
19 _simulation: SimulationID,
21 executor: SimulationExecutor,
23 _active_particles: HashSet<ParticleID>,
25 global_context: Arc<GlobalContext>,
27}
28
29impl CompiledRuntime {
30 pub(crate) fn new(runtime: Runtime, simulation: SimulationID) -> Result<Self> {
33 let indices = util::IndicesRef::new(
35 &runtime.particle_index,
36 &runtime.simulation_index,
37 &runtime.interaction_index,
38 );
39 let mut global_symbols = SymbolTable::new();
41 for (name, value) in &runtime.constants {
42 global_symbols.add_constant_from_substitution(name.clone(), value)?;
43 }
44 for (function_id, function_def) in runtime.function_index.get_functions() {
45 global_symbols.add_function(function_def.get_name().to_string(), function_id)?;
46 }
47 global_symbols.add_constant_f64("DT".into(), runtime.get_time_step())?;
49 let active_particles = runtime.particle_store.get_particles()
51 .map(|(particle_id, _)| particle_id)
52 .collect::<HashSet<_>>();
53 let simgraph = SimulationGraph::new(active_particles.clone(), simulation, indices)?;
55 let global_context = Arc::new(GlobalContext {
57 runtime, global_symbols, simgraph
58 });
59 let executor = SimulationExecutor::new(global_context.clone(), &active_particles)?;
61 Ok(Self {
63 global_context: global_context.clone(),
64 executor, _simulation: simulation,
65 _active_particles: active_particles
66 })
68 }
69
70 pub fn run_step(&self) {
72 self.executor.run();
74 self.wait();
77 }
78
79 fn wait(&self) {
80 self.executor.wait();
82 }
83
84 pub fn join(self) {
85 self.executor.join();
87 }
88
89 pub fn borrow_particle_mut(&self, particle_name: &str) -> Result<ParticleBorrowMut> {
91 self.global_context.borrow_particle_mut(particle_name)
92 }
93
94 pub fn define_callback(&mut self, name: &str,
96 callback: CallbackType,
97 callback_state: CallbackStateType)
98 -> Result<()> {
99 self.executor.register_callback(name, callback, callback_state)
100 }
101
102 pub fn undefine_callback(&mut self, name: &str)
104 -> Result<(CallbackType, CallbackStateType)> {
105 self.executor.unregister_callback(name)
106 }
107
108 pub fn rebuild_neighbor_list(&mut self, interaction_name: &str) -> Result<()> {
110 self.executor.rebuild_neighbor_list(interaction_name)
111 }
112
113 pub fn get_neighbor_lists(&mut self, interaction_name: &str) -> Result<HashMap<(ParticleID, Range<usize>), (Vec<usize>, Vec<usize>)>> {
114 self.executor.get_neighbor_lists(interaction_name)
115 }
116}
117
118pub(crate) struct SimulationExecutor {
120 workers: Vec<WorkerThread>,
122 callback_thread: CallbackThread,
124 executor_context: ExecutorContext
128}
129
130impl SimulationExecutor {
131 pub(crate) fn new(global_context: Arc<GlobalContext>, active_particles: &HashSet<ParticleID>) -> Result<Self> {
132 let simgraph = &global_context.simgraph;
133 let particle_store = &global_context.runtime.particle_store;
134 let threads_per_particle = &global_context.runtime.threads_per_particle;
135 let interaction_index = &global_context.runtime.interaction_index;
136 let particle_index = &global_context.runtime.particle_index;
137 let enabled_interactions = &global_context.runtime.enabled_interactions;
138 let domain = &global_context.runtime.domain;
139 let rng_seeds = &global_context.runtime.rng_seeds;
140 let mut worker_index_ranges = HashMap::new();
142 for particle_id in active_particles {
143 let particle_count = particle_store
144 .get_particle(*particle_id).unwrap().get_particle_count();
145 worker_index_ranges.insert(*particle_id,
146 IndexRange::new(0, particle_count).split(
147 *threads_per_particle.get(particle_id).unwrap()
148 )
149 );
150 }
151 let mut neighbor_lists = HashMap::new();
153 for (interaction, interaction_def) in interaction_index.iter() {
154 let affected_particles = interaction_def.get_affected_particles(particle_index)?;
156 if let None = affected_particles.intersection(&active_particles).next() {
157 continue;
158 }
159 match enabled_interactions.get(&interaction) {
161 Some(details) => {
162 let position_blocks = affected_particles.intersection(&active_particles)
163 .map(|particle_id|
164 (*particle_id, worker_index_ranges.get(particle_id).unwrap().clone()))
165 .collect::<HashMap<_,_>>();
166 let cutoff_length = details.skin_factor * util::unwrap_f64_constant(&interaction_def.get_cutoff())?;
167 let bin_size = details.cell_size.unwrap_or(cutoff_length);
168 neighbor_lists.insert(interaction, Arc::new(RwLock::new(
169 NeighborList::new(
170 bin_size, cutoff_length,
171 domain.clone(),
172 details.num_workers,
173 details.rebuild_interval,
174 position_blocks,
175 particle_index,
176 particle_store
177 )?))
178 );
179 },
180 None => { continue; }
181 }
182 }
183 let barriers = simgraph.barriers.iter()
185 .map(|(barrier_id, barrier_def)| (
186 barrier_id,
187 Arc::new(sync::Barrier::new(
189 barrier_def.affected_particles.iter()
190 .map(|particle_id| worker_index_ranges.get(particle_id).unwrap().len())
191 .sum()
192 ))
193 ))
194 .collect::<HashMap<_,_>>();
195 let num_workers = worker_index_ranges.iter().map(|(_,ranges)| ranges.len()).sum();
197 let step_barrier = Arc::new(sync::Barrier::new(num_workers));
198 let call_end_barrier = Arc::new(sync::Barrier::new(num_workers+1));
199 let step_counter = Arc::new(RwLock::new(0));
201 let callback_thread = CallbackThread::new(call_end_barrier.clone(), num_workers, global_context.clone());
203 let executor_context = ExecutorContext {
205 barriers, step_barrier, step_counter,
206 call_end_barrier,
207 neighbor_lists,
208 call_sender: callback_thread.get_sender(),
209 global_context: global_context.clone()
210 };
211 let rng_seeds = match rng_seeds {
213 None => unimplemented!("Explicit RNG seeding required for now"),
214 Some(rng_seeds) => {
215 if rng_seeds.len() >= num_workers {
216 Ok(rng_seeds)
217 }
218 else {
219 Err(anyhow!("Not enough RNG seeds (need {}, got {})", num_workers, rng_seeds.len()))
220 }
221 }
222 }?;
223 let workers = worker_index_ranges.into_iter()
224 .zip(&rng_seeds[0..num_workers])
225 .map(|((particle_id, index_ranges), rng_seed)| {
226 let executor_context = &executor_context;
227 index_ranges.into_iter().map(move |particle_range| {
228 let rng = Xoshiro256PlusPlus::seed_from_u64(*rng_seed);
230 let normal_dist = rand_distr::Normal::new(0.0, 1.0)
231 .expect("Math is broken. All is lost.");
232 let thread_context = ThreadContext {
233 particle_id, particle_range, rng, normal_dist,
234 executor_context: executor_context.clone()
235 };
236 WorkerThread::spawn(thread_context)
238 })
239 })
240 .flatten()
241 .collect::<Vec<_>>();
242 for worker in &workers {
244 worker.wait_for_compilation();
245 }
246 Ok(Self {
247 workers, callback_thread, executor_context
248 })
249 }
250
251 pub(crate) fn run(&self) {
252 for worker in &self.workers {
254 worker.run_step();
255 }
256 }
257
258 pub(crate) fn wait(&self) {
259 for worker in &self.workers {
262 worker.wait();
263 }
264 }
265
266 pub(crate) fn join(self) {
267 for worker in self.workers {
268 worker.join();
269 }
270 self.callback_thread.join();
271 }
272
273 pub(crate) fn register_callback(&mut self, name: &str,
274 callback: CallbackType,
275 callback_state: CallbackStateType)
276 -> Result<()> {
277 let barriers = self.executor_context.global_context.simgraph.callbacks.get(name)
279 .ok_or(anyhow!("No callback named {} found in simulation graph", &name))?
280 .clone();
281 self.callback_thread.register_callback(barriers, callback, callback_state)
282 }
283
284 pub(crate) fn unregister_callback(&mut self, name: &str)
285 -> Result<(CallbackType, CallbackStateType)> {
286 let barriers = self.executor_context.global_context.simgraph.callbacks.get(name)
287 .ok_or(anyhow!("No callback named {} found in simulation graph", &name))?
288 .clone();
289 self.callback_thread.unregister_callback(barriers)
290 }
291
292 pub(crate) fn rebuild_neighbor_list(&mut self, interaction_name: &str) -> Result<()> {
293 let (interaction_id,_) = self.executor_context.global_context.runtime.interaction_index
294 .get_interaction_by_name(interaction_name)
295 .ok_or(anyhow!("Cannot find interaction with name {}", interaction_name))?;
296 self.executor_context.neighbor_lists.get_mut(&interaction_id)
297 .ok_or(anyhow!("There is no neighbor list for interaction with name {}", interaction_name))?
298 .write().unwrap()
299 .rebuild(&self.executor_context.global_context.runtime.particle_index,
300 &self.executor_context.global_context.runtime.particle_store);
301 Ok(())
302 }
303
304 pub fn get_neighbor_lists(&mut self, interaction_name: &str) -> Result<HashMap<(ParticleID, Range<usize>), (Vec<usize>, Vec<usize>)>> {
305 let (interaction_id,_) = self.executor_context.global_context.runtime.interaction_index
307 .get_interaction_by_name(interaction_name)
308 .ok_or(anyhow!("Cannot find interaction with name {}", interaction_name))?;
309 let neighbor_list = self.executor_context.neighbor_lists.get(&interaction_id)
311 .ok_or(anyhow!("There is no neighbor list for interaction with name {}", interaction_name))?
312 .read().unwrap();
313 Ok(neighbor_list.pos_blocks.iter()
315 .zip(neighbor_list.neighbor_lists.iter())
316 .map(|((particle_id, index_range), (neighbor_list_index, neighbor_list))| {
317 ((*particle_id, index_range.start..index_range.end),
318 (neighbor_list_index.clone(), neighbor_list.clone()))
319 })
320 .collect::<HashMap<_,_>>())
321 }
322}
323