1use crate::{
2 memory_management::{
3 MemoryHandle, MemoryUsage,
4 memory_pool::{SliceBinding, SliceHandle},
5 },
6 storage::{BindingResource, ComputeStorage},
7 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
8};
9use alloc::collections::BTreeMap;
10use alloc::vec::Vec;
11use core::{fmt::Debug, future::Future};
12use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
13use cubecl_ir::Elem;
14
15pub trait ComputeServer: Send + core::fmt::Debug
20where
21 Self: Sized,
22{
23 type Kernel: Send;
25 type Info: Debug + Send + Sync;
27 type Storage: ComputeStorage;
29 type Feature: Ord + Copy + Debug + Send + Sync;
31
32 fn read(
34 &mut self,
35 bindings: Vec<Binding>,
36 ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
37
38 fn read_tensor(
40 &mut self,
41 bindings: Vec<BindingWithMeta>,
42 ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
43
44 fn get_resource(
46 &mut self,
47 binding: Binding,
48 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
49
50 fn create(&mut self, data: &[u8]) -> Handle;
52
53 fn create_tensor(
59 &mut self,
60 data: &[u8],
61 shape: &[usize],
62 elem_size: usize,
63 ) -> (Handle, Vec<usize>);
64
65 fn empty(&mut self, size: usize) -> Handle;
67
68 fn empty_tensor(&mut self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>);
70
71 unsafe fn execute(
80 &mut self,
81 kernel: Self::Kernel,
82 count: CubeCount,
83 bindings: Bindings,
84 kind: ExecutionMode,
85 );
86
87 fn flush(&mut self);
89
90 fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
92
93 fn memory_usage(&self) -> MemoryUsage;
95
96 fn memory_cleanup(&mut self);
98
99 fn start_profile(&mut self);
101
102 fn end_profile(&mut self) -> ProfileDuration;
104}
105
106#[derive(new, Debug)]
108pub struct Handle {
109 pub memory: SliceHandle,
111 pub offset_start: Option<u64>,
113 pub offset_end: Option<u64>,
115 size: u64,
117}
118
119impl Handle {
120 pub fn offset_start(mut self, offset: u64) -> Self {
122 if let Some(val) = &mut self.offset_start {
123 *val += offset;
124 } else {
125 self.offset_start = Some(offset);
126 }
127
128 self
129 }
130 pub fn offset_end(mut self, offset: u64) -> Self {
132 if let Some(val) = &mut self.offset_end {
133 *val += offset;
134 } else {
135 self.offset_end = Some(offset);
136 }
137
138 self
139 }
140
141 pub fn size(&self) -> u64 {
143 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
144 }
145}
146
147#[derive(Debug, Default)]
149pub struct Bindings {
150 pub buffers: Vec<Binding>,
152 pub metadata: MetadataBinding,
155 pub scalars: BTreeMap<Elem, ScalarBinding>,
157 pub tensor_maps: Vec<TensorMapBinding>,
159}
160
161impl Bindings {
162 pub fn new() -> Self {
164 Self::default()
165 }
166
167 pub fn with_buffer(mut self, binding: Binding) -> Self {
169 self.buffers.push(binding);
170 self
171 }
172
173 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
175 self.buffers.extend(bindings);
176 self
177 }
178
179 pub fn with_scalar(mut self, elem: Elem, length: usize, data: Vec<u64>) -> Self {
181 self.scalars
182 .insert(elem, ScalarBinding::new(elem, length, data));
183 self
184 }
185
186 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
188 self.scalars
189 .extend(bindings.into_iter().map(|binding| (binding.elem, binding)));
190 self
191 }
192
193 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
195 self.metadata = meta;
196 self
197 }
198
199 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
201 self.tensor_maps.extend(bindings);
202 self
203 }
204}
205
206#[derive(new, Debug, Default)]
208pub struct MetadataBinding {
209 pub data: Vec<u32>,
211 pub static_len: usize,
213}
214
215#[derive(new, Debug)]
217pub struct ScalarBinding {
218 pub elem: Elem,
220 pub length: usize,
222 pub data: Vec<u64>,
224}
225
226impl ScalarBinding {
227 pub fn data(&self) -> &[u8] {
229 bytemuck::cast_slice(&self.data)
230 }
231}
232
233#[derive(new, Debug)]
235pub struct Binding {
236 pub memory: SliceBinding,
238 pub offset_start: Option<u64>,
240 pub offset_end: Option<u64>,
242}
243
244#[derive(new, Debug)]
246pub struct BindingWithMeta {
247 pub binding: Binding,
249 pub shape: Vec<usize>,
251 pub strides: Vec<usize>,
253 pub elem_size: usize,
255}
256
257#[derive(new, Debug, Clone)]
259pub struct TensorMapBinding {
260 pub binding: Binding,
262 pub map: TensorMapMeta,
264}
265
266#[derive(Debug, Clone)]
268pub struct TensorMapMeta {
269 pub format: TensorMapFormat,
271 pub rank: usize,
273 pub shape: Vec<usize>,
275 pub strides: Vec<usize>,
277 pub elem_stride: Vec<usize>,
280 pub interleave: TensorMapInterleave,
282 pub swizzle: TensorMapSwizzle,
284 pub prefetch: TensorMapPrefetch,
286 pub oob_fill: OobFill,
288 pub elem: Elem,
290}
291
292impl Handle {
293 pub fn can_mut(&self) -> bool {
295 self.memory.can_mut()
296 }
297}
298
299impl Handle {
300 pub fn binding(self) -> Binding {
302 Binding {
303 memory: MemoryHandle::binding(self.memory),
304 offset_start: self.offset_start,
305 offset_end: self.offset_end,
306 }
307 }
308
309 pub fn binding_with_meta(
311 self,
312 shape: Vec<usize>,
313 strides: Vec<usize>,
314 elem_size: usize,
315 ) -> BindingWithMeta {
316 BindingWithMeta {
317 shape,
318 strides,
319 elem_size,
320 binding: self.binding(),
321 }
322 }
323}
324
325impl Clone for Handle {
326 fn clone(&self) -> Self {
327 Self {
328 memory: self.memory.clone(),
329 offset_start: self.offset_start,
330 offset_end: self.offset_end,
331 size: self.size,
332 }
333 }
334}
335
336impl Clone for Binding {
337 fn clone(&self) -> Self {
338 Self {
339 memory: self.memory.clone(),
340 offset_start: self.offset_start,
341 offset_end: self.offset_end,
342 }
343 }
344}
345
346#[allow(clippy::large_enum_variant)]
350pub enum CubeCount {
351 Static(u32, u32, u32),
353 Dynamic(Binding),
355}
356
357impl CubeCount {
358 pub fn new_single() -> Self {
360 CubeCount::Static(1, 1, 1)
361 }
362
363 pub fn new_1d(x: u32) -> Self {
365 CubeCount::Static(x, 1, 1)
366 }
367
368 pub fn new_2d(x: u32, y: u32) -> Self {
370 CubeCount::Static(x, y, 1)
371 }
372
373 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
375 CubeCount::Static(x, y, z)
376 }
377}
378
379impl Debug for CubeCount {
380 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
381 match self {
382 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
383 CubeCount::Dynamic(_) => f.write_str("binding"),
384 }
385 }
386}
387
388impl Clone for CubeCount {
389 fn clone(&self) -> Self {
390 match self {
391 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
392 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
393 }
394 }
395}