use std::collections::HashSet;
use cairo_lang_semantic::test_utils::setup_test_function;
use cairo_lang_utils::Intern;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use super::core::{DataflowAnalyzer, Direction, StatementLocation};
use super::forward::ForwardDataflowAnalysis;
use crate::db::LoweringGroup;
use crate::ids::FunctionWithBodyLongId;
use crate::test_utils::LoweringDatabaseForTesting;
use crate::{Block, BlockEnd, BlockId, Lowered};
#[derive(Default)]
struct BlockCounter {
block_count: usize,
}
impl<'db, 'a> DataflowAnalyzer<'db, 'a> for BlockCounter {
type Info = usize;
const DIRECTION: Direction = Direction::Forward;
fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
0
}
fn merge(
&mut self,
_lowered: &Lowered<'db>,
_statement_location: StatementLocation,
info1: Self::Info,
info2: Self::Info,
) -> Self::Info {
info1.max(info2)
}
fn transfer_block(
&mut self,
info: &mut Self::Info,
_block_id: BlockId,
_block: &'a Block<'db>,
) {
self.block_count += 1;
*info += 1;
}
}
#[derive(Default)]
struct ReachabilityAnalyzer {
reachable_blocks: HashSet<BlockId>,
}
impl<'db, 'a> DataflowAnalyzer<'db, 'a> for ReachabilityAnalyzer {
type Info = HashSet<BlockId>;
const DIRECTION: Direction = Direction::Forward;
fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
HashSet::new()
}
fn merge(
&mut self,
_lowered: &Lowered<'db>,
_statement_location: StatementLocation,
info1: Self::Info,
info2: Self::Info,
) -> Self::Info {
let mut result = info1;
result.extend(info2);
result
}
fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
self.reachable_blocks.insert(block_id);
info.insert(block_id);
}
}
#[test]
fn test_block_level_analysis() {
let db = LoweringDatabaseForTesting::default();
let inputs = OrderedHashMap::from([
(
"function_code".to_string(),
"fn foo(x: bool) -> felt252 { if x { 1 } else { 2 } }".to_string(),
),
("function_name".to_string(), "foo".to_string()),
("module_code".to_string(), "".to_string()),
]);
let (test_function, _) = setup_test_function(&db, &inputs).split();
let lowered = db
.function_with_body_lowering(
FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
)
.unwrap();
let analyzer = BlockCounter::default();
let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
let _ = analysis.run();
assert!(
analysis.analyzer.block_count >= 2,
"Expected at least 2 blocks, got {}",
analysis.analyzer.block_count
);
}
#[test]
fn test_forward_single_block() {
let db = LoweringDatabaseForTesting::default();
let inputs = OrderedHashMap::from([
("function_code".to_string(), "fn foo() {}".to_string()),
("function_name".to_string(), "foo".to_string()),
("module_code".to_string(), "".to_string()),
]);
let (test_function, _) = setup_test_function(&db, &inputs).split();
let lowered = db
.function_with_body_lowering(
FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
)
.unwrap();
let analyzer = ReachabilityAnalyzer::default();
let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
let _ = analysis.run();
assert!(!analysis.analyzer.reachable_blocks.is_empty());
assert!(analysis.analyzer.reachable_blocks.contains(&BlockId::root()));
}
#[test]
fn test_forward_with_branching() {
let db = LoweringDatabaseForTesting::default();
let inputs = OrderedHashMap::from([
(
"function_code".to_string(),
"fn foo(x: bool) -> felt252 { if x { 1 } else { 2 } }".to_string(),
),
("function_name".to_string(), "foo".to_string()),
("module_code".to_string(), "".to_string()),
]);
let (test_function, _) = setup_test_function(&db, &inputs).split();
let lowered = db
.function_with_body_lowering(
FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
)
.unwrap();
let analyzer = ReachabilityAnalyzer::default();
let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
let exit_info = analysis.run().clone();
assert!(
analysis.analyzer.reachable_blocks.len() >= 2,
"Expected at least 2 reachable blocks with branching"
);
for block_id in &analysis.analyzer.reachable_blocks {
assert!(exit_info[block_id.0].is_some(), "Block {:?} should have exit info", block_id);
}
}