cubecl_wgpu/compute/
server.rs

1use std::{future::Future, marker::PhantomData, num::NonZero, time::Duration};
2
3use super::{
4    stream::{PipelineDispatch, WgpuStream},
5    WgpuStorage,
6};
7use crate::compiler::base::WgpuCompiler;
8use crate::timestamps::KernelTimestamps;
9use alloc::sync::Arc;
10use cubecl_common::future;
11use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, Feature, KernelId};
12use cubecl_runtime::{
13    debug::{DebugLogger, ProfileLevel},
14    memory_management::{MemoryHandle, MemoryLock, MemoryManagement},
15    server::{self, ComputeServer},
16    storage::{BindingResource, ComputeStorage},
17    ExecutionMode, TimestampsError, TimestampsResult,
18};
19use hashbrown::HashMap;
20use wgpu::ComputePipeline;
21
22/// Wgpu compute server.
23#[derive(Debug)]
24pub struct WgpuServer<C: WgpuCompiler> {
25    memory_management: MemoryManagement<WgpuStorage>,
26    pub(crate) device: Arc<wgpu::Device>,
27    queue: Arc<wgpu::Queue>,
28    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
29    logger: DebugLogger,
30    storage_locked: MemoryLock,
31    duration_profiled: Option<Duration>,
32    stream: WgpuStream,
33    pub compilation_options: C::CompilationOptions,
34    _compiler: PhantomData<C>,
35}
36
37impl<C: WgpuCompiler> WgpuServer<C> {
38    /// Create a new server.
39    pub fn new(
40        memory_management: MemoryManagement<WgpuStorage>,
41        compilation_options: C::CompilationOptions,
42        device: Arc<wgpu::Device>,
43        queue: Arc<wgpu::Queue>,
44        tasks_max: usize,
45    ) -> Self {
46        let logger = DebugLogger::default();
47        let mut timestamps = KernelTimestamps::Disabled;
48
49        if logger.profile_level().is_some() {
50            timestamps.enable(&device);
51        }
52
53        let stream = WgpuStream::new(device.clone(), queue.clone(), timestamps, tasks_max);
54
55        // Low estimate, but it makes sure there is no memory error from allocating too much
56        // at the same time.
57        let estimated_buffers_per_task = 4;
58        let storage_locked = MemoryLock::new(tasks_max * estimated_buffers_per_task);
59
60        Self {
61            memory_management,
62            compilation_options,
63            device: device.clone(),
64            queue: queue.clone(),
65            storage_locked,
66            pipelines: HashMap::new(),
67            logger,
68            duration_profiled: None,
69            stream,
70            _compiler: PhantomData,
71        }
72    }
73
74    fn pipeline(
75        &mut self,
76        kernel: <Self as ComputeServer>::Kernel,
77        mode: ExecutionMode,
78    ) -> Arc<ComputePipeline> {
79        let mut kernel_id = kernel.id();
80        kernel_id.mode(mode);
81
82        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
83            return pipeline.clone();
84        }
85
86        let mut compile = <C as WgpuCompiler>::compile(self, kernel, mode);
87
88        if self.logger.is_activated() {
89            compile.debug_info = Some(DebugInformation::new("wgsl", kernel_id.clone()));
90        }
91
92        let compile = self.logger.debug(compile);
93        let pipeline = C::create_pipeline(self, compile);
94
95        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
96
97        pipeline
98    }
99
100    fn on_flushed(&mut self) {
101        self.storage_locked.clear_locked();
102
103        // Cleanup allocations and deallocations.
104        self.memory_management.cleanup();
105        self.memory_management.storage().perform_deallocations();
106    }
107}
108
109impl<C: WgpuCompiler> ComputeServer for WgpuServer<C> {
110    type Kernel = Box<dyn CubeTask<C>>;
111    type Storage = WgpuStorage;
112    type Feature = Feature;
113
114    fn read(
115        &mut self,
116        bindings: Vec<server::Binding>,
117    ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static {
118        let resources = bindings
119            .into_iter()
120            .map(|binding| {
121                let rb = self.get_resource(binding);
122                let resource = rb.resource();
123
124                (resource.buffer.clone(), resource.offset(), resource.size())
125            })
126            .collect();
127
128        // Clear compute pass.
129        let fut = self.stream.read_buffers(resources);
130        self.on_flushed();
131
132        fut
133    }
134
135    fn get_resource(&mut self, binding: server::Binding) -> BindingResource<Self> {
136        let handle = self.memory_management.get(binding.memory.clone());
137
138        // Keep track of any buffer that might be used in the wgpu queue, as we cannot copy into them
139        // after they have any outstanding compute work. Calling get_resource repeatedly
140        // will add duplicates to this, but that is ok.
141        self.storage_locked.add_locked(handle.id);
142
143        let handle = match binding.offset_start {
144            Some(offset) => handle.offset_start(offset),
145            None => handle,
146        };
147        let handle = match binding.offset_end {
148            Some(offset) => handle.offset_end(offset),
149            None => handle,
150        };
151        let resource = self.memory_management.storage().get(&handle);
152        BindingResource::new(binding, resource)
153    }
154
155    /// When we create a new handle from existing data, we use custom allocations so that we don't
156    /// have to execute the current pending tasks.
157    ///
158    /// This is important, otherwise the compute passes are going to be too small and we won't be able to
159    /// fully utilize the GPU.
160    fn create(&mut self, data: &[u8]) -> server::Handle {
161        let num_bytes = data.len() as u64;
162
163        // Copying into a buffer has to be 4 byte aligned. We can safely do so, as
164        // memory is 32 bytes aligned (see WgpuStorage).
165        let align = wgpu::COPY_BUFFER_ALIGNMENT;
166        let aligned_len = num_bytes.div_ceil(align) * align;
167
168        // Reserve memory on some storage we haven't yet used this command queue for compute
169        // or copying.
170        let memory = self
171            .memory_management
172            .reserve(aligned_len, Some(&self.storage_locked));
173
174        if let Some(len) = NonZero::new(aligned_len) {
175            let resource_handle = self.memory_management.get(memory.clone().binding());
176
177            // Dont re-use this handle for writing until the queue is flushed. All writes
178            // happen at the start of the submission.
179            self.storage_locked.add_locked(resource_handle.id);
180
181            let resource = self.memory_management.storage().get(&resource_handle);
182
183            // Write to the staging buffer. Next queue submission this will copy the data to the GPU.
184            self.queue
185                .write_buffer_with(&resource.buffer, resource.offset(), len)
186                .expect("Failed to write to staging buffer.")[0..data.len()]
187                .copy_from_slice(data);
188
189            // If too many handles are locked, we flush.
190            if self.storage_locked.has_reached_threshold() {
191                self.flush();
192            }
193        }
194
195        Handle::new(memory, None, None, aligned_len)
196    }
197
198    fn empty(&mut self, size: usize) -> server::Handle {
199        server::Handle::new(
200            self.memory_management.reserve(size as u64, None),
201            None,
202            None,
203            size as u64,
204        )
205    }
206
207    unsafe fn execute(
208        &mut self,
209        kernel: Self::Kernel,
210        count: CubeCount,
211        bindings: Vec<server::Binding>,
212        mode: ExecutionMode,
213    ) {
214        // Check for any profiling work to be done before execution.
215        let profile_level = self.logger.profile_level();
216        let profile_info = if profile_level.is_some() {
217            Some((kernel.name(), kernel.id()))
218        } else {
219            None
220        };
221
222        if profile_level.is_some() {
223            let fut = self.stream.sync_elapsed();
224            if let Ok(duration) = future::block_on(fut) {
225                if let Some(profiled) = &mut self.duration_profiled {
226                    *profiled += duration;
227                } else {
228                    self.duration_profiled = Some(duration);
229                }
230            }
231        }
232
233        // Start execution.
234        let pipeline = self.pipeline(kernel, mode);
235
236        // Store all the resources we'll be using. This could be eliminated if
237        // there was a way to tie the lifetime of the resource to the memory handle.
238        let resources: Vec<_> = bindings
239            .iter()
240            .map(|binding| self.get_resource(binding.clone()).into_resource())
241            .collect();
242
243        // First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this
244        // needs to be longer than the compute pass, so we can't do this just before dispatching.
245        let dispatch = match count.clone() {
246            CubeCount::Dynamic(binding) => {
247                PipelineDispatch::Dynamic(self.get_resource(binding).into_resource())
248            }
249            CubeCount::Static(x, y, z) => PipelineDispatch::Static(x, y, z),
250        };
251
252        if self
253            .stream
254            .register(pipeline, resources, dispatch, &self.storage_locked)
255        {
256            self.on_flushed();
257        }
258
259        // If profiling, write out results.
260        if let Some(level) = profile_level {
261            let (name, kernel_id) = profile_info.unwrap();
262
263            // Execute the task.
264            if let Ok(duration) = future::block_on(self.stream.sync_elapsed()) {
265                if let Some(profiled) = &mut self.duration_profiled {
266                    *profiled += duration;
267                } else {
268                    self.duration_profiled = Some(duration);
269                }
270
271                let info = match level {
272                    ProfileLevel::Basic | ProfileLevel::Medium => {
273                        if let Some(val) = name.split("<").next() {
274                            val.split("::").last().unwrap_or(name).to_string()
275                        } else {
276                            name.to_string()
277                        }
278                    }
279                    ProfileLevel::Full => {
280                        format!("{name}: {kernel_id} CubeCount {count:?}")
281                    }
282                };
283                self.logger.register_profiled(info, duration);
284            }
285        }
286    }
287
288    fn flush(&mut self) {
289        // End the current compute pass.
290        self.stream.flush();
291        self.on_flushed();
292    }
293
294    /// Returns the total time of GPU work this sync completes.
295    fn sync(&mut self) -> impl Future<Output = ()> + 'static {
296        self.logger.profile_summary();
297        let fut = self.stream.sync();
298        self.on_flushed();
299
300        fut
301    }
302
303    /// Returns the total time of GPU work this sync completes.
304    fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + 'static {
305        self.logger.profile_summary();
306
307        let fut = self.stream.sync_elapsed();
308        self.on_flushed();
309
310        let profiled = self.duration_profiled;
311        self.duration_profiled = None;
312
313        async move {
314            match fut.await {
315                Ok(duration) => match profiled {
316                    Some(profiled) => Ok(duration + profiled),
317                    None => Ok(duration),
318                },
319                Err(err) => match err {
320                    TimestampsError::Disabled => Err(err),
321                    TimestampsError::Unavailable => match profiled {
322                        Some(profiled) => Ok(profiled),
323                        None => Err(err),
324                    },
325                    TimestampsError::Unknown(_) => Err(err),
326                },
327            }
328        }
329    }
330
331    fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
332        self.memory_management.memory_usage()
333    }
334
335    fn enable_timestamps(&mut self) {
336        self.stream.timestamps.enable(&self.device);
337    }
338
339    fn disable_timestamps(&mut self) {
340        // Only disable timestamps if profiling isn't enabled.
341        if self.logger.profile_level().is_none() {
342            self.stream.timestamps.disable();
343        }
344    }
345}