1use std::{future::Future, marker::PhantomData, num::NonZero, time::Duration};
2
3use super::{
4 stream::{PipelineDispatch, WgpuStream},
5 WgpuStorage,
6};
7use crate::compiler::base::WgpuCompiler;
8use crate::timestamps::KernelTimestamps;
9use alloc::sync::Arc;
10use cubecl_common::future;
11use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, Feature, KernelId};
12use cubecl_runtime::{
13 debug::{DebugLogger, ProfileLevel},
14 memory_management::{MemoryHandle, MemoryLock, MemoryManagement},
15 server::{self, ComputeServer},
16 storage::{BindingResource, ComputeStorage},
17 ExecutionMode, TimestampsError, TimestampsResult,
18};
19use hashbrown::HashMap;
20use wgpu::ComputePipeline;
21
22#[derive(Debug)]
24pub struct WgpuServer<C: WgpuCompiler> {
25 memory_management: MemoryManagement<WgpuStorage>,
26 pub(crate) device: Arc<wgpu::Device>,
27 queue: Arc<wgpu::Queue>,
28 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
29 logger: DebugLogger,
30 storage_locked: MemoryLock,
31 duration_profiled: Option<Duration>,
32 stream: WgpuStream,
33 pub compilation_options: C::CompilationOptions,
34 _compiler: PhantomData<C>,
35}
36
37impl<C: WgpuCompiler> WgpuServer<C> {
38 pub fn new(
40 memory_management: MemoryManagement<WgpuStorage>,
41 compilation_options: C::CompilationOptions,
42 device: Arc<wgpu::Device>,
43 queue: Arc<wgpu::Queue>,
44 tasks_max: usize,
45 ) -> Self {
46 let logger = DebugLogger::default();
47 let mut timestamps = KernelTimestamps::Disabled;
48
49 if logger.profile_level().is_some() {
50 timestamps.enable(&device);
51 }
52
53 let stream = WgpuStream::new(device.clone(), queue.clone(), timestamps, tasks_max);
54
55 let estimated_buffers_per_task = 4;
58 let storage_locked = MemoryLock::new(tasks_max * estimated_buffers_per_task);
59
60 Self {
61 memory_management,
62 compilation_options,
63 device: device.clone(),
64 queue: queue.clone(),
65 storage_locked,
66 pipelines: HashMap::new(),
67 logger,
68 duration_profiled: None,
69 stream,
70 _compiler: PhantomData,
71 }
72 }
73
74 fn pipeline(
75 &mut self,
76 kernel: <Self as ComputeServer>::Kernel,
77 mode: ExecutionMode,
78 ) -> Arc<ComputePipeline> {
79 let mut kernel_id = kernel.id();
80 kernel_id.mode(mode);
81
82 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
83 return pipeline.clone();
84 }
85
86 let mut compile = <C as WgpuCompiler>::compile(self, kernel, mode);
87
88 if self.logger.is_activated() {
89 compile.debug_info = Some(DebugInformation::new("wgsl", kernel_id.clone()));
90 }
91
92 let compile = self.logger.debug(compile);
93 let pipeline = C::create_pipeline(self, compile);
94
95 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
96
97 pipeline
98 }
99
100 fn on_flushed(&mut self) {
101 self.storage_locked.clear_locked();
102
103 self.memory_management.cleanup();
105 self.memory_management.storage().perform_deallocations();
106 }
107}
108
109impl<C: WgpuCompiler> ComputeServer for WgpuServer<C> {
110 type Kernel = Box<dyn CubeTask<C>>;
111 type Storage = WgpuStorage;
112 type Feature = Feature;
113
114 fn read(
115 &mut self,
116 bindings: Vec<server::Binding>,
117 ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static {
118 let resources = bindings
119 .into_iter()
120 .map(|binding| {
121 let rb = self.get_resource(binding);
122 let resource = rb.resource();
123
124 (resource.buffer.clone(), resource.offset(), resource.size())
125 })
126 .collect();
127
128 let fut = self.stream.read_buffers(resources);
130 self.on_flushed();
131
132 fut
133 }
134
135 fn get_resource(&mut self, binding: server::Binding) -> BindingResource<Self> {
136 let handle = self.memory_management.get(binding.memory.clone());
137
138 self.storage_locked.add_locked(handle.id);
142
143 let handle = match binding.offset_start {
144 Some(offset) => handle.offset_start(offset),
145 None => handle,
146 };
147 let handle = match binding.offset_end {
148 Some(offset) => handle.offset_end(offset),
149 None => handle,
150 };
151 let resource = self.memory_management.storage().get(&handle);
152 BindingResource::new(binding, resource)
153 }
154
155 fn create(&mut self, data: &[u8]) -> server::Handle {
161 let num_bytes = data.len() as u64;
162
163 let align = wgpu::COPY_BUFFER_ALIGNMENT;
166 let aligned_len = num_bytes.div_ceil(align) * align;
167
168 let memory = self
171 .memory_management
172 .reserve(aligned_len, Some(&self.storage_locked));
173
174 if let Some(len) = NonZero::new(aligned_len) {
175 let resource_handle = self.memory_management.get(memory.clone().binding());
176
177 self.storage_locked.add_locked(resource_handle.id);
180
181 let resource = self.memory_management.storage().get(&resource_handle);
182
183 self.queue
185 .write_buffer_with(&resource.buffer, resource.offset(), len)
186 .expect("Failed to write to staging buffer.")[0..data.len()]
187 .copy_from_slice(data);
188
189 if self.storage_locked.has_reached_threshold() {
191 self.flush();
192 }
193 }
194
195 Handle::new(memory, None, None, aligned_len)
196 }
197
198 fn empty(&mut self, size: usize) -> server::Handle {
199 server::Handle::new(
200 self.memory_management.reserve(size as u64, None),
201 None,
202 None,
203 size as u64,
204 )
205 }
206
207 unsafe fn execute(
208 &mut self,
209 kernel: Self::Kernel,
210 count: CubeCount,
211 bindings: Vec<server::Binding>,
212 mode: ExecutionMode,
213 ) {
214 let profile_level = self.logger.profile_level();
216 let profile_info = if profile_level.is_some() {
217 Some((kernel.name(), kernel.id()))
218 } else {
219 None
220 };
221
222 if profile_level.is_some() {
223 let fut = self.stream.sync_elapsed();
224 if let Ok(duration) = future::block_on(fut) {
225 if let Some(profiled) = &mut self.duration_profiled {
226 *profiled += duration;
227 } else {
228 self.duration_profiled = Some(duration);
229 }
230 }
231 }
232
233 let pipeline = self.pipeline(kernel, mode);
235
236 let resources: Vec<_> = bindings
239 .iter()
240 .map(|binding| self.get_resource(binding.clone()).into_resource())
241 .collect();
242
243 let dispatch = match count.clone() {
246 CubeCount::Dynamic(binding) => {
247 PipelineDispatch::Dynamic(self.get_resource(binding).into_resource())
248 }
249 CubeCount::Static(x, y, z) => PipelineDispatch::Static(x, y, z),
250 };
251
252 if self
253 .stream
254 .register(pipeline, resources, dispatch, &self.storage_locked)
255 {
256 self.on_flushed();
257 }
258
259 if let Some(level) = profile_level {
261 let (name, kernel_id) = profile_info.unwrap();
262
263 if let Ok(duration) = future::block_on(self.stream.sync_elapsed()) {
265 if let Some(profiled) = &mut self.duration_profiled {
266 *profiled += duration;
267 } else {
268 self.duration_profiled = Some(duration);
269 }
270
271 let info = match level {
272 ProfileLevel::Basic | ProfileLevel::Medium => {
273 if let Some(val) = name.split("<").next() {
274 val.split("::").last().unwrap_or(name).to_string()
275 } else {
276 name.to_string()
277 }
278 }
279 ProfileLevel::Full => {
280 format!("{name}: {kernel_id} CubeCount {count:?}")
281 }
282 };
283 self.logger.register_profiled(info, duration);
284 }
285 }
286 }
287
288 fn flush(&mut self) {
289 self.stream.flush();
291 self.on_flushed();
292 }
293
294 fn sync(&mut self) -> impl Future<Output = ()> + 'static {
296 self.logger.profile_summary();
297 let fut = self.stream.sync();
298 self.on_flushed();
299
300 fut
301 }
302
303 fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + 'static {
305 self.logger.profile_summary();
306
307 let fut = self.stream.sync_elapsed();
308 self.on_flushed();
309
310 let profiled = self.duration_profiled;
311 self.duration_profiled = None;
312
313 async move {
314 match fut.await {
315 Ok(duration) => match profiled {
316 Some(profiled) => Ok(duration + profiled),
317 None => Ok(duration),
318 },
319 Err(err) => match err {
320 TimestampsError::Disabled => Err(err),
321 TimestampsError::Unavailable => match profiled {
322 Some(profiled) => Ok(profiled),
323 None => Err(err),
324 },
325 TimestampsError::Unknown(_) => Err(err),
326 },
327 }
328 }
329 }
330
331 fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
332 self.memory_management.memory_usage()
333 }
334
335 fn enable_timestamps(&mut self) {
336 self.stream.timestamps.enable(&self.device);
337 }
338
339 fn disable_timestamps(&mut self) {
340 if self.logger.profile_level().is_none() {
342 self.stream.timestamps.disable();
343 }
344 }
345}