cubecl_wgpu/compute/
server.rs

1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::bytes::Bytes;
6use cubecl_common::profile::{ProfileDuration, TimingMethod};
7use cubecl_common::stream_id::StreamId;
8use cubecl_core::future::DynFut;
9use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
10use cubecl_core::{
11    MemoryConfiguration, WgpuCompilationOptions,
12    prelude::*,
13    server::{Binding, Bindings, CopyDescriptor},
14};
15use cubecl_core::{
16    compute::{CubeTask, DebugInformation},
17    server::{Allocation, AllocationDescriptor, IoError},
18};
19use cubecl_runtime::config::GlobalConfig;
20use cubecl_runtime::logging::ServerLogger;
21use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
22use cubecl_runtime::stream::scheduler::{
23    SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
24};
25use cubecl_runtime::{
26    memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
27};
28use hashbrown::HashMap;
29use wgpu::ComputePipeline;
30
31/// Wgpu compute server.
32#[derive(Debug)]
33pub struct WgpuServer {
34    pub(crate) device: wgpu::Device,
35    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
36    scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
37    pub compilation_options: WgpuCompilationOptions,
38    pub(crate) backend: wgpu::Backend,
39    pub(crate) utilities: Arc<ServerUtilities<Self>>,
40}
41
42impl ServerCommunication for WgpuServer {
43    const SERVER_COMM_ENABLED: bool = false;
44}
45
46impl WgpuServer {
47    /// Create a new server.
48    #[allow(clippy::too_many_arguments)]
49    pub fn new(
50        memory_properties: MemoryDeviceProperties,
51        memory_config: MemoryConfiguration,
52        compilation_options: WgpuCompilationOptions,
53        device: wgpu::Device,
54        queue: wgpu::Queue,
55        tasks_max: usize,
56        backend: wgpu::Backend,
57        timing_method: TimingMethod,
58        utilities: ServerUtilities<Self>,
59    ) -> Self {
60        let backend_scheduler = ScheduledWgpuBackend::new(
61            device.clone(),
62            queue.clone(),
63            memory_properties,
64            memory_config,
65            timing_method,
66            tasks_max,
67            utilities.logger.clone(),
68        );
69
70        let config = GlobalConfig::get();
71        let max_streams = config.streaming.max_streams;
72
73        Self {
74            compilation_options,
75            device,
76            pipelines: HashMap::new(),
77            scheduler: SchedulerMultiStream::new(
78                utilities.logger.clone(),
79                backend_scheduler,
80                SchedulerMultiStreamOptions {
81                    max_streams,
82                    max_tasks: tasks_max,
83                    strategy: SchedulerStrategy::Interleave,
84                },
85            ),
86            backend,
87            utilities: Arc::new(utilities),
88        }
89    }
90
91    fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
92        // Store all the resources we'll be using. This could be eliminated if
93        // there was a way to tie the lifetime of the resource to the memory handle.
94        let resources = bindings
95            .buffers
96            .iter()
97            .map(|b| {
98                let stream = self.scheduler.stream(&b.stream);
99                stream.mem_manage.get_resource(b.clone())
100            })
101            .collect::<Vec<_>>();
102
103        BindingsResource {
104            resources,
105            metadata: bindings.metadata,
106            scalars: bindings.scalars,
107        }
108    }
109
110    fn pipeline(
111        &mut self,
112        kernel: <Self as ComputeServer>::Kernel,
113        mode: ExecutionMode,
114    ) -> Arc<ComputePipeline> {
115        let mut kernel_id = kernel.id();
116        kernel_id.mode(mode);
117
118        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
119            return pipeline.clone();
120        }
121
122        let mut compiler = compiler(self.backend);
123        let mut compile = compiler.compile(self, kernel, mode);
124
125        if self.scheduler.logger.compilation_activated() {
126            compile.debug_info = Some(DebugInformation::new(
127                compiler.lang_tag(),
128                kernel_id.clone(),
129            ));
130        }
131        self.scheduler.logger.log_compilation(&compile);
132        // /!\ Do not delete the following commented code.
133        // This is useful while working on the metal compiler.
134        // Also the errors are printed nicely which is not the case when this is the runtime
135        // that does it.
136        // println!("SOURCE:\n{}", compile.source);
137        // {
138        //     // Write shader in metal file then compile it for error
139        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
140        //     let _status = std::process::Command::new("xcrun")
141        //         .args(vec![
142        //             "-sdk",
143        //             "macosx",
144        //             "metal",
145        //             "-o",
146        //             "shader.ir",
147        //             "-c",
148        //             "shader.metal",
149        //         ])
150        //         .status()
151        //         .expect("should launch the command");
152        //     // std::process::exit(status.code().unwrap());
153        // }
154        let pipeline = self.create_pipeline(compile, mode);
155        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
156
157        pipeline
158    }
159}
160
161impl ComputeServer for WgpuServer {
162    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
163    type Storage = WgpuStorage;
164    type Info = wgpu::Backend;
165
166    fn logger(&self) -> Arc<ServerLogger> {
167        self.scheduler.logger.clone()
168    }
169
170    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
171        self.utilities.clone()
172    }
173
174    fn create(
175        &mut self,
176        descriptors: Vec<AllocationDescriptor<'_>>,
177        stream_id: StreamId,
178    ) -> Result<Vec<Allocation>, IoError> {
179        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
180        let strides = descriptors
181            .iter()
182            .map(|desc| contiguous_strides(desc.shape))
183            .collect::<Vec<_>>();
184        let sizes = descriptors
185            .iter()
186            .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
187            .collect::<Vec<_>>();
188        let total_size = sizes
189            .iter()
190            .map(|it| it.next_multiple_of(align))
191            .sum::<usize>();
192
193        let stream = self.scheduler.stream(&stream_id);
194        let mem_handle = stream.empty(total_size as u64, stream_id)?;
195        let handles = offset_handles(mem_handle, &sizes, align);
196
197        Ok(handles
198            .into_iter()
199            .zip(strides)
200            .map(|(handle, strides)| Allocation::new(handle, strides))
201            .collect())
202    }
203
204    fn read<'a>(
205        &mut self,
206        descriptors: Vec<CopyDescriptor<'a>>,
207        stream_id: StreamId,
208    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
209        let mut streams = vec![stream_id];
210        let mut resources = Vec::with_capacity(descriptors.len());
211        for desc in descriptors {
212            if contiguous_strides(desc.shape) != desc.strides {
213                return Box::pin(async { Err(IoError::UnsupportedStrides) });
214            }
215            if !streams.contains(&desc.binding.stream) {
216                streams.push(desc.binding.stream);
217            }
218            let stream = self.scheduler.stream(&desc.binding.stream);
219            let resource = stream.mem_manage.get_resource(desc.binding);
220            resources.push((resource, desc.shape.to_vec(), desc.elem_size));
221        }
222
223        self.scheduler.execute_streams(streams);
224        let stream = self.scheduler.stream(&stream_id);
225        stream.read_resources(resources)
226    }
227
228    fn write(
229        &mut self,
230        descriptors: Vec<(CopyDescriptor<'_>, &[u8])>,
231        stream_id: StreamId,
232    ) -> Result<(), IoError> {
233        for (desc, data) in descriptors {
234            if contiguous_strides(desc.shape) != desc.strides {
235                return Err(IoError::UnsupportedStrides);
236            }
237
238            let stream = self.scheduler.stream(&desc.binding.stream);
239            let resource = stream.mem_manage.get_resource(desc.binding.clone());
240            let task = ScheduleTask::Write {
241                data: data.to_vec(),
242                buffer: resource,
243            };
244
245            self.scheduler.register(stream_id, task, [].into_iter());
246        }
247
248        Ok(())
249    }
250
251    fn get_resource(
252        &mut self,
253        binding: Binding,
254        stream_id: StreamId,
255    ) -> BindingResource<WgpuResource> {
256        let mut streams = vec![stream_id];
257        if binding.stream != stream_id {
258            streams.push(binding.stream);
259        }
260        self.scheduler.execute_streams(streams);
261        let stream = self.scheduler.stream(&binding.stream);
262        let resource = stream.mem_manage.get_resource(binding.clone());
263        BindingResource::new(binding, resource)
264    }
265
266    unsafe fn execute(
267        &mut self,
268        kernel: Self::Kernel,
269        count: CubeCount,
270        bindings: Bindings,
271        mode: ExecutionMode,
272        stream_id: StreamId,
273    ) {
274        let pipeline = self.pipeline(kernel, mode);
275        let buffers = bindings.buffers.clone();
276        let resources = self.prepare_bindings(bindings);
277        let task = ScheduleTask::Execute {
278            pipeline,
279            count,
280            resources,
281        };
282
283        self.scheduler.register(stream_id, task, buffers.iter());
284    }
285
286    fn flush(&mut self, stream_id: StreamId) {
287        self.scheduler.execute_streams(vec![stream_id]);
288        let stream = self.scheduler.stream(&stream_id);
289        stream.flush()
290    }
291
292    /// Returns the total time of GPU work this sync completes.
293    fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
294        self.scheduler.execute_streams(vec![stream_id]);
295        let stream = self.scheduler.stream(&stream_id);
296        stream.sync()
297    }
298
299    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
300        self.scheduler.execute_streams(vec![stream_id]);
301        let stream = self.scheduler.stream(&stream_id);
302        stream.start_profile()
303    }
304
305    fn end_profile(
306        &mut self,
307        stream_id: StreamId,
308        token: ProfilingToken,
309    ) -> Result<ProfileDuration, ProfileError> {
310        self.scheduler.execute_streams(vec![stream_id]);
311        let stream = self.scheduler.stream(&stream_id);
312        stream.end_profile(token)
313    }
314
315    fn memory_usage(
316        &mut self,
317        stream_id: StreamId,
318    ) -> cubecl_runtime::memory_management::MemoryUsage {
319        self.scheduler.execute_streams(vec![stream_id]);
320        let stream = self.scheduler.stream(&stream_id);
321        stream.mem_manage.memory_usage()
322    }
323
324    fn memory_cleanup(&mut self, stream_id: StreamId) {
325        self.scheduler.execute_streams(vec![stream_id]);
326        let stream = self.scheduler.stream(&stream_id);
327        stream.mem_manage.memory_cleanup(true);
328    }
329
330    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
331        self.scheduler.execute_streams(vec![stream_id]);
332        let stream = self.scheduler.stream(&stream_id);
333        stream.mem_manage.mode(mode);
334    }
335}
336
337fn compiler(backend: wgpu::Backend) -> AutoCompiler {
338    match backend {
339        #[cfg(feature = "spirv")]
340        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
341        #[cfg(feature = "msl")]
342        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
343        _ => AutoCompiler::Wgsl(Default::default()),
344    }
345}
346
347pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
348    let rank = shape.len();
349    let mut strides = vec![1; rank];
350    for i in (0..rank - 1).rev() {
351        strides[i] = strides[i + 1] * shape[i + 1];
352    }
353    strides
354}