1use super::storage::{WgpuResource, WgpuStorage};
2use crate::AutoCompiler;
3use crate::schedule::{BindingsResource, ScheduleTask, ScheduledWgpuBackend};
4use alloc::sync::Arc;
5use cubecl_common::bytes::Bytes;
6use cubecl_common::profile::{ProfileDuration, TimingMethod};
7use cubecl_common::stream_id::StreamId;
8use cubecl_core::future::DynFut;
9use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities};
10use cubecl_core::{
11 MemoryConfiguration, WgpuCompilationOptions,
12 prelude::*,
13 server::{Binding, Bindings, CopyDescriptor},
14};
15use cubecl_core::{
16 compute::{CubeTask, DebugInformation},
17 server::{Allocation, AllocationDescriptor, IoError},
18};
19use cubecl_runtime::config::GlobalConfig;
20use cubecl_runtime::logging::ServerLogger;
21use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles};
22use cubecl_runtime::stream::scheduler::{
23 SchedulerMultiStream, SchedulerMultiStreamOptions, SchedulerStrategy,
24};
25use cubecl_runtime::{
26 memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
27};
28use hashbrown::HashMap;
29use wgpu::ComputePipeline;
30
31#[derive(Debug)]
33pub struct WgpuServer {
34 pub(crate) device: wgpu::Device,
35 pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
36 scheduler: SchedulerMultiStream<ScheduledWgpuBackend>,
37 pub compilation_options: WgpuCompilationOptions,
38 pub(crate) backend: wgpu::Backend,
39 pub(crate) utilities: Arc<ServerUtilities<Self>>,
40}
41
42impl ServerCommunication for WgpuServer {
43 const SERVER_COMM_ENABLED: bool = false;
44}
45
46impl WgpuServer {
47 #[allow(clippy::too_many_arguments)]
49 pub fn new(
50 memory_properties: MemoryDeviceProperties,
51 memory_config: MemoryConfiguration,
52 compilation_options: WgpuCompilationOptions,
53 device: wgpu::Device,
54 queue: wgpu::Queue,
55 tasks_max: usize,
56 backend: wgpu::Backend,
57 timing_method: TimingMethod,
58 utilities: ServerUtilities<Self>,
59 ) -> Self {
60 let backend_scheduler = ScheduledWgpuBackend::new(
61 device.clone(),
62 queue.clone(),
63 memory_properties,
64 memory_config,
65 timing_method,
66 tasks_max,
67 utilities.logger.clone(),
68 );
69
70 let config = GlobalConfig::get();
71 let max_streams = config.streaming.max_streams;
72
73 Self {
74 compilation_options,
75 device,
76 pipelines: HashMap::new(),
77 scheduler: SchedulerMultiStream::new(
78 utilities.logger.clone(),
79 backend_scheduler,
80 SchedulerMultiStreamOptions {
81 max_streams,
82 max_tasks: tasks_max,
83 strategy: SchedulerStrategy::Interleave,
84 },
85 ),
86 backend,
87 utilities: Arc::new(utilities),
88 }
89 }
90
91 fn prepare_bindings(&mut self, bindings: Bindings) -> BindingsResource {
92 let resources = bindings
95 .buffers
96 .iter()
97 .map(|b| {
98 let stream = self.scheduler.stream(&b.stream);
99 stream.mem_manage.get_resource(b.clone())
100 })
101 .collect::<Vec<_>>();
102
103 BindingsResource {
104 resources,
105 metadata: bindings.metadata,
106 scalars: bindings.scalars,
107 }
108 }
109
110 fn pipeline(
111 &mut self,
112 kernel: <Self as ComputeServer>::Kernel,
113 mode: ExecutionMode,
114 ) -> Arc<ComputePipeline> {
115 let mut kernel_id = kernel.id();
116 kernel_id.mode(mode);
117
118 if let Some(pipeline) = self.pipelines.get(&kernel_id) {
119 return pipeline.clone();
120 }
121
122 let mut compiler = compiler(self.backend);
123 let mut compile = compiler.compile(self, kernel, mode);
124
125 if self.scheduler.logger.compilation_activated() {
126 compile.debug_info = Some(DebugInformation::new(
127 compiler.lang_tag(),
128 kernel_id.clone(),
129 ));
130 }
131 self.scheduler.logger.log_compilation(&compile);
132 let pipeline = self.create_pipeline(compile, mode);
155 self.pipelines.insert(kernel_id.clone(), pipeline.clone());
156
157 pipeline
158 }
159}
160
161impl ComputeServer for WgpuServer {
162 type Kernel = Box<dyn CubeTask<AutoCompiler>>;
163 type Storage = WgpuStorage;
164 type Info = wgpu::Backend;
165
166 fn logger(&self) -> Arc<ServerLogger> {
167 self.scheduler.logger.clone()
168 }
169
170 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
171 self.utilities.clone()
172 }
173
174 fn create(
175 &mut self,
176 descriptors: Vec<AllocationDescriptor<'_>>,
177 stream_id: StreamId,
178 ) -> Result<Vec<Allocation>, IoError> {
179 let align = self.device.limits().min_storage_buffer_offset_alignment as usize;
180 let strides = descriptors
181 .iter()
182 .map(|desc| contiguous_strides(desc.shape))
183 .collect::<Vec<_>>();
184 let sizes = descriptors
185 .iter()
186 .map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
187 .collect::<Vec<_>>();
188 let total_size = sizes
189 .iter()
190 .map(|it| it.next_multiple_of(align))
191 .sum::<usize>();
192
193 let stream = self.scheduler.stream(&stream_id);
194 let mem_handle = stream.empty(total_size as u64, stream_id)?;
195 let handles = offset_handles(mem_handle, &sizes, align);
196
197 Ok(handles
198 .into_iter()
199 .zip(strides)
200 .map(|(handle, strides)| Allocation::new(handle, strides))
201 .collect())
202 }
203
204 fn read<'a>(
205 &mut self,
206 descriptors: Vec<CopyDescriptor<'a>>,
207 stream_id: StreamId,
208 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
209 let mut streams = vec![stream_id];
210 let mut resources = Vec::with_capacity(descriptors.len());
211 for desc in descriptors {
212 if contiguous_strides(desc.shape) != desc.strides {
213 return Box::pin(async { Err(IoError::UnsupportedStrides) });
214 }
215 if !streams.contains(&desc.binding.stream) {
216 streams.push(desc.binding.stream);
217 }
218 let stream = self.scheduler.stream(&desc.binding.stream);
219 let resource = stream.mem_manage.get_resource(desc.binding);
220 resources.push((resource, desc.shape.to_vec(), desc.elem_size));
221 }
222
223 self.scheduler.execute_streams(streams);
224 let stream = self.scheduler.stream(&stream_id);
225 stream.read_resources(resources)
226 }
227
228 fn write(
229 &mut self,
230 descriptors: Vec<(CopyDescriptor<'_>, &[u8])>,
231 stream_id: StreamId,
232 ) -> Result<(), IoError> {
233 for (desc, data) in descriptors {
234 if contiguous_strides(desc.shape) != desc.strides {
235 return Err(IoError::UnsupportedStrides);
236 }
237
238 let stream = self.scheduler.stream(&desc.binding.stream);
239 let resource = stream.mem_manage.get_resource(desc.binding.clone());
240 let task = ScheduleTask::Write {
241 data: data.to_vec(),
242 buffer: resource,
243 };
244
245 self.scheduler.register(stream_id, task, [].into_iter());
246 }
247
248 Ok(())
249 }
250
251 fn get_resource(
252 &mut self,
253 binding: Binding,
254 stream_id: StreamId,
255 ) -> BindingResource<WgpuResource> {
256 let mut streams = vec![stream_id];
257 if binding.stream != stream_id {
258 streams.push(binding.stream);
259 }
260 self.scheduler.execute_streams(streams);
261 let stream = self.scheduler.stream(&binding.stream);
262 let resource = stream.mem_manage.get_resource(binding.clone());
263 BindingResource::new(binding, resource)
264 }
265
266 unsafe fn execute(
267 &mut self,
268 kernel: Self::Kernel,
269 count: CubeCount,
270 bindings: Bindings,
271 mode: ExecutionMode,
272 stream_id: StreamId,
273 ) {
274 let pipeline = self.pipeline(kernel, mode);
275 let buffers = bindings.buffers.clone();
276 let resources = self.prepare_bindings(bindings);
277 let task = ScheduleTask::Execute {
278 pipeline,
279 count,
280 resources,
281 };
282
283 self.scheduler.register(stream_id, task, buffers.iter());
284 }
285
286 fn flush(&mut self, stream_id: StreamId) {
287 self.scheduler.execute_streams(vec![stream_id]);
288 let stream = self.scheduler.stream(&stream_id);
289 stream.flush()
290 }
291
292 fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
294 self.scheduler.execute_streams(vec![stream_id]);
295 let stream = self.scheduler.stream(&stream_id);
296 stream.sync()
297 }
298
299 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
300 self.scheduler.execute_streams(vec![stream_id]);
301 let stream = self.scheduler.stream(&stream_id);
302 stream.start_profile()
303 }
304
305 fn end_profile(
306 &mut self,
307 stream_id: StreamId,
308 token: ProfilingToken,
309 ) -> Result<ProfileDuration, ProfileError> {
310 self.scheduler.execute_streams(vec![stream_id]);
311 let stream = self.scheduler.stream(&stream_id);
312 stream.end_profile(token)
313 }
314
315 fn memory_usage(
316 &mut self,
317 stream_id: StreamId,
318 ) -> cubecl_runtime::memory_management::MemoryUsage {
319 self.scheduler.execute_streams(vec![stream_id]);
320 let stream = self.scheduler.stream(&stream_id);
321 stream.mem_manage.memory_usage()
322 }
323
324 fn memory_cleanup(&mut self, stream_id: StreamId) {
325 self.scheduler.execute_streams(vec![stream_id]);
326 let stream = self.scheduler.stream(&stream_id);
327 stream.mem_manage.memory_cleanup(true);
328 }
329
330 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
331 self.scheduler.execute_streams(vec![stream_id]);
332 let stream = self.scheduler.stream(&stream_id);
333 stream.mem_manage.mode(mode);
334 }
335}
336
337fn compiler(backend: wgpu::Backend) -> AutoCompiler {
338 match backend {
339 #[cfg(feature = "spirv")]
340 wgpu::Backend::Vulkan => AutoCompiler::SpirV(Default::default()),
341 #[cfg(feature = "msl")]
342 wgpu::Backend::Metal => AutoCompiler::Msl(Default::default()),
343 _ => AutoCompiler::Wgsl(Default::default()),
344 }
345}
346
347pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
348 let rank = shape.len();
349 let mut strides = vec![1; rank];
350 for i in (0..rank - 1).rev() {
351 strides[i] = strides[i + 1] * shape[i + 1];
352 }
353 strides
354}