use std::collections::HashMap;
use std::sync::{
atomic::{AtomicU64, Ordering},
Mutex,
};
use runmat_accelerate_api::{
KernelAttrTelemetry, KernelLaunchTelemetry, ProviderDispatchStats, ProviderFallbackStat,
ProviderTelemetry,
};
const MAX_KERNEL_LAUNCH_EVENTS: usize = 64;
#[derive(Default)]
pub struct AccelTelemetry {
fused_elementwise_count: AtomicU64,
fused_elementwise_wall_ns: AtomicU64,
fused_reduction_count: AtomicU64,
fused_reduction_wall_ns: AtomicU64,
matmul_count: AtomicU64,
matmul_wall_ns: AtomicU64,
linsolve_count: AtomicU64,
linsolve_wall_ns: AtomicU64,
mldivide_count: AtomicU64,
mldivide_wall_ns: AtomicU64,
mrdivide_count: AtomicU64,
mrdivide_wall_ns: AtomicU64,
upload_bytes: AtomicU64,
download_bytes: AtomicU64,
solve_fallbacks: Mutex<HashMap<&'static str, u64>>,
kernel_launches: Mutex<Vec<KernelLaunchTelemetry>>,
}
impl AccelTelemetry {
pub fn new() -> Self {
Self::default()
}
pub fn record_upload_bytes(&self, bytes: u64) {
if bytes > 0 {
self.upload_bytes.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn record_download_bytes(&self, bytes: u64) {
if bytes > 0 {
self.download_bytes.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn record_fused_elementwise(&self, wall_ns: u64) {
self.fused_elementwise_count.fetch_add(1, Ordering::Relaxed);
if wall_ns > 0 {
self.fused_elementwise_wall_ns
.fetch_add(wall_ns, Ordering::Relaxed);
}
}
pub fn record_fused_reduction(&self, wall_ns: u64) {
self.fused_reduction_count.fetch_add(1, Ordering::Relaxed);
if wall_ns > 0 {
self.fused_reduction_wall_ns
.fetch_add(wall_ns, Ordering::Relaxed);
}
}
pub fn record_matmul(&self, wall_ns: u64) {
self.matmul_count.fetch_add(1, Ordering::Relaxed);
if wall_ns > 0 {
self.matmul_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
}
}
pub fn record_linsolve(&self, wall_ns: u64) {
self.linsolve_count.fetch_add(1, Ordering::Relaxed);
if wall_ns > 0 {
self.linsolve_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
}
}
pub fn record_mldivide(&self, wall_ns: u64) {
self.mldivide_count.fetch_add(1, Ordering::Relaxed);
if wall_ns > 0 {
self.mldivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
}
}
pub fn record_mrdivide(&self, wall_ns: u64) {
self.mrdivide_count.fetch_add(1, Ordering::Relaxed);
if wall_ns > 0 {
self.mrdivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
}
}
pub fn record_solve_fallback(&self, reason: &'static str) {
if let Ok(mut guard) = self.solve_fallbacks.lock() {
*guard.entry(reason).or_insert(0) += 1;
}
}
pub fn reset(&self) {
self.fused_elementwise_count.store(0, Ordering::Relaxed);
self.fused_elementwise_wall_ns.store(0, Ordering::Relaxed);
self.fused_reduction_count.store(0, Ordering::Relaxed);
self.fused_reduction_wall_ns.store(0, Ordering::Relaxed);
self.matmul_count.store(0, Ordering::Relaxed);
self.matmul_wall_ns.store(0, Ordering::Relaxed);
self.linsolve_count.store(0, Ordering::Relaxed);
self.linsolve_wall_ns.store(0, Ordering::Relaxed);
self.mldivide_count.store(0, Ordering::Relaxed);
self.mldivide_wall_ns.store(0, Ordering::Relaxed);
self.mrdivide_count.store(0, Ordering::Relaxed);
self.mrdivide_wall_ns.store(0, Ordering::Relaxed);
self.upload_bytes.store(0, Ordering::Relaxed);
self.download_bytes.store(0, Ordering::Relaxed);
if let Ok(mut guard) = self.solve_fallbacks.lock() {
guard.clear();
}
if let Ok(mut guard) = self.kernel_launches.lock() {
guard.clear();
}
}
pub fn snapshot(
&self,
fusion_cache_hits: u64,
fusion_cache_misses: u64,
bind_group_cache_hits: u64,
bind_group_cache_misses: u64,
bind_group_cache_by_layout: Option<Vec<runmat_accelerate_api::BindGroupLayoutTelemetry>>,
) -> ProviderTelemetry {
let kernel_launches = self
.kernel_launches
.lock()
.map(|events| events.clone())
.unwrap_or_default();
let solve_fallbacks = self
.solve_fallbacks
.lock()
.map(|reasons| {
let mut stats: Vec<ProviderFallbackStat> = reasons
.iter()
.map(|(reason, count)| ProviderFallbackStat {
reason: (*reason).to_string(),
count: *count,
})
.collect();
stats.sort_by(|a, b| a.reason.cmp(&b.reason));
stats
})
.unwrap_or_default();
ProviderTelemetry {
fused_elementwise: ProviderDispatchStats {
count: self.fused_elementwise_count.load(Ordering::Relaxed),
total_wall_time_ns: self.fused_elementwise_wall_ns.load(Ordering::Relaxed),
},
fused_reduction: ProviderDispatchStats {
count: self.fused_reduction_count.load(Ordering::Relaxed),
total_wall_time_ns: self.fused_reduction_wall_ns.load(Ordering::Relaxed),
},
matmul: ProviderDispatchStats {
count: self.matmul_count.load(Ordering::Relaxed),
total_wall_time_ns: self.matmul_wall_ns.load(Ordering::Relaxed),
},
linsolve: ProviderDispatchStats {
count: self.linsolve_count.load(Ordering::Relaxed),
total_wall_time_ns: self.linsolve_wall_ns.load(Ordering::Relaxed),
},
mldivide: ProviderDispatchStats {
count: self.mldivide_count.load(Ordering::Relaxed),
total_wall_time_ns: self.mldivide_wall_ns.load(Ordering::Relaxed),
},
mrdivide: ProviderDispatchStats {
count: self.mrdivide_count.load(Ordering::Relaxed),
total_wall_time_ns: self.mrdivide_wall_ns.load(Ordering::Relaxed),
},
upload_bytes: self.upload_bytes.load(Ordering::Relaxed),
download_bytes: self.download_bytes.load(Ordering::Relaxed),
solve_fallbacks,
fusion_cache_hits,
fusion_cache_misses,
bind_group_cache_hits,
bind_group_cache_misses,
bind_group_cache_by_layout,
kernel_launches,
}
}
}
fn saturating_duration_ns(duration: std::time::Duration) -> u64 {
duration.as_nanos().min(u64::MAX as u128) as u64
}
impl AccelTelemetry {
pub fn record_fused_elementwise_duration(&self, duration: std::time::Duration) {
self.record_fused_elementwise(saturating_duration_ns(duration));
}
pub fn record_fused_reduction_duration(&self, duration: std::time::Duration) {
self.record_fused_reduction(saturating_duration_ns(duration));
}
pub fn record_matmul_duration(&self, duration: std::time::Duration) {
self.record_matmul(saturating_duration_ns(duration));
}
pub fn record_linsolve_duration(&self, duration: std::time::Duration) {
self.record_linsolve(saturating_duration_ns(duration));
}
pub fn record_mldivide_duration(&self, duration: std::time::Duration) {
self.record_mldivide(saturating_duration_ns(duration));
}
pub fn record_mrdivide_duration(&self, duration: std::time::Duration) {
self.record_mrdivide(saturating_duration_ns(duration));
}
pub fn record_kernel_launch(
&self,
kernel: &'static str,
precision: Option<&str>,
shape: &[(&str, u64)],
tuning: &[(&str, u64)],
) {
let event = KernelLaunchTelemetry {
kernel: kernel.to_string(),
precision: precision.map(|p| p.to_string()),
shape: Self::pairs_to_attrs(shape),
tuning: Self::pairs_to_attrs(tuning),
};
if let Ok(mut guard) = self.kernel_launches.lock() {
if guard.len() >= MAX_KERNEL_LAUNCH_EVENTS {
let drop = guard.len() + 1 - MAX_KERNEL_LAUNCH_EVENTS;
guard.drain(0..drop);
}
guard.push(event);
}
}
fn pairs_to_attrs(pairs: &[(&str, u64)]) -> Vec<KernelAttrTelemetry> {
pairs
.iter()
.map(|(k, v)| KernelAttrTelemetry {
key: (*k).to_string(),
value: *v,
})
.collect()
}
}