Skip to main content

runmat_accelerate/
telemetry.rs

1use std::collections::HashMap;
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    Mutex,
5};
6
7use runmat_accelerate_api::{
8    KernelAttrTelemetry, KernelLaunchTelemetry, ProviderDispatchStats, ProviderFallbackStat,
9    ProviderTelemetry,
10};
11
12const MAX_KERNEL_LAUNCH_EVENTS: usize = 64;
13
14#[derive(Default)]
15pub struct AccelTelemetry {
16    fused_elementwise_count: AtomicU64,
17    fused_elementwise_wall_ns: AtomicU64,
18    fused_reduction_count: AtomicU64,
19    fused_reduction_wall_ns: AtomicU64,
20    matmul_count: AtomicU64,
21    matmul_wall_ns: AtomicU64,
22    linsolve_count: AtomicU64,
23    linsolve_wall_ns: AtomicU64,
24    mldivide_count: AtomicU64,
25    mldivide_wall_ns: AtomicU64,
26    mrdivide_count: AtomicU64,
27    mrdivide_wall_ns: AtomicU64,
28    upload_bytes: AtomicU64,
29    download_bytes: AtomicU64,
30    solve_fallbacks: Mutex<HashMap<&'static str, u64>>,
31    kernel_launches: Mutex<Vec<KernelLaunchTelemetry>>,
32}
33
34impl AccelTelemetry {
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    pub fn record_upload_bytes(&self, bytes: u64) {
40        if bytes > 0 {
41            self.upload_bytes.fetch_add(bytes, Ordering::Relaxed);
42        }
43    }
44
45    pub fn record_download_bytes(&self, bytes: u64) {
46        if bytes > 0 {
47            self.download_bytes.fetch_add(bytes, Ordering::Relaxed);
48        }
49    }
50
51    pub fn record_fused_elementwise(&self, wall_ns: u64) {
52        self.fused_elementwise_count.fetch_add(1, Ordering::Relaxed);
53        if wall_ns > 0 {
54            self.fused_elementwise_wall_ns
55                .fetch_add(wall_ns, Ordering::Relaxed);
56        }
57    }
58
59    pub fn record_fused_reduction(&self, wall_ns: u64) {
60        self.fused_reduction_count.fetch_add(1, Ordering::Relaxed);
61        if wall_ns > 0 {
62            self.fused_reduction_wall_ns
63                .fetch_add(wall_ns, Ordering::Relaxed);
64        }
65    }
66
67    pub fn record_matmul(&self, wall_ns: u64) {
68        self.matmul_count.fetch_add(1, Ordering::Relaxed);
69        if wall_ns > 0 {
70            self.matmul_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
71        }
72    }
73
74    pub fn record_linsolve(&self, wall_ns: u64) {
75        self.linsolve_count.fetch_add(1, Ordering::Relaxed);
76        if wall_ns > 0 {
77            self.linsolve_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
78        }
79    }
80
81    pub fn record_mldivide(&self, wall_ns: u64) {
82        self.mldivide_count.fetch_add(1, Ordering::Relaxed);
83        if wall_ns > 0 {
84            self.mldivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
85        }
86    }
87
88    pub fn record_mrdivide(&self, wall_ns: u64) {
89        self.mrdivide_count.fetch_add(1, Ordering::Relaxed);
90        if wall_ns > 0 {
91            self.mrdivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
92        }
93    }
94
95    pub fn record_solve_fallback(&self, reason: &'static str) {
96        if let Ok(mut guard) = self.solve_fallbacks.lock() {
97            *guard.entry(reason).or_insert(0) += 1;
98        }
99    }
100
101    pub fn reset(&self) {
102        self.fused_elementwise_count.store(0, Ordering::Relaxed);
103        self.fused_elementwise_wall_ns.store(0, Ordering::Relaxed);
104        self.fused_reduction_count.store(0, Ordering::Relaxed);
105        self.fused_reduction_wall_ns.store(0, Ordering::Relaxed);
106        self.matmul_count.store(0, Ordering::Relaxed);
107        self.matmul_wall_ns.store(0, Ordering::Relaxed);
108        self.linsolve_count.store(0, Ordering::Relaxed);
109        self.linsolve_wall_ns.store(0, Ordering::Relaxed);
110        self.mldivide_count.store(0, Ordering::Relaxed);
111        self.mldivide_wall_ns.store(0, Ordering::Relaxed);
112        self.mrdivide_count.store(0, Ordering::Relaxed);
113        self.mrdivide_wall_ns.store(0, Ordering::Relaxed);
114        self.upload_bytes.store(0, Ordering::Relaxed);
115        self.download_bytes.store(0, Ordering::Relaxed);
116        if let Ok(mut guard) = self.solve_fallbacks.lock() {
117            guard.clear();
118        }
119        if let Ok(mut guard) = self.kernel_launches.lock() {
120            guard.clear();
121        }
122    }
123
124    pub fn snapshot(
125        &self,
126        fusion_cache_hits: u64,
127        fusion_cache_misses: u64,
128        bind_group_cache_hits: u64,
129        bind_group_cache_misses: u64,
130        bind_group_cache_by_layout: Option<Vec<runmat_accelerate_api::BindGroupLayoutTelemetry>>,
131    ) -> ProviderTelemetry {
132        let kernel_launches = self
133            .kernel_launches
134            .lock()
135            .map(|events| events.clone())
136            .unwrap_or_default();
137        let solve_fallbacks = self
138            .solve_fallbacks
139            .lock()
140            .map(|reasons| {
141                let mut stats: Vec<ProviderFallbackStat> = reasons
142                    .iter()
143                    .map(|(reason, count)| ProviderFallbackStat {
144                        reason: (*reason).to_string(),
145                        count: *count,
146                    })
147                    .collect();
148                stats.sort_by(|a, b| a.reason.cmp(&b.reason));
149                stats
150            })
151            .unwrap_or_default();
152        ProviderTelemetry {
153            fused_elementwise: ProviderDispatchStats {
154                count: self.fused_elementwise_count.load(Ordering::Relaxed),
155                total_wall_time_ns: self.fused_elementwise_wall_ns.load(Ordering::Relaxed),
156            },
157            fused_reduction: ProviderDispatchStats {
158                count: self.fused_reduction_count.load(Ordering::Relaxed),
159                total_wall_time_ns: self.fused_reduction_wall_ns.load(Ordering::Relaxed),
160            },
161            matmul: ProviderDispatchStats {
162                count: self.matmul_count.load(Ordering::Relaxed),
163                total_wall_time_ns: self.matmul_wall_ns.load(Ordering::Relaxed),
164            },
165            linsolve: ProviderDispatchStats {
166                count: self.linsolve_count.load(Ordering::Relaxed),
167                total_wall_time_ns: self.linsolve_wall_ns.load(Ordering::Relaxed),
168            },
169            mldivide: ProviderDispatchStats {
170                count: self.mldivide_count.load(Ordering::Relaxed),
171                total_wall_time_ns: self.mldivide_wall_ns.load(Ordering::Relaxed),
172            },
173            mrdivide: ProviderDispatchStats {
174                count: self.mrdivide_count.load(Ordering::Relaxed),
175                total_wall_time_ns: self.mrdivide_wall_ns.load(Ordering::Relaxed),
176            },
177            upload_bytes: self.upload_bytes.load(Ordering::Relaxed),
178            download_bytes: self.download_bytes.load(Ordering::Relaxed),
179            solve_fallbacks,
180            fusion_cache_hits,
181            fusion_cache_misses,
182            bind_group_cache_hits,
183            bind_group_cache_misses,
184            bind_group_cache_by_layout,
185            kernel_launches,
186        }
187    }
188}
189
190fn saturating_duration_ns(duration: std::time::Duration) -> u64 {
191    duration.as_nanos().min(u64::MAX as u128) as u64
192}
193
194impl AccelTelemetry {
195    pub fn record_fused_elementwise_duration(&self, duration: std::time::Duration) {
196        self.record_fused_elementwise(saturating_duration_ns(duration));
197    }
198
199    pub fn record_fused_reduction_duration(&self, duration: std::time::Duration) {
200        self.record_fused_reduction(saturating_duration_ns(duration));
201    }
202
203    pub fn record_matmul_duration(&self, duration: std::time::Duration) {
204        self.record_matmul(saturating_duration_ns(duration));
205    }
206
207    pub fn record_linsolve_duration(&self, duration: std::time::Duration) {
208        self.record_linsolve(saturating_duration_ns(duration));
209    }
210
211    pub fn record_mldivide_duration(&self, duration: std::time::Duration) {
212        self.record_mldivide(saturating_duration_ns(duration));
213    }
214
215    pub fn record_mrdivide_duration(&self, duration: std::time::Duration) {
216        self.record_mrdivide(saturating_duration_ns(duration));
217    }
218
219    pub fn record_kernel_launch(
220        &self,
221        kernel: &'static str,
222        precision: Option<&str>,
223        shape: &[(&str, u64)],
224        tuning: &[(&str, u64)],
225    ) {
226        let event = KernelLaunchTelemetry {
227            kernel: kernel.to_string(),
228            precision: precision.map(|p| p.to_string()),
229            shape: Self::pairs_to_attrs(shape),
230            tuning: Self::pairs_to_attrs(tuning),
231        };
232        if let Ok(mut guard) = self.kernel_launches.lock() {
233            if guard.len() >= MAX_KERNEL_LAUNCH_EVENTS {
234                let drop = guard.len() + 1 - MAX_KERNEL_LAUNCH_EVENTS;
235                guard.drain(0..drop);
236            }
237            guard.push(event);
238        }
239    }
240
241    fn pairs_to_attrs(pairs: &[(&str, u64)]) -> Vec<KernelAttrTelemetry> {
242        pairs
243            .iter()
244            .map(|(k, v)| KernelAttrTelemetry {
245                key: (*k).to_string(),
246                value: *v,
247            })
248            .collect()
249    }
250}