cubecl_runtime/
timestamp_profiler.rs

1use cubecl_common::profile::{Instant, ProfileDuration};
2use hashbrown::HashMap;
3
4use crate::server::{ProfileError, ProfilingToken};
5
6#[derive(Default, Debug)]
7/// A simple struct to keep track of timestamps for kernel execution.
8/// This should be used for servers that do not have native device profiling.
9pub struct TimestampProfiler {
10    state: HashMap<ProfilingToken, State>,
11    counter: u64,
12}
13
14#[derive(Debug)]
15enum State {
16    Start(Instant),
17    Error(ProfileError),
18}
19
20impl TimestampProfiler {
21    /// If there is some profiling registered.
22    pub fn is_empty(&self) -> bool {
23        self.state.is_empty()
24    }
25    /// Start measuring
26    pub fn start(&mut self) -> ProfilingToken {
27        let token = ProfilingToken { id: self.counter };
28        self.counter += 1;
29        self.state.insert(token, State::Start(Instant::now()));
30        token
31    }
32
33    /// Stop measuring
34    pub fn stop(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
35        let state = self.state.remove(&token);
36        let start = match state {
37            Some(val) => match val {
38                State::Start(instant) => instant,
39                State::Error(profile_error) => return Err(profile_error),
40            },
41            None => return Err(ProfileError::NotRegistered),
42        };
43        Ok(ProfileDuration::new_system_time(start, Instant::now()))
44    }
45
46    /// Register an error during profiling.
47    pub fn error(&mut self, error: ProfileError) {
48        self.state
49            .iter_mut()
50            .for_each(|(_, state)| *state = State::Error(error.clone()));
51    }
52}