1#[cfg(target_arch = "x86_64")]
18use std::arch::is_x86_feature_detected;
19
20use super::profiler::BlisProfiler;
21use super::{gemm_blis, TruenoError};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum ComputeBackend {
30 Cpu,
32 Gpu,
34 Wgpu,
36 Scalar,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum BrickLevel {
48 Nano,
50 Micro,
52 Meso,
54}
55
56#[derive(Debug, Clone)]
60pub struct BackendCostModel {
61 pub pcie_bandwidth_gbps: f64,
63 pub gpu_peak_tflops: f64,
65 pub cpu_peak_gflops: f64,
67 pub gpu_min_elements: usize,
69}
70
71const DEFAULT_CPU_PEAK_GFLOPS: f64 = 400.0;
73
74impl Default for BackendCostModel {
75 fn default() -> Self {
76 Self {
77 pcie_bandwidth_gbps: 15.75, gpu_peak_tflops: 10.0, cpu_peak_gflops: DEFAULT_CPU_PEAK_GFLOPS,
80 gpu_min_elements: 1_000_000, }
82 }
83}
84
85impl BackendCostModel {
86 pub fn select_backend(&self, m: usize, n: usize, k: usize) -> ComputeBackend {
93 let flops = 2 * m * n * k;
94 let bytes = 4 * (m * k + k * n + m * n); let arithmetic_intensity = flops as f64 / bytes as f64;
96
97 let ridge_point = self.gpu_peak_tflops * 1000.0 / self.pcie_bandwidth_gbps;
99
100 let elements = m * n * k;
105 if arithmetic_intensity > ridge_point && elements > self.gpu_min_elements {
106 #[cfg(feature = "wgpu")]
108 return ComputeBackend::Wgpu;
109
110 #[cfg(all(not(feature = "wgpu"), feature = "cuda"))]
111 return ComputeBackend::Gpu;
112
113 #[allow(unreachable_code)]
114 ComputeBackend::Cpu
115 } else {
116 #[cfg(target_arch = "x86_64")]
118 {
119 if is_x86_feature_detected!("avx2") {
120 return ComputeBackend::Cpu;
121 }
122 }
123 #[cfg(target_arch = "aarch64")]
124 {
125 return ComputeBackend::Cpu;
126 }
127 ComputeBackend::Scalar
128 }
129 }
130
131 pub fn estimate_time_us(&self, m: usize, n: usize, k: usize, backend: ComputeBackend) -> f64 {
133 let flops = 2.0 * m as f64 * n as f64 * k as f64;
134 let bytes = 4.0 * (m * k + k * n + m * n) as f64;
135
136 match backend {
137 ComputeBackend::Gpu | ComputeBackend::Wgpu => {
138 let transfer_us = bytes / (self.pcie_bandwidth_gbps * 1e3);
140 let compute_us = flops / (self.gpu_peak_tflops * 1e6);
141 transfer_us + compute_us
142 }
143 ComputeBackend::Cpu => flops / (self.cpu_peak_gflops * 1e3),
144 ComputeBackend::Scalar => {
145 flops / 1e3
147 }
148 }
149 }
150}
151
152#[derive(Debug, Clone, Default)]
156pub struct UnifiedBrickProfiler {
157 pub cpu_stats: BlisProfiler,
159 pub backend: Option<ComputeBackend>,
161 pub total_elements: u64,
163 pub selection_history: Vec<(usize, usize, usize, ComputeBackend)>,
165}
166
167impl UnifiedBrickProfiler {
168 pub fn new() -> Self {
170 Self {
171 cpu_stats: BlisProfiler::enabled(),
172 backend: None,
173 total_elements: 0,
174 selection_history: Vec::new(),
175 }
176 }
177
178 pub fn record_selection(&mut self, m: usize, n: usize, k: usize, backend: ComputeBackend) {
180 self.backend = Some(backend);
181 self.total_elements += (m * n) as u64;
182 self.selection_history.push((m, n, k, backend));
183 }
184
185 pub fn roofline_analysis(&self, m: usize, n: usize, k: usize) -> RooflineResult {
187 let cost = BackendCostModel::default();
188 let flops = 2.0 * m as f64 * n as f64 * k as f64;
189 let bytes = 4.0 * (m * k + k * n + m * n) as f64;
190 let ai = flops / bytes;
191
192 let ridge_point = match self.backend.unwrap_or(ComputeBackend::Cpu) {
193 ComputeBackend::Gpu | ComputeBackend::Wgpu => {
194 cost.gpu_peak_tflops * 1000.0 / cost.pcie_bandwidth_gbps
195 }
196 ComputeBackend::Cpu | ComputeBackend::Scalar => {
197 cost.cpu_peak_gflops / 50.0 }
199 };
200
201 if ai < ridge_point {
202 RooflineResult::MemoryBound { ai, ridge_point }
203 } else {
204 RooflineResult::ComputeBound { ai, ridge_point }
205 }
206 }
207
208 pub fn summary(&self) -> String {
210 let mut s = String::new();
211 s.push_str("Unified Brick Profiler Summary\n");
212 s.push_str("==============================\n");
213 s.push_str(&format!("Backend: {:?}\n", self.backend.unwrap_or(ComputeBackend::Scalar)));
214 s.push_str(&format!("Total elements: {}\n", self.total_elements));
215 s.push_str(&format!("Selections: {} decisions\n", self.selection_history.len()));
216 s.push_str("\nCPU Stats:\n");
217 s.push_str(&self.cpu_stats.summary());
218 s
219 }
220}
221
222#[derive(Debug, Clone, Copy)]
224pub enum RooflineResult {
225 MemoryBound {
227 ai: f64,
229 ridge_point: f64,
231 },
232 ComputeBound {
234 ai: f64,
236 ridge_point: f64,
238 },
239}
240
241impl RooflineResult {
242 pub fn arithmetic_intensity(&self) -> f64 {
244 match self {
245 RooflineResult::MemoryBound { ai, .. } => *ai,
246 RooflineResult::ComputeBound { ai, .. } => *ai,
247 }
248 }
249
250 pub fn is_compute_bound(&self) -> bool {
252 matches!(self, RooflineResult::ComputeBound { .. })
253 }
254}
255
256#[derive(Debug, Clone)]
266pub struct PtxMicrokernelSpec {
267 pub ptx_version: &'static str,
269 pub sm_target: &'static str,
271 pub registers_per_thread: u32,
273 pub smem_bytes: usize,
275 pub block_dim: (u32, u32, u32),
277 pub tile_dim: (usize, usize),
279}
280
281impl Default for PtxMicrokernelSpec {
282 fn default() -> Self {
283 Self {
284 ptx_version: "8.0",
285 sm_target: "sm_80",
286 registers_per_thread: 64,
287 smem_bytes: 48 * 1024, block_dim: (16, 16, 1),
289 tile_dim: (16, 16), }
291 }
292}
293
294#[derive(Debug, Clone)]
298pub struct WgslMicrokernelSpec {
299 pub workgroup_size: (u32, u32, u32),
301 pub tile_dim: (usize, usize),
303 pub use_shared_memory: bool,
305}
306
307impl Default for WgslMicrokernelSpec {
308 fn default() -> Self {
309 Self { workgroup_size: (8, 8, 1), tile_dim: (8, 8), use_shared_memory: true }
310 }
311}
312
313impl WgslMicrokernelSpec {
314 pub fn generate_wgsl(&self) -> String {
319 format!(
320 r#"// WGSL GEMM Microkernel
321// Generated by trueno BLIS module
322// Tile: {}x{}, Workgroup: {}x{}x{}
323
324struct GemmParams {{
325 m: u32,
326 n: u32,
327 k: u32,
328 alpha: f32,
329 beta: f32,
330}}
331
332@group(0) @binding(0) var<uniform> params: GemmParams;
333@group(0) @binding(1) var<storage, read> a: array<f32>;
334@group(0) @binding(2) var<storage, read> b: array<f32>;
335@group(0) @binding(3) var<storage, read_write> c: array<f32>;
336
337var<workgroup> tile_a: array<f32, {tile_a_size}>;
338var<workgroup> tile_b: array<f32, {tile_b_size}>;
339
340@compute @workgroup_size({wx}, {wy}, {wz})
341fn main(
342 @builtin(global_invocation_id) global_id: vec3<u32>,
343 @builtin(local_invocation_id) local_id: vec3<u32>,
344 @builtin(workgroup_id) group_id: vec3<u32>,
345) {{
346 let row = global_id.y;
347 let col = global_id.x;
348
349 if (row >= params.m || col >= params.n) {{
350 return;
351 }}
352
353 var sum: f32 = 0.0;
354
355 // Tile over K dimension
356 let num_tiles = (params.k + {tile_k}u - 1u) / {tile_k}u;
357
358 for (var t: u32 = 0u; t < num_tiles; t++) {{
359 let k_base = t * {tile_k}u;
360
361 // Load tile_a and tile_b into shared memory
362 // (simplified - production code would have proper coalescing)
363 let k_idx = k_base + local_id.x;
364 if (row < params.m && k_idx < params.k) {{
365 tile_a[local_id.y * {tile_k}u + local_id.x] = a[row * params.k + k_idx];
366 }}
367 if (k_idx < params.k && col < params.n) {{
368 tile_b[local_id.y * {tile_k}u + local_id.x] = b[k_idx * params.n + col];
369 }}
370
371 workgroupBarrier();
372
373 // Compute partial sum
374 for (var kk: u32 = 0u; kk < {tile_k}u; kk++) {{
375 if (k_base + kk < params.k) {{
376 sum += tile_a[local_id.y * {tile_k}u + kk] * tile_b[kk * {tile_k}u + local_id.x];
377 }}
378 }}
379
380 workgroupBarrier();
381 }}
382
383 // Store result
384 let c_idx = row * params.n + col;
385 c[c_idx] = params.alpha * sum + params.beta * c[c_idx];
386}}
387"#,
388 self.tile_dim.0,
389 self.tile_dim.1,
390 self.workgroup_size.0,
391 self.workgroup_size.1,
392 self.workgroup_size.2,
393 tile_a_size = self.tile_dim.0 * self.tile_dim.0,
394 tile_b_size = self.tile_dim.0 * self.tile_dim.1,
395 wx = self.workgroup_size.0,
396 wy = self.workgroup_size.1,
397 wz = self.workgroup_size.2,
398 tile_k = self.tile_dim.0,
399 )
400 }
401}
402
403pub fn gemm_auto(
407 m: usize,
408 n: usize,
409 k: usize,
410 a: &[f32],
411 b: &[f32],
412 c: &mut [f32],
413 profiler: Option<&mut UnifiedBrickProfiler>,
414) -> Result<(), TruenoError> {
415 let cost_model = BackendCostModel::default();
416 let backend = cost_model.select_backend(m, n, k);
417
418 if let Some(prof) = profiler {
419 prof.record_selection(m, n, k, backend);
420 }
421
422 match backend {
423 ComputeBackend::Cpu | ComputeBackend::Scalar => {
424 gemm_blis(m, n, k, a, b, c, None)
426 }
427 ComputeBackend::Gpu => {
428 gemm_blis(m, n, k, a, b, c, None)
431 }
432 ComputeBackend::Wgpu => {
433 gemm_blis(m, n, k, a, b, c, None)
436 }
437 }
438}