1use super::storage::{WgpuResource, WgpuStorage};
2use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
3use crate::{AutoCompiler, AutoRepresentation};
4use alloc::sync::Arc;
5use cubecl_common::{
6 backtrace::BackTrace,
7 bytes::Bytes,
8 profile::{ProfileDuration, TimingMethod},
9 stream_id::StreamId,
10};
11use cubecl_core::server::{Binding, StreamErrorMode};
12use cubecl_core::zspace::Shape;
13use cubecl_core::{
14 MemoryConfiguration, WgpuCompilationOptions,
15 future::DynFut,
16 prelude::*,
17 server::{
18 CopyDescriptor, IoError, KernelArguments, LaunchError, ProfileError, ProfilingToken,
19 ResourceLimitError, ServerCommunication, ServerError, ServerUtilities,
20 },
21 zspace::{Strides, strides},
22};
23#[cfg(feature = "spirv")]
24use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash};
25use cubecl_ir::MemoryDeviceProperties;
26use cubecl_runtime::allocator::ContiguousMemoryLayoutPolicy;
27use cubecl_runtime::memory_management::{ManagedMemoryHandle, MemoryUsage};
28use cubecl_runtime::{
29 compiler::CubeTask,
30 config::{CubeClRuntimeConfig, RuntimeConfig},
31 logging::ServerLogger,
32 memory_management::MemoryAllocationMode,
33 server::ComputeServer,
34 storage::ManagedResource,
35 stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
36 validation::{validate_cube_dim, validate_units},
37};
38use hashbrown::HashMap;
39use wgpu::ComputePipeline;
40
41#[derive(Debug)]
43pub struct WgpuServer {
44 pub(crate) device: wgpu::Device,
45 streams_pool: Vec<StreamId>,
47 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
48 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
49 #[cfg(feature = "spirv")]
50 pub(crate) spirv_cache:
51 Option<CompilationCache<(u64, StableHash), cubecl_spirv::SpirvCacheEntry>>,
52 pub compilation_options: WgpuCompilationOptions,
53 pub(crate) backend: wgpu::Backend,
54 pub(crate) utilities: Arc<ServerUtilities<Self>>,
55}
56
57impl ServerCommunication for WgpuServer {
58 const SERVER_COMM_ENABLED: bool = false;
59}
60
61impl WgpuServer {
62 #[allow(clippy::too_many_arguments)]
64 pub fn new(
65 memory_properties: MemoryDeviceProperties,
66 memory_config: MemoryConfiguration,
67 compilation_options: WgpuCompilationOptions,
68 device: wgpu::Device,
69 queue: wgpu::Queue,
70 tasks_max: usize,
71 backend: wgpu::Backend,
72 timing_method: TimingMethod,
73 utilities: ServerUtilities<Self>,
74 ) -> Self {
75 let backend_scheduler = ScheduledWgpuBackend::new(
76 device.clone(),
77 queue.clone(),
78 memory_properties,
79 memory_config,
80 timing_method,
81 tasks_max,
82 utilities.logger.clone(),
83 );
84
85 let config = CubeClRuntimeConfig::get();
86 let max_streams = config.streaming.max_streams;
87
88 Self {
89 compilation_options,
90 streams_pool: Vec::new(),
91 device,
92 pipelines: HashMap::new(),
93 scheduler: SchedulerMultiStream::new(
94 utilities.logger.clone(),
95 backend_scheduler,
96 SchedulerMultiStreamOptions {
97 max_streams,
98 max_tasks: tasks_max,
99 strategy: SchedulerStrategy::Interleave,
100 },
101 ),
102 #[cfg(feature = "spirv")]
103 spirv_cache: {
104 let config = cubecl_runtime::config::CubeClRuntimeConfig::get();
105 if let Some(cache) = &config.compilation.cache {
106 let root = cache.root();
107 Some(CompilationCache::new(
108 "spirv",
109 CacheOption::default().name("vulkan").root(root),
110 ))
111 } else {
112 None
113 }
114 },
115 backend,
116 utilities: Arc::new(utilities),
117 }
118 }
119
120 fn prepare_bindings(&mut self, bindings: KernelArguments) -> Result<BindingsResource, IoError> {
121 let mut resources = Vec::with_capacity(bindings.buffers.len());
124
125 for b in bindings.buffers.into_iter() {
126 let stream = self.scheduler.stream(&b.stream);
127 let resource = stream.mem_manage.get_resource(b)?;
128 resources.push(resource);
129 }
130
131 Ok(BindingsResource {
132 resources,
133 info: bindings.info,
134 })
135 }
136
137 fn pipeline(
138 &mut self,
139 kernel: <Self as ComputeServer>::Kernel,
140 bindings: &KernelArguments,
141 mode: ExecutionMode,
142 ) -> Result<Arc<ComputePipeline>, LaunchError> {
143 let mut kernel_id = kernel.id();
144 kernel_id.mode(mode);
145
146 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
147 return Ok(pipeline.clone());
148 }
149
150 let cached = self.load_cached_pipeline(&kernel_id, bindings, mode)?;
151
152 if let Some(Ok(pipeline)) = cached {
153 self.pipelines.insert(kernel_id, pipeline.clone());
154 return Ok(pipeline);
155 }
156
157 validate_cube_dim(&self.utilities.properties, &kernel_id)?;
158 validate_units(&self.utilities.properties, &kernel_id)?;
159
160 let mut compiler = compiler(self.backend, &self.compilation_options);
161 let mut compiled = compiler.compile(self, kernel, mode)?;
162
163 if self.scheduler.logger.compilation_activated() {
164 compiled.debug_info = Some(DebugInformation::new(
165 compiler.lang_tag(),
166 kernel_id.clone(),
167 ));
168 }
169 self.scheduler.logger.log_compilation(&compiled);
170
171 self.validate_shared(&compiled.repr)?;
172
173 let repr = compiled.repr.as_ref().map(|it| it.as_ref());
197 let module = self.create_module(&compiled.entrypoint_name, repr, &compiled.source, mode)?;
198 let pipeline = self.create_pipeline(&compiled.entrypoint_name, repr, module, bindings);
199 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
200
201 #[cfg(feature = "spirv")]
202 if let Some(Err(key)) = cached
203 && let Some(crate::AutoRepresentation::SpirV(kernel)) = compiled.repr
204 {
205 let cache = self.spirv_cache.as_mut().unwrap();
206 let result = cache.insert(
207 key,
208 cubecl_spirv::SpirvCacheEntry::new(compiled.entrypoint_name, kernel),
209 );
210 if let Err(err) = result {
211 log::warn!("Unable to save the SPIR-V {err:?}");
212 }
213 }
214
215 Ok(pipeline)
216 }
217
218 fn validate_shared(&self, repr: &Option<crate::AutoRepresentation>) -> Result<(), LaunchError> {
219 let shared_bytes = repr.as_ref().map(|repr| match repr {
220 AutoRepresentation::Wgsl(repr) => repr.shared_memory_bytes(),
221 #[cfg(feature = "msl")]
222 AutoRepresentation::Msl(repr) => repr.shared_memory_size(),
223 #[cfg(feature = "spirv")]
224 AutoRepresentation::SpirV(repr) => repr.shared_size,
225 });
226 let max_smem = self.utilities.properties.hardware.max_shared_memory_size;
227 if let Some(shared_bytes) = shared_bytes
228 && shared_bytes > max_smem
229 {
230 Err(ResourceLimitError::SharedMemory {
231 requested: shared_bytes,
232 max: max_smem,
233 backtrace: BackTrace::capture(),
234 }
235 .into())
236 } else {
237 Ok(())
238 }
239 }
240}
241
242impl ComputeServer for WgpuServer {
243 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
244 type Storage = WgpuStorage;
245 type MemoryLayoutPolicy = ContiguousMemoryLayoutPolicy;
246 type Info = wgpu::Backend;
247
248 fn logger(&self) -> Arc<ServerLogger> {
249 self.scheduler.logger.clone()
250 }
251
252 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
253 self.utilities.clone()
254 }
255
256 fn staging(
257 &mut self,
258 _sizes: &[usize],
259 _stream_id: StreamId,
260 ) -> Result<Vec<Bytes>, ServerError> {
261 Err(IoError::UnsupportedIoOperation {
263 backtrace: BackTrace::capture(),
264 }
265 .into())
266 }
267
268 fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId) {
269 let stream = self.scheduler.stream(&stream_id);
270 let reserved = stream.empty(size).unwrap();
271 stream.mem_manage.bind(reserved, memory);
272 }
273
274 fn read(
275 &mut self,
276 descriptors: Vec<CopyDescriptor>,
277 stream_id: StreamId,
278 ) -> DynFut<Result<Vec<Bytes>, ServerError>> {
279 let mut streams = vec![stream_id];
280 let mut resources = Vec::with_capacity(descriptors.len());
281 for desc in descriptors {
282 if contiguous_strides(&desc.shape) != desc.strides {
283 return Box::pin(async {
284 Err(IoError::UnsupportedStrides {
285 backtrace: BackTrace::capture(),
286 }
287 .into())
288 });
289 }
290 if !streams.contains(&desc.handle.stream) {
291 streams.push(desc.handle.stream);
292 }
293 let stream = self.scheduler.stream(&desc.handle.stream);
294 let resource = match stream.mem_manage.get_resource(desc.handle) {
295 Ok(val) => val,
296 Err(err) => return Box::pin(async move { Err(err.into()) }),
297 };
298 resources.push((resource, desc.shape, desc.elem_size));
299 }
300
301 self.scheduler.execute_streams(streams);
302
303 let stream = self.scheduler.stream(&stream_id);
304 stream.read_resources(resources)
305 }
306
307 fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId) {
308 for (desc, data) in descriptors {
309 let stream = self.scheduler.stream(&desc.handle.stream);
310
311 if contiguous_strides(&desc.shape) != desc.strides {
312 stream.error(ServerError::Io(IoError::UnsupportedStrides {
313 backtrace: BackTrace::capture(),
314 }));
315 return;
316 }
317
318 let resource = match stream.mem_manage.get_resource(desc.handle) {
319 Ok(r) => r,
320 Err(err) => {
321 stream.error(ServerError::Io(err));
322 return;
323 }
324 };
325 let task = ScheduleTask::Write {
326 data,
327 buffer: resource,
328 };
329
330 self.scheduler.register(stream_id, task, &[]);
331 }
332 }
333
334 fn get_resource(
335 &mut self,
336 binding: Binding,
337 stream_id: StreamId,
338 ) -> Result<ManagedResource<WgpuResource>, ServerError> {
339 let mut streams = vec![stream_id];
340 if binding.stream != stream_id {
341 streams.push(binding.stream);
342 }
343 self.scheduler.execute_streams(streams);
344 let stream = self.scheduler.stream(&binding.stream);
345 let memory = binding.memory.clone();
346 let resource = stream.mem_manage.get_resource(binding)?;
347
348 Ok(ManagedResource::new(memory, resource))
349 }
350
351 unsafe fn launch(
352 &mut self,
353 kernel: Self::Kernel,
354 count: CubeCount,
355 args: KernelArguments,
356 mode: ExecutionMode,
357 stream_id: StreamId,
358 ) {
359 let pipeline = match self.pipeline(kernel, &args, mode) {
360 Ok(val) => val,
361 Err(err) => {
362 let stream = self.scheduler.stream(&stream_id);
364 stream.errors.push(ServerError::Launch(err));
365 return;
366 }
367 };
368
369 self.streams_pool.clear();
370 args.buffers
371 .iter()
372 .for_each(|b| self.streams_pool.push(b.stream));
373
374 let resources = match self.prepare_bindings(args) {
375 Ok(val) => val,
376 Err(err) => {
377 let stream = self.scheduler.stream(&stream_id);
379 stream.errors.push(ServerError::Io(err));
380 return;
381 }
382 };
383 let task = ScheduleTask::Execute {
384 pipeline,
385 count,
386 resources,
387 };
388
389 self.scheduler.register(stream_id, task, &self.streams_pool);
390 }
391
392 fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
393 self.scheduler.execute_streams(vec![stream_id]);
394
395 let stream = self.scheduler.stream(&stream_id);
396
397 stream.flush(StreamErrorMode {
398 ignore: false,
399 flush: true,
400 })
401 }
402
403 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>> {
405 self.scheduler.execute_streams(vec![stream_id]);
406 let stream = self.scheduler.stream(&stream_id);
407
408 stream.sync()
409 }
410
411 fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError> {
412 self.scheduler.execute_streams(vec![stream_id]);
413 let stream = self.scheduler.stream(&stream_id);
414 stream.start_profile()
415 }
416
417 fn end_profile(
418 &mut self,
419 stream_id: StreamId,
420 token: ProfilingToken,
421 ) -> Result<ProfileDuration, ProfileError> {
422 self.scheduler.execute_streams(vec![stream_id]);
423 let stream = self.scheduler.stream(&stream_id);
424
425 stream.end_profile(token)
426 }
427
428 fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError> {
429 self.scheduler.execute_streams(vec![stream_id]);
430 let stream = self.scheduler.stream(&stream_id);
431 Ok(stream.mem_manage.memory_usage())
432 }
433
434 fn memory_cleanup(&mut self, stream_id: StreamId) {
435 self.scheduler.execute_streams(vec![stream_id]);
436 let stream = self.scheduler.stream(&stream_id);
437 stream.mem_manage.memory_cleanup(true);
438 }
439
440 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
441 self.scheduler.execute_streams(vec![stream_id]);
442 let stream = self.scheduler.stream(&stream_id);
443 stream.mem_manage.mode(mode);
444 }
445}
446
447fn compiler(backend: wgpu::Backend, options: &WgpuCompilationOptions) -> AutoCompiler {
448 let _ = options; match backend {
450 #[cfg(feature = "spirv")]
451 wgpu::Backend::Vulkan if options.supports_vulkan => AutoCompiler::SpirV(Default::default()),
452 #[cfg(feature = "msl")]
453 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
454 _ => AutoCompiler::Wgsl(Default::default()),
455 }
456}
457
458pub(crate) fn contiguous_strides(shape: &Shape) -> Strides {
459 let rank = shape.len();
460 let mut strides = strides![1; rank];
461 for i in (0..rank - 1).rev() {
462 strides[i] = strides[i + 1] * shape[i + 1];
463 }
464 strides
465}