Skip to main content

ronn_graph/passes/
layout.rs

1use super::{OptimizationPass, PassStats};
2use crate::error::Result;
3use ronn_core::ModelGraph;
4use tracing::debug;
5
6/// Layout optimization pass - optimizes memory layout for performance
7/// Examples: NCHW vs NHWC, row-major vs column-major
8pub struct LayoutOptimizationPass;
9
10impl OptimizationPass for LayoutOptimizationPass {
11    fn name(&self) -> &str {
12        "LayoutOptimization"
13    }
14
15    fn run(&self, graph: &mut ModelGraph) -> Result<PassStats> {
16        let mut stats = PassStats::default();
17
18        // Analyze the graph to determine optimal layout
19        let layout = self.determine_optimal_layout(graph)?;
20        debug!("Determined optimal layout: {:?}", layout);
21
22        // Insert layout transformation nodes where needed
23        stats.nodes_modified += self.insert_layout_transforms(graph, layout)?;
24
25        debug!(
26            "Layout optimization completed: {} layout transforms inserted",
27            stats.nodes_modified
28        );
29
30        Ok(stats)
31    }
32}
33
34#[derive(Debug, Clone, Copy)]
35enum TensorLayout {
36    NCHW, // Batch, Channels, Height, Width
37    NHWC, // Batch, Height, Width, Channels
38}
39
40impl LayoutOptimizationPass {
41    /// Determine the optimal layout based on operations in the graph
42    fn determine_optimal_layout(&self, graph: &ModelGraph) -> Result<TensorLayout> {
43        // Count conv operations - they prefer NCHW on GPU
44        let mut conv_count = 0;
45        let mut other_count = 0;
46
47        for node in graph.nodes() {
48            match node.op_type.as_str() {
49                "Conv" | "MaxPool" | "AveragePool" => conv_count += 1,
50                _ => other_count += 1,
51            }
52        }
53
54        // If mostly convolutions, use NCHW (better for GPU)
55        // Otherwise use NHWC (better for CPU)
56        if conv_count > other_count / 2 {
57            Ok(TensorLayout::NCHW)
58        } else {
59            Ok(TensorLayout::NHWC)
60        }
61    }
62
63    /// Insert layout transformation nodes where needed
64    fn insert_layout_transforms(
65        &self,
66        graph: &mut ModelGraph,
67        _target_layout: TensorLayout,
68    ) -> Result<usize> {
69        // Find places where layout needs to change
70        // Insert Transpose nodes to convert between layouts
71        // This is a simplified version
72        Ok(0)
73    }
74}