cubecl_runtime/
timestamp_profiler.rs

1use cubecl_common::{
2    backtrace::BackTrace,
3    profile::{Instant, ProfileDuration},
4};
5use hashbrown::HashMap;
6
7use crate::server::{ProfileError, ProfilingToken};
8
9#[derive(Default, Debug)]
10/// A simple struct to keep track of timestamps for kernel execution.
11/// This should be used for servers that do not have native device profiling.
12pub struct TimestampProfiler {
13    state: HashMap<ProfilingToken, State>,
14    counter: u64,
15}
16
17#[derive(Debug)]
18enum State {
19    Start(Instant),
20    Error(ProfileError),
21}
22
23impl TimestampProfiler {
24    /// If there is some profiling registered.
25    pub fn is_empty(&self) -> bool {
26        self.state.is_empty()
27    }
28    /// Start measuring
29    pub fn start(&mut self) -> ProfilingToken {
30        let token = ProfilingToken { id: self.counter };
31        self.counter += 1;
32        self.state.insert(token, State::Start(Instant::now()));
33        token
34    }
35
36    /// Stop measuring
37    pub fn stop(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
38        let state = self.state.remove(&token);
39        let start = match state {
40            Some(val) => match val {
41                State::Start(instant) => instant,
42                State::Error(profile_error) => return Err(profile_error),
43            },
44            None => {
45                return Err(ProfileError::NotRegistered {
46                    backtrace: BackTrace::capture(),
47                });
48            }
49        };
50        Ok(ProfileDuration::new_system_time(start, Instant::now()))
51    }
52
53    /// Register an error during profiling.
54    pub fn error(&mut self, error: ProfileError) {
55        self.state
56            .iter_mut()
57            .for_each(|(_, state)| *state = State::Error(error.clone()));
58    }
59}