duskphantom_middle/analysis/
memory_ssa.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::{HashMap, HashSet};
18
19use crate::ir::instruction::downcast_ref;
20use crate::{
21    analysis::dominator_tree::DominatorTree,
22    context,
23    ir::{
24        instruction::{misc_inst::Call, InstType},
25        BBPtr, Constant, FunPtr, InstPtr, Operand,
26    },
27    Program,
28};
29use anyhow::{anyhow, Context, Result};
30use duskphantom_utils::mem::{ObjPool, ObjPtr};
31
32use super::{alias_analysis::EffectRange, effect_analysis::EffectAnalysis};
33
34pub type NodePtr = ObjPtr<Node>;
35
36/// MemorySSA analyzer.
37/// Reference: https://llvm.org/docs/MemorySSA.html
38/// My version is different by analyzing the effect of function calls.
39pub struct MemorySSA<'a> {
40    builder: MemorySSABuilder,
41    functions: Vec<FunPtr>,
42    inst_to_node: HashMap<InstPtr, NodePtr>,
43    block_to_node: HashMap<BBPtr, NodePtr>,
44    node_to_block: HashMap<NodePtr, BBPtr>,
45    node_to_user: HashMap<NodePtr, HashSet<NodePtr>>,
46    pub effect_analysis: &'a EffectAnalysis,
47}
48
49impl<'a> MemorySSA<'a> {
50    /// Build MemorySSA for program.
51    pub fn new(program: &Program, effect_analysis: &'a EffectAnalysis) -> Self {
52        let mut memory_ssa = Self {
53            builder: MemorySSABuilder {
54                node_pool: ObjPool::new(),
55                counter: 0,
56            },
57            inst_to_node: HashMap::new(),
58            block_to_node: HashMap::new(),
59            node_to_block: HashMap::new(),
60            node_to_user: HashMap::new(),
61            effect_analysis,
62            functions: program.module.functions.clone(),
63        };
64        for func in program.module.functions.iter() {
65            memory_ssa.run(*func);
66        }
67        memory_ssa
68    }
69
70    /// Get node from instruction.
71    pub fn get_inst_node(&self, inst: InstPtr) -> Option<NodePtr> {
72        self.inst_to_node.get(&inst).cloned()
73    }
74
75    /// Get node from block.
76    pub fn get_block_node(&self, bb: BBPtr) -> Option<NodePtr> {
77        self.block_to_node.get(&bb).cloned()
78    }
79
80    /// Get block from node.
81    pub fn get_node_block(&self, node: NodePtr) -> Option<BBPtr> {
82        self.node_to_block.get(&node).cloned()
83    }
84
85    /// Get all users of a node.
86    pub fn get_user(&self, node: NodePtr) -> HashSet<NodePtr> {
87        self.node_to_user.get(&node).cloned().unwrap_or_default()
88    }
89
90    /// Assume `src` writes memory, predict content read from `dst`.
91    pub fn predict_read(
92        &self,
93        src: NodePtr,
94        dst: InstPtr,
95        func: FunPtr,
96    ) -> Result<Option<Operand>> {
97        let use_range = &self
98            .effect_analysis
99            .inst_effect
100            .get(&dst)
101            .ok_or_else(|| anyhow!("{} effect not found", dst.gen_llvm_ir()))
102            .with_context(|| context!())?
103            .use_range;
104        match *src {
105            Node::Entry(_) => {
106                // In main function if read from entry, load from global variable initializer
107                if func.is_main() {
108                    if let Some(ptr) = use_range.get_single() {
109                        if let Some(op) = readonly_deref(ptr, vec![Some(0)]) {
110                            return Ok(Some(op));
111                        }
112                    }
113                }
114                Ok(None)
115            }
116            Node::Normal(_, _, src, inst) => {
117                let def_range = &self
118                    .effect_analysis
119                    .inst_effect
120                    .get(&inst)
121                    .ok_or_else(|| anyhow!("{} effect not found", inst.gen_llvm_ir()))
122                    .with_context(|| context!())?
123                    .def_range;
124
125                // If range does not alias, recurse into sub-node
126                if !def_range.can_alias(use_range) {
127                    if let Some(src) = src {
128                        return self.predict_read(src, dst, func);
129                    }
130                    return Err(anyhow!("{} is not MemoryDef", inst.gen_llvm_ir()))
131                        .with_context(|| context!());
132                }
133
134                // If MemoryDef is store, return store operand if it's exact hit
135                // TODO check equality with GVN
136                if inst.get_type() == InstType::Store {
137                    if def_range == use_range {
138                        let store_op = inst
139                            .get_operand()
140                            .first()
141                            .ok_or_else(|| anyhow!("{} should be store", inst.gen_llvm_ir()))
142                            .with_context(|| context!())?;
143                        return Ok(Some(store_op.clone()));
144                    }
145                    return Ok(None);
146                }
147
148                // If MemoryDef is memset, replace with constant
149                // (we assume this memset sets 0 and is large enough)
150                if inst.get_type() == InstType::Call {
151                    let call = downcast_ref::<Call>(inst.as_ref().as_ref());
152                    if call.func.name.contains("memset") {
153                        return Ok(Some(Operand::Constant(0.into())));
154                    }
155                }
156                Ok(None)
157            }
158            Node::Phi(_, _, _) => {
159                // TODO phi translation
160                Ok(None)
161            }
162        }
163    }
164
165    /// Dump MemorySSA result to string.
166    pub fn dump(&self) -> String {
167        let mut result = String::new();
168        for func in self.functions.iter() {
169            if func.is_lib() {
170                continue;
171            }
172            result += &format!("MemorySSA for function: {}\n", func.name);
173            for bb in func.dfs_iter() {
174                result += &format!(
175                    "{}:    ; preds = {}\n",
176                    bb.name,
177                    bb.get_pred_bb()
178                        .iter()
179                        .map(|bb| bb.name.clone())
180                        .collect::<Vec<_>>()
181                        .join(", ")
182                );
183                if let Some(node) = self.block_to_node.get(&bb) {
184                    result += &self.dump_node(*node);
185                    result += "\n";
186                }
187                for inst in bb.iter() {
188                    if let Some(node) = self.inst_to_node.get(&inst) {
189                        result += &self.dump_node(*node);
190                        result += "\n";
191                    }
192                    result += &inst.gen_llvm_ir();
193                    result += "\n";
194                }
195                result += "\n";
196            }
197        }
198        result
199    }
200
201    /// Dump a node to string.
202    pub fn dump_node(&self, node: NodePtr) -> String {
203        match node.as_ref() {
204            Node::Entry(id) => format!("; {} (liveOnEntry)", id),
205            Node::Normal(id, use_node, def_node, _) => {
206                let mut result: Vec<String> = Vec::new();
207                if let Some(use_node) = use_node {
208                    result.push(format!("; MemoryUse({})", use_node.get_id()));
209                }
210                if let Some(def_node) = def_node {
211                    result.push(format!("; {} = MemoryDef({})", id, def_node.get_id()));
212                }
213                result.join("\n")
214            }
215            Node::Phi(id, arg, _) => {
216                let mut args: Vec<String> = Vec::new();
217                for (bb, node) in arg {
218                    args.push(format!("[{}, {}]", node.get_id(), bb.name));
219                }
220                format!("; {} = MemoryPhi({})", id, args.join(", "))
221            }
222        }
223    }
224
225    /// Remove a node, update use-def chain.
226    ///
227    /// # Panics
228    /// Do not remove used phi node or entry node with this function!
229    pub fn remove_node(&mut self, node: NodePtr) {
230        let used_nodes = node.get_used_node();
231        for used_node in &used_nodes {
232            self.node_to_user.get_mut(used_node).unwrap().remove(&node);
233        }
234
235        // Wire users of this node to the first used node
236        let users = self.get_user(node);
237        for mut user in users {
238            let cloned_user = user;
239            let user_used_nodes = user.get_used_node_mut();
240            for user_used_node in user_used_nodes {
241                if *user_used_node == node {
242                    let used_node = used_nodes[0];
243                    *user_used_node = used_node;
244                    self.node_to_user
245                        .get_mut(&used_node)
246                        .unwrap()
247                        .insert(cloned_user);
248                }
249            }
250        }
251    }
252
253    /// Build MemorySSA for function.
254    fn run(&mut self, func: FunPtr) {
255        let Some(entry) = func.entry else {
256            return;
257        };
258
259        // Add entry node
260        let mut range_to_node = RangeToNode::new();
261        let entry_node = self.builder.get_entry();
262        self.block_to_node.insert(entry, entry_node);
263        self.node_to_block.insert(entry_node, entry);
264        range_to_node.insert(EffectRange::All, entry_node);
265
266        // Insert empty phi nodes
267        let phi_insertions = self.insert_empty_phi(func);
268
269        // Add other nodes
270        self.add_node_start_from(
271            None,
272            entry,
273            &mut HashSet::new(),
274            &mut range_to_node,
275            &phi_insertions,
276        )
277    }
278
279    /// Add nodes starting from `current_bb`.
280    fn add_node_start_from(
281        &mut self,
282        parent_bb: Option<BBPtr>,
283        current_bb: BBPtr,
284        visited: &mut HashSet<BBPtr>,
285        range_to_node: &mut RangeToNode,
286        phi_insertions: &HashMap<BBPtr, PhiInsertion>,
287    ) {
288        // Add argument for "phi" instruction
289        if let Some(mut phi) = phi_insertions.get(&current_bb).and_then(|p| p.get()) {
290            let value = range_to_node.get_def(phi.get_effect_range()).unwrap();
291            phi.add_phi_arg((parent_bb.unwrap(), value));
292            self.node_to_user.entry(value).or_default().insert(phi);
293            range_to_node.insert(phi.get_effect_range().clone(), phi);
294        }
295
296        // Do not continue if visited
297        // Argument of "phi" instruction need to be added multiple times,
298        // so that part is before this check
299        if visited.contains(&current_bb) {
300            return;
301        }
302        visited.insert(current_bb);
303
304        // Build MemorySSA for each node
305        for inst in current_bb.iter() {
306            if let Some(effect) = self.effect_analysis.inst_effect.get(&inst) {
307                let def_range = effect.def_range.clone();
308                let use_range = effect.use_range.clone();
309                let def_node = range_to_node.get_def(&def_range);
310                let use_node = range_to_node.get_use(&use_range);
311                let new_node = self.create_normal_node(use_node, def_node, inst);
312                range_to_node.insert(def_range, new_node);
313            }
314        }
315
316        // Visit all successors
317        let successors = current_bb.get_succ_bb();
318        for succ in successors {
319            self.add_node_start_from(
320                Some(current_bb),
321                *succ,
322                visited,
323                &mut range_to_node.branch(),
324                phi_insertions,
325            );
326        }
327    }
328
329    /// Create a normal node.
330    fn create_normal_node(
331        &mut self,
332        use_node: Option<NodePtr>,
333        def_node: Option<NodePtr>,
334        inst: InstPtr,
335    ) -> NodePtr {
336        let node = self.builder.get_normal_node(use_node, def_node, inst);
337        self.inst_to_node.insert(inst, node);
338        if let Some(use_node) = use_node {
339            self.node_to_user.entry(use_node).or_default().insert(node);
340        }
341        if let Some(def_node) = def_node {
342            self.node_to_user.entry(def_node).or_default().insert(node);
343        }
344        self.node_to_block
345            .insert(node, inst.get_parent_bb().unwrap());
346        node
347    }
348
349    /// Insert empty "phi" for basic blocks starting from `entry`
350    /// Returns a mapping from basic block to phi insertions
351    #[allow(unused)]
352    fn insert_empty_phi(&mut self, func: FunPtr) -> HashMap<BBPtr, PhiInsertion> {
353        let entry = func.entry.unwrap();
354        let mut phi_insertions: HashMap<BBPtr, PhiInsertion> = HashMap::new();
355        let mut dom_tree = DominatorTree::new(func);
356
357        for bb in func.dfs_iter() {
358            for inst in bb.iter() {
359                if let Some(effect) = self.effect_analysis.inst_effect.get(&inst) {
360                    // Only insert phi for stores
361                    if effect.def_range.is_empty() {
362                        continue;
363                    }
364
365                    // Insert phi with DFS on dominance frontier tree
366                    let mut visited = HashSet::new();
367                    let mut positions: Vec<(BBPtr, EffectRange)> = Vec::new();
368                    positions.push((bb, effect.def_range.clone()));
369                    while let Some((position, range)) = positions.pop() {
370                        if visited.contains(&position) {
371                            continue;
372                        }
373                        visited.insert(position);
374                        let df = dom_tree.get_df(position);
375
376                        // Insert phi for each dominance frontier, update effect range
377                        for bb in df {
378                            let phi = self.builder.get_phi(range.clone());
379                            let phi = phi_insertions.entry(bb).or_default().insert(phi);
380                            self.block_to_node.insert(bb, phi);
381                            self.node_to_block.insert(phi, bb);
382                            positions.push((bb, phi.get_effect_range().clone()));
383                        }
384                    }
385                }
386            }
387        }
388
389        // Return result
390        phi_insertions
391    }
392}
393
394/// Memory pool for MemorySSA nodes.
395struct MemorySSABuilder {
396    node_pool: ObjPool<Node>,
397    counter: usize,
398}
399
400impl MemorySSABuilder {
401    /// Allocate a new node.
402    fn new_node(&mut self, node: Node) -> NodePtr {
403        self.node_pool.alloc(node)
404    }
405
406    /// Returns a unique ID.
407    fn next_counter(&mut self) -> usize {
408        let counter = self.counter;
409        self.counter += 1;
410        counter
411    }
412
413    /// Get an entry node.
414    fn get_entry(&mut self) -> NodePtr {
415        let next_counter = self.next_counter();
416        self.new_node(Node::Entry(next_counter))
417    }
418
419    /// Get a normal node.
420    fn get_normal_node(
421        &mut self,
422        use_node: Option<NodePtr>,
423        def_node: Option<NodePtr>,
424        inst: InstPtr,
425    ) -> NodePtr {
426        let next_counter = self.next_counter();
427        self.new_node(Node::Normal(next_counter, use_node, def_node, inst))
428    }
429
430    /// Get a phi node.
431    fn get_phi(&mut self, range: EffectRange) -> NodePtr {
432        let next_counter = self.next_counter();
433        self.new_node(Node::Phi(next_counter, Vec::new(), range))
434    }
435}
436
437/// Memory SSA node.
438/// Function in Node does not maintain use-def chain.
439pub enum Node {
440    /// Entry(id) represents the memory state at the beginning of the function.
441    Entry(usize),
442
443    /// Normal(id, use_node, def_node, inst) represents a memory state after an instruction.
444    Normal(usize, Option<NodePtr>, Option<NodePtr>, InstPtr),
445
446    /// Phi(id, args, range) represents a phi node.
447    Phi(usize, Vec<(BBPtr, NodePtr)>, EffectRange),
448}
449
450impl Node {
451    /// Get instruction if it's a normal node.
452    pub fn get_inst(&self) -> Option<InstPtr> {
453        match self {
454            Node::Normal(_, _, _, inst) => Some(*inst),
455            _ => None,
456        }
457    }
458
459    /// Get ID of the node.
460    pub fn get_id(&self) -> usize {
461        match self {
462            Node::Entry(id) => *id,
463            Node::Normal(id, _, _, _) => *id,
464            Node::Phi(id, _, _) => *id,
465        }
466    }
467
468    /// Get use node.
469    /// Use node is the node that is read from.
470    pub fn get_use_node(&self) -> NodePtr {
471        match self {
472            Node::Normal(_, use_node, _, _) => use_node.unwrap(),
473            _ => panic!("not a normal node"),
474        }
475    }
476
477    /// Get used nodes.
478    /// Used nodes contains both use and def nodes.
479    pub fn get_used_node(&self) -> Vec<NodePtr> {
480        match self {
481            Node::Normal(_, use_node, def_node, _) => {
482                let mut result = Vec::new();
483                if let Some(node) = use_node {
484                    result.push(*node);
485                }
486                if let Some(node) = def_node {
487                    result.push(*node);
488                }
489                result
490            }
491            Node::Phi(_, args, _) => args.iter().map(|(_, node)| *node).collect(),
492            _ => Vec::new(),
493        }
494    }
495
496    /// Get mutable used nodes.
497    /// Used nodes contains both use and def nodes.
498    pub fn get_used_node_mut(&mut self) -> Vec<&mut NodePtr> {
499        match self {
500            Node::Normal(_, use_node, def_node, _) => {
501                let mut result = Vec::new();
502                if let Some(node) = use_node {
503                    result.push(node);
504                }
505                if let Some(node) = def_node {
506                    result.push(node);
507                }
508                result
509            }
510            Node::Phi(_, args, _) => args.iter_mut().map(|(_, node)| node).collect(),
511            _ => Vec::new(),
512        }
513    }
514
515    /// Add an argument to a phi node.
516    fn add_phi_arg(&mut self, arg: (BBPtr, NodePtr)) {
517        match self {
518            Node::Phi(_, args, _) => args.push(arg),
519            _ => panic!("not a phi node"),
520        }
521    }
522
523    /// Get effect range of a phi node.
524    fn get_effect_range(&self) -> &EffectRange {
525        match self {
526            Node::Phi(_, _, range) => range,
527            _ => panic!("not a phi node"),
528        }
529    }
530
531    /// Merge effect range of a phi node.
532    fn merge_effect_range(&mut self, another: &EffectRange) {
533        match self {
534            Node::Phi(_, _, range) => range.merge(another),
535            _ => panic!("not a phi node"),
536        }
537    }
538}
539
540/// Phi insertion for a block. (Some(Node) or None)
541pub struct PhiInsertion(Option<NodePtr>);
542
543impl PhiInsertion {
544    /// Initialize an empty phi insertion.
545    pub fn new() -> Self {
546        Self(None)
547    }
548
549    /// Insert an empty phi node.
550    /// Returns the inserted or merged phi node.
551    pub fn insert(&mut self, phi: NodePtr) -> NodePtr {
552        if let Some(node) = self.0.as_mut() {
553            node.merge_effect_range(phi.get_effect_range());
554            return *node;
555        }
556        self.0 = Some(phi);
557        phi
558    }
559
560    /// Get containing phi node.
561    pub fn get(&self) -> Option<NodePtr> {
562        self.0
563    }
564}
565
566impl Default for PhiInsertion {
567    fn default() -> Self {
568        Self::new()
569    }
570}
571
572/// Framed mapping from range to node.
573pub enum RangeToNode<'a> {
574    Root(RangeToNodeFrame),
575    Leaf(RangeToNodeFrame, &'a RangeToNode<'a>),
576}
577
578impl Default for RangeToNode<'_> {
579    fn default() -> Self {
580        Self::Root(RangeToNodeFrame::default())
581    }
582}
583
584impl<'a> RangeToNode<'a> {
585    /// Create a new FrameMap.
586    pub fn new() -> Self {
587        Self::default()
588    }
589
590    /// Get the last frame.
591    pub fn last_frame(&mut self) -> &mut RangeToNodeFrame {
592        match self {
593            Self::Root(map) => map,
594            Self::Leaf(map, _) => map,
595        }
596    }
597
598    /// Insert a new element into the last frame.
599    pub fn insert(&mut self, k: EffectRange, v: NodePtr) {
600        if k.is_empty() {
601            return;
602        }
603        self.last_frame().insert(k, v);
604    }
605
606    /// Get an element from all frames with def range.
607    /// Returns the element.
608    pub fn get_def(&self, def_range: &EffectRange) -> Option<NodePtr> {
609        if def_range.is_empty() {
610            return None;
611        }
612        let mut map = self;
613        loop {
614            match map {
615                Self::Root(m) => return m.get_def(),
616                Self::Leaf(m, parent) => {
617                    if let Some(v) = m.get_def() {
618                        return Some(v);
619                    }
620                    map = parent;
621                }
622            }
623        }
624    }
625
626    /// Get an element from all frames with use range.
627    pub fn get_use(&self, use_range: &EffectRange) -> Option<NodePtr> {
628        if use_range.is_empty() {
629            return None;
630        }
631        let mut map = self;
632        loop {
633            match map {
634                Self::Root(m) => return m.get_use(use_range),
635                Self::Leaf(m, parent) => {
636                    if let Some(v) = m.get_use(use_range) {
637                        return Some(v);
638                    }
639                    map = parent;
640                }
641            }
642        }
643    }
644
645    /// Make a branch on the frame map.
646    /// Modifications on the new branch will not affect the original one.
647    /// This is useful when implementing scopes.
648    pub fn branch(&'a self) -> Self {
649        Self::Leaf(RangeToNodeFrame::default(), self)
650    }
651}
652
653/// One frame of range to node mapping.
654#[derive(Default)]
655pub struct RangeToNodeFrame(Vec<(EffectRange, NodePtr)>);
656
657impl RangeToNodeFrame {
658    pub fn insert(&mut self, k: EffectRange, v: NodePtr) {
659        self.0.push((k, v));
660    }
661
662    /// Get the last definition.
663    pub fn get_def(&self) -> Option<NodePtr> {
664        self.0.last().map(|(_, n)| *n)
665    }
666
667    /// Get an element from the frame with given use range.
668    pub fn get_use(&self, use_range: &EffectRange) -> Option<NodePtr> {
669        for (key, value) in self.0.iter().rev() {
670            if key.can_alias(use_range) {
671                return Some(*value);
672            }
673        }
674        None
675    }
676}
677
678/// Deref readonly array operand with maybe-constant indices.
679fn readonly_deref(op: &Operand, mut index: Vec<Option<i32>>) -> Option<Operand> {
680    match op {
681        Operand::Global(gvar) => {
682            let mut val = &gvar.as_ref().initializer;
683
684            // For a[0][1][2], it translates to something like `gep (gep a, 0, 0), 0, 1, 2`
685            // Calling `readonly_deref` on it will first push `2, 1` to index array, and then push `0 + 0, 0` (reversed!)
686            // To get the final value, we iterate the index array in reverse order
687            // The last index is skipped because it moves the base pointer, that's not valid for global var
688            for i in index.iter().rev().skip(1) {
689                if let Constant::Array(arr) = val {
690                    if let Some(i) = i {
691                        if let Some(element) = arr.get(*i as usize) {
692                            val = element;
693                        } else {
694                            return None;
695                        }
696                    } else {
697                        return None;
698                    }
699                } else if let Constant::Zero(_) = val {
700                    return Some(Operand::Constant(0.into()));
701                } else {
702                    return None;
703                }
704            }
705            Some(val.clone().into())
706        }
707        Operand::Instruction(inst) => {
708            if inst.get_type() == InstType::GetElementPtr {
709                // For example, if we have `index = [out_ix1, out_ix0]`, and `inst = gep %ptr, ix0, ix1, ix2`
710                // We would try to combine `ix2` with `out_ix0` first, so that index is `[out_ix1, out_ix0 + ix2]`
711                // The index array is in reverse order, and will be read reversed when indexing array
712                let last_const = if let (Some(Some(i)), Some(Operand::Constant(Constant::Int(j)))) =
713                    (index.last_mut(), inst.get_operand().last())
714                {
715                    *i += *j;
716                    true
717                } else {
718                    false
719                };
720
721                // If there is non-constant in last index, set it to None
722                if !last_const {
723                    if let Some(i) = index.last_mut() {
724                        *i = None;
725                    }
726                }
727
728                // Add `ix1, ix0` to index array, so that it becomes `[out_ix1, out_ix0 + ix2, ix1, ix0]`
729                index.extend(
730                    inst.get_operand()
731                        .iter()
732                        .skip(1)
733                        .rev()
734                        .skip(1)
735                        .cloned()
736                        .map(|op| {
737                            if let Operand::Constant(Constant::Int(i)) = op {
738                                Some(i)
739                            } else {
740                                None
741                            }
742                        }),
743                );
744
745                // Recurse into sub-expression
746                return readonly_deref(inst.get_operand().first().unwrap(), index);
747            }
748            None
749        }
750        _ => None,
751    }
752}