1use crate::{
2 DeviceProperties,
3 compiler::CompilationError,
4 kernel::KernelMetadata,
5 logging::ServerLogger,
6 memory_management::{
7 MemoryAllocationMode, MemoryHandle, MemoryUsage,
8 memory_pool::{SliceBinding, SliceHandle},
9 },
10 storage::{BindingResource, ComputeStorage},
11 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
12};
13use alloc::collections::BTreeMap;
14use alloc::string::String;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::fmt::Debug;
19use cubecl_common::{
20 ExecutionMode, backtrace::BackTrace, bytes::Bytes, device, future::DynFut,
21 profile::ProfileDuration, stream_id::StreamId,
22};
23use cubecl_ir::StorageType;
24use thiserror::Error;
25
26#[derive(Error, Clone)]
27pub enum ProfileError {
29 #[error(
31 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
32 )]
33 Unknown {
34 reason: String,
36 backtrace: BackTrace,
38 },
39
40 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
42 NotRegistered {
43 backtrace: BackTrace,
45 },
46
47 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
49 Launch(#[from] LaunchError),
50
51 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
53 Execution(#[from] ExecutionError),
54}
55
56impl core::fmt::Debug for ProfileError {
57 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58 f.write_fmt(format_args!("{self}"))
59 }
60}
61
62#[derive(Debug)]
63pub struct ServerUtilities<Server: ComputeServer> {
65 #[cfg(feature = "profile-tracy")]
67 pub epoch_time: web_time::Instant,
68 #[cfg(feature = "profile-tracy")]
70 pub gpu_client: tracy_client::GpuContext,
71 pub properties: DeviceProperties,
73 pub info: Server::Info,
75 pub logger: Arc<ServerLogger>,
77}
78
79impl<S: ComputeServer> ServerUtilities<S> {
80 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
82 #[cfg(feature = "profile-tracy")]
84 let client = tracy_client::Client::start();
85
86 Self {
87 properties,
88 logger,
89 #[cfg(feature = "profile-tracy")]
91 gpu_client: client
92 .clone()
93 .new_gpu_context(
94 Some(&format!("{info:?}")),
95 tracy_client::GpuContextType::Invalid,
97 0, 1.0, )
100 .unwrap(),
101 #[cfg(feature = "profile-tracy")]
102 epoch_time: web_time::Instant::now(),
103 info,
104 }
105 }
106}
107
108#[derive(Error, Clone)]
115#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
116pub enum LaunchError {
117 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
119 CompilationError(#[from] CompilationError),
120
121 #[error(
123 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
124 )]
125 OutOfMemory {
126 reason: String,
128 #[cfg_attr(std_io, serde(skip))]
130 backtrace: BackTrace,
131 },
132
133 #[error(
135 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
136 )]
137 Unknown {
138 reason: String,
140 #[cfg_attr(std_io, serde(skip))]
142 backtrace: BackTrace,
143 },
144
145 #[error("An io error happened during launch\nCaused by:\n {0}")]
147 IoError(#[from] IoError),
148}
149
150impl core::fmt::Debug for LaunchError {
151 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152 f.write_fmt(format_args!("{self}"))
153 }
154}
155
156#[derive(Error, Debug, Clone)]
158#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
159pub enum ExecutionError {
160 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
162 Generic {
163 reason: String,
165 #[cfg_attr(std_io, serde(skip))]
167 backtrace: BackTrace,
168 },
169}
170
171pub trait ComputeServer:
176 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
177where
178 Self: Sized,
179{
180 type Kernel: KernelMetadata;
182 type Info: Debug + Send + Sync;
184 type Storage: ComputeStorage;
186
187 fn create(
189 &mut self,
190 descriptors: Vec<AllocationDescriptor<'_>>,
191 stream_id: StreamId,
192 ) -> Result<Vec<Allocation>, IoError>;
193
194 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
196 Err(IoError::UnsupportedIoOperation {
197 backtrace: BackTrace::capture(),
198 })
199 }
200
201 fn logger(&self) -> Arc<ServerLogger>;
203
204 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
206
207 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
209 let alloc = self
210 .create(
211 vec![AllocationDescriptor::new(
212 AllocationKind::Contiguous,
213 &[data.len()],
214 1,
215 )],
216 stream_id,
217 )?
218 .remove(0);
219 self.write(
220 vec![(
221 CopyDescriptor::new(
222 alloc.handle.clone().binding(),
223 &[data.len()],
224 &alloc.strides,
225 1,
226 ),
227 Bytes::from_bytes_vec(data.to_vec()),
228 )],
229 stream_id,
230 )?;
231 Ok(alloc.handle)
232 }
233
234 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
236 let alloc = self
237 .create(
238 vec![AllocationDescriptor::new(
239 AllocationKind::Contiguous,
240 &[data.len()],
241 1,
242 )],
243 stream_id,
244 )?
245 .remove(0);
246 self.write(
247 vec![(
248 CopyDescriptor::new(
249 alloc.handle.clone().binding(),
250 &[data.len()],
251 &alloc.strides,
252 1,
253 ),
254 data,
255 )],
256 stream_id,
257 )?;
258 Ok(alloc.handle)
259 }
260
261 fn read<'a>(
263 &mut self,
264 descriptors: Vec<CopyDescriptor<'a>>,
265 stream_id: StreamId,
266 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
267
268 fn write(
270 &mut self,
271 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
272 stream_id: StreamId,
273 ) -> Result<(), IoError>;
274
275 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
277
278 fn get_resource(
280 &mut self,
281 binding: Binding,
282 stream_id: StreamId,
283 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
284
285 unsafe fn launch(
294 &mut self,
295 kernel: Self::Kernel,
296 count: CubeCount,
297 bindings: Bindings,
298 kind: ExecutionMode,
299 stream_id: StreamId,
300 ) -> Result<(), LaunchError>;
301
302 fn flush(&mut self, stream_id: StreamId);
304
305 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
307
308 fn memory_cleanup(&mut self, stream_id: StreamId);
310
311 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
313
314 fn end_profile(
316 &mut self,
317 stream_id: StreamId,
318 token: ProfilingToken,
319 ) -> Result<ProfileDuration, ProfileError>;
320
321 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
323}
324
325pub trait ServerCommunication {
328 const SERVER_COMM_ENABLED: bool;
330
331 #[allow(unused_variables)]
350 fn copy(
351 server_src: &mut Self,
352 server_dst: &mut Self,
353 src: CopyDescriptor<'_>,
354 stream_id_src: StreamId,
355 stream_id_dst: StreamId,
356 ) -> Result<Allocation, IoError> {
357 if !Self::SERVER_COMM_ENABLED {
358 panic!("Server-to-server communication is not supported by this server.");
359 } else {
360 panic!(
361 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
362 );
363 }
364 }
365}
366
367#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
368pub struct ProfilingToken {
370 pub id: u64,
372}
373
374#[derive(new, Debug, PartialEq, Eq)]
376pub struct Handle {
377 pub memory: SliceHandle,
379 pub offset_start: Option<u64>,
381 pub offset_end: Option<u64>,
383 pub stream: cubecl_common::stream_id::StreamId,
385 pub cursor: u64,
387 size: u64,
389}
390
391#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
393pub enum AllocationKind {
394 Contiguous,
396 Optimized,
399}
400
401#[derive(new, Debug, Clone, Copy)]
403pub struct AllocationDescriptor<'a> {
404 pub kind: AllocationKind,
406 pub shape: &'a [usize],
408 pub elem_size: usize,
410}
411
412impl<'a> AllocationDescriptor<'a> {
413 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
415 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
416 }
417
418 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
420 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
421 }
422}
423
424#[derive(new, Debug)]
426pub struct Allocation {
427 pub handle: Handle,
429 pub strides: Vec<usize>,
431}
432
433#[derive(Error, Clone)]
436#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
437pub enum IoError {
438 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
440 BufferTooBig {
441 size: u64,
443 #[cfg_attr(std_io, serde(skip))]
445 backtrace: BackTrace,
446 },
447
448 #[error("the provided strides are not supported for this operation\n{backtrace}")]
450 UnsupportedStrides {
451 #[cfg_attr(std_io, serde(skip))]
453 backtrace: BackTrace,
454 },
455
456 #[error("couldn't find resource for that handle\n{backtrace}")]
458 InvalidHandle {
459 #[cfg_attr(std_io, serde(skip))]
461 backtrace: BackTrace,
462 },
463
464 #[error("Unknown error happened during execution\n{backtrace}")]
466 Unknown {
467 description: String,
469 #[cfg_attr(std_io, serde(skip))]
471 backtrace: BackTrace,
472 },
473
474 #[error("The current IO operation is not supported\n{backtrace}")]
476 UnsupportedIoOperation {
477 #[cfg_attr(std_io, serde(skip))]
479 backtrace: BackTrace,
480 },
481
482 #[error("Can't perform the IO operation because of a runtime error")]
484 Execution(#[from] ExecutionError),
485}
486
487impl core::fmt::Debug for IoError {
488 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
489 f.write_fmt(format_args!("{self}"))
490 }
491}
492
493impl Handle {
494 pub fn offset_start(mut self, offset: u64) -> Self {
496 if let Some(val) = &mut self.offset_start {
497 *val += offset;
498 } else {
499 self.offset_start = Some(offset);
500 }
501
502 self
503 }
504 pub fn offset_end(mut self, offset: u64) -> Self {
506 if let Some(val) = &mut self.offset_end {
507 *val += offset;
508 } else {
509 self.offset_end = Some(offset);
510 }
511
512 self
513 }
514
515 pub fn size(&self) -> u64 {
517 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
518 }
519}
520
521#[derive(Debug, Default)]
523pub struct Bindings {
524 pub buffers: Vec<Binding>,
526 pub metadata: MetadataBinding,
529 pub scalars: BTreeMap<StorageType, ScalarBinding>,
531 pub tensor_maps: Vec<TensorMapBinding>,
533}
534
535impl Bindings {
536 pub fn new() -> Self {
538 Self::default()
539 }
540
541 pub fn with_buffer(mut self, binding: Binding) -> Self {
543 self.buffers.push(binding);
544 self
545 }
546
547 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
549 self.buffers.extend(bindings);
550 self
551 }
552
553 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
555 self.scalars
556 .insert(ty, ScalarBinding::new(ty, length, data));
557 self
558 }
559
560 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
562 self.scalars
563 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
564 self
565 }
566
567 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
569 self.metadata = meta;
570 self
571 }
572
573 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
575 self.tensor_maps.extend(bindings);
576 self
577 }
578}
579
580#[derive(new, Debug, Default)]
582pub struct MetadataBinding {
583 pub data: Vec<u32>,
585 pub static_len: usize,
587}
588
589#[derive(new, Debug, Clone)]
591pub struct ScalarBinding {
592 pub ty: StorageType,
594 pub length: usize,
596 pub data: Vec<u64>,
598}
599
600impl ScalarBinding {
601 pub fn data(&self) -> &[u8] {
603 bytemuck::cast_slice(&self.data)
604 }
605}
606
607#[derive(new, Debug)]
609pub struct Binding {
610 pub memory: SliceBinding,
612 pub offset_start: Option<u64>,
614 pub offset_end: Option<u64>,
616 pub stream: cubecl_common::stream_id::StreamId,
618 pub cursor: u64,
620 size: u64,
622}
623
624impl Binding {
625 pub fn size(&self) -> u64 {
627 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
628 }
629}
630
631#[derive(new, Debug, Clone)]
633pub struct CopyDescriptor<'a> {
634 pub binding: Binding,
636 pub shape: &'a [usize],
638 pub strides: &'a [usize],
640 pub elem_size: usize,
642}
643
644#[derive(new, Debug, Clone)]
646pub struct TensorMapBinding {
647 pub binding: Binding,
649 pub map: TensorMapMeta,
651}
652
653#[derive(Debug, Clone)]
655pub struct TensorMapMeta {
656 pub format: TensorMapFormat,
658 pub rank: usize,
660 pub shape: Vec<usize>,
662 pub strides: Vec<usize>,
664 pub elem_stride: Vec<usize>,
667 pub interleave: TensorMapInterleave,
669 pub swizzle: TensorMapSwizzle,
671 pub prefetch: TensorMapPrefetch,
673 pub oob_fill: OobFill,
675 pub storage_ty: StorageType,
677}
678
679impl Handle {
680 pub fn can_mut(&self) -> bool {
682 self.memory.can_mut() && self.stream == StreamId::current()
683 }
684}
685
686impl Handle {
687 pub fn binding(self) -> Binding {
689 Binding {
690 memory: MemoryHandle::binding(self.memory),
691 offset_start: self.offset_start,
692 offset_end: self.offset_end,
693 size: self.size,
694 stream: self.stream,
695 cursor: self.cursor,
696 }
697 }
698
699 pub fn copy_descriptor<'a>(
701 &'a self,
702 shape: &'a [usize],
703 strides: &'a [usize],
704 elem_size: usize,
705 ) -> CopyDescriptor<'a> {
706 CopyDescriptor {
707 shape,
708 strides,
709 elem_size,
710 binding: self.clone().binding(),
711 }
712 }
713}
714
715impl Clone for Handle {
716 fn clone(&self) -> Self {
717 Self {
718 memory: self.memory.clone(),
719 offset_start: self.offset_start,
720 offset_end: self.offset_end,
721 size: self.size,
722 stream: self.stream,
723 cursor: self.cursor,
724 }
725 }
726}
727
728impl Clone for Binding {
729 fn clone(&self) -> Self {
730 Self {
731 memory: self.memory.clone(),
732 offset_start: self.offset_start,
733 offset_end: self.offset_end,
734 size: self.size,
735 stream: self.stream,
736 cursor: self.cursor,
737 }
738 }
739}
740
741#[allow(clippy::large_enum_variant)]
745pub enum CubeCount {
746 Static(u32, u32, u32),
748 Dynamic(Binding),
750}
751
752impl CubeCount {
753 pub fn new_single() -> Self {
755 CubeCount::Static(1, 1, 1)
756 }
757
758 pub fn new_1d(x: u32) -> Self {
760 CubeCount::Static(x, 1, 1)
761 }
762
763 pub fn new_2d(x: u32, y: u32) -> Self {
765 CubeCount::Static(x, y, 1)
766 }
767
768 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
770 CubeCount::Static(x, y, z)
771 }
772}
773
774impl Debug for CubeCount {
775 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
776 match self {
777 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
778 CubeCount::Dynamic(_) => f.write_str("binding"),
779 }
780 }
781}
782
783impl Clone for CubeCount {
784 fn clone(&self) -> Self {
785 match self {
786 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
787 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
788 }
789 }
790}