1use crate::{
2 DeviceProperties,
3 kernel::KernelMetadata,
4 logging::ServerLogger,
5 memory_management::{
6 MemoryAllocationMode, MemoryHandle, MemoryUsage,
7 memory_pool::{SliceBinding, SliceHandle},
8 },
9 storage::{BindingResource, ComputeStorage},
10 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
11};
12use alloc::collections::BTreeMap;
13use alloc::string::String;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::fmt::Debug;
18use cubecl_common::{
19 ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
20 stream_id::StreamId,
21};
22use cubecl_ir::StorageType;
23use thiserror::Error;
24
25#[derive(Debug, Clone)]
26pub enum ProfileError {
28 Unknown(String),
30 NotRegistered,
32}
33
34#[derive(Debug)]
35pub struct ServerUtilities<Server: ComputeServer> {
37 #[cfg(feature = "profile-tracy")]
39 pub epoch_time: web_time::Instant,
40 #[cfg(feature = "profile-tracy")]
42 pub gpu_client: tracy_client::GpuContext,
43 pub properties: DeviceProperties,
45 pub info: Server::Info,
47 pub logger: Arc<ServerLogger>,
49}
50
51impl<S: ComputeServer> ServerUtilities<S> {
52 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
54 #[cfg(feature = "profile-tracy")]
56 let client = tracy_client::Client::start();
57
58 Self {
59 properties,
60 logger,
61 #[cfg(feature = "profile-tracy")]
63 gpu_client: client
64 .clone()
65 .new_gpu_context(
66 Some(&format!("{info:?}")),
67 tracy_client::GpuContextType::Invalid,
69 0, 1.0, )
72 .unwrap(),
73 #[cfg(feature = "profile-tracy")]
74 epoch_time: web_time::Instant::now(),
75 info,
76 }
77 }
78}
79
80pub trait ComputeServer:
85 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
86where
87 Self: Sized,
88{
89 type Kernel: KernelMetadata;
91 type Info: Debug + Send + Sync;
93 type Storage: ComputeStorage;
95
96 fn create(
98 &mut self,
99 descriptors: Vec<AllocationDescriptor<'_>>,
100 stream_id: StreamId,
101 ) -> Result<Vec<Allocation>, IoError>;
102
103 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
105 Err(IoError::UnsupportedIoOperation)
106 }
107
108 fn logger(&self) -> Arc<ServerLogger>;
110
111 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
113
114 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
116 let alloc = self
117 .create(
118 vec![AllocationDescriptor::new(
119 AllocationKind::Contiguous,
120 &[data.len()],
121 1,
122 )],
123 stream_id,
124 )?
125 .remove(0);
126 self.write(
127 vec![(
128 CopyDescriptor::new(
129 alloc.handle.clone().binding(),
130 &[data.len()],
131 &alloc.strides,
132 1,
133 ),
134 Bytes::from_bytes_vec(data.to_vec()),
135 )],
136 stream_id,
137 )?;
138 Ok(alloc.handle)
139 }
140
141 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
143 let alloc = self
144 .create(
145 vec![AllocationDescriptor::new(
146 AllocationKind::Contiguous,
147 &[data.len()],
148 1,
149 )],
150 stream_id,
151 )?
152 .remove(0);
153 self.write(
154 vec![(
155 CopyDescriptor::new(
156 alloc.handle.clone().binding(),
157 &[data.len()],
158 &alloc.strides,
159 1,
160 ),
161 data,
162 )],
163 stream_id,
164 )?;
165 Ok(alloc.handle)
166 }
167
168 fn read<'a>(
170 &mut self,
171 descriptors: Vec<CopyDescriptor<'a>>,
172 stream_id: StreamId,
173 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
174
175 fn write(
177 &mut self,
178 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
179 stream_id: StreamId,
180 ) -> Result<(), IoError>;
181
182 fn sync(&mut self, stream_id: StreamId) -> DynFut<()>;
184
185 fn get_resource(
187 &mut self,
188 binding: Binding,
189 stream_id: StreamId,
190 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
191
192 unsafe fn execute(
201 &mut self,
202 kernel: Self::Kernel,
203 count: CubeCount,
204 bindings: Bindings,
205 kind: ExecutionMode,
206 stream_id: StreamId,
207 );
208
209 fn flush(&mut self, stream_id: StreamId);
211
212 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
214
215 fn memory_cleanup(&mut self, stream_id: StreamId);
217
218 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
220
221 fn end_profile(
223 &mut self,
224 stream_id: StreamId,
225 token: ProfilingToken,
226 ) -> Result<ProfileDuration, ProfileError>;
227
228 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
230}
231
232pub trait ServerCommunication {
235 const SERVER_COMM_ENABLED: bool;
237
238 #[allow(unused_variables)]
257 fn copy(
258 server_src: &mut Self,
259 server_dst: &mut Self,
260 src: CopyDescriptor<'_>,
261 stream_id_src: StreamId,
262 stream_id_dst: StreamId,
263 ) -> Result<Allocation, IoError> {
264 if !Self::SERVER_COMM_ENABLED {
265 panic!("Server-to-server communication is not supported by this server.");
266 } else {
267 panic!(
268 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
269 );
270 }
271 }
272}
273
274#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
275pub struct ProfilingToken {
277 pub id: u64,
279}
280
281#[derive(new, Debug, PartialEq, Eq)]
283pub struct Handle {
284 pub memory: SliceHandle,
286 pub offset_start: Option<u64>,
288 pub offset_end: Option<u64>,
290 pub stream: cubecl_common::stream_id::StreamId,
292 pub cursor: u64,
294 size: u64,
296}
297
298#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
300pub enum AllocationKind {
301 Contiguous,
303 Optimized,
306}
307
308#[derive(new, Debug, Clone, Copy)]
310pub struct AllocationDescriptor<'a> {
311 pub kind: AllocationKind,
313 pub shape: &'a [usize],
315 pub elem_size: usize,
317}
318
319impl<'a> AllocationDescriptor<'a> {
320 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
322 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
323 }
324
325 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
327 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
328 }
329}
330
331#[derive(new, Debug)]
333pub struct Allocation {
334 pub handle: Handle,
336 pub strides: Vec<usize>,
338}
339
340#[derive(Debug, Error)]
343pub enum IoError {
344 #[error("can't allocate buffer of size")]
346 BufferTooBig(usize),
347 #[error("the provided strides are not supported for this operation")]
349 UnsupportedStrides,
350 #[error("couldn't find resource for that handle")]
352 InvalidHandle,
353 #[error("Unknown error happened during execution")]
355 Unknown(String),
356 #[error("The current IO operation is not supported")]
358 UnsupportedIoOperation,
359}
360
361impl Handle {
362 pub fn offset_start(mut self, offset: u64) -> Self {
364 if let Some(val) = &mut self.offset_start {
365 *val += offset;
366 } else {
367 self.offset_start = Some(offset);
368 }
369
370 self
371 }
372 pub fn offset_end(mut self, offset: u64) -> Self {
374 if let Some(val) = &mut self.offset_end {
375 *val += offset;
376 } else {
377 self.offset_end = Some(offset);
378 }
379
380 self
381 }
382
383 pub fn size(&self) -> u64 {
385 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
386 }
387}
388
389#[derive(Debug, Default)]
391pub struct Bindings {
392 pub buffers: Vec<Binding>,
394 pub metadata: MetadataBinding,
397 pub scalars: BTreeMap<StorageType, ScalarBinding>,
399 pub tensor_maps: Vec<TensorMapBinding>,
401}
402
403impl Bindings {
404 pub fn new() -> Self {
406 Self::default()
407 }
408
409 pub fn with_buffer(mut self, binding: Binding) -> Self {
411 self.buffers.push(binding);
412 self
413 }
414
415 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
417 self.buffers.extend(bindings);
418 self
419 }
420
421 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
423 self.scalars
424 .insert(ty, ScalarBinding::new(ty, length, data));
425 self
426 }
427
428 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
430 self.scalars
431 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
432 self
433 }
434
435 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
437 self.metadata = meta;
438 self
439 }
440
441 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
443 self.tensor_maps.extend(bindings);
444 self
445 }
446}
447
448#[derive(new, Debug, Default)]
450pub struct MetadataBinding {
451 pub data: Vec<u32>,
453 pub static_len: usize,
455}
456
457#[derive(new, Debug, Clone)]
459pub struct ScalarBinding {
460 pub ty: StorageType,
462 pub length: usize,
464 pub data: Vec<u64>,
466}
467
468impl ScalarBinding {
469 pub fn data(&self) -> &[u8] {
471 bytemuck::cast_slice(&self.data)
472 }
473}
474
475#[derive(new, Debug)]
477pub struct Binding {
478 pub memory: SliceBinding,
480 pub offset_start: Option<u64>,
482 pub offset_end: Option<u64>,
484 pub stream: cubecl_common::stream_id::StreamId,
486 pub cursor: u64,
488 size: u64,
490}
491
492impl Binding {
493 pub fn size(&self) -> u64 {
495 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
496 }
497}
498
499#[derive(new, Debug, Clone)]
501pub struct CopyDescriptor<'a> {
502 pub binding: Binding,
504 pub shape: &'a [usize],
506 pub strides: &'a [usize],
508 pub elem_size: usize,
510}
511
512#[derive(new, Debug, Clone)]
514pub struct TensorMapBinding {
515 pub binding: Binding,
517 pub map: TensorMapMeta,
519}
520
521#[derive(Debug, Clone)]
523pub struct TensorMapMeta {
524 pub format: TensorMapFormat,
526 pub rank: usize,
528 pub shape: Vec<usize>,
530 pub strides: Vec<usize>,
532 pub elem_stride: Vec<usize>,
535 pub interleave: TensorMapInterleave,
537 pub swizzle: TensorMapSwizzle,
539 pub prefetch: TensorMapPrefetch,
541 pub oob_fill: OobFill,
543 pub storage_ty: StorageType,
545}
546
547impl Handle {
548 pub fn can_mut(&self) -> bool {
550 self.memory.can_mut() && self.stream == StreamId::current()
551 }
552}
553
554impl Handle {
555 pub fn binding(self) -> Binding {
557 Binding {
558 memory: MemoryHandle::binding(self.memory),
559 offset_start: self.offset_start,
560 offset_end: self.offset_end,
561 size: self.size,
562 stream: self.stream,
563 cursor: self.cursor,
564 }
565 }
566
567 pub fn copy_descriptor<'a>(
569 &'a self,
570 shape: &'a [usize],
571 strides: &'a [usize],
572 elem_size: usize,
573 ) -> CopyDescriptor<'a> {
574 CopyDescriptor {
575 shape,
576 strides,
577 elem_size,
578 binding: self.clone().binding(),
579 }
580 }
581}
582
583impl Clone for Handle {
584 fn clone(&self) -> Self {
585 Self {
586 memory: self.memory.clone(),
587 offset_start: self.offset_start,
588 offset_end: self.offset_end,
589 size: self.size,
590 stream: self.stream,
591 cursor: self.cursor,
592 }
593 }
594}
595
596impl Clone for Binding {
597 fn clone(&self) -> Self {
598 Self {
599 memory: self.memory.clone(),
600 offset_start: self.offset_start,
601 offset_end: self.offset_end,
602 size: self.size,
603 stream: self.stream,
604 cursor: self.cursor,
605 }
606 }
607}
608
609#[allow(clippy::large_enum_variant)]
613pub enum CubeCount {
614 Static(u32, u32, u32),
616 Dynamic(Binding),
618}
619
620impl CubeCount {
621 pub fn new_single() -> Self {
623 CubeCount::Static(1, 1, 1)
624 }
625
626 pub fn new_1d(x: u32) -> Self {
628 CubeCount::Static(x, 1, 1)
629 }
630
631 pub fn new_2d(x: u32, y: u32) -> Self {
633 CubeCount::Static(x, y, 1)
634 }
635
636 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
638 CubeCount::Static(x, y, z)
639 }
640}
641
642impl Debug for CubeCount {
643 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
644 match self {
645 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
646 CubeCount::Dynamic(_) => f.write_str("binding"),
647 }
648 }
649}
650
651impl Clone for CubeCount {
652 fn clone(&self) -> Self {
653 match self {
654 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
655 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
656 }
657 }
658}