fips_md/codegen/analysis/
simgraph.rs1use anyhow::{anyhow, Result};
5use slotmap::SlotMap;
6
7use std::{collections::{BTreeSet, HashMap, HashSet}, usize};
8
9use crate::{codegen::util, parser, parser::{Statement}, runtime::{InteractionID, InteractionQuantityID, ParticleID, SimulationBlockKind, SimulationID, SimulationParticleFilter}};
10
11use super::{Barrier, BarrierID, BarrierKind, SymbolTable};
12
13pub struct SimulationGraph {
17 pub timelines: HashMap<ParticleID, Timeline>,
18 pub barriers: SlotMap<BarrierID, Barrier>,
19 pub callbacks: HashMap<String, BTreeSet<BarrierID>>
20}
21
22pub struct Timeline {
24 pub(crate) nodes: Vec<SimulationNode>,
25 pub(crate) particle_symbols: SymbolTable<()>
26}
27
28impl Timeline {
29 pub fn new(nodes: Vec<SimulationNode>, particle_symbols: SymbolTable<()>) -> Self {
30 Self { nodes, particle_symbols }
31 }
32}
33
34struct UnresolvedTimeline {
36 nodes: Vec<MaybeResolvedSimNode>
37}
38
39pub enum SimulationNode {
42 StatementBlock(StatementBlock),
44 CommonBarrier(BarrierID)
46}
47
48enum MaybeResolvedSimNode {
50 Resolved(SimulationNode),
52 UnresolvedInteractionBarrier(InteractionID, Option<InteractionQuantityID>),
54 UnresolvedCallBarrier(String)
56}
57
58#[derive(Clone)]
59pub struct StepRange {
60 start: usize,
61 end: usize,
62 step: usize
63}
64
65pub struct StatementBlock {
66 pub step_range: StepRange,
67 pub statements: Vec<parser::Statement>,
68 pub local_symbols: SymbolTable<()>
69}
70
71impl SimulationGraph {
72 pub(crate) fn new<'a>(active_particles: HashSet<ParticleID>,
73 simulation_id: SimulationID,
74 indices: util::IndicesRef<'a>
75 ) -> Result<Self> {
76 let active_particles = active_particles.iter()
78 .map(|x| *x ).collect::<HashSet<_>>();
79 let mut unresolved_graph = active_particles.iter()
81 .map(|particle_id| {
82 let timeline = UnresolvedTimeline::new(*particle_id, simulation_id, indices);
83 timeline.map(|timeline| (*particle_id, timeline))
84 })
85 .collect::<Result<Vec<_>>>()?;
86 let mut used_interactions = HashSet::new();
88 for (_, timeline) in &unresolved_graph {
89 for node in &timeline.nodes {
90 match node {
91 MaybeResolvedSimNode::UnresolvedInteractionBarrier(interaction_id, _) => {
92 used_interactions.insert(*interaction_id);
93 }
94 MaybeResolvedSimNode::Resolved(_) |
95 MaybeResolvedSimNode::UnresolvedCallBarrier(_) => {}
96 }
97 }
98 }
99 let affected_particles_map = used_interactions.iter()
101 .map(|interaction_id| {
102 indices.interactions.get(*interaction_id).unwrap()
104 .get_affected_particles(indices.particles)
105 .map(|affected_particles| (
107 *interaction_id,
108 affected_particles.intersection(&active_particles)
109 .map(|x| *x).collect::<HashSet<_>>()
110 ))
111 })
112 .collect::<Result<HashMap<_,_>>>()?;
113 let mut barriers: SlotMap<BarrierID, Barrier> = SlotMap::new();
115 let mut callbacks: HashMap<String, BTreeSet<BarrierID>> = HashMap::new();
116 let mut timelines: HashMap<ParticleID, Timeline> = HashMap::new();
117 while let Some((particle_id, timeline)) = unresolved_graph.pop() {
119 let particle_def = indices.particles.get(particle_id).unwrap();
121 let particle_symbols = SymbolTable::from_particle(particle_def);
122 let mut resolved_nodes = vec![];
124 for node in timeline.nodes {
125 resolved_nodes.push(match node {
126 MaybeResolvedSimNode::Resolved(node) => node,
128 MaybeResolvedSimNode::UnresolvedInteractionBarrier(interaction_id, quantity_id) => {
130 let affected_particles = affected_particles_map
132 .get(&interaction_id).unwrap().clone();
133 let barrier = Barrier::new(affected_particles,
134 BarrierKind::InteractionBarrier(interaction_id, quantity_id));
135 let barrier_id = barriers.insert(barrier);
137 let affected_particles = &barriers.get(barrier_id)
138 .unwrap().affected_particles; for (other_particle, other_timeline) in unresolved_graph.iter_mut()
141 .filter(|(other_particle,_)| affected_particles.contains(other_particle))
143 {
144 let mut found = false;
146 for other_node in &mut other_timeline.nodes {
147 if let MaybeResolvedSimNode::UnresolvedInteractionBarrier(other_interaction_id, other_quantity_id) = other_node {
148 if interaction_id == *other_interaction_id && quantity_id == *other_quantity_id {
149 *other_node = MaybeResolvedSimNode::Resolved(SimulationNode::CommonBarrier(barrier_id));
150 found = true;
152 break
153 }
154 }
155 }
156 if !found {
159 return Err(anyhow!(
160 "No matching update statement for interaction quantity {}:{} found for particle {} (required by interaction with particle {})",
161 indices.interactions.get(interaction_id).unwrap().get_name(),
162 match quantity_id {
163 None => "",
164 Some(quantity_id) => indices.interactions.get(interaction_id).unwrap().get_quantity(quantity_id).unwrap().get_name(),
165 },
166 indices.particles.get(*other_particle).unwrap().get_name(),
167 indices.particles.get(particle_id).unwrap().get_name(),
168 ));
169 }
170 }
171 SimulationNode::CommonBarrier(barrier_id)
172 }
173 MaybeResolvedSimNode::UnresolvedCallBarrier(call_name) => {
176 let barrier = Barrier::new(active_particles.clone(),
178 BarrierKind::CallBarrier(call_name.clone()));
179 let barrier_id = barriers.insert(barrier);
181 match callbacks.get_mut(&call_name) {
182 Some(barriers) => {
183 barriers.insert(barrier_id);
184 }
185 None => {
186 let mut barriers = BTreeSet::new();
187 barriers.insert(barrier_id);
188 callbacks.insert(call_name.clone(), barriers);
189 }
190 }
191 for (_, other_timeline) in unresolved_graph.iter_mut() {
193 let mut found = false;
194 for other_node in &mut other_timeline.nodes {
195 if let MaybeResolvedSimNode::UnresolvedCallBarrier(other_call_name) = other_node {
196 if *other_call_name == call_name {
197 *other_node = MaybeResolvedSimNode::Resolved(SimulationNode::CommonBarrier(barrier_id));
198 found = true;
199 break
200 }
201 }
202 }
203 if !found {
205 other_timeline.nodes.push(
206 MaybeResolvedSimNode::Resolved(SimulationNode::CommonBarrier(barrier_id))
207 )
208 }
209 }
210 SimulationNode::CommonBarrier(barrier_id)
211 }
212 })
213 }
214 timelines.insert(particle_id, Timeline::new(resolved_nodes, particle_symbols));
216 }
217
218 Ok(Self {
229 timelines,
230 barriers,
231 callbacks
232 })
233 }
234}
235
236impl UnresolvedTimeline {
237 pub(crate) fn new<'a>(particle_id: ParticleID,
238 simulation_id: SimulationID,
239 indices: util::IndicesRef<'a>
240 ) -> Result<Self> {
241 let simulation = indices.simulations.get_simulation(&simulation_id).unwrap();
243 let default_particle_id = simulation.get_default_particle();
245 let mut nodes = vec![];
247 for simblock in simulation.get_blocks() {
248 let step_range = match &simblock.kind {
250 SimulationBlockKind::Once(step) => {
252 let step = util::unwrap_usize_constant(step)?;
253 StepRange::new(step, step+1, 1)
254 }
255 SimulationBlockKind::Step(step_range) => {
257 let (start,end,step) = (
258 util::unwrap_usize_constant(&step_range.start)?,
259 util::unwrap_usize_constant(&step_range.end)?,
260 util::unwrap_usize_constant(&step_range.step)?
261 );
262 StepRange::new(start, end, step)
263 }
264 };
265 let mut current_node = vec![];
267 let push_node = |current_node: Vec<Statement>, nodes: &mut Vec<MaybeResolvedSimNode>| -> Result<Vec<Statement>> {
269 if !current_node.is_empty() {
270 nodes.push(MaybeResolvedSimNode::Resolved(
271 SimulationNode::StatementBlock(
272 StatementBlock::new(step_range.clone(),current_node)?)));
273 }
274 Ok(vec![])
275 };
276 for statement_block in &simblock.statement_blocks {
277 let affected = match &statement_block.particle_filter {
279 SimulationParticleFilter::Default => match default_particle_id {
281 None => false, Some(default_particle_id) => particle_id == default_particle_id
283 },
284 SimulationParticleFilter::Single(filter_particle_id) =>
285 particle_id == *filter_particle_id
286 };
287 if affected {
289 for statement in &statement_block.statements {
290 match &statement {
291 Statement::Let(_) | Statement::Assign(_) => {
293 current_node.push(statement.clone())
294 }
295 Statement::Update(update_statement) => {
297 let interaction_name = &update_statement.interaction;
298 let (interaction_id, interaction) = indices.interactions.get_interaction_by_name(&interaction_name)
300 .ok_or(anyhow!("Cannot resolve interaction with name {}", &interaction_name))?;
301 let quantity_id = match &update_statement.quantity {
302 None => None,
303 Some(quantity_name) => Some(interaction.get_quantity_by_name(&quantity_name)
304 .ok_or(anyhow!("Cannot resolve quantity {} of interaction {}",
305 quantity_name, interaction_name))?.0)
306 };
307 current_node = push_node(current_node, &mut nodes)?;
309 nodes.push(MaybeResolvedSimNode::UnresolvedInteractionBarrier(
310 interaction_id, quantity_id
311 ))
312 }
313 Statement::Call(call_statement) => {
315 current_node = push_node(current_node, &mut nodes)?;
317 nodes.push(MaybeResolvedSimNode::UnresolvedCallBarrier(
318 call_statement.name.clone()
319 ))
320 }
321 }
322 }
323 }
324 }
325 push_node(current_node, &mut nodes)?;
326 }
327 Ok(Self { nodes })
329 }
330}
331
332impl StatementBlock {
333 pub fn new(step_range: StepRange, statements: Vec<parser::Statement>) -> Result<Self> {
334 let mut local_symbols = SymbolTable::new();
336 for statement in &statements {
337 if let parser::Statement::Let(statement) = statement {
338 local_symbols.add_local_symbol(statement.name.clone(), statement.typ.clone())?
339 }
340 }
341 Ok(Self {
342 step_range,
343 statements,
344 local_symbols
345 })
346 }
347}
348
349impl StepRange {
350 pub fn new(start: usize, end: usize, step: usize) -> Self {
351 Self {
352 start, end, step
353 }
354 }
355}