sparta/
fixpoint_iter.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::atomic::AtomicU32;
13use std::sync::atomic::Ordering;
14
15use crate::datatype::AbstractDomain;
16use crate::graph::Graph;
17use crate::graph::SuccessorNodes;
18use crate::wpo::WeakPartialOrdering;
19use crate::wpo::WpoIdx;
20
21// Unlike the C++ version, we don't treat this as a base interface
22// for all iterators, instead, it's only a trait for concrete type
23// to analyze nodes and edges. And this can be composited in the
24// fixpoint iterator type. We will revisit this part in the following
25// diff for analyzer interface.
26pub trait FixpointIteratorTransformer<G: Graph, D: AbstractDomain> {
27    /// The *current_state* could be updated in place.
28    fn analyze_node(&mut self, n: G::NodeId, current_state: &mut D);
29
30    fn analyze_edge(&mut self, e: G::EdgeId, exit_state_at_src: &D) -> D;
31}
32
33pub struct MonotonicFixpointIteratorContext<G: Graph, D: AbstractDomain> {
34    init_value: D,
35    local_iterations: HashMap<G::NodeId, u32>,
36    global_iterations: HashMap<G::NodeId, u32>,
37}
38
39impl<G, D> MonotonicFixpointIteratorContext<G, D>
40where
41    G: Graph,
42    D: AbstractDomain,
43{
44    pub fn get_local_iterations_for(&self, n: G::NodeId) -> u32 {
45        *self.local_iterations.get(&n).unwrap_or(&0)
46    }
47
48    pub fn get_global_iterations_for(&self, n: G::NodeId) -> u32 {
49        *self.global_iterations.get(&n).unwrap_or(&0)
50    }
51
52    pub fn get_init_value(&self) -> &D {
53        &self.init_value
54    }
55
56    fn increase_iteration_count(n: G::NodeId, table: &mut HashMap<G::NodeId, u32>) {
57        *table.entry(n).or_default() += 1;
58    }
59
60    pub fn increase_iteration_count_for(&mut self, n: G::NodeId) {
61        Self::increase_iteration_count(n, &mut self.local_iterations);
62        Self::increase_iteration_count(n, &mut self.global_iterations);
63    }
64
65    pub fn reset_local_iteration_count_for(&mut self, n: G::NodeId) {
66        *self.local_iterations.entry(n).or_default() = 0;
67    }
68
69    pub fn new(init_value: D) -> Self {
70        Self {
71            init_value,
72            local_iterations: Default::default(),
73            global_iterations: Default::default(),
74        }
75    }
76
77    pub fn with_nodes(mut self, nodes: &HashSet<G::NodeId>) -> Self {
78        for &node in nodes {
79            *self.global_iterations.entry(node).or_default() = 0;
80            *self.local_iterations.entry(node).or_default() = 0;
81        }
82        self
83    }
84}
85
86pub struct MonotonicFixpointIterator<
87    'g,
88    G: Graph,
89    D: AbstractDomain,
90    T: FixpointIteratorTransformer<G, D>,
91> {
92    graph: &'g G,
93    entry_states: HashMap<G::NodeId, D>,
94    exit_states: HashMap<G::NodeId, D>,
95    transformer: T,
96    wpo: WeakPartialOrdering<G::NodeId>,
97}
98
99impl<'g, G, D, T> MonotonicFixpointIterator<'g, G, D, T>
100where
101    G: Graph,
102    D: AbstractDomain,
103    T: FixpointIteratorTransformer<G, D>,
104{
105    pub fn new<SN>(g: &'g G, cfg_size_hint: usize, transformer: T, successors_nodes: &SN) -> Self
106    where
107        SN: SuccessorNodes<NodeId = G::NodeId>,
108    {
109        let wpo = WeakPartialOrdering::new(g.entry(), g.size(), successors_nodes);
110        Self {
111            graph: g,
112            entry_states: HashMap::with_capacity(cfg_size_hint),
113            exit_states: HashMap::with_capacity(cfg_size_hint),
114            transformer,
115            wpo,
116        }
117    }
118
119    pub fn run(&mut self, init_value: D) {
120        self.clear();
121
122        let mut context = MonotonicFixpointIteratorContext::new(init_value);
123        let wpo_counter: Vec<AtomicU32> =
124            (0..self.wpo.size()).map(|_| Default::default()).collect();
125
126        let mut worklist = VecDeque::new();
127        let entry_idx = self.wpo.get_entry();
128        worklist.push_front(entry_idx);
129        assert_eq!(self.wpo.get_num_preds(entry_idx), 0);
130
131        let mut process_node = |wpo_idx: WpoIdx, worklist: &mut VecDeque<WpoIdx>| {
132            assert_eq!(
133                wpo_counter[wpo_idx as usize].load(Ordering::Relaxed),
134                self.wpo.get_num_preds(wpo_idx)
135            );
136
137            wpo_counter[wpo_idx as usize].store(0, Ordering::Relaxed);
138
139            if !self.wpo.is_exit(wpo_idx) {
140                self.analyze_vertex(&context, self.wpo.get_node(wpo_idx));
141
142                for &succ_idx in self.wpo.get_successors(wpo_idx) {
143                    let old_counter =
144                        wpo_counter[succ_idx as usize].fetch_add(1, Ordering::Relaxed);
145                    if old_counter + 1 == self.wpo.get_num_preds(succ_idx) {
146                        worklist.push_back(succ_idx);
147                    }
148                }
149
150                return;
151            }
152
153            let head_idx = self.wpo.get_head_of_exit(wpo_idx);
154            let head = self.wpo.get_node(head_idx);
155            let current_state = self.entry_states.entry(head).or_insert_with(D::bottom);
156            let mut new_state = D::bottom();
157            Self::compute_entry_state(
158                self.graph,
159                &self.exit_states,
160                &mut self.transformer,
161                &context,
162                head,
163                &mut new_state,
164            );
165
166            if new_state.leq(current_state) {
167                context.reset_local_iteration_count_for(head);
168                *current_state = new_state;
169
170                for &succ_idx in self.wpo.get_successors(wpo_idx) {
171                    let old_counter =
172                        wpo_counter[succ_idx as usize].fetch_add(1, Ordering::Relaxed);
173                    if old_counter + 1 == self.wpo.get_num_preds(succ_idx) {
174                        worklist.push_back(succ_idx);
175                    }
176                }
177            } else {
178                Self::extrapolate(&context, head, current_state, new_state);
179                context.increase_iteration_count_for(head);
180                for (&component_idx, &num) in self.wpo.get_num_outer_preds(wpo_idx) {
181                    assert!(component_idx != entry_idx);
182                    let old_counter =
183                        wpo_counter[component_idx as usize].fetch_add(num, Ordering::Relaxed);
184                    if old_counter + num == self.wpo.get_num_preds(component_idx) {
185                        worklist.push_back(component_idx);
186                    }
187                }
188
189                if head_idx == entry_idx {
190                    worklist.push_back(head_idx);
191                }
192            }
193        };
194
195        while let Some(idx) = worklist.pop_front() {
196            process_node(idx, &mut worklist);
197        }
198
199        for counter in wpo_counter {
200            assert_eq!(counter.load(Ordering::Relaxed), 0);
201        }
202    }
203
204    /// Default strategy for applying widening operator (apply
205    /// join at the first iteration and then widening in all
206    /// rest iterations).
207    pub fn extrapolate(
208        context: &MonotonicFixpointIteratorContext<G, D>,
209        n: G::NodeId,
210        current_state: &mut D,
211        new_state: D,
212    ) {
213        if 0 == context.get_global_iterations_for(n) {
214            // TODO: we need to revisit this design, should we use
215            // move or clone of domain?
216            current_state.join_with(new_state);
217        } else {
218            current_state.widen_with(new_state);
219        }
220    }
221
222    fn get_state_at_or_bottom(states: &HashMap<G::NodeId, D>, n: G::NodeId) -> Cow<'_, D> {
223        if let Some(state) = states.get(&n) {
224            Cow::Borrowed(state)
225        } else {
226            Cow::Owned(D::bottom())
227        }
228    }
229
230    pub fn get_entry_state_at(&self, n: G::NodeId) -> Cow<'_, D> {
231        Self::get_state_at_or_bottom(&self.entry_states, n)
232    }
233
234    pub fn get_exit_state_at(&self, n: G::NodeId) -> Cow<'_, D> {
235        Self::get_state_at_or_bottom(&self.exit_states, n)
236    }
237
238    pub fn clear(&mut self) {
239        self.entry_states.clear();
240        self.entry_states.shrink_to_fit();
241        self.exit_states.clear();
242        self.exit_states.shrink_to_fit();
243    }
244
245    pub fn set_all_to_bottom(&mut self, all_nodes: &HashSet<G::NodeId>) {
246        for &node in all_nodes {
247            self.entry_states
248                .entry(node)
249                .and_modify(|s| *s = D::bottom())
250                .or_insert_with(D::bottom);
251            self.exit_states
252                .entry(node)
253                .and_modify(|s| *s = D::bottom())
254                .or_insert_with(D::bottom);
255        }
256    }
257
258    pub fn compute_entry_state(
259        graph: &'g G,
260        exit_states: &HashMap<G::NodeId, D>,
261        transformer: &mut T,
262        context: &MonotonicFixpointIteratorContext<G, D>,
263        n: G::NodeId,
264        entry_state: &mut D,
265    ) {
266        if n == graph.entry() {
267            entry_state.join_with(context.get_init_value().clone());
268        }
269
270        for e in graph.predecessors(n) {
271            let d = Self::get_state_at_or_bottom(exit_states, graph.source(e));
272            entry_state.join_with(transformer.analyze_edge(e, &d));
273        }
274    }
275
276    pub fn analyze_vertex(
277        &mut self,
278        context: &MonotonicFixpointIteratorContext<G, D>,
279        n: G::NodeId,
280    ) {
281        let entry_state = self.entry_states.entry(n).or_insert_with(D::bottom);
282        Self::compute_entry_state(
283            self.graph,
284            &self.exit_states,
285            &mut self.transformer,
286            context,
287            n,
288            entry_state,
289        );
290        let exit_state = self
291            .exit_states
292            .entry(n)
293            .and_modify(|s| *s = entry_state.clone())
294            .or_insert_with(|| entry_state.clone());
295        self.transformer.analyze_node(n, exit_state);
296    }
297}