Skip to main content

trueno/blis/
backend_selection.rs

1//! Backend Selection and Cost Model
2//!
3//! Automatic selection between CPU (SIMD), CUDA (PTX), and wgpu (WGSL) backends
4//! based on the 5× PCIe rule and roofline analysis.
5//!
6//! # Philosophy
7//!
8//! Uses Gregg & Hazelwood (2011) "5× PCIe rule": GPU worthwhile when
9//! compute time exceeds 5× data transfer time.
10//!
11//! # References
12//!
13//! - Gregg, C., & Hazelwood, K. (2011). Where is the Data? Why You Cannot
14//!   Debate CPU vs. GPU Performance Without the Answer. IEEE ISPASS.
15//! - Volkov, V. (2010). Better Performance at Lower Occupancy.
16
17#[cfg(target_arch = "x86_64")]
18use std::arch::is_x86_feature_detected;
19
20use super::profiler::BlisProfiler;
21use super::{gemm_blis, TruenoError};
22
23///
24/// Maps to different ISA targets:
25/// - Cpu: x86 asm (AVX2/AVX-512), ARM asm (NEON)
26/// - Gpu: PTX (CUDA), wgpu compute shaders
27/// - Wgpu: WGSL for cross-platform GPU (Vulkan/Metal/DX12/WebGPU)
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum ComputeBackend {
30    /// CPU SIMD backend (AVX2, AVX-512, NEON, SSE2)
31    Cpu,
32    /// NVIDIA GPU backend (PTX)
33    Gpu,
34    /// Cross-platform GPU backend (wgpu/WGSL)
35    Wgpu,
36    /// Scalar fallback (no SIMD)
37    Scalar,
38}
39
40/// ComputeBrick hierarchy level
41///
42/// Maps BLIS loop structure to brick abstraction:
43/// - Nano: Microkernel (MR×NR×K) - register file
44/// - Micro: Midi loop (MC×NC×KC) - L1/L2 cache
45/// - Meso: Macro loop (full M×N×K) - L3/DRAM
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum BrickLevel {
48    /// Register-level compute (MR×NR tile)
49    Nano,
50    /// Cache-level compute (MC×NC block)
51    Micro,
52    /// Memory-level compute (full matrix)
53    Meso,
54}
55
56/// Cost model for backend selection
57///
58/// Based on Gregg & Hazelwood (2011): GPU worthwhile when compute > 5× transfer
59#[derive(Debug, Clone)]
60pub struct BackendCostModel {
61    /// PCIe bandwidth in GB/s (e.g., 15.75 for PCIe 3.0 x16)
62    pub pcie_bandwidth_gbps: f64,
63    /// GPU peak TFLOP/s
64    pub gpu_peak_tflops: f64,
65    /// CPU peak GFLOP/s
66    pub cpu_peak_gflops: f64,
67    /// Minimum problem size for GPU (elements)
68    pub gpu_min_elements: usize,
69}
70
71/// Modern AVX2 CPU peak compute in GFLOP/s
72const DEFAULT_CPU_PEAK_GFLOPS: f64 = 400.0;
73
74impl Default for BackendCostModel {
75    fn default() -> Self {
76        Self {
77            pcie_bandwidth_gbps: 15.75, // PCIe 3.0 x16
78            gpu_peak_tflops: 10.0,      // Mid-range GPU
79            cpu_peak_gflops: DEFAULT_CPU_PEAK_GFLOPS,
80            gpu_min_elements: 1_000_000, // ~1M elements
81        }
82    }
83}
84
85impl BackendCostModel {
86    /// Select optimal backend based on 5× PCIe rule
87    ///
88    /// # References
89    ///
90    /// Gregg, C., & Hazelwood, K. (2011). Where is the Data? Why You Cannot
91    /// Debate CPU vs. GPU Performance Without the Answer. IEEE ISPASS.
92    pub fn select_backend(&self, m: usize, n: usize, k: usize) -> ComputeBackend {
93        let flops = 2 * m * n * k;
94        let bytes = 4 * (m * k + k * n + m * n); // f32 = 4 bytes
95        let arithmetic_intensity = flops as f64 / bytes as f64;
96
97        // Ridge point: where compute = memory bandwidth
98        let ridge_point = self.gpu_peak_tflops * 1000.0 / self.pcie_bandwidth_gbps;
99
100        // GPU worthwhile if:
101        // 1. High arithmetic intensity (compute-bound)
102        // 2. Problem size exceeds minimum threshold
103        // 3. Transfer time is amortized (5× rule)
104        let elements = m * n * k;
105        if arithmetic_intensity > ridge_point && elements > self.gpu_min_elements {
106            // Check if wgpu available at runtime
107            #[cfg(feature = "wgpu")]
108            return ComputeBackend::Wgpu;
109
110            #[cfg(all(not(feature = "wgpu"), feature = "cuda"))]
111            return ComputeBackend::Gpu;
112
113            #[allow(unreachable_code)]
114            ComputeBackend::Cpu
115        } else {
116            // CPU is better for small problems or memory-bound workloads
117            #[cfg(target_arch = "x86_64")]
118            {
119                if is_x86_feature_detected!("avx2") {
120                    return ComputeBackend::Cpu;
121                }
122            }
123            #[cfg(target_arch = "aarch64")]
124            {
125                return ComputeBackend::Cpu;
126            }
127            ComputeBackend::Scalar
128        }
129    }
130
131    /// Estimate execution time in microseconds
132    pub fn estimate_time_us(&self, m: usize, n: usize, k: usize, backend: ComputeBackend) -> f64 {
133        let flops = 2.0 * m as f64 * n as f64 * k as f64;
134        let bytes = 4.0 * (m * k + k * n + m * n) as f64;
135
136        match backend {
137            ComputeBackend::Gpu | ComputeBackend::Wgpu => {
138                // Transfer time + compute time
139                let transfer_us = bytes / (self.pcie_bandwidth_gbps * 1e3);
140                let compute_us = flops / (self.gpu_peak_tflops * 1e6);
141                transfer_us + compute_us
142            }
143            ComputeBackend::Cpu => flops / (self.cpu_peak_gflops * 1e3),
144            ComputeBackend::Scalar => {
145                // Assume 1 GFLOP/s for scalar
146                flops / 1e3
147            }
148        }
149    }
150}
151
152/// Unified profiler for all backends
153///
154/// Collects metrics across CPU (RDTSC), GPU (CUDA events), and wgpu (timestamp queries)
155#[derive(Debug, Clone, Default)]
156pub struct UnifiedBrickProfiler {
157    /// CPU profiling stats
158    pub cpu_stats: BlisProfiler,
159    /// Selected backend for this run
160    pub backend: Option<ComputeBackend>,
161    /// Total elements processed
162    pub total_elements: u64,
163    /// Backend selection decisions
164    pub selection_history: Vec<(usize, usize, usize, ComputeBackend)>,
165}
166
167impl UnifiedBrickProfiler {
168    /// Create a new unified profiler
169    pub fn new() -> Self {
170        Self {
171            cpu_stats: BlisProfiler::enabled(),
172            backend: None,
173            total_elements: 0,
174            selection_history: Vec::new(),
175        }
176    }
177
178    /// Record backend selection
179    pub fn record_selection(&mut self, m: usize, n: usize, k: usize, backend: ComputeBackend) {
180        self.backend = Some(backend);
181        self.total_elements += (m * n) as u64;
182        self.selection_history.push((m, n, k, backend));
183    }
184
185    /// Get roofline analysis for current backend
186    pub fn roofline_analysis(&self, m: usize, n: usize, k: usize) -> RooflineResult {
187        let cost = BackendCostModel::default();
188        let flops = 2.0 * m as f64 * n as f64 * k as f64;
189        let bytes = 4.0 * (m * k + k * n + m * n) as f64;
190        let ai = flops / bytes;
191
192        let ridge_point = match self.backend.unwrap_or(ComputeBackend::Cpu) {
193            ComputeBackend::Gpu | ComputeBackend::Wgpu => {
194                cost.gpu_peak_tflops * 1000.0 / cost.pcie_bandwidth_gbps
195            }
196            ComputeBackend::Cpu | ComputeBackend::Scalar => {
197                cost.cpu_peak_gflops / 50.0 // ~50 GB/s memory bandwidth
198            }
199        };
200
201        if ai < ridge_point {
202            RooflineResult::MemoryBound { ai, ridge_point }
203        } else {
204            RooflineResult::ComputeBound { ai, ridge_point }
205        }
206    }
207
208    /// Generate summary report
209    pub fn summary(&self) -> String {
210        let mut s = String::new();
211        s.push_str("Unified Brick Profiler Summary\n");
212        s.push_str("==============================\n");
213        s.push_str(&format!("Backend: {:?}\n", self.backend.unwrap_or(ComputeBackend::Scalar)));
214        s.push_str(&format!("Total elements: {}\n", self.total_elements));
215        s.push_str(&format!("Selections: {} decisions\n", self.selection_history.len()));
216        s.push_str("\nCPU Stats:\n");
217        s.push_str(&self.cpu_stats.summary());
218        s
219    }
220}
221
222/// Roofline model result
223#[derive(Debug, Clone, Copy)]
224pub enum RooflineResult {
225    /// Workload is memory-bound (AI < ridge point)
226    MemoryBound {
227        /// Arithmetic intensity (FLOP/byte)
228        ai: f64,
229        /// Ridge point where compute = memory
230        ridge_point: f64,
231    },
232    /// Workload is compute-bound (AI > ridge point)
233    ComputeBound {
234        /// Arithmetic intensity (FLOP/byte)
235        ai: f64,
236        /// Ridge point where compute = memory
237        ridge_point: f64,
238    },
239}
240
241impl RooflineResult {
242    /// Get arithmetic intensity
243    pub fn arithmetic_intensity(&self) -> f64 {
244        match self {
245            RooflineResult::MemoryBound { ai, .. } => *ai,
246            RooflineResult::ComputeBound { ai, .. } => *ai,
247        }
248    }
249
250    /// Check if compute-bound
251    pub fn is_compute_bound(&self) -> bool {
252        matches!(self, RooflineResult::ComputeBound { .. })
253    }
254}
255
256/// PTX microkernel definition (for documentation and future CUDA support)
257///
258/// This is a specification for the GPU microkernel. Actual PTX code generation
259/// would be done by the trueno-ptx crate.
260///
261/// # References
262///
263/// - NVIDIA PTX ISA Reference Manual
264/// - Volkov, V. (2010). Better Performance at Lower Occupancy.
265#[derive(Debug, Clone)]
266pub struct PtxMicrokernelSpec {
267    /// PTX version (e.g., "8.0")
268    pub ptx_version: &'static str,
269    /// Target SM architecture (e.g., "sm_80")
270    pub sm_target: &'static str,
271    /// Register count per thread
272    pub registers_per_thread: u32,
273    /// Shared memory bytes per block
274    pub smem_bytes: usize,
275    /// Thread block dimensions
276    pub block_dim: (u32, u32, u32),
277    /// Tile dimensions (MR, NR)
278    pub tile_dim: (usize, usize),
279}
280
281impl Default for PtxMicrokernelSpec {
282    fn default() -> Self {
283        Self {
284            ptx_version: "8.0",
285            sm_target: "sm_80",
286            registers_per_thread: 64,
287            smem_bytes: 48 * 1024, // 48KB shared memory
288            block_dim: (16, 16, 1),
289            tile_dim: (16, 16), // 16x16 output tile per warp
290        }
291    }
292}
293
294/// WGSL microkernel specification (for wgpu backend)
295///
296/// Defines the compute shader for matrix multiplication.
297#[derive(Debug, Clone)]
298pub struct WgslMicrokernelSpec {
299    /// Workgroup size (x, y, z)
300    pub workgroup_size: (u32, u32, u32),
301    /// Tile dimensions (MR, NR)
302    pub tile_dim: (usize, usize),
303    /// Use shared memory for tiling
304    pub use_shared_memory: bool,
305}
306
307impl Default for WgslMicrokernelSpec {
308    fn default() -> Self {
309        Self { workgroup_size: (8, 8, 1), tile_dim: (8, 8), use_shared_memory: true }
310    }
311}
312
313impl WgslMicrokernelSpec {
314    /// Generate WGSL shader source
315    ///
316    /// This generates a basic tiled GEMM shader. For production use,
317    /// this would be optimized with coalesced memory access and bank conflict avoidance.
318    pub fn generate_wgsl(&self) -> String {
319        format!(
320            r#"// WGSL GEMM Microkernel
321// Generated by trueno BLIS module
322// Tile: {}x{}, Workgroup: {}x{}x{}
323
324struct GemmParams {{
325    m: u32,
326    n: u32,
327    k: u32,
328    alpha: f32,
329    beta: f32,
330}}
331
332@group(0) @binding(0) var<uniform> params: GemmParams;
333@group(0) @binding(1) var<storage, read> a: array<f32>;
334@group(0) @binding(2) var<storage, read> b: array<f32>;
335@group(0) @binding(3) var<storage, read_write> c: array<f32>;
336
337var<workgroup> tile_a: array<f32, {tile_a_size}>;
338var<workgroup> tile_b: array<f32, {tile_b_size}>;
339
340@compute @workgroup_size({wx}, {wy}, {wz})
341fn main(
342    @builtin(global_invocation_id) global_id: vec3<u32>,
343    @builtin(local_invocation_id) local_id: vec3<u32>,
344    @builtin(workgroup_id) group_id: vec3<u32>,
345) {{
346    let row = global_id.y;
347    let col = global_id.x;
348
349    if (row >= params.m || col >= params.n) {{
350        return;
351    }}
352
353    var sum: f32 = 0.0;
354
355    // Tile over K dimension
356    let num_tiles = (params.k + {tile_k}u - 1u) / {tile_k}u;
357
358    for (var t: u32 = 0u; t < num_tiles; t++) {{
359        let k_base = t * {tile_k}u;
360
361        // Load tile_a and tile_b into shared memory
362        // (simplified - production code would have proper coalescing)
363        let k_idx = k_base + local_id.x;
364        if (row < params.m && k_idx < params.k) {{
365            tile_a[local_id.y * {tile_k}u + local_id.x] = a[row * params.k + k_idx];
366        }}
367        if (k_idx < params.k && col < params.n) {{
368            tile_b[local_id.y * {tile_k}u + local_id.x] = b[k_idx * params.n + col];
369        }}
370
371        workgroupBarrier();
372
373        // Compute partial sum
374        for (var kk: u32 = 0u; kk < {tile_k}u; kk++) {{
375            if (k_base + kk < params.k) {{
376                sum += tile_a[local_id.y * {tile_k}u + kk] * tile_b[kk * {tile_k}u + local_id.x];
377            }}
378        }}
379
380        workgroupBarrier();
381    }}
382
383    // Store result
384    let c_idx = row * params.n + col;
385    c[c_idx] = params.alpha * sum + params.beta * c[c_idx];
386}}
387"#,
388            self.tile_dim.0,
389            self.tile_dim.1,
390            self.workgroup_size.0,
391            self.workgroup_size.1,
392            self.workgroup_size.2,
393            tile_a_size = self.tile_dim.0 * self.tile_dim.0,
394            tile_b_size = self.tile_dim.0 * self.tile_dim.1,
395            wx = self.workgroup_size.0,
396            wy = self.workgroup_size.1,
397            wz = self.workgroup_size.2,
398            tile_k = self.tile_dim.0,
399        )
400    }
401}
402
403/// GEMM with automatic backend selection
404///
405/// Uses the 5× PCIe rule to select between CPU (asm) and GPU (PTX/WGSL) backends.
406pub fn gemm_auto(
407    m: usize,
408    n: usize,
409    k: usize,
410    a: &[f32],
411    b: &[f32],
412    c: &mut [f32],
413    profiler: Option<&mut UnifiedBrickProfiler>,
414) -> Result<(), TruenoError> {
415    let cost_model = BackendCostModel::default();
416    let backend = cost_model.select_backend(m, n, k);
417
418    if let Some(prof) = profiler {
419        prof.record_selection(m, n, k, backend);
420    }
421
422    match backend {
423        ComputeBackend::Cpu | ComputeBackend::Scalar => {
424            // Use BLIS CPU implementation
425            gemm_blis(m, n, k, a, b, c, None)
426        }
427        ComputeBackend::Gpu => {
428            // PTX backend (stub - requires CUDA support)
429            // For now, fall back to CPU
430            gemm_blis(m, n, k, a, b, c, None)
431        }
432        ComputeBackend::Wgpu => {
433            // WGSL backend (stub - requires wgpu support)
434            // For now, fall back to CPU
435            gemm_blis(m, n, k, a, b, c, None)
436        }
437    }
438}