1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::bytes::Bytes;
6use cubecl_common::profile::{ProfileDuration, TimingMethod};
7use cubecl_common::stream_id::StreamId;
8use cubecl_core::future::DynFut;
9use cubecl_core::server::{Allocation, AllocationDescriptor, ExecutionError, IoError, LaunchError};
10use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
11use cubecl_core::{
12 MemoryConfiguration, WgpuCompilationOptions,
13 prelude::*,
14 server::{Binding, Bindings, CopyDescriptor},
15};
16use cubecl_runtime::compiler::CompilationError;
17use cubecl_runtime::logging::ServerLogger;
18use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
19use cubecl_runtime::stream::scheduler::{
20 SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
21};
22use cubecl_runtime::{compiler::CubeTask, config::GlobalConfig};
23use cubecl_runtime::{
24 memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
25};
26use hashbrown::HashMap;
27use wgpu::ComputePipeline;
28
29#[derive(Debug)]
31pub struct WgpuServer {
32 pub(crate) device: wgpu::Device,
33 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
34 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
35 pub compilation_options: WgpuCompilationOptions,
36 pub(crate) backend: wgpu::Backend,
37 pub(crate) utilities: Arc<ServerUtilities<Self>>,
38}
39
40impl ServerCommunication for WgpuServer {
41 const SERVER_COMM_ENABLED: bool = false;
42}
43
44impl WgpuServer {
45 #[allow(clippy::too_many_arguments)]
47 pub fn new(
48 memory_properties: MemoryDeviceProperties,
49 memory_config: MemoryConfiguration,
50 compilation_options: WgpuCompilationOptions,
51 device: wgpu::Device,
52 queue: wgpu::Queue,
53 tasks_max: usize,
54 backend: wgpu::Backend,
55 timing_method: TimingMethod,
56 utilities: ServerUtilities<Self>,
57 ) -> Self {
58 let backend_scheduler = ScheduledWgpuBackend::new(
59 device.clone(),
60 queue.clone(),
61 memory_properties,
62 memory_config,
63 timing_method,
64 tasks_max,
65 utilities.logger.clone(),
66 );
67
68 let config = GlobalConfig::get();
69 let max_streams = config.streaming.max_streams;
70
71 Self {
72 compilation_options,
73 device,
74 pipelines: HashMap::new(),
75 scheduler: SchedulerMultiStream::new(
76 utilities.logger.clone(),
77 backend_scheduler,
78 SchedulerMultiStreamOptions {
79 max_streams,
80 max_tasks: tasks_max,
81 strategy: SchedulerStrategy::Interleave,
82 },
83 ),
84 backend,
85 utilities: Arc::new(utilities),
86 }
87 }
88
89 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
90 let resources = bindings
93 .buffers
94 .iter()
95 .map(|b| {
96 let stream = self.scheduler.stream(&b.stream);
97 stream.mem_manage.get_resource(b.clone())
98 })
99 .collect::<Vec<_>>();
100
101 BindingsResource {
102 resources,
103 metadata: bindings.metadata,
104 scalars: bindings.scalars,
105 }
106 }
107
108 fn pipeline(
109 &mut self,
110 kernel: <Self as ComputeServer>::Kernel,
111 mode: ExecutionMode,
112 ) -> Result<Arc<ComputePipeline>, CompilationError> {
113 let mut kernel_id = kernel.id();
114 kernel_id.mode(mode);
115
116 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
117 return Ok(pipeline.clone());
118 }
119
120 let mut compiler = compiler(self.backend);
121 let mut compile = compiler.compile(self, kernel, mode)?;
122
123 if self.scheduler.logger.compilation_activated() {
124 compile.debug_info = Some(DebugInformation::new(
125 compiler.lang_tag(),
126 kernel_id.clone(),
127 ));
128 }
129 self.scheduler.logger.log_compilation(&compile);
130 let pipeline = self.create_pipeline(compile, mode)?;
153 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
154
155 Ok(pipeline)
156 }
157}
158
159impl ComputeServer for WgpuServer {
160 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
161 type Storage = WgpuStorage;
162 type Info = wgpu::Backend;
163
164 fn logger(&self) -> Arc<ServerLogger> {
165 self.scheduler.logger.clone()
166 }
167
168 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
169 self.utilities.clone()
170 }
171
172 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
173 Err(IoError::UnsupportedIoOperation)
175 }
176
177 fn create(
178 &mut self,
179 descriptors: Vec<AllocationDescriptor<'_>>,
180 stream_id: StreamId,
181 ) -> Result<Vec<Allocation>, IoError> {
182 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
183 let strides = descriptors
184 .iter()
185 .map(|desc| contiguous_strides(desc.shape))
186 .collect::<Vec<_>>();
187 let sizes = descriptors
188 .iter()
189 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
190 .collect::<Vec<_>>();
191 let total_size = sizes
192 .iter()
193 .map(|it| it.next_multiple_of(align))
194 .sum::<usize>();
195
196 let stream = self.scheduler.stream(&stream_id);
197 let mem_handle = stream.empty(total_size as u64, stream_id)?;
198 let handles = offset_handles(mem_handle, &sizes, align);
199
200 Ok(handles
201 .into_iter()
202 .zip(strides)
203 .map(|(handle, strides)| Allocation::new(handle, strides))
204 .collect())
205 }
206
207 fn read<'a>(
208 &mut self,
209 descriptors: Vec<CopyDescriptor<'a>>,
210 stream_id: StreamId,
211 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
212 let mut streams = vec![stream_id];
213 let mut resources = Vec::with_capacity(descriptors.len());
214 for desc in descriptors {
215 if contiguous_strides(desc.shape) != desc.strides {
216 return Box::pin(async { Err(IoError::UnsupportedStrides) });
217 }
218 if !streams.contains(&desc.binding.stream) {
219 streams.push(desc.binding.stream);
220 }
221 let stream = self.scheduler.stream(&desc.binding.stream);
222 let resource = stream.mem_manage.get_resource(desc.binding);
223 resources.push((resource, desc.shape.to_vec(), desc.elem_size));
224 }
225
226 self.scheduler.execute_streams(streams);
227 let stream = self.scheduler.stream(&stream_id);
228 stream.read_resources(resources)
229 }
230
231 fn write(
232 &mut self,
233 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
234 stream_id: StreamId,
235 ) -> Result<(), IoError> {
236 for (desc, data) in descriptors {
237 if contiguous_strides(desc.shape) != desc.strides {
238 return Err(IoError::UnsupportedStrides);
239 }
240
241 let stream = self.scheduler.stream(&desc.binding.stream);
242 let resource = stream.mem_manage.get_resource(desc.binding.clone());
243 let task = ScheduleTask::Write {
244 data,
245 buffer: resource,
246 };
247
248 self.scheduler.register(stream_id, task, [].into_iter());
249 }
250
251 Ok(())
252 }
253
254 fn get_resource(
255 &mut self,
256 binding: Binding,
257 stream_id: StreamId,
258 ) -> BindingResource<WgpuResource> {
259 let mut streams = vec![stream_id];
260 if binding.stream != stream_id {
261 streams.push(binding.stream);
262 }
263 self.scheduler.execute_streams(streams);
264 let stream = self.scheduler.stream(&binding.stream);
265 let resource = stream.mem_manage.get_resource(binding.clone());
266 BindingResource::new(binding, resource)
267 }
268
269 unsafe fn launch(
270 &mut self,
271 kernel: Self::Kernel,
272 count: CubeCount,
273 bindings: Bindings,
274 mode: ExecutionMode,
275 stream_id: StreamId,
276 ) -> Result<(), LaunchError> {
277 let pipeline = self.pipeline(kernel, mode)?;
278 let buffers = bindings.buffers.clone();
279 let resources = self.prepare_bindings(bindings);
280 let task = ScheduleTask::Execute {
281 pipeline,
282 count,
283 resources,
284 };
285
286 self.scheduler.register(stream_id, task, buffers.iter());
287
288 Ok(())
289 }
290
291 fn flush(&mut self, stream_id: StreamId) {
292 self.scheduler.execute_streams(vec![stream_id]);
293 let stream = self.scheduler.stream(&stream_id);
294 stream.flush()
295 }
296
297 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
299 self.scheduler.execute_streams(vec![stream_id]);
300 let stream = self.scheduler.stream(&stream_id);
301 stream.sync()
302 }
303
304 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
305 self.scheduler.execute_streams(vec![stream_id]);
306 let stream = self.scheduler.stream(&stream_id);
307 stream.start_profile()
308 }
309
310 fn end_profile(
311 &mut self,
312 stream_id: StreamId,
313 token: ProfilingToken,
314 ) -> Result<ProfileDuration, ProfileError> {
315 self.scheduler.execute_streams(vec![stream_id]);
316 let stream = self.scheduler.stream(&stream_id);
317 stream.end_profile(token)
318 }
319
320 fn memory_usage(
321 &mut self,
322 stream_id: StreamId,
323 ) -> cubecl_runtime::memory_management::MemoryUsage {
324 self.scheduler.execute_streams(vec![stream_id]);
325 let stream = self.scheduler.stream(&stream_id);
326 stream.mem_manage.memory_usage()
327 }
328
329 fn memory_cleanup(&mut self, stream_id: StreamId) {
330 self.scheduler.execute_streams(vec![stream_id]);
331 let stream = self.scheduler.stream(&stream_id);
332 stream.mem_manage.memory_cleanup(true);
333 }
334
335 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
336 self.scheduler.execute_streams(vec![stream_id]);
337 let stream = self.scheduler.stream(&stream_id);
338 stream.mem_manage.mode(mode);
339 }
340}
341
342fn compiler(backend: wgpu::Backend) -> AutoCompiler {
343 match backend {
344 #[cfg(feature = "spirv")]
345 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
346 #[cfg(feature = "msl")]
347 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
348 _ => AutoCompiler::Wgsl(Default::default()),
349 }
350}
351
352pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
353 let rank = shape.len();
354 let mut strides = vec![1; rank];
355 for i in (0..rank - 1).rev() {
356 strides[i] = strides[i + 1] * shape[i + 1];
357 }
358 strides
359}