1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::bytes::Bytes;
6use cubecl_common::profile::{ProfileDuration, TimingMethod};
7use cubecl_common::stream_id::StreamId;
8use cubecl_core::future::DynFut;
9use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
10use cubecl_core::{
11 MemoryConfiguration, WgpuCompilationOptions,
12 prelude::*,
13 server::{Binding, Bindings, CopyDescriptor},
14};
15use cubecl_core::{
16 compute::{CubeTask, DebugInformation},
17 server::{Allocation, AllocationDescriptor, IoError},
18};
19use cubecl_runtime::config::GlobalConfig;
20use cubecl_runtime::logging::ServerLogger;
21use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
22use cubecl_runtime::stream::scheduler::{
23 SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
24};
25use cubecl_runtime::{
26 memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
27};
28use hashbrown::HashMap;
29use wgpu::ComputePipeline;
30
31#[derive(Debug)]
33pub struct WgpuServer {
34 pub(crate) device: wgpu::Device,
35 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
36 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
37 pub compilation_options: WgpuCompilationOptions,
38 pub(crate) backend: wgpu::Backend,
39 pub(crate) utilities: Arc<ServerUtilities<Self>>,
40}
41
42impl ServerCommunication for WgpuServer {
43 const SERVER_COMM_ENABLED: bool = false;
44}
45
46impl WgpuServer {
47 #[allow(clippy::too_many_arguments)]
49 pub fn new(
50 memory_properties: MemoryDeviceProperties,
51 memory_config: MemoryConfiguration,
52 compilation_options: WgpuCompilationOptions,
53 device: wgpu::Device,
54 queue: wgpu::Queue,
55 tasks_max: usize,
56 backend: wgpu::Backend,
57 timing_method: TimingMethod,
58 utilities: ServerUtilities<Self>,
59 ) -> Self {
60 let backend_scheduler = ScheduledWgpuBackend::new(
61 device.clone(),
62 queue.clone(),
63 memory_properties,
64 memory_config,
65 timing_method,
66 tasks_max,
67 utilities.logger.clone(),
68 );
69
70 let config = GlobalConfig::get();
71 let max_streams = config.streaming.max_streams;
72
73 Self {
74 compilation_options,
75 device,
76 pipelines: HashMap::new(),
77 scheduler: SchedulerMultiStream::new(
78 utilities.logger.clone(),
79 backend_scheduler,
80 SchedulerMultiStreamOptions {
81 max_streams,
82 max_tasks: tasks_max,
83 strategy: SchedulerStrategy::Interleave,
84 },
85 ),
86 backend,
87 utilities: Arc::new(utilities),
88 }
89 }
90
91 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
92 let resources = bindings
95 .buffers
96 .iter()
97 .map(|b| {
98 let stream = self.scheduler.stream(&b.stream);
99 stream.mem_manage.get_resource(b.clone())
100 })
101 .collect::<Vec<_>>();
102
103 BindingsResource {
104 resources,
105 metadata: bindings.metadata,
106 scalars: bindings.scalars,
107 }
108 }
109
110 fn pipeline(
111 &mut self,
112 kernel: <Self as ComputeServer>::Kernel,
113 mode: ExecutionMode,
114 ) -> Arc<ComputePipeline> {
115 let mut kernel_id = kernel.id();
116 kernel_id.mode(mode);
117
118 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
119 return pipeline.clone();
120 }
121
122 let mut compiler = compiler(self.backend);
123 let mut compile = compiler.compile(self, kernel, mode);
124
125 if self.scheduler.logger.compilation_activated() {
126 compile.debug_info = Some(DebugInformation::new(
127 compiler.lang_tag(),
128 kernel_id.clone(),
129 ));
130 }
131 self.scheduler.logger.log_compilation(&compile);
132 let pipeline = self.create_pipeline(compile, mode);
155 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
156
157 pipeline
158 }
159}
160
161impl ComputeServer for WgpuServer {
162 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
163 type Storage = WgpuStorage;
164 type Info = wgpu::Backend;
165
166 fn logger(&self) -> Arc<ServerLogger> {
167 self.scheduler.logger.clone()
168 }
169
170 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
171 self.utilities.clone()
172 }
173
174 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
175 Err(IoError::UnsupportedIoOperation)
177 }
178
179 fn create(
180 &mut self,
181 descriptors: Vec<AllocationDescriptor<'_>>,
182 stream_id: StreamId,
183 ) -> Result<Vec<Allocation>, IoError> {
184 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
185 let strides = descriptors
186 .iter()
187 .map(|desc| contiguous_strides(desc.shape))
188 .collect::<Vec<_>>();
189 let sizes = descriptors
190 .iter()
191 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
192 .collect::<Vec<_>>();
193 let total_size = sizes
194 .iter()
195 .map(|it| it.next_multiple_of(align))
196 .sum::<usize>();
197
198 let stream = self.scheduler.stream(&stream_id);
199 let mem_handle = stream.empty(total_size as u64, stream_id)?;
200 let handles = offset_handles(mem_handle, &sizes, align);
201
202 Ok(handles
203 .into_iter()
204 .zip(strides)
205 .map(|(handle, strides)| Allocation::new(handle, strides))
206 .collect())
207 }
208
209 fn read<'a>(
210 &mut self,
211 descriptors: Vec<CopyDescriptor<'a>>,
212 stream_id: StreamId,
213 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
214 let mut streams = vec![stream_id];
215 let mut resources = Vec::with_capacity(descriptors.len());
216 for desc in descriptors {
217 if contiguous_strides(desc.shape) != desc.strides {
218 return Box::pin(async { Err(IoError::UnsupportedStrides) });
219 }
220 if !streams.contains(&desc.binding.stream) {
221 streams.push(desc.binding.stream);
222 }
223 let stream = self.scheduler.stream(&desc.binding.stream);
224 let resource = stream.mem_manage.get_resource(desc.binding);
225 resources.push((resource, desc.shape.to_vec(), desc.elem_size));
226 }
227
228 self.scheduler.execute_streams(streams);
229 let stream = self.scheduler.stream(&stream_id);
230 stream.read_resources(resources)
231 }
232
233 fn write(
234 &mut self,
235 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
236 stream_id: StreamId,
237 ) -> Result<(), IoError> {
238 for (desc, data) in descriptors {
239 if contiguous_strides(desc.shape) != desc.strides {
240 return Err(IoError::UnsupportedStrides);
241 }
242
243 let stream = self.scheduler.stream(&desc.binding.stream);
244 let resource = stream.mem_manage.get_resource(desc.binding.clone());
245 let task = ScheduleTask::Write {
246 data,
247 buffer: resource,
248 };
249
250 self.scheduler.register(stream_id, task, [].into_iter());
251 }
252
253 Ok(())
254 }
255
256 fn get_resource(
257 &mut self,
258 binding: Binding,
259 stream_id: StreamId,
260 ) -> BindingResource<WgpuResource> {
261 let mut streams = vec![stream_id];
262 if binding.stream != stream_id {
263 streams.push(binding.stream);
264 }
265 self.scheduler.execute_streams(streams);
266 let stream = self.scheduler.stream(&binding.stream);
267 let resource = stream.mem_manage.get_resource(binding.clone());
268 BindingResource::new(binding, resource)
269 }
270
271 unsafe fn execute(
272 &mut self,
273 kernel: Self::Kernel,
274 count: CubeCount,
275 bindings: Bindings,
276 mode: ExecutionMode,
277 stream_id: StreamId,
278 ) {
279 let pipeline = self.pipeline(kernel, mode);
280 let buffers = bindings.buffers.clone();
281 let resources = self.prepare_bindings(bindings);
282 let task = ScheduleTask::Execute {
283 pipeline,
284 count,
285 resources,
286 };
287
288 self.scheduler.register(stream_id, task, buffers.iter());
289 }
290
291 fn flush(&mut self, stream_id: StreamId) {
292 self.scheduler.execute_streams(vec![stream_id]);
293 let stream = self.scheduler.stream(&stream_id);
294 stream.flush()
295 }
296
297 fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
299 self.scheduler.execute_streams(vec![stream_id]);
300 let stream = self.scheduler.stream(&stream_id);
301 stream.sync()
302 }
303
304 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
305 self.scheduler.execute_streams(vec![stream_id]);
306 let stream = self.scheduler.stream(&stream_id);
307 stream.start_profile()
308 }
309
310 fn end_profile(
311 &mut self,
312 stream_id: StreamId,
313 token: ProfilingToken,
314 ) -> Result<ProfileDuration, ProfileError> {
315 self.scheduler.execute_streams(vec![stream_id]);
316 let stream = self.scheduler.stream(&stream_id);
317 stream.end_profile(token)
318 }
319
320 fn memory_usage(
321 &mut self,
322 stream_id: StreamId,
323 ) -> cubecl_runtime::memory_management::MemoryUsage {
324 self.scheduler.execute_streams(vec![stream_id]);
325 let stream = self.scheduler.stream(&stream_id);
326 stream.mem_manage.memory_usage()
327 }
328
329 fn memory_cleanup(&mut self, stream_id: StreamId) {
330 self.scheduler.execute_streams(vec![stream_id]);
331 let stream = self.scheduler.stream(&stream_id);
332 stream.mem_manage.memory_cleanup(true);
333 }
334
335 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
336 self.scheduler.execute_streams(vec![stream_id]);
337 let stream = self.scheduler.stream(&stream_id);
338 stream.mem_manage.mode(mode);
339 }
340}
341
342fn compiler(backend: wgpu::Backend) -> AutoCompiler {
343 match backend {
344 #[cfg(feature = "spirv")]
345 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
346 #[cfg(feature = "msl")]
347 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
348 _ => AutoCompiler::Wgsl(Default::default()),
349 }
350}
351
352pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
353 let rank = shape.len();
354 let mut strides = vec![1; rank];
355 for i in (0..rank - 1).rev() {
356 strides[i] = strides[i + 1] * shape[i + 1];
357 }
358 strides
359}