1use std::collections::HashMap;
7
8#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
10pub enum MemoryLayout {
11 NCHW,
13 NHWC,
15 CHWN,
17 BSF,
19 SBF,
21}
22
23#[derive(Clone, Debug)]
25pub struct DeviceInfo {
26 pub has_tensor_cores: bool,
27 pub compute_capability: (u32, u32),
28 pub memory_bandwidth: f64, pub is_ampere_or_newer: bool,
30}
31
32impl DeviceInfo {
33 #[cfg(feature = "cuda")]
35 pub fn detect() -> Self {
36 Self {
38 has_tensor_cores: true,
39 compute_capability: (8, 0), memory_bandwidth: 1555.0,
41 is_ampere_or_newer: true,
42 }
43 }
44
45 #[cfg(not(feature = "cuda"))]
46 pub fn detect() -> Self {
47 Self {
48 has_tensor_cores: false,
49 compute_capability: (0, 0),
50 memory_bandwidth: 0.0,
51 is_ampere_or_newer: false,
52 }
53 }
54}
55
56#[derive(Clone, Debug, PartialEq, Eq, Hash)]
58pub enum OperationType {
59 Conv2d { kernel: (usize, usize), stride: (usize, usize) },
60 MatMul { m: usize, n: usize, k: usize },
61 BatchNorm,
62 LayerNorm,
63 Attention { heads: usize, seq_len: usize },
64 ElementWise,
65 Pooling,
66}
67
68pub struct LayoutOptimizer {
70 device_info: DeviceInfo,
71 layout_cache: HashMap<OperationType, MemoryLayout>,
72}
73
74impl LayoutOptimizer {
75 pub fn new() -> Self {
77 Self {
78 device_info: DeviceInfo::detect(),
79 layout_cache: HashMap::new(),
80 }
81 }
82
83 pub fn choose_layout(&mut self, op: &OperationType) -> MemoryLayout {
85 if let Some(&layout) = self.layout_cache.get(op) {
87 return layout;
88 }
89
90 let layout = self.compute_optimal_layout(op);
91 self.layout_cache.insert(op.clone(), layout);
92 layout
93 }
94
95 fn compute_optimal_layout(&self, op: &OperationType) -> MemoryLayout {
97 match op {
98 OperationType::Conv2d { kernel, stride } => {
100 if self.device_info.has_tensor_cores {
101 MemoryLayout::NHWC
103 } else if kernel.0 == 1 && kernel.1 == 1 {
104 MemoryLayout::NCHW
106 } else if stride.0 > 1 || stride.1 > 1 {
107 MemoryLayout::NHWC
109 } else {
110 MemoryLayout::NCHW
112 }
113 },
114
115 OperationType::MatMul { m, n, k } => {
117 if self.device_info.has_tensor_cores && m % 16 == 0 && n % 16 == 0 && k % 16 == 0 {
118 MemoryLayout::NCHW } else {
121 MemoryLayout::NCHW
122 }
123 },
124
125 OperationType::BatchNorm => MemoryLayout::NCHW,
127
128 OperationType::LayerNorm => MemoryLayout::NCHW,
130
131 OperationType::Attention { heads: _heads, seq_len } => {
133 if *seq_len > 512 {
134 MemoryLayout::BSF
136 } else {
137 MemoryLayout::BSF
139 }
140 },
141
142 OperationType::ElementWise => MemoryLayout::NCHW,
144
145 OperationType::Pooling => {
147 if self.device_info.has_tensor_cores {
148 MemoryLayout::NHWC
149 } else {
150 MemoryLayout::NCHW
151 }
152 },
153 }
154 }
155
156 pub fn transform_layout(
158 &self,
159 data: &[f32],
160 from: MemoryLayout,
161 to: MemoryLayout,
162 shape: &[usize],
163 ) -> Vec<f32> {
164 if from == to {
165 return data.to_vec();
166 }
167
168 match (from, to) {
169 (MemoryLayout::NCHW, MemoryLayout::NHWC) => {
170 self.nchw_to_nhwc(data, shape)
171 },
172 (MemoryLayout::NHWC, MemoryLayout::NCHW) => {
173 self.nhwc_to_nchw(data, shape)
174 },
175 _ => data.to_vec(), }
177 }
178
179 fn nchw_to_nhwc(&self, data: &[f32], shape: &[usize]) -> Vec<f32> {
181 let n = shape[0];
182 let c = shape[1];
183 let h = shape[2];
184 let w = shape[3];
185
186 let mut output = vec![0.0f32; data.len()];
187
188 for batch in 0..n {
189 for channel in 0..c {
190 for height in 0..h {
191 for width in 0..w {
192 let nchw_idx = ((batch * c + channel) * h + height) * w + width;
193 let nhwc_idx = ((batch * h + height) * w + width) * c + channel;
194 output[nhwc_idx] = data[nchw_idx];
195 }
196 }
197 }
198 }
199
200 output
201 }
202
203 fn nhwc_to_nchw(&self, data: &[f32], shape: &[usize]) -> Vec<f32> {
205 let n = shape[0];
207 let c = shape[1];
208 let h = shape[2];
209 let w = shape[3];
210
211 let mut output = vec![0.0f32; data.len()];
212
213 for batch in 0..n {
214 for height in 0..h {
215 for width in 0..w {
216 for channel in 0..c {
217 let nhwc_idx = ((batch * h + height) * w + width) * c + channel;
218 let nchw_idx = ((batch * c + channel) * h + height) * w + width;
219 output[nchw_idx] = data[nhwc_idx];
220 }
221 }
222 }
223 }
224
225 output
226 }
227
228 pub fn estimate_performance(
230 &self,
231 op: &OperationType,
232 layout: MemoryLayout,
233 ) -> f64 {
234 match (op, layout) {
236 (OperationType::Conv2d { .. }, MemoryLayout::NHWC) if self.device_info.has_tensor_cores => {
237 1.3 },
239 (OperationType::Conv2d { .. }, MemoryLayout::NCHW) => {
240 1.0 },
242 (OperationType::MatMul { .. }, _) if self.device_info.has_tensor_cores => {
243 1.5 },
245 _ => 1.0,
246 }
247 }
248
249 pub fn clear_cache(&mut self) {
251 self.layout_cache.clear();
252 }
253}
254
255impl Default for LayoutOptimizer {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_layout_selection() {
267 let mut optimizer = LayoutOptimizer::new();
268
269 let conv_op = OperationType::Conv2d {
270 kernel: (3, 3),
271 stride: (1, 1),
272 };
273
274 let layout = optimizer.choose_layout(&conv_op);
275 assert!(layout == MemoryLayout::NCHW || layout == MemoryLayout::NHWC);
276 }
277
278 #[test]
279 fn test_layout_transformation() {
280 let optimizer = LayoutOptimizer::new();
281
282 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
283 let shape = vec![1, 2, 2, 2]; let nhwc = optimizer.transform_layout(
286 &data,
287 MemoryLayout::NCHW,
288 MemoryLayout::NHWC,
289 &shape,
290 );
291
292 assert_eq!(nhwc.len(), data.len());
293 }
294
295 #[test]
296 fn test_performance_estimate() {
297 let optimizer = LayoutOptimizer::new();
298
299 let conv_op = OperationType::Conv2d {
300 kernel: (3, 3),
301 stride: (1, 1),
302 };
303
304 let perf = optimizer.estimate_performance(&conv_op, MemoryLayout::NHWC);
305 assert!(perf >= 1.0);
306 }
307}