1use super::storage::{WgpuResource, WgpuStorage};
2use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
3use crate::{AutoCompiler, AutoRepresentation};
4use alloc::sync::Arc;
5use cubecl_common::{
6 backtrace::BackTrace,
7 bytes::Bytes,
8 profile::{ProfileDuration, TimingMethod},
9 stream_id::StreamId,
10};
11use cubecl_core::{
12 MemoryConfiguration, WgpuCompilationOptions,
13 future::DynFut,
14 prelude::*,
15 server::{
16 Allocation, AllocationDescriptor, Binding, Bindings, CopyDescriptor, ExecutionError,
17 IoError, LaunchError, ProfileError, ProfilingToken, ResourceLimitError,
18 ServerCommunication, ServerUtilities,
19 },
20};
21#[cfg(feature = "spirv")]
22use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash};
23use cubecl_ir::MemoryDeviceProperties;
24use cubecl_runtime::{
25 compiler::CubeTask,
26 config::GlobalConfig,
27 logging::ServerLogger,
28 memory_management::{MemoryAllocationMode, offset_handles},
29 server::ComputeServer,
30 storage::BindingResource,
31 stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
32 validation::{validate_cube_dim, validate_units},
33};
34use hashbrown::HashMap;
35use wgpu::ComputePipeline;
36
37#[derive(Debug)]
39pub struct WgpuServer {
40 pub(crate) device: wgpu::Device,
41 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
42 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
43 #[cfg(feature = "spirv")]
44 pub(crate) spirv_cache:
45 Option<CompilationCache<(u64, StableHash), cubecl_spirv::SpirvCacheEntry>>,
46 pub compilation_options: WgpuCompilationOptions,
47 pub(crate) backend: wgpu::Backend,
48 pub(crate) utilities: Arc<ServerUtilities<Self>>,
49}
50
51impl ServerCommunication for WgpuServer {
52 const SERVER_COMM_ENABLED: bool = false;
53}
54
55impl WgpuServer {
56 #[allow(clippy::too_many_arguments)]
58 pub fn new(
59 memory_properties: MemoryDeviceProperties,
60 memory_config: MemoryConfiguration,
61 compilation_options: WgpuCompilationOptions,
62 device: wgpu::Device,
63 queue: wgpu::Queue,
64 tasks_max: usize,
65 backend: wgpu::Backend,
66 timing_method: TimingMethod,
67 utilities: ServerUtilities<Self>,
68 ) -> Self {
69 let backend_scheduler = ScheduledWgpuBackend::new(
70 device.clone(),
71 queue.clone(),
72 memory_properties,
73 memory_config,
74 timing_method,
75 tasks_max,
76 utilities.logger.clone(),
77 );
78
79 let config = GlobalConfig::get();
80 let max_streams = config.streaming.max_streams;
81
82 Self {
83 compilation_options,
84 device,
85 pipelines: HashMap::new(),
86 scheduler: SchedulerMultiStream::new(
87 utilities.logger.clone(),
88 backend_scheduler,
89 SchedulerMultiStreamOptions {
90 max_streams,
91 max_tasks: tasks_max,
92 strategy: SchedulerStrategy::Interleave,
93 },
94 ),
95 #[cfg(feature = "spirv")]
96 spirv_cache: {
97 let config = cubecl_runtime::config::GlobalConfig::get();
98 if let Some(cache) = &config.compilation.cache {
99 let root = cache.root();
100 Some(CompilationCache::new(
101 "spirv",
102 CacheOption::default().name("vulkan").root(root),
103 ))
104 } else {
105 None
106 }
107 },
108 backend,
109 utilities: Arc::new(utilities),
110 }
111 }
112
113 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
114 let resources = bindings
117 .buffers
118 .iter()
119 .map(|b| {
120 let stream = self.scheduler.stream(&b.stream);
121 stream.mem_manage.get_resource(b.clone()).unwrap()
122 })
123 .collect::<Vec<_>>();
124
125 BindingsResource {
126 resources,
127 metadata: bindings.metadata,
128 scalars: bindings.scalars,
129 }
130 }
131
132 fn pipeline(
133 &mut self,
134 kernel: <Self as ComputeServer>::Kernel,
135 bindings: &Bindings,
136 mode: ExecutionMode,
137 ) -> Result<Arc<ComputePipeline>, LaunchError> {
138 let mut kernel_id = kernel.id();
139 kernel_id.mode(mode);
140
141 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
142 return Ok(pipeline.clone());
143 }
144
145 let cached = self.load_cached_pipeline(&kernel_id, bindings, mode)?;
146
147 if let Some(Ok(pipeline)) = cached {
148 self.pipelines.insert(kernel_id, pipeline.clone());
149 return Ok(pipeline);
150 }
151
152 validate_cube_dim(&self.utilities.properties, &kernel_id)?;
153 validate_units(&self.utilities.properties, &kernel_id)?;
154
155 let mut compiler = compiler(self.backend);
156 let mut compiled = compiler.compile(self, kernel, mode)?;
157
158 if self.scheduler.logger.compilation_activated() {
159 compiled.debug_info = Some(DebugInformation::new(
160 compiler.lang_tag(),
161 kernel_id.clone(),
162 ));
163 }
164 self.scheduler.logger.log_compilation(&compiled);
165
166 self.validate_shared(&compiled.repr)?;
167
168 let repr = compiled.repr.as_ref().map(|it| it.as_ref());
191 let module = self.create_module(&compiled.entrypoint_name, repr, &compiled.source, mode)?;
192 let pipeline = self.create_pipeline(&compiled.entrypoint_name, repr, module, bindings);
193 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
194
195 #[cfg(feature = "spirv")]
196 if let Some(Err(key)) = cached
197 && let Some(crate::AutoRepresentation::SpirV(kernel)) = compiled.repr
198 {
199 let cache = self.spirv_cache.as_mut().unwrap();
200 let result = cache.insert(
201 key,
202 cubecl_spirv::SpirvCacheEntry::new(compiled.entrypoint_name, kernel),
203 );
204 if let Err(err) = result {
205 log::warn!("Unable to save the SPIR-V {err:?}");
206 }
207 }
208
209 Ok(pipeline)
210 }
211
212 fn validate_shared(&self, repr: &Option<crate::AutoRepresentation>) -> Result<(), LaunchError> {
213 let shared_bytes = repr.as_ref().map(|repr| match repr {
214 AutoRepresentation::Wgsl(repr) => repr.shared_memory_bytes(),
215 #[cfg(feature = "msl")]
216 AutoRepresentation::Msl(repr) => repr.shared_memory_size(),
217 #[cfg(feature = "spirv")]
218 AutoRepresentation::SpirV(repr) => repr.shared_size,
219 });
220 let max_smem = self.utilities.properties.hardware.max_shared_memory_size;
221 if let Some(shared_bytes) = shared_bytes
222 && shared_bytes > max_smem
223 {
224 Err(ResourceLimitError::SharedMemory {
225 requested: shared_bytes,
226 max: max_smem,
227 backtrace: BackTrace::capture(),
228 }
229 .into())
230 } else {
231 Ok(())
232 }
233 }
234}
235
236impl ComputeServer for WgpuServer {
237 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
238 type Storage = WgpuStorage;
239 type Info = wgpu::Backend;
240
241 fn logger(&self) -> Arc<ServerLogger> {
242 self.scheduler.logger.clone()
243 }
244
245 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
246 self.utilities.clone()
247 }
248
249 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
250 Err(IoError::UnsupportedIoOperation {
252 backtrace: BackTrace::capture(),
253 })
254 }
255
256 fn create(
257 &mut self,
258 descriptors: Vec<AllocationDescriptor<'_>>,
259 stream_id: StreamId,
260 ) -> Result<Vec<Allocation>, IoError> {
261 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
262 let strides = descriptors
263 .iter()
264 .map(|desc| contiguous_strides(desc.shape))
265 .collect::<Vec<_>>();
266 let sizes = descriptors
267 .iter()
268 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
269 .collect::<Vec<_>>();
270 let total_size = sizes
271 .iter()
272 .map(|it| it.next_multiple_of(align))
273 .sum::<usize>();
274
275 let stream = self.scheduler.stream(&stream_id);
276 let mem_handle = stream.empty(total_size as u64, stream_id)?;
277 let handles = offset_handles(mem_handle, &sizes, align);
278
279 Ok(handles
280 .into_iter()
281 .zip(strides)
282 .map(|(handle, strides)| Allocation::new(handle, strides))
283 .collect())
284 }
285
286 fn read<'a>(
287 &mut self,
288 descriptors: Vec<CopyDescriptor<'a>>,
289 stream_id: StreamId,
290 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
291 let mut streams = vec![stream_id];
292 let mut resources = Vec::with_capacity(descriptors.len());
293 for desc in descriptors {
294 if contiguous_strides(desc.shape) != desc.strides {
295 return Box::pin(async {
296 Err(IoError::UnsupportedStrides {
297 backtrace: BackTrace::capture(),
298 })
299 });
300 }
301 if !streams.contains(&desc.binding.stream) {
302 streams.push(desc.binding.stream);
303 }
304 let stream = self.scheduler.stream(&desc.binding.stream);
305 let resource = match stream.mem_manage.get_resource(desc.binding) {
306 Ok(val) => val,
307 Err(err) => return Box::pin(async move { Err(err) }),
308 };
309 resources.push((resource, desc.shape.to_vec(), desc.elem_size));
310 }
311
312 self.scheduler.execute_streams(streams);
313 let stream = self.scheduler.stream(&stream_id);
314 stream.read_resources(resources)
315 }
316
317 fn write(
318 &mut self,
319 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
320 stream_id: StreamId,
321 ) -> Result<(), IoError> {
322 for (desc, data) in descriptors {
323 if contiguous_strides(desc.shape) != desc.strides {
324 return Err(IoError::UnsupportedStrides {
325 backtrace: BackTrace::capture(),
326 });
327 }
328
329 let stream = self.scheduler.stream(&desc.binding.stream);
330 let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
331 let task = ScheduleTask::Write {
332 data,
333 buffer: resource,
334 };
335
336 self.scheduler.register(stream_id, task, [].into_iter());
337 }
338
339 Ok(())
340 }
341
342 fn get_resource(
343 &mut self,
344 binding: Binding,
345 stream_id: StreamId,
346 ) -> BindingResource<WgpuResource> {
347 let mut streams = vec![stream_id];
348 if binding.stream != stream_id {
349 streams.push(binding.stream);
350 }
351 self.scheduler.execute_streams(streams);
352 let stream = self.scheduler.stream(&binding.stream);
353 let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
354 BindingResource::new(binding, resource)
355 }
356
357 unsafe fn launch(
358 &mut self,
359 kernel: Self::Kernel,
360 count: CubeCount,
361 bindings: Bindings,
362 mode: ExecutionMode,
363 stream_id: StreamId,
364 ) -> Result<(), LaunchError> {
365 let pipeline = self.pipeline(kernel, &bindings, mode)?;
366 let buffers = bindings.buffers.clone();
367 let resources = self.prepare_bindings(bindings);
368 let task = ScheduleTask::Execute {
369 pipeline,
370 count,
371 resources,
372 };
373
374 self.scheduler.register(stream_id, task, buffers.iter());
375
376 Ok(())
377 }
378
379 fn flush(&mut self, stream_id: StreamId) {
380 self.scheduler.execute_streams(vec![stream_id]);
381 let stream = self.scheduler.stream(&stream_id);
382 stream.flush()
383 }
384
385 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
387 self.scheduler.execute_streams(vec![stream_id]);
388 let stream = self.scheduler.stream(&stream_id);
389 stream.sync()
390 }
391
392 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
393 self.scheduler.execute_streams(vec![stream_id]);
394 let stream = self.scheduler.stream(&stream_id);
395 stream.start_profile()
396 }
397
398 fn end_profile(
399 &mut self,
400 stream_id: StreamId,
401 token: ProfilingToken,
402 ) -> Result<ProfileDuration, ProfileError> {
403 self.scheduler.execute_streams(vec![stream_id]);
404 let stream = self.scheduler.stream(&stream_id);
405 stream.end_profile(token)
406 }
407
408 fn memory_usage(
409 &mut self,
410 stream_id: StreamId,
411 ) -> cubecl_runtime::memory_management::MemoryUsage {
412 self.scheduler.execute_streams(vec![stream_id]);
413 let stream = self.scheduler.stream(&stream_id);
414 stream.mem_manage.memory_usage()
415 }
416
417 fn memory_cleanup(&mut self, stream_id: StreamId) {
418 self.scheduler.execute_streams(vec![stream_id]);
419 let stream = self.scheduler.stream(&stream_id);
420 stream.mem_manage.memory_cleanup(true);
421 }
422
423 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
424 self.scheduler.execute_streams(vec![stream_id]);
425 let stream = self.scheduler.stream(&stream_id);
426 stream.mem_manage.mode(mode);
427 }
428}
429
430fn compiler(backend: wgpu::Backend) -> AutoCompiler {
431 match backend {
432 #[cfg(feature = "spirv")]
433 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
434 #[cfg(feature = "msl")]
435 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
436 _ => AutoCompiler::Wgsl(Default::default()),
437 }
438}
439
440pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
441 let rank = shape.len();
442 let mut strides = vec![1; rank];
443 for i in (0..rank - 1).rev() {
444 strides[i] = strides[i + 1] * shape[i + 1];
445 }
446 strides
447}