Skip to main content

gam_gpu/
calibration.rs

1use crate::device::GpuDeviceInfo;
2use crate::gpu_error::GpuError;
3use crate::policy::GpuDispatchPolicy;
4use faer::Side;
5use gam_linalg::faer_ndarray::FaerCholesky;
6use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
7use ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::PathBuf;
11use std::time::Instant;
12
13const SCHEMA_VERSION: u32 = 1;
14const CACHE_ROOT_COMPONENTS: [&str; 4] = ["gam", "gpu", "policy", "v1"];
15const GEMM_DIMS: [usize; 3] = [64, 128, 256];
16const POTRF_DIMS: [usize; 3] = [64, 128, 256];
17const XTWX_DIMS: [(usize, usize); 3] = [(2048, 32), (4096, 64), (8192, 96)];
18const GPU_WIN_RATIO: f64 = 0.95;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21struct CachedCalibration {
22    schema_version: u32,
23    device_fingerprint: String,
24    policy: GpuDispatchPolicy,
25    measurements: Vec<MeasurementRecord>,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29struct MeasurementRecord {
30    operation: String,
31    rows: usize,
32    cols: usize,
33    inner: usize,
34    flops: usize,
35    cpu_seconds: f64,
36    gpu_seconds: f64,
37}
38
39#[derive(Clone, Debug)]
40struct Measurement {
41    operation: &'static str,
42    rows: usize,
43    cols: usize,
44    inner: usize,
45    flops: usize,
46    cpu_seconds: f64,
47    gpu_seconds: f64,
48}
49
50pub(crate) fn calibrated_policy_for_device(device: &GpuDeviceInfo) -> GpuDispatchPolicy {
51    let fingerprint = device_fingerprint(device);
52    if let Some(cached) = load_cached_policy(fingerprint) {
53        log::info!(
54            "[GPU] loaded calibrated dispatch policy for {} ({fingerprint})",
55            device.name
56        );
57        return cached;
58    }
59
60    match calibrate_device(device, fingerprint) {
61        Ok(record) => {
62            let policy = record.policy.clone();
63            store_cached_policy(fingerprint, &record);
64            policy
65        }
66        Err(err) => {
67            log::warn!(
68                "[GPU] dispatch calibration unavailable for {}: {}; using default policy",
69                device.name,
70                err
71            );
72            GpuDispatchPolicy::default()
73        }
74    }
75}
76
77fn calibrate_device(
78    device: &GpuDeviceInfo,
79    fingerprint: Fingerprint,
80) -> Result<CachedCalibration, GpuError> {
81    let mut measurements = Vec::new();
82    measurements.extend(measure_gemm(device.ordinal)?);
83    measurements.extend(measure_potrf(device.ordinal)?);
84    measurements.extend(measure_xtwx(device.ordinal)?);
85    if measurements.is_empty() {
86        return Err(GpuError::CalibrationFailed {
87            reason: "no GPU calibration measurements completed".to_string(),
88        });
89    }
90
91    let mut policy = GpuDispatchPolicy::default();
92    if let Some(flops) = crossover_flops(&measurements, "gemm", policy.gemm_min_flops) {
93        policy.gemm_min_flops = flops;
94    }
95    if let Some(flops) = crossover_flops(&measurements, "xtwx", policy.xtwx_flops_min) {
96        policy.xtwx_flops_min = flops;
97    }
98    if let Some(rows) = crossover_rows(&measurements, "xtwx", policy.xtwx_n_min) {
99        policy.xtwx_n_min = rows;
100        policy.row_kernel_min_n = rows;
101        policy.fused_kernel_min_n = rows.saturating_mul(2);
102    }
103    if let Some(p) = crossover_rows(&measurements, "potrf", policy.potrf_min_p) {
104        policy.potrf_min_p = p;
105        policy.prefer_gpu_factorization_min_p = p;
106    }
107
108    log::info!(
109        "[GPU] calibrated dispatch policy for {} ({fingerprint}) from {} measurements",
110        device.name,
111        measurements.len()
112    );
113
114    Ok(CachedCalibration {
115        schema_version: SCHEMA_VERSION,
116        device_fingerprint: fingerprint.to_hex(),
117        policy,
118        measurements: measurements
119            .into_iter()
120            .map(Measurement::into_record)
121            .collect(),
122    })
123}
124
125fn measure_gemm(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
126    let mut out = Vec::with_capacity(GEMM_DIMS.len());
127    for dim in GEMM_DIMS {
128        let a = deterministic_matrix(dim, dim, 0.13);
129        let b = deterministic_matrix(dim, dim, 0.37);
130        let cpu_seconds = time_cpu(|| a.dot(&b))?;
131        let gpu_seconds = time_gpu(|| {
132            crate::blas::gemm_on_ordinal_cuda(ordinal, a.view(), b.view(), false, false)
133        })?;
134        out.push(Measurement {
135            operation: "gemm",
136            rows: dim,
137            cols: dim,
138            inner: dim,
139            flops: 2usize
140                .saturating_mul(dim)
141                .saturating_mul(dim)
142                .saturating_mul(dim),
143            cpu_seconds,
144            gpu_seconds,
145        });
146    }
147    Ok(out)
148}
149
150fn measure_potrf(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
151    let mut out = Vec::with_capacity(POTRF_DIMS.len());
152    for dim in POTRF_DIMS {
153        let a = deterministic_spd_matrix(dim);
154        let cpu_seconds = time_gpu_result(|| {
155            a.cholesky(Side::Lower)
156                .map(|factor| factor.lower_triangular())
157                .map_err(|err| format!("cpu POTRF failed: {err}"))
158        })?;
159        let gpu_seconds =
160            time_gpu_result(|| crate::solver::cholesky_lower_on_ordinal_gpu(ordinal, a.view()))?;
161        out.push(Measurement {
162            operation: "potrf",
163            rows: dim,
164            cols: dim,
165            inner: dim,
166            flops: dim.saturating_mul(dim).saturating_mul(dim) / 3,
167            cpu_seconds,
168            gpu_seconds,
169        });
170    }
171    Ok(out)
172}
173
174fn measure_xtwx(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
175    let mut out = Vec::with_capacity(XTWX_DIMS.len());
176    for (n, p) in XTWX_DIMS {
177        let x = deterministic_matrix(n, p, 0.61);
178        let w = deterministic_weights(n);
179        let cpu_seconds = time_cpu(|| cpu_xtwx(&x, &w))?;
180        let gpu_seconds =
181            time_gpu(|| crate::blas::xt_diag_x_on_ordinal_cuda(ordinal, x.view(), w.view()))?;
182        out.push(Measurement {
183            operation: "xtwx",
184            rows: n,
185            cols: p,
186            inner: p,
187            flops: 2usize.saturating_mul(n).saturating_mul(p).saturating_mul(p),
188            cpu_seconds,
189            gpu_seconds,
190        });
191    }
192    Ok(out)
193}
194
195fn time_cpu<F>(mut f: F) -> Result<f64, GpuError>
196where
197    F: FnMut() -> Array2<f64>,
198{
199    time_gpu_result(|| Result::<Array2<f64>, GpuError>::Ok(f()))
200}
201
202fn time_gpu<F>(mut f: F) -> Result<f64, GpuError>
203where
204    F: FnMut() -> Option<Array2<f64>>,
205{
206    time_gpu_result(|| {
207        f().ok_or_else(|| GpuError::CalibrationFailed {
208            reason: "GPU calibration kernel returned no result".to_string(),
209        })
210    })
211}
212
213fn time_gpu_result<F, E>(mut f: F) -> Result<f64, GpuError>
214where
215    F: FnMut() -> Result<Array2<f64>, E>,
216    E: std::fmt::Display,
217{
218    let start = Instant::now();
219    let out = f().map_err(|err| GpuError::CalibrationFailed {
220        reason: err.to_string(),
221    })?;
222    let elapsed = start.elapsed().as_secs_f64();
223    let checksum = out.iter().fold(0.0, |acc, value| acc + value.abs());
224    if elapsed.is_finite() && elapsed > 0.0 && checksum.is_finite() {
225        Ok(elapsed)
226    } else {
227        Err(GpuError::CalibrationFailed {
228            reason: format!(
229                "invalid calibration timing/checksum: elapsed={elapsed}, checksum={checksum}"
230            ),
231        })
232    }
233}
234
235fn crossover_flops(
236    measurements: &[Measurement],
237    operation: &'static str,
238    fallback: usize,
239) -> Option<usize> {
240    crossover_measurement(measurements, operation)
241        .map(|measurement| measurement.flops.max(1))
242        .or_else(|| {
243            measurements
244                .iter()
245                .filter(|measurement| measurement.operation == operation)
246                .map(|measurement| measurement.flops)
247                .max()
248                .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
249        })
250}
251
252fn crossover_rows(
253    measurements: &[Measurement],
254    operation: &'static str,
255    fallback: usize,
256) -> Option<usize> {
257    crossover_measurement(measurements, operation)
258        .map(|measurement| measurement.rows.max(1))
259        .or_else(|| {
260            measurements
261                .iter()
262                .filter(|measurement| measurement.operation == operation)
263                .map(|measurement| measurement.rows)
264                .max()
265                .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
266        })
267}
268
269fn crossover_measurement<'a>(
270    measurements: &'a [Measurement],
271    operation: &'static str,
272) -> Option<&'a Measurement> {
273    measurements
274        .iter()
275        .filter(|measurement| measurement.operation == operation)
276        .find(|measurement| measurement.gpu_seconds <= measurement.cpu_seconds * GPU_WIN_RATIO)
277}
278
279fn deterministic_matrix(rows: usize, cols: usize, phase: f64) -> Array2<f64> {
280    Array2::from_shape_fn((rows, cols), |(row, col)| {
281        let x = (row as f64 + 1.0) * 0.017 + (col as f64 + 1.0) * 0.031 + phase;
282        x.sin() + 0.25 * (2.0 * x).cos()
283    })
284}
285
286fn deterministic_spd_matrix(dim: usize) -> Array2<f64> {
287    let a = deterministic_matrix(dim, dim, 0.89);
288    let mut spd = a.t().dot(&a);
289    for idx in 0..dim {
290        spd[[idx, idx]] += dim as f64;
291    }
292    spd
293}
294
295fn deterministic_weights(n: usize) -> Array1<f64> {
296    Array1::from_shape_fn(n, |idx| 0.5 + ((idx as f64 + 1.0) * 0.019).sin().abs())
297}
298
299fn cpu_xtwx(x: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
300    let mut weighted = x.clone();
301    for (mut row, weight) in weighted.outer_iter_mut().zip(w.iter()) {
302        row *= *weight;
303    }
304    x.t().dot(&weighted)
305}
306
307fn load_cached_policy(fingerprint: Fingerprint) -> Option<GpuDispatchPolicy> {
308    let path = cache_path(fingerprint);
309    let bytes = fs::read(path).ok()?;
310    let record: CachedCalibration = serde_json::from_slice(&bytes).ok()?;
311    if record.schema_version == SCHEMA_VERSION && record.device_fingerprint == fingerprint.to_hex()
312    {
313        Some(record.policy)
314    } else {
315        None
316    }
317}
318
319fn store_cached_policy(fingerprint: Fingerprint, record: &CachedCalibration) {
320    let path = cache_path(fingerprint);
321    if let Some(parent) = path.parent() {
322        if let Err(err) = fs::create_dir_all(parent) {
323            log::warn!("[GPU] unable to create calibration cache dir: {err}");
324            return;
325        }
326    }
327    let tmp = path.with_extension("json.tmp");
328    let bytes = match serde_json::to_vec_pretty(record) {
329        Ok(bytes) => bytes,
330        Err(err) => {
331            log::warn!("[GPU] unable to serialize calibration cache: {err}");
332            return;
333        }
334    };
335    if let Err(err) = fs::write(&tmp, bytes).and_then(|_| fs::rename(&tmp, &path)) {
336        log::warn!("[GPU] unable to write calibration cache: {err}");
337    }
338}
339
340fn cache_path(fingerprint: Fingerprint) -> PathBuf {
341    let mut root = std::env::temp_dir();
342    for component in CACHE_ROOT_COMPONENTS {
343        root.push(component);
344    }
345    root.push(format!("{fingerprint}.json"));
346    root
347}
348
349fn device_fingerprint(device: &GpuDeviceInfo) -> Fingerprint {
350    let mut fp = Fingerprinter::new();
351    fp.absorb_tag(b"gpu-dispatch-calibration");
352    fp.absorb_u64(b"schema-version", u64::from(SCHEMA_VERSION));
353    fp.absorb_str(b"name", &device.name);
354    fp.absorb_u64(
355        b"compute-major",
356        u64::try_from(device.capability.compute_major).unwrap_or(0),
357    );
358    fp.absorb_u64(
359        b"compute-minor",
360        u64::try_from(device.capability.compute_minor).unwrap_or(0),
361    );
362    fp.absorb_u64(b"sm-count", u64::try_from(device.sm_count).unwrap_or(0));
363    fp.absorb_u64(
364        b"max-threads-per-sm",
365        u64::try_from(device.max_threads_per_sm).unwrap_or(0),
366    );
367    fp.absorb_u64(
368        b"max-shared-mem-per-block",
369        device.max_shared_mem_per_block as u64,
370    );
371    fp.absorb_u64(b"l2-cache-bytes", device.l2_cache_bytes as u64);
372    fp.absorb_u64(b"total-mem-bytes", device.total_mem_bytes as u64);
373    fp.absorb_u64(b"ecc-enabled", bool_fingerprint_value(device.ecc_enabled));
374    fp.absorb_u64(b"integrated", bool_fingerprint_value(device.integrated));
375    fp.absorb_u64(b"mig-mode", bool_fingerprint_value(device.mig_mode));
376    fp.finalize()
377}
378
379const fn bool_fingerprint_value(value: bool) -> u64 {
380    if value { 1 } else { 0 }
381}
382
383impl Measurement {
384    fn into_record(self) -> MeasurementRecord {
385        MeasurementRecord {
386            operation: self.operation.to_string(),
387            rows: self.rows,
388            cols: self.cols,
389            inner: self.inner,
390            flops: self.flops,
391            cpu_seconds: self.cpu_seconds,
392            gpu_seconds: self.gpu_seconds,
393        }
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::device::GpuCapability;
401
402    fn measurement(
403        operation: &'static str,
404        rows: usize,
405        cols: usize,
406        flops: usize,
407        cpu_seconds: f64,
408        gpu_seconds: f64,
409    ) -> Measurement {
410        Measurement {
411            operation,
412            rows,
413            cols,
414            inner: cols,
415            flops,
416            cpu_seconds,
417            gpu_seconds,
418        }
419    }
420
421    #[test]
422    fn calibration_crossover_uses_first_measured_gpu_win() {
423        let measurements = vec![
424            measurement("gemm", 64, 64, 524_288, 0.001, 0.004),
425            measurement("gemm", 128, 128, 4_194_304, 0.010, 0.009),
426            measurement("gemm", 256, 256, 33_554_432, 0.080, 0.010),
427        ];
428
429        assert_eq!(
430            crossover_flops(&measurements, "gemm", 100_000_000),
431            Some(4_194_304)
432        );
433    }
434
435    #[test]
436    fn calibration_crossover_raises_threshold_when_gpu_never_wins() {
437        let measurements = vec![
438            measurement("xtwx", 2_048, 32, 4_194_304, 0.001, 0.004),
439            measurement("xtwx", 4_096, 64, 33_554_432, 0.010, 0.040),
440            measurement("xtwx", 8_192, 96, 150_994_944, 0.080, 0.400),
441        ];
442
443        assert_eq!(
444            crossover_flops(&measurements, "xtwx", 100_000_000),
445            Some(301_989_888)
446        );
447        assert_eq!(crossover_rows(&measurements, "xtwx", 50_000), Some(50_000));
448    }
449
450    #[test]
451    fn calibration_cache_key_tracks_device_fingerprint() {
452        let device = GpuDeviceInfo {
453            ordinal: 0,
454            name: "unit-test GPU".to_string(),
455            capability: GpuCapability::from_compute_capability(8, 0),
456            sm_count: 108,
457            max_threads_per_sm: 2048,
458            max_shared_mem_per_block: 99_328,
459            l2_cache_bytes: 40 * 1024 * 1024,
460            total_mem_bytes: 80 * 1024 * 1024 * 1024,
461            free_mem_bytes: 70 * 1024 * 1024 * 1024,
462            ecc_enabled: true,
463            integrated: false,
464            mig_mode: false,
465        };
466
467        let fingerprint = device_fingerprint(&device);
468        let path = cache_path(fingerprint);
469        assert!(path.ends_with(format!("{}.json", fingerprint.to_hex())));
470        assert!(
471            path.components()
472                .map(|component| component.as_os_str().to_string_lossy().into_owned())
473                .collect::<Vec<_>>()
474                .windows(CACHE_ROOT_COMPONENTS.len())
475                .any(|window| window == CACHE_ROOT_COMPONENTS)
476        );
477    }
478}