Skip to main content

synth_cfg/
lib.rs

1//! Control Flow Graph (CFG) for WebAssembly Functions
2//!
3//! This crate provides CFG construction and analysis for WebAssembly functions,
4//! enabling proper branch target resolution and optimization.
5//!
6//! ## Key Concepts
7//!
8//! - **Basic Block**: A maximal sequence of instructions with single entry and exit
9//! - **CFG Edge**: Control flow from one basic block to another
10//! - **Dominance**: Block A dominates B if all paths to B go through A
11//! - **Loop**: A strongly connected component in the CFG
12//!
13//! ## Usage
14//!
15//! ```ignore
16//! use synth_cfg::{CfgBuilder, BasicBlock};
17//!
18//! let mut builder = CfgBuilder::new();
19//! builder.add_instruction(/* ... */);
20//! let cfg = builder.build();
21//! ```
22
23use std::collections::{HashMap, HashSet, VecDeque};
24
25/// A basic block in the CFG
26#[derive(Debug, Clone)]
27pub struct BasicBlock {
28    /// Unique ID for this block
29    pub id: BlockId,
30
31    /// Start index in instruction stream
32    pub start: usize,
33
34    /// End index (exclusive) in instruction stream
35    pub end: usize,
36
37    /// Successor blocks (control flow targets)
38    pub successors: Vec<BlockId>,
39
40    /// Predecessor blocks (blocks that jump here)
41    pub predecessors: Vec<BlockId>,
42
43    /// Label depth (for structured control flow)
44    pub label_depth: usize,
45
46    /// Whether this block is a loop header
47    pub is_loop_header: bool,
48}
49
50/// Block identifier
51pub type BlockId = usize;
52
53/// Control Flow Graph
54#[derive(Debug, Clone)]
55pub struct Cfg {
56    /// All basic blocks
57    pub blocks: HashMap<BlockId, BasicBlock>,
58
59    /// Entry block ID
60    pub entry: BlockId,
61
62    /// Exit block ID (if function has explicit return)
63    pub exit: Option<BlockId>,
64
65    /// Loop information
66    pub loops: Vec<Loop>,
67}
68
69/// Loop information
70#[derive(Debug, Clone)]
71pub struct Loop {
72    /// Loop header block
73    pub header: BlockId,
74
75    /// Blocks in the loop body
76    pub body: HashSet<BlockId>,
77
78    /// Loop depth (nested loops have higher depth)
79    pub depth: usize,
80}
81
82impl Cfg {
83    /// Get a basic block by ID
84    pub fn block(&self, id: BlockId) -> Option<&BasicBlock> {
85        self.blocks.get(&id)
86    }
87
88    /// Get a mutable basic block by ID
89    pub fn block_mut(&mut self, id: BlockId) -> Option<&mut BasicBlock> {
90        self.blocks.get_mut(&id)
91    }
92
93    /// Iterate over all blocks in RPO (Reverse Post-Order)
94    pub fn blocks_rpo(&self) -> Vec<BlockId> {
95        let mut visited = HashSet::new();
96        let mut post_order = Vec::new();
97
98        self.dfs_post_order(self.entry, &mut visited, &mut post_order);
99
100        post_order.reverse();
101        post_order
102    }
103
104    fn dfs_post_order(
105        &self,
106        block_id: BlockId,
107        visited: &mut HashSet<BlockId>,
108        post_order: &mut Vec<BlockId>,
109    ) {
110        if visited.contains(&block_id) {
111            return;
112        }
113        visited.insert(block_id);
114
115        if let Some(block) = self.blocks.get(&block_id) {
116            for &succ in &block.successors {
117                self.dfs_post_order(succ, visited, post_order);
118            }
119        }
120
121        post_order.push(block_id);
122    }
123
124    /// Compute dominator tree
125    pub fn dominators(&self) -> HashMap<BlockId, BlockId> {
126        let mut doms = HashMap::new();
127        doms.insert(self.entry, self.entry);
128
129        let rpo = self.blocks_rpo();
130        let mut changed = true;
131
132        while changed {
133            changed = false;
134            for &block_id in &rpo {
135                if block_id == self.entry {
136                    continue;
137                }
138
139                let block = self.blocks.get(&block_id).unwrap();
140
141                // Find immediate dominator
142                let mut new_idom = None;
143                for &pred in &block.predecessors {
144                    if doms.contains_key(&pred) {
145                        new_idom = Some(if let Some(curr_idom) = new_idom {
146                            self.intersect(curr_idom, pred, &doms, &rpo)
147                        } else {
148                            pred
149                        });
150                    }
151                }
152
153                if let Some(new_idom) = new_idom
154                    && doms.get(&block_id) != Some(&new_idom)
155                {
156                    doms.insert(block_id, new_idom);
157                    changed = true;
158                }
159            }
160        }
161
162        doms
163    }
164
165    fn intersect(
166        &self,
167        mut b1: BlockId,
168        mut b2: BlockId,
169        doms: &HashMap<BlockId, BlockId>,
170        rpo: &[BlockId],
171    ) -> BlockId {
172        let rpo_map: HashMap<BlockId, usize> =
173            rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect();
174
175        while b1 != b2 {
176            while rpo_map[&b1] > rpo_map[&b2] {
177                b1 = doms[&b1];
178            }
179            while rpo_map[&b2] > rpo_map[&b1] {
180                b2 = doms[&b2];
181            }
182        }
183
184        b1
185    }
186
187    /// Detect natural loops in the CFG
188    pub fn detect_loops(&mut self) {
189        let doms = self.dominators();
190        let mut loops = Vec::new();
191
192        // Find back edges (edges where target dominates source)
193        for (block_id, block) in &self.blocks {
194            for &succ in &block.successors {
195                if doms.contains_key(block_id) && self.dominates(succ, *block_id, &doms) {
196                    // Back edge found: block_id -> succ is a back edge
197                    // succ is the loop header
198                    let body = self.find_loop_body(succ, *block_id);
199                    loops.push(Loop {
200                        header: succ,
201                        body,
202                        depth: 0, // Will be computed later
203                    });
204                }
205            }
206        }
207
208        // Compute loop depths
209        for i in 0..loops.len() {
210            let mut depth = 1;
211            for j in 0..loops.len() {
212                if i != j && loops[j].body.contains(&loops[i].header) {
213                    depth += 1;
214                }
215            }
216            loops[i].depth = depth;
217        }
218
219        self.loops = loops;
220    }
221
222    fn dominates(
223        &self,
224        dominator: BlockId,
225        block: BlockId,
226        doms: &HashMap<BlockId, BlockId>,
227    ) -> bool {
228        let mut current = block;
229        loop {
230            if current == dominator {
231                return true;
232            }
233            if let Some(&idom) = doms.get(&current) {
234                if idom == current {
235                    return false; // Reached entry
236                }
237                current = idom;
238            } else {
239                return false;
240            }
241        }
242    }
243
244    fn find_loop_body(&self, header: BlockId, back_edge_source: BlockId) -> HashSet<BlockId> {
245        let mut body = HashSet::new();
246        body.insert(header);
247
248        let mut worklist = VecDeque::new();
249        worklist.push_back(back_edge_source);
250
251        while let Some(block_id) = worklist.pop_front() {
252            if !body.contains(&block_id) {
253                body.insert(block_id);
254
255                if let Some(block) = self.blocks.get(&block_id) {
256                    for &pred in &block.predecessors {
257                        worklist.push_back(pred);
258                    }
259                }
260            }
261        }
262
263        body
264    }
265
266    /// Find all blocks reachable from entry (helper for optimization)
267    pub fn reachable_blocks(&self) -> HashSet<BlockId> {
268        let mut reachable = HashSet::new();
269        let mut worklist = VecDeque::new();
270        worklist.push_back(self.entry);
271
272        while let Some(block_id) = worklist.pop_front() {
273            if reachable.contains(&block_id) {
274                continue;
275            }
276            reachable.insert(block_id);
277
278            if let Some(block) = self.blocks.get(&block_id) {
279                for &succ in &block.successors {
280                    worklist.push_back(succ);
281                }
282            }
283        }
284
285        reachable
286    }
287
288    /// Merge basic blocks (CFG optimization)
289    /// Merge block B into block A if:
290    /// - A has only one successor (B)
291    /// - B has only one predecessor (A)
292    /// - B is not the entry block
293    ///
294    /// Returns the number of blocks merged
295    pub fn merge_blocks(&mut self) -> usize {
296        let mut merged_count = 0;
297        let mut changed = true;
298
299        while changed {
300            changed = false;
301            let blocks: Vec<BlockId> = self.blocks.keys().copied().collect();
302
303            for block_a_id in blocks {
304                let can_merge = {
305                    let block_a = match self.blocks.get(&block_a_id) {
306                        Some(b) => b,
307                        None => continue,
308                    };
309
310                    if block_a.successors.len() != 1 {
311                        continue;
312                    }
313
314                    let block_b_id = block_a.successors[0];
315                    if block_b_id == self.entry {
316                        continue;
317                    }
318
319                    let block_b = match self.blocks.get(&block_b_id) {
320                        Some(b) => b,
321                        None => continue,
322                    };
323
324                    if block_b.predecessors.len() != 1 {
325                        continue;
326                    }
327
328                    Some((block_a_id, block_b_id, block_b.successors.clone()))
329                };
330
331                if let Some((a_id, b_id, b_successors)) = can_merge {
332                    // Get the end position from block B before borrowing A mutably
333                    let b_end = self.blocks.get(&b_id).unwrap().end;
334
335                    // Merge B into A
336                    if let Some(block_a) = self.blocks.get_mut(&a_id) {
337                        block_a.end = b_end;
338                        block_a.successors = b_successors.clone();
339                    }
340
341                    // Update successors' predecessors
342                    for succ_id in &b_successors {
343                        if let Some(succ) = self.blocks.get_mut(succ_id) {
344                            succ.predecessors.retain(|&p| p != b_id);
345                            if !succ.predecessors.contains(&a_id) {
346                                succ.predecessors.push(a_id);
347                            }
348                        }
349                    }
350
351                    // Remove block B
352                    self.blocks.remove(&b_id);
353                    merged_count += 1;
354                    changed = true;
355                    break; // Restart to avoid concurrent modification issues
356                }
357            }
358        }
359
360        merged_count
361    }
362
363    /// Eliminate unreachable blocks (CFG optimization)
364    /// Removes blocks that cannot be reached from the entry block
365    /// Returns the number of blocks eliminated
366    pub fn eliminate_unreachable(&mut self) -> usize {
367        let reachable = self.reachable_blocks();
368        let all_blocks: Vec<BlockId> = self.blocks.keys().copied().collect();
369
370        let mut removed_count = 0;
371        for block_id in all_blocks {
372            if !reachable.contains(&block_id) {
373                // Remove unreachable block
374                if let Some(block) = self.blocks.remove(&block_id) {
375                    // Clean up references from other blocks
376                    for succ_id in &block.successors {
377                        if let Some(succ) = self.blocks.get_mut(succ_id) {
378                            succ.predecessors.retain(|&p| p != block_id);
379                        }
380                    }
381                    for pred_id in &block.predecessors {
382                        if let Some(pred) = self.blocks.get_mut(pred_id) {
383                            pred.successors.retain(|&s| s != block_id);
384                        }
385                    }
386                    removed_count += 1;
387                }
388            }
389        }
390
391        removed_count
392    }
393
394    /// Simplify branches (CFG optimization)
395    /// Simplifies control flow by:
396    /// - Removing branches to the immediate next block (fall-through)
397    /// - Collapsing chains of unconditional branches
398    ///
399    /// Returns the number of branches simplified
400    pub fn simplify_branches(&mut self) -> usize {
401        let mut simplified_count = 0;
402        let blocks: Vec<BlockId> = self.blocks.keys().copied().collect();
403
404        for block_id in blocks {
405            let block = match self.blocks.get(&block_id) {
406                Some(b) => b,
407                None => continue,
408            };
409
410            // Check if this block has a single successor that is just a trampoline
411            if block.successors.len() == 1 {
412                let succ_id = block.successors[0];
413                let succ = match self.blocks.get(&succ_id) {
414                    Some(b) => b,
415                    None => continue,
416                };
417
418                // If successor is an empty trampoline with one successor, bypass it
419                if succ.start == succ.end && succ.successors.len() == 1 && succ_id != self.entry {
420                    let final_target = succ.successors[0];
421
422                    // Update current block to point to final target
423                    if let Some(block_mut) = self.blocks.get_mut(&block_id) {
424                        block_mut.successors = vec![final_target];
425                    }
426
427                    // Update final target's predecessors
428                    if let Some(final_block) = self.blocks.get_mut(&final_target) {
429                        if !final_block.predecessors.contains(&block_id) {
430                            final_block.predecessors.push(block_id);
431                        }
432                        final_block.predecessors.retain(|&p| p != succ_id);
433                    }
434
435                    simplified_count += 1;
436                }
437            }
438        }
439
440        simplified_count
441    }
442}
443
444/// Builder for constructing CFGs
445pub struct CfgBuilder {
446    blocks: Vec<BasicBlock>,
447    current_block: Option<BlockId>,
448    next_block_id: BlockId,
449    instruction_count: usize,
450    block_starts: HashMap<usize, BlockId>,
451    _pending_branches: Vec<(BlockId, usize)>, // (source block, target instruction)
452}
453
454impl CfgBuilder {
455    pub fn new() -> Self {
456        let entry_block = BasicBlock {
457            id: 0,
458            start: 0,
459            end: 0,
460            successors: Vec::new(),
461            predecessors: Vec::new(),
462            label_depth: 0,
463            is_loop_header: false,
464        };
465
466        Self {
467            blocks: vec![entry_block],
468            current_block: Some(0),
469            next_block_id: 1,
470            instruction_count: 0,
471            block_starts: HashMap::from([(0, 0)]),
472            _pending_branches: Vec::new(),
473        }
474    }
475
476    /// Add an instruction to the current block
477    pub fn add_instruction(&mut self) {
478        if let Some(current_id) = self.current_block
479            && let Some(block) = self.blocks.get_mut(current_id)
480        {
481            block.end = self.instruction_count + 1;
482        }
483        self.instruction_count += 1;
484    }
485
486    /// Start a new basic block
487    pub fn start_block(&mut self) -> BlockId {
488        let block_id = self.next_block_id;
489        self.next_block_id += 1;
490
491        let block = BasicBlock {
492            id: block_id,
493            start: self.instruction_count,
494            end: self.instruction_count,
495            successors: Vec::new(),
496            predecessors: Vec::new(),
497            label_depth: 0,
498            is_loop_header: false,
499        };
500
501        self.blocks.push(block);
502        self.block_starts.insert(self.instruction_count, block_id);
503        self.current_block = Some(block_id);
504
505        block_id
506    }
507
508    /// Add a branch from current block to target block
509    pub fn add_branch(&mut self, target: BlockId) {
510        if let Some(current_id) = self.current_block {
511            if let Some(current_block) = self.blocks.iter_mut().find(|b| b.id == current_id)
512                && !current_block.successors.contains(&target)
513            {
514                current_block.successors.push(target);
515            }
516
517            if let Some(target_block) = self.blocks.iter_mut().find(|b| b.id == target)
518                && !target_block.predecessors.contains(&current_id)
519            {
520                target_block.predecessors.push(current_id);
521            }
522        }
523    }
524
525    /// Set the current block (for test purposes)
526    pub fn set_current_block(&mut self, block_id: BlockId) {
527        self.current_block = Some(block_id);
528    }
529
530    /// Mark current block as ending with a terminator
531    pub fn terminate_block(&mut self) {
532        self.current_block = None;
533    }
534
535    /// Build the final CFG
536    pub fn build(self) -> Cfg {
537        let blocks: HashMap<BlockId, BasicBlock> =
538            self.blocks.into_iter().map(|b| (b.id, b)).collect();
539
540        let mut cfg = Cfg {
541            blocks,
542            entry: 0,
543            exit: None,
544            loops: Vec::new(),
545        };
546
547        cfg.detect_loops();
548        cfg
549    }
550}
551
552impl Default for CfgBuilder {
553    fn default() -> Self {
554        Self::new()
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn test_empty_cfg() {
564        let builder = CfgBuilder::new();
565        let cfg = builder.build();
566
567        assert_eq!(cfg.blocks.len(), 1);
568        assert_eq!(cfg.entry, 0);
569    }
570
571    #[test]
572    fn test_simple_cfg() {
573        let mut builder = CfgBuilder::new();
574
575        // Block 0: entry
576        builder.add_instruction(); // inst 0
577        builder.add_instruction(); // inst 1
578
579        // Block 1: body
580        let block1 = builder.start_block();
581        builder.add_instruction(); // inst 2
582        builder.add_instruction(); // inst 3
583
584        // Add branch from entry to block1
585        builder.current_block = Some(0);
586        builder.add_branch(block1);
587
588        let cfg = builder.build();
589
590        assert_eq!(cfg.blocks.len(), 2);
591        assert_eq!(cfg.block(0).unwrap().successors, vec![1]);
592        assert_eq!(cfg.block(1).unwrap().predecessors, vec![0]);
593    }
594
595    #[test]
596    fn test_loop_detection() {
597        let mut builder = CfgBuilder::new();
598
599        // Entry block
600        builder.add_instruction();
601
602        // Loop header
603        let loop_header = builder.start_block();
604        builder.add_instruction();
605
606        // Loop body
607        let loop_body = builder.start_block();
608        builder.add_instruction();
609
610        // Exit block
611        let exit = builder.start_block();
612        builder.add_instruction();
613
614        // Connect blocks
615        builder.current_block = Some(0);
616        builder.add_branch(loop_header);
617
618        builder.current_block = Some(loop_header);
619        builder.add_branch(loop_body);
620        builder.add_branch(exit);
621
622        builder.current_block = Some(loop_body);
623        builder.add_branch(loop_header); // Back edge
624
625        let cfg = builder.build();
626
627        // Loop detection should find one loop
628        assert_eq!(cfg.loops.len(), 1);
629        assert_eq!(cfg.loops[0].header, loop_header);
630        assert!(cfg.loops[0].body.contains(&loop_header));
631        assert!(cfg.loops[0].body.contains(&loop_body));
632    }
633
634    #[test]
635    fn test_rpo_order() {
636        let mut builder = CfgBuilder::new();
637
638        let b1 = builder.start_block();
639        let b2 = builder.start_block();
640
641        builder.current_block = Some(0);
642        builder.add_branch(b1);
643
644        builder.current_block = Some(b1);
645        builder.add_branch(b2);
646
647        let cfg = builder.build();
648        let rpo = cfg.blocks_rpo();
649
650        // Entry should come first in RPO
651        assert_eq!(rpo[0], 0);
652    }
653
654    #[test]
655    fn test_dominators() {
656        let mut builder = CfgBuilder::new();
657
658        let b1 = builder.start_block();
659        let b2 = builder.start_block();
660
661        builder.current_block = Some(0);
662        builder.add_branch(b1);
663
664        builder.current_block = Some(b1);
665        builder.add_branch(b2);
666
667        let cfg = builder.build();
668        let doms = cfg.dominators();
669
670        // Entry dominates itself
671        assert_eq!(doms[&0], 0);
672
673        // Entry dominates b1
674        assert_eq!(doms[&b1], 0);
675
676        // b1 dominates b2
677        assert_eq!(doms[&b2], b1);
678    }
679
680    #[test]
681    fn test_merge_blocks() {
682        let mut builder = CfgBuilder::new();
683        builder.add_instruction();
684
685        // Create a chain: entry -> b1 -> b2
686        let b1 = builder.start_block();
687        builder.add_instruction();
688
689        let b2 = builder.start_block();
690        builder.add_instruction();
691
692        builder.current_block = Some(0);
693        builder.add_branch(b1);
694
695        builder.current_block = Some(b1);
696        builder.add_branch(b2);
697
698        let mut cfg = builder.build();
699        assert_eq!(cfg.blocks.len(), 3);
700
701        // Merge blocks
702        let merged = cfg.merge_blocks();
703        assert_eq!(merged, 2); // b1 and b2 should be merged into entry
704        assert_eq!(cfg.blocks.len(), 1);
705    }
706
707    #[test]
708    fn test_eliminate_unreachable() {
709        let mut builder = CfgBuilder::new();
710        builder.add_instruction();
711
712        // Create reachable blocks
713        let b1 = builder.start_block();
714        builder.add_instruction();
715
716        // Create unreachable block
717        let b2 = builder.start_block();
718        builder.add_instruction();
719
720        // Only connect entry to b1 (b2 is unreachable)
721        builder.current_block = Some(0);
722        builder.add_branch(b1);
723
724        let mut cfg = builder.build();
725        assert_eq!(cfg.blocks.len(), 3);
726
727        // Eliminate unreachable blocks
728        let removed = cfg.eliminate_unreachable();
729        assert_eq!(removed, 1); // b2 should be removed
730        assert_eq!(cfg.blocks.len(), 2);
731        assert!(cfg.blocks.contains_key(&0));
732        assert!(cfg.blocks.contains_key(&b1));
733        assert!(!cfg.blocks.contains_key(&b2));
734    }
735
736    #[test]
737    fn test_simplify_branches() {
738        let mut builder = CfgBuilder::new();
739        builder.add_instruction();
740
741        // Create entry -> trampoline -> target
742        let trampoline = builder.start_block();
743        // Empty trampoline (no instructions)
744
745        let target = builder.start_block();
746        builder.add_instruction();
747
748        builder.current_block = Some(0);
749        builder.add_branch(trampoline);
750
751        builder.current_block = Some(trampoline);
752        builder.add_branch(target);
753
754        let mut cfg = builder.build();
755
756        // Simplify branches
757        let simplified = cfg.simplify_branches();
758        assert_eq!(simplified, 1);
759
760        // Entry should now point directly to target
761        assert_eq!(cfg.block(0).unwrap().successors, vec![target]);
762    }
763
764    #[test]
765    fn test_reachable_blocks() {
766        let mut builder = CfgBuilder::new();
767        builder.add_instruction();
768
769        let b1 = builder.start_block();
770        builder.add_instruction();
771
772        let b2 = builder.start_block();
773        builder.add_instruction();
774
775        let b3 = builder.start_block();
776        builder.add_instruction();
777
778        // Connect: entry -> b1 -> b2 (b3 unreachable)
779        builder.current_block = Some(0);
780        builder.add_branch(b1);
781
782        builder.current_block = Some(b1);
783        builder.add_branch(b2);
784
785        let cfg = builder.build();
786        let reachable = cfg.reachable_blocks();
787
788        assert_eq!(reachable.len(), 3);
789        assert!(reachable.contains(&0));
790        assert!(reachable.contains(&b1));
791        assert!(reachable.contains(&b2));
792        assert!(!reachable.contains(&b3));
793    }
794
795    #[test]
796    fn test_optimization_pipeline() {
797        let mut builder = CfgBuilder::new();
798        builder.add_instruction();
799
800        // Create a complex CFG with optimization opportunities
801        let b1 = builder.start_block();
802        builder.add_instruction();
803
804        let b2 = builder.start_block();
805        builder.add_instruction();
806
807        let _unreachable = builder.start_block();
808        builder.add_instruction();
809
810        let trampoline = builder.start_block();
811        // Empty trampoline
812
813        let target = builder.start_block();
814        builder.add_instruction();
815
816        builder.current_block = Some(0);
817        builder.add_branch(b1);
818
819        builder.current_block = Some(b1);
820        builder.add_branch(b2);
821
822        builder.current_block = Some(b2);
823        builder.add_branch(trampoline);
824
825        builder.current_block = Some(trampoline);
826        builder.add_branch(target);
827
828        let mut cfg = builder.build();
829        let initial_blocks = cfg.blocks.len();
830
831        // Run optimization pipeline
832        let eliminated = cfg.eliminate_unreachable();
833        let simplified = cfg.simplify_branches();
834        let merged = cfg.merge_blocks();
835
836        assert!(eliminated > 0 || simplified > 0 || merged > 0);
837        assert!(cfg.blocks.len() < initial_blocks);
838    }
839}