cairo-lang-lowering 2.19.0

Cairo lowering phase.
Documentation
//! Tests for the dataflow analysis framework.

use cairo_lang_semantic::test_utils::setup_test_function;
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
use cairo_lang_utils::Intern;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;

use super::core::{DataflowAnalyzer, Direction, StatementLocation};
use super::def_site::DefSiteAnalysis;
use super::dominator::Dominators;
use super::forward::ForwardDataflowAnalysis;
use super::topological_order::TopologicalOrder;
use super::use_sites::UseSites;
use crate::db::LoweringGroup;
use crate::ids::{ConcreteFunctionWithBodyId, FunctionWithBodyLongId};
use crate::test_utils::{LoweringDatabaseForTesting, formatted_lowered};
use crate::{Block, BlockEnd, BlockId, Lowered, LoweringStage};

// ============================================================================
// Golden test runner — dispatches to the analysis specified by the `analysis`
// arg in the test data file.  Adding a new analysis only requires a new match
// arm here plus a `test_data/<name>` file.
// ============================================================================

cairo_lang_test_utils::test_file_test!(
    analysis,
    "src/analysis/test_data",
    {
        dominator: "dominator",
        def_site: "def_site",
        use_sites: "use_sites",
        topological_order: "topological_order",
    },
    test_analysis,
    ["analysis"]
);

fn test_analysis(
    inputs: &OrderedHashMap<String, String>,
    args: &OrderedHashMap<String, String>,
) -> TestRunnerResult {
    let analysis_name = args.get("analysis").expect("test requires `analysis` arg");

    let db = &mut LoweringDatabaseForTesting::default();
    let (test_function, semantic_diagnostics) = setup_test_function(db, inputs).split();

    let function_id =
        ConcreteFunctionWithBodyId::from_semantic(db, test_function.concrete_function_id);

    let lowered = db.lowered_body(function_id, LoweringStage::PostBaseline);

    let (lowering_str, result_str) = if let Ok(lowered) = lowered {
        let lowering_str = formatted_lowered(db, Some(lowered));
        let result_str = match analysis_name.as_str() {
            "dominator" => format!("{:#?}", Dominators::analyze(lowered)),
            "def_site" => format!("{:#?}", DefSiteAnalysis::analyze(lowered)),
            "use_sites" => format!("{:#?}", UseSites::analyze(lowered)),
            "topological_order" => format!("{:#?}", TopologicalOrder::analyze(lowered)),
            _ => panic!("unknown analysis: {analysis_name}"),
        };
        (lowering_str, result_str)
    } else {
        ("Lowering failed.".to_string(), "".to_string())
    };

    TestRunnerResult::success(OrderedHashMap::from([
        ("semantic_diagnostics".into(), semantic_diagnostics),
        ("lowering".into(), lowering_str),
        ("result".into(), result_str),
    ]))
}

// ============================================================================
// Block-level Analysis: Count blocks (demonstrates transfer_block override)
// ============================================================================

/// A simple block-level analyzer that counts blocks.
/// Demonstrates overriding transfer_block for coarse-grained analysis.
#[derive(Default)]
struct BlockCounter {
    block_count: usize,
}

impl<'db, 'a> DataflowAnalyzer<'db, 'a> for BlockCounter {
    type Info = usize; // Count of blocks seen

    const DIRECTION: Direction = Direction::Forward;

    fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
        0
    }

    fn merge(
        &mut self,
        _lowered: &Lowered<'db>,
        _statement_location: StatementLocation,
        info1: Self::Info,
        info2: Self::Info,
    ) -> Self::Info {
        info1.max(info2)
    }

    // Override transfer_block for block-level analysis (no statement iteration)
    fn transfer_block(
        &mut self,
        info: &mut Self::Info,
        _block_id: BlockId,
        _block: &'a Block<'db>,
    ) {
        self.block_count += 1;
        *info += 1;
    }
}

// ============================================================================
// Statement-level Analysis: Track reachable blocks (uses default transfer_block)
// ============================================================================

