batuta/backend.rs
1/// Backend selection and cost model (per spec section 2.2)
2///
3/// # Mixture-of-Experts (MoE) Routing
4///
5/// This module implements adaptive backend selection using a Mixture-of-Experts approach.
6/// The MoE router analyzes operation complexity and data size to select the optimal
7/// compute backend (Scalar/SIMD/GPU).
8///
9/// ## Operation Complexity Levels
10///
11/// - **Low**: Element-wise operations (add, multiply, etc.) - Memory-bound, GPU rarely beneficial
12/// - **Medium**: Reductions (dot product, sum, etc.) - Moderate compute, GPU at 100K+ elements
13/// - **High**: Matrix operations (matmul, convolution) - Compute-intensive O(n²) or O(n³), GPU at 10K+ elements
14///
15/// ## Usage Example
16///
17/// ```rust
18/// use batuta::backend::{BackendSelector, OpComplexity};
19///
20/// let selector = BackendSelector::new();
21///
22/// // Element-wise operation
23/// let backend = selector.select_with_moe(OpComplexity::Low, 500_000);
24/// // Returns: Scalar (below 1M threshold, memory-bound)
25///
26/// // Matrix multiplication
27/// let backend = selector.select_with_moe(OpComplexity::High, 50_000);
28/// // Returns: GPU (above 10K threshold for O(n²) ops)
29/// ```
30///
31/// ## Performance Thresholds
32///
33/// Based on empirical analysis and the 5× PCIe rule (Gregg & Hazelwood 2011):
34///
35/// | Complexity | SIMD Threshold | GPU Threshold | Rationale |
36/// |------------|---------------|---------------|-----------|
37/// | Low | 1M elements | Never | Memory-bound, PCIe overhead dominates |
38/// | Medium | 10K elements | 100K elements | Moderate compute/transfer ratio |
39/// | High | 1K elements | 10K elements | O(n²/n³) complexity favors GPU |
40///
41use serde::{Deserialize, Serialize};
42
43#[cfg(feature = "trueno-integration")]
44use trueno::{Matrix, Vector};
45
46/// Compute backend options
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[allow(clippy::upper_case_acronyms)]
49pub enum Backend {
50 /// Scalar operations (baseline)
51 Scalar,
52 /// SIMD vectorization (AVX2, NEON)
53 SIMD,
54 /// GPU acceleration (WebGPU/Vulkan)
55 GPU,
56}
57
58impl std::fmt::Display for Backend {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Backend::Scalar => write!(f, "Scalar"),
62 Backend::SIMD => write!(f, "SIMD"),
63 Backend::GPU => write!(f, "GPU"),
64 }
65 }
66}
67
68/// Operation complexity for MoE (Mixture-of-Experts) routing
69#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
70pub enum OpComplexity {
71 /// Simple operations (add, mul) - O(n), prefer SIMD unless very large
72 Low,
73 /// Moderate operations (dot, reduce) - O(n), GPU beneficial at 100K+ elements
74 Medium,
75 /// Complex operations (matmul, convolution) - O(n²) or O(n³), GPU beneficial at 10K+ elements
76 High,
77}
78
79/// Cost model for backend selection
80/// Based on spec section 2.2 lines 191-204
81pub struct BackendSelector {
82 /// PCIe bandwidth in bytes/sec (default: 32 GB/s for PCIe 4.0 x16)
83 pcie_bandwidth: f64,
84
85 /// GPU compute throughput in FLOPS (default: 20 TFLOPS for A100)
86 gpu_gflops: f64,
87
88 /// Minimum dispatch ratio (default: 5× per Gregg & Hazelwood 2011)
89 min_dispatch_ratio: f64,
90}
91
92impl Default for BackendSelector {
93 fn default() -> Self {
94 Self {
95 pcie_bandwidth: 32e9, // 32 GB/s
96 gpu_gflops: 20e12, // 20 TFLOPS
97 min_dispatch_ratio: 5.0, // 5× rule
98 }
99 }
100}
101
102impl BackendSelector {
103 pub fn new() -> Self {
104 Self::default()
105 }
106
107 /// Configure custom PCIe bandwidth (must be > 0).
108 pub fn with_pcie_bandwidth(mut self, bandwidth: f64) -> Self {
109 assert!(bandwidth > 0.0, "PCIe bandwidth must be > 0");
110 self.pcie_bandwidth = bandwidth;
111 self
112 }
113
114 /// Configure custom GPU throughput (must be > 0).
115 pub fn with_gpu_gflops(mut self, gflops: f64) -> Self {
116 assert!(gflops > 0.0, "GPU GFLOPS must be > 0");
117 self.gpu_gflops = gflops;
118 self
119 }
120
121 /// Configure custom dispatch ratio threshold
122 pub fn with_min_dispatch_ratio(mut self, ratio: f64) -> Self {
123 self.min_dispatch_ratio = ratio;
124 self
125 }
126
127 /// Select optimal backend based on workload characteristics
128 ///
129 /// # Arguments
130 /// * `data_bytes` - Amount of data to transfer (host → device)
131 /// * `flops` - Floating point operations required
132 ///
133 /// # Returns
134 /// Recommended backend based on cost model
135 ///
136 /// # Cost Model
137 /// GPU dispatch is beneficial when:
138 /// ```text
139 /// compute_time > min_dispatch_ratio × transfer_time
140 /// ```
141 ///
142 /// Per Gregg & Hazelwood (2011), the 5× rule accounts for:
143 /// - Host→Device transfer (PCIe overhead)
144 /// - Kernel launch latency
145 /// - Device→Host transfer
146 /// - CPU-GPU synchronization
147 pub fn select_backend(&self, data_bytes: usize, flops: u64) -> Backend {
148 // Calculate transfer time (seconds)
149 let transfer_s = data_bytes as f64 / self.pcie_bandwidth;
150
151 // Calculate compute time (seconds)
152 let compute_s = flops as f64 / self.gpu_gflops;
153
154 // Apply 5× dispatch rule
155 if compute_s > self.min_dispatch_ratio * transfer_s {
156 Backend::GPU
157 } else {
158 // Fallback to SIMD for intermediate workloads
159 Backend::SIMD
160 }
161 }
162
163 /// Select backend for matrix multiplication
164 ///
165 /// # Arguments
166 /// * `m`, `n`, `k` - Matrix dimensions (M×K) × (K×N) = (M×N)
167 ///
168 /// # Complexity
169 /// - Data: O(mk + kn + mn) = O(mk + kn + mn) bytes
170 /// - FLOPs: O(2mnk) = O(mnk) operations
171 pub fn select_for_matmul(&self, m: usize, n: usize, k: usize) -> Backend {
172 // Data size: two input matrices + output (f32 = 4 bytes)
173 let data_bytes = (m * k + k * n + m * n) * 4;
174
175 // FLOPs: 2mnk (multiply-add per element)
176 let flops = (2 * m * n * k) as u64;
177
178 self.select_backend(data_bytes, flops)
179 }
180
181 /// Select backend for vector operations
182 ///
183 /// # Arguments
184 /// * `n` - Vector length
185 /// * `ops_per_element` - Operations per element (e.g., 2 for dot product)
186 pub fn select_for_vector_op(&self, n: usize, ops_per_element: u64) -> Backend {
187 // Data size: typically two input vectors + output (f32 = 4 bytes)
188 let data_bytes = n * 3 * 4;
189
190 // FLOPs
191 let flops = n as u64 * ops_per_element;
192
193 self.select_backend(data_bytes, flops)
194 }
195
196 /// Select backend for element-wise operations
197 ///
198 /// Element-wise ops are memory-bound, so GPU is rarely beneficial
199 pub fn select_for_elementwise(&self, n: usize) -> Backend {
200 // Element-wise ops: 1 FLOP per element, memory-bound
201 // GPU overhead rarely justified
202 if n > 1_000_000 {
203 Backend::SIMD
204 } else {
205 Backend::Scalar
206 }
207 }
208
209 /// MoE (Mixture-of-Experts) routing: select backend based on operation complexity
210 ///
211 /// # Arguments
212 /// * `complexity` - Operation complexity (Low/Medium/High)
213 /// * `data_size` - Number of elements in the operation
214 ///
215 /// # Returns
216 /// Recommended backend using adaptive thresholds per complexity level
217 ///
218 /// # MoE Thresholds (per empirical performance analysis)
219 /// - **Low complexity** (element-wise): SIMD at 1M+ elements, never GPU
220 /// - **Medium complexity** (reductions): SIMD at 10K+, GPU at 100K+ elements
221 /// - **High complexity** (matmul): SIMD at 1K+, GPU at 10K+ elements
222 pub fn select_with_moe(&self, complexity: OpComplexity, data_size: usize) -> Backend {
223 match complexity {
224 OpComplexity::Low => {
225 // Element-wise: memory-bound, GPU overhead not justified
226 if data_size > 1_000_000 {
227 Backend::SIMD
228 } else {
229 Backend::Scalar
230 }
231 }
232 OpComplexity::Medium => {
233 // Reductions (dot product, sum): moderate compute
234 if data_size > 100_000 {
235 Backend::GPU
236 } else if data_size > 10_000 {
237 Backend::SIMD
238 } else {
239 Backend::Scalar
240 }
241 }
242 OpComplexity::High => {
243 // Matrix operations: compute-intensive, O(n²) or O(n³)
244 if data_size > 10_000 {
245 Backend::GPU
246 } else if data_size > 1_000 {
247 Backend::SIMD
248 } else {
249 Backend::Scalar
250 }
251 }
252 }
253 }
254
255 /// Map Batuta Backend to Trueno Backend
256 #[cfg(feature = "trueno-integration")]
257 pub fn to_trueno_backend(backend: Backend) -> trueno::Backend {
258 match backend {
259 Backend::Scalar => trueno::Backend::Scalar,
260 Backend::SIMD => trueno::Backend::Auto, // Let Trueno pick best SIMD (AVX2/NEON)
261 Backend::GPU => trueno::Backend::GPU,
262 }
263 }
264
265 /// Perform vector addition using Trueno with selected backend
266 #[cfg(feature = "trueno-integration")]
267 pub fn vector_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, String> {
268 if a.len() != b.len() {
269 return Err("Vector lengths must match".to_string());
270 }
271
272 let _backend = self.select_with_moe(OpComplexity::Low, a.len());
273
274 let vec_a: Vector<f32> = Vector::from_slice(a);
275 let vec_b: Vector<f32> = Vector::from_slice(b);
276
277 match vec_a.add(&vec_b) {
278 Ok(result) => Ok(result.as_slice().to_vec()),
279 Err(e) => Err(format!("Trueno error: {}", e)),
280 }
281 }
282
283 /// Perform matrix multiplication using Trueno with selected backend
284 #[cfg(feature = "trueno-integration")]
285 pub fn matrix_multiply(
286 &self,
287 a: &[f32],
288 b: &[f32],
289 m: usize,
290 n: usize,
291 k: usize,
292 ) -> Result<Vec<f32>, String> {
293 // Matrix A: m×k, Matrix B: k×n, Result: m×n
294 if a.len() != m * k {
295 return Err(format!("Matrix A size mismatch: expected {}, got {}", m * k, a.len()));
296 }
297 if b.len() != k * n {
298 return Err(format!("Matrix B size mismatch: expected {}, got {}", k * n, b.len()));
299 }
300
301 let _backend = self.select_for_matmul(m, n, k);
302
303 // Create matrices using from_vec (Trueno API)
304 let mat_a: Matrix<f32> = match Matrix::from_vec(m, k, a.to_vec()) {
305 Ok(m) => m,
306 Err(e) => return Err(format!("Trueno error creating matrix A: {}", e)),
307 };
308
309 let mat_b: Matrix<f32> = match Matrix::from_vec(k, n, b.to_vec()) {
310 Ok(m) => m,
311 Err(e) => return Err(format!("Trueno error creating matrix B: {}", e)),
312 };
313
314 match mat_a.matmul(&mat_b) {
315 Ok(result) => Ok(result.as_slice().to_vec()),
316 Err(e) => Err(format!("Trueno error in matmul: {}", e)),
317 }
318 }
319}
320
321#[cfg(test)]
322#[path = "backend_tests.rs"]
323mod tests;