fips_md/codegen/analysis/
simgraph.rs

1//! Sort of intermediate representation of the simulation code as a set of timelines
2//! for every particle type connected by synchronization barriers
3
4use anyhow::{anyhow, Result};
5use slotmap::SlotMap;
6
7use std::{collections::{BTreeSet, HashMap, HashSet}, usize};
8
9use crate::{codegen::util, parser, parser::{Statement}, runtime::{InteractionID, InteractionQuantityID, ParticleID, SimulationBlockKind, SimulationID, SimulationParticleFilter}};
10
11use super::{Barrier, BarrierID, BarrierKind, SymbolTable};
12
13/// A simulation graph consists of a timeline for every type of particle
14/// Timelines can be connected by common synchronization barriers due to mutual
15/// interactions
16pub struct SimulationGraph {
17    pub timelines: HashMap<ParticleID, Timeline>,
18    pub barriers: SlotMap<BarrierID, Barrier>,
19    pub callbacks: HashMap<String, BTreeSet<BarrierID>>
20}
21
22/// A timeline is a chain of simulation nodes for a single particle type
23pub struct Timeline {
24    pub(crate) nodes: Vec<SimulationNode>,
25    pub(crate) particle_symbols: SymbolTable<()>
26}
27
28impl Timeline {
29    pub fn new(nodes: Vec<SimulationNode>, particle_symbols: SymbolTable<()>) -> Self {
30        Self { nodes, particle_symbols }
31    }
32}
33
34/// Internal proto-simgraph created during construction
35struct UnresolvedTimeline {
36    nodes: Vec<MaybeResolvedSimNode>
37}
38
39/// A single simulation node that is considered atomic for synchronization 
40/// purposes
41pub enum SimulationNode {
42    /// A block of statements
43    StatementBlock(StatementBlock),
44    /// A shared barrier between multiple timelines
45    CommonBarrier(BarrierID)
46}
47
48/// Simulation node that might not have been resolved yet
49enum MaybeResolvedSimNode {
50    /// Resolved node
51    Resolved(SimulationNode),
52    /// Unresolved barrier caused by interaction update
53    UnresolvedInteractionBarrier(InteractionID, Option<InteractionQuantityID>),
54    /// Unresolved barrier caused by call to Rust
55    UnresolvedCallBarrier(String)
56}
57
58#[derive(Clone)]
59pub struct StepRange {
60    start: usize,
61    end: usize,
62    step: usize
63}
64
65pub struct StatementBlock {
66    pub step_range: StepRange,
67    pub statements: Vec<parser::Statement>,
68    pub local_symbols: SymbolTable<()>
69}
70
71impl SimulationGraph {
72    pub(crate) fn new<'a>(active_particles: HashSet<ParticleID>,
73        simulation_id: SimulationID,
74        indices: util::IndicesRef<'a>
75    ) -> Result<Self> {
76        // Convert active particles to set
77        let active_particles = active_particles.iter()
78            .map(|x| *x ).collect::<HashSet<_>>();
79        // Construct unresolved simulation graph
80        let mut unresolved_graph = active_particles.iter()
81            .map(|particle_id| {
82                let timeline = UnresolvedTimeline::new(*particle_id, simulation_id, indices);
83                timeline.map(|timeline| (*particle_id, timeline))
84            })
85            .collect::<Result<Vec<_>>>()?;
86        // Collect all interactions from the graph
87        let mut used_interactions = HashSet::new();
88        for (_, timeline) in &unresolved_graph {
89            for node in &timeline.nodes {
90                match node {
91                    MaybeResolvedSimNode::UnresolvedInteractionBarrier(interaction_id, _) => {
92                        used_interactions.insert(*interaction_id);
93                    }
94                    MaybeResolvedSimNode::Resolved(_) |
95                    MaybeResolvedSimNode::UnresolvedCallBarrier(_) => {}
96                }
97            }
98        }
99        // Create hashmap of affected particles for every interaction
100        let affected_particles_map = used_interactions.iter()
101            .map(|interaction_id| {
102                // Unwrap is safe due to interaction ids being resolved previously
103                indices.interactions.get(*interaction_id).unwrap()
104                    .get_affected_particles(indices.particles)
105                    // Map successful result to intersection with the list of active particles
106                    .map(|affected_particles| (
107                        *interaction_id,
108                        affected_particles.intersection(&active_particles)
109                            .map(|x| *x).collect::<HashSet<_>>()
110                    ))
111            })
112            .collect::<Result<HashMap<_,_>>>()?;
113        // Initialize barrier and callback maps
114        let mut barriers: SlotMap<BarrierID, Barrier> = SlotMap::new();
115        let mut callbacks: HashMap<String, BTreeSet<BarrierID>> = HashMap::new();
116        let mut timelines: HashMap<ParticleID, Timeline> = HashMap::new();
117        // Resolve all unresolved nodes for every timeline
118        while let Some((particle_id, timeline)) = unresolved_graph.pop() {
119            // Construct particle symbol table
120            let particle_def = indices.particles.get(particle_id).unwrap();
121            let particle_symbols = SymbolTable::from_particle(particle_def);
122            // Resolve nodes
123            let mut resolved_nodes = vec![];
124            for node in timeline.nodes {
125                resolved_nodes.push(match node {
126                    // Node already resolved => Passthrough
127                    MaybeResolvedSimNode::Resolved(node) => node,
128                    // Interaction barrier? Try to find all other barriers
129                    MaybeResolvedSimNode::UnresolvedInteractionBarrier(interaction_id, quantity_id) => {
130                        // Create a new barrier
131                        let affected_particles = affected_particles_map
132                            .get(&interaction_id).unwrap().clone();
133                        let barrier = Barrier::new(affected_particles,
134                            BarrierKind::InteractionBarrier(interaction_id, quantity_id));
135                        // Register barrier
136                        let barrier_id = barriers.insert(barrier);
137                        let affected_particles = &barriers.get(barrier_id)
138                            .unwrap().affected_particles; // "Trivially" safe unwrap
139                        // Resolve corresponding nodes in other timelines
140                        for (other_particle, other_timeline) in unresolved_graph.iter_mut()
141                            // Filter for particles affected by the same interaction
142                            .filter(|(other_particle,_)| affected_particles.contains(other_particle))
143                        {
144                            // If interaction and quantity id match, replace node with this barrier
145                            let mut found = false;
146                            for other_node in &mut other_timeline.nodes {
147                                if let MaybeResolvedSimNode::UnresolvedInteractionBarrier(other_interaction_id, other_quantity_id) = other_node {
148                                    if interaction_id == *other_interaction_id && quantity_id == *other_quantity_id {
149                                        *other_node = MaybeResolvedSimNode::Resolved(SimulationNode::CommonBarrier(barrier_id));
150                                        // Break from loop (we only want to replace the first matching node)
151                                        found = true;
152                                        break
153                                    }
154                                }
155                            }
156                            // Return error if no matching update statement is found for a particle type that interacts with
157                            // the current particle type
158                            if !found {
159                                return Err(anyhow!(
160                                    "No matching update statement for interaction quantity {}:{} found for particle {} (required by interaction with particle {})",
161                                    indices.interactions.get(interaction_id).unwrap().get_name(),
162                                    match quantity_id {
163                                        None => "",
164                                        Some(quantity_id) => indices.interactions.get(interaction_id).unwrap().get_quantity(quantity_id).unwrap().get_name(),
165                                    },
166                                    indices.particles.get(*other_particle).unwrap().get_name(),
167                                    indices.particles.get(particle_id).unwrap().get_name(),
168                                ));
169                            }
170                        }
171                        SimulationNode::CommonBarrier(barrier_id)
172                    }
173                    // Call barrier? Try to find corresponding call barriers in other timelines
174                    // or add one at the end of every other timeline
175                    MaybeResolvedSimNode::UnresolvedCallBarrier(call_name) => {
176                        // Create a new barrier (calls affect all particles)
177                        let barrier = Barrier::new(active_particles.clone(),
178                            BarrierKind::CallBarrier(call_name.clone()));
179                        // Register barrier
180                        let barrier_id = barriers.insert(barrier);
181                        match callbacks.get_mut(&call_name) {
182                            Some(barriers) => {
183                                barriers.insert(barrier_id);
184                            }
185                            None => {
186                                let mut barriers = BTreeSet::new();
187                                barriers.insert(barrier_id);
188                                callbacks.insert(call_name.clone(), barriers);
189                            }
190                        }
191                        // Find barriers in other timelines
192                        for (_, other_timeline) in unresolved_graph.iter_mut() {
193                            let mut found = false;
194                            for other_node in &mut other_timeline.nodes {
195                                if let MaybeResolvedSimNode::UnresolvedCallBarrier(other_call_name) = other_node {
196                                    if *other_call_name == call_name {
197                                        *other_node = MaybeResolvedSimNode::Resolved(SimulationNode::CommonBarrier(barrier_id));
198                                        found = true;
199                                        break
200                                    }
201                                }
202                            }
203                            // If not found, insert new call barrier at the end
204                            if !found {
205                                other_timeline.nodes.push(
206                                    MaybeResolvedSimNode::Resolved(SimulationNode::CommonBarrier(barrier_id))
207                                )
208                            }
209                        }
210                        SimulationNode::CommonBarrier(barrier_id)
211                    }
212                })
213            }
214            // Push resolved nodes to timelines
215            timelines.insert(particle_id, Timeline::new(resolved_nodes, particle_symbols));
216        }
217
218        // for (pid, timeline) in &timelines {
219        //     dbg!(indices.particles.get(*pid).unwrap().get_name());
220        //     for node in &timeline.nodes {
221        //         match node {
222        //             SimulationNode::StatementBlock(_) => {println!("STATEMENT BLOCK")},
223        //             SimulationNode::CommonBarrier(_) => {println!("BARRIER")},
224        //         }
225        //     }
226        // }
227
228        Ok(Self {
229            timelines,
230            barriers,
231            callbacks
232        })
233    }
234}
235
236impl UnresolvedTimeline {
237    pub(crate) fn new<'a>(particle_id: ParticleID,
238        simulation_id: SimulationID,
239        indices: util::IndicesRef<'a>
240    ) -> Result<Self> {
241        // Unwrap is safe due to caller (Runtime::compile) resolving simulation before
242        let simulation = indices.simulations.get_simulation(&simulation_id).unwrap();
243        // Resolve default particle (if any exists)
244        let default_particle_id = simulation.get_default_particle();
245        // Create a timeline for every particle type
246        let mut nodes = vec![];
247        for simblock in simulation.get_blocks() {
248            // Canonize step ranges from step and once blocks
249            let step_range = match &simblock.kind {
250                // Once block: step range is <only step>:<only step + 1>:1
251                SimulationBlockKind::Once(step) => {
252                    let step = util::unwrap_usize_constant(step)?;
253                    StepRange::new(step, step+1, 1)
254                }
255                // Step block: step range is <start or 0>:<end or MAX>:<step>
256                SimulationBlockKind::Step(step_range) => {
257                    let (start,end,step) = (
258                        util::unwrap_usize_constant(&step_range.start)?,
259                        util::unwrap_usize_constant(&step_range.end)?,
260                        util::unwrap_usize_constant(&step_range.step)?
261                    );
262                    StepRange::new(start, end, step)
263                }
264            };
265            // Create new statement block and fill it with statements
266            let mut current_node = vec![];
267            // Quick macro for pushing node to timeline and creating a fresh one
268            let push_node = |current_node: Vec<Statement>, nodes: &mut Vec<MaybeResolvedSimNode>| -> Result<Vec<Statement>> {
269                if !current_node.is_empty() {
270                    nodes.push(MaybeResolvedSimNode::Resolved(
271                        SimulationNode::StatementBlock(
272                            StatementBlock::new(step_range.clone(),current_node)?)));
273                }
274                Ok(vec![])
275            };
276            for statement_block in &simblock.statement_blocks {
277                // Check if subblock affects this particle
278                let affected = match &statement_block.particle_filter {
279                    // Check default particle
280                    SimulationParticleFilter::Default => match default_particle_id {
281                        None => false, // TODO: Should we issue a warning here?
282                        Some(default_particle_id) => particle_id == default_particle_id
283                    },
284                    SimulationParticleFilter::Single(filter_particle_id) => 
285                        particle_id == *filter_particle_id
286                };
287                // Add statements if particle is affected
288                if affected {
289                    for statement in &statement_block.statements {
290                        match &statement {
291                            // Let and assign statements just get added to the current node
292                            Statement::Let(_) | Statement::Assign(_) => { 
293                                current_node.push(statement.clone()) 
294                            }
295                            // Update statements cause a barrier
296                            Statement::Update(update_statement) => {
297                                let interaction_name = &update_statement.interaction;
298                                // Try to resolve interaction (and quantity)
299                                let (interaction_id, interaction) = indices.interactions.get_interaction_by_name(&interaction_name)
300                                    .ok_or(anyhow!("Cannot resolve interaction with name {}", &interaction_name))?;
301                                let quantity_id = match &update_statement.quantity {
302                                    None => None,
303                                    Some(quantity_name) => Some(interaction.get_quantity_by_name(&quantity_name)
304                                        .ok_or(anyhow!("Cannot resolve quantity {} of interaction {}",
305                                            quantity_name, interaction_name))?.0)
306                                };
307                                // Push old statement block and barrier for update
308                                current_node = push_node(current_node, &mut nodes)?;
309                                nodes.push(MaybeResolvedSimNode::UnresolvedInteractionBarrier(
310                                    interaction_id, quantity_id
311                                ))
312                            }
313                            // Call statements as well
314                            Statement::Call(call_statement) => {
315                                // Push old statement block and barrier for call
316                                current_node = push_node(current_node, &mut nodes)?;
317                                nodes.push(MaybeResolvedSimNode::UnresolvedCallBarrier(
318                                    call_statement.name.clone()
319                                ))
320                            }
321                        }
322                    }
323                }
324            }
325            push_node(current_node, &mut nodes)?;
326        }
327        // Return finished timeline
328        Ok(Self { nodes })
329    }
330}
331
332impl StatementBlock {
333    pub fn new(step_range: StepRange, statements: Vec<parser::Statement>) -> Result<Self> {
334        // Extract all locally defined symbols and create symbol table
335        let mut local_symbols = SymbolTable::new();
336        for statement in &statements {
337            if let parser::Statement::Let(statement) = statement {
338                local_symbols.add_local_symbol(statement.name.clone(), statement.typ.clone())?
339            }
340        }
341        Ok(Self {
342            step_range,
343            statements,
344            local_symbols
345        })
346    }
347}
348
349impl StepRange {
350    pub fn new(start: usize, end: usize, step: usize) -> Self {
351        Self {
352            start, end, step
353        }
354    }
355}