cubecl_wgpu/compute/
server.rs

1use super::WgpuResource;
2use super::{WgpuStorage, stream::WgpuStream};
3use crate::AutoCompiler;
4use alloc::sync::Arc;
5use cubecl_common::profile::{ProfileDuration, TimingMethod};
6use cubecl_core::compute::{CubeTask, DebugInformation};
7use cubecl_core::future::DynFut;
8use cubecl_core::server::{ProfileError, ProfilingToken};
9use cubecl_core::{
10    Feature, MemoryConfiguration, WgpuCompilationOptions,
11    prelude::*,
12    server::{Binding, BindingWithMeta, Bindings, Handle},
13};
14use cubecl_runtime::logging::ServerLogger;
15use cubecl_runtime::memory_management::offset_handles;
16use cubecl_runtime::{
17    memory_management::MemoryDeviceProperties,
18    server::{self, ComputeServer},
19    storage::BindingResource,
20};
21use hashbrown::HashMap;
22use wgpu::ComputePipeline;
23
24/// Wgpu compute server.
25#[derive(Debug)]
26pub struct WgpuServer {
27    pub(crate) device: wgpu::Device,
28    pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
29    stream: WgpuStream,
30    pub compilation_options: WgpuCompilationOptions,
31    pub(crate) backend: wgpu::Backend,
32}
33
34impl WgpuServer {
35    /// Create a new server.
36    #[allow(clippy::too_many_arguments)]
37    pub fn new(
38        memory_properties: MemoryDeviceProperties,
39        memory_config: MemoryConfiguration,
40        compilation_options: WgpuCompilationOptions,
41        device: wgpu::Device,
42        queue: wgpu::Queue,
43        tasks_max: usize,
44        backend: wgpu::Backend,
45        timing_method: TimingMethod,
46    ) -> Self {
47        let stream = WgpuStream::new(
48            device.clone(),
49            queue.clone(),
50            memory_properties,
51            memory_config,
52            timing_method,
53            tasks_max,
54        );
55
56        Self {
57            compilation_options,
58            device,
59            pipelines: HashMap::new(),
60            stream,
61            backend,
62        }
63    }
64
65    fn pipeline(
66        &mut self,
67        kernel: <Self as ComputeServer>::Kernel,
68        mode: ExecutionMode,
69        logger: Arc<ServerLogger>,
70    ) -> Arc<ComputePipeline> {
71        let mut kernel_id = kernel.id();
72        kernel_id.mode(mode);
73
74        if let Some(pipeline) = self.pipelines.get(&kernel_id) {
75            return pipeline.clone();
76        }
77
78        let mut compiler = compiler(self.backend);
79        let mut compile = compiler.compile(self, kernel, mode);
80
81        if logger.compilation_activated() {
82            compile.debug_info = Some(DebugInformation::new(
83                compiler.lang_tag(),
84                kernel_id.clone(),
85            ));
86        }
87        logger.log_compilation(&compile);
88        // /!\ Do not delete the following commented code.
89        // This is useful while working on the metal compiler.
90        // Also the errors are printed nicely which is not the case when this is the runtime
91        // that does it.
92        // println!("SOURCE:\n{}", compile.source);
93        // {
94        //     // Write shader in metal file then compile it for error
95        //     std::fs::write("shader.metal", &compile.source).expect("should write to file");
96        //     let _status = std::process::Command::new("xcrun")
97        //         .args(vec![
98        //             "-sdk",
99        //             "macosx",
100        //             "metal",
101        //             "-o",
102        //             "shader.ir",
103        //             "-c",
104        //             "shader.metal",
105        //         ])
106        //         .status()
107        //         .expect("should launch the command");
108        //     // std::process::exit(status.code().unwrap());
109        // }
110        let pipeline = self.create_pipeline(compile, mode);
111        self.pipelines.insert(kernel_id.clone(), pipeline.clone());
112
113        pipeline
114    }
115}
116
117impl ComputeServer for WgpuServer {
118    type Kernel = Box<dyn CubeTask<AutoCompiler>>;
119    type Storage = WgpuStorage;
120    type Feature = Feature;
121    type Info = wgpu::Backend;
122
123    fn read(&mut self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>> {
124        self.stream.read_buffers(bindings)
125    }
126
127    fn get_resource(&mut self, binding: Binding) -> BindingResource<WgpuResource> {
128        let resource = self.stream.mem_manage.get_resource(binding.clone());
129        BindingResource::new(binding, resource)
130    }
131
132    /// When we create a new handle from existing data, we use custom allocations so that we don't
133    /// have to execute the current pending tasks.
134    ///
135    /// This is important, otherwise the compute passes are going to be too small and we won't be able to
136    /// fully utilize the GPU.
137    fn create(&mut self, data: &[u8]) -> server::Handle {
138        self.stream.create(data)
139    }
140
141    fn empty(&mut self, size: usize) -> server::Handle {
142        self.stream.empty(size as u64)
143    }
144
145    unsafe fn execute(
146        &mut self,
147        kernel: Self::Kernel,
148        count: CubeCount,
149        bindings: Bindings,
150        mode: ExecutionMode,
151        logger: Arc<ServerLogger>,
152    ) {
153        let pipeline = self.pipeline(kernel, mode, logger);
154        self.stream.register(pipeline, bindings, &count);
155    }
156
157    fn flush(&mut self) {
158        // End the current compute pass.
159        self.stream.flush();
160    }
161
162    /// Returns the total time of GPU work this sync completes.
163    fn sync(&mut self) -> DynFut<()> {
164        self.stream.sync()
165    }
166
167    fn start_profile(&mut self) -> ProfilingToken {
168        self.stream.start_profile()
169    }
170
171    fn end_profile(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
172        self.stream.end_profile(token)
173    }
174
175    fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
176        self.stream.mem_manage.memory_usage()
177    }
178
179    fn memory_cleanup(&mut self) {
180        self.stream.mem_manage.memory_cleanup(true);
181    }
182
183    fn read_tensor(&mut self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
184        let expected_sizes = bindings
185            .iter()
186            .map(|it| it.shape.iter().product::<usize>() * it.elem_size)
187            .collect::<Vec<_>>();
188        let bindings = bindings.into_iter().map(|it| it.binding).collect();
189        let data = self.read(bindings);
190        Box::pin(async move {
191            let mut data = data.await;
192            for (data, expected_size) in data.iter_mut().zip(expected_sizes) {
193                data.truncate(expected_size);
194            }
195            data
196        })
197    }
198
199    fn create_tensors(
200        &mut self,
201        data: Vec<&[u8]>,
202        shapes: Vec<&[usize]>,
203        elem_size: Vec<usize>,
204    ) -> Vec<(Handle, Vec<usize>)> {
205        let handles_strides = self.empty_tensors(shapes.clone(), elem_size);
206
207        for i in 0..data.len() {
208            let data = data[i];
209            let (handle, _) = &handles_strides[i];
210            self.stream.copy_to_handle(handle.clone(), data);
211        }
212
213        handles_strides
214    }
215
216    fn empty_tensors(
217        &mut self,
218        shape: Vec<&[usize]>,
219        elem_size: Vec<usize>,
220    ) -> Vec<(Handle, Vec<usize>)> {
221        let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
222        let strides = shape
223            .iter()
224            .map(|shape| contiguous_strides(shape))
225            .collect::<Vec<_>>();
226        let sizes = shape
227            .iter()
228            .map(|it| it.iter().product::<usize>())
229            .zip(elem_size)
230            .map(|(size, elem_size)| (size * elem_size).next_multiple_of(align))
231            .collect::<Vec<_>>();
232        let total_size = sizes.iter().sum::<usize>();
233
234        let mem_handle = self.empty(total_size);
235        let handles = offset_handles(mem_handle, &sizes);
236
237        handles.into_iter().zip(strides).collect()
238    }
239}
240
241fn compiler(backend: wgpu::Backend) -> AutoCompiler {
242    match backend {
243        #[cfg(feature = "spirv")]
244        wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
245        #[cfg(feature = "msl")]
246        wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
247        _ => AutoCompiler::Wgsl(Default::default()),
248    }
249}
250
251pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
252    let rank = shape.len();
253    let mut strides = vec![1; rank];
254    for i in (0..rank - 1).rev() {
255        strides[i] = strides[i + 1] * shape[i + 1];
256    }
257    strides
258}