/// A simple forward analyzer that tracks which blocks are reachable.
/// Demonstrates using default transfer_block with statement-level transfer_stmt.
#[derive(Default)]
struct ReachabilityAnalyzer {
    reachable_blocks: OrderedHashSet<BlockId>,
}

impl<'db, 'a> DataflowAnalyzer<'db, 'a> for ReachabilityAnalyzer {
    type Info = OrderedHashSet<BlockId>; // Set of blocks visited to reach this point

    const DIRECTION: Direction = Direction::Forward;

    fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
        OrderedHashSet::default()
    }

    fn merge(
        &mut self,
        _lowered: &Lowered<'db>,
        _statement_location: StatementLocation,
        info1: Self::Info,
        info2: Self::Info,
    ) -> Self::Info {
        // Union of two reachability sets
        let mut result = info1;
        result.extend(info2);
        result
    }

    // Uses default transfer_block which iterates statements.
    // transfer_stmt is no-op (default) since reachability doesn't change at statements.

    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
        self.reachable_blocks.insert(block_id);
        info.insert(block_id);
    }
}

// ============================================================================
// Tests
// ============================================================================

#[test]
fn test_block_level_analysis() {
    let db = LoweringDatabaseForTesting::default();
    let inputs = OrderedHashMap::from([
        (
            "function_code".to_string(),
            "fn foo(x: bool) -> felt252 { if x { 1 } else { 2 } }".to_string(),
        ),
        ("function_name".to_string(), "foo".to_string()),
        ("module_code".to_string(), "".to_string()),
    ]);
    let (test_function, _) = setup_test_function(&db, &inputs).split();
    let lowered = db
        .function_with_body_lowering(
            FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
        )
        .unwrap();

    let analyzer = BlockCounter::default();
    let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
    let _ = analysis.run();

    // Block-level analyzer should have counted multiple blocks
    assert!(
        analysis.analyzer.block_count >= 2,
        "Expected at least 2 blocks, got {}",
        analysis.analyzer.block_count
    );
}

#[test]
fn test_forward_single_block() {
    let db = LoweringDatabaseForTesting::default();
    let inputs = OrderedHashMap::from([
        ("function_code".to_string(), "fn foo() {}".to_string()),
        ("function_name".to_string(), "foo".to_string()),
        ("module_code".to_string(), "".to_string()),
    ]);
    let (test_function, _) = setup_test_function(&db, &inputs).split();
    let lowered = db
        .function_with_body_lowering(
            FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
        )
        .unwrap();

    let analyzer = ReachabilityAnalyzer::default();
    let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
    let _ = analysis.run();

    // Should have visited at least the root block
    assert!(!analysis.analyzer.reachable_blocks.is_empty());
    assert!(analysis.analyzer.reachable_blocks.contains(&BlockId::root()));
}

#[test]
fn test_forward_with_branching() {
    let db = LoweringDatabaseForTesting::default();
    // A function with branching creates multiple blocks
    let inputs = OrderedHashMap::from([
        (
            "function_code".to_string(),
            "fn foo(x: bool) -> felt252 { if x { 1 } else { 2 } }".to_string(),
        ),
        ("function_name".to_string(), "foo".to_string()),
        ("module_code".to_string(), "".to_string()),
    ]);
    let (test_function, _) = setup_test_function(&db, &inputs).split();
    let lowered = db
        .function_with_body_lowering(
            FunctionWithBodyLongId::Semantic(test_function.function_id).intern(&db),
        )
        .unwrap();

    let analyzer = ReachabilityAnalyzer::default();
    let mut analysis = ForwardDataflowAnalysis::new(lowered, analyzer);
    let exit_info = analysis.run().clone();

    // With branching, should visit multiple blocks
    assert!(
        analysis.analyzer.reachable_blocks.len() >= 2,
        "Expected at least 2 reachable blocks with branching"
    );

    // All processed blocks should have exit info
    for block_id in &analysis.analyzer.reachable_blocks {
        assert!(exit_info[block_id.0].is_some(), "Block {:?} should have exit info", block_id);
    }
}