1use crate::device::GpuDeviceInfo;
2use crate::gpu_error::GpuError;
3use crate::policy::GpuDispatchPolicy;
4use faer::Side;
5use gam_linalg::faer_ndarray::FaerCholesky;
6use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
7use ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::PathBuf;
11use std::time::Instant;
12
13const SCHEMA_VERSION: u32 = 1;
14const CACHE_ROOT_COMPONENTS: [&str; 4] = ["gam", "gpu", "policy", "v1"];
15const GEMM_DIMS: [usize; 3] = [64, 128, 256];
16const POTRF_DIMS: [usize; 3] = [64, 128, 256];
17const XTWX_DIMS: [(usize, usize); 3] = [(2048, 32), (4096, 64), (8192, 96)];
18const GPU_WIN_RATIO: f64 = 0.95;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21struct CachedCalibration {
22 schema_version: u32,
23 device_fingerprint: String,
24 policy: GpuDispatchPolicy,
25 measurements: Vec<MeasurementRecord>,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29struct MeasurementRecord {
30 operation: String,
31 rows: usize,
32 cols: usize,
33 inner: usize,
34 flops: usize,
35 cpu_seconds: f64,
36 gpu_seconds: f64,
37}
38
39#[derive(Clone, Debug)]
40struct Measurement {
41 operation: &'static str,
42 rows: usize,
43 cols: usize,
44 inner: usize,
45 flops: usize,
46 cpu_seconds: f64,
47 gpu_seconds: f64,
48}
49
50pub(crate) fn calibrated_policy_for_device(device: &GpuDeviceInfo) -> GpuDispatchPolicy {
51 let fingerprint = device_fingerprint(device);
52 if let Some(cached) = load_cached_policy(fingerprint) {
53 log::info!(
54 "[GPU] loaded calibrated dispatch policy for {} ({fingerprint})",
55 device.name
56 );
57 return cached;
58 }
59
60 match calibrate_device(device, fingerprint) {
61 Ok(record) => {
62 let policy = record.policy.clone();
63 store_cached_policy(fingerprint, &record);
64 policy
65 }
66 Err(err) => {
67 log::warn!(
68 "[GPU] dispatch calibration unavailable for {}: {}; using default policy",
69 device.name,
70 err
71 );
72 GpuDispatchPolicy::default()
73 }
74 }
75}
76
77fn calibrate_device(
78 device: &GpuDeviceInfo,
79 fingerprint: Fingerprint,
80) -> Result<CachedCalibration, GpuError> {
81 let mut measurements = Vec::new();
82 measurements.extend(measure_gemm(device.ordinal)?);
83 measurements.extend(measure_potrf(device.ordinal)?);
84 measurements.extend(measure_xtwx(device.ordinal)?);
85 if measurements.is_empty() {
86 return Err(GpuError::CalibrationFailed {
87 reason: "no GPU calibration measurements completed".to_string(),
88 });
89 }
90
91 let mut policy = GpuDispatchPolicy::default();
92 if let Some(flops) = crossover_flops(&measurements, "gemm", policy.gemm_min_flops) {
93 policy.gemm_min_flops = flops;
94 }
95 if let Some(flops) = crossover_flops(&measurements, "xtwx", policy.xtwx_flops_min) {
96 policy.xtwx_flops_min = flops;
97 }
98 if let Some(rows) = crossover_rows(&measurements, "xtwx", policy.xtwx_n_min) {
99 policy.xtwx_n_min = rows;
100 policy.row_kernel_min_n = rows;
101 policy.fused_kernel_min_n = rows.saturating_mul(2);
102 }
103 if let Some(p) = crossover_rows(&measurements, "potrf", policy.potrf_min_p) {
104 policy.potrf_min_p = p;
105 policy.prefer_gpu_factorization_min_p = p;
106 }
107
108 log::info!(
109 "[GPU] calibrated dispatch policy for {} ({fingerprint}) from {} measurements",
110 device.name,
111 measurements.len()
112 );
113
114 Ok(CachedCalibration {
115 schema_version: SCHEMA_VERSION,
116 device_fingerprint: fingerprint.to_hex(),
117 policy,
118 measurements: measurements
119 .into_iter()
120 .map(Measurement::into_record)
121 .collect(),
122 })
123}
124
125fn measure_gemm(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
126 let mut out = Vec::with_capacity(GEMM_DIMS.len());
127 for dim in GEMM_DIMS {
128 let a = deterministic_matrix(dim, dim, 0.13);
129 let b = deterministic_matrix(dim, dim, 0.37);
130 let cpu_seconds = time_cpu(|| a.dot(&b))?;
131 let gpu_seconds = time_gpu(|| {
132 crate::blas::gemm_on_ordinal_cuda(ordinal, a.view(), b.view(), false, false)
133 })?;
134 out.push(Measurement {
135 operation: "gemm",
136 rows: dim,
137 cols: dim,
138 inner: dim,
139 flops: 2usize
140 .saturating_mul(dim)
141 .saturating_mul(dim)
142 .saturating_mul(dim),
143 cpu_seconds,
144 gpu_seconds,
145 });
146 }
147 Ok(out)
148}
149
150fn measure_potrf(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
151 let mut out = Vec::with_capacity(POTRF_DIMS.len());
152 for dim in POTRF_DIMS {
153 let a = deterministic_spd_matrix(dim);
154 let cpu_seconds = time_gpu_result(|| {
155 a.cholesky(Side::Lower)
156 .map(|factor| factor.lower_triangular())
157 .map_err(|err| format!("cpu POTRF failed: {err}"))
158 })?;
159 let gpu_seconds =
160 time_gpu_result(|| crate::solver::cholesky_lower_on_ordinal_gpu(ordinal, a.view()))?;
161 out.push(Measurement {
162 operation: "potrf",
163 rows: dim,
164 cols: dim,
165 inner: dim,
166 flops: dim.saturating_mul(dim).saturating_mul(dim) / 3,
167 cpu_seconds,
168 gpu_seconds,
169 });
170 }
171 Ok(out)
172}
173
174fn measure_xtwx(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
175 let mut out = Vec::with_capacity(XTWX_DIMS.len());
176 for (n, p) in XTWX_DIMS {
177 let x = deterministic_matrix(n, p, 0.61);
178 let w = deterministic_weights(n);
179 let cpu_seconds = time_cpu(|| cpu_xtwx(&x, &w))?;
180 let gpu_seconds =
181 time_gpu(|| crate::blas::xt_diag_x_on_ordinal_cuda(ordinal, x.view(), w.view()))?;
182 out.push(Measurement {
183 operation: "xtwx",
184 rows: n,
185 cols: p,
186 inner: p,
187 flops: 2usize.saturating_mul(n).saturating_mul(p).saturating_mul(p),
188 cpu_seconds,
189 gpu_seconds,
190 });
191 }
192 Ok(out)
193}
194
195fn time_cpu<F>(mut f: F) -> Result<f64, GpuError>
196where
197 F: FnMut() -> Array2<f64>,
198{
199 time_gpu_result(|| Result::<Array2<f64>, GpuError>::Ok(f()))
200}
201
202fn time_gpu<F>(mut f: F) -> Result<f64, GpuError>
203where
204 F: FnMut() -> Option<Array2<f64>>,
205{
206 time_gpu_result(|| {
207 f().ok_or_else(|| GpuError::CalibrationFailed {
208 reason: "GPU calibration kernel returned no result".to_string(),
209 })
210 })
211}
212
213fn time_gpu_result<F, E>(mut f: F) -> Result<f64, GpuError>
214where
215 F: FnMut() -> Result<Array2<f64>, E>,
216 E: std::fmt::Display,
217{
218 let start = Instant::now();
219 let out = f().map_err(|err| GpuError::CalibrationFailed {
220 reason: err.to_string(),
221 })?;
222 let elapsed = start.elapsed().as_secs_f64();
223 let checksum = out.iter().fold(0.0, |acc, value| acc + value.abs());
224 if elapsed.is_finite() && elapsed > 0.0 && checksum.is_finite() {
225 Ok(elapsed)
226 } else {
227 Err(GpuError::CalibrationFailed {
228 reason: format!(
229 "invalid calibration timing/checksum: elapsed={elapsed}, checksum={checksum}"
230 ),
231 })
232 }
233}
234
235fn crossover_flops(
236 measurements: &[Measurement],
237 operation: &'static str,
238 fallback: usize,
239) -> Option<usize> {
240 crossover_measurement(measurements, operation)
241 .map(|measurement| measurement.flops.max(1))
242 .or_else(|| {
243 measurements
244 .iter()
245 .filter(|measurement| measurement.operation == operation)
246 .map(|measurement| measurement.flops)
247 .max()
248 .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
249 })
250}
251
252fn crossover_rows(
253 measurements: &[Measurement],
254 operation: &'static str,
255 fallback: usize,
256) -> Option<usize> {
257 crossover_measurement(measurements, operation)
258 .map(|measurement| measurement.rows.max(1))
259 .or_else(|| {
260 measurements
261 .iter()
262 .filter(|measurement| measurement.operation == operation)
263 .map(|measurement| measurement.rows)
264 .max()
265 .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
266 })
267}
268
269fn crossover_measurement<'a>(
270 measurements: &'a [Measurement],
271 operation: &'static str,
272) -> Option<&'a Measurement> {
273 measurements
274 .iter()
275 .filter(|measurement| measurement.operation == operation)
276 .find(|measurement| measurement.gpu_seconds <= measurement.cpu_seconds * GPU_WIN_RATIO)
277}
278
279fn deterministic_matrix(rows: usize, cols: usize, phase: f64) -> Array2<f64> {
280 Array2::from_shape_fn((rows, cols), |(row, col)| {
281 let x = (row as f64 + 1.0) * 0.017 + (col as f64 + 1.0) * 0.031 + phase;
282 x.sin() + 0.25 * (2.0 * x).cos()
283 })
284}
285
286fn deterministic_spd_matrix(dim: usize) -> Array2<f64> {
287 let a = deterministic_matrix(dim, dim, 0.89);
288 let mut spd = a.t().dot(&a);
289 for idx in 0..dim {
290 spd[[idx, idx]] += dim as f64;
291 }
292 spd
293}
294
295fn deterministic_weights(n: usize) -> Array1<f64> {
296 Array1::from_shape_fn(n, |idx| 0.5 + ((idx as f64 + 1.0) * 0.019).sin().abs())
297}
298
299fn cpu_xtwx(x: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
300 let mut weighted = x.clone();
301 for (mut row, weight) in weighted.outer_iter_mut().zip(w.iter()) {
302 row *= *weight;
303 }
304 x.t().dot(&weighted)
305}
306
307fn load_cached_policy(fingerprint: Fingerprint) -> Option<GpuDispatchPolicy> {
308 let path = cache_path(fingerprint);
309 let bytes = fs::read(path).ok()?;
310 let record: CachedCalibration = serde_json::from_slice(&bytes).ok()?;
311 if record.schema_version == SCHEMA_VERSION && record.device_fingerprint == fingerprint.to_hex()
312 {
313 Some(record.policy)
314 } else {
315 None
316 }
317}
318
319fn store_cached_policy(fingerprint: Fingerprint, record: &CachedCalibration) {
320 let path = cache_path(fingerprint);
321 if let Some(parent) = path.parent() {
322 if let Err(err) = fs::create_dir_all(parent) {
323 log::warn!("[GPU] unable to create calibration cache dir: {err}");
324 return;
325 }
326 }
327 let tmp = path.with_extension("json.tmp");
328 let bytes = match serde_json::to_vec_pretty(record) {
329 Ok(bytes) => bytes,
330 Err(err) => {
331 log::warn!("[GPU] unable to serialize calibration cache: {err}");
332 return;
333 }
334 };
335 if let Err(err) = fs::write(&tmp, bytes).and_then(|_| fs::rename(&tmp, &path)) {
336 log::warn!("[GPU] unable to write calibration cache: {err}");
337 }
338}
339
340fn cache_path(fingerprint: Fingerprint) -> PathBuf {
341 let mut root = std::env::temp_dir();
342 for component in CACHE_ROOT_COMPONENTS {
343 root.push(component);
344 }
345 root.push(format!("{fingerprint}.json"));
346 root
347}
348
349fn device_fingerprint(device: &GpuDeviceInfo) -> Fingerprint {
350 let mut fp = Fingerprinter::new();
351 fp.absorb_tag(b"gpu-dispatch-calibration");
352 fp.absorb_u64(b"schema-version", u64::from(SCHEMA_VERSION));
353 fp.absorb_str(b"name", &device.name);
354 fp.absorb_u64(
355 b"compute-major",
356 u64::try_from(device.capability.compute_major).unwrap_or(0),
357 );
358 fp.absorb_u64(
359 b"compute-minor",
360 u64::try_from(device.capability.compute_minor).unwrap_or(0),
361 );
362 fp.absorb_u64(b"sm-count", u64::try_from(device.sm_count).unwrap_or(0));
363 fp.absorb_u64(
364 b"max-threads-per-sm",
365 u64::try_from(device.max_threads_per_sm).unwrap_or(0),
366 );
367 fp.absorb_u64(
368 b"max-shared-mem-per-block",
369 device.max_shared_mem_per_block as u64,
370 );
371 fp.absorb_u64(b"l2-cache-bytes", device.l2_cache_bytes as u64);
372 fp.absorb_u64(b"total-mem-bytes", device.total_mem_bytes as u64);
373 fp.absorb_u64(b"ecc-enabled", bool_fingerprint_value(device.ecc_enabled));
374 fp.absorb_u64(b"integrated", bool_fingerprint_value(device.integrated));
375 fp.absorb_u64(b"mig-mode", bool_fingerprint_value(device.mig_mode));
376 fp.finalize()
377}
378
379const fn bool_fingerprint_value(value: bool) -> u64 {
380 if value { 1 } else { 0 }
381}
382
383impl Measurement {
384 fn into_record(self) -> MeasurementRecord {
385 MeasurementRecord {
386 operation: self.operation.to_string(),
387 rows: self.rows,
388 cols: self.cols,
389 inner: self.inner,
390 flops: self.flops,
391 cpu_seconds: self.cpu_seconds,
392 gpu_seconds: self.gpu_seconds,
393 }
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::device::GpuCapability;
401
402 fn measurement(
403 operation: &'static str,
404 rows: usize,
405 cols: usize,
406 flops: usize,
407 cpu_seconds: f64,
408 gpu_seconds: f64,
409 ) -> Measurement {
410 Measurement {
411 operation,
412 rows,
413 cols,
414 inner: cols,
415 flops,
416 cpu_seconds,
417 gpu_seconds,
418 }
419 }
420
421 #[test]
422 fn calibration_crossover_uses_first_measured_gpu_win() {
423 let measurements = vec![
424 measurement("gemm", 64, 64, 524_288, 0.001, 0.004),
425 measurement("gemm", 128, 128, 4_194_304, 0.010, 0.009),
426 measurement("gemm", 256, 256, 33_554_432, 0.080, 0.010),
427 ];
428
429 assert_eq!(
430 crossover_flops(&measurements, "gemm", 100_000_000),
431 Some(4_194_304)
432 );
433 }
434
435 #[test]
436 fn calibration_crossover_raises_threshold_when_gpu_never_wins() {
437 let measurements = vec![
438 measurement("xtwx", 2_048, 32, 4_194_304, 0.001, 0.004),
439 measurement("xtwx", 4_096, 64, 33_554_432, 0.010, 0.040),
440 measurement("xtwx", 8_192, 96, 150_994_944, 0.080, 0.400),
441 ];
442
443 assert_eq!(
444 crossover_flops(&measurements, "xtwx", 100_000_000),
445 Some(301_989_888)
446 );
447 assert_eq!(crossover_rows(&measurements, "xtwx", 50_000), Some(50_000));
448 }
449
450 #[test]
451 fn calibration_cache_key_tracks_device_fingerprint() {
452 let device = GpuDeviceInfo {
453 ordinal: 0,
454 name: "unit-test GPU".to_string(),
455 capability: GpuCapability::from_compute_capability(8, 0),
456 sm_count: 108,
457 max_threads_per_sm: 2048,
458 max_shared_mem_per_block: 99_328,
459 l2_cache_bytes: 40 * 1024 * 1024,
460 total_mem_bytes: 80 * 1024 * 1024 * 1024,
461 free_mem_bytes: 70 * 1024 * 1024 * 1024,
462 ecc_enabled: true,
463 integrated: false,
464 mig_mode: false,
465 };
466
467 let fingerprint = device_fingerprint(&device);
468 let path = cache_path(fingerprint);
469 assert!(path.ends_with(format!("{}.json", fingerprint.to_hex())));
470 assert!(
471 path.components()
472 .map(|component| component.as_os_str().to_string_lossy().into_owned())
473 .collect::<Vec<_>>()
474 .windows(CACHE_ROOT_COMPONENTS.len())
475 .any(|window| window == CACHE_ROOT_COMPONENTS)
476 );
477 }
478}