cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::backtrace::BackTrace;
6use cubecl_common::bytes::Bytes;
7use cubecl_common::profile::{ProfileDuration, TimingMethod};
8use cubecl_common::stream_id::StreamId;
9use cubecl_core::future::DynFut;
10use cubecl_core::server::{Allocation, AllocationDescriptor, ExecutionError, IoError, LaunchError};
11use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
12use cubecl_core::{
13    MemoryConfiguration, WgpuCompilationOptions,
14    prelude::*,
15    server::{Binding, Bindings, CopyDescriptor},
16};
17use cubecl_runtime::compiler::CompilationError;
18use cubecl_runtime::logging::ServerLogger;
19use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
20use cubecl_runtime::stream::scheduler::{
21    SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
22};
23use cubecl_runtime::{compiler::CubeTask, config::GlobalConfig};
24use cubecl_runtime::{
25    memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
26};
27use hashbrown::HashMap;
28use wgpu::ComputePipeline;
29
30/// Wgpu compute server.
31#[derive(Debug)]
32pub struct WgpuServer {
33    pub(crate) device: wgpu::Device,
34    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
35    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
36    pub compilation_options: WgpuCompilationOptions,
37    pub(crate) backend: wgpu::Backend,
38    pub(crate) utilities: Arc<ServerUtilities<Self>>,
39}
40
41impl ServerCommunication for WgpuServer {
42    const SERVER_COMM_ENABLED: bool = false;
43}
44
45impl WgpuServer {
46    /// Create a new server.
47    #[allow(clippy::too_many_arguments)]
48    pub fn new(
49        memory_properties: MemoryDeviceProperties,
50        memory_config: MemoryConfiguration,
51        compilation_options: WgpuCompilationOptions,
52        device: wgpu::Device,
53        queue: wgpu::Queue,
54        tasks_max: usize,
55        backend: wgpu::Backend,
56        timing_method: TimingMethod,
57        utilities: ServerUtilities<Self>,
58    ) -> Self {
59        let backend_scheduler = ScheduledWgpuBackend::new(
60            device.clone(),
61            queue.clone(),
62            memory_properties,
63            memory_config,
64            timing_method,
65            tasks_max,
66            utilities.logger.clone(),
67        );
68
69        let config = GlobalConfig::get();
70        let max_streams = config.streaming.max_streams;
71
72        Self {
73            compilation_options,
74            device,
75            pipelines: HashMap::new(),
76            scheduler: SchedulerMultiStream::new(
77                utilities.logger.clone(),
78                backend_scheduler,
79                SchedulerMultiStreamOptions {
80                    max_streams,
81                    max_tasks: tasks_max,
82                    strategy: SchedulerStrategy::Interleave,
83                },
84            ),
85            backend,
86            utilities: Arc::new(utilities),
87        }
88    }
89
90    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
91        // Store all the resources we'll be using. This could be eliminated if
92        // there was a way to tie the lifetime of the resource to the memory handle.
93        let resources = bindings
94            .buffers
95            .iter()
96            .map(|b| {
97                let stream = self.scheduler.stream(&b.stream);
98                stream.mem_manage.get_resource(b.clone()).unwrap()
99            })
100            .collect::<Vec<_>>();
101
102        BindingsResource {
103            resources,
104            metadata: bindings.metadata,
105            scalars: bindings.scalars,
106        }
107    }
108
109    fn pipeline(
110        &mut self,
111        kernel: <Self as ComputeServer>::Kernel,
112        mode: ExecutionMode,
113    ) -> Result<Arc<ComputePipeline>, CompilationError> {
114        let mut kernel_id = kernel.id();
115        kernel_id.mode(mode);
116
117        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
118            return Ok(pipeline.clone());
119        }
120
121        let mut compiler = compiler(self.backend);
122        let mut compile = compiler.compile(self, kernel, mode)?;
123
124        if self.scheduler.logger.compilation_activated() {
125            compile.debug_info = Some(DebugInformation::new(
126                compiler.lang_tag(),
127                kernel_id.clone(),
128            ));
129        }
130        self.scheduler.logger.log_compilation(&compile);
131        // /!\ Do not delete the following commented code.
132        // This is useful while working on the metal compiler.
133        // Also the errors are printed nicely which is not the case when this is the runtime
134        // that does it.
135        // println!("SOURCE:\n{}", compile.source);
136        // {
137        //     // Write shader in metal file then compile it for error
138        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
139        //     let _status = std::process::Command::new("xcrun")
140        //         .args(vec![
141        //             "-sdk",
142        //             "macosx",
143        //             "metal",
144        //             "-o",
145        //             "shader.ir",
146        //             "-c",
147        //             "shader.metal",
148        //         ])
149        //         .status()
150        //         .expect("should launch the command");
151        //     // std::process::exit(status.code().unwrap());
152        // }
153        let pipeline = self.create_pipeline(compile, mode)?;
154        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
155
156        Ok(pipeline)
157    }
158}
159
160impl ComputeServer for WgpuServer {
161    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
162    type Storage = WgpuStorage;
163    type Info = wgpu::Backend;
164
165    fn logger(&self) -> Arc<ServerLogger> {
166        self.scheduler.logger.clone()
167    }
168
169    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
170        self.utilities.clone()
171    }
172
173    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
174        // TODO: Check if using a staging buffer is useful here.
175        Err(IoError::UnsupportedIoOperation {
176            backtrace: BackTrace::capture(),
177        })
178    }
179
180    fn create(
181        &mut self,
182        descriptors: Vec<AllocationDescriptor<'_>>,
183        stream_id: StreamId,
184    ) -> Result<Vec<Allocation>, IoError> {
185        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
186        let strides = descriptors
187            .iter()
188            .map(|desc| contiguous_strides(desc.shape))
189            .collect::<Vec<_>>();
190        let sizes = descriptors
191            .iter()
192            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
193            .collect::<Vec<_>>();
194        let total_size = sizes
195            .iter()
196            .map(|it| it.next_multiple_of(align))
197            .sum::<usize>();
198
199        let stream = self.scheduler.stream(&stream_id);
200        let mem_handle = stream.empty(total_size as u64, stream_id)?;
201        let handles = offset_handles(mem_handle, &sizes, align);
202
203        Ok(handles
204            .into_iter()
205            .zip(strides)
206            .map(|(handle, strides)| Allocation::new(handle, strides))
207            .collect())
208    }
209
210    fn read<'a>(
211        &mut self,
212        descriptors: Vec<CopyDescriptor<'a>>,
213        stream_id: StreamId,
214    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
215        let mut streams = vec![stream_id];
216        let mut resources = Vec::with_capacity(descriptors.len());
217        for desc in descriptors {
218            if contiguous_strides(desc.shape) != desc.strides {
219                return Box::pin(async {
220                    Err(IoError::UnsupportedStrides {
221                        backtrace: BackTrace::capture(),
222                    })
223                });
224            }
225            if !streams.contains(&desc.binding.stream) {
226                streams.push(desc.binding.stream);
227            }
228            let stream = self.scheduler.stream(&desc.binding.stream);
229            let resource = match stream.mem_manage.get_resource(desc.binding) {
230                Ok(val) => val,
231                Err(err) => return Box::pin(async move { Err(err) }),
232            };
233            resources.push((resource, desc.shape.to_vec(), desc.elem_size));
234        }
235
236        self.scheduler.execute_streams(streams);
237        let stream = self.scheduler.stream(&stream_id);
238        stream.read_resources(resources)
239    }
240
241    fn write(
242        &mut self,
243        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
244        stream_id: StreamId,
245    ) -> Result<(), IoError> {
246        for (desc, data) in descriptors {
247            if contiguous_strides(desc.shape) != desc.strides {
248                return Err(IoError::UnsupportedStrides {
249                    backtrace: BackTrace::capture(),
250                });
251            }
252
253            let stream = self.scheduler.stream(&desc.binding.stream);
254            let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
255            let task = ScheduleTask::Write {
256                data,
257                buffer: resource,
258            };
259
260            self.scheduler.register(stream_id, task, [].into_iter());
261        }
262
263        Ok(())
264    }
265
266    fn get_resource(
267        &mut self,
268        binding: Binding,
269        stream_id: StreamId,
270    ) -> BindingResource<WgpuResource> {
271        let mut streams = vec![stream_id];
272        if binding.stream != stream_id {
273            streams.push(binding.stream);
274        }
275        self.scheduler.execute_streams(streams);
276        let stream = self.scheduler.stream(&binding.stream);
277        let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
278        BindingResource::new(binding, resource)
279    }
280
281    unsafe fn launch(
282        &mut self,
283        kernel: Self::Kernel,
284        count: CubeCount,
285        bindings: Bindings,
286        mode: ExecutionMode,
287        stream_id: StreamId,
288    ) -> Result<(), LaunchError> {
289        let pipeline = self.pipeline(kernel, mode)?;
290        let buffers = bindings.buffers.clone();
291        let resources = self.prepare_bindings(bindings);
292        let task = ScheduleTask::Execute {
293            pipeline,
294            count,
295            resources,
296        };
297
298        self.scheduler.register(stream_id, task, buffers.iter());
299
300        Ok(())
301    }
302
303    fn flush(&mut self, stream_id: StreamId) {
304        self.scheduler.execute_streams(vec![stream_id]);
305        let stream = self.scheduler.stream(&stream_id);
306        stream.flush()
307    }
308
309    /// Returns the total time of GPU work this sync completes.
310    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
311        self.scheduler.execute_streams(vec![stream_id]);
312        let stream = self.scheduler.stream(&stream_id);
313        stream.sync()
314    }
315
316    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
317        self.scheduler.execute_streams(vec![stream_id]);
318        let stream = self.scheduler.stream(&stream_id);
319        stream.start_profile()
320    }
321
322    fn end_profile(
323        &mut self,
324        stream_id: StreamId,
325        token: ProfilingToken,
326    ) -> Result<ProfileDuration, ProfileError> {
327        self.scheduler.execute_streams(vec![stream_id]);
328        let stream = self.scheduler.stream(&stream_id);
329        stream.end_profile(token)
330    }
331
332    fn memory_usage(
333        &mut self,
334        stream_id: StreamId,
335    ) -> cubecl_runtime::memory_management::MemoryUsage {
336        self.scheduler.execute_streams(vec![stream_id]);
337        let stream = self.scheduler.stream(&stream_id);
338        stream.mem_manage.memory_usage()
339    }
340
341    fn memory_cleanup(&mut self, stream_id: StreamId) {
342        self.scheduler.execute_streams(vec![stream_id]);
343        let stream = self.scheduler.stream(&stream_id);
344        stream.mem_manage.memory_cleanup(true);
345    }
346
347    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
348        self.scheduler.execute_streams(vec![stream_id]);
349        let stream = self.scheduler.stream(&stream_id);
350        stream.mem_manage.mode(mode);
351    }
352}
353
354fn compiler(backend: wgpu::Backend) -> AutoCompiler {
355    match backend {
356        #[cfg(feature = "spirv")]
357        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
358        #[cfg(feature = "msl")]
359        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
360        _ => AutoCompiler::Wgsl(Default::default()),
361    }
362}
363
364pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
365    let rank = shape.len();
366    let mut strides = vec![1; rank];
367    for i in (0..rank - 1).rev() {
368        strides[i] = strides[i + 1] * shape[i + 1];
369    }
370    strides
371}