1use crate::device::GpuDeviceInfo;
2use crate::gpu_error::GpuError;
3use crate::policy::GpuDispatchPolicy;
4use gam_linalg::faer_ndarray::FaerCholesky;
5use faer::Side;
6use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
7use ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::PathBuf;
11use std::time::Instant;
12
13const SCHEMA_VERSION: u32 = 1;
14const CACHE_ROOT_COMPONENTS: [&str; 4] = ["gam", "gpu", "policy", "v1"];
15const GEMM_DIMS: [usize; 3] = [64, 128, 256];
16const POTRF_DIMS: [usize; 3] = [64, 128, 256];
17const XTWX_DIMS: [(usize, usize); 3] = [(2048, 32), (4096, 64), (8192, 96)];
18const GPU_WIN_RATIO: f64 = 0.95;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21struct CachedCalibration {
22 schema_version: u32,
23 device_fingerprint: String,
24 policy: GpuDispatchPolicy,
25 measurements: Vec<MeasurementRecord>,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29struct MeasurementRecord {
30 operation: String,
31 rows: usize,
32 cols: usize,
33 inner: usize,
34 flops: usize,
35 cpu_seconds: f64,
36 gpu_seconds: f64,
37}
38
39#[derive(Clone, Debug)]
40struct Measurement {
41 operation: &'static str,
42 rows: usize,
43 cols: usize,
44 inner: usize,
45 flops: usize,
46 cpu_seconds: f64,
47 gpu_seconds: f64,
48}
49
50pub(crate) fn calibrated_policy_for_device(device: &GpuDeviceInfo) -> GpuDispatchPolicy {
51 let fingerprint = device_fingerprint(device);
52 if let Some(cached) = load_cached_policy(fingerprint) {
53 log::info!(
54 "[GPU] loaded calibrated dispatch policy for {} ({fingerprint})",
55 device.name
56 );
57 return cached;
58 }
59
60 match calibrate_device(device, fingerprint) {
61 Ok(record) => {
62 let policy = record.policy.clone();
63 store_cached_policy(fingerprint, &record);
64 policy
65 }
66 Err(err) => {
67 log::warn!(
68 "[GPU] dispatch calibration unavailable for {}: {}; using default policy",
69 device.name,
70 err
71 );
72 GpuDispatchPolicy::default()
73 }
74 }
75}
76
77fn calibrate_device(
78 device: &GpuDeviceInfo,
79 fingerprint: Fingerprint,
80) -> Result<CachedCalibration, GpuError> {
81 let mut measurements = Vec::new();
82 measurements.extend(measure_gemm(device.ordinal)?);
83 measurements.extend(measure_potrf(device.ordinal)?);
84 measurements.extend(measure_xtwx(device.ordinal)?);
85 if measurements.is_empty() {
86 return Err(GpuError::CalibrationFailed {
87 reason: "no GPU calibration measurements completed".to_string(),
88 });
89 }
90
91 let mut policy = GpuDispatchPolicy::default();
92 if let Some(flops) = crossover_flops(&measurements, "gemm", policy.gemm_min_flops) {
93 policy.gemm_min_flops = flops;
94 }
95 if let Some(flops) = crossover_flops(&measurements, "xtwx", policy.xtwx_flops_min) {
96 policy.xtwx_flops_min = flops;
97 }
98 if let Some(rows) = crossover_rows(&measurements, "xtwx", policy.xtwx_n_min) {
99 policy.xtwx_n_min = rows;
100 policy.row_kernel_min_n = rows;
101 policy.fused_kernel_min_n = rows.saturating_mul(2);
102 }
103 if let Some(p) = crossover_rows(&measurements, "potrf", policy.potrf_min_p) {
104 policy.potrf_min_p = p;
105 policy.prefer_gpu_factorization_min_p = p;
106 }
107
108 log::info!(
109 "[GPU] calibrated dispatch policy for {} ({fingerprint}) from {} measurements",
110 device.name,
111 measurements.len()
112 );
113
114 Ok(CachedCalibration {
115 schema_version: SCHEMA_VERSION,
116 device_fingerprint: fingerprint.to_hex(),
117 policy,
118 measurements: measurements
119 .into_iter()
120 .map(Measurement::into_record)
121 .collect(),
122 })
123}
124
125fn measure_gemm(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
126 let mut out = Vec::with_capacity(GEMM_DIMS.len());
127 for dim in GEMM_DIMS {
128 let a = deterministic_matrix(dim, dim, 0.13);
129 let b = deterministic_matrix(dim, dim, 0.37);
130 let cpu_seconds = time_cpu(|| a.dot(&b))?;
131 let gpu_seconds = time_gpu(|| {
132 crate::blas::gemm_on_ordinal_cuda(ordinal, a.view(), b.view(), false, false)
133 })?;
134 out.push(Measurement {
135 operation: "gemm",
136 rows: dim,
137 cols: dim,
138 inner: dim,
139 flops: 2usize
140 .saturating_mul(dim)
141 .saturating_mul(dim)
142 .saturating_mul(dim),
143 cpu_seconds,
144 gpu_seconds,
145 });
146 }
147 Ok(out)
148}
149
150fn measure_potrf(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
151 let mut out = Vec::with_capacity(POTRF_DIMS.len());
152 for dim in POTRF_DIMS {
153 let a = deterministic_spd_matrix(dim);
154 let cpu_seconds = time_gpu_result(|| {
155 a.cholesky(Side::Lower)
156 .map(|factor| factor.lower_triangular())
157 .map_err(|err| format!("cpu POTRF failed: {err}"))
158 })?;
159 let gpu_seconds = time_gpu_result(|| {
160 crate::solver::cholesky_lower_on_ordinal_gpu(ordinal, a.view())
161 })?;
162 out.push(Measurement {
163 operation: "potrf",
164 rows: dim,
165 cols: dim,
166 inner: dim,
167 flops: dim.saturating_mul(dim).saturating_mul(dim) / 3,
168 cpu_seconds,
169 gpu_seconds,
170 });
171 }
172 Ok(out)
173}
174
175fn measure_xtwx(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
176 let mut out = Vec::with_capacity(XTWX_DIMS.len());
177 for (n, p) in XTWX_DIMS {
178 let x = deterministic_matrix(n, p, 0.61);
179 let w = deterministic_weights(n);
180 let cpu_seconds = time_cpu(|| cpu_xtwx(&x, &w))?;
181 let gpu_seconds =
182 time_gpu(|| crate::blas::xt_diag_x_on_ordinal_cuda(ordinal, x.view(), w.view()))?;
183 out.push(Measurement {
184 operation: "xtwx",
185 rows: n,
186 cols: p,
187 inner: p,
188 flops: 2usize.saturating_mul(n).saturating_mul(p).saturating_mul(p),
189 cpu_seconds,
190 gpu_seconds,
191 });
192 }
193 Ok(out)
194}
195
196fn time_cpu<F>(mut f: F) -> Result<f64, GpuError>
197where
198 F: FnMut() -> Array2<f64>,
199{
200 time_gpu_result(|| Result::<Array2<f64>, GpuError>::Ok(f()))
201}
202
203fn time_gpu<F>(mut f: F) -> Result<f64, GpuError>
204where
205 F: FnMut() -> Option<Array2<f64>>,
206{
207 time_gpu_result(|| {
208 f().ok_or_else(|| GpuError::CalibrationFailed {
209 reason: "GPU calibration kernel returned no result".to_string(),
210 })
211 })
212}
213
214fn time_gpu_result<F, E>(mut f: F) -> Result<f64, GpuError>
215where
216 F: FnMut() -> Result<Array2<f64>, E>,
217 E: std::fmt::Display,
218{
219 let start = Instant::now();
220 let out = f().map_err(|err| GpuError::CalibrationFailed {
221 reason: err.to_string(),
222 })?;
223 let elapsed = start.elapsed().as_secs_f64();
224 let checksum = out.iter().fold(0.0, |acc, value| acc + value.abs());
225 if elapsed.is_finite() && elapsed > 0.0 && checksum.is_finite() {
226 Ok(elapsed)
227 } else {
228 Err(GpuError::CalibrationFailed {
229 reason: format!(
230 "invalid calibration timing/checksum: elapsed={elapsed}, checksum={checksum}"
231 ),
232 })
233 }
234}
235
236fn crossover_flops(
237 measurements: &[Measurement],
238 operation: &'static str,
239 fallback: usize,
240) -> Option<usize> {
241 crossover_measurement(measurements, operation)
242 .map(|measurement| measurement.flops.max(1))
243 .or_else(|| {
244 measurements
245 .iter()
246 .filter(|measurement| measurement.operation == operation)
247 .map(|measurement| measurement.flops)
248 .max()
249 .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
250 })
251}
252
253fn crossover_rows(
254 measurements: &[Measurement],
255 operation: &'static str,
256 fallback: usize,
257) -> Option<usize> {
258 crossover_measurement(measurements, operation)
259 .map(|measurement| measurement.rows.max(1))
260 .or_else(|| {
261 measurements
262 .iter()
263 .filter(|measurement| measurement.operation == operation)
264 .map(|measurement| measurement.rows)
265 .max()
266 .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
267 })
268}
269
270fn crossover_measurement<'a>(
271 measurements: &'a [Measurement],
272 operation: &'static str,
273) -> Option<&'a Measurement> {
274 measurements
275 .iter()
276 .filter(|measurement| measurement.operation == operation)
277 .find(|measurement| measurement.gpu_seconds <= measurement.cpu_seconds * GPU_WIN_RATIO)
278}
279
280fn deterministic_matrix(rows: usize, cols: usize, phase: f64) -> Array2<f64> {
281 Array2::from_shape_fn((rows, cols), |(row, col)| {
282 let x = (row as f64 + 1.0) * 0.017 + (col as f64 + 1.0) * 0.031 + phase;
283 x.sin() + 0.25 * (2.0 * x).cos()
284 })
285}
286
287fn deterministic_spd_matrix(dim: usize) -> Array2<f64> {
288 let a = deterministic_matrix(dim, dim, 0.89);
289 let mut spd = a.t().dot(&a);
290 for idx in 0..dim {
291 spd[[idx, idx]] += dim as f64;
292 }
293 spd
294}
295
296fn deterministic_weights(n: usize) -> Array1<f64> {
297 Array1::from_shape_fn(n, |idx| 0.5 + ((idx as f64 + 1.0) * 0.019).sin().abs())
298}
299
300fn cpu_xtwx(x: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
301 let mut weighted = x.clone();
302 for (mut row, weight) in weighted.outer_iter_mut().zip(w.iter()) {
303 row *= *weight;
304 }
305 x.t().dot(&weighted)
306}
307
308fn load_cached_policy(fingerprint: Fingerprint) -> Option<GpuDispatchPolicy> {
309 let path = cache_path(fingerprint);
310 let bytes = fs::read(path).ok()?;
311 let record: CachedCalibration = serde_json::from_slice(&bytes).ok()?;
312 if record.schema_version == SCHEMA_VERSION && record.device_fingerprint == fingerprint.to_hex()
313 {
314 Some(record.policy)
315 } else {
316 None
317 }
318}
319
320fn store_cached_policy(fingerprint: Fingerprint, record: &CachedCalibration) {
321 let path = cache_path(fingerprint);
322 if let Some(parent) = path.parent() {
323 if let Err(err) = fs::create_dir_all(parent) {
324 log::warn!("[GPU] unable to create calibration cache dir: {err}");
325 return;
326 }
327 }
328 let tmp = path.with_extension("json.tmp");
329 let bytes = match serde_json::to_vec_pretty(record) {
330 Ok(bytes) => bytes,
331 Err(err) => {
332 log::warn!("[GPU] unable to serialize calibration cache: {err}");
333 return;
334 }
335 };
336 if let Err(err) = fs::write(&tmp, bytes).and_then(|_| fs::rename(&tmp, &path)) {
337 log::warn!("[GPU] unable to write calibration cache: {err}");
338 }
339}
340
341fn cache_path(fingerprint: Fingerprint) -> PathBuf {
342 let mut root = std::env::temp_dir();
343 for component in CACHE_ROOT_COMPONENTS {
344 root.push(component);
345 }
346 root.push(format!("{fingerprint}.json"));
347 root
348}
349
350fn device_fingerprint(device: &GpuDeviceInfo) -> Fingerprint {
351 let mut fp = Fingerprinter::new();
352 fp.absorb_tag(b"gpu-dispatch-calibration");
353 fp.absorb_u64(b"schema-version", u64::from(SCHEMA_VERSION));
354 fp.absorb_str(b"name", &device.name);
355 fp.absorb_u64(
356 b"compute-major",
357 u64::try_from(device.capability.compute_major).unwrap_or(0),
358 );
359 fp.absorb_u64(
360 b"compute-minor",
361 u64::try_from(device.capability.compute_minor).unwrap_or(0),
362 );
363 fp.absorb_u64(b"sm-count", u64::try_from(device.sm_count).unwrap_or(0));
364 fp.absorb_u64(
365 b"max-threads-per-sm",
366 u64::try_from(device.max_threads_per_sm).unwrap_or(0),
367 );
368 fp.absorb_u64(
369 b"max-shared-mem-per-block",
370 device.max_shared_mem_per_block as u64,
371 );
372 fp.absorb_u64(b"l2-cache-bytes", device.l2_cache_bytes as u64);
373 fp.absorb_u64(b"total-mem-bytes", device.total_mem_bytes as u64);
374 fp.absorb_u64(b"ecc-enabled", bool_fingerprint_value(device.ecc_enabled));
375 fp.absorb_u64(b"integrated", bool_fingerprint_value(device.integrated));
376 fp.absorb_u64(b"mig-mode", bool_fingerprint_value(device.mig_mode));
377 fp.finalize()
378}
379
380const fn bool_fingerprint_value(value: bool) -> u64 {
381 if value { 1 } else { 0 }
382}
383
384impl Measurement {
385 fn into_record(self) -> MeasurementRecord {
386 MeasurementRecord {
387 operation: self.operation.to_string(),
388 rows: self.rows,
389 cols: self.cols,
390 inner: self.inner,
391 flops: self.flops,
392 cpu_seconds: self.cpu_seconds,
393 gpu_seconds: self.gpu_seconds,
394 }
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::device::GpuCapability;
402
403 fn measurement(
404 operation: &'static str,
405 rows: usize,
406 cols: usize,
407 flops: usize,
408 cpu_seconds: f64,
409 gpu_seconds: f64,
410 ) -> Measurement {
411 Measurement {
412 operation,
413 rows,
414 cols,
415 inner: cols,
416 flops,
417 cpu_seconds,
418 gpu_seconds,
419 }
420 }
421
422 #[test]
423 fn calibration_crossover_uses_first_measured_gpu_win() {
424 let measurements = vec![
425 measurement("gemm", 64, 64, 524_288, 0.001, 0.004),
426 measurement("gemm", 128, 128, 4_194_304, 0.010, 0.009),
427 measurement("gemm", 256, 256, 33_554_432, 0.080, 0.010),
428 ];
429
430 assert_eq!(
431 crossover_flops(&measurements, "gemm", 100_000_000),
432 Some(4_194_304)
433 );
434 }
435
436 #[test]
437 fn calibration_crossover_raises_threshold_when_gpu_never_wins() {
438 let measurements = vec![
439 measurement("xtwx", 2_048, 32, 4_194_304, 0.001, 0.004),
440 measurement("xtwx", 4_096, 64, 33_554_432, 0.010, 0.040),
441 measurement("xtwx", 8_192, 96, 150_994_944, 0.080, 0.400),
442 ];
443
444 assert_eq!(
445 crossover_flops(&measurements, "xtwx", 100_000_000),
446 Some(301_989_888)
447 );
448 assert_eq!(crossover_rows(&measurements, "xtwx", 50_000), Some(50_000));
449 }
450
451 #[test]
452 fn calibration_cache_key_tracks_device_fingerprint() {
453 let device = GpuDeviceInfo {
454 ordinal: 0,
455 name: "unit-test GPU".to_string(),
456 capability: GpuCapability::from_compute_capability(8, 0),
457 sm_count: 108,
458 max_threads_per_sm: 2048,
459 max_shared_mem_per_block: 99_328,
460 l2_cache_bytes: 40 * 1024 * 1024,
461 total_mem_bytes: 80 * 1024 * 1024 * 1024,
462 free_mem_bytes: 70 * 1024 * 1024 * 1024,
463 ecc_enabled: true,
464 integrated: false,
465 mig_mode: false,
466 };
467
468 let fingerprint = device_fingerprint(&device);
469 let path = cache_path(fingerprint);
470 assert!(path.ends_with(format!("{}.json", fingerprint.to_hex())));
471 assert!(
472 path.components()
473 .map(|component| component.as_os_str().to_string_lossy().into_owned())
474 .collect::<Vec<_>>()
475 .windows(CACHE_ROOT_COMPONENTS.len())
476 .any(|window| window == CACHE_ROOT_COMPONENTS)
477 );
478 }
479}