Skip to main content

cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
3use crate::{AutoCompiler, AutoRepresentation};
4use alloc::sync::Arc;
5use cubecl_common::{
6    backtrace::BackTrace,
7    bytes::Bytes,
8    profile::{ProfileDuration, TimingMethod},
9    stream_id::StreamId,
10};
11use cubecl_core::{
12    MemoryConfiguration, WgpuCompilationOptions,
13    future::DynFut,
14    prelude::*,
15    server::{
16        Allocation, AllocationDescriptor, Binding, Bindings, CopyDescriptor, ExecutionError,
17        IoError, LaunchError, ProfileError, ProfilingToken, ResourceLimitError,
18        ServerCommunication, ServerUtilities,
19    },
20};
21#[cfg(feature = "spirv")]
22use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash};
23use cubecl_ir::MemoryDeviceProperties;
24use cubecl_runtime::{
25    compiler::CubeTask,
26    config::GlobalConfig,
27    logging::ServerLogger,
28    memory_management::{MemoryAllocationMode, offset_handles},
29    server::ComputeServer,
30    storage::BindingResource,
31    stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
32    validation::{validate_cube_dim, validate_units},
33};
34use hashbrown::HashMap;
35use wgpu::ComputePipeline;
36
37/// Wgpu compute server.
38#[derive(Debug)]
39pub struct WgpuServer {
40    pub(crate) device: wgpu::Device,
41    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
42    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
43    #[cfg(feature = "spirv")]
44    pub(crate) spirv_cache:
45        Option<CompilationCache<(u64, StableHash), cubecl_spirv::SpirvCacheEntry>>,
46    pub compilation_options: WgpuCompilationOptions,
47    pub(crate) backend: wgpu::Backend,
48    pub(crate) utilities: Arc<ServerUtilities<Self>>,
49}
50
51impl ServerCommunication for WgpuServer {
52    const SERVER_COMM_ENABLED: bool = false;
53}
54
55impl WgpuServer {
56    /// Create a new server.
57    #[allow(clippy::too_many_arguments)]
58    pub fn new(
59        memory_properties: MemoryDeviceProperties,
60        memory_config: MemoryConfiguration,
61        compilation_options: WgpuCompilationOptions,
62        device: wgpu::Device,
63        queue: wgpu::Queue,
64        tasks_max: usize,
65        backend: wgpu::Backend,
66        timing_method: TimingMethod,
67        utilities: ServerUtilities<Self>,
68    ) -> Self {
69        let backend_scheduler = ScheduledWgpuBackend::new(
70            device.clone(),
71            queue.clone(),
72            memory_properties,
73            memory_config,
74            timing_method,
75            tasks_max,
76            utilities.logger.clone(),
77        );
78
79        let config = GlobalConfig::get();
80        let max_streams = config.streaming.max_streams;
81
82        Self {
83            compilation_options,
84            device,
85            pipelines: HashMap::new(),
86            scheduler: SchedulerMultiStream::new(
87                utilities.logger.clone(),
88                backend_scheduler,
89                SchedulerMultiStreamOptions {
90                    max_streams,
91                    max_tasks: tasks_max,
92                    strategy: SchedulerStrategy::Interleave,
93                },
94            ),
95            #[cfg(feature = "spirv")]
96            spirv_cache: {
97                let config = cubecl_runtime::config::GlobalConfig::get();
98                if let Some(cache) = &config.compilation.cache {
99                    let root = cache.root();
100                    Some(CompilationCache::new(
101                        "spirv",
102                        CacheOption::default().name("vulkan").root(root),
103                    ))
104                } else {
105                    None
106                }
107            },
108            backend,
109            utilities: Arc::new(utilities),
110        }
111    }
112
113    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
114        // Store all the resources we'll be using. This could be eliminated if
115        // there was a way to tie the lifetime of the resource to the memory handle.
116        let resources = bindings
117            .buffers
118            .iter()
119            .map(|b| {
120                let stream = self.scheduler.stream(&b.stream);
121                stream.mem_manage.get_resource(b.clone()).unwrap()
122            })
123            .collect::<Vec<_>>();
124
125        BindingsResource {
126            resources,
127            metadata: bindings.metadata,
128            scalars: bindings.scalars,
129        }
130    }
131
132    fn pipeline(
133        &mut self,
134        kernel: <Self as ComputeServer>::Kernel,
135        bindings: &Bindings,
136        mode: ExecutionMode,
137    ) -> Result<Arc<ComputePipeline>, LaunchError> {
138        let mut kernel_id = kernel.id();
139        kernel_id.mode(mode);
140
141        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
142            return Ok(pipeline.clone());
143        }
144
145        let cached = self.load_cached_pipeline(&kernel_id, bindings, mode)?;
146
147        if let Some(Ok(pipeline)) = cached {
148            self.pipelines.insert(kernel_id, pipeline.clone());
149            return Ok(pipeline);
150        }
151
152        validate_cube_dim(&self.utilities.properties, &kernel_id)?;
153        validate_units(&self.utilities.properties, &kernel_id)?;
154
155        let mut compiler = compiler(self.backend);
156        let mut compiled = compiler.compile(self, kernel, mode)?;
157
158        if self.scheduler.logger.compilation_activated() {
159            compiled.debug_info = Some(DebugInformation::new(
160                compiler.lang_tag(),
161                kernel_id.clone(),
162            ));
163        }
164        self.scheduler.logger.log_compilation(&compiled);
165
166        self.validate_shared(&compiled.repr)?;
167
168        // /!\ Do not delete the following commented code.
169        // This is useful while working on the metal compiler.
170        // Also the errors are printed nicely which is not the case when this is the runtime
171        // that does it.
172        // println!("SOURCE:\n{}", compile.source);
173        // {
174        //     // Write shader in metal file then compile it for error
175        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
176        //     let _status = std::process::Command::new("xcrun")
177        //         .args(vec![
178        //             "-sdk",
179        //             "macosx",
180        //             "metal",
181        //             "-o",
182        //             "shader.ir",
183        //             "-c",
184        //             "shader.metal",
185        //         ])
186        //         .status()
187        //         .expect("should launch the command");
188        //     // std::process::exit(status.code().unwrap());
189        // }
190        let repr = compiled.repr.as_ref().map(|it| it.as_ref());
191        let module = self.create_module(&compiled.entrypoint_name, repr, &compiled.source, mode)?;
192        let pipeline = self.create_pipeline(&compiled.entrypoint_name, repr, module, bindings);
193        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
194
195        #[cfg(feature = "spirv")]
196        if let Some(Err(key)) = cached
197            && let Some(crate::AutoRepresentation::SpirV(kernel)) = compiled.repr
198        {
199            let cache = self.spirv_cache.as_mut().unwrap();
200            let result = cache.insert(
201                key,
202                cubecl_spirv::SpirvCacheEntry::new(compiled.entrypoint_name, kernel),
203            );
204            if let Err(err) = result {
205                log::warn!("Unable to save the SPIR-V {err:?}");
206            }
207        }
208
209        Ok(pipeline)
210    }
211
212    fn validate_shared(&self, repr: &Option<crate::AutoRepresentation>) -> Result<(), LaunchError> {
213        let shared_bytes = repr.as_ref().map(|repr| match repr {
214            AutoRepresentation::Wgsl(repr) => repr.shared_memory_bytes(),
215            #[cfg(feature = "msl")]
216            AutoRepresentation::Msl(repr) => repr.shared_memory_size(),
217            #[cfg(feature = "spirv")]
218            AutoRepresentation::SpirV(repr) => repr.shared_size,
219        });
220        let max_smem = self.utilities.properties.hardware.max_shared_memory_size;
221        if let Some(shared_bytes) = shared_bytes
222            && shared_bytes > max_smem
223        {
224            Err(ResourceLimitError::SharedMemory {
225                requested: shared_bytes,
226                max: max_smem,
227                backtrace: BackTrace::capture(),
228            }
229            .into())
230        } else {
231            Ok(())
232        }
233    }
234}
235
236impl ComputeServer for WgpuServer {
237    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
238    type Storage = WgpuStorage;
239    type Info = wgpu::Backend;
240
241    fn logger(&self) -> Arc<ServerLogger> {
242        self.scheduler.logger.clone()
243    }
244
245    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
246        self.utilities.clone()
247    }
248
249    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
250        // TODO: Check if using a staging buffer is useful here.
251        Err(IoError::UnsupportedIoOperation {
252            backtrace: BackTrace::capture(),
253        })
254    }
255
256    fn create(
257        &mut self,
258        descriptors: Vec<AllocationDescriptor<'_>>,
259        stream_id: StreamId,
260    ) -> Result<Vec<Allocation>, IoError> {
261        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
262        let strides = descriptors
263            .iter()
264            .map(|desc| contiguous_strides(desc.shape))
265            .collect::<Vec<_>>();
266        let sizes = descriptors
267            .iter()
268            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
269            .collect::<Vec<_>>();
270        let total_size = sizes
271            .iter()
272            .map(|it| it.next_multiple_of(align))
273            .sum::<usize>();
274
275        let stream = self.scheduler.stream(&stream_id);
276        let mem_handle = stream.empty(total_size as u64, stream_id)?;
277        let handles = offset_handles(mem_handle, &sizes, align);
278
279        Ok(handles
280            .into_iter()
281            .zip(strides)
282            .map(|(handle, strides)| Allocation::new(handle, strides))
283            .collect())
284    }
285
286    fn read<'a>(
287        &mut self,
288        descriptors: Vec<CopyDescriptor<'a>>,
289        stream_id: StreamId,
290    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
291        let mut streams = vec![stream_id];
292        let mut resources = Vec::with_capacity(descriptors.len());
293        for desc in descriptors {
294            if contiguous_strides(desc.shape) != desc.strides {
295                return Box::pin(async {
296                    Err(IoError::UnsupportedStrides {
297                        backtrace: BackTrace::capture(),
298                    })
299                });
300            }
301            if !streams.contains(&desc.binding.stream) {
302                streams.push(desc.binding.stream);
303            }
304            let stream = self.scheduler.stream(&desc.binding.stream);
305            let resource = match stream.mem_manage.get_resource(desc.binding) {
306                Ok(val) => val,
307                Err(err) => return Box::pin(async move { Err(err) }),
308            };
309            resources.push((resource, desc.shape.to_vec(), desc.elem_size));
310        }
311
312        self.scheduler.execute_streams(streams);
313        let stream = self.scheduler.stream(&stream_id);
314        stream.read_resources(resources)
315    }
316
317    fn write(
318        &mut self,
319        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
320        stream_id: StreamId,
321    ) -> Result<(), IoError> {
322        for (desc, data) in descriptors {
323            if contiguous_strides(desc.shape) != desc.strides {
324                return Err(IoError::UnsupportedStrides {
325                    backtrace: BackTrace::capture(),
326                });
327            }
328
329            let stream = self.scheduler.stream(&desc.binding.stream);
330            let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
331            let task = ScheduleTask::Write {
332                data,
333                buffer: resource,
334            };
335
336            self.scheduler.register(stream_id, task, [].into_iter());
337        }
338
339        Ok(())
340    }
341
342    fn get_resource(
343        &mut self,
344        binding: Binding,
345        stream_id: StreamId,
346    ) -> BindingResource<WgpuResource> {
347        let mut streams = vec![stream_id];
348        if binding.stream != stream_id {
349            streams.push(binding.stream);
350        }
351        self.scheduler.execute_streams(streams);
352        let stream = self.scheduler.stream(&binding.stream);
353        let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
354        BindingResource::new(binding, resource)
355    }
356
357    unsafe fn launch(
358        &mut self,
359        kernel: Self::Kernel,
360        count: CubeCount,
361        bindings: Bindings,
362        mode: ExecutionMode,
363        stream_id: StreamId,
364    ) -> Result<(), LaunchError> {
365        let pipeline = self.pipeline(kernel, &bindings, mode)?;
366        let buffers = bindings.buffers.clone();
367        let resources = self.prepare_bindings(bindings);
368        let task = ScheduleTask::Execute {
369            pipeline,
370            count,
371            resources,
372        };
373
374        self.scheduler.register(stream_id, task, buffers.iter());
375
376        Ok(())
377    }
378
379    fn flush(&mut self, stream_id: StreamId) {
380        self.scheduler.execute_streams(vec![stream_id]);
381        let stream = self.scheduler.stream(&stream_id);
382        stream.flush()
383    }
384
385    /// Returns the total time of GPU work this sync completes.
386    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
387        self.scheduler.execute_streams(vec![stream_id]);
388        let stream = self.scheduler.stream(&stream_id);
389        stream.sync()
390    }
391
392    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
393        self.scheduler.execute_streams(vec![stream_id]);
394        let stream = self.scheduler.stream(&stream_id);
395        stream.start_profile()
396    }
397
398    fn end_profile(
399        &mut self,
400        stream_id: StreamId,
401        token: ProfilingToken,
402    ) -> Result<ProfileDuration, ProfileError> {
403        self.scheduler.execute_streams(vec![stream_id]);
404        let stream = self.scheduler.stream(&stream_id);
405        stream.end_profile(token)
406    }
407
408    fn memory_usage(
409        &mut self,
410        stream_id: StreamId,
411    ) -> cubecl_runtime::memory_management::MemoryUsage {
412        self.scheduler.execute_streams(vec![stream_id]);
413        let stream = self.scheduler.stream(&stream_id);
414        stream.mem_manage.memory_usage()
415    }
416
417    fn memory_cleanup(&mut self, stream_id: StreamId) {
418        self.scheduler.execute_streams(vec![stream_id]);
419        let stream = self.scheduler.stream(&stream_id);
420        stream.mem_manage.memory_cleanup(true);
421    }
422
423    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
424        self.scheduler.execute_streams(vec![stream_id]);
425        let stream = self.scheduler.stream(&stream_id);
426        stream.mem_manage.mode(mode);
427    }
428}
429
430fn compiler(backend: wgpu::Backend) -> AutoCompiler {
431    match backend {
432        #[cfg(feature = "spirv")]
433        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
434        #[cfg(feature = "msl")]
435        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
436        _ => AutoCompiler::Wgsl(Default::default()),
437    }
438}
439
440pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
441    let rank = shape.len();
442    let mut strides = vec![1; rank];
443    for i in (0..rank - 1).rev() {
444        strides[i] = strides[i + 1] * shape[i + 1];
445    }
446    strides
447}