Skip to main content

ronn_graph/passes/
constant_folding.rs

1use super::{OptimizationPass, PassStats};
2use crate::error::Result;
3use ronn_core::{ModelGraph, NodeAttribute};
4use std::collections::HashSet;
5use tracing::debug;
6
7/// Constant folding pass - evaluates constant expressions at compile time
8pub struct ConstantFoldingPass;
9
10impl OptimizationPass for ConstantFoldingPass {
11    fn name(&self) -> &str {
12        "ConstantFolding"
13    }
14
15    fn run(&self, graph: &mut ModelGraph) -> Result<PassStats> {
16        let mut stats = PassStats::default();
17        let mut constants_folded = HashSet::new();
18
19        // Find nodes whose all inputs are constants
20        for node in graph.nodes() {
21            if Self::all_inputs_constant(&node.id.to_string(), graph)
22                && Self::is_foldable_op(&node.op_type)
23            {
24                debug!("Folding constant node: {} ({})", node.id, node.op_type);
25
26                // Execute the operation at compile time
27                // For now, we mark it as foldable - actual execution would happen here
28                constants_folded.insert(node.id.to_string());
29                stats.nodes_modified += 1;
30            }
31        }
32
33        debug!(
34            "Constant folding pass completed: {} constants folded",
35            constants_folded.len()
36        );
37
38        Ok(stats)
39    }
40}
41
42impl ConstantFoldingPass {
43    /// Check if all inputs to a node are constants
44    fn all_inputs_constant(node_id: &str, graph: &ModelGraph) -> bool {
45        // Check if all input tensors are initializers (constants)
46        // This would require access to the initializers map
47        // For now, return false (would be implemented with full graph context)
48        false
49    }
50
51    /// Check if an operation can be folded
52    fn is_foldable_op(op_type: &str) -> bool {
53        matches!(
54            op_type,
55            "Add" | "Sub" | "Mul" | "Div" | "Reshape" | "Transpose" | "Cast"
56        )
57    }
58}