cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::bytes::Bytes;
6use cubecl_common::profile::{ProfileDuration, TimingMethod};
7use cubecl_common::stream_id::StreamId;
8use cubecl_core::future::DynFut;
9use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
10use cubecl_core::{
11    MemoryConfiguration, WgpuCompilationOptions,
12    prelude::*,
13    server::{Binding, Bindings, CopyDescriptor},
14};
15use cubecl_core::{
16    compute::{CubeTask, DebugInformation},
17    server::{Allocation, AllocationDescriptor, IoError},
18};
19use cubecl_runtime::config::GlobalConfig;
20use cubecl_runtime::logging::ServerLogger;
21use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
22use cubecl_runtime::stream::scheduler::{
23    SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
24};
25use cubecl_runtime::{
26    memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
27};
28use hashbrown::HashMap;
29use wgpu::ComputePipeline;
30
31/// Wgpu compute server.
32#[derive(Debug)]
33pub struct WgpuServer {
34    pub(crate) device: wgpu::Device,
35    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
36    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
37    pub compilation_options: WgpuCompilationOptions,
38    pub(crate) backend: wgpu::Backend,
39    pub(crate) utilities: Arc<ServerUtilities<Self>>,
40}
41
42impl ServerCommunication for WgpuServer {
43    const SERVER_COMM_ENABLED: bool = false;
44}
45
46impl WgpuServer {
47    /// Create a new server.
48    #[allow(clippy::too_many_arguments)]
49    pub fn new(
50        memory_properties: MemoryDeviceProperties,
51        memory_config: MemoryConfiguration,
52        compilation_options: WgpuCompilationOptions,
53        device: wgpu::Device,
54        queue: wgpu::Queue,
55        tasks_max: usize,
56        backend: wgpu::Backend,
57        timing_method: TimingMethod,
58        utilities: ServerUtilities<Self>,
59    ) -> Self {
60        let backend_scheduler = ScheduledWgpuBackend::new(
61            device.clone(),
62            queue.clone(),
63            memory_properties,
64            memory_config,
65            timing_method,
66            tasks_max,
67            utilities.logger.clone(),
68        );
69
70        let config = GlobalConfig::get();
71        let max_streams = config.streaming.max_streams;
72
73        Self {
74            compilation_options,
75            device,
76            pipelines: HashMap::new(),
77            scheduler: SchedulerMultiStream::new(
78                utilities.logger.clone(),
79                backend_scheduler,
80                SchedulerMultiStreamOptions {
81                    max_streams,
82                    max_tasks: tasks_max,
83                    strategy: SchedulerStrategy::Interleave,
84                },
85            ),
86            backend,
87            utilities: Arc::new(utilities),
88        }
89    }
90
91    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
92        // Store all the resources we'll be using. This could be eliminated if
93        // there was a way to tie the lifetime of the resource to the memory handle.
94        let resources = bindings
95            .buffers
96            .iter()
97            .map(|b| {
98                let stream = self.scheduler.stream(&b.stream);
99                stream.mem_manage.get_resource(b.clone())
100            })
101            .collect::<Vec<_>>();
102
103        BindingsResource {
104            resources,
105            metadata: bindings.metadata,
106            scalars: bindings.scalars,
107        }
108    }
109
110    fn pipeline(
111        &mut self,
112        kernel: <Self as ComputeServer>::Kernel,
113        mode: ExecutionMode,
114    ) -> Arc<ComputePipeline> {
115        let mut kernel_id = kernel.id();
116        kernel_id.mode(mode);
117
118        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
119            return pipeline.clone();
120        }
121
122        let mut compiler = compiler(self.backend);
123        let mut compile = compiler.compile(self, kernel, mode);
124
125        if self.scheduler.logger.compilation_activated() {
126            compile.debug_info = Some(DebugInformation::new(
127                compiler.lang_tag(),
128                kernel_id.clone(),
129            ));
130        }
131        self.scheduler.logger.log_compilation(&compile);
132        // /!\ Do not delete the following commented code.
133        // This is useful while working on the metal compiler.
134        // Also the errors are printed nicely which is not the case when this is the runtime
135        // that does it.
136        // println!("SOURCE:\n{}", compile.source);
137        // {
138        //     // Write shader in metal file then compile it for error
139        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
140        //     let _status = std::process::Command::new("xcrun")
141        //         .args(vec![
142        //             "-sdk",
143        //             "macosx",
144        //             "metal",
145        //             "-o",
146        //             "shader.ir",
147        //             "-c",
148        //             "shader.metal",
149        //         ])
150        //         .status()
151        //         .expect("should launch the command");
152        //     // std::process::exit(status.code().unwrap());
153        // }
154        let pipeline = self.create_pipeline(compile, mode);
155        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
156
157        pipeline
158    }
159}
160
161impl ComputeServer for WgpuServer {
162    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
163    type Storage = WgpuStorage;
164    type Info = wgpu::Backend;
165
166    fn logger(&self) -> Arc<ServerLogger> {
167        self.scheduler.logger.clone()
168    }
169
170    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
171        self.utilities.clone()
172    }
173
174    fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
175        // TODO: Check if using a staging buffer is useful here.
176        Err(IoError::UnsupportedIoOperation)
177    }
178
179    fn create(
180        &mut self,
181        descriptors: Vec<AllocationDescriptor<'_>>,
182        stream_id: StreamId,
183    ) -> Result<Vec<Allocation>, IoError> {
184        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
185        let strides = descriptors
186            .iter()
187            .map(|desc| contiguous_strides(desc.shape))
188            .collect::<Vec<_>>();
189        let sizes = descriptors
190            .iter()
191            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
192            .collect::<Vec<_>>();
193        let total_size = sizes
194            .iter()
195            .map(|it| it.next_multiple_of(align))
196            .sum::<usize>();
197
198        let stream = self.scheduler.stream(&stream_id);
199        let mem_handle = stream.empty(total_size as u64, stream_id)?;
200        let handles = offset_handles(mem_handle, &sizes, align);
201
202        Ok(handles
203            .into_iter()
204            .zip(strides)
205            .map(|(handle, strides)| Allocation::new(handle, strides))
206            .collect())
207    }
208
209    fn read<'a>(
210        &mut self,
211        descriptors: Vec<CopyDescriptor<'a>>,
212        stream_id: StreamId,
213    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
214        let mut streams = vec![stream_id];
215        let mut resources = Vec::with_capacity(descriptors.len());
216        for desc in descriptors {
217            if contiguous_strides(desc.shape) != desc.strides {
218                return Box::pin(async { Err(IoError::UnsupportedStrides) });
219            }
220            if !streams.contains(&desc.binding.stream) {
221                streams.push(desc.binding.stream);
222            }
223            let stream = self.scheduler.stream(&desc.binding.stream);
224            let resource = stream.mem_manage.get_resource(desc.binding);
225            resources.push((resource, desc.shape.to_vec(), desc.elem_size));
226        }
227
228        self.scheduler.execute_streams(streams);
229        let stream = self.scheduler.stream(&stream_id);
230        stream.read_resources(resources)
231    }
232
233    fn write(
234        &mut self,
235        descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
236        stream_id: StreamId,
237    ) -> Result<(), IoError> {
238        for (desc, data) in descriptors {
239            if contiguous_strides(desc.shape) != desc.strides {
240                return Err(IoError::UnsupportedStrides);
241            }
242
243            let stream = self.scheduler.stream(&desc.binding.stream);
244            let resource = stream.mem_manage.get_resource(desc.binding.clone());
245            let task = ScheduleTask::Write {
246                data,
247                buffer: resource,
248            };
249
250            self.scheduler.register(stream_id, task, [].into_iter());
251        }
252
253        Ok(())
254    }
255
256    fn get_resource(
257        &mut self,
258        binding: Binding,
259        stream_id: StreamId,
260    ) -> BindingResource<WgpuResource> {
261        let mut streams = vec![stream_id];
262        if binding.stream != stream_id {
263            streams.push(binding.stream);
264        }
265        self.scheduler.execute_streams(streams);
266        let stream = self.scheduler.stream(&binding.stream);
267        let resource = stream.mem_manage.get_resource(binding.clone());
268        BindingResource::new(binding, resource)
269    }
270
271    unsafe fn execute(
272        &mut self,
273        kernel: Self::Kernel,
274        count: CubeCount,
275        bindings: Bindings,
276        mode: ExecutionMode,
277        stream_id: StreamId,
278    ) {
279        let pipeline = self.pipeline(kernel, mode);
280        let buffers = bindings.buffers.clone();
281        let resources = self.prepare_bindings(bindings);
282        let task = ScheduleTask::Execute {
283            pipeline,
284            count,
285            resources,
286        };
287
288        self.scheduler.register(stream_id, task, buffers.iter());
289    }
290
291    fn flush(&mut self, stream_id: StreamId) {
292        self.scheduler.execute_streams(vec![stream_id]);
293        let stream = self.scheduler.stream(&stream_id);
294        stream.flush()
295    }
296
297    /// Returns the total time of GPU work this sync completes.
298    fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
299        self.scheduler.execute_streams(vec![stream_id]);
300        let stream = self.scheduler.stream(&stream_id);
301        stream.sync()
302    }
303
304    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
305        self.scheduler.execute_streams(vec![stream_id]);
306        let stream = self.scheduler.stream(&stream_id);
307        stream.start_profile()
308    }
309
310    fn end_profile(
311        &mut self,
312        stream_id: StreamId,
313        token: ProfilingToken,
314    ) -> Result<ProfileDuration, ProfileError> {
315        self.scheduler.execute_streams(vec![stream_id]);
316        let stream = self.scheduler.stream(&stream_id);
317        stream.end_profile(token)
318    }
319
320    fn memory_usage(
321        &mut self,
322        stream_id: StreamId,
323    ) -> cubecl_runtime::memory_management::MemoryUsage {
324        self.scheduler.execute_streams(vec![stream_id]);
325        let stream = self.scheduler.stream(&stream_id);
326        stream.mem_manage.memory_usage()
327    }
328
329    fn memory_cleanup(&mut self, stream_id: StreamId) {
330        self.scheduler.execute_streams(vec![stream_id]);
331        let stream = self.scheduler.stream(&stream_id);
332        stream.mem_manage.memory_cleanup(true);
333    }
334
335    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
336        self.scheduler.execute_streams(vec![stream_id]);
337        let stream = self.scheduler.stream(&stream_id);
338        stream.mem_manage.mode(mode);
339    }
340}
341
342fn compiler(backend: wgpu::Backend) -> AutoCompiler {
343    match backend {
344        #[cfg(feature = "spirv")]
345        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
346        #[cfg(feature = "msl")]
347        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
348        _ => AutoCompiler::Wgsl(Default::default()),
349    }
350}
351
352pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
353    let rank = shape.len();
354    let mut strides = vec![1; rank];
355    for i in (0..rank - 1).rev() {
356        strides[i] = strides[i + 1] * shape[i + 1];
357    }
358    strides
359}