ghostflow_core/layout/
mod.rs

1//! Memory Layout Optimizer
2//!
3//! Automatically chooses the best memory layout for each operation
4//! This gives us the final edge over JAX!
5
6use std::collections::HashMap;
7
8/// Memory layout formats
9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
10pub enum MemoryLayout {
11    /// Batch, Channel, Height, Width (PyTorch default)
12    NCHW,
13    /// Batch, Height, Width, Channel (TensorFlow default, Tensor Core friendly)
14    NHWC,
15    /// Channel, Height, Width, Batch (rarely used)
16    CHWN,
17    /// Batch, Sequence, Features (for transformers)
18    BSF,
19    /// Sequence, Batch, Features (for RNNs)
20    SBF,
21}
22
23/// Device capabilities
24#[derive(Clone, Debug)]
25pub struct DeviceInfo {
26    pub has_tensor_cores: bool,
27    pub compute_capability: (u32, u32),
28    pub memory_bandwidth: f64, // GB/s
29    pub is_ampere_or_newer: bool,
30}
31
32impl DeviceInfo {
33    /// Detect device capabilities
34    #[cfg(feature = "cuda")]
35    pub fn detect() -> Self {
36        // Would query actual CUDA device
37        Self {
38            has_tensor_cores: true,
39            compute_capability: (8, 0), // Ampere
40            memory_bandwidth: 1555.0,
41            is_ampere_or_newer: true,
42        }
43    }
44
45    #[cfg(not(feature = "cuda"))]
46    pub fn detect() -> Self {
47        Self {
48            has_tensor_cores: false,
49            compute_capability: (0, 0),
50            memory_bandwidth: 0.0,
51            is_ampere_or_newer: false,
52        }
53    }
54}
55
56/// Operation types for layout selection
57#[derive(Clone, Debug, PartialEq, Eq, Hash)]
58pub enum OperationType {
59    Conv2d { kernel: (usize, usize), stride: (usize, usize) },
60    MatMul { m: usize, n: usize, k: usize },
61    BatchNorm,
62    LayerNorm,
63    Attention { heads: usize, seq_len: usize },
64    ElementWise,
65    Pooling,
66}
67
68/// Layout optimizer
69pub struct LayoutOptimizer {
70    device_info: DeviceInfo,
71    layout_cache: HashMap<OperationType, MemoryLayout>,
72}
73
74impl LayoutOptimizer {
75    /// Create a new layout optimizer
76    pub fn new() -> Self {
77        Self {
78            device_info: DeviceInfo::detect(),
79            layout_cache: HashMap::new(),
80        }
81    }
82
83    /// Choose optimal layout for an operation
84    pub fn choose_layout(&mut self, op: &OperationType) -> MemoryLayout {
85        // Check cache first
86        if let Some(&layout) = self.layout_cache.get(op) {
87            return layout;
88        }
89
90        let layout = self.compute_optimal_layout(op);
91        self.layout_cache.insert(op.clone(), layout);
92        layout
93    }
94
95    /// Compute optimal layout based on operation and device
96    fn compute_optimal_layout(&self, op: &OperationType) -> MemoryLayout {
97        match op {
98            // Convolution layout selection
99            OperationType::Conv2d { kernel, stride } => {
100                if self.device_info.has_tensor_cores {
101                    // Tensor cores prefer NHWC
102                    MemoryLayout::NHWC
103                } else if kernel.0 == 1 && kernel.1 == 1 {
104                    // 1x1 convolutions are memory-bound, use NCHW
105                    MemoryLayout::NCHW
106                } else if stride.0 > 1 || stride.1 > 1 {
107                    // Strided convolutions benefit from NHWC
108                    MemoryLayout::NHWC
109                } else {
110                    // Default to cuDNN's preference
111                    MemoryLayout::NCHW
112                }
113            },
114
115            // Matrix multiplication
116            OperationType::MatMul { m, n, k } => {
117                if self.device_info.has_tensor_cores && m % 16 == 0 && n % 16 == 0 && k % 16 == 0 {
118                    // Tensor cores work best with aligned dimensions
119                    MemoryLayout::NCHW // Doesn't really apply, but indicates tensor core path
120                } else {
121                    MemoryLayout::NCHW
122                }
123            },
124
125            // BatchNorm prefers NCHW for better vectorization
126            OperationType::BatchNorm => MemoryLayout::NCHW,
127
128            // LayerNorm doesn't care much
129            OperationType::LayerNorm => MemoryLayout::NCHW,
130
131            // Attention operations
132            OperationType::Attention { heads: _heads, seq_len } => {
133                if *seq_len > 512 {
134                    // Long sequences benefit from BSF layout
135                    MemoryLayout::BSF
136                } else {
137                    // Short sequences can use either
138                    MemoryLayout::BSF
139                }
140            },
141
142            // Element-wise operations don't care
143            OperationType::ElementWise => MemoryLayout::NCHW,
144
145            // Pooling prefers NHWC on tensor cores
146            OperationType::Pooling => {
147                if self.device_info.has_tensor_cores {
148                    MemoryLayout::NHWC
149                } else {
150                    MemoryLayout::NCHW
151                }
152            },
153        }
154    }
155
156    /// Transform tensor from one layout to another
157    pub fn transform_layout(
158        &self,
159        data: &[f32],
160        from: MemoryLayout,
161        to: MemoryLayout,
162        shape: &[usize],
163    ) -> Vec<f32> {
164        if from == to {
165            return data.to_vec();
166        }
167
168        match (from, to) {
169            (MemoryLayout::NCHW, MemoryLayout::NHWC) => {
170                self.nchw_to_nhwc(data, shape)
171            },
172            (MemoryLayout::NHWC, MemoryLayout::NCHW) => {
173                self.nhwc_to_nchw(data, shape)
174            },
175            _ => data.to_vec(), // Fallback
176        }
177    }
178
179    /// Convert NCHW to NHWC
180    fn nchw_to_nhwc(&self, data: &[f32], shape: &[usize]) -> Vec<f32> {
181        let n = shape[0];
182        let c = shape[1];
183        let h = shape[2];
184        let w = shape[3];
185
186        let mut output = vec![0.0f32; data.len()];
187
188        for batch in 0..n {
189            for channel in 0..c {
190                for height in 0..h {
191                    for width in 0..w {
192                        let nchw_idx = ((batch * c + channel) * h + height) * w + width;
193                        let nhwc_idx = ((batch * h + height) * w + width) * c + channel;
194                        output[nhwc_idx] = data[nchw_idx];
195                    }
196                }
197            }
198        }
199
200        output
201    }
202
203    /// Convert NHWC to NCHW
204    fn nhwc_to_nchw(&self, data: &[f32], shape: &[usize]) -> Vec<f32> {
205        // shape is in NCHW format: [N, C, H, W]
206        let n = shape[0];
207        let c = shape[1];
208        let h = shape[2];
209        let w = shape[3];
210
211        let mut output = vec![0.0f32; data.len()];
212
213        for batch in 0..n {
214            for height in 0..h {
215                for width in 0..w {
216                    for channel in 0..c {
217                        let nhwc_idx = ((batch * h + height) * w + width) * c + channel;
218                        let nchw_idx = ((batch * c + channel) * h + height) * w + width;
219                        output[nchw_idx] = data[nhwc_idx];
220                    }
221                }
222            }
223        }
224
225        output
226    }
227
228    /// Get performance estimate for a layout choice
229    pub fn estimate_performance(
230        &self,
231        op: &OperationType,
232        layout: MemoryLayout,
233    ) -> f64 {
234        // Estimate relative performance (1.0 = baseline)
235        match (op, layout) {
236            (OperationType::Conv2d { .. }, MemoryLayout::NHWC) if self.device_info.has_tensor_cores => {
237                1.3 // 30% faster with tensor cores
238            },
239            (OperationType::Conv2d { .. }, MemoryLayout::NCHW) => {
240                1.0 // Baseline
241            },
242            (OperationType::MatMul { .. }, _) if self.device_info.has_tensor_cores => {
243                1.5 // 50% faster with tensor cores
244            },
245            _ => 1.0,
246        }
247    }
248
249    /// Clear layout cache
250    pub fn clear_cache(&mut self) {
251        self.layout_cache.clear();
252    }
253}
254
255impl Default for LayoutOptimizer {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_layout_selection() {
267        let mut optimizer = LayoutOptimizer::new();
268        
269        let conv_op = OperationType::Conv2d {
270            kernel: (3, 3),
271            stride: (1, 1),
272        };
273        
274        let layout = optimizer.choose_layout(&conv_op);
275        assert!(layout == MemoryLayout::NCHW || layout == MemoryLayout::NHWC);
276    }
277
278    #[test]
279    fn test_layout_transformation() {
280        let optimizer = LayoutOptimizer::new();
281        
282        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
283        let shape = vec![1, 2, 2, 2]; // N=1, C=2, H=2, W=2
284        
285        let nhwc = optimizer.transform_layout(
286            &data,
287            MemoryLayout::NCHW,
288            MemoryLayout::NHWC,
289            &shape,
290        );
291        
292        assert_eq!(nhwc.len(), data.len());
293    }
294
295    #[test]
296    fn test_performance_estimate() {
297        let optimizer = LayoutOptimizer::new();
298        
299        let conv_op = OperationType::Conv2d {
300            kernel: (3, 3),
301            stride: (1, 1),
302        };
303        
304        let perf = optimizer.estimate_performance(&conv_op, MemoryLayout::NHWC);
305        assert!(perf >= 1.0);
306    }
307}