1use crate::{
2 kernel::KernelMetadata,
3 logging::ServerLogger,
4 memory_management::{
5 MemoryHandle, MemoryUsage,
6 memory_pool::{SliceBinding, SliceHandle},
7 },
8 storage::{BindingResource, ComputeStorage},
9 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
10};
11use alloc::collections::BTreeMap;
12use alloc::string::String;
13use alloc::sync::Arc;
14use alloc::vec::Vec;
15use core::fmt::Debug;
16use cubecl_common::{ExecutionMode, future::DynFut, profile::ProfileDuration};
17use cubecl_ir::Elem;
18
19#[derive(Debug, Clone)]
20pub enum ProfileError {
22 Unknown(String),
24 NotRegistered,
26}
27
28pub trait ComputeServer: Send + core::fmt::Debug
33where
34 Self: Sized,
35{
36 type Kernel: KernelMetadata;
38 type Info: Debug + Send + Sync;
40 type Storage: ComputeStorage;
42 type Feature: Ord + Copy + Debug + Send + Sync;
44
45 fn read(&mut self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>>;
47
48 fn read_tensor(&mut self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>>;
50
51 fn sync(&mut self) -> DynFut<()>;
53
54 fn get_resource(
56 &mut self,
57 binding: Binding,
58 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
59
60 fn create(&mut self, data: &[u8]) -> Handle;
62
63 fn create_tensors(
69 &mut self,
70 data: Vec<&[u8]>,
71 shapes: Vec<&[usize]>,
72 elem_sizes: Vec<usize>,
73 ) -> Vec<(Handle, Vec<usize>)>;
74
75 fn empty(&mut self, size: usize) -> Handle;
77
78 fn empty_tensors(
80 &mut self,
81 shapes: Vec<&[usize]>,
82 elem_sizes: Vec<usize>,
83 ) -> Vec<(Handle, Vec<usize>)>;
84
85 unsafe fn execute(
94 &mut self,
95 kernel: Self::Kernel,
96 count: CubeCount,
97 bindings: Bindings,
98 kind: ExecutionMode,
99 logger: Arc<ServerLogger>,
100 );
101
102 fn flush(&mut self);
104
105 fn memory_usage(&self) -> MemoryUsage;
107
108 fn memory_cleanup(&mut self);
110
111 fn start_profile(&mut self) -> ProfilingToken;
113
114 fn end_profile(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError>;
116}
117
118#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
119pub struct ProfilingToken {
121 pub id: u64,
123}
124
125#[derive(new, Debug)]
127pub struct Handle {
128 pub memory: SliceHandle,
130 pub offset_start: Option<u64>,
132 pub offset_end: Option<u64>,
134 size: u64,
136}
137
138impl Handle {
139 pub fn offset_start(mut self, offset: u64) -> Self {
141 if let Some(val) = &mut self.offset_start {
142 *val += offset;
143 } else {
144 self.offset_start = Some(offset);
145 }
146
147 self
148 }
149 pub fn offset_end(mut self, offset: u64) -> Self {
151 if let Some(val) = &mut self.offset_end {
152 *val += offset;
153 } else {
154 self.offset_end = Some(offset);
155 }
156
157 self
158 }
159
160 pub fn size(&self) -> u64 {
162 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
163 }
164}
165
166#[derive(Debug, Default)]
168pub struct Bindings {
169 pub buffers: Vec<Binding>,
171 pub metadata: MetadataBinding,
174 pub scalars: BTreeMap<Elem, ScalarBinding>,
176 pub tensor_maps: Vec<TensorMapBinding>,
178}
179
180impl Bindings {
181 pub fn new() -> Self {
183 Self::default()
184 }
185
186 pub fn with_buffer(mut self, binding: Binding) -> Self {
188 self.buffers.push(binding);
189 self
190 }
191
192 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
194 self.buffers.extend(bindings);
195 self
196 }
197
198 pub fn with_scalar(mut self, elem: Elem, length: usize, data: Vec<u64>) -> Self {
200 self.scalars
201 .insert(elem, ScalarBinding::new(elem, length, data));
202 self
203 }
204
205 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
207 self.scalars
208 .extend(bindings.into_iter().map(|binding| (binding.elem, binding)));
209 self
210 }
211
212 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
214 self.metadata = meta;
215 self
216 }
217
218 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
220 self.tensor_maps.extend(bindings);
221 self
222 }
223}
224
225#[derive(new, Debug, Default)]
227pub struct MetadataBinding {
228 pub data: Vec<u32>,
230 pub static_len: usize,
232}
233
234#[derive(new, Debug)]
236pub struct ScalarBinding {
237 pub elem: Elem,
239 pub length: usize,
241 pub data: Vec<u64>,
243}
244
245impl ScalarBinding {
246 pub fn data(&self) -> &[u8] {
248 bytemuck::cast_slice(&self.data)
249 }
250}
251
252#[derive(new, Debug)]
254pub struct Binding {
255 pub memory: SliceBinding,
257 pub offset_start: Option<u64>,
259 pub offset_end: Option<u64>,
261}
262
263#[derive(new, Debug)]
265pub struct BindingWithMeta {
266 pub binding: Binding,
268 pub shape: Vec<usize>,
270 pub strides: Vec<usize>,
272 pub elem_size: usize,
274}
275
276#[derive(new, Debug, Clone)]
278pub struct TensorMapBinding {
279 pub binding: Binding,
281 pub map: TensorMapMeta,
283}
284
285#[derive(Debug, Clone)]
287pub struct TensorMapMeta {
288 pub format: TensorMapFormat,
290 pub rank: usize,
292 pub shape: Vec<usize>,
294 pub strides: Vec<usize>,
296 pub elem_stride: Vec<usize>,
299 pub interleave: TensorMapInterleave,
301 pub swizzle: TensorMapSwizzle,
303 pub prefetch: TensorMapPrefetch,
305 pub oob_fill: OobFill,
307 pub elem: Elem,
309}
310
311impl Handle {
312 pub fn can_mut(&self) -> bool {
314 self.memory.can_mut()
315 }
316}
317
318impl Handle {
319 pub fn binding(self) -> Binding {
321 Binding {
322 memory: MemoryHandle::binding(self.memory),
323 offset_start: self.offset_start,
324 offset_end: self.offset_end,
325 }
326 }
327
328 pub fn binding_with_meta(
330 self,
331 shape: Vec<usize>,
332 strides: Vec<usize>,
333 elem_size: usize,
334 ) -> BindingWithMeta {
335 BindingWithMeta {
336 shape,
337 strides,
338 elem_size,
339 binding: self.binding(),
340 }
341 }
342}
343
344impl Clone for Handle {
345 fn clone(&self) -> Self {
346 Self {
347 memory: self.memory.clone(),
348 offset_start: self.offset_start,
349 offset_end: self.offset_end,
350 size: self.size,
351 }
352 }
353}
354
355impl Clone for Binding {
356 fn clone(&self) -> Self {
357 Self {
358 memory: self.memory.clone(),
359 offset_start: self.offset_start,
360 offset_end: self.offset_end,
361 }
362 }
363}
364
365#[allow(clippy::large_enum_variant)]
369pub enum CubeCount {
370 Static(u32, u32, u32),
372 Dynamic(Binding),
374}
375
376impl CubeCount {
377 pub fn new_single() -> Self {
379 CubeCount::Static(1, 1, 1)
380 }
381
382 pub fn new_1d(x: u32) -> Self {
384 CubeCount::Static(x, 1, 1)
385 }
386
387 pub fn new_2d(x: u32, y: u32) -> Self {
389 CubeCount::Static(x, y, 1)
390 }
391
392 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
394 CubeCount::Static(x, y, z)
395 }
396}
397
398impl Debug for CubeCount {
399 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
400 match self {
401 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
402 CubeCount::Dynamic(_) => f.write_str("binding"),
403 }
404 }
405}
406
407impl Clone for CubeCount {
408 fn clone(&self) -> Self {
409 match self {
410 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
411 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
412 }
413 }
414}