1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::{
6 backtrace::BackTrace,
7 bytes::Bytes,
8 profile::{ProfileDuration, TimingMethod},
9 stream_id::StreamId,
10};
11use cubecl_core::{
12 MemoryConfiguration, WgpuCompilationOptions,
13 future::DynFut,
14 prelude::*,
15 server::{
16 Allocation, AllocationDescriptor, Binding, Bindings, CopyDescriptor, ExecutionError,
17 IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities,
18 },
19};
20use cubecl_ir::MemoryDeviceProperties;
21use cubecl_runtime::{
22 compiler::{CompilationError, CubeTask},
23 config::GlobalConfig,
24 logging::ServerLogger,
25 memory_management::{MemoryAllocationMode, offset_handles},
26 server::ComputeServer,
27 storage::BindingResource,
28 stream::scheduler::{SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy},
29};
30use hashbrown::HashMap;
31use wgpu::ComputePipeline;
32
33#[derive(Debug)]
35pub struct WgpuServer {
36 pub(crate) device: wgpu::Device,
37 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
38 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
39 pub compilation_options: WgpuCompilationOptions,
40 pub(crate) backend: wgpu::Backend,
41 pub(crate) utilities: Arc<ServerUtilities<Self>>,
42}
43
44impl ServerCommunication for WgpuServer {
45 const SERVER_COMM_ENABLED: bool = false;
46}
47
48impl WgpuServer {
49 #[allow(clippy::too_many_arguments)]
51 pub fn new(
52 memory_properties: MemoryDeviceProperties,
53 memory_config: MemoryConfiguration,
54 compilation_options: WgpuCompilationOptions,
55 device: wgpu::Device,
56 queue: wgpu::Queue,
57 tasks_max: usize,
58 backend: wgpu::Backend,
59 timing_method: TimingMethod,
60 utilities: ServerUtilities<Self>,
61 ) -> Self {
62 let backend_scheduler = ScheduledWgpuBackend::new(
63 device.clone(),
64 queue.clone(),
65 memory_properties,
66 memory_config,
67 timing_method,
68 tasks_max,
69 utilities.logger.clone(),
70 );
71
72 let config = GlobalConfig::get();
73 let max_streams = config.streaming.max_streams;
74
75 Self {
76 compilation_options,
77 device,
78 pipelines: HashMap::new(),
79 scheduler: SchedulerMultiStream::new(
80 utilities.logger.clone(),
81 backend_scheduler,
82 SchedulerMultiStreamOptions {
83 max_streams,
84 max_tasks: tasks_max,
85 strategy: SchedulerStrategy::Interleave,
86 },
87 ),
88 backend,
89 utilities: Arc::new(utilities),
90 }
91 }
92
93 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
94 let resources = bindings
97 .buffers
98 .iter()
99 .map(|b| {
100 let stream = self.scheduler.stream(&b.stream);
101 stream.mem_manage.get_resource(b.clone()).unwrap()
102 })
103 .collect::<Vec<_>>();
104
105 BindingsResource {
106 resources,
107 metadata: bindings.metadata,
108 scalars: bindings.scalars,
109 }
110 }
111
112 fn pipeline(
113 &mut self,
114 kernel: <Self as ComputeServer>::Kernel,
115 mode: ExecutionMode,
116 ) -> Result<Arc<ComputePipeline>, CompilationError> {
117 let mut kernel_id = kernel.id();
118 kernel_id.mode(mode);
119
120 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
121 return Ok(pipeline.clone());
122 }
123
124 let mut compiler = compiler(self.backend);
125 let mut compile = compiler.compile(self, kernel, mode)?;
126
127 if self.scheduler.logger.compilation_activated() {
128 compile.debug_info = Some(DebugInformation::new(
129 compiler.lang_tag(),
130 kernel_id.clone(),
131 ));
132 }
133 self.scheduler.logger.log_compilation(&compile);
134 let pipeline = self.create_pipeline(compile, mode)?;
157 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
158
159 Ok(pipeline)
160 }
161}
162
163impl ComputeServer for WgpuServer {
164 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
165 type Storage = WgpuStorage;
166 type Info = wgpu::Backend;
167
168 fn logger(&self) -> Arc<ServerLogger> {
169 self.scheduler.logger.clone()
170 }
171
172 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
173 self.utilities.clone()
174 }
175
176 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
177 Err(IoError::UnsupportedIoOperation {
179 backtrace: BackTrace::capture(),
180 })
181 }
182
183 fn create(
184 &mut self,
185 descriptors: Vec<AllocationDescriptor<'_>>,
186 stream_id: StreamId,
187 ) -> Result<Vec<Allocation>, IoError> {
188 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
189 let strides = descriptors
190 .iter()
191 .map(|desc| contiguous_strides(desc.shape))
192 .collect::<Vec<_>>();
193 let sizes = descriptors
194 .iter()
195 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
196 .collect::<Vec<_>>();
197 let total_size = sizes
198 .iter()
199 .map(|it| it.next_multiple_of(align))
200 .sum::<usize>();
201
202 let stream = self.scheduler.stream(&stream_id);
203 let mem_handle = stream.empty(total_size as u64, stream_id)?;
204 let handles = offset_handles(mem_handle, &sizes, align);
205
206 Ok(handles
207 .into_iter()
208 .zip(strides)
209 .map(|(handle, strides)| Allocation::new(handle, strides))
210 .collect())
211 }
212
213 fn read<'a>(
214 &mut self,
215 descriptors: Vec<CopyDescriptor<'a>>,
216 stream_id: StreamId,
217 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
218 let mut streams = vec![stream_id];
219 let mut resources = Vec::with_capacity(descriptors.len());
220 for desc in descriptors {
221 if contiguous_strides(desc.shape) != desc.strides {
222 return Box::pin(async {
223 Err(IoError::UnsupportedStrides {
224 backtrace: BackTrace::capture(),
225 })
226 });
227 }
228 if !streams.contains(&desc.binding.stream) {
229 streams.push(desc.binding.stream);
230 }
231 let stream = self.scheduler.stream(&desc.binding.stream);
232 let resource = match stream.mem_manage.get_resource(desc.binding) {
233 Ok(val) => val,
234 Err(err) => return Box::pin(async move { Err(err) }),
235 };
236 resources.push((resource, desc.shape.to_vec(), desc.elem_size));
237 }
238
239 self.scheduler.execute_streams(streams);
240 let stream = self.scheduler.stream(&stream_id);
241 stream.read_resources(resources)
242 }
243
244 fn write(
245 &mut self,
246 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
247 stream_id: StreamId,
248 ) -> Result<(), IoError> {
249 for (desc, data) in descriptors {
250 if contiguous_strides(desc.shape) != desc.strides {
251 return Err(IoError::UnsupportedStrides {
252 backtrace: BackTrace::capture(),
253 });
254 }
255
256 let stream = self.scheduler.stream(&desc.binding.stream);
257 let resource = stream.mem_manage.get_resource(desc.binding.clone())?;
258 let task = ScheduleTask::Write {
259 data,
260 buffer: resource,
261 };
262
263 self.scheduler.register(stream_id, task, [].into_iter());
264 }
265
266 Ok(())
267 }
268
269 fn get_resource(
270 &mut self,
271 binding: Binding,
272 stream_id: StreamId,
273 ) -> BindingResource<WgpuResource> {
274 let mut streams = vec![stream_id];
275 if binding.stream != stream_id {
276 streams.push(binding.stream);
277 }
278 self.scheduler.execute_streams(streams);
279 let stream = self.scheduler.stream(&binding.stream);
280 let resource = stream.mem_manage.get_resource(binding.clone()).unwrap();
281 BindingResource::new(binding, resource)
282 }
283
284 unsafe fn launch(
285 &mut self,
286 kernel: Self::Kernel,
287 count: CubeCount,
288 bindings: Bindings,
289 mode: ExecutionMode,
290 stream_id: StreamId,
291 ) -> Result<(), LaunchError> {
292 let pipeline = self.pipeline(kernel, mode)?;
293 let buffers = bindings.buffers.clone();
294 let resources = self.prepare_bindings(bindings);
295 let task = ScheduleTask::Execute {
296 pipeline,
297 count,
298 resources,
299 };
300
301 self.scheduler.register(stream_id, task, buffers.iter());
302
303 Ok(())
304 }
305
306 fn flush(&mut self, stream_id: StreamId) {
307 self.scheduler.execute_streams(vec![stream_id]);
308 let stream = self.scheduler.stream(&stream_id);
309 stream.flush()
310 }
311
312 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
314 self.scheduler.execute_streams(vec![stream_id]);
315 let stream = self.scheduler.stream(&stream_id);
316 stream.sync()
317 }
318
319 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
320 self.scheduler.execute_streams(vec![stream_id]);
321 let stream = self.scheduler.stream(&stream_id);
322 stream.start_profile()
323 }
324
325 fn end_profile(
326 &mut self,
327 stream_id: StreamId,
328 token: ProfilingToken,
329 ) -> Result<ProfileDuration, ProfileError> {
330 self.scheduler.execute_streams(vec![stream_id]);
331 let stream = self.scheduler.stream(&stream_id);
332 stream.end_profile(token)
333 }
334
335 fn memory_usage(
336 &mut self,
337 stream_id: StreamId,
338 ) -> cubecl_runtime::memory_management::MemoryUsage {
339 self.scheduler.execute_streams(vec![stream_id]);
340 let stream = self.scheduler.stream(&stream_id);
341 stream.mem_manage.memory_usage()
342 }
343
344 fn memory_cleanup(&mut self, stream_id: StreamId) {
345 self.scheduler.execute_streams(vec![stream_id]);
346 let stream = self.scheduler.stream(&stream_id);
347 stream.mem_manage.memory_cleanup(true);
348 }
349
350 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
351 self.scheduler.execute_streams(vec![stream_id]);
352 let stream = self.scheduler.stream(&stream_id);
353 stream.mem_manage.mode(mode);
354 }
355}
356
357fn compiler(backend: wgpu::Backend) -> AutoCompiler {
358 match backend {
359 #[cfg(feature = "spirv")]
360 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
361 #[cfg(feature = "msl")]
362 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
363 _ => AutoCompiler::Wgsl(Default::default()),
364 }
365}
366
367pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
368 let rank = shape.len();
369 let mut strides = vec![1; rank];
370 for i in (0..rank - 1).rev() {
371 strides[i] = strides[i + 1] * shape[i + 1];
372 }
373 strides
374}