use std::sync::Arc;
use cubecl_common::profile::{Duration, Instant, ProfileDuration, ProfileTicks};
use cubecl_core::server::{ProfileError, ProfilingToken};
use hashbrown::HashMap;
use wgpu::{QUERY_SIZE, QuerySet, QuerySetDescriptor, QueryType};
type QuerySetId = u64;
#[derive(Debug)]
pub struct QueryProfiler {
timestamps: HashMap<ProfilingToken, Timestamp>,
init_tokens: Vec<ProfilingToken>,
query_set_pool: Vec<QuerySet>,
query_sets: HashMap<QuerySetId, QuerySetItem>,
current: Option<u64>,
counter_token: u64,
counter_query_set: u64,
cleanups: Vec<QuerySetId>,
queue_period: f64,
epoch_tick: u64,
epoch_instant: Instant,
}
#[derive(Debug)]
pub struct Timestamp {
start: Option<u64>,
end: Option<u64>,
}
#[derive(Debug)]
struct QuerySetItem {
query_set: QuerySet,
num_ref: u32,
}
fn create_resolve_buffer(device: &wgpu::Device, count: u32) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some("CubeCL gpu -> cpu resolve buffer"),
size: (QUERY_SIZE * count) as _,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
}
fn create_map_buffer(device: &wgpu::Device, count: u32) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some("CubeCL gpu -> cpu map buffer"),
size: (QUERY_SIZE * count) as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
#[cfg(feature = "profile-tracy")]
fn get_cur_timestamp(queue: &wgpu::Queue, device: &wgpu::Device) -> u64 {
use wgpu::BufferAddress;
device.poll(wgpu::PollType::Wait).unwrap();
let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
label: Some("CubeCL gpu -> cpu sync query_set"),
ty: wgpu::QueryType::Timestamp,
count: 1,
});
let resolve_buffer = create_resolve_buffer(device, 1);
let map_buffer = create_map_buffer(device, 1);
let mut timestamp_encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu-profiler gpu -> cpu query timestamp"),
});
timestamp_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Write timestamp pass"),
timestamp_writes: Some(wgpu::ComputePassTimestampWrites {
query_set: &query_set,
beginning_of_pass_write_index: None,
end_of_pass_write_index: Some(0),
}),
});
timestamp_encoder.write_timestamp(&query_set, 0);
timestamp_encoder.resolve_query_set(&query_set, 0..1, &resolve_buffer, 0);
let mut copy_encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu-profiler gpu -> cpu copy timestamp"),
});
copy_encoder.copy_buffer_to_buffer(
&resolve_buffer,
0,
&map_buffer,
0,
Some(QUERY_SIZE as BufferAddress),
);
let commands = [timestamp_encoder.finish(), copy_encoder.finish()];
queue.submit(commands);
map_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
device.poll(wgpu::PollType::Wait).unwrap();
let view = map_buffer.slice(..).get_mapped_range();
u64::from_le_bytes((*view).try_into().unwrap())
}
impl QueryProfiler {
pub fn new(queue: &wgpu::Queue, #[allow(unused)] device: &wgpu::Device) -> Self {
#[cfg(feature = "profile-tracy")]
let sync_timestamps = get_cur_timestamp(queue, device);
#[cfg(not(feature = "profile-tracy"))]
let sync_timestamps = 0;
let epoch_instant = Instant::now();
Self {
cleanups: Vec::new(),
counter_query_set: 0,
counter_token: 0,
query_sets: HashMap::new(),
query_set_pool: Vec::new(),
current: None,
timestamps: HashMap::new(),
init_tokens: Vec::new(),
queue_period: queue.get_timestamp_period() as f64,
epoch_instant,
epoch_tick: sync_timestamps,
}
}
pub fn start_profile(&mut self) -> ProfilingToken {
let token = ProfilingToken {
id: self.counter_token,
};
self.counter_token += 1;
self.init_tokens.push(token);
self.timestamps.insert(
token,
Timestamp {
start: None,
end: None,
},
);
token
}
pub fn stop_profile_setup(
&mut self,
token: ProfilingToken,
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
) -> Option<wgpu::Buffer> {
let mut timestamps = self.timestamps.remove(&token).unwrap();
let Timestamp { start, end } = &mut timestamps;
*end = self.current;
let (Some(start), Some(end)) = (start, end) else {
return None;
};
let query_set_start = self.query_sets.get_mut(start).unwrap();
query_set_start.num_ref -= 1;
if query_set_start.num_ref == 0 {
self.cleanups.push(*start);
}
let resolve_start = create_resolve_buffer(device, 1);
let resolve_end = create_resolve_buffer(device, 1);
let map_buffer = create_map_buffer(device, 2);
let query_set_start = self.query_sets.get(start).unwrap();
let query_set_end = self.query_sets.get(end).unwrap();
let size = QUERY_SIZE as u64;
encoder.resolve_query_set(&query_set_start.query_set, 0..1, &resolve_start, 0);
encoder.resolve_query_set(&query_set_end.query_set, 1..2, &resolve_end, 0);
encoder.copy_buffer_to_buffer(&resolve_start, 0, &map_buffer, 0, size);
encoder.copy_buffer_to_buffer(&resolve_end, 0, &map_buffer, size, size);
Some(map_buffer)
}
pub fn stop_profile(
&self,
map_buffer: Option<wgpu::Buffer>,
poll_signal: Arc<()>,
) -> Result<ProfileDuration, ProfileError> {
if let Some(map_buffer) = map_buffer {
let period = self.queue_period;
let epoch_tick = self.epoch_tick;
let epoch_instant = self.epoch_instant;
Ok(ProfileDuration::new_device_time(async move {
let (sender, rec) = async_channel::bounded(1);
map_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, move |v| {
let _ = sender.try_send(v);
});
rec.recv()
.await
.expect("Unable to receive buffer slice result.")
.expect("Failed to map buffer");
core::mem::drop(poll_signal);
let binding = map_buffer.slice(..).get_mapped_range();
let data: &[u64] = bytemuck::try_cast_slice(&binding).unwrap();
let data_start = data[0].saturating_sub(epoch_tick);
let data_end = data[1].saturating_sub(epoch_tick);
drop(binding);
map_buffer.unmap();
let start_duration = Duration::from_nanos((data_start as f64 * period) as u64);
let end_duration = Duration::from_nanos((data_end as f64 * period) as u64);
let instant_start = epoch_instant + start_duration;
let instant_end = epoch_instant + end_duration;
ProfileTicks::from_start_end(instant_start, instant_end)
}))
} else {
let now = Instant::now();
Ok(ProfileDuration::new_device_time(async move {
ProfileTicks::from_start_end(now, now)
}))
}
}
pub fn register_profile_device(&mut self, device: &wgpu::Device) -> Option<&QuerySet> {
self.init_query_set().map(|info| {
let item = self.new_query_set(info, device);
&item.query_set
})
}
fn new_query_set(
&mut self,
query_set_info: (u64, u32),
device: &wgpu::Device,
) -> &mut QuerySetItem {
let (query_set_id, num_ref) = query_set_info;
let query_set = if let Some(pool) = self.query_set_pool.pop() {
pool
} else {
device.create_query_set(&QuerySetDescriptor {
label: Some("CubeCL profile queries"),
ty: QueryType::Timestamp,
count: 2,
})
};
let slot = QuerySetItem { query_set, num_ref };
self.query_sets.insert(query_set_id, slot);
self.query_sets.get_mut(&query_set_id).unwrap()
}
fn init_query_set(&mut self) -> Option<(QuerySetId, u32)> {
let mut query_set_id = None;
let mut count = 0;
for token in self.init_tokens.drain(..) {
if let Some(Timestamp { start, .. }) = &mut self.timestamps.get_mut(&token) {
count += 1;
let id = match query_set_id {
Some(id) => id,
None => {
let id = self.counter_query_set;
self.counter_query_set += 1;
self.current = Some(id);
query_set_id = Some(id);
id
}
};
*start = Some(id);
}
}
self.cleanup_query_sets();
query_set_id.map(|v| (v, count))
}
fn cleanup_query_sets(&mut self) {
for key in self.cleanups.drain(..) {
let removed = self
.query_sets
.remove(&key)
.expect("Unknown query set cleaned up");
self.query_set_pool.push(removed.query_set);
}
}
}