cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::{
6    backtrace::BackTrace,
7    bytes::Bytes,
8    profile::{ProfileDuration, TimingMethod},
9    stream_id::StreamId,
10};
11use cubecl_core::{
12    MemoryConfiguration, WgpuCompilationOptions,
13    future::DynFut,
14    prelude::*,
15    server::{
16        Allocation, AllocationDescriptor, Binding, Bindings, CopyDescriptor, ExecutionError,
17        IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities,
18    },
19};
20use cubecl_ir::MemoryDeviceProperties;
21use cubecl_runtime::{
22    compiler::{CompilationError, CubeTask},
23    config::GlobalConfig,
24    logging::ServerLogger,
25    memory_management::{MemoryAllocationMode, offset_handles},
26    server::ComputeServer,
27    storage::BindingResource,
28    stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
29};
30use hashbrown::HashMap;
31use wgpu::ComputePipeline;
32
33/// Wgpu compute server.
34#[derive(Debug)]
35pub struct WgpuServer {
36    pub(crate) device: wgpu::Device,
37    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
38    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
39    pub compilation_options: WgpuCompilationOptions,
40    pub(crate) backend: wgpu::Backend,
41    pub(crate) utilities: Arc<ServerUtilities<Self>>,
42}
43
44impl ServerCommunication for WgpuServer {
45    const SERVER_COMM_ENABLED: bool = false;
46}
47
48impl WgpuServer {
49    /// Create a new server.
50    #[allow(clippy::too_many_arguments)]
51    pub fn new(
52        memory_properties: MemoryDeviceProperties,
53        memory_config: MemoryConfiguration,
54        compilation_options: WgpuCompilationOptions,
55        device: wgpu::Device,
56        queue: wgpu::Queue,
57        tasks_max: usize,
58        backend: wgpu::Backend,
59        timing_method: TimingMethod,
60        utilities: ServerUtilities<Self>,
61    ) -> Self {
62        let backend_scheduler = ScheduledWgpuBackend::new(
63            device.clone(),
64            queue.clone(),
65            memory_properties,
66            memory_config,
67            timing_method,
68            tasks_max,
69            utilities.logger.clone(),
70        );
71
72        let config = GlobalConfig::get();
73        let max_streams = config.streaming.max_streams;
74
75        Self {
76            compilation_options,
77            device,
78            pipelines: HashMap::new(),
79            scheduler: SchedulerMultiStream::new(
80                utilities.logger.clone(),
81                backend_scheduler,
82                SchedulerMultiStreamOptions {
83                    max_streams,
84                    max_tasks: tasks_max,
85                    strategy: SchedulerStrategy::Interleave,
86                },
87            ),
88            backend,
89            utilities: Arc::new(utilities),
90        }
91    }
92
93    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
94        // Store all the resources we'll be using. This could be eliminated if
95        // there was a way to tie the lifetime of the resource to the memory handle.
96        let resources = bindings
97            .buffers
98            .iter()
99            .map(|b| {
100                let stream = self.scheduler.stream(&b.stream);
101                stream.mem_manage.get_resource(b.clone()).unwrap()
102            })
103            .collect::<Vec<_>>();
104
105        BindingsResource {
106            resources,
107            metadata: bindings.metadata,
108            scalars: bindings.scalars,
109        }
110    }
111
112    fn pipeline(
113        &mut self,
114        kernel: <Self as ComputeServer>::Kernel,
115        mode: ExecutionMode,
116    ) -> Result<Arc<ComputePipeline>, CompilationError> {
117        let mut kernel_id = kernel.id();
118        kernel_id.mode(mode);
119
120        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
121            return Ok(pipeline.clone());
122        }
123
124        let mut compiler = compiler(self.backend);
125        let mut compile = compiler.compile(self, kernel, mode)?;
126
127        if self.scheduler.logger.compilation_activated() {
128            compile.debug_info = Some(DebugInformation::new(
129                compiler.lang_tag(),
130                kernel_id.clone(),
131            ));
132        }
133        self.scheduler.logger.log_compilation(&compile);
134        // /!\ Do not delete the following commented code.
135        // This is useful while working on the metal compiler.
136        // Also the errors are printed nicely which is not the case when this is the runtime
137        // that does it.
138        // println!("SOURCE:\n{}", compile.source);
139        // {
140        //     // Write shader in metal file then compile it for error
141        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
142        //     let _status = std::process::Command::new("xcrun")
143        //         .args(vec![
144        //             "-sdk",
145        //             "macosx",
146        //             "metal",
147        //             "-o",
148        //             "shader.ir",
149        //             "-c",
150        //             "shader.metal",
151        //         ])
152        //         .status()
153        //         .expect("should launch the command");
154        //     // std::process::exit(status.code().unwrap());
155        // }
156        let pipeline = self.create_pipeline(compile, mode)?;
157        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
158
159        Ok(pipeline)
160    }
161}
162
163impl ComputeServer for WgpuServer {
164    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
165    type Storage = WgpuStorage;
166    type Info = wgpu::Backend;
167
168    fn logger(&self) -> Arc<ServerLogger> {
169        self.scheduler.logger.clone()
170    }
171
172    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
173        self.utilities.clone()
174    }
175
176    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
177        // TODO: Check if using a staging buffer is useful here.
178        Err(IoError::UnsupportedIoOperation {
179            backtrace: BackTrace::capture(),
180        })
181    }
182
183    fn create(
184        &mut self,
185        descriptors: Vec<AllocationDescriptor<'_>>,
186        stream_id: StreamId,
187    ) -> Result<Vec<Allocation>, IoError> {
188        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
189        let strides = descriptors
190            .iter()
191            .map(|desc| contiguous_strides(desc.shape))
192            .collect::<Vec<_>>();
193        let sizes = descriptors
194            .iter()
195            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
196            .collect::<Vec<_>>();
197        let total_size = sizes
198            .iter()
199            .map(|it| it.next_multiple_of(align))
200            .sum::<usize>();
201
202        let stream = self.scheduler.stream(&stream_id);
203        let mem_handle = stream.empty(total_size as u64, stream_id)?;
204        let handles = offset_handles(mem_handle, &sizes, align);
205
206        Ok(handles
207            .into_iter()
208            .zip(strides)
209            .map(|(handle, strides)| Allocation::new(handle, strides))
210            .collect())
211    }
212
213    fn read<'a>(
214        &mut self,
215        descriptors: Vec<CopyDescriptor<'a>>,
216        stream_id: StreamId,
217    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
218        let mut streams = vec![stream_id];
219        let mut resources = Vec::with_capacity(descriptors.len());
220        for desc in descriptors {
221            if contiguous_strides(desc.shape) != desc.strides {
222                return Box::pin(async {
223                    Err(IoError::UnsupportedStrides {
224                        backtrace: BackTrace::capture(),
225                    })
226                });
227            }
228            if !streams.contains(&desc.binding.stream) {
229                streams.push(desc.binding.stream);
230            }
231            let stream = self.scheduler.stream(&desc.binding.stream);
232            let resource = match stream.mem_manage.get_resource(desc.binding) {
233                Ok(val) => val,
234                Err(err) => return Box::pin(async move { Err(err) }),
235            };
236            resources.push((resource, desc.shape.to_vec(), desc.elem_size));
237        }
238
239        self.scheduler.execute_streams(streams);
240        let stream = self.scheduler.stream(&stream_id);
241        stream.read_resources(resources)
242    }
243
244    fn write(
245        &mut self,
246        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
247        stream_id: StreamId,
248    ) -> Result<(), IoError> {
249        for (desc, data) in descriptors {
250            if contiguous_strides(desc.shape) != desc.strides {
251                return Err(IoError::UnsupportedStrides {
252                    backtrace: BackTrace::capture(),
253                });
254            }
255
256            let stream = self.scheduler.stream(&desc.binding.stream);
257            let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
258            let task = ScheduleTask::Write {
259                data,
260                buffer: resource,
261            };
262
263            self.scheduler.register(stream_id, task, [].into_iter());
264        }
265
266        Ok(())
267    }
268
269    fn get_resource(
270        &mut self,
271        binding: Binding,
272        stream_id: StreamId,
273    ) -> BindingResource<WgpuResource> {
274        let mut streams = vec![stream_id];
275        if binding.stream != stream_id {
276            streams.push(binding.stream);
277        }
278        self.scheduler.execute_streams(streams);
279        let stream = self.scheduler.stream(&binding.stream);
280        let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
281        BindingResource::new(binding, resource)
282    }
283
284    unsafe fn launch(
285        &mut self,
286        kernel: Self::Kernel,
287        count: CubeCount,
288        bindings: Bindings,
289        mode: ExecutionMode,
290        stream_id: StreamId,
291    ) -> Result<(), LaunchError> {
292        let pipeline = self.pipeline(kernel, mode)?;
293        let buffers = bindings.buffers.clone();
294        let resources = self.prepare_bindings(bindings);
295        let task = ScheduleTask::Execute {
296            pipeline,
297            count,
298            resources,
299        };
300
301        self.scheduler.register(stream_id, task, buffers.iter());
302
303        Ok(())
304    }
305
306    fn flush(&mut self, stream_id: StreamId) {
307        self.scheduler.execute_streams(vec![stream_id]);
308        let stream = self.scheduler.stream(&stream_id);
309        stream.flush()
310    }
311
312    /// Returns the total time of GPU work this sync completes.
313    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
314        self.scheduler.execute_streams(vec![stream_id]);
315        let stream = self.scheduler.stream(&stream_id);
316        stream.sync()
317    }
318
319    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
320        self.scheduler.execute_streams(vec![stream_id]);
321        let stream = self.scheduler.stream(&stream_id);
322        stream.start_profile()
323    }
324
325    fn end_profile(
326        &mut self,
327        stream_id: StreamId,
328        token: ProfilingToken,
329    ) -> Result<ProfileDuration, ProfileError> {
330        self.scheduler.execute_streams(vec![stream_id]);
331        let stream = self.scheduler.stream(&stream_id);
332        stream.end_profile(token)
333    }
334
335    fn memory_usage(
336        &mut self,
337        stream_id: StreamId,
338    ) -> cubecl_runtime::memory_management::MemoryUsage {
339        self.scheduler.execute_streams(vec![stream_id]);
340        let stream = self.scheduler.stream(&stream_id);
341        stream.mem_manage.memory_usage()
342    }
343
344    fn memory_cleanup(&mut self, stream_id: StreamId) {
345        self.scheduler.execute_streams(vec![stream_id]);
346        let stream = self.scheduler.stream(&stream_id);
347        stream.mem_manage.memory_cleanup(true);
348    }
349
350    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
351        self.scheduler.execute_streams(vec![stream_id]);
352        let stream = self.scheduler.stream(&stream_id);
353        stream.mem_manage.mode(mode);
354    }
355}
356
357fn compiler(backend: wgpu::Backend) -> AutoCompiler {
358    match backend {
359        #[cfg(feature = "spirv")]
360        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
361        #[cfg(feature = "msl")]
362        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
363        _ => AutoCompiler::Wgsl(Default::default()),
364    }
365}
366
367pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
368    let rank = shape.len();
369    let mut strides = vec![1; rank];
370    for i in (0..rank - 1).rev() {
371        strides[i] = strides[i + 1] * shape[i + 1];
372    }
373    strides
374}