1use super::storage::{WgpuResource, WgpuStorage};
2use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
3use crate::{AutoCompiler, AutoRepresentation};
4use alloc::sync::Arc;
5use cubecl_common::{
6 backtrace::BackTrace,
7 bytes::Bytes,
8 profile::{ProfileDuration, TimingMethod},
9 stream_id::StreamId,
10};
11use cubecl_core::{
12 MemoryConfiguration, WgpuCompilationOptions,
13 future::DynFut,
14 prelude::*,
15 server::{
16 Allocation, AllocationDescriptor, Binding, Bindings, CopyDescriptor, ExecutionError,
17 IoError, LaunchError, ProfileError, ProfilingToken, ResourceLimitError,
18 ServerCommunication, ServerUtilities,
19 },
20 zspace::{Strides, strides},
21};
22#[cfg(feature = "spirv")]
23use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash};
24use cubecl_ir::MemoryDeviceProperties;
25use cubecl_runtime::{
26 compiler::CubeTask,
27 config::GlobalConfig,
28 logging::ServerLogger,
29 memory_management::{MemoryAllocationMode, offset_handles},
30 server::ComputeServer,
31 storage::BindingResource,
32 stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
33 validation::{validate_cube_dim, validate_units},
34};
35use hashbrown::HashMap;
36use wgpu::ComputePipeline;
37
38#[derive(Debug)]
40pub struct WgpuServer {
41 pub(crate) device: wgpu::Device,
42 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
43 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
44 #[cfg(feature = "spirv")]
45 pub(crate) spirv_cache:
46 Option<CompilationCache<(u64, StableHash), cubecl_spirv::SpirvCacheEntry>>,
47 pub compilation_options: WgpuCompilationOptions,
48 pub(crate) backend: wgpu::Backend,
49 pub(crate) utilities: Arc<ServerUtilities<Self>>,
50}
51
52impl ServerCommunication for WgpuServer {
53 const SERVER_COMM_ENABLED: bool = false;
54}
55
56impl WgpuServer {
57 #[allow(clippy::too_many_arguments)]
59 pub fn new(
60 memory_properties: MemoryDeviceProperties,
61 memory_config: MemoryConfiguration,
62 compilation_options: WgpuCompilationOptions,
63 device: wgpu::Device,
64 queue: wgpu::Queue,
65 tasks_max: usize,
66 backend: wgpu::Backend,
67 timing_method: TimingMethod,
68 utilities: ServerUtilities<Self>,
69 ) -> Self {
70 let backend_scheduler = ScheduledWgpuBackend::new(
71 device.clone(),
72 queue.clone(),
73 memory_properties,
74 memory_config,
75 timing_method,
76 tasks_max,
77 utilities.logger.clone(),
78 );
79
80 let config = GlobalConfig::get();
81 let max_streams = config.streaming.max_streams;
82
83 Self {
84 compilation_options,
85 device,
86 pipelines: HashMap::new(),
87 scheduler: SchedulerMultiStream::new(
88 utilities.logger.clone(),
89 backend_scheduler,
90 SchedulerMultiStreamOptions {
91 max_streams,
92 max_tasks: tasks_max,
93 strategy: SchedulerStrategy::Interleave,
94 },
95 ),
96 #[cfg(feature = "spirv")]
97 spirv_cache: {
98 let config = cubecl_runtime::config::GlobalConfig::get();
99 if let Some(cache) = &config.compilation.cache {
100 let root = cache.root();
101 Some(CompilationCache::new(
102 "spirv",
103 CacheOption::default().name("vulkan").root(root),
104 ))
105 } else {
106 None
107 }
108 },
109 backend,
110 utilities: Arc::new(utilities),
111 }
112 }
113
114 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
115 let resources = bindings
118 .buffers
119 .iter()
120 .map(|b| {
121 let stream = self.scheduler.stream(&b.stream);
122 stream.mem_manage.get_resource(b.clone()).unwrap()
123 })
124 .collect::<Vec<_>>();
125
126 BindingsResource {
127 resources,
128 metadata: bindings.metadata,
129 scalars: bindings.scalars,
130 }
131 }
132
133 fn pipeline(
134 &mut self,
135 kernel: <Self as ComputeServer>::Kernel,
136 bindings: &Bindings,
137 mode: ExecutionMode,
138 ) -> Result<Arc<ComputePipeline>, LaunchError> {
139 let mut kernel_id = kernel.id();
140 kernel_id.mode(mode);
141
142 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
143 return Ok(pipeline.clone());
144 }
145
146 let cached = self.load_cached_pipeline(&kernel_id, bindings, mode)?;
147
148 if let Some(Ok(pipeline)) = cached {
149 self.pipelines.insert(kernel_id, pipeline.clone());
150 return Ok(pipeline);
151 }
152
153 validate_cube_dim(&self.utilities.properties, &kernel_id)?;
154 validate_units(&self.utilities.properties, &kernel_id)?;
155
156 let mut compiler = compiler(self.backend);
157 let mut compiled = compiler.compile(self, kernel, mode)?;
158
159 if self.scheduler.logger.compilation_activated() {
160 compiled.debug_info = Some(DebugInformation::new(
161 compiler.lang_tag(),
162 kernel_id.clone(),
163 ));
164 }
165 self.scheduler.logger.log_compilation(&compiled);
166
167 self.validate_shared(&compiled.repr)?;
168
169 let repr = compiled.repr.as_ref().map(|it| it.as_ref());
192 let module = self.create_module(&compiled.entrypoint_name, repr, &compiled.source, mode)?;
193 let pipeline = self.create_pipeline(&compiled.entrypoint_name, repr, module, bindings);
194 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
195
196 #[cfg(feature = "spirv")]
197 if let Some(Err(key)) = cached
198 && let Some(crate::AutoRepresentation::SpirV(kernel)) = compiled.repr
199 {
200 let cache = self.spirv_cache.as_mut().unwrap();
201 let result = cache.insert(
202 key,
203 cubecl_spirv::SpirvCacheEntry::new(compiled.entrypoint_name, kernel),
204 );
205 if let Err(err) = result {
206 log::warn!("Unable to save the SPIR-V {err:?}");
207 }
208 }
209
210 Ok(pipeline)
211 }
212
213 fn validate_shared(&self, repr: &Option<crate::AutoRepresentation>) -> Result<(), LaunchError> {
214 let shared_bytes = repr.as_ref().map(|repr| match repr {
215 AutoRepresentation::Wgsl(repr) => repr.shared_memory_bytes(),
216 #[cfg(feature = "msl")]
217 AutoRepresentation::Msl(repr) => repr.shared_memory_size(),
218 #[cfg(feature = "spirv")]
219 AutoRepresentation::SpirV(repr) => repr.shared_size,
220 });
221 let max_smem = self.utilities.properties.hardware.max_shared_memory_size;
222 if let Some(shared_bytes) = shared_bytes
223 && shared_bytes > max_smem
224 {
225 Err(ResourceLimitError::SharedMemory {
226 requested: shared_bytes,
227 max: max_smem,
228 backtrace: BackTrace::capture(),
229 }
230 .into())
231 } else {
232 Ok(())
233 }
234 }
235}
236
237impl ComputeServer for WgpuServer {
238 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
239 type Storage = WgpuStorage;
240 type Info = wgpu::Backend;
241
242 fn logger(&self) -> Arc<ServerLogger> {
243 self.scheduler.logger.clone()
244 }
245
246 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
247 self.utilities.clone()
248 }
249
250 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
251 Err(IoError::UnsupportedIoOperation {
253 backtrace: BackTrace::capture(),
254 })
255 }
256
257 fn create(
258 &mut self,
259 descriptors: Vec<AllocationDescriptor<'_>>,
260 stream_id: StreamId,
261 ) -> Result<Vec<Allocation>, IoError> {
262 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
263 let strides = descriptors
264 .iter()
265 .map(|desc| contiguous_strides(desc.shape))
266 .collect::<Vec<_>>();
267 let sizes = descriptors
268 .iter()
269 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
270 .collect::<Vec<_>>();
271 let total_size = sizes
272 .iter()
273 .map(|it| it.next_multiple_of(align))
274 .sum::<usize>();
275
276 let stream = self.scheduler.stream(&stream_id);
277 let mem_handle = stream.empty(total_size as u64, stream_id)?;
278 let handles = offset_handles(mem_handle, &sizes, align);
279
280 Ok(handles
281 .into_iter()
282 .zip(strides)
283 .map(|(handle, strides)| Allocation::new(handle, strides))
284 .collect())
285 }
286
287 fn read<'a>(
288 &mut self,
289 descriptors: Vec<CopyDescriptor<'a>>,
290 stream_id: StreamId,
291 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
292 let mut streams = vec![stream_id];
293 let mut resources = Vec::with_capacity(descriptors.len());
294 for desc in descriptors {
295 if &*contiguous_strides(desc.shape) != desc.strides {
296 return Box::pin(async {
297 Err(IoError::UnsupportedStrides {
298 backtrace: BackTrace::capture(),
299 })
300 });
301 }
302 if !streams.contains(&desc.binding.stream) {
303 streams.push(desc.binding.stream);
304 }
305 let stream = self.scheduler.stream(&desc.binding.stream);
306 let resource = match stream.mem_manage.get_resource(desc.binding) {
307 Ok(val) => val,
308 Err(err) => return Box::pin(async move { Err(err) }),
309 };
310 resources.push((resource, desc.shape.into(), desc.elem_size));
311 }
312
313 self.scheduler.execute_streams(streams);
314 let stream = self.scheduler.stream(&stream_id);
315 stream.read_resources(resources)
316 }
317
318 fn write(
319 &mut self,
320 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
321 stream_id: StreamId,
322 ) -> Result<(), IoError> {
323 for (desc, data) in descriptors {
324 if &*contiguous_strides(desc.shape) != desc.strides {
325 return Err(IoError::UnsupportedStrides {
326 backtrace: BackTrace::capture(),
327 });
328 }
329
330 let stream = self.scheduler.stream(&desc.binding.stream);
331 let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
332 let task = ScheduleTask::Write {
333 data,
334 buffer: resource,
335 };
336
337 self.scheduler.register(stream_id, task, [].into_iter());
338 }
339
340 Ok(())
341 }
342
343 fn get_resource(
344 &mut self,
345 binding: Binding,
346 stream_id: StreamId,
347 ) -> BindingResource<WgpuResource> {
348 let mut streams = vec![stream_id];
349 if binding.stream != stream_id {
350 streams.push(binding.stream);
351 }
352 self.scheduler.execute_streams(streams);
353 let stream = self.scheduler.stream(&binding.stream);
354 let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
355 BindingResource::new(binding, resource)
356 }
357
358 unsafe fn launch(
359 &mut self,
360 kernel: Self::Kernel,
361 count: CubeCount,
362 bindings: Bindings,
363 mode: ExecutionMode,
364 stream_id: StreamId,
365 ) -> Result<(), LaunchError> {
366 let pipeline = self.pipeline(kernel, &bindings, mode)?;
367 let buffers = bindings.buffers.clone();
368 let resources = self.prepare_bindings(bindings);
369 let task = ScheduleTask::Execute {
370 pipeline,
371 count,
372 resources,
373 };
374
375 self.scheduler.register(stream_id, task, buffers.iter());
376
377 Ok(())
378 }
379
380 fn flush(&mut self, stream_id: StreamId) {
381 self.scheduler.execute_streams(vec![stream_id]);
382 let stream = self.scheduler.stream(&stream_id);
383 stream.flush()
384 }
385
386 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
388 self.scheduler.execute_streams(vec![stream_id]);
389 let stream = self.scheduler.stream(&stream_id);
390 stream.sync()
391 }
392
393 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
394 self.scheduler.execute_streams(vec![stream_id]);
395 let stream = self.scheduler.stream(&stream_id);
396 stream.start_profile()
397 }
398
399 fn end_profile(
400 &mut self,
401 stream_id: StreamId,
402 token: ProfilingToken,
403 ) -> Result<ProfileDuration, ProfileError> {
404 self.scheduler.execute_streams(vec![stream_id]);
405 let stream = self.scheduler.stream(&stream_id);
406 stream.end_profile(token)
407 }
408
409 fn memory_usage(
410 &mut self,
411 stream_id: StreamId,
412 ) -> cubecl_runtime::memory_management::MemoryUsage {
413 self.scheduler.execute_streams(vec![stream_id]);
414 let stream = self.scheduler.stream(&stream_id);
415 stream.mem_manage.memory_usage()
416 }
417
418 fn memory_cleanup(&mut self, stream_id: StreamId) {
419 self.scheduler.execute_streams(vec![stream_id]);
420 let stream = self.scheduler.stream(&stream_id);
421 stream.mem_manage.memory_cleanup(true);
422 }
423
424 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
425 self.scheduler.execute_streams(vec![stream_id]);
426 let stream = self.scheduler.stream(&stream_id);
427 stream.mem_manage.mode(mode);
428 }
429}
430
431fn compiler(backend: wgpu::Backend) -> AutoCompiler {
432 match backend {
433 #[cfg(feature = "spirv")]
434 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
435 #[cfg(feature = "msl")]
436 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
437 _ => AutoCompiler::Wgsl(Default::default()),
438 }
439}
440
441pub(crate) fn contiguous_strides(shape: &[usize]) -> Strides {
442 let rank = shape.len();
443 let mut strides = strides![1; rank];
444 for i in (0..rank - 1).rev() {
445 strides[i] = strides[i + 1] * shape[i + 1];
446 }
447 strides
448}