Skip to main content

open_gpui/
profiler.rs

1use itertools::Itertools;
2use open_gpui_scheduler::{Instant, SpawnTime};
3use std::{
4    cell::LazyCell,
5    collections::{HashMap, VecDeque},
6    hash::{DefaultHasher, Hash, Hasher},
7    hint::cold_path,
8    sync::{
9        Arc,
10        atomic::{AtomicBool, Ordering},
11    },
12    thread::ThreadId,
13    time::Duration,
14};
15
16mod actions;
17pub use actions::{ActionStatistics, ActionTiming, take_action_stats};
18pub(crate) use actions::{save_action_timing, update_running_action};
19
20use serde::{Deserialize, Serialize};
21
22use crate::{SharedString, TasksIncluded};
23
24#[doc(hidden)]
25pub fn get_all_timings(included: open_gpui::TasksIncluded) -> Vec<open_gpui::ThreadTaskTimings> {
26    let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
27    ThreadTaskTimings::collect(&global_thread_timings, included)
28}
29
30#[doc(hidden)]
31pub fn get_current_thread_timings(included: TasksIncluded) -> open_gpui::ThreadTaskTimings {
32    open_gpui::profiler::get_current_thread_task_timings(included)
33}
34
35#[doc(hidden)]
36pub fn take_all_stats(included: TasksIncluded) -> Vec<open_gpui::ThreadTaskStatistics> {
37    let global_timings = GLOBAL_THREAD_TIMINGS.lock();
38    ThreadTaskStatistics::collect_and_reset(&global_timings, included)
39}
40
41#[doc(hidden)]
42#[derive(Debug, Copy, Clone)]
43pub struct YieldTime(pub Instant);
44
45#[doc(hidden)]
46#[derive(Copy, Clone)]
47pub struct TaskTiming {
48    pub location: &'static core::panic::Location<'static>,
49    pub spawned: SpawnTime,
50    pub start: Instant,
51    pub end: YieldTime,
52}
53
54impl std::fmt::Debug for TaskTiming {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("TaskTiming")
57            .field("location", &self.location)
58            .field("since_spawned", &self.spawned.0.elapsed())
59            .field("last_poll_duration", &self.poll_duration())
60            .field("total_runtime", &self.since_spawn())
61            .finish()
62    }
63}
64
65#[doc(hidden)]
66#[derive(Debug, Copy, Clone)]
67pub struct ActiveTiming {
68    pub location: &'static core::panic::Location<'static>,
69    pub spawned: SpawnTime,
70    pub start: Instant,
71}
72
73impl TaskTiming {
74    /// A task timing with a duration of zero. Any task will replace this in history.
75    pub fn placeholder() -> Self {
76        let now = Instant::now();
77        Self {
78            location: std::panic::Location::caller(),
79            spawned: SpawnTime(now),
80            start: now,
81            end: YieldTime(now),
82        }
83    }
84
85    #[inline(always)]
86    pub fn poll_duration(&self) -> Duration {
87        self.end.0 - self.start
88    }
89
90    #[inline(always)]
91    fn since_spawn(&self) -> Duration {
92        self.end.0 - self.spawned.0
93    }
94}
95
96#[doc(hidden)]
97#[derive(Debug, Clone)]
98pub struct ThreadTaskTimings {
99    pub thread_name: Option<String>,
100    pub thread_id: ThreadId,
101    pub timings: Vec<TaskTiming>,
102    pub stats: TaskStatistics,
103    pub total_pushed: u64,
104}
105
106impl ThreadTaskTimings {
107    /// Convert global thread timings into their structured format.
108    pub fn collect(timings: &[GlobalThreadTimings], included: TasksIncluded) -> Vec<Self> {
109        timings
110            .iter()
111            .filter_map(|t| match t.timings.upgrade() {
112                Some(timings) => Some((t.thread_id, timings)),
113                _ => None,
114            })
115            .map(|(thread_id, timings)| {
116                let timings = timings.lock();
117                let thread_name = timings.thread_name.clone();
118                let total_pushed = timings.total_pushed;
119                let completed = &timings.timings;
120
121                let mut vec = Vec::with_capacity(completed.len() + 1); // +1 for running task
122                let (s1, s2) = completed.as_slices();
123                vec.extend_from_slice(s1);
124                vec.extend_from_slice(s2);
125                if let TasksIncluded::CompletedAndRunning = included
126                    && let Some(running) = timings.running
127                {
128                    vec.push(TaskTiming {
129                        location: running.location,
130                        spawned: running.spawned,
131                        start: running.start,
132                        end: YieldTime(Instant::now()),
133                    })
134                }
135
136                ThreadTaskTimings {
137                    thread_name,
138                    thread_id,
139                    timings: vec,
140                    stats: timings.stats.clone(),
141                    total_pushed,
142                }
143            })
144            .collect()
145    }
146}
147
148#[doc(hidden)]
149#[derive(Debug)]
150pub struct ThreadTaskStatistics {
151    pub thread_name: Option<String>,
152    pub thread_id: ThreadId,
153    pub stats: TaskStatistics,
154}
155
156impl ThreadTaskStatistics {
157    pub fn collect_and_reset(
158        timings: &[GlobalThreadTimings],
159        include_running: TasksIncluded,
160    ) -> Vec<Self> {
161        timings
162            .iter()
163            .filter_map(|t| match t.timings.upgrade() {
164                Some(timings) => Some((t.thread_id, timings)),
165                _ => None,
166            })
167            .map(|(thread_id, timings)| {
168                let mut timings = timings.lock();
169                let thread_name = timings.thread_name.clone();
170
171                let mut stats = std::mem::take(&mut timings.stats);
172                if let TasksIncluded::CompletedAndRunning = include_running
173                    && let Some(ActiveTiming {
174                        location,
175                        spawned,
176                        start,
177                    }) = timings.running
178                {
179                    let end = YieldTime(Instant::now());
180                    let timing = TaskTiming {
181                        location,
182                        spawned,
183                        start,
184                        end,
185                    };
186                    stats.add_runtime(timing);
187                    stats.add_yield_timing(timing);
188                }
189
190                Self {
191                    thread_name,
192                    thread_id,
193                    stats,
194                }
195            })
196            .collect()
197    }
198}
199
200/// Serializable variant of [`core::panic::Location`]
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct SerializedLocation {
203    /// Name of the source file
204    pub file: SharedString,
205    /// Line in the source file
206    pub line: u32,
207    /// Column in the source file
208    pub column: u32,
209}
210
211impl From<&core::panic::Location<'static>> for SerializedLocation {
212    fn from(value: &core::panic::Location<'static>) -> Self {
213        SerializedLocation {
214            file: value.file().into(),
215            line: value.line(),
216            column: value.column(),
217        }
218    }
219}
220
221/// Serializable variant of [`TaskTiming`]
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct SerializedTaskTiming {
224    /// Location of the timing
225    pub location: SerializedLocation,
226    /// Time at which the measurement was reported in nanoseconds
227    pub start: u128,
228    /// Duration of the measurement in nanoseconds
229    pub duration: u128,
230}
231
232impl SerializedTaskTiming {
233    /// Convert an array of [`TaskTiming`] into their serializable format
234    ///
235    /// # Params
236    ///
237    /// `anchor` - [`Instant`] that should be earlier than all timings to use as base anchor
238    pub fn convert(anchor: Instant, timings: &[TaskTiming]) -> Vec<SerializedTaskTiming> {
239        let serialized = timings
240            .iter()
241            .map(|timing| {
242                let start = timing.start.duration_since(anchor).as_nanos();
243                let duration = timing.end.0.duration_since(timing.start).as_nanos();
244                SerializedTaskTiming {
245                    location: timing.location.into(),
246                    start,
247                    duration,
248                }
249            })
250            .collect::<Vec<_>>();
251
252        serialized
253    }
254
255    /// `anchor` - [`Instant`] that should be earlier than all timings to use as base anchor
256    pub fn from(anchor: Instant, timing: TaskTiming) -> SerializedTaskTiming {
257        let start = timing.start.duration_since(anchor).as_nanos();
258        let duration = timing.end.0.duration_since(timing.start).as_nanos();
259        SerializedTaskTiming {
260            location: timing.location.into(),
261            start,
262            duration,
263        }
264    }
265}
266
267/// Serializable variant of [`ThreadTaskTimings`]
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct SerializedThreadTaskTimings {
270    /// Thread name
271    pub thread_name: Option<String>,
272    /// Hash of the thread id
273    pub thread_id: u64,
274    /// Timing records for this thread
275    pub timings: Vec<SerializedTaskTiming>,
276}
277
278impl SerializedThreadTaskTimings {
279    /// Convert [`ThreadTaskTimings`] into their serializable format
280    ///
281    /// # Params
282    ///
283    /// `anchor` - [`Instant`] that should be earlier than all timings to use as base anchor
284    pub fn convert(anchor: Instant, timings: ThreadTaskTimings) -> SerializedThreadTaskTimings {
285        let serialized_timings = SerializedTaskTiming::convert(anchor, &timings.timings);
286
287        let mut hasher = DefaultHasher::new();
288        timings.thread_id.hash(&mut hasher);
289        let thread_id = hasher.finish();
290
291        SerializedThreadTaskTimings {
292            thread_name: timings.thread_name,
293            thread_id,
294            timings: serialized_timings,
295        }
296    }
297}
298
299#[doc(hidden)]
300#[derive(Debug, Clone)]
301pub struct ThreadTimingsDelta {
302    /// Hashed thread id
303    pub thread_id: u64,
304    /// Thread name, if known
305    pub thread_name: Option<String>,
306    /// New timings since the last call. If the circular buffer wrapped around
307    /// since the previous poll, some entries may have been lost.
308    pub new_timings: Vec<SerializedTaskTiming>,
309}
310
311/// Tracks which timing events have already been seen so that callers can request only unseen events.
312#[doc(hidden)]
313pub struct ProfilingCollector {
314    startup_time: Instant,
315    cursors: HashMap<ThreadId, u64>,
316}
317
318impl ProfilingCollector {
319    pub fn new(startup_time: Instant) -> Self {
320        Self {
321            startup_time,
322            cursors: HashMap::default(),
323        }
324    }
325
326    pub fn startup_time(&self) -> Instant {
327        self.startup_time
328    }
329
330    pub fn collect_unseen(
331        &mut self,
332        all_timings: Vec<ThreadTaskTimings>,
333    ) -> Vec<ThreadTimingsDelta> {
334        let mut deltas = Vec::with_capacity(all_timings.len());
335
336        for thread in all_timings {
337            let mut hasher = DefaultHasher::new();
338            thread.thread_id.hash(&mut hasher);
339            let hashed_id = hasher.finish();
340
341            let prev_cursor = self.cursors.get(&thread.thread_id).copied().unwrap_or(0);
342            let buffer_len = thread.timings.len() as u64;
343            let buffer_start = thread.total_pushed.saturating_sub(buffer_len);
344
345            let mut slice = if prev_cursor < buffer_start {
346                // Cursor fell behind the buffer — some entries were evicted.
347                // Return everything still in the buffer.
348                thread.timings.as_slice()
349            } else {
350                let skip = (prev_cursor - buffer_start) as usize;
351                &thread.timings[skip.min(thread.timings.len())..]
352            };
353
354            let cursor_advance = thread.total_pushed;
355            self.cursors.insert(thread.thread_id, cursor_advance);
356
357            if slice.is_empty() {
358                continue;
359            }
360
361            let new_timings = SerializedTaskTiming::convert(self.startup_time, slice);
362
363            deltas.push(ThreadTimingsDelta {
364                thread_id: hashed_id,
365                thread_name: thread.thread_name,
366                new_timings,
367            });
368        }
369
370        deltas
371    }
372
373    pub fn reset(&mut self) {
374        self.cursors.clear();
375    }
376}
377
378// Allow 16MiB of task timing entries.
379// VecDeque grows by doubling its capacity when full, so keep this a power of 2 to avoid wasting
380// memory.
381const MAX_TASK_TIMINGS: usize = (16 * 1024 * 1024) / core::mem::size_of::<TaskTiming>();
382
383#[doc(hidden)]
384pub(crate) type TaskTimings = VecDeque<TaskTiming>;
385
386#[doc(hidden)]
387pub type GuardedTaskTimings = spin::Mutex<ThreadTimings>;
388
389#[doc(hidden)]
390pub struct GlobalThreadTimings {
391    pub thread_id: ThreadId,
392    pub timings: std::sync::Weak<GuardedTaskTimings>,
393}
394
395#[doc(hidden)]
396#[derive(Debug, Clone)]
397pub struct TaskStatistics {
398    pub poll_time_to_beat: Duration,
399    pub runtime_to_beat: Duration,
400    pub longest_poll_times: [TaskTiming; 5],
401    pub longest_runtimes: [TaskTiming; 5],
402}
403
404impl std::fmt::Display for TaskStatistics {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        f.write_str("Tasks that blocked the longest before yielding\n")?;
407        for timing in self.longest_poll_times {
408            f.write_fmt(format_args!(
409                "{:<20} - {}:{}\n",
410                format!("{:?}", timing.poll_duration()),
411                timing.location.file(),
412                timing.location.column()
413            ))?;
414        }
415        f.write_str("Tasks that ran the longest\n")?;
416        for timing in self.longest_runtimes {
417            f.write_fmt(format_args!(
418                "{:<20} - {}:{}\n",
419                format!("{:?}", timing.since_spawn()),
420                timing.location.file(),
421                timing.location.column()
422            ))?;
423        }
424        Ok(())
425    }
426}
427
428impl Default for TaskStatistics {
429    fn default() -> Self {
430        Self {
431            // Do not track polls that are not problematic
432            // this keeps more calls on the fast path
433            poll_time_to_beat: Duration::from_micros(100),
434            runtime_to_beat: Duration::from_micros(100),
435            longest_poll_times: [TaskTiming::placeholder(); 5],
436            longest_runtimes: [TaskTiming::placeholder(); 5],
437        }
438    }
439}
440
441impl TaskStatistics {
442    #[inline(always)]
443    fn add_yield_timing(&mut self, task: TaskTiming) {
444        let yielded_after = task.poll_duration();
445        if yielded_after >= self.poll_time_to_beat {
446            cold_path(); // most tasks are not the worst, optimize for that
447            let to_replace = self
448                .longest_poll_times
449                .iter()
450                .position_min_by_key(|task| task.since_spawn())
451                .expect("guarded by the comparison with nth_longest_yield_time");
452            self.longest_poll_times[to_replace] = task;
453
454            self.poll_time_to_beat = self
455                .longest_poll_times
456                .iter()
457                .map(|task| task.since_spawn())
458                .min()
459                .expect("never empty");
460        }
461    }
462
463    #[inline(always)]
464    fn add_runtime(&mut self, task: TaskTiming) {
465        let runtime = task.since_spawn();
466        if runtime >= self.runtime_to_beat {
467            cold_path(); // most tasks are not the worst, optimize for that
468            let to_replace = self
469                .longest_runtimes
470                .iter()
471                .position_min_by_key(|task| task.since_spawn())
472                .expect("guarded by the comparison with nth_longest_yield_time");
473            self.longest_runtimes[to_replace] = task;
474
475            self.runtime_to_beat = self
476                .longest_runtimes
477                .iter()
478                .map(|task| task.since_spawn())
479                .min()
480                .expect("never empty");
481        }
482    }
483}
484
485#[doc(hidden)]
486pub static GLOBAL_THREAD_TIMINGS: spin::Mutex<Vec<GlobalThreadTimings>> =
487    spin::Mutex::new(Vec::new());
488
489thread_local! {
490    #[doc(hidden)]
491    pub static THREAD_TIMINGS: LazyCell<Arc<GuardedTaskTimings>> = LazyCell::new(|| {
492        let current_thread = std::thread::current();
493        let thread_name = current_thread.name();
494        let thread_id = current_thread.id();
495        let timings = ThreadTimings::new(thread_name.map(|e| e.to_string()), thread_id);
496        let timings = Arc::new(spin::Mutex::new(timings));
497
498        {
499            let timings = Arc::downgrade(&timings);
500            let global_timings = GlobalThreadTimings {
501                thread_id: std::thread::current().id(),
502                timings,
503            };
504            GLOBAL_THREAD_TIMINGS.lock().push(global_timings);
505        }
506
507        timings
508    });
509}
510
511#[doc(hidden)]
512pub struct ThreadTimings {
513    pub thread_name: Option<String>,
514    pub thread_id: ThreadId,
515    pub timings: TaskTimings,
516    pub running: Option<ActiveTiming>,
517    pub stats: TaskStatistics,
518    pub total_pushed: u64,
519}
520
521impl ThreadTimings {
522    pub fn new(thread_name: Option<String>, thread_id: ThreadId) -> Self {
523        ThreadTimings {
524            thread_name,
525            thread_id,
526            timings: TaskTimings::new(),
527            stats: TaskStatistics::default(),
528            total_pushed: 0,
529            running: None,
530        }
531    }
532
533    pub fn update_running_task(
534        &mut self,
535        spawned: SpawnTime,
536        location: &'static std::panic::Location<'_>,
537    ) {
538        let start = Instant::now();
539        self.running = Some(ActiveTiming {
540            spawned,
541            location,
542            start,
543        });
544    }
545
546    pub fn save_task_timing(&mut self, ended: YieldTime) {
547        let ActiveTiming {
548            location,
549            start,
550            spawned,
551        } = self
552            .running
553            .take()
554            .expect("this function is only ever called after register_task_start");
555
556        let timing = TaskTiming {
557            location,
558            spawned,
559            start,
560            end: ended,
561        };
562        self.stats.add_yield_timing(timing);
563        self.stats.add_runtime(timing);
564
565        if trace_enabled() {
566            cold_path(); // optimize for when the profiling is off
567            if self.timings.len() >= MAX_TASK_TIMINGS {
568                self.timings.pop_front();
569            }
570            self.timings.push_back(timing);
571            self.total_pushed += 1;
572        }
573    }
574
575    // Running tasks are included in the reliability trace, which is written
576    // whenever the foreground executor makes no progress for > n seconds
577    pub fn get_thread_task_timings(&self, includes: TasksIncluded) -> ThreadTaskTimings {
578        ThreadTaskTimings {
579            thread_name: self.thread_name.clone(),
580            thread_id: self.thread_id,
581            timings: self
582                .timings
583                .iter()
584                .cloned()
585                .chain(
586                    self.running
587                        .filter(|_| matches!(includes, TasksIncluded::CompletedAndRunning))
588                        .map(|running| TaskTiming {
589                            spawned: running.spawned,
590                            location: running.location,
591                            start: running.start,
592                            end: YieldTime(Instant::now()),
593                        }),
594                )
595                .collect(),
596            stats: self.stats.clone(),
597            total_pushed: self.total_pushed,
598        }
599    }
600}
601
602impl Drop for ThreadTimings {
603    fn drop(&mut self) {
604        let mut thread_timings = GLOBAL_THREAD_TIMINGS.lock();
605
606        let Some((index, _)) = thread_timings
607            .iter()
608            .enumerate()
609            .find(|(_, t)| t.thread_id == self.thread_id)
610        else {
611            return;
612        };
613        thread_timings.swap_remove(index);
614    }
615}
616
617#[doc(hidden)]
618pub fn update_running_task(spawned: SpawnTime, location: &'static std::panic::Location<'_>) {
619    THREAD_TIMINGS.with(|timings| {
620        timings.lock().update_running_task(spawned, location);
621    });
622}
623
624#[doc(hidden)]
625pub fn save_task_timing() {
626    let yielded_at = YieldTime(Instant::now());
627    THREAD_TIMINGS.with(|timings| {
628        timings.lock().save_task_timing(yielded_at);
629    });
630}
631
632#[doc(hidden)]
633pub fn get_current_thread_task_timings(include_running: TasksIncluded) -> ThreadTaskTimings {
634    THREAD_TIMINGS.with(|timings| timings.lock().get_thread_task_timings(include_running))
635}
636
637static PROFILER_ENABLED: AtomicBool = AtomicBool::new(false);
638
639/// Enables or disables task timing trace collection at runtime.
640///
641/// When transitioning from enabled to disabled, `add_task_timing` becomes a
642/// cheaper since only cheap statistics are gathered. The existing per-thread
643/// buffers for traces are cleared so stale data isn't reported after a later
644/// re-enable. Calls with the current value are a no-op.
645pub fn set_trace_enabled(enabled: bool) -> bool {
646    if PROFILER_ENABLED.swap(enabled, Ordering::AcqRel) == enabled {
647        return false;
648    }
649
650    if !enabled {
651        for global in GLOBAL_THREAD_TIMINGS.lock().iter() {
652            if let Some(timings) = global.timings.upgrade() {
653                let mut timings = timings.lock();
654                timings.timings.clear();
655                timings.timings.shrink_to_fit();
656                timings.total_pushed = 0;
657            }
658        }
659    }
660    true
661}
662
663/// Returns whether task timing tracing is enabled.
664pub fn trace_enabled() -> bool {
665    PROFILER_ENABLED.load(Ordering::Relaxed)
666}