1use crate::{
2 DeviceProperties,
3 client::ComputeClient,
4 compiler::CompilationError,
5 kernel::KernelMetadata,
6 logging::ServerLogger,
7 memory_management::{
8 MemoryAllocationMode, MemoryHandle, MemoryUsage,
9 memory_pool::{SliceBinding, SliceHandle},
10 },
11 runtime::Runtime,
12 storage::{BindingResource, ComputeStorage},
13 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
14};
15use alloc::collections::BTreeMap;
16use alloc::string::String;
17use alloc::sync::Arc;
18use alloc::vec;
19use alloc::vec::Vec;
20use core::fmt::Debug;
21use cubecl_common::{
22 backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
23 stream_id::StreamId,
24};
25use cubecl_ir::StorageType;
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Error, Clone)]
30pub enum ProfileError {
32 #[error(
34 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
35 )]
36 Unknown {
37 reason: String,
39 backtrace: BackTrace,
41 },
42
43 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
45 NotRegistered {
46 backtrace: BackTrace,
48 },
49
50 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
52 Launch(#[from] LaunchError),
53
54 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
56 Execution(#[from] ExecutionError),
57}
58
59impl core::fmt::Debug for ProfileError {
60 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61 f.write_fmt(format_args!("{self}"))
62 }
63}
64
65pub struct ServerUtilities<Server: ComputeServer> {
67 #[cfg(feature = "profile-tracy")]
69 pub epoch_time: web_time::Instant,
70 #[cfg(feature = "profile-tracy")]
72 pub gpu_client: tracy_client::GpuContext,
73 pub properties: DeviceProperties,
75 pub info: Server::Info,
77 pub logger: Arc<ServerLogger>,
79}
80
81impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
82where
83 Server: ComputeServer,
84 Server::Info: core::fmt::Debug,
85{
86 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
87 f.debug_struct("ServerUtilities")
88 .field("properties", &self.properties)
89 .field("info", &self.info)
90 .field("logger", &self.logger)
91 .finish()
92 }
93}
94
95impl<S: ComputeServer> ServerUtilities<S> {
96 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
98 #[cfg(feature = "profile-tracy")]
100 let client = tracy_client::Client::start();
101
102 Self {
103 properties,
104 logger,
105 #[cfg(feature = "profile-tracy")]
107 gpu_client: client
108 .clone()
109 .new_gpu_context(
110 Some(&format!("{info:?}")),
111 tracy_client::GpuContextType::Invalid,
113 0, 1.0, )
116 .unwrap(),
117 #[cfg(feature = "profile-tracy")]
118 epoch_time: web_time::Instant::now(),
119 info,
120 }
121 }
122}
123
124#[derive(Error, Clone)]
131#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
132pub enum LaunchError {
133 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
135 CompilationError(#[from] CompilationError),
136
137 #[error(
139 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
140 )]
141 OutOfMemory {
142 reason: String,
144 #[cfg_attr(std_io, serde(skip))]
146 backtrace: BackTrace,
147 },
148
149 #[error(
151 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
152 )]
153 Unknown {
154 reason: String,
156 #[cfg_attr(std_io, serde(skip))]
158 backtrace: BackTrace,
159 },
160
161 #[error("An io error happened during launch\nCaused by:\n {0}")]
163 IoError(#[from] IoError),
164}
165
166impl core::fmt::Debug for LaunchError {
167 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
168 f.write_fmt(format_args!("{self}"))
169 }
170}
171
172#[derive(Error, Debug, Clone)]
174#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
175pub enum ExecutionError {
176 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
178 Generic {
179 reason: String,
181 #[cfg_attr(std_io, serde(skip))]
183 backtrace: BackTrace,
184 },
185}
186
187pub trait ComputeServer:
192 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
193where
194 Self: Sized,
195{
196 type Kernel: KernelMetadata;
198 type Info: Debug + Send + Sync;
200 type Storage: ComputeStorage;
202
203 fn create(
205 &mut self,
206 descriptors: Vec<AllocationDescriptor<'_>>,
207 stream_id: StreamId,
208 ) -> Result<Vec<Allocation>, IoError>;
209
210 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
212 Err(IoError::UnsupportedIoOperation {
213 backtrace: BackTrace::capture(),
214 })
215 }
216
217 fn logger(&self) -> Arc<ServerLogger>;
219
220 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
222
223 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
225 let alloc = self
226 .create(
227 vec![AllocationDescriptor::new(
228 AllocationKind::Contiguous,
229 &[data.len()],
230 1,
231 )],
232 stream_id,
233 )?
234 .remove(0);
235 self.write(
236 vec![(
237 CopyDescriptor::new(
238 alloc.handle.clone().binding(),
239 &[data.len()],
240 &alloc.strides,
241 1,
242 ),
243 Bytes::from_bytes_vec(data.to_vec()),
244 )],
245 stream_id,
246 )?;
247 Ok(alloc.handle)
248 }
249
250 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
252 let alloc = self
253 .create(
254 vec![AllocationDescriptor::new(
255 AllocationKind::Contiguous,
256 &[data.len()],
257 1,
258 )],
259 stream_id,
260 )?
261 .remove(0);
262 self.write(
263 vec![(
264 CopyDescriptor::new(
265 alloc.handle.clone().binding(),
266 &[data.len()],
267 &alloc.strides,
268 1,
269 ),
270 data,
271 )],
272 stream_id,
273 )?;
274 Ok(alloc.handle)
275 }
276
277 fn read<'a>(
279 &mut self,
280 descriptors: Vec<CopyDescriptor<'a>>,
281 stream_id: StreamId,
282 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
283
284 fn write(
286 &mut self,
287 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
288 stream_id: StreamId,
289 ) -> Result<(), IoError>;
290
291 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
293
294 fn get_resource(
296 &mut self,
297 binding: Binding,
298 stream_id: StreamId,
299 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
300
301 unsafe fn launch(
310 &mut self,
311 kernel: Self::Kernel,
312 count: CubeCount,
313 bindings: Bindings,
314 kind: ExecutionMode,
315 stream_id: StreamId,
316 ) -> Result<(), LaunchError>;
317
318 fn flush(&mut self, stream_id: StreamId);
320
321 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
323
324 fn memory_cleanup(&mut self, stream_id: StreamId);
326
327 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
329
330 fn end_profile(
332 &mut self,
333 stream_id: StreamId,
334 token: ProfilingToken,
335 ) -> Result<ProfileDuration, ProfileError>;
336
337 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
339}
340
341pub trait ServerCommunication {
344 const SERVER_COMM_ENABLED: bool;
346
347 #[allow(unused_variables)]
366 fn copy(
367 server_src: &mut Self,
368 server_dst: &mut Self,
369 src: CopyDescriptor<'_>,
370 stream_id_src: StreamId,
371 stream_id_dst: StreamId,
372 ) -> Result<Allocation, IoError> {
373 if !Self::SERVER_COMM_ENABLED {
374 panic!("Server-to-server communication is not supported by this server.");
375 } else {
376 panic!(
377 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
378 );
379 }
380 }
381}
382
383#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
384pub struct ProfilingToken {
386 pub id: u64,
388}
389
390#[derive(new, Debug, PartialEq, Eq)]
392pub struct Handle {
393 pub memory: SliceHandle,
395 pub offset_start: Option<u64>,
397 pub offset_end: Option<u64>,
399 pub stream: cubecl_common::stream_id::StreamId,
401 pub cursor: u64,
403 size: u64,
405}
406
407#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
409pub enum AllocationKind {
410 Contiguous,
412 Optimized,
415}
416
417#[derive(new, Debug, Clone, Copy)]
419pub struct AllocationDescriptor<'a> {
420 pub kind: AllocationKind,
422 pub shape: &'a [usize],
424 pub elem_size: usize,
426}
427
428impl<'a> AllocationDescriptor<'a> {
429 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
431 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
432 }
433
434 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
436 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
437 }
438}
439
440#[derive(new, Debug)]
442pub struct Allocation {
443 pub handle: Handle,
445 pub strides: Vec<usize>,
447}
448
449#[derive(Error, Clone)]
452#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
453pub enum IoError {
454 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
456 BufferTooBig {
457 size: u64,
459 #[cfg_attr(std_io, serde(skip))]
461 backtrace: BackTrace,
462 },
463
464 #[error("the provided strides are not supported for this operation\n{backtrace}")]
466 UnsupportedStrides {
467 #[cfg_attr(std_io, serde(skip))]
469 backtrace: BackTrace,
470 },
471
472 #[error("couldn't find resource for that handle\n{backtrace}")]
474 InvalidHandle {
475 #[cfg_attr(std_io, serde(skip))]
477 backtrace: BackTrace,
478 },
479
480 #[error("Unknown error happened during execution\n{backtrace}")]
482 Unknown {
483 description: String,
485 #[cfg_attr(std_io, serde(skip))]
487 backtrace: BackTrace,
488 },
489
490 #[error("The current IO operation is not supported\n{backtrace}")]
492 UnsupportedIoOperation {
493 #[cfg_attr(std_io, serde(skip))]
495 backtrace: BackTrace,
496 },
497
498 #[error("Can't perform the IO operation because of a runtime error")]
500 Execution(#[from] ExecutionError),
501}
502
503impl core::fmt::Debug for IoError {
504 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
505 f.write_fmt(format_args!("{self}"))
506 }
507}
508
509impl Handle {
510 pub fn offset_start(mut self, offset: u64) -> Self {
512 if let Some(val) = &mut self.offset_start {
513 *val += offset;
514 } else {
515 self.offset_start = Some(offset);
516 }
517
518 self
519 }
520 pub fn offset_end(mut self, offset: u64) -> Self {
522 if let Some(val) = &mut self.offset_end {
523 *val += offset;
524 } else {
525 self.offset_end = Some(offset);
526 }
527
528 self
529 }
530
531 pub fn size(&self) -> u64 {
533 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
534 }
535}
536
537#[derive(Debug, Default)]
539pub struct Bindings {
540 pub buffers: Vec<Binding>,
542 pub metadata: MetadataBinding,
545 pub scalars: BTreeMap<StorageType, ScalarBinding>,
547 pub tensor_maps: Vec<TensorMapBinding>,
549}
550
551impl Bindings {
552 pub fn new() -> Self {
554 Self::default()
555 }
556
557 pub fn with_buffer(mut self, binding: Binding) -> Self {
559 self.buffers.push(binding);
560 self
561 }
562
563 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
565 self.buffers.extend(bindings);
566 self
567 }
568
569 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
571 self.scalars
572 .insert(ty, ScalarBinding::new(ty, length, data));
573 self
574 }
575
576 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
578 self.scalars
579 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
580 self
581 }
582
583 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
585 self.metadata = meta;
586 self
587 }
588
589 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
591 self.tensor_maps.extend(bindings);
592 self
593 }
594}
595
596#[derive(new, Debug, Default)]
598pub struct MetadataBinding {
599 pub data: Vec<u32>,
601 pub static_len: usize,
603}
604
605#[derive(new, Debug, Clone)]
607pub struct ScalarBinding {
608 pub ty: StorageType,
610 pub length: usize,
612 pub data: Vec<u64>,
614}
615
616impl ScalarBinding {
617 pub fn data(&self) -> &[u8] {
619 bytemuck::cast_slice(&self.data)
620 }
621}
622
623#[derive(new, Debug)]
625pub struct Binding {
626 pub memory: SliceBinding,
628 pub offset_start: Option<u64>,
630 pub offset_end: Option<u64>,
632 pub stream: cubecl_common::stream_id::StreamId,
634 pub cursor: u64,
636 size: u64,
638}
639
640impl Binding {
641 pub fn size(&self) -> u64 {
643 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
644 }
645}
646
647#[derive(new, Debug, Clone)]
649pub struct CopyDescriptor<'a> {
650 pub binding: Binding,
652 pub shape: &'a [usize],
654 pub strides: &'a [usize],
656 pub elem_size: usize,
658}
659
660#[derive(new, Debug, Clone)]
662pub struct TensorMapBinding {
663 pub binding: Binding,
665 pub map: TensorMapMeta,
667}
668
669#[derive(Debug, Clone)]
671pub struct TensorMapMeta {
672 pub format: TensorMapFormat,
674 pub rank: usize,
676 pub shape: Vec<usize>,
678 pub strides: Vec<usize>,
680 pub elem_stride: Vec<usize>,
683 pub interleave: TensorMapInterleave,
685 pub swizzle: TensorMapSwizzle,
687 pub prefetch: TensorMapPrefetch,
689 pub oob_fill: OobFill,
691 pub storage_ty: StorageType,
693}
694
695impl Handle {
696 pub fn can_mut(&self) -> bool {
698 self.memory.can_mut() && self.stream == StreamId::current()
699 }
700}
701
702impl Handle {
703 pub fn binding(self) -> Binding {
705 Binding {
706 memory: MemoryHandle::binding(self.memory),
707 offset_start: self.offset_start,
708 offset_end: self.offset_end,
709 size: self.size,
710 stream: self.stream,
711 cursor: self.cursor,
712 }
713 }
714
715 pub fn copy_descriptor<'a>(
717 &'a self,
718 shape: &'a [usize],
719 strides: &'a [usize],
720 elem_size: usize,
721 ) -> CopyDescriptor<'a> {
722 CopyDescriptor {
723 shape,
724 strides,
725 elem_size,
726 binding: self.clone().binding(),
727 }
728 }
729}
730
731impl Clone for Handle {
732 fn clone(&self) -> Self {
733 Self {
734 memory: self.memory.clone(),
735 offset_start: self.offset_start,
736 offset_end: self.offset_end,
737 size: self.size,
738 stream: self.stream,
739 cursor: self.cursor,
740 }
741 }
742}
743
744impl Clone for Binding {
745 fn clone(&self) -> Self {
746 Self {
747 memory: self.memory.clone(),
748 offset_start: self.offset_start,
749 offset_end: self.offset_end,
750 size: self.size,
751 stream: self.stream,
752 cursor: self.cursor,
753 }
754 }
755}
756
757#[allow(clippy::large_enum_variant)]
761pub enum CubeCount {
762 Static(u32, u32, u32),
764 Dynamic(Binding),
766}
767
768pub enum CubeCountSelection {
770 Exact(CubeCount),
772 Approx(CubeCount, u32),
776}
777
778impl CubeCountSelection {
779 pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
781 let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
782
783 let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
784 let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
785
786 match num_cubes_actual == num_cubes {
787 true => CubeCountSelection::Exact(cube_count),
788 false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
789 }
790 }
791
792 pub fn has_idle(&self) -> bool {
794 matches!(self, Self::Approx(..))
795 }
796
797 pub fn cube_count(self) -> CubeCount {
799 match self {
800 CubeCountSelection::Exact(cube_count) => cube_count,
801 CubeCountSelection::Approx(cube_count, _) => cube_count,
802 }
803 }
804}
805
806impl From<CubeCountSelection> for CubeCount {
807 fn from(value: CubeCountSelection) -> Self {
808 value.cube_count()
809 }
810}
811
812impl CubeCount {
813 pub fn new_single() -> Self {
815 CubeCount::Static(1, 1, 1)
816 }
817
818 pub fn new_1d(x: u32) -> Self {
820 CubeCount::Static(x, 1, 1)
821 }
822
823 pub fn new_2d(x: u32, y: u32) -> Self {
825 CubeCount::Static(x, y, 1)
826 }
827
828 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
830 CubeCount::Static(x, y, z)
831 }
832}
833
834impl Debug for CubeCount {
835 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
836 match self {
837 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
838 CubeCount::Dynamic(_) => f.write_str("binding"),
839 }
840 }
841}
842
843impl Clone for CubeCount {
844 fn clone(&self) -> Self {
845 match self {
846 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
847 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
848 }
849 }
850}
851
852#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
853#[allow(missing_docs)]
854pub struct CubeDim {
856 pub x: u32,
858 pub y: u32,
860 pub z: u32,
862}
863
864impl CubeDim {
865 pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
873 let properties = client.properties();
874 let plane_size = properties.hardware.plane_size_max;
875 let plane_count = Self::calculate_plane_count_per_cube(
876 working_units as u32,
877 plane_size,
878 properties.hardware.num_cpu_cores,
879 );
880
881 Self::new_2d(plane_size, plane_count)
882 }
883
884 fn calculate_plane_count_per_cube(
885 working_units: u32,
886 plane_dim: u32,
887 num_cpu_cores: Option<u32>,
888 ) -> u32 {
889 match num_cpu_cores {
890 Some(num_cores) => core::cmp::min(num_cores, working_units),
891 None => {
892 let plane_count_max = core::cmp::max(1, working_units / plane_dim);
893
894 const NUM_PLANE_MAX: u32 = 8u32;
896 const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
897 let plane_count_max_log2 =
898 core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
899 2u32.pow(plane_count_max_log2)
900 }
901 }
902 }
903
904 pub const fn new_single() -> Self {
906 Self { x: 1, y: 1, z: 1 }
907 }
908
909 pub const fn new_1d(x: u32) -> Self {
911 Self { x, y: 1, z: 1 }
912 }
913
914 pub const fn new_2d(x: u32, y: u32) -> Self {
916 Self { x, y, z: 1 }
917 }
918
919 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
922 Self { x, y, z }
923 }
924
925 pub const fn num_elems(&self) -> u32 {
927 self.x * self.y * self.z
928 }
929
930 pub const fn can_contain(&self, other: CubeDim) -> bool {
932 self.x >= other.x && self.y >= other.y && self.z >= other.z
933 }
934}
935
936#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
938pub enum ExecutionMode {
939 #[default]
941 Checked,
942 Unchecked,
944}
945
946fn cube_count_spread(max: &CubeCount, num_cubes: u32) -> [u32; 3] {
947 let max_cube_counts = match max {
948 CubeCount::Static(x, y, z) => [*x, *y, *z],
949 CubeCount::Dynamic(_) => panic!("No static max cube count"),
950 };
951 let mut num_cubes = [num_cubes, 1, 1];
952 let base = 2;
953
954 let mut reduce_count = |i: usize| {
955 if num_cubes[i] <= max_cube_counts[i] {
956 return true;
957 }
958
959 loop {
960 num_cubes[i] = num_cubes[i].div_ceil(base);
961 num_cubes[i + 1] *= base;
962
963 if num_cubes[i] <= max_cube_counts[i] {
964 return false;
965 }
966 }
967 };
968
969 for i in 0..2 {
970 if reduce_count(i) {
971 break;
972 }
973 }
974
975 num_cubes
976}
977
978#[cfg(test)]
979mod tests {
980 use super::*;
981
982 #[test]
983 fn safe_num_cubes_even() {
984 let max = CubeCount::Static(32, 32, 32);
985 let required = 2048;
986
987 let actual = cube_count_spread(&max, required);
988 let expected = [32, 32, 2];
989 assert_eq!(actual, expected);
990 }
991
992 #[test]
993 fn safe_num_cubes_odd() {
994 let max = CubeCount::Static(48, 32, 16);
995 let required = 3177;
996
997 let actual = cube_count_spread(&max, required);
998 let expected = [25, 32, 4];
999 assert_eq!(actual, expected);
1000 }
1001}