1use scheduler::Instant;
2use std::{
3 cell::LazyCell,
4 collections::{HashMap, VecDeque},
5 hash::{DefaultHasher, Hash, Hasher},
6 sync::{
7 Arc,
8 atomic::{AtomicBool, Ordering},
9 },
10 thread::ThreadId,
11};
12
13use serde::{Deserialize, Serialize};
14
15use crate::SharedString;
16
17#[doc(hidden)]
18#[derive(Debug, Copy, Clone)]
19pub struct TaskTiming {
20 pub location: &'static core::panic::Location<'static>,
21 pub start: Instant,
22 pub end: Option<Instant>,
23}
24
25#[doc(hidden)]
26#[derive(Debug, Clone)]
27pub struct ThreadTaskTimings {
28 pub thread_name: Option<String>,
29 pub thread_id: ThreadId,
30 pub timings: Vec<TaskTiming>,
31 pub total_pushed: u64,
32}
33
34impl ThreadTaskTimings {
35 pub fn convert(timings: &[GlobalThreadTimings]) -> Vec<Self> {
37 timings
38 .iter()
39 .filter_map(|t| match t.timings.upgrade() {
40 Some(timings) => Some((t.thread_id, timings)),
41 _ => None,
42 })
43 .map(|(thread_id, timings)| {
44 let timings = timings.lock();
45 let thread_name = timings.thread_name.clone();
46 let total_pushed = timings.total_pushed;
47 let timings = &timings.timings;
48
49 let mut vec = Vec::with_capacity(timings.len());
50 let (s1, s2) = timings.as_slices();
51 vec.extend_from_slice(s1);
52 vec.extend_from_slice(s2);
53
54 ThreadTaskTimings {
55 thread_name,
56 thread_id,
57 timings: vec,
58 total_pushed,
59 }
60 })
61 .collect()
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct SerializedLocation {
68 pub file: SharedString,
70 pub line: u32,
72 pub column: u32,
74}
75
76impl From<&core::panic::Location<'static>> for SerializedLocation {
77 fn from(value: &core::panic::Location<'static>) -> Self {
78 SerializedLocation {
79 file: value.file().into(),
80 line: value.line(),
81 column: value.column(),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SerializedTaskTiming {
89 pub location: SerializedLocation,
91 pub start: u128,
93 pub duration: u128,
95}
96
97impl SerializedTaskTiming {
98 pub fn convert(anchor: Instant, timings: &[TaskTiming]) -> Vec<SerializedTaskTiming> {
104 let serialized = timings
105 .iter()
106 .map(|timing| {
107 let start = timing.start.duration_since(anchor).as_nanos();
108 let duration = timing
109 .end
110 .unwrap_or_else(|| Instant::now())
111 .duration_since(timing.start)
112 .as_nanos();
113 SerializedTaskTiming {
114 location: timing.location.into(),
115 start,
116 duration,
117 }
118 })
119 .collect::<Vec<_>>();
120
121 serialized
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct SerializedThreadTaskTimings {
128 pub thread_name: Option<String>,
130 pub thread_id: u64,
132 pub timings: Vec<SerializedTaskTiming>,
134}
135
136impl SerializedThreadTaskTimings {
137 pub fn convert(anchor: Instant, timings: ThreadTaskTimings) -> SerializedThreadTaskTimings {
143 let serialized_timings = SerializedTaskTiming::convert(anchor, &timings.timings);
144
145 let mut hasher = DefaultHasher::new();
146 timings.thread_id.hash(&mut hasher);
147 let thread_id = hasher.finish();
148
149 SerializedThreadTaskTimings {
150 thread_name: timings.thread_name,
151 thread_id,
152 timings: serialized_timings,
153 }
154 }
155}
156
157#[doc(hidden)]
158#[derive(Debug, Clone)]
159pub struct ThreadTimingsDelta {
160 pub thread_id: u64,
162 pub thread_name: Option<String>,
164 pub new_timings: Vec<SerializedTaskTiming>,
167}
168
169#[doc(hidden)]
171pub struct ProfilingCollector {
172 startup_time: Instant,
173 cursors: HashMap<ThreadId, u64>,
174}
175
176impl ProfilingCollector {
177 pub fn new(startup_time: Instant) -> Self {
178 Self {
179 startup_time,
180 cursors: HashMap::default(),
181 }
182 }
183
184 pub fn startup_time(&self) -> Instant {
185 self.startup_time
186 }
187
188 pub fn collect_unseen(
189 &mut self,
190 all_timings: Vec<ThreadTaskTimings>,
191 ) -> Vec<ThreadTimingsDelta> {
192 let mut deltas = Vec::with_capacity(all_timings.len());
193
194 for thread in all_timings {
195 let mut hasher = DefaultHasher::new();
196 thread.thread_id.hash(&mut hasher);
197 let hashed_id = hasher.finish();
198
199 let prev_cursor = self.cursors.get(&thread.thread_id).copied().unwrap_or(0);
200 let buffer_len = thread.timings.len() as u64;
201 let buffer_start = thread.total_pushed.saturating_sub(buffer_len);
202
203 let mut slice = if prev_cursor < buffer_start {
204 thread.timings.as_slice()
207 } else {
208 let skip = (prev_cursor - buffer_start) as usize;
209 &thread.timings[skip.min(thread.timings.len())..]
210 };
211
212 let incomplete_at_end = slice.last().is_some_and(|t| t.end.is_none());
214 if incomplete_at_end {
215 slice = &slice[..slice.len() - 1];
216 }
217
218 let cursor_advance = if incomplete_at_end {
219 thread.total_pushed.saturating_sub(1)
220 } else {
221 thread.total_pushed
222 };
223
224 self.cursors.insert(thread.thread_id, cursor_advance);
225
226 if slice.is_empty() {
227 continue;
228 }
229
230 let new_timings = SerializedTaskTiming::convert(self.startup_time, slice);
231
232 deltas.push(ThreadTimingsDelta {
233 thread_id: hashed_id,
234 thread_name: thread.thread_name,
235 new_timings,
236 });
237 }
238
239 deltas
240 }
241
242 pub fn reset(&mut self) {
243 self.cursors.clear();
244 }
245}
246
247const MAX_TASK_TIMINGS: usize = (16 * 1024 * 1024) / core::mem::size_of::<TaskTiming>();
251
252#[doc(hidden)]
253pub(crate) type TaskTimings = VecDeque<TaskTiming>;
254
255#[doc(hidden)]
256pub type GuardedTaskTimings = spin::Mutex<ThreadTimings>;
257
258#[doc(hidden)]
259pub struct GlobalThreadTimings {
260 pub thread_id: ThreadId,
261 pub timings: std::sync::Weak<GuardedTaskTimings>,
262}
263
264#[doc(hidden)]
265pub static GLOBAL_THREAD_TIMINGS: spin::Mutex<Vec<GlobalThreadTimings>> =
266 spin::Mutex::new(Vec::new());
267
268thread_local! {
269 #[doc(hidden)]
270 pub static THREAD_TIMINGS: LazyCell<Arc<GuardedTaskTimings>> = LazyCell::new(|| {
271 let current_thread = std::thread::current();
272 let thread_name = current_thread.name();
273 let thread_id = current_thread.id();
274 let timings = ThreadTimings::new(thread_name.map(|e| e.to_string()), thread_id);
275 let timings = Arc::new(spin::Mutex::new(timings));
276
277 {
278 let timings = Arc::downgrade(&timings);
279 let global_timings = GlobalThreadTimings {
280 thread_id: std::thread::current().id(),
281 timings,
282 };
283 GLOBAL_THREAD_TIMINGS.lock().push(global_timings);
284 }
285
286 timings
287 });
288}
289
290#[doc(hidden)]
291pub struct ThreadTimings {
292 pub thread_name: Option<String>,
293 pub thread_id: ThreadId,
294 pub timings: TaskTimings,
295 pub total_pushed: u64,
296}
297
298impl ThreadTimings {
299 pub fn new(thread_name: Option<String>, thread_id: ThreadId) -> Self {
300 ThreadTimings {
301 thread_name,
302 thread_id,
303 timings: TaskTimings::new(),
304 total_pushed: 0,
305 }
306 }
307
308 pub fn add_task_timing(&mut self, timing: TaskTiming) {
312 if let Some(last_timing) = self.timings.back_mut()
313 && last_timing.location == timing.location
314 && last_timing.start == timing.start
315 {
316 last_timing.end = timing.end;
317 } else {
318 while self.timings.len() + 1 > MAX_TASK_TIMINGS {
319 self.timings.pop_front();
321 }
322 self.timings.push_back(timing);
323 self.total_pushed += 1;
324 }
325 }
326
327 pub fn get_thread_task_timings(&self) -> ThreadTaskTimings {
328 ThreadTaskTimings {
329 thread_name: self.thread_name.clone(),
330 thread_id: self.thread_id,
331 timings: self.timings.iter().cloned().collect(),
332 total_pushed: self.total_pushed,
333 }
334 }
335}
336
337impl Drop for ThreadTimings {
338 fn drop(&mut self) {
339 let mut thread_timings = GLOBAL_THREAD_TIMINGS.lock();
340
341 let Some((index, _)) = thread_timings
342 .iter()
343 .enumerate()
344 .find(|(_, t)| t.thread_id == self.thread_id)
345 else {
346 return;
347 };
348 thread_timings.swap_remove(index);
349 }
350}
351
352#[doc(hidden)]
353pub fn add_task_timing(timing: TaskTiming) {
354 if !PROFILER_ENABLED.load(Ordering::Acquire) {
355 return;
356 }
357 THREAD_TIMINGS.with(|timings| {
358 timings.lock().add_task_timing(timing);
359 });
360}
361
362#[doc(hidden)]
363pub fn get_current_thread_task_timings() -> ThreadTaskTimings {
364 THREAD_TIMINGS.with(|timings| timings.lock().get_thread_task_timings())
365}
366
367static PROFILER_ENABLED: AtomicBool = AtomicBool::new(false);
368
369pub fn set_enabled(enabled: bool) -> bool {
375 if PROFILER_ENABLED.swap(enabled, Ordering::AcqRel) == enabled {
376 return false;
377 }
378
379 if !enabled {
380 for global in GLOBAL_THREAD_TIMINGS.lock().iter() {
381 if let Some(timings) = global.timings.upgrade() {
382 let mut timings = timings.lock();
383 timings.timings.clear();
384 timings.timings.shrink_to_fit();
385 timings.total_pushed = 0;
386 }
387 }
388 }
389 true
390}