1use crate::{
2 DeviceProperties,
3 compiler::CompilationError,
4 kernel::KernelMetadata,
5 logging::ServerLogger,
6 memory_management::{
7 MemoryAllocationMode, MemoryHandle, MemoryUsage,
8 memory_pool::{SliceBinding, SliceHandle},
9 },
10 storage::{BindingResource, ComputeStorage},
11 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
12};
13use alloc::collections::BTreeMap;
14use alloc::string::String;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::fmt::Debug;
19use cubecl_common::{
20 ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
21 stream_id::StreamId,
22};
23use cubecl_ir::StorageType;
24use thiserror::Error;
25
26#[derive(Debug, Clone)]
27pub enum ProfileError {
29 Unknown(String),
31 NotRegistered,
33 Launch(LaunchError),
35 Execution(ExecutionError),
37}
38
39impl From<LaunchError> for ProfileError {
40 fn from(val: LaunchError) -> Self {
41 ProfileError::Launch(val)
42 }
43}
44
45impl From<ExecutionError> for ProfileError {
46 fn from(val: ExecutionError) -> Self {
47 Self::Execution(val)
48 }
49}
50
51#[derive(Debug)]
52pub struct ServerUtilities<Server: ComputeServer> {
54 #[cfg(feature = "profile-tracy")]
56 pub epoch_time: web_time::Instant,
57 #[cfg(feature = "profile-tracy")]
59 pub gpu_client: tracy_client::GpuContext,
60 pub properties: DeviceProperties,
62 pub info: Server::Info,
64 pub logger: Arc<ServerLogger>,
66}
67
68impl<S: ComputeServer> ServerUtilities<S> {
69 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
71 #[cfg(feature = "profile-tracy")]
73 let client = tracy_client::Client::start();
74
75 Self {
76 properties,
77 logger,
78 #[cfg(feature = "profile-tracy")]
80 gpu_client: client
81 .clone()
82 .new_gpu_context(
83 Some(&format!("{info:?}")),
84 tracy_client::GpuContextType::Invalid,
86 0, 1.0, )
89 .unwrap(),
90 #[cfg(feature = "profile-tracy")]
91 epoch_time: web_time::Instant::now(),
92 info,
93 }
94 }
95}
96
97#[derive(Debug, PartialEq, Eq, Clone, Hash)]
104#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
105pub enum LaunchError {
106 CompilationError(CompilationError),
108 OutOfMemory {
110 context: String,
112 },
113 Unknown {
115 context: String,
117 },
118 IoError(IoError),
120}
121
122#[derive(Debug, PartialEq, Eq, Clone, Hash)]
124#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
125pub enum ExecutionError {
126 Generic {
128 context: String,
130 },
131 Composed {
133 context: String,
135 errors: Vec<Self>,
137 },
138}
139
140impl From<CompilationError> for LaunchError {
141 fn from(value: CompilationError) -> Self {
142 Self::CompilationError(value)
143 }
144}
145
146impl From<IoError> for LaunchError {
147 fn from(value: IoError) -> Self {
148 Self::IoError(value)
149 }
150}
151
152impl core::fmt::Display for LaunchError {
153 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
154 match self {
155 LaunchError::CompilationError(err) => f.write_fmt(format_args!(
156 "A compilation error happened during launch: {err}"
157 )),
158 LaunchError::OutOfMemory { context } => f.write_fmt(format_args!(
159 "Out of memory error happened during launch: {context}"
160 )),
161 LaunchError::Unknown { context } => f.write_fmt(format_args!(
162 "An unknown error happened during launch: {context}"
163 )),
164 LaunchError::IoError(err) => {
165 f.write_fmt(format_args!("Can't launch because of an IO error: {err}"))
166 }
167 }
168 }
169}
170
171pub trait ComputeServer:
176 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
177where
178 Self: Sized,
179{
180 type Kernel: KernelMetadata;
182 type Info: Debug + Send + Sync;
184 type Storage: ComputeStorage;
186
187 fn create(
189 &mut self,
190 descriptors: Vec<AllocationDescriptor<'_>>,
191 stream_id: StreamId,
192 ) -> Result<Vec<Allocation>, IoError>;
193
194 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
196 Err(IoError::UnsupportedIoOperation)
197 }
198
199 fn logger(&self) -> Arc<ServerLogger>;
201
202 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
204
205 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
207 let alloc = self
208 .create(
209 vec![AllocationDescriptor::new(
210 AllocationKind::Contiguous,
211 &[data.len()],
212 1,
213 )],
214 stream_id,
215 )?
216 .remove(0);
217 self.write(
218 vec![(
219 CopyDescriptor::new(
220 alloc.handle.clone().binding(),
221 &[data.len()],
222 &alloc.strides,
223 1,
224 ),
225 Bytes::from_bytes_vec(data.to_vec()),
226 )],
227 stream_id,
228 )?;
229 Ok(alloc.handle)
230 }
231
232 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
234 let alloc = self
235 .create(
236 vec![AllocationDescriptor::new(
237 AllocationKind::Contiguous,
238 &[data.len()],
239 1,
240 )],
241 stream_id,
242 )?
243 .remove(0);
244 self.write(
245 vec![(
246 CopyDescriptor::new(
247 alloc.handle.clone().binding(),
248 &[data.len()],
249 &alloc.strides,
250 1,
251 ),
252 data,
253 )],
254 stream_id,
255 )?;
256 Ok(alloc.handle)
257 }
258
259 fn read<'a>(
261 &mut self,
262 descriptors: Vec<CopyDescriptor<'a>>,
263 stream_id: StreamId,
264 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
265
266 fn write(
268 &mut self,
269 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
270 stream_id: StreamId,
271 ) -> Result<(), IoError>;
272
273 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
275
276 fn get_resource(
278 &mut self,
279 binding: Binding,
280 stream_id: StreamId,
281 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
282
283 unsafe fn launch(
292 &mut self,
293 kernel: Self::Kernel,
294 count: CubeCount,
295 bindings: Bindings,
296 kind: ExecutionMode,
297 stream_id: StreamId,
298 ) -> Result<(), LaunchError>;
299
300 fn flush(&mut self, stream_id: StreamId);
302
303 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
305
306 fn memory_cleanup(&mut self, stream_id: StreamId);
308
309 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
311
312 fn end_profile(
314 &mut self,
315 stream_id: StreamId,
316 token: ProfilingToken,
317 ) -> Result<ProfileDuration, ProfileError>;
318
319 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
321}
322
323pub trait ServerCommunication {
326 const SERVER_COMM_ENABLED: bool;
328
329 #[allow(unused_variables)]
348 fn copy(
349 server_src: &mut Self,
350 server_dst: &mut Self,
351 src: CopyDescriptor<'_>,
352 stream_id_src: StreamId,
353 stream_id_dst: StreamId,
354 ) -> Result<Allocation, IoError> {
355 if !Self::SERVER_COMM_ENABLED {
356 panic!("Server-to-server communication is not supported by this server.");
357 } else {
358 panic!(
359 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
360 );
361 }
362 }
363}
364
365#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
366pub struct ProfilingToken {
368 pub id: u64,
370}
371
372#[derive(new, Debug, PartialEq, Eq)]
374pub struct Handle {
375 pub memory: SliceHandle,
377 pub offset_start: Option<u64>,
379 pub offset_end: Option<u64>,
381 pub stream: cubecl_common::stream_id::StreamId,
383 pub cursor: u64,
385 size: u64,
387}
388
389#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
391pub enum AllocationKind {
392 Contiguous,
394 Optimized,
397}
398
399#[derive(new, Debug, Clone, Copy)]
401pub struct AllocationDescriptor<'a> {
402 pub kind: AllocationKind,
404 pub shape: &'a [usize],
406 pub elem_size: usize,
408}
409
410impl<'a> AllocationDescriptor<'a> {
411 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
413 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
414 }
415
416 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
418 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
419 }
420}
421
422#[derive(new, Debug)]
424pub struct Allocation {
425 pub handle: Handle,
427 pub strides: Vec<usize>,
429}
430
431#[derive(Debug, Error, PartialEq, Eq, Clone, Hash)]
434#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
435pub enum IoError {
436 #[error("can't allocate buffer of size")]
438 BufferTooBig(usize),
439 #[error("the provided strides are not supported for this operation")]
441 UnsupportedStrides,
442 #[error("couldn't find resource for that handle")]
444 InvalidHandle,
445 #[error("Unknown error happened during execution")]
447 Unknown(String),
448 #[error("The current IO operation is not supported")]
450 UnsupportedIoOperation,
451 #[error("Can't perform the IO operation because of a runtime error")]
453 Execution(ExecutionError),
454}
455
456impl From<ExecutionError> for IoError {
457 fn from(value: ExecutionError) -> Self {
458 Self::Execution(value)
459 }
460}
461
462impl Handle {
463 pub fn offset_start(mut self, offset: u64) -> Self {
465 if let Some(val) = &mut self.offset_start {
466 *val += offset;
467 } else {
468 self.offset_start = Some(offset);
469 }
470
471 self
472 }
473 pub fn offset_end(mut self, offset: u64) -> Self {
475 if let Some(val) = &mut self.offset_end {
476 *val += offset;
477 } else {
478 self.offset_end = Some(offset);
479 }
480
481 self
482 }
483
484 pub fn size(&self) -> u64 {
486 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
487 }
488}
489
490#[derive(Debug, Default)]
492pub struct Bindings {
493 pub buffers: Vec<Binding>,
495 pub metadata: MetadataBinding,
498 pub scalars: BTreeMap<StorageType, ScalarBinding>,
500 pub tensor_maps: Vec<TensorMapBinding>,
502}
503
504impl Bindings {
505 pub fn new() -> Self {
507 Self::default()
508 }
509
510 pub fn with_buffer(mut self, binding: Binding) -> Self {
512 self.buffers.push(binding);
513 self
514 }
515
516 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
518 self.buffers.extend(bindings);
519 self
520 }
521
522 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
524 self.scalars
525 .insert(ty, ScalarBinding::new(ty, length, data));
526 self
527 }
528
529 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
531 self.scalars
532 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
533 self
534 }
535
536 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
538 self.metadata = meta;
539 self
540 }
541
542 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
544 self.tensor_maps.extend(bindings);
545 self
546 }
547}
548
549#[derive(new, Debug, Default)]
551pub struct MetadataBinding {
552 pub data: Vec<u32>,
554 pub static_len: usize,
556}
557
558#[derive(new, Debug, Clone)]
560pub struct ScalarBinding {
561 pub ty: StorageType,
563 pub length: usize,
565 pub data: Vec<u64>,
567}
568
569impl ScalarBinding {
570 pub fn data(&self) -> &[u8] {
572 bytemuck::cast_slice(&self.data)
573 }
574}
575
576#[derive(new, Debug)]
578pub struct Binding {
579 pub memory: SliceBinding,
581 pub offset_start: Option<u64>,
583 pub offset_end: Option<u64>,
585 pub stream: cubecl_common::stream_id::StreamId,
587 pub cursor: u64,
589 size: u64,
591}
592
593impl Binding {
594 pub fn size(&self) -> u64 {
596 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
597 }
598}
599
600#[derive(new, Debug, Clone)]
602pub struct CopyDescriptor<'a> {
603 pub binding: Binding,
605 pub shape: &'a [usize],
607 pub strides: &'a [usize],
609 pub elem_size: usize,
611}
612
613#[derive(new, Debug, Clone)]
615pub struct TensorMapBinding {
616 pub binding: Binding,
618 pub map: TensorMapMeta,
620}
621
622#[derive(Debug, Clone)]
624pub struct TensorMapMeta {
625 pub format: TensorMapFormat,
627 pub rank: usize,
629 pub shape: Vec<usize>,
631 pub strides: Vec<usize>,
633 pub elem_stride: Vec<usize>,
636 pub interleave: TensorMapInterleave,
638 pub swizzle: TensorMapSwizzle,
640 pub prefetch: TensorMapPrefetch,
642 pub oob_fill: OobFill,
644 pub storage_ty: StorageType,
646}
647
648impl Handle {
649 pub fn can_mut(&self) -> bool {
651 self.memory.can_mut() && self.stream == StreamId::current()
652 }
653}
654
655impl Handle {
656 pub fn binding(self) -> Binding {
658 Binding {
659 memory: MemoryHandle::binding(self.memory),
660 offset_start: self.offset_start,
661 offset_end: self.offset_end,
662 size: self.size,
663 stream: self.stream,
664 cursor: self.cursor,
665 }
666 }
667
668 pub fn copy_descriptor<'a>(
670 &'a self,
671 shape: &'a [usize],
672 strides: &'a [usize],
673 elem_size: usize,
674 ) -> CopyDescriptor<'a> {
675 CopyDescriptor {
676 shape,
677 strides,
678 elem_size,
679 binding: self.clone().binding(),
680 }
681 }
682}
683
684impl Clone for Handle {
685 fn clone(&self) -> Self {
686 Self {
687 memory: self.memory.clone(),
688 offset_start: self.offset_start,
689 offset_end: self.offset_end,
690 size: self.size,
691 stream: self.stream,
692 cursor: self.cursor,
693 }
694 }
695}
696
697impl Clone for Binding {
698 fn clone(&self) -> Self {
699 Self {
700 memory: self.memory.clone(),
701 offset_start: self.offset_start,
702 offset_end: self.offset_end,
703 size: self.size,
704 stream: self.stream,
705 cursor: self.cursor,
706 }
707 }
708}
709
710#[allow(clippy::large_enum_variant)]
714pub enum CubeCount {
715 Static(u32, u32, u32),
717 Dynamic(Binding),
719}
720
721impl CubeCount {
722 pub fn new_single() -> Self {
724 CubeCount::Static(1, 1, 1)
725 }
726
727 pub fn new_1d(x: u32) -> Self {
729 CubeCount::Static(x, 1, 1)
730 }
731
732 pub fn new_2d(x: u32, y: u32) -> Self {
734 CubeCount::Static(x, y, 1)
735 }
736
737 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
739 CubeCount::Static(x, y, z)
740 }
741}
742
743impl Debug for CubeCount {
744 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
745 match self {
746 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
747 CubeCount::Dynamic(_) => f.write_str("binding"),
748 }
749 }
750}
751
752impl Clone for CubeCount {
753 fn clone(&self) -> Self {
754 match self {
755 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
756 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
757 }
758 }
759}