use cairo_lang_semantic::test_utils::setup_test_function;
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
use cairo_lang_utils::Intern;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use super::core::{DataflowAnalyzer, Direction, StatementLocation};
use super::def_site::DefSiteAnalysis;
use super::dominator::Dominators;
use super::forward::ForwardDataflowAnalysis;
use super::topological_order::TopologicalOrder;
use super::use_sites::UseSites;
use crate::db::LoweringGroup;
use crate::ids::{ConcreteFunctionWithBodyId, FunctionWithBodyLongId};
use crate::test_utils::{LoweringDatabaseForTesting, formatted_lowered};
use crate::{Block, BlockEnd, BlockId, Lowered, LoweringStage};
cairo_lang_test_utils::test_file_test!(
analysis,
"src/analysis/test_data",
{
dominator: "dominator",
def_site: "def_site",
use_sites: "use_sites",
topological_order: "topological_order",
},
test_analysis,
["analysis"]
);
fn test_analysis(
inputs: &OrderedHashMap<String, String>,
args: &OrderedHashMap<String, String>,
) -> TestRunnerResult {
let analysis_name = args.get("analysis").expect("test requires `analysis` arg");
let db = &mut LoweringDatabaseForTesting::default();
let (test_function, semantic_diagnostics) = setup_test_function(db, inputs).split();
let function_id =
ConcreteFunctionWithBodyId::from_semantic(db, test_function.concrete_function_id);
let lowered = db.lowered_body(function_id, LoweringStage::PostBaseline);
let (lowering_str, result_str) = if let Ok(lowered) = lowered {
let lowering_str = formatted_lowered(db, Some(lowered));
let result_str = match analysis_name.as_str() {
"dominator" => format!("{:#?}", Dominators::analyze(lowered)),
"def_site" => format!("{:#?}", DefSiteAnalysis::analyze(lowered)),
"use_sites" => format!("{:#?}", UseSites::analyze(lowered)),
"topological_order" => format!("{:#?}", TopologicalOrder::analyze(lowered)),
_ => panic!("unknown analysis: {analysis_name}"),
};
(lowering_str, result_str)
} else {
("Lowering failed.".to_string(), "".to_string())
};
TestRunnerResult::success(OrderedHashMap::from([
("semantic_diagnostics".into(), semantic_diagnostics),
("lowering".into(), lowering_str),
("result".into(), result_str),
]))
}
#[derive(Default)]
struct BlockCounter {
block_count: usize,
}
impl<'db, 'a> DataflowAnalyzer<'db, 'a> for BlockCounter {
type Info = usize;
const DIRECTION: Direction = Direction::Forward;
fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
0
}
fn merge(
&mut self,
_lowered: &Lowered<'db>,
_statement_location: StatementLocation,
info1: Self::Info,
info2: Self::Info,
) -> Self::Info {
info1.max(info2)
}
fn transfer_block(
&mut self,
info: &mut Self::Info,
_block_id: BlockId,
_block: &'a Block<'db>,
) {
self.block_count += 1;
*info += 1;
}
}
#[derive(Default)]
struct ReachabilityAnalyzer {
reachable_blocks: OrderedHashSet<BlockId>,
}
impl<'db, 'a> DataflowAnalyzer<'db, 'a> for ReachabilityAnalyzer {
type Info = OrderedHashSet<BlockId>;
const DIRECTION: Direction = Direction::Forward;
fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
OrderedHashSet::default()
}
fn merge(
&mut self,
_lowered: &Lowered<'db>,
_statement_location: StatementLocation,
info1: Self::Info,
info2: Self::Info,
) -> Self::Info {
let mut result = info1;
result.extend(info2);
result
}
fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
self.reachable_blocks.insert(block_id);
info.insert(block_id);
}
}
#[test]
fn test_block_level_analysis() {
let db = LoweringDatabaseForTesting::default();
let inputs = OrderedHashMap::from([
(
"function_code".to_string(),
"fn foo(x: bool) -> felt252 { if x { 1 } else { 2 } }".to_string(),
),
("function_name".to_string(), "foo".to_string()),
("module_code".to_string(), "".to_string()),
]);
let (test_function, _) = setup_test_function(&db, &inputs).split();
let lowered = db
.function_with_body_lowering(
FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
)
.unwrap();
let analyzer = BlockCounter::default();
let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
let _ = analysis.run();
assert!(
analysis.analyzer.block_count >= 2,
"Expected at least 2 blocks, got {}",
analysis.analyzer.block_count
);
}
#[test]
fn test_forward_single_block() {
let db = LoweringDatabaseForTesting::default();
let inputs = OrderedHashMap::from([
("function_code".to_string(), "fn foo() {}".to_string()),
("function_name".to_string(), "foo".to_string()),
("module_code".to_string(), "".to_string()),
]);
let (test_function, _) = setup_test_function(&db, &inputs).split();
let lowered = db
.function_with_body_lowering(
FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
)
.unwrap();
let analyzer = ReachabilityAnalyzer::default();
let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
let _ = analysis.run();
assert!(!analysis.analyzer.reachable_blocks.is_empty());
assert!(analysis.analyzer.reachable_blocks.contains(&BlockId::root()));
}
#[test]
fn test_forward_with_branching() {
let db = LoweringDatabaseForTesting::default();
let inputs = OrderedHashMap::from([
(
"function_code".to_string(),
"fn foo(x: bool) -> felt252 { if x { 1 } else { 2 } }".to_string(),
),
("function_name".to_string(), "foo".to_string()),
("module_code".to_string(), "".to_string()),
]);
let (test_function, _) = setup_test_function(&db, &inputs).split();
let lowered = db
.function_with_body_lowering(
FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
)
.unwrap();
let analyzer = ReachabilityAnalyzer::default();
let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
let exit_info = analysis.run().clone();
assert!(
analysis.analyzer.reachable_blocks.len() >= 2,
"Expected at least 2 reachable blocks with branching"
);
for block_id in &analysis.analyzer.reachable_blocks {
assert!(exit_info[block_id.0].is_some(), "Block {:?} should have exit info", block_id);
}
}