ronn_graph/passes/
layout.rs1use super::{OptimizationPass, PassStats};
2use crate::error::Result;
3use ronn_core::ModelGraph;
4use tracing::debug;
5
6pub struct LayoutOptimizationPass;
9
10impl OptimizationPass for LayoutOptimizationPass {
11 fn name(&self) -> &str {
12 "LayoutOptimization"
13 }
14
15 fn run(&self, graph: &mut ModelGraph) -> Result<PassStats> {
16 let mut stats = PassStats::default();
17
18 let layout = self.determine_optimal_layout(graph)?;
20 debug!("Determined optimal layout: {:?}", layout);
21
22 stats.nodes_modified += self.insert_layout_transforms(graph, layout)?;
24
25 debug!(
26 "Layout optimization completed: {} layout transforms inserted",
27 stats.nodes_modified
28 );
29
30 Ok(stats)
31 }
32}
33
34#[derive(Debug, Clone, Copy)]
35enum TensorLayout {
36 NCHW, NHWC, }
39
40impl LayoutOptimizationPass {
41 fn determine_optimal_layout(&self, graph: &ModelGraph) -> Result<TensorLayout> {
43 let mut conv_count = 0;
45 let mut other_count = 0;
46
47 for node in graph.nodes() {
48 match node.op_type.as_str() {
49 "Conv" | "MaxPool" | "AveragePool" => conv_count += 1,
50 _ => other_count += 1,
51 }
52 }
53
54 if conv_count > other_count / 2 {
57 Ok(TensorLayout::NCHW)
58 } else {
59 Ok(TensorLayout::NHWC)
60 }
61 }
62
63 fn insert_layout_transforms(
65 &self,
66 graph: &mut ModelGraph,
67 _target_layout: TensorLayout,
68 ) -> Result<usize> {
69 Ok(0)
73 }
74}