ronn_graph/passes/
constant_folding.rs1use super::{OptimizationPass, PassStats};
2use crate::error::Result;
3use ronn_core::{ModelGraph, NodeAttribute};
4use std::collections::HashSet;
5use tracing::debug;
6
7pub struct ConstantFoldingPass;
9
10impl OptimizationPass for ConstantFoldingPass {
11 fn name(&self) -> &str {
12 "ConstantFolding"
13 }
14
15 fn run(&self, graph: &mut ModelGraph) -> Result<PassStats> {
16 let mut stats = PassStats::default();
17 let mut constants_folded = HashSet::new();
18
19 for node in graph.nodes() {
21 if Self::all_inputs_constant(&node.id.to_string(), graph)
22 && Self::is_foldable_op(&node.op_type)
23 {
24 debug!("Folding constant node: {} ({})", node.id, node.op_type);
25
26 constants_folded.insert(node.id.to_string());
29 stats.nodes_modified += 1;
30 }
31 }
32
33 debug!(
34 "Constant folding pass completed: {} constants folded",
35 constants_folded.len()
36 );
37
38 Ok(stats)
39 }
40}
41
42impl ConstantFoldingPass {
43 fn all_inputs_constant(node_id: &str, graph: &ModelGraph) -> bool {
45 false
49 }
50
51 fn is_foldable_op(op_type: &str) -> bool {
53 matches!(
54 op_type,
55 "Add" | "Sub" | "Mul" | "Div" | "Reshape" | "Transpose" | "Cast"
56 )
57 }
58}