cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::bytes::Bytes;
6use cubecl_common::profile::{ProfileDuration, TimingMethod};
7use cubecl_common::stream_id::StreamId;
8use cubecl_core::future::DynFut;
9use cubecl_core::server::{Allocation, AllocationDescriptor, ExecutionError, IoError, LaunchError};
10use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
11use cubecl_core::{
12    MemoryConfiguration, WgpuCompilationOptions,
13    prelude::*,
14    server::{Binding, Bindings, CopyDescriptor},
15};
16use cubecl_runtime::compiler::CompilationError;
17use cubecl_runtime::logging::ServerLogger;
18use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
19use cubecl_runtime::stream::scheduler::{
20    SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
21};
22use cubecl_runtime::{compiler::CubeTask, config::GlobalConfig};
23use cubecl_runtime::{
24    memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
25};
26use hashbrown::HashMap;
27use wgpu::ComputePipeline;
28
29/// Wgpu compute server.
30#[derive(Debug)]
31pub struct WgpuServer {
32    pub(crate) device: wgpu::Device,
33    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
34    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
35    pub compilation_options: WgpuCompilationOptions,
36    pub(crate) backend: wgpu::Backend,
37    pub(crate) utilities: Arc<ServerUtilities<Self>>,
38}
39
40impl ServerCommunication for WgpuServer {
41    const SERVER_COMM_ENABLED: bool = false;
42}
43
44impl WgpuServer {
45    /// Create a new server.
46    #[allow(clippy::too_many_arguments)]
47    pub fn new(
48        memory_properties: MemoryDeviceProperties,
49        memory_config: MemoryConfiguration,
50        compilation_options: WgpuCompilationOptions,
51        device: wgpu::Device,
52        queue: wgpu::Queue,
53        tasks_max: usize,
54        backend: wgpu::Backend,
55        timing_method: TimingMethod,
56        utilities: ServerUtilities<Self>,
57    ) -> Self {
58        let backend_scheduler = ScheduledWgpuBackend::new(
59            device.clone(),
60            queue.clone(),
61            memory_properties,
62            memory_config,
63            timing_method,
64            tasks_max,
65            utilities.logger.clone(),
66        );
67
68        let config = GlobalConfig::get();
69        let max_streams = config.streaming.max_streams;
70
71        Self {
72            compilation_options,
73            device,
74            pipelines: HashMap::new(),
75            scheduler: SchedulerMultiStream::new(
76                utilities.logger.clone(),
77                backend_scheduler,
78                SchedulerMultiStreamOptions {
79                    max_streams,
80                    max_tasks: tasks_max,
81                    strategy: SchedulerStrategy::Interleave,
82                },
83            ),
84            backend,
85            utilities: Arc::new(utilities),
86        }
87    }
88
89    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
90        // Store all the resources we'll be using. This could be eliminated if
91        // there was a way to tie the lifetime of the resource to the memory handle.
92        let resources = bindings
93            .buffers
94            .iter()
95            .map(|b| {
96                let stream = self.scheduler.stream(&b.stream);
97                stream.mem_manage.get_resource(b.clone())
98            })
99            .collect::<Vec<_>>();
100
101        BindingsResource {
102            resources,
103            metadata: bindings.metadata,
104            scalars: bindings.scalars,
105        }
106    }
107
108    fn pipeline(
109        &mut self,
110        kernel: <Self as ComputeServer>::Kernel,
111        mode: ExecutionMode,
112    ) -> Result<Arc<ComputePipeline>, CompilationError> {
113        let mut kernel_id = kernel.id();
114        kernel_id.mode(mode);
115
116        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
117            return Ok(pipeline.clone());
118        }
119
120        let mut compiler = compiler(self.backend);
121        let mut compile = compiler.compile(self, kernel, mode)?;
122
123        if self.scheduler.logger.compilation_activated() {
124            compile.debug_info = Some(DebugInformation::new(
125                compiler.lang_tag(),
126                kernel_id.clone(),
127            ));
128        }
129        self.scheduler.logger.log_compilation(&compile);
130        // /!\ Do not delete the following commented code.
131        // This is useful while working on the metal compiler.
132        // Also the errors are printed nicely which is not the case when this is the runtime
133        // that does it.
134        // println!("SOURCE:\n{}", compile.source);
135        // {
136        //     // Write shader in metal file then compile it for error
137        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
138        //     let _status = std::process::Command::new("xcrun")
139        //         .args(vec![
140        //             "-sdk",
141        //             "macosx",
142        //             "metal",
143        //             "-o",
144        //             "shader.ir",
145        //             "-c",
146        //             "shader.metal",
147        //         ])
148        //         .status()
149        //         .expect("should launch the command");
150        //     // std::process::exit(status.code().unwrap());
151        // }
152        let pipeline = self.create_pipeline(compile, mode)?;
153        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
154
155        Ok(pipeline)
156    }
157}
158
159impl ComputeServer for WgpuServer {
160    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
161    type Storage = WgpuStorage;
162    type Info = wgpu::Backend;
163
164    fn logger(&self) -> Arc<ServerLogger> {
165        self.scheduler.logger.clone()
166    }
167
168    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
169        self.utilities.clone()
170    }
171
172    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
173        // TODO: Check if using a staging buffer is useful here.
174        Err(IoError::UnsupportedIoOperation)
175    }
176
177    fn create(
178        &mut self,
179        descriptors: Vec<AllocationDescriptor<'_>>,
180        stream_id: StreamId,
181    ) -> Result<Vec<Allocation>, IoError> {
182        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
183        let strides = descriptors
184            .iter()
185            .map(|desc| contiguous_strides(desc.shape))
186            .collect::<Vec<_>>();
187        let sizes = descriptors
188            .iter()
189            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
190            .collect::<Vec<_>>();
191        let total_size = sizes
192            .iter()
193            .map(|it| it.next_multiple_of(align))
194            .sum::<usize>();
195
196        let stream = self.scheduler.stream(&stream_id);
197        let mem_handle = stream.empty(total_size as u64, stream_id)?;
198        let handles = offset_handles(mem_handle, &sizes, align);
199
200        Ok(handles
201            .into_iter()
202            .zip(strides)
203            .map(|(handle, strides)| Allocation::new(handle, strides))
204            .collect())
205    }
206
207    fn read<'a>(
208        &mut self,
209        descriptors: Vec<CopyDescriptor<'a>>,
210        stream_id: StreamId,
211    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
212        let mut streams = vec![stream_id];
213        let mut resources = Vec::with_capacity(descriptors.len());
214        for desc in descriptors {
215            if contiguous_strides(desc.shape) != desc.strides {
216                return Box::pin(async { Err(IoError::UnsupportedStrides) });
217            }
218            if !streams.contains(&desc.binding.stream) {
219                streams.push(desc.binding.stream);
220            }
221            let stream = self.scheduler.stream(&desc.binding.stream);
222            let resource = stream.mem_manage.get_resource(desc.binding);
223            resources.push((resource, desc.shape.to_vec(), desc.elem_size));
224        }
225
226        self.scheduler.execute_streams(streams);
227        let stream = self.scheduler.stream(&stream_id);
228        stream.read_resources(resources)
229    }
230
231    fn write(
232        &mut self,
233        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
234        stream_id: StreamId,
235    ) -> Result<(), IoError> {
236        for (desc, data) in descriptors {
237            if contiguous_strides(desc.shape) != desc.strides {
238                return Err(IoError::UnsupportedStrides);
239            }
240
241            let stream = self.scheduler.stream(&desc.binding.stream);
242            let resource = stream.mem_manage.get_resource(desc.binding.clone());
243            let task = ScheduleTask::Write {
244                data,
245                buffer: resource,
246            };
247
248            self.scheduler.register(stream_id, task, [].into_iter());
249        }
250
251        Ok(())
252    }
253
254    fn get_resource(
255        &mut self,
256        binding: Binding,
257        stream_id: StreamId,
258    ) -> BindingResource<WgpuResource> {
259        let mut streams = vec![stream_id];
260        if binding.stream != stream_id {
261            streams.push(binding.stream);
262        }
263        self.scheduler.execute_streams(streams);
264        let stream = self.scheduler.stream(&binding.stream);
265        let resource = stream.mem_manage.get_resource(binding.clone());
266        BindingResource::new(binding, resource)
267    }
268
269    unsafe fn launch(
270        &mut self,
271        kernel: Self::Kernel,
272        count: CubeCount,
273        bindings: Bindings,
274        mode: ExecutionMode,
275        stream_id: StreamId,
276    ) -> Result<(), LaunchError> {
277        let pipeline = self.pipeline(kernel, mode)?;
278        let buffers = bindings.buffers.clone();
279        let resources = self.prepare_bindings(bindings);
280        let task = ScheduleTask::Execute {
281            pipeline,
282            count,
283            resources,
284        };
285
286        self.scheduler.register(stream_id, task, buffers.iter());
287
288        Ok(())
289    }
290
291    fn flush(&mut self, stream_id: StreamId) {
292        self.scheduler.execute_streams(vec![stream_id]);
293        let stream = self.scheduler.stream(&stream_id);
294        stream.flush()
295    }
296
297    /// Returns the total time of GPU work this sync completes.
298    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
299        self.scheduler.execute_streams(vec![stream_id]);
300        let stream = self.scheduler.stream(&stream_id);
301        stream.sync()
302    }
303
304    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
305        self.scheduler.execute_streams(vec![stream_id]);
306        let stream = self.scheduler.stream(&stream_id);
307        stream.start_profile()
308    }
309
310    fn end_profile(
311        &mut self,
312        stream_id: StreamId,
313        token: ProfilingToken,
314    ) -> Result<ProfileDuration, ProfileError> {
315        self.scheduler.execute_streams(vec![stream_id]);
316        let stream = self.scheduler.stream(&stream_id);
317        stream.end_profile(token)
318    }
319
320    fn memory_usage(
321        &mut self,
322        stream_id: StreamId,
323    ) -> cubecl_runtime::memory_management::MemoryUsage {
324        self.scheduler.execute_streams(vec![stream_id]);
325        let stream = self.scheduler.stream(&stream_id);
326        stream.mem_manage.memory_usage()
327    }
328
329    fn memory_cleanup(&mut self, stream_id: StreamId) {
330        self.scheduler.execute_streams(vec![stream_id]);
331        let stream = self.scheduler.stream(&stream_id);
332        stream.mem_manage.memory_cleanup(true);
333    }
334
335    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
336        self.scheduler.execute_streams(vec![stream_id]);
337        let stream = self.scheduler.stream(&stream_id);
338        stream.mem_manage.mode(mode);
339    }
340}
341
342fn compiler(backend: wgpu::Backend) -> AutoCompiler {
343    match backend {
344        #[cfg(feature = "spirv")]
345        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
346        #[cfg(feature = "msl")]
347        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
348        _ => AutoCompiler::Wgsl(Default::default()),
349    }
350}
351
352pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
353    let rank = shape.len();
354    let mut strides = vec![1; rank];
355    for i in (0..rank - 1).rev() {
356        strides[i] = strides[i + 1] * shape[i + 1];
357    }
358    strides
359}