1use super::WgpuResource;
2use super::{WgpuStorage, stream::WgpuStream};
3use crate::AutoCompiler;
4use alloc::sync::Arc;
5use cubecl_common::profile::{ProfileDuration, TimingMethod};
6use cubecl_core::compute::{CubeTask, DebugInformation};
7use cubecl_core::future::DynFut;
8use cubecl_core::server::{ProfileError, ProfilingToken};
9use cubecl_core::{
10 Feature, MemoryConfiguration, WgpuCompilationOptions,
11 prelude::*,
12 server::{Binding, BindingWithMeta, Bindings, Handle},
13};
14use cubecl_runtime::logging::ServerLogger;
15use cubecl_runtime::memory_management::offset_handles;
16use cubecl_runtime::{
17 memory_management::MemoryDeviceProperties,
18 server::{self, ComputeServer},
19 storage::BindingResource,
20};
21use hashbrown::HashMap;
22use wgpu::ComputePipeline;
23
24#[derive(Debug)]
26pub struct WgpuServer {
27 pub(crate) device: wgpu::Device,
28 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
29 stream: WgpuStream,
30 pub compilation_options: WgpuCompilationOptions,
31 pub(crate) backend: wgpu::Backend,
32}
33
34impl WgpuServer {
35 #[allow(clippy::too_many_arguments)]
37 pub fn new(
38 memory_properties: MemoryDeviceProperties,
39 memory_config: MemoryConfiguration,
40 compilation_options: WgpuCompilationOptions,
41 device: wgpu::Device,
42 queue: wgpu::Queue,
43 tasks_max: usize,
44 backend: wgpu::Backend,
45 timing_method: TimingMethod,
46 ) -> Self {
47 let stream = WgpuStream::new(
48 device.clone(),
49 queue.clone(),
50 memory_properties,
51 memory_config,
52 timing_method,
53 tasks_max,
54 );
55
56 Self {
57 compilation_options,
58 device,
59 pipelines: HashMap::new(),
60 stream,
61 backend,
62 }
63 }
64
65 fn pipeline(
66 &mut self,
67 kernel: <Self as ComputeServer>::Kernel,
68 mode: ExecutionMode,
69 logger: Arc<ServerLogger>,
70 ) -> Arc<ComputePipeline> {
71 let mut kernel_id = kernel.id();
72 kernel_id.mode(mode);
73
74 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
75 return pipeline.clone();
76 }
77
78 let mut compiler = compiler(self.backend);
79 let mut compile = compiler.compile(self, kernel, mode);
80
81 if logger.compilation_activated() {
82 compile.debug_info = Some(DebugInformation::new(
83 compiler.lang_tag(),
84 kernel_id.clone(),
85 ));
86 }
87 logger.log_compilation(&compile);
88 let pipeline = self.create_pipeline(compile, mode);
111 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
112
113 pipeline
114 }
115}
116
117impl ComputeServer for WgpuServer {
118 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
119 type Storage = WgpuStorage;
120 type Feature = Feature;
121 type Info = wgpu::Backend;
122
123 fn read(&mut self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>> {
124 self.stream.read_buffers(bindings)
125 }
126
127 fn get_resource(&mut self, binding: Binding) -> BindingResource<WgpuResource> {
128 let resource = self.stream.mem_manage.get_resource(binding.clone());
129 BindingResource::new(binding, resource)
130 }
131
132 fn create(&mut self, data: &[u8]) -> server::Handle {
138 self.stream.create(data)
139 }
140
141 fn empty(&mut self, size: usize) -> server::Handle {
142 self.stream.empty(size as u64)
143 }
144
145 unsafe fn execute(
146 &mut self,
147 kernel: Self::Kernel,
148 count: CubeCount,
149 bindings: Bindings,
150 mode: ExecutionMode,
151 logger: Arc<ServerLogger>,
152 ) {
153 let pipeline = self.pipeline(kernel, mode, logger);
154 self.stream.register(pipeline, bindings, &count);
155 }
156
157 fn flush(&mut self) {
158 self.stream.flush();
160 }
161
162 fn sync(&mut self) -> DynFut<()> {
164 self.stream.sync()
165 }
166
167 fn start_profile(&mut self) -> ProfilingToken {
168 self.stream.start_profile()
169 }
170
171 fn end_profile(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
172 self.stream.end_profile(token)
173 }
174
175 fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
176 self.stream.mem_manage.memory_usage()
177 }
178
179 fn memory_cleanup(&mut self) {
180 self.stream.mem_manage.memory_cleanup(true);
181 }
182
183 fn read_tensor(&mut self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
184 let expected_sizes = bindings
185 .iter()
186 .map(|it| it.shape.iter().product::<usize>() * it.elem_size)
187 .collect::<Vec<_>>();
188 let bindings = bindings.into_iter().map(|it| it.binding).collect();
189 let data = self.read(bindings);
190 Box::pin(async move {
191 let mut data = data.await;
192 for (data, expected_size) in data.iter_mut().zip(expected_sizes) {
193 data.truncate(expected_size);
194 }
195 data
196 })
197 }
198
199 fn create_tensors(
200 &mut self,
201 data: Vec<&[u8]>,
202 shapes: Vec<&[usize]>,
203 elem_size: Vec<usize>,
204 ) -> Vec<(Handle, Vec<usize>)> {
205 let handles_strides = self.empty_tensors(shapes.clone(), elem_size);
206
207 for i in 0..data.len() {
208 let data = data[i];
209 let (handle, _) = &handles_strides[i];
210 self.stream.copy_to_handle(handle.clone(), data);
211 }
212
213 handles_strides
214 }
215
216 fn empty_tensors(
217 &mut self,
218 shape: Vec<&[usize]>,
219 elem_size: Vec<usize>,
220 ) -> Vec<(Handle, Vec<usize>)> {
221 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
222 let strides = shape
223 .iter()
224 .map(|shape| contiguous_strides(shape))
225 .collect::<Vec<_>>();
226 let sizes = shape
227 .iter()
228 .map(|it| it.iter().product::<usize>())
229 .zip(elem_size)
230 .map(|(size, elem_size)| (size * elem_size).next_multiple_of(align))
231 .collect::<Vec<_>>();
232 let total_size = sizes.iter().sum::<usize>();
233
234 let mem_handle = self.empty(total_size);
235 let handles = offset_handles(mem_handle, &sizes);
236
237 handles.into_iter().zip(strides).collect()
238 }
239}
240
241fn compiler(backend: wgpu::Backend) -> AutoCompiler {
242 match backend {
243 #[cfg(feature = "spirv")]
244 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
245 #[cfg(feature = "msl")]
246 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
247 _ => AutoCompiler::Wgsl(Default::default()),
248 }
249}
250
251pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
252 let rank = shape.len();
253 let mut strides = vec![1; rank];
254 for i in (0..rank - 1).rev() {
255 strides[i] = strides[i + 1] * shape[i + 1];
256 }
257 strides
258}