1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::backtrace::BackTrace;
6use cubecl_common::bytes::Bytes;
7use cubecl_common::profile::{ProfileDuration, TimingMethod};
8use cubecl_common::stream_id::StreamId;
9use cubecl_core::future::DynFut;
10use cubecl_core::server::{Allocation, AllocationDescriptor, ExecutionError, IoError, LaunchError};
11use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
12use cubecl_core::{
13 MemoryConfiguration, WgpuCompilationOptions,
14 prelude::*,
15 server::{Binding, Bindings, CopyDescriptor},
16};
17use cubecl_runtime::compiler::CompilationError;
18use cubecl_runtime::logging::ServerLogger;
19use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
20use cubecl_runtime::stream::scheduler::{
21 SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
22};
23use cubecl_runtime::{compiler::CubeTask, config::GlobalConfig};
24use cubecl_runtime::{
25 memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
26};
27use hashbrown::HashMap;
28use wgpu::ComputePipeline;
29
30#[derive(Debug)]
32pub struct WgpuServer {
33 pub(crate) device: wgpu::Device,
34 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
35 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
36 pub compilation_options: WgpuCompilationOptions,
37 pub(crate) backend: wgpu::Backend,
38 pub(crate) utilities: Arc<ServerUtilities<Self>>,
39}
40
41impl ServerCommunication for WgpuServer {
42 const SERVER_COMM_ENABLED: bool = false;
43}
44
45impl WgpuServer {
46 #[allow(clippy::too_many_arguments)]
48 pub fn new(
49 memory_properties: MemoryDeviceProperties,
50 memory_config: MemoryConfiguration,
51 compilation_options: WgpuCompilationOptions,
52 device: wgpu::Device,
53 queue: wgpu::Queue,
54 tasks_max: usize,
55 backend: wgpu::Backend,
56 timing_method: TimingMethod,
57 utilities: ServerUtilities<Self>,
58 ) -> Self {
59 let backend_scheduler = ScheduledWgpuBackend::new(
60 device.clone(),
61 queue.clone(),
62 memory_properties,
63 memory_config,
64 timing_method,
65 tasks_max,
66 utilities.logger.clone(),
67 );
68
69 let config = GlobalConfig::get();
70 let max_streams = config.streaming.max_streams;
71
72 Self {
73 compilation_options,
74 device,
75 pipelines: HashMap::new(),
76 scheduler: SchedulerMultiStream::new(
77 utilities.logger.clone(),
78 backend_scheduler,
79 SchedulerMultiStreamOptions {
80 max_streams,
81 max_tasks: tasks_max,
82 strategy: SchedulerStrategy::Interleave,
83 },
84 ),
85 backend,
86 utilities: Arc::new(utilities),
87 }
88 }
89
90 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
91 let resources = bindings
94 .buffers
95 .iter()
96 .map(|b| {
97 let stream = self.scheduler.stream(&b.stream);
98 stream.mem_manage.get_resource(b.clone()).unwrap()
99 })
100 .collect::<Vec<_>>();
101
102 BindingsResource {
103 resources,
104 metadata: bindings.metadata,
105 scalars: bindings.scalars,
106 }
107 }
108
109 fn pipeline(
110 &mut self,
111 kernel: <Self as ComputeServer>::Kernel,
112 mode: ExecutionMode,
113 ) -> Result<Arc<ComputePipeline>, CompilationError> {
114 let mut kernel_id = kernel.id();
115 kernel_id.mode(mode);
116
117 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
118 return Ok(pipeline.clone());
119 }
120
121 let mut compiler = compiler(self.backend);
122 let mut compile = compiler.compile(self, kernel, mode)?;
123
124 if self.scheduler.logger.compilation_activated() {
125 compile.debug_info = Some(DebugInformation::new(
126 compiler.lang_tag(),
127 kernel_id.clone(),
128 ));
129 }
130 self.scheduler.logger.log_compilation(&compile);
131 let pipeline = self.create_pipeline(compile, mode)?;
154 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
155
156 Ok(pipeline)
157 }
158}
159
160impl ComputeServer for WgpuServer {
161 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
162 type Storage = WgpuStorage;
163 type Info = wgpu::Backend;
164
165 fn logger(&self) -> Arc<ServerLogger> {
166 self.scheduler.logger.clone()
167 }
168
169 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
170 self.utilities.clone()
171 }
172
173 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
174 Err(IoError::UnsupportedIoOperation {
176 backtrace: BackTrace::capture(),
177 })
178 }
179
180 fn create(
181 &mut self,
182 descriptors: Vec<AllocationDescriptor<'_>>,
183 stream_id: StreamId,
184 ) -> Result<Vec<Allocation>, IoError> {
185 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
186 let strides = descriptors
187 .iter()
188 .map(|desc| contiguous_strides(desc.shape))
189 .collect::<Vec<_>>();
190 let sizes = descriptors
191 .iter()
192 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
193 .collect::<Vec<_>>();
194 let total_size = sizes
195 .iter()
196 .map(|it| it.next_multiple_of(align))
197 .sum::<usize>();
198
199 let stream = self.scheduler.stream(&stream_id);
200 let mem_handle = stream.empty(total_size as u64, stream_id)?;
201 let handles = offset_handles(mem_handle, &sizes, align);
202
203 Ok(handles
204 .into_iter()
205 .zip(strides)
206 .map(|(handle, strides)| Allocation::new(handle, strides))
207 .collect())
208 }
209
210 fn read<'a>(
211 &mut self,
212 descriptors: Vec<CopyDescriptor<'a>>,
213 stream_id: StreamId,
214 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
215 let mut streams = vec![stream_id];
216 let mut resources = Vec::with_capacity(descriptors.len());
217 for desc in descriptors {
218 if contiguous_strides(desc.shape) != desc.strides {
219 return Box::pin(async {
220 Err(IoError::UnsupportedStrides {
221 backtrace: BackTrace::capture(),
222 })
223 });
224 }
225 if !streams.contains(&desc.binding.stream) {
226 streams.push(desc.binding.stream);
227 }
228 let stream = self.scheduler.stream(&desc.binding.stream);
229 let resource = match stream.mem_manage.get_resource(desc.binding) {
230 Ok(val) => val,
231 Err(err) => return Box::pin(async move { Err(err) }),
232 };
233 resources.push((resource, desc.shape.to_vec(), desc.elem_size));
234 }
235
236 self.scheduler.execute_streams(streams);
237 let stream = self.scheduler.stream(&stream_id);
238 stream.read_resources(resources)
239 }
240
241 fn write(
242 &mut self,
243 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
244 stream_id: StreamId,
245 ) -> Result<(), IoError> {
246 for (desc, data) in descriptors {
247 if contiguous_strides(desc.shape) != desc.strides {
248 return Err(IoError::UnsupportedStrides {
249 backtrace: BackTrace::capture(),
250 });
251 }
252
253 let stream = self.scheduler.stream(&desc.binding.stream);
254 let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
255 let task = ScheduleTask::Write {
256 data,
257 buffer: resource,
258 };
259
260 self.scheduler.register(stream_id, task, [].into_iter());
261 }
262
263 Ok(())
264 }
265
266 fn get_resource(
267 &mut self,
268 binding: Binding,
269 stream_id: StreamId,
270 ) -> BindingResource<WgpuResource> {
271 let mut streams = vec![stream_id];
272 if binding.stream != stream_id {
273 streams.push(binding.stream);
274 }
275 self.scheduler.execute_streams(streams);
276 let stream = self.scheduler.stream(&binding.stream);
277 let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
278 BindingResource::new(binding, resource)
279 }
280
281 unsafe fn launch(
282 &mut self,
283 kernel: Self::Kernel,
284 count: CubeCount,
285 bindings: Bindings,
286 mode: ExecutionMode,
287 stream_id: StreamId,
288 ) -> Result<(), LaunchError> {
289 let pipeline = self.pipeline(kernel, mode)?;
290 let buffers = bindings.buffers.clone();
291 let resources = self.prepare_bindings(bindings);
292 let task = ScheduleTask::Execute {
293 pipeline,
294 count,
295 resources,
296 };
297
298 self.scheduler.register(stream_id, task, buffers.iter());
299
300 Ok(())
301 }
302
303 fn flush(&mut self, stream_id: StreamId) {
304 self.scheduler.execute_streams(vec![stream_id]);
305 let stream = self.scheduler.stream(&stream_id);
306 stream.flush()
307 }
308
309 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
311 self.scheduler.execute_streams(vec![stream_id]);
312 let stream = self.scheduler.stream(&stream_id);
313 stream.sync()
314 }
315
316 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
317 self.scheduler.execute_streams(vec![stream_id]);
318 let stream = self.scheduler.stream(&stream_id);
319 stream.start_profile()
320 }
321
322 fn end_profile(
323 &mut self,
324 stream_id: StreamId,
325 token: ProfilingToken,
326 ) -> Result<ProfileDuration, ProfileError> {
327 self.scheduler.execute_streams(vec![stream_id]);
328 let stream = self.scheduler.stream(&stream_id);
329 stream.end_profile(token)
330 }
331
332 fn memory_usage(
333 &mut self,
334 stream_id: StreamId,
335 ) -> cubecl_runtime::memory_management::MemoryUsage {
336 self.scheduler.execute_streams(vec![stream_id]);
337 let stream = self.scheduler.stream(&stream_id);
338 stream.mem_manage.memory_usage()
339 }
340
341 fn memory_cleanup(&mut self, stream_id: StreamId) {
342 self.scheduler.execute_streams(vec![stream_id]);
343 let stream = self.scheduler.stream(&stream_id);
344 stream.mem_manage.memory_cleanup(true);
345 }
346
347 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
348 self.scheduler.execute_streams(vec![stream_id]);
349 let stream = self.scheduler.stream(&stream_id);
350 stream.mem_manage.mode(mode);
351 }
352}
353
354fn compiler(backend: wgpu::Backend) -> AutoCompiler {
355 match backend {
356 #[cfg(feature = "spirv")]
357 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
358 #[cfg(feature = "msl")]
359 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
360 _ => AutoCompiler::Wgsl(Default::default()),
361 }
362}
363
364pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
365 let rank = shape.len();
366 let mut strides = vec![1; rank];
367 for i in (0..rank - 1).rev() {
368 strides[i] = strides[i + 1] * shape[i + 1];
369 }
370 strides
371}