Skip to main content

cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
3use crate::{AutoCompiler, AutoRepresentation};
4use alloc::sync::Arc;
5use cubecl_common::{
6    backtrace::BackTrace,
7    bytes::Bytes,
8    profile::{ProfileDuration, TimingMethod},
9    stream_id::StreamId,
10};
11use cubecl_core::{
12    MemoryConfiguration, WgpuCompilationOptions,
13    future::DynFut,
14    prelude::*,
15    server::{
16        Allocation, AllocationDescriptor, Binding, Bindings, CopyDescriptor, ExecutionError,
17        IoError, LaunchError, ProfileError, ProfilingToken, ResourceLimitError,
18        ServerCommunication, ServerUtilities,
19    },
20    zspace::{Strides, strides},
21};
22#[cfg(feature = "spirv")]
23use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash};
24use cubecl_ir::MemoryDeviceProperties;
25use cubecl_runtime::{
26    compiler::CubeTask,
27    config::GlobalConfig,
28    logging::ServerLogger,
29    memory_management::{MemoryAllocationMode, offset_handles},
30    server::ComputeServer,
31    storage::BindingResource,
32    stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
33    validation::{validate_cube_dim, validate_units},
34};
35use hashbrown::HashMap;
36use wgpu::ComputePipeline;
37
38/// Wgpu compute server.
39#[derive(Debug)]
40pub struct WgpuServer {
41    pub(crate) device: wgpu::Device,
42    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
43    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
44    #[cfg(feature = "spirv")]
45    pub(crate) spirv_cache:
46        Option<CompilationCache<(u64, StableHash), cubecl_spirv::SpirvCacheEntry>>,
47    pub compilation_options: WgpuCompilationOptions,
48    pub(crate) backend: wgpu::Backend,
49    pub(crate) utilities: Arc<ServerUtilities<Self>>,
50}
51
52impl ServerCommunication for WgpuServer {
53    const SERVER_COMM_ENABLED: bool = false;
54}
55
56impl WgpuServer {
57    /// Create a new server.
58    #[allow(clippy::too_many_arguments)]
59    pub fn new(
60        memory_properties: MemoryDeviceProperties,
61        memory_config: MemoryConfiguration,
62        compilation_options: WgpuCompilationOptions,
63        device: wgpu::Device,
64        queue: wgpu::Queue,
65        tasks_max: usize,
66        backend: wgpu::Backend,
67        timing_method: TimingMethod,
68        utilities: ServerUtilities<Self>,
69    ) -> Self {
70        let backend_scheduler = ScheduledWgpuBackend::new(
71            device.clone(),
72            queue.clone(),
73            memory_properties,
74            memory_config,
75            timing_method,
76            tasks_max,
77            utilities.logger.clone(),
78        );
79
80        let config = GlobalConfig::get();
81        let max_streams = config.streaming.max_streams;
82
83        Self {
84            compilation_options,
85            device,
86            pipelines: HashMap::new(),
87            scheduler: SchedulerMultiStream::new(
88                utilities.logger.clone(),
89                backend_scheduler,
90                SchedulerMultiStreamOptions {
91                    max_streams,
92                    max_tasks: tasks_max,
93                    strategy: SchedulerStrategy::Interleave,
94                },
95            ),
96            #[cfg(feature = "spirv")]
97            spirv_cache: {
98                let config = cubecl_runtime::config::GlobalConfig::get();
99                if let Some(cache) = &config.compilation.cache {
100                    let root = cache.root();
101                    Some(CompilationCache::new(
102                        "spirv",
103                        CacheOption::default().name("vulkan").root(root),
104                    ))
105                } else {
106                    None
107                }
108            },
109            backend,
110            utilities: Arc::new(utilities),
111        }
112    }
113
114    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
115        // Store all the resources we'll be using. This could be eliminated if
116        // there was a way to tie the lifetime of the resource to the memory handle.
117        let resources = bindings
118            .buffers
119            .iter()
120            .map(|b| {
121                let stream = self.scheduler.stream(&b.stream);
122                stream.mem_manage.get_resource(b.clone()).unwrap()
123            })
124            .collect::<Vec<_>>();
125
126        BindingsResource {
127            resources,
128            metadata: bindings.metadata,
129            scalars: bindings.scalars,
130        }
131    }
132
133    fn pipeline(
134        &mut self,
135        kernel: <Self as ComputeServer>::Kernel,
136        bindings: &Bindings,
137        mode: ExecutionMode,
138    ) -> Result<Arc<ComputePipeline>, LaunchError> {
139        let mut kernel_id = kernel.id();
140        kernel_id.mode(mode);
141
142        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
143            return Ok(pipeline.clone());
144        }
145
146        let cached = self.load_cached_pipeline(&kernel_id, bindings, mode)?;
147
148        if let Some(Ok(pipeline)) = cached {
149            self.pipelines.insert(kernel_id, pipeline.clone());
150            return Ok(pipeline);
151        }
152
153        validate_cube_dim(&self.utilities.properties, &kernel_id)?;
154        validate_units(&self.utilities.properties, &kernel_id)?;
155
156        let mut compiler = compiler(self.backend);
157        let mut compiled = compiler.compile(self, kernel, mode)?;
158
159        if self.scheduler.logger.compilation_activated() {
160            compiled.debug_info = Some(DebugInformation::new(
161                compiler.lang_tag(),
162                kernel_id.clone(),
163            ));
164        }
165        self.scheduler.logger.log_compilation(&compiled);
166
167        self.validate_shared(&compiled.repr)?;
168
169        // /!\ Do not delete the following commented code.
170        // This is useful while working on the metal compiler.
171        // Also the errors are printed nicely which is not the case when this is the runtime
172        // that does it.
173        // println!("SOURCE:\n{}", compile.source);
174        // {
175        //     // Write shader in metal file then compile it for error
176        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
177        //     let _status = std::process::Command::new("xcrun")
178        //         .args(vec![
179        //             "-sdk",
180        //             "macosx",
181        //             "metal",
182        //             "-o",
183        //             "shader.ir",
184        //             "-c",
185        //             "shader.metal",
186        //         ])
187        //         .status()
188        //         .expect("should launch the command");
189        //     // std::process::exit(status.code().unwrap());
190        // }
191        let repr = compiled.repr.as_ref().map(|it| it.as_ref());
192        let module = self.create_module(&compiled.entrypoint_name, repr, &compiled.source, mode)?;
193        let pipeline = self.create_pipeline(&compiled.entrypoint_name, repr, module, bindings);
194        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
195
196        #[cfg(feature = "spirv")]
197        if let Some(Err(key)) = cached
198            && let Some(crate::AutoRepresentation::SpirV(kernel)) = compiled.repr
199        {
200            let cache = self.spirv_cache.as_mut().unwrap();
201            let result = cache.insert(
202                key,
203                cubecl_spirv::SpirvCacheEntry::new(compiled.entrypoint_name, kernel),
204            );
205            if let Err(err) = result {
206                log::warn!("Unable to save the SPIR-V {err:?}");
207            }
208        }
209
210        Ok(pipeline)
211    }
212
213    fn validate_shared(&self, repr: &Option<crate::AutoRepresentation>) -> Result<(), LaunchError> {
214        let shared_bytes = repr.as_ref().map(|repr| match repr {
215            AutoRepresentation::Wgsl(repr) => repr.shared_memory_bytes(),
216            #[cfg(feature = "msl")]
217            AutoRepresentation::Msl(repr) => repr.shared_memory_size(),
218            #[cfg(feature = "spirv")]
219            AutoRepresentation::SpirV(repr) => repr.shared_size,
220        });
221        let max_smem = self.utilities.properties.hardware.max_shared_memory_size;
222        if let Some(shared_bytes) = shared_bytes
223            && shared_bytes > max_smem
224        {
225            Err(ResourceLimitError::SharedMemory {
226                requested: shared_bytes,
227                max: max_smem,
228                backtrace: BackTrace::capture(),
229            }
230            .into())
231        } else {
232            Ok(())
233        }
234    }
235}
236
237impl ComputeServer for WgpuServer {
238    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
239    type Storage = WgpuStorage;
240    type Info = wgpu::Backend;
241
242    fn logger(&self) -> Arc<ServerLogger> {
243        self.scheduler.logger.clone()
244    }
245
246    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
247        self.utilities.clone()
248    }
249
250    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
251        // TODO: Check if using a staging buffer is useful here.
252        Err(IoError::UnsupportedIoOperation {
253            backtrace: BackTrace::capture(),
254        })
255    }
256
257    fn create(
258        &mut self,
259        descriptors: Vec<AllocationDescriptor<'_>>,
260        stream_id: StreamId,
261    ) -> Result<Vec<Allocation>, IoError> {
262        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
263        let strides = descriptors
264            .iter()
265            .map(|desc| contiguous_strides(desc.shape))
266            .collect::<Vec<_>>();
267        let sizes = descriptors
268            .iter()
269            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
270            .collect::<Vec<_>>();
271        let total_size = sizes
272            .iter()
273            .map(|it| it.next_multiple_of(align))
274            .sum::<usize>();
275
276        let stream = self.scheduler.stream(&stream_id);
277        let mem_handle = stream.empty(total_size as u64, stream_id)?;
278        let handles = offset_handles(mem_handle, &sizes, align);
279
280        Ok(handles
281            .into_iter()
282            .zip(strides)
283            .map(|(handle, strides)| Allocation::new(handle, strides))
284            .collect())
285    }
286
287    fn read<'a>(
288        &mut self,
289        descriptors: Vec<CopyDescriptor<'a>>,
290        stream_id: StreamId,
291    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
292        let mut streams = vec![stream_id];
293        let mut resources = Vec::with_capacity(descriptors.len());
294        for desc in descriptors {
295            if &*contiguous_strides(desc.shape) != desc.strides {
296                return Box::pin(async {
297                    Err(IoError::UnsupportedStrides {
298                        backtrace: BackTrace::capture(),
299                    })
300                });
301            }
302            if !streams.contains(&desc.binding.stream) {
303                streams.push(desc.binding.stream);
304            }
305            let stream = self.scheduler.stream(&desc.binding.stream);
306            let resource = match stream.mem_manage.get_resource(desc.binding) {
307                Ok(val) => val,
308                Err(err) => return Box::pin(async move { Err(err) }),
309            };
310            resources.push((resource, desc.shape.into(), desc.elem_size));
311        }
312
313        self.scheduler.execute_streams(streams);
314        let stream = self.scheduler.stream(&stream_id);
315        stream.read_resources(resources)
316    }
317
318    fn write(
319        &mut self,
320        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
321        stream_id: StreamId,
322    ) -> Result<(), IoError> {
323        for (desc, data) in descriptors {
324            if &*contiguous_strides(desc.shape) != desc.strides {
325                return Err(IoError::UnsupportedStrides {
326                    backtrace: BackTrace::capture(),
327                });
328            }
329
330            let stream = self.scheduler.stream(&desc.binding.stream);
331            let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
332            let task = ScheduleTask::Write {
333                data,
334                buffer: resource,
335            };
336
337            self.scheduler.register(stream_id, task, [].into_iter());
338        }
339
340        Ok(())
341    }
342
343    fn get_resource(
344        &mut self,
345        binding: Binding,
346        stream_id: StreamId,
347    ) -> BindingResource<WgpuResource> {
348        let mut streams = vec![stream_id];
349        if binding.stream != stream_id {
350            streams.push(binding.stream);
351        }
352        self.scheduler.execute_streams(streams);
353        let stream = self.scheduler.stream(&binding.stream);
354        let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
355        BindingResource::new(binding, resource)
356    }
357
358    unsafe fn launch(
359        &mut self,
360        kernel: Self::Kernel,
361        count: CubeCount,
362        bindings: Bindings,
363        mode: ExecutionMode,
364        stream_id: StreamId,
365    ) -> Result<(), LaunchError> {
366        let pipeline = self.pipeline(kernel, &bindings, mode)?;
367        let buffers = bindings.buffers.clone();
368        let resources = self.prepare_bindings(bindings);
369        let task = ScheduleTask::Execute {
370            pipeline,
371            count,
372            resources,
373        };
374
375        self.scheduler.register(stream_id, task, buffers.iter());
376
377        Ok(())
378    }
379
380    fn flush(&mut self, stream_id: StreamId) {
381        self.scheduler.execute_streams(vec![stream_id]);
382        let stream = self.scheduler.stream(&stream_id);
383        stream.flush()
384    }
385
386    /// Returns the total time of GPU work this sync completes.
387    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
388        self.scheduler.execute_streams(vec![stream_id]);
389        let stream = self.scheduler.stream(&stream_id);
390        stream.sync()
391    }
392
393    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
394        self.scheduler.execute_streams(vec![stream_id]);
395        let stream = self.scheduler.stream(&stream_id);
396        stream.start_profile()
397    }
398
399    fn end_profile(
400        &mut self,
401        stream_id: StreamId,
402        token: ProfilingToken,
403    ) -> Result<ProfileDuration, ProfileError> {
404        self.scheduler.execute_streams(vec![stream_id]);
405        let stream = self.scheduler.stream(&stream_id);
406        stream.end_profile(token)
407    }
408
409    fn memory_usage(
410        &mut self,
411        stream_id: StreamId,
412    ) -> cubecl_runtime::memory_management::MemoryUsage {
413        self.scheduler.execute_streams(vec![stream_id]);
414        let stream = self.scheduler.stream(&stream_id);
415        stream.mem_manage.memory_usage()
416    }
417
418    fn memory_cleanup(&mut self, stream_id: StreamId) {
419        self.scheduler.execute_streams(vec![stream_id]);
420        let stream = self.scheduler.stream(&stream_id);
421        stream.mem_manage.memory_cleanup(true);
422    }
423
424    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
425        self.scheduler.execute_streams(vec![stream_id]);
426        let stream = self.scheduler.stream(&stream_id);
427        stream.mem_manage.mode(mode);
428    }
429}
430
431fn compiler(backend: wgpu::Backend) -> AutoCompiler {
432    match backend {
433        #[cfg(feature = "spirv")]
434        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
435        #[cfg(feature = "msl")]
436        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
437        _ => AutoCompiler::Wgsl(Default::default()),
438    }
439}
440
441pub(crate) fn contiguous_strides(shape: &[usize]) -> Strides {
442    let rank = shape.len();
443    let mut strides = strides![1; rank];
444    for i in (0..rank - 1).rev() {
445        strides[i] = strides[i + 1] * shape[i + 1];
446    }
447    strides
448}