Skip to main content

gam_gpu/
calibration.rs

1use crate::device::GpuDeviceInfo;
2use crate::gpu_error::GpuError;
3use crate::policy::GpuDispatchPolicy;
4use gam_linalg::faer_ndarray::FaerCholesky;
5use faer::Side;
6use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
7use ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::PathBuf;
11use std::time::Instant;
12
13const SCHEMA_VERSION: u32 = 1;
14const CACHE_ROOT_COMPONENTS: [&str; 4] = ["gam", "gpu", "policy", "v1"];
15const GEMM_DIMS: [usize; 3] = [64, 128, 256];
16const POTRF_DIMS: [usize; 3] = [64, 128, 256];
17const XTWX_DIMS: [(usize, usize); 3] = [(2048, 32), (4096, 64), (8192, 96)];
18const GPU_WIN_RATIO: f64 = 0.95;
19
20#[derive(Clone, Debug, Serialize, Deserialize)]
21struct CachedCalibration {
22    schema_version: u32,
23    device_fingerprint: String,
24    policy: GpuDispatchPolicy,
25    measurements: Vec<MeasurementRecord>,
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29struct MeasurementRecord {
30    operation: String,
31    rows: usize,
32    cols: usize,
33    inner: usize,
34    flops: usize,
35    cpu_seconds: f64,
36    gpu_seconds: f64,
37}
38
39#[derive(Clone, Debug)]
40struct Measurement {
41    operation: &'static str,
42    rows: usize,
43    cols: usize,
44    inner: usize,
45    flops: usize,
46    cpu_seconds: f64,
47    gpu_seconds: f64,
48}
49
50pub(crate) fn calibrated_policy_for_device(device: &GpuDeviceInfo) -> GpuDispatchPolicy {
51    let fingerprint = device_fingerprint(device);
52    if let Some(cached) = load_cached_policy(fingerprint) {
53        log::info!(
54            "[GPU] loaded calibrated dispatch policy for {} ({fingerprint})",
55            device.name
56        );
57        return cached;
58    }
59
60    match calibrate_device(device, fingerprint) {
61        Ok(record) => {
62            let policy = record.policy.clone();
63            store_cached_policy(fingerprint, &record);
64            policy
65        }
66        Err(err) => {
67            log::warn!(
68                "[GPU] dispatch calibration unavailable for {}: {}; using default policy",
69                device.name,
70                err
71            );
72            GpuDispatchPolicy::default()
73        }
74    }
75}
76
77fn calibrate_device(
78    device: &GpuDeviceInfo,
79    fingerprint: Fingerprint,
80) -> Result<CachedCalibration, GpuError> {
81    let mut measurements = Vec::new();
82    measurements.extend(measure_gemm(device.ordinal)?);
83    measurements.extend(measure_potrf(device.ordinal)?);
84    measurements.extend(measure_xtwx(device.ordinal)?);
85    if measurements.is_empty() {
86        return Err(GpuError::CalibrationFailed {
87            reason: "no GPU calibration measurements completed".to_string(),
88        });
89    }
90
91    let mut policy = GpuDispatchPolicy::default();
92    if let Some(flops) = crossover_flops(&measurements, "gemm", policy.gemm_min_flops) {
93        policy.gemm_min_flops = flops;
94    }
95    if let Some(flops) = crossover_flops(&measurements, "xtwx", policy.xtwx_flops_min) {
96        policy.xtwx_flops_min = flops;
97    }
98    if let Some(rows) = crossover_rows(&measurements, "xtwx", policy.xtwx_n_min) {
99        policy.xtwx_n_min = rows;
100        policy.row_kernel_min_n = rows;
101        policy.fused_kernel_min_n = rows.saturating_mul(2);
102    }
103    if let Some(p) = crossover_rows(&measurements, "potrf", policy.potrf_min_p) {
104        policy.potrf_min_p = p;
105        policy.prefer_gpu_factorization_min_p = p;
106    }
107
108    log::info!(
109        "[GPU] calibrated dispatch policy for {} ({fingerprint}) from {} measurements",
110        device.name,
111        measurements.len()
112    );
113
114    Ok(CachedCalibration {
115        schema_version: SCHEMA_VERSION,
116        device_fingerprint: fingerprint.to_hex(),
117        policy,
118        measurements: measurements
119            .into_iter()
120            .map(Measurement::into_record)
121            .collect(),
122    })
123}
124
125fn measure_gemm(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
126    let mut out = Vec::with_capacity(GEMM_DIMS.len());
127    for dim in GEMM_DIMS {
128        let a = deterministic_matrix(dim, dim, 0.13);
129        let b = deterministic_matrix(dim, dim, 0.37);
130        let cpu_seconds = time_cpu(|| a.dot(&b))?;
131        let gpu_seconds = time_gpu(|| {
132            crate::blas::gemm_on_ordinal_cuda(ordinal, a.view(), b.view(), false, false)
133        })?;
134        out.push(Measurement {
135            operation: "gemm",
136            rows: dim,
137            cols: dim,
138            inner: dim,
139            flops: 2usize
140                .saturating_mul(dim)
141                .saturating_mul(dim)
142                .saturating_mul(dim),
143            cpu_seconds,
144            gpu_seconds,
145        });
146    }
147    Ok(out)
148}
149
150fn measure_potrf(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
151    let mut out = Vec::with_capacity(POTRF_DIMS.len());
152    for dim in POTRF_DIMS {
153        let a = deterministic_spd_matrix(dim);
154        let cpu_seconds = time_gpu_result(|| {
155            a.cholesky(Side::Lower)
156                .map(|factor| factor.lower_triangular())
157                .map_err(|err| format!("cpu POTRF failed: {err}"))
158        })?;
159        let gpu_seconds = time_gpu_result(|| {
160            crate::solver::cholesky_lower_on_ordinal_gpu(ordinal, a.view())
161        })?;
162        out.push(Measurement {
163            operation: "potrf",
164            rows: dim,
165            cols: dim,
166            inner: dim,
167            flops: dim.saturating_mul(dim).saturating_mul(dim) / 3,
168            cpu_seconds,
169            gpu_seconds,
170        });
171    }
172    Ok(out)
173}
174
175fn measure_xtwx(ordinal: usize) -> Result<Vec<Measurement>, GpuError> {
176    let mut out = Vec::with_capacity(XTWX_DIMS.len());
177    for (n, p) in XTWX_DIMS {
178        let x = deterministic_matrix(n, p, 0.61);
179        let w = deterministic_weights(n);
180        let cpu_seconds = time_cpu(|| cpu_xtwx(&x, &w))?;
181        let gpu_seconds =
182            time_gpu(|| crate::blas::xt_diag_x_on_ordinal_cuda(ordinal, x.view(), w.view()))?;
183        out.push(Measurement {
184            operation: "xtwx",
185            rows: n,
186            cols: p,
187            inner: p,
188            flops: 2usize.saturating_mul(n).saturating_mul(p).saturating_mul(p),
189            cpu_seconds,
190            gpu_seconds,
191        });
192    }
193    Ok(out)
194}
195
196fn time_cpu<F>(mut f: F) -> Result<f64, GpuError>
197where
198    F: FnMut() -> Array2<f64>,
199{
200    time_gpu_result(|| Result::<Array2<f64>, GpuError>::Ok(f()))
201}
202
203fn time_gpu<F>(mut f: F) -> Result<f64, GpuError>
204where
205    F: FnMut() -> Option<Array2<f64>>,
206{
207    time_gpu_result(|| {
208        f().ok_or_else(|| GpuError::CalibrationFailed {
209            reason: "GPU calibration kernel returned no result".to_string(),
210        })
211    })
212}
213
214fn time_gpu_result<F, E>(mut f: F) -> Result<f64, GpuError>
215where
216    F: FnMut() -> Result<Array2<f64>, E>,
217    E: std::fmt::Display,
218{
219    let start = Instant::now();
220    let out = f().map_err(|err| GpuError::CalibrationFailed {
221        reason: err.to_string(),
222    })?;
223    let elapsed = start.elapsed().as_secs_f64();
224    let checksum = out.iter().fold(0.0, |acc, value| acc + value.abs());
225    if elapsed.is_finite() && elapsed > 0.0 && checksum.is_finite() {
226        Ok(elapsed)
227    } else {
228        Err(GpuError::CalibrationFailed {
229            reason: format!(
230                "invalid calibration timing/checksum: elapsed={elapsed}, checksum={checksum}"
231            ),
232        })
233    }
234}
235
236fn crossover_flops(
237    measurements: &[Measurement],
238    operation: &'static str,
239    fallback: usize,
240) -> Option<usize> {
241    crossover_measurement(measurements, operation)
242        .map(|measurement| measurement.flops.max(1))
243        .or_else(|| {
244            measurements
245                .iter()
246                .filter(|measurement| measurement.operation == operation)
247                .map(|measurement| measurement.flops)
248                .max()
249                .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
250        })
251}
252
253fn crossover_rows(
254    measurements: &[Measurement],
255    operation: &'static str,
256    fallback: usize,
257) -> Option<usize> {
258    crossover_measurement(measurements, operation)
259        .map(|measurement| measurement.rows.max(1))
260        .or_else(|| {
261            measurements
262                .iter()
263                .filter(|measurement| measurement.operation == operation)
264                .map(|measurement| measurement.rows)
265                .max()
266                .map(|max_seen| fallback.max(max_seen.saturating_mul(2)))
267        })
268}
269
270fn crossover_measurement<'a>(
271    measurements: &'a [Measurement],
272    operation: &'static str,
273) -> Option<&'a Measurement> {
274    measurements
275        .iter()
276        .filter(|measurement| measurement.operation == operation)
277        .find(|measurement| measurement.gpu_seconds <= measurement.cpu_seconds * GPU_WIN_RATIO)
278}
279
280fn deterministic_matrix(rows: usize, cols: usize, phase: f64) -> Array2<f64> {
281    Array2::from_shape_fn((rows, cols), |(row, col)| {
282        let x = (row as f64 + 1.0) * 0.017 + (col as f64 + 1.0) * 0.031 + phase;
283        x.sin() + 0.25 * (2.0 * x).cos()
284    })
285}
286
287fn deterministic_spd_matrix(dim: usize) -> Array2<f64> {
288    let a = deterministic_matrix(dim, dim, 0.89);
289    let mut spd = a.t().dot(&a);
290    for idx in 0..dim {
291        spd[[idx, idx]] += dim as f64;
292    }
293    spd
294}
295
296fn deterministic_weights(n: usize) -> Array1<f64> {
297    Array1::from_shape_fn(n, |idx| 0.5 + ((idx as f64 + 1.0) * 0.019).sin().abs())
298}
299
300fn cpu_xtwx(x: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
301    let mut weighted = x.clone();
302    for (mut row, weight) in weighted.outer_iter_mut().zip(w.iter()) {
303        row *= *weight;
304    }
305    x.t().dot(&weighted)
306}
307
308fn load_cached_policy(fingerprint: Fingerprint) -> Option<GpuDispatchPolicy> {
309    let path = cache_path(fingerprint);
310    let bytes = fs::read(path).ok()?;
311    let record: CachedCalibration = serde_json::from_slice(&bytes).ok()?;
312    if record.schema_version == SCHEMA_VERSION && record.device_fingerprint == fingerprint.to_hex()
313    {
314        Some(record.policy)
315    } else {
316        None
317    }
318}
319
320fn store_cached_policy(fingerprint: Fingerprint, record: &CachedCalibration) {
321    let path = cache_path(fingerprint);
322    if let Some(parent) = path.parent() {
323        if let Err(err) = fs::create_dir_all(parent) {
324            log::warn!("[GPU] unable to create calibration cache dir: {err}");
325            return;
326        }
327    }
328    let tmp = path.with_extension("json.tmp");
329    let bytes = match serde_json::to_vec_pretty(record) {
330        Ok(bytes) => bytes,
331        Err(err) => {
332            log::warn!("[GPU] unable to serialize calibration cache: {err}");
333            return;
334        }
335    };
336    if let Err(err) = fs::write(&tmp, bytes).and_then(|_| fs::rename(&tmp, &path)) {
337        log::warn!("[GPU] unable to write calibration cache: {err}");
338    }
339}
340
341fn cache_path(fingerprint: Fingerprint) -> PathBuf {
342    let mut root = std::env::temp_dir();
343    for component in CACHE_ROOT_COMPONENTS {
344        root.push(component);
345    }
346    root.push(format!("{fingerprint}.json"));
347    root
348}
349
350fn device_fingerprint(device: &GpuDeviceInfo) -> Fingerprint {
351    let mut fp = Fingerprinter::new();
352    fp.absorb_tag(b"gpu-dispatch-calibration");
353    fp.absorb_u64(b"schema-version", u64::from(SCHEMA_VERSION));
354    fp.absorb_str(b"name", &device.name);
355    fp.absorb_u64(
356        b"compute-major",
357        u64::try_from(device.capability.compute_major).unwrap_or(0),
358    );
359    fp.absorb_u64(
360        b"compute-minor",
361        u64::try_from(device.capability.compute_minor).unwrap_or(0),
362    );
363    fp.absorb_u64(b"sm-count", u64::try_from(device.sm_count).unwrap_or(0));
364    fp.absorb_u64(
365        b"max-threads-per-sm",
366        u64::try_from(device.max_threads_per_sm).unwrap_or(0),
367    );
368    fp.absorb_u64(
369        b"max-shared-mem-per-block",
370        device.max_shared_mem_per_block as u64,
371    );
372    fp.absorb_u64(b"l2-cache-bytes", device.l2_cache_bytes as u64);
373    fp.absorb_u64(b"total-mem-bytes", device.total_mem_bytes as u64);
374    fp.absorb_u64(b"ecc-enabled", bool_fingerprint_value(device.ecc_enabled));
375    fp.absorb_u64(b"integrated", bool_fingerprint_value(device.integrated));
376    fp.absorb_u64(b"mig-mode", bool_fingerprint_value(device.mig_mode));
377    fp.finalize()
378}
379
380const fn bool_fingerprint_value(value: bool) -> u64 {
381    if value { 1 } else { 0 }
382}
383
384impl Measurement {
385    fn into_record(self) -> MeasurementRecord {
386        MeasurementRecord {
387            operation: self.operation.to_string(),
388            rows: self.rows,
389            cols: self.cols,
390            inner: self.inner,
391            flops: self.flops,
392            cpu_seconds: self.cpu_seconds,
393            gpu_seconds: self.gpu_seconds,
394        }
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::device::GpuCapability;
402
403    fn measurement(
404        operation: &'static str,
405        rows: usize,
406        cols: usize,
407        flops: usize,
408        cpu_seconds: f64,
409        gpu_seconds: f64,
410    ) -> Measurement {
411        Measurement {
412            operation,
413            rows,
414            cols,
415            inner: cols,
416            flops,
417            cpu_seconds,
418            gpu_seconds,
419        }
420    }
421
422    #[test]
423    fn calibration_crossover_uses_first_measured_gpu_win() {
424        let measurements = vec![
425            measurement("gemm", 64, 64, 524_288, 0.001, 0.004),
426            measurement("gemm", 128, 128, 4_194_304, 0.010, 0.009),
427            measurement("gemm", 256, 256, 33_554_432, 0.080, 0.010),
428        ];
429
430        assert_eq!(
431            crossover_flops(&measurements, "gemm", 100_000_000),
432            Some(4_194_304)
433        );
434    }
435
436    #[test]
437    fn calibration_crossover_raises_threshold_when_gpu_never_wins() {
438        let measurements = vec![
439            measurement("xtwx", 2_048, 32, 4_194_304, 0.001, 0.004),
440            measurement("xtwx", 4_096, 64, 33_554_432, 0.010, 0.040),
441            measurement("xtwx", 8_192, 96, 150_994_944, 0.080, 0.400),
442        ];
443
444        assert_eq!(
445            crossover_flops(&measurements, "xtwx", 100_000_000),
446            Some(301_989_888)
447        );
448        assert_eq!(crossover_rows(&measurements, "xtwx", 50_000), Some(50_000));
449    }
450
451    #[test]
452    fn calibration_cache_key_tracks_device_fingerprint() {
453        let device = GpuDeviceInfo {
454            ordinal: 0,
455            name: "unit-test GPU".to_string(),
456            capability: GpuCapability::from_compute_capability(8, 0),
457            sm_count: 108,
458            max_threads_per_sm: 2048,
459            max_shared_mem_per_block: 99_328,
460            l2_cache_bytes: 40 * 1024 * 1024,
461            total_mem_bytes: 80 * 1024 * 1024 * 1024,
462            free_mem_bytes: 70 * 1024 * 1024 * 1024,
463            ecc_enabled: true,
464            integrated: false,
465            mig_mode: false,
466        };
467
468        let fingerprint = device_fingerprint(&device);
469        let path = cache_path(fingerprint);
470        assert!(path.ends_with(format!("{}.json", fingerprint.to_hex())));
471        assert!(
472            path.components()
473                .map(|component| component.as_os_str().to_string_lossy().into_owned())
474                .collect::<Vec<_>>()
475                .windows(CACHE_ROOT_COMPONENTS.len())
476                .any(|window| window == CACHE_ROOT_COMPONENTS)
477        );
478    }
479}