runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use std::collections::HashMap;
use std::sync::{
    atomic::{AtomicU64, Ordering},
    Mutex,
};

use runmat_accelerate_api::{
    KernelAttrTelemetry, KernelLaunchTelemetry, ProviderDispatchStats, ProviderFallbackStat,
    ProviderTelemetry,
};

const MAX_KERNEL_LAUNCH_EVENTS: usize = 64;

#[derive(Default)]
pub struct AccelTelemetry {
    fused_elementwise_count: AtomicU64,
    fused_elementwise_wall_ns: AtomicU64,
    fused_reduction_count: AtomicU64,
    fused_reduction_wall_ns: AtomicU64,
    matmul_count: AtomicU64,
    matmul_wall_ns: AtomicU64,
    linsolve_count: AtomicU64,
    linsolve_wall_ns: AtomicU64,
    mldivide_count: AtomicU64,
    mldivide_wall_ns: AtomicU64,
    mrdivide_count: AtomicU64,
    mrdivide_wall_ns: AtomicU64,
    upload_bytes: AtomicU64,
    download_bytes: AtomicU64,
    solve_fallbacks: Mutex<HashMap<&'static str, u64>>,
    kernel_launches: Mutex<Vec<KernelLaunchTelemetry>>,
}

impl AccelTelemetry {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn record_upload_bytes(&self, bytes: u64) {
        if bytes > 0 {
            self.upload_bytes.fetch_add(bytes, Ordering::Relaxed);
        }
    }

    pub fn record_download_bytes(&self, bytes: u64) {
        if bytes > 0 {
            self.download_bytes.fetch_add(bytes, Ordering::Relaxed);
        }
    }

    pub fn record_fused_elementwise(&self, wall_ns: u64) {
        self.fused_elementwise_count.fetch_add(1, Ordering::Relaxed);
        if wall_ns > 0 {
            self.fused_elementwise_wall_ns
                .fetch_add(wall_ns, Ordering::Relaxed);
        }
    }

    pub fn record_fused_reduction(&self, wall_ns: u64) {
        self.fused_reduction_count.fetch_add(1, Ordering::Relaxed);
        if wall_ns > 0 {
            self.fused_reduction_wall_ns
                .fetch_add(wall_ns, Ordering::Relaxed);
        }
    }

    pub fn record_matmul(&self, wall_ns: u64) {
        self.matmul_count.fetch_add(1, Ordering::Relaxed);
        if wall_ns > 0 {
            self.matmul_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
        }
    }

    pub fn record_linsolve(&self, wall_ns: u64) {
        self.linsolve_count.fetch_add(1, Ordering::Relaxed);
        if wall_ns > 0 {
            self.linsolve_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
        }
    }

    pub fn record_mldivide(&self, wall_ns: u64) {
        self.mldivide_count.fetch_add(1, Ordering::Relaxed);
        if wall_ns > 0 {
            self.mldivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
        }
    }

    pub fn record_mrdivide(&self, wall_ns: u64) {
        self.mrdivide_count.fetch_add(1, Ordering::Relaxed);
        if wall_ns > 0 {
            self.mrdivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
        }
    }

    pub fn record_solve_fallback(&self, reason: &'static str) {
        if let Ok(mut guard) = self.solve_fallbacks.lock() {
            *guard.entry(reason).or_insert(0) += 1;
        }
    }

    pub fn reset(&self) {
        self.fused_elementwise_count.store(0, Ordering::Relaxed);
        self.fused_elementwise_wall_ns.store(0, Ordering::Relaxed);
        self.fused_reduction_count.store(0, Ordering::Relaxed);
        self.fused_reduction_wall_ns.store(0, Ordering::Relaxed);
        self.matmul_count.store(0, Ordering::Relaxed);
        self.matmul_wall_ns.store(0, Ordering::Relaxed);
        self.linsolve_count.store(0, Ordering::Relaxed);
        self.linsolve_wall_ns.store(0, Ordering::Relaxed);
        self.mldivide_count.store(0, Ordering::Relaxed);
        self.mldivide_wall_ns.store(0, Ordering::Relaxed);
        self.mrdivide_count.store(0, Ordering::Relaxed);
        self.mrdivide_wall_ns.store(0, Ordering::Relaxed);
        self.upload_bytes.store(0, Ordering::Relaxed);
        self.download_bytes.store(0, Ordering::Relaxed);
        if let Ok(mut guard) = self.solve_fallbacks.lock() {
            guard.clear();
        }
        if let Ok(mut guard) = self.kernel_launches.lock() {
            guard.clear();
        }
    }

