Skip to main content

cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
3use crate::{AutoCompiler, AutoRepresentation};
4use alloc::sync::Arc;
5use cubecl_common::{
6    backtrace::BackTrace,
7    bytes::Bytes,
8    profile::{ProfileDuration, TimingMethod},
9    stream_id::StreamId,
10};
11use cubecl_core::server::{Binding, StreamErrorMode};
12use cubecl_core::zspace::Shape;
13use cubecl_core::{
14    MemoryConfiguration, WgpuCompilationOptions,
15    future::DynFut,
16    prelude::*,
17    server::{
18        CopyDescriptor, IoError, KernelArguments, LaunchError, ProfileError, ProfilingToken,
19        ResourceLimitError, ServerCommunication, ServerError, ServerUtilities,
20    },
21    zspace::{Strides, strides},
22};
23#[cfg(feature = "spirv")]
24use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash};
25use cubecl_ir::MemoryDeviceProperties;
26use cubecl_runtime::allocator::ContiguousMemoryLayoutPolicy;
27use cubecl_runtime::memory_management::{ManagedMemoryHandle, MemoryUsage};
28use cubecl_runtime::{
29    compiler::CubeTask,
30    config::{CubeClRuntimeConfig, RuntimeConfig},
31    logging::ServerLogger,
32    memory_management::MemoryAllocationMode,
33    server::ComputeServer,
34    storage::ManagedResource,
35    stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
36    validation::{validate_cube_dim, validate_units},
37};
38use hashbrown::HashMap;
39use wgpu::ComputePipeline;
40
41/// Wgpu compute server.
42#[derive(Debug)]
43pub struct WgpuServer {
44    pub(crate) device: wgpu::Device,
45    // A buffer that can be used to store stream id without extra allocations.
46    streams_pool: Vec<StreamId>,
47    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
48    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
49    #[cfg(feature = "spirv")]
50    pub(crate) spirv_cache:
51        Option<CompilationCache<(u64, StableHash), cubecl_spirv::SpirvCacheEntry>>,
52    pub compilation_options: WgpuCompilationOptions,
53    pub(crate) backend: wgpu::Backend,
54    pub(crate) utilities: Arc<ServerUtilities<Self>>,
55}
56
57impl ServerCommunication for WgpuServer {
58    const SERVER_COMM_ENABLED: bool = false;
59}
60
61impl WgpuServer {
62    /// Create a new server.
63    #[allow(clippy::too_many_arguments)]
64    pub fn new(
65        memory_properties: MemoryDeviceProperties,
66        memory_config: MemoryConfiguration,
67        compilation_options: WgpuCompilationOptions,
68        device: wgpu::Device,
69        queue: wgpu::Queue,
70        tasks_max: usize,
71        backend: wgpu::Backend,
72        timing_method: TimingMethod,
73        utilities: ServerUtilities<Self>,
74    ) -> Self {
75        let backend_scheduler = ScheduledWgpuBackend::new(
76            device.clone(),
77            queue.clone(),
78            memory_properties,
79            memory_config,
80            timing_method,
81            tasks_max,
82            utilities.logger.clone(),
83        );
84
85        let config = CubeClRuntimeConfig::get();
86        let max_streams = config.streaming.max_streams;
87
88        Self {
89            compilation_options,
90            streams_pool: Vec::new(),
91            device,
92            pipelines: HashMap::new(),
93            scheduler: SchedulerMultiStream::new(
94                utilities.logger.clone(),
95                backend_scheduler,
96                SchedulerMultiStreamOptions {
97                    max_streams,
98                    max_tasks: tasks_max,
99                    strategy: SchedulerStrategy::Interleave,
100                },
101            ),
102            #[cfg(feature = "spirv")]
103            spirv_cache: {
104                let config = cubecl_runtime::config::CubeClRuntimeConfig::get();
105                if let Some(cache) = &config.compilation.cache {
106                    let root = cache.root();
107                    Some(CompilationCache::new(
108                        "spirv",
109                        CacheOption::default().name("vulkan").root(root),
110                    ))
111                } else {
112                    None
113                }
114            },
115            backend,
116            utilities: Arc::new(utilities),
117        }
118    }
119
120    fn prepare_bindings(&mut self, bindings: KernelArguments) -> Result<BindingsResource, IoError> {
121        // Store all the resources we'll be using. This could be eliminated if
122        // there was a way to tie the lifetime of the resource to the memory handle.
123        let mut resources = Vec::with_capacity(bindings.buffers.len());
124
125        for b in bindings.buffers.into_iter() {
126            let stream = self.scheduler.stream(&b.stream);
127            let resource = stream.mem_manage.get_resource(b)?;
128            resources.push(resource);
129        }
130
131        Ok(BindingsResource {
132            resources,
133            info: bindings.info,
134        })
135    }
136
137    fn pipeline(
138        &mut self,
139        kernel: <Self as ComputeServer>::Kernel,
140        bindings: &KernelArguments,
141        mode: ExecutionMode,
142    ) -> Result<Arc<ComputePipeline>, LaunchError> {
143        let mut kernel_id = kernel.id();
144        kernel_id.mode(mode);
145
146        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
147            return Ok(pipeline.clone());
148        }
149
150        let cached = self.load_cached_pipeline(&kernel_id, bindings, mode)?;
151
152        if let Some(Ok(pipeline)) = cached {
153            self.pipelines.insert(kernel_id, pipeline.clone());
154            return Ok(pipeline);
155        }
156
157        validate_cube_dim(&self.utilities.properties, &kernel_id)?;
158        validate_units(&self.utilities.properties, &kernel_id)?;
159
160        let mut compiler = compiler(self.backend, &self.compilation_options);
161        let mut compiled = compiler.compile(self, kernel, mode)?;
162
163        if self.scheduler.logger.compilation_activated() {
164            compiled.debug_info = Some(DebugInformation::new(
165                compiler.lang_tag(),
166                kernel_id.clone(),
167            ));
168        }
169        self.scheduler.logger.log_compilation(&compiled);
170
171        self.validate_shared(&compiled.repr)?;
172
173        // /!\ Do not delete the following commented code.
174        // This is useful while working on the metal compiler.
175        // Also the errors are printed nicely which is not the case when this is the runtime
176        // that does it.
177        // println!("SOURCE:\n{}", compiled.source);
178        // {
179        //     // Write shader in metal file then compile it for error
180        //     std::fs::write("shader.metal", &compiled.source).expect("should write to file");
181        //     let _status = std::process::Command::new("xcrun")
182        //         .args(vec![
183        //             "-sdk",
184        //             "macosx",
185        //             "metal",
186        //             "-o",
187        //             "shader.ir",
188        //             "-c",
189        //             "shader.metal",
190        //             "-w",
191        //         ])
192        //         .status()
193        //         .expect("should launch the command");
194        //     // std::process::exit(status.code().unwrap());
195        // }
196        let repr = compiled.repr.as_ref().map(|it| it.as_ref());
197        let module = self.create_module(&compiled.entrypoint_name, repr, &compiled.source, mode)?;
198        let pipeline = self.create_pipeline(&compiled.entrypoint_name, repr, module, bindings);
199        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
200
201        #[cfg(feature = "spirv")]
202        if let Some(Err(key)) = cached
203            && let Some(crate::AutoRepresentation::SpirV(kernel)) = compiled.repr
204        {
205            let cache = self.spirv_cache.as_mut().unwrap();
206            let result = cache.insert(
207                key,
208                cubecl_spirv::SpirvCacheEntry::new(compiled.entrypoint_name, kernel),
209            );
210            if let Err(err) = result {
211                log::warn!("Unable to save the SPIR-V {err:?}");
212            }
213        }
214
215        Ok(pipeline)
216    }
217
218    fn validate_shared(&self, repr: &Option<crate::AutoRepresentation>) -> Result<(), LaunchError> {
219        let shared_bytes = repr.as_ref().map(|repr| match repr {
220            AutoRepresentation::Wgsl(repr) => repr.shared_memory_bytes(),
221            #[cfg(feature = "msl")]
222            AutoRepresentation::Msl(repr) => repr.shared_memory_size(),
223            #[cfg(feature = "spirv")]
224            AutoRepresentation::SpirV(repr) => repr.shared_size,
225        });
226        let max_smem = self.utilities.properties.hardware.max_shared_memory_size;
227        if let Some(shared_bytes) = shared_bytes
228            && shared_bytes > max_smem
229        {
230            Err(ResourceLimitError::SharedMemory {
231                requested: shared_bytes,
232                max: max_smem,
233                backtrace: BackTrace::capture(),
234            }
235            .into())
236        } else {
237            Ok(())
238        }
239    }
240}
241
242impl ComputeServer for WgpuServer {
243    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
244    type Storage = WgpuStorage;
245    type MemoryLayoutPolicy = ContiguousMemoryLayoutPolicy;
246    type Info = wgpu::Backend;
247
248    fn logger(&self) -> Arc<ServerLogger> {
249        self.scheduler.logger.clone()
250    }
251
252    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
253        self.utilities.clone()
254    }
255
256    fn staging(
257        &mut self,
258        _sizes: &[usize],
259        _stream_id: StreamId,
260    ) -> Result<Vec<Bytes>, ServerError> {
261        // TODO: Check if using a staging buffer is useful here.
262        Err(IoError::UnsupportedIoOperation {
263            backtrace: BackTrace::capture(),
264        }
265        .into())
266    }
267
268    fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId) {
269        let stream = self.scheduler.stream(&stream_id);
270        let reserved = stream.empty(size).unwrap();
271        stream.mem_manage.bind(reserved, memory);
272    }
273
274    fn read(
275        &mut self,
276        descriptors: Vec<CopyDescriptor>,
277        stream_id: StreamId,
278    ) -> DynFut<Result<Vec<Bytes>, ServerError>> {
279        let mut streams = vec![stream_id];
280        let mut resources = Vec::with_capacity(descriptors.len());
281        for desc in descriptors {
282            if contiguous_strides(&desc.shape) != desc.strides {
283                return Box::pin(async {
284                    Err(IoError::UnsupportedStrides {
285                        backtrace: BackTrace::capture(),
286                    }
287                    .into())
288                });
289            }
290            if !streams.contains(&desc.handle.stream) {
291                streams.push(desc.handle.stream);
292            }
293            let stream = self.scheduler.stream(&desc.handle.stream);
294            let resource = match stream.mem_manage.get_resource(desc.handle) {
295                Ok(val) => val,
296                Err(err) => return Box::pin(async move { Err(err.into()) }),
297            };
298            resources.push((resource, desc.shape, desc.elem_size));
299        }
300
301        self.scheduler.execute_streams(streams);
302
303        let stream = self.scheduler.stream(&stream_id);
304        stream.read_resources(resources)
305    }
306
307    fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId) {
308        for (desc, data) in descriptors {
309            let stream = self.scheduler.stream(&desc.handle.stream);
310
311            if contiguous_strides(&desc.shape) != desc.strides {
312                stream.error(ServerError::Io(IoError::UnsupportedStrides {
313                    backtrace: BackTrace::capture(),
314                }));
315                return;
316            }
317
318            let resource = match stream.mem_manage.get_resource(desc.handle) {
319                Ok(r) => r,
320                Err(err) => {
321                    stream.error(ServerError::Io(err));
322                    return;
323                }
324            };
325            let task = ScheduleTask::Write {
326                data,
327                buffer: resource,
328            };
329
330            self.scheduler.register(stream_id, task, &[]);
331        }
332    }
333
334    fn get_resource(
335        &mut self,
336        binding: Binding,
337        stream_id: StreamId,
338    ) -> Result<ManagedResource<WgpuResource>, ServerError> {
339        let mut streams = vec![stream_id];
340        if binding.stream != stream_id {
341            streams.push(binding.stream);
342        }
343        self.scheduler.execute_streams(streams);
344        let stream = self.scheduler.stream(&binding.stream);
345        let memory = binding.memory.clone();
346        let resource = stream.mem_manage.get_resource(binding)?;
347
348        Ok(ManagedResource::new(memory, resource))
349    }
350
351    unsafe fn launch(
352        &mut self,
353        kernel: Self::Kernel,
354        count: CubeCount,
355        args: KernelArguments,
356        mode: ExecutionMode,
357        stream_id: StreamId,
358    ) {
359        let pipeline = match self.pipeline(kernel, &args, mode) {
360            Ok(val) => val,
361            Err(err) => {
362                // We make the stream that would execute the kernel in error.
363                let stream = self.scheduler.stream(&stream_id);
364                stream.errors.push(ServerError::Launch(err));
365                return;
366            }
367        };
368
369        self.streams_pool.clear();
370        args.buffers
371            .iter()
372            .for_each(|b| self.streams_pool.push(b.stream));
373
374        let resources = match self.prepare_bindings(args) {
375            Ok(val) => val,
376            Err(err) => {
377                // We make the stream that would execute the kernel in error.
378                let stream = self.scheduler.stream(&stream_id);
379                stream.errors.push(ServerError::Io(err));
380                return;
381            }
382        };
383        let task = ScheduleTask::Execute {
384            pipeline,
385            count,
386            resources,
387        };
388
389        self.scheduler.register(stream_id, task, &self.streams_pool);
390    }
391
392    fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
393        self.scheduler.execute_streams(vec![stream_id]);
394
395        let stream = self.scheduler.stream(&stream_id);
396
397        stream.flush(StreamErrorMode {
398            ignore: false,
399            flush: true,
400        })
401    }
402
403    /// Returns the total time of GPU work this sync completes.
404    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>> {
405        self.scheduler.execute_streams(vec![stream_id]);
406        let stream = self.scheduler.stream(&stream_id);
407
408        stream.sync()
409    }
410
411    fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError> {
412        self.scheduler.execute_streams(vec![stream_id]);
413        let stream = self.scheduler.stream(&stream_id);
414        stream.start_profile()
415    }
416
417    fn end_profile(
418        &mut self,
419        stream_id: StreamId,
420        token: ProfilingToken,
421    ) -> Result<ProfileDuration, ProfileError> {
422        self.scheduler.execute_streams(vec![stream_id]);
423        let stream = self.scheduler.stream(&stream_id);
424
425        stream.end_profile(token)
426    }
427
428    fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError> {
429        self.scheduler.execute_streams(vec![stream_id]);
430        let stream = self.scheduler.stream(&stream_id);
431        Ok(stream.mem_manage.memory_usage())
432    }
433
434    fn memory_cleanup(&mut self, stream_id: StreamId) {
435        self.scheduler.execute_streams(vec![stream_id]);
436        let stream = self.scheduler.stream(&stream_id);
437        stream.mem_manage.memory_cleanup(true);
438    }
439
440    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
441        self.scheduler.execute_streams(vec![stream_id]);
442        let stream = self.scheduler.stream(&stream_id);
443        stream.mem_manage.mode(mode);
444    }
445}
446
447fn compiler(backend: wgpu::Backend, options: &WgpuCompilationOptions) -> AutoCompiler {
448    let _ = options; // Unused without `spirv` feature
449    match backend {
450        #[cfg(feature = "spirv")]
451        wgpu::Backend::Vulkan if options.supports_vulkan => AutoCompiler::SpirV(Default::default()),
452        #[cfg(feature = "msl")]
453        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
454        _ => AutoCompiler::Wgsl(Default::default()),
455    }
456}
457
458pub(crate) fn contiguous_strides(shape: &Shape) -> Strides {
459    let rank = shape.len();
460    let mut strides = strides![1; rank];
461    for i in (0..rank - 1).rev() {
462        strides[i] = strides[i + 1] * shape[i + 1];
463    }
464    strides
465}