1use std::collections::HashMap;
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 Mutex,
5};
6
7use runmat_accelerate_api::{
8 KernelAttrTelemetry, KernelLaunchTelemetry, ProviderDispatchStats, ProviderFallbackStat,
9 ProviderTelemetry,
10};
11
12const MAX_KERNEL_LAUNCH_EVENTS: usize = 64;
13
14#[derive(Default)]
15pub struct AccelTelemetry {
16 fused_elementwise_count: AtomicU64,
17 fused_elementwise_wall_ns: AtomicU64,
18 fused_reduction_count: AtomicU64,
19 fused_reduction_wall_ns: AtomicU64,
20 matmul_count: AtomicU64,
21 matmul_wall_ns: AtomicU64,
22 linsolve_count: AtomicU64,
23 linsolve_wall_ns: AtomicU64,
24 mldivide_count: AtomicU64,
25 mldivide_wall_ns: AtomicU64,
26 mrdivide_count: AtomicU64,
27 mrdivide_wall_ns: AtomicU64,
28 upload_bytes: AtomicU64,
29 download_bytes: AtomicU64,
30 solve_fallbacks: Mutex<HashMap<&'static str, u64>>,
31 kernel_launches: Mutex<Vec<KernelLaunchTelemetry>>,
32}
33
34impl AccelTelemetry {
35 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn record_upload_bytes(&self, bytes: u64) {
40 if bytes > 0 {
41 self.upload_bytes.fetch_add(bytes, Ordering::Relaxed);
42 }
43 }
44
45 pub fn record_download_bytes(&self, bytes: u64) {
46 if bytes > 0 {
47 self.download_bytes.fetch_add(bytes, Ordering::Relaxed);
48 }
49 }
50
51 pub fn record_fused_elementwise(&self, wall_ns: u64) {
52 self.fused_elementwise_count.fetch_add(1, Ordering::Relaxed);
53 if wall_ns > 0 {
54 self.fused_elementwise_wall_ns
55 .fetch_add(wall_ns, Ordering::Relaxed);
56 }
57 }
58
59 pub fn record_fused_reduction(&self, wall_ns: u64) {
60 self.fused_reduction_count.fetch_add(1, Ordering::Relaxed);
61 if wall_ns > 0 {
62 self.fused_reduction_wall_ns
63 .fetch_add(wall_ns, Ordering::Relaxed);
64 }
65 }
66
67 pub fn record_matmul(&self, wall_ns: u64) {
68 self.matmul_count.fetch_add(1, Ordering::Relaxed);
69 if wall_ns > 0 {
70 self.matmul_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
71 }
72 }
73
74 pub fn record_linsolve(&self, wall_ns: u64) {
75 self.linsolve_count.fetch_add(1, Ordering::Relaxed);
76 if wall_ns > 0 {
77 self.linsolve_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
78 }
79 }
80
81 pub fn record_mldivide(&self, wall_ns: u64) {
82 self.mldivide_count.fetch_add(1, Ordering::Relaxed);
83 if wall_ns > 0 {
84 self.mldivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
85 }
86 }
87
88 pub fn record_mrdivide(&self, wall_ns: u64) {
89 self.mrdivide_count.fetch_add(1, Ordering::Relaxed);
90 if wall_ns > 0 {
91 self.mrdivide_wall_ns.fetch_add(wall_ns, Ordering::Relaxed);
92 }
93 }
94
95 pub fn record_solve_fallback(&self, reason: &'static str) {
96 if let Ok(mut guard) = self.solve_fallbacks.lock() {
97 *guard.entry(reason).or_insert(0) += 1;
98 }
99 }
100
101 pub fn reset(&self) {
102 self.fused_elementwise_count.store(0, Ordering::Relaxed);
103 self.fused_elementwise_wall_ns.store(0, Ordering::Relaxed);
104 self.fused_reduction_count.store(0, Ordering::Relaxed);
105 self.fused_reduction_wall_ns.store(0, Ordering::Relaxed);
106 self.matmul_count.store(0, Ordering::Relaxed);
107 self.matmul_wall_ns.store(0, Ordering::Relaxed);
108 self.linsolve_count.store(0, Ordering::Relaxed);
109 self.linsolve_wall_ns.store(0, Ordering::Relaxed);
110 self.mldivide_count.store(0, Ordering::Relaxed);
111 self.mldivide_wall_ns.store(0, Ordering::Relaxed);
112 self.mrdivide_count.store(0, Ordering::Relaxed);
113 self.mrdivide_wall_ns.store(0, Ordering::Relaxed);
114 self.upload_bytes.store(0, Ordering::Relaxed);
115 self.download_bytes.store(0, Ordering::Relaxed);
116 if let Ok(mut guard) = self.solve_fallbacks.lock() {
117 guard.clear();
118 }
119 if let Ok(mut guard) = self.kernel_launches.lock() {
120 guard.clear();
121 }
122 }
123
124 pub fn snapshot(
125 &self,
126 fusion_cache_hits: u64,
127 fusion_cache_misses: u64,
128 bind_group_cache_hits: u64,
129 bind_group_cache_misses: u64,
130 bind_group_cache_by_layout: Option<Vec<runmat_accelerate_api::BindGroupLayoutTelemetry>>,
131 ) -> ProviderTelemetry {
132 let kernel_launches = self
133 .kernel_launches
134 .lock()
135 .map(|events| events.clone())
136 .unwrap_or_default();
137 let solve_fallbacks = self
138 .solve_fallbacks
139 .lock()
140 .map(|reasons| {
141 let mut stats: Vec<ProviderFallbackStat> = reasons
142 .iter()
143 .map(|(reason, count)| ProviderFallbackStat {
144 reason: (*reason).to_string(),
145 count: *count,
146 })
147 .collect();
148 stats.sort_by(|a, b| a.reason.cmp(&b.reason));
149 stats
150 })
151 .unwrap_or_default();
152 ProviderTelemetry {
153 fused_elementwise: ProviderDispatchStats {
154 count: self.fused_elementwise_count.load(Ordering::Relaxed),
155 total_wall_time_ns: self.fused_elementwise_wall_ns.load(Ordering::Relaxed),
156 },
157 fused_reduction: ProviderDispatchStats {
158 count: self.fused_reduction_count.load(Ordering::Relaxed),
159 total_wall_time_ns: self.fused_reduction_wall_ns.load(Ordering::Relaxed),
160 },
161 matmul: ProviderDispatchStats {
162 count: self.matmul_count.load(Ordering::Relaxed),
163 total_wall_time_ns: self.matmul_wall_ns.load(Ordering::Relaxed),
164 },
165 linsolve: ProviderDispatchStats {
166 count: self.linsolve_count.load(Ordering::Relaxed),
167 total_wall_time_ns: self.linsolve_wall_ns.load(Ordering::Relaxed),
168 },
169 mldivide: ProviderDispatchStats {
170 count: self.mldivide_count.load(Ordering::Relaxed),
171 total_wall_time_ns: self.mldivide_wall_ns.load(Ordering::Relaxed),
172 },
173 mrdivide: ProviderDispatchStats {
174 count: self.mrdivide_count.load(Ordering::Relaxed),
175 total_wall_time_ns: self.mrdivide_wall_ns.load(Ordering::Relaxed),
176 },
177 upload_bytes: self.upload_bytes.load(Ordering::Relaxed),
178 download_bytes: self.download_bytes.load(Ordering::Relaxed),
179 solve_fallbacks,
180 fusion_cache_hits,
181 fusion_cache_misses,
182 bind_group_cache_hits,
183 bind_group_cache_misses,
184 bind_group_cache_by_layout,
185 kernel_launches,
186 }
187 }
188}
189
190fn saturating_duration_ns(duration: std::time::Duration) -> u64 {
191 duration.as_nanos().min(u64::MAX as u128) as u64
192}
193
194impl AccelTelemetry {
195 pub fn record_fused_elementwise_duration(&self, duration: std::time::Duration) {
196 self.record_fused_elementwise(saturating_duration_ns(duration));
197 }
198
199 pub fn record_fused_reduction_duration(&self, duration: std::time::Duration) {
200 self.record_fused_reduction(saturating_duration_ns(duration));
201 }
202
203 pub fn record_matmul_duration(&self, duration: std::time::Duration) {
204 self.record_matmul(saturating_duration_ns(duration));
205 }
206
207 pub fn record_linsolve_duration(&self, duration: std::time::Duration) {
208 self.record_linsolve(saturating_duration_ns(duration));
209 }
210
211 pub fn record_mldivide_duration(&self, duration: std::time::Duration) {
212 self.record_mldivide(saturating_duration_ns(duration));
213 }
214
215 pub fn record_mrdivide_duration(&self, duration: std::time::Duration) {
216 self.record_mrdivide(saturating_duration_ns(duration));
217 }
218
219 pub fn record_kernel_launch(
220 &self,
221 kernel: &'static str,
222 precision: Option<&str>,
223 shape: &[(&str, u64)],
224 tuning: &[(&str, u64)],
225 ) {
226 let event = KernelLaunchTelemetry {
227 kernel: kernel.to_string(),
228 precision: precision.map(|p| p.to_string()),
229 shape: Self::pairs_to_attrs(shape),
230 tuning: Self::pairs_to_attrs(tuning),
231 };
232 if let Ok(mut guard) = self.kernel_launches.lock() {
233 if guard.len() >= MAX_KERNEL_LAUNCH_EVENTS {
234 let drop = guard.len() + 1 - MAX_KERNEL_LAUNCH_EVENTS;
235 guard.drain(0..drop);
236 }
237 guard.push(event);
238 }
239 }
240
241 fn pairs_to_attrs(pairs: &[(&str, u64)]) -> Vec<KernelAttrTelemetry> {
242 pairs
243 .iter()
244 .map(|(k, v)| KernelAttrTelemetry {
245 key: (*k).to_string(),
246 value: *v,
247 })
248 .collect()
249 }
250}