    pub fn snapshot(
        &self,
        fusion_cache_hits: u64,
        fusion_cache_misses: u64,
        bind_group_cache_hits: u64,
        bind_group_cache_misses: u64,
        bind_group_cache_by_layout: Option<Vec<runmat_accelerate_api::BindGroupLayoutTelemetry>>,
    ) -> ProviderTelemetry {
        let kernel_launches = self
            .kernel_launches
            .lock()
            .map(|events| events.clone())
            .unwrap_or_default();
        let solve_fallbacks = self
            .solve_fallbacks
            .lock()
            .map(|reasons| {
                let mut stats: Vec<ProviderFallbackStat> = reasons
                    .iter()
                    .map(|(reason, count)| ProviderFallbackStat {
                        reason: (*reason).to_string(),
                        count: *count,
                    })
                    .collect();
                stats.sort_by(|a, b| a.reason.cmp(&b.reason));
                stats
            })
            .unwrap_or_default();
        ProviderTelemetry {
            fused_elementwise: ProviderDispatchStats {
                count: self.fused_elementwise_count.load(Ordering::Relaxed),
                total_wall_time_ns: self.fused_elementwise_wall_ns.load(Ordering::Relaxed),
            },
            fused_reduction: ProviderDispatchStats {
                count: self.fused_reduction_count.load(Ordering::Relaxed),
                total_wall_time_ns: self.fused_reduction_wall_ns.load(Ordering::Relaxed),
            },
            matmul: ProviderDispatchStats {
                count: self.matmul_count.load(Ordering::Relaxed),
                total_wall_time_ns: self.matmul_wall_ns.load(Ordering::Relaxed),
            },
            linsolve: ProviderDispatchStats {
                count: self.linsolve_count.load(Ordering::Relaxed),
                total_wall_time_ns: self.linsolve_wall_ns.load(Ordering::Relaxed),
            },
            mldivide: ProviderDispatchStats {
                count: self.mldivide_count.load(Ordering::Relaxed),
                total_wall_time_ns: self.mldivide_wall_ns.load(Ordering::Relaxed),
            },
            mrdivide: ProviderDispatchStats {
                count: self.mrdivide_count.load(Ordering::Relaxed),
                total_wall_time_ns: self.mrdivide_wall_ns.load(Ordering::Relaxed),
            },
            upload_bytes: self.upload_bytes.load(Ordering::Relaxed),
            download_bytes: self.download_bytes.load(Ordering::Relaxed),
            solve_fallbacks,
            fusion_cache_hits,
            fusion_cache_misses,
            bind_group_cache_hits,
            bind_group_cache_misses,
            bind_group_cache_by_layout,
            kernel_launches,
        }
    }
}

fn saturating_duration_ns(duration: std::time::Duration) -> u64 {
    duration.as_nanos().min(u64::MAX as u128) as u64
}

impl AccelTelemetry {
    pub fn record_fused_elementwise_duration(&self, duration: std::time::Duration) {
        self.record_fused_elementwise(saturating_duration_ns(duration));
    }

    pub fn record_fused_reduction_duration(&self, duration: std::time::Duration) {
        self.record_fused_reduction(saturating_duration_ns(duration));
    }

    pub fn record_matmul_duration(&self, duration: std::time::Duration) {
        self.record_matmul(saturating_duration_ns(duration));
    }

    pub fn record_linsolve_duration(&self, duration: std::time::Duration) {
        self.record_linsolve(saturating_duration_ns(duration));
    }

    pub fn record_mldivide_duration(&self, duration: std::time::Duration) {
        self.record_mldivide(saturating_duration_ns(duration));
    }

    pub fn record_mrdivide_duration(&self, duration: std::time::Duration) {
        self.record_mrdivide(saturating_duration_ns(duration));
    }

    pub fn record_kernel_launch(
        &self,
        kernel: &'static str,
        precision: Option<&str>,
        shape: &[(&str, u64)],
        tuning: &[(&str, u64)],
    ) {
        let event = KernelLaunchTelemetry {
            kernel: kernel.to_string(),
            precision: precision.map(|p| p.to_string()),
            shape: Self::pairs_to_attrs(shape),
            tuning: Self::pairs_to_attrs(tuning),
        };
        if let Ok(mut guard) = self.kernel_launches.lock() {
            if guard.len() >= MAX_KERNEL_LAUNCH_EVENTS {
                let drop = guard.len() + 1 - MAX_KERNEL_LAUNCH_EVENTS;
                guard.drain(0..drop);
            }
            guard.push(event);
        }
    }

    fn pairs_to_attrs(pairs: &[(&str, u64)]) -> Vec<KernelAttrTelemetry> {
        pairs
            .iter()
            .map(|(k, v)| KernelAttrTelemetry {
                key: (*k).to_string(),
                value: *v,
            })
            .collect()
    }
}