1use crate::{
2 client::ComputeClient,
3 compiler::CompilationError,
4 kernel::KernelMetadata,
5 logging::ServerLogger,
6 memory_management::{
7 MemoryAllocationMode, MemoryHandle, MemoryUsage,
8 memory_pool::{SliceBinding, SliceHandle},
9 },
10 runtime::Runtime,
11 storage::{BindingResource, ComputeStorage},
12 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::collections::BTreeMap;
15use alloc::string::String;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::fmt::Debug;
20use cubecl_common::{
21 backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
22 stream_id::StreamId,
23};
24use cubecl_ir::{DeviceProperties, StorageType};
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28#[derive(Error, Clone)]
29pub enum ProfileError {
31 #[error(
33 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
34 )]
35 Unknown {
36 reason: String,
38 backtrace: BackTrace,
40 },
41
42 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
44 NotRegistered {
45 backtrace: BackTrace,
47 },
48
49 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
51 Launch(#[from] LaunchError),
52
53 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
55 Execution(#[from] ExecutionError),
56}
57
58impl core::fmt::Debug for ProfileError {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60 f.write_fmt(format_args!("{self}"))
61 }
62}
63
64pub struct ServerUtilities<Server: ComputeServer> {
66 #[cfg(feature = "profile-tracy")]
68 pub epoch_time: web_time::Instant,
69 #[cfg(feature = "profile-tracy")]
71 pub gpu_client: tracy_client::GpuContext,
72 pub properties: DeviceProperties,
74 pub info: Server::Info,
76 pub logger: Arc<ServerLogger>,
78}
79
80impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
81where
82 Server: ComputeServer,
83 Server::Info: core::fmt::Debug,
84{
85 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
86 f.debug_struct("ServerUtilities")
87 .field("properties", &self.properties)
88 .field("info", &self.info)
89 .field("logger", &self.logger)
90 .finish()
91 }
92}
93
94impl<S: ComputeServer> ServerUtilities<S> {
95 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
97 #[cfg(feature = "profile-tracy")]
99 let client = tracy_client::Client::start();
100
101 Self {
102 properties,
103 logger,
104 #[cfg(feature = "profile-tracy")]
106 gpu_client: client
107 .clone()
108 .new_gpu_context(
109 Some(&format!("{info:?}")),
110 tracy_client::GpuContextType::Invalid,
112 0, 1.0, )
115 .unwrap(),
116 #[cfg(feature = "profile-tracy")]
117 epoch_time: web_time::Instant::now(),
118 info,
119 }
120 }
121}
122
123#[derive(Error, Clone)]
130#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
131pub enum LaunchError {
132 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
134 CompilationError(#[from] CompilationError),
135
136 #[error(
138 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
139 )]
140 OutOfMemory {
141 reason: String,
143 #[cfg_attr(std_io, serde(skip))]
145 backtrace: BackTrace,
146 },
147
148 #[error(
150 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
151 )]
152 Unknown {
153 reason: String,
155 #[cfg_attr(std_io, serde(skip))]
157 backtrace: BackTrace,
158 },
159
160 #[error("An io error happened during launch\nCaused by:\n {0}")]
162 IoError(#[from] IoError),
163}
164
165impl core::fmt::Debug for LaunchError {
166 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
167 f.write_fmt(format_args!("{self}"))
168 }
169}
170
171#[derive(Error, Debug, Clone)]
173#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
174pub enum ExecutionError {
175 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
177 Generic {
178 reason: String,
180 #[cfg_attr(std_io, serde(skip))]
182 backtrace: BackTrace,
183 },
184}
185
186pub trait ComputeServer:
191 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
192where
193 Self: Sized,
194{
195 type Kernel: KernelMetadata;
197 type Info: Debug + Send + Sync;
199 type Storage: ComputeStorage;
201
202 fn create(
204 &mut self,
205 descriptors: Vec<AllocationDescriptor<'_>>,
206 stream_id: StreamId,
207 ) -> Result<Vec<Allocation>, IoError>;
208
209 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
211 Err(IoError::UnsupportedIoOperation {
212 backtrace: BackTrace::capture(),
213 })
214 }
215
216 fn logger(&self) -> Arc<ServerLogger>;
218
219 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
221
222 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
224 let alloc = self
225 .create(
226 vec![AllocationDescriptor::new(
227 AllocationKind::Contiguous,
228 &[data.len()],
229 1,
230 )],
231 stream_id,
232 )?
233 .remove(0);
234 self.write(
235 vec![(
236 CopyDescriptor::new(
237 alloc.handle.clone().binding(),
238 &[data.len()],
239 &alloc.strides,
240 1,
241 ),
242 Bytes::from_bytes_vec(data.to_vec()),
243 )],
244 stream_id,
245 )?;
246 Ok(alloc.handle)
247 }
248
249 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
251 let alloc = self
252 .create(
253 vec![AllocationDescriptor::new(
254 AllocationKind::Contiguous,
255 &[data.len()],
256 1,
257 )],
258 stream_id,
259 )?
260 .remove(0);
261 self.write(
262 vec![(
263 CopyDescriptor::new(
264 alloc.handle.clone().binding(),
265 &[data.len()],
266 &alloc.strides,
267 1,
268 ),
269 data,
270 )],
271 stream_id,
272 )?;
273 Ok(alloc.handle)
274 }
275
276 fn read<'a>(
278 &mut self,
279 descriptors: Vec<CopyDescriptor<'a>>,
280 stream_id: StreamId,
281 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
282
283 fn write(
285 &mut self,
286 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
287 stream_id: StreamId,
288 ) -> Result<(), IoError>;
289
290 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
292
293 fn get_resource(
295 &mut self,
296 binding: Binding,
297 stream_id: StreamId,
298 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
299
300 unsafe fn launch(
309 &mut self,
310 kernel: Self::Kernel,
311 count: CubeCount,
312 bindings: Bindings,
313 kind: ExecutionMode,
314 stream_id: StreamId,
315 ) -> Result<(), LaunchError>;
316
317 fn flush(&mut self, stream_id: StreamId);
319
320 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
322
323 fn memory_cleanup(&mut self, stream_id: StreamId);
325
326 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
328
329 fn end_profile(
331 &mut self,
332 stream_id: StreamId,
333 token: ProfilingToken,
334 ) -> Result<ProfileDuration, ProfileError>;
335
336 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
338}
339
340pub trait ServerCommunication {
343 const SERVER_COMM_ENABLED: bool;
345
346 #[allow(unused_variables)]
365 fn copy(
366 server_src: &mut Self,
367 server_dst: &mut Self,
368 src: CopyDescriptor<'_>,
369 stream_id_src: StreamId,
370 stream_id_dst: StreamId,
371 ) -> Result<Allocation, IoError> {
372 if !Self::SERVER_COMM_ENABLED {
373 panic!("Server-to-server communication is not supported by this server.");
374 } else {
375 panic!(
376 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
377 );
378 }
379 }
380}
381
382#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
383pub struct ProfilingToken {
385 pub id: u64,
387}
388
389#[derive(new, Debug, PartialEq, Eq)]
391pub struct Handle {
392 pub memory: SliceHandle,
394 pub offset_start: Option<u64>,
396 pub offset_end: Option<u64>,
398 pub stream: cubecl_common::stream_id::StreamId,
400 pub cursor: u64,
402 size: u64,
404}
405
406#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
408pub enum AllocationKind {
409 Contiguous,
411 Optimized,
414}
415
416#[derive(new, Debug, Clone, Copy)]
418pub struct AllocationDescriptor<'a> {
419 pub kind: AllocationKind,
421 pub shape: &'a [usize],
423 pub elem_size: usize,
425}
426
427impl<'a> AllocationDescriptor<'a> {
428 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
430 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
431 }
432
433 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
435 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
436 }
437}
438
439#[derive(new, Debug)]
441pub struct Allocation {
442 pub handle: Handle,
444 pub strides: Vec<usize>,
446}
447
448#[derive(Error, Clone)]
451#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
452pub enum IoError {
453 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
455 BufferTooBig {
456 size: u64,
458 #[cfg_attr(std_io, serde(skip))]
460 backtrace: BackTrace,
461 },
462
463 #[error("the provided strides are not supported for this operation\n{backtrace}")]
465 UnsupportedStrides {
466 #[cfg_attr(std_io, serde(skip))]
468 backtrace: BackTrace,
469 },
470
471 #[error("couldn't find resource for that handle\n{backtrace}")]
473 InvalidHandle {
474 #[cfg_attr(std_io, serde(skip))]
476 backtrace: BackTrace,
477 },
478
479 #[error("Unknown error happened during execution\n{backtrace}")]
481 Unknown {
482 description: String,
484 #[cfg_attr(std_io, serde(skip))]
486 backtrace: BackTrace,
487 },
488
489 #[error("The current IO operation is not supported\n{backtrace}")]
491 UnsupportedIoOperation {
492 #[cfg_attr(std_io, serde(skip))]
494 backtrace: BackTrace,
495 },
496
497 #[error("Can't perform the IO operation because of a runtime error")]
499 Execution(#[from] ExecutionError),
500}
501
502impl core::fmt::Debug for IoError {
503 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
504 f.write_fmt(format_args!("{self}"))
505 }
506}
507
508impl Handle {
509 pub fn offset_start(mut self, offset: u64) -> Self {
511 if let Some(val) = &mut self.offset_start {
512 *val += offset;
513 } else {
514 self.offset_start = Some(offset);
515 }
516
517 self
518 }
519 pub fn offset_end(mut self, offset: u64) -> Self {
521 if let Some(val) = &mut self.offset_end {
522 *val += offset;
523 } else {
524 self.offset_end = Some(offset);
525 }
526
527 self
528 }
529
530 pub fn size(&self) -> u64 {
532 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
533 }
534}
535
536#[derive(Debug, Default)]
538pub struct Bindings {
539 pub buffers: Vec<Binding>,
541 pub metadata: MetadataBinding,
544 pub scalars: BTreeMap<StorageType, ScalarBinding>,
546 pub tensor_maps: Vec<TensorMapBinding>,
548}
549
550impl Bindings {
551 pub fn new() -> Self {
553 Self::default()
554 }
555
556 pub fn with_buffer(mut self, binding: Binding) -> Self {
558 self.buffers.push(binding);
559 self
560 }
561
562 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
564 self.buffers.extend(bindings);
565 self
566 }
567
568 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
570 self.scalars
571 .insert(ty, ScalarBinding::new(ty, length, data));
572 self
573 }
574
575 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
577 self.scalars
578 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
579 self
580 }
581
582 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
584 self.metadata = meta;
585 self
586 }
587
588 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
590 self.tensor_maps.extend(bindings);
591 self
592 }
593}
594
595#[derive(new, Debug, Default)]
597pub struct MetadataBinding {
598 pub data: Vec<u64>,
600 pub static_len: usize,
602}
603
604#[derive(new, Debug, Clone)]
606pub struct ScalarBinding {
607 pub ty: StorageType,
609 pub length: usize,
611 pub data: Vec<u64>,
613}
614
615impl ScalarBinding {
616 pub fn data(&self) -> &[u8] {
618 bytemuck::cast_slice(&self.data)
619 }
620}
621
622#[derive(new, Debug)]
624pub struct Binding {
625 pub memory: SliceBinding,
627 pub offset_start: Option<u64>,
629 pub offset_end: Option<u64>,
631 pub stream: cubecl_common::stream_id::StreamId,
633 pub cursor: u64,
635 size: u64,
637}
638
639impl Binding {
640 pub fn size(&self) -> u64 {
642 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
643 }
644}
645
646#[derive(new, Debug, Clone)]
648pub struct CopyDescriptor<'a> {
649 pub binding: Binding,
651 pub shape: &'a [usize],
653 pub strides: &'a [usize],
655 pub elem_size: usize,
657}
658
659#[derive(new, Debug, Clone)]
661pub struct TensorMapBinding {
662 pub binding: Binding,
664 pub map: TensorMapMeta,
666}
667
668#[derive(Debug, Clone)]
670pub struct TensorMapMeta {
671 pub format: TensorMapFormat,
673 pub rank: usize,
675 pub shape: Vec<usize>,
677 pub strides: Vec<usize>,
679 pub elem_stride: Vec<usize>,
682 pub interleave: TensorMapInterleave,
684 pub swizzle: TensorMapSwizzle,
686 pub prefetch: TensorMapPrefetch,
688 pub oob_fill: OobFill,
690 pub storage_ty: StorageType,
692}
693
694impl Handle {
695 pub fn can_mut(&self) -> bool {
697 self.memory.can_mut() && self.stream == StreamId::current()
698 }
699}
700
701impl Handle {
702 pub fn binding(self) -> Binding {
704 Binding {
705 memory: MemoryHandle::binding(self.memory),
706 offset_start: self.offset_start,
707 offset_end: self.offset_end,
708 size: self.size,
709 stream: self.stream,
710 cursor: self.cursor,
711 }
712 }
713
714 pub fn copy_descriptor<'a>(
716 &'a self,
717 shape: &'a [usize],
718 strides: &'a [usize],
719 elem_size: usize,
720 ) -> CopyDescriptor<'a> {
721 CopyDescriptor {
722 shape,
723 strides,
724 elem_size,
725 binding: self.clone().binding(),
726 }
727 }
728}
729
730impl Clone for Handle {
731 fn clone(&self) -> Self {
732 Self {
733 memory: self.memory.clone(),
734 offset_start: self.offset_start,
735 offset_end: self.offset_end,
736 size: self.size,
737 stream: self.stream,
738 cursor: self.cursor,
739 }
740 }
741}
742
743impl Clone for Binding {
744 fn clone(&self) -> Self {
745 Self {
746 memory: self.memory.clone(),
747 offset_start: self.offset_start,
748 offset_end: self.offset_end,
749 size: self.size,
750 stream: self.stream,
751 cursor: self.cursor,
752 }
753 }
754}
755
756#[allow(clippy::large_enum_variant)]
760pub enum CubeCount {
761 Static(u32, u32, u32),
763 Dynamic(Binding),
765}
766
767pub enum CubeCountSelection {
769 Exact(CubeCount),
771 Approx(CubeCount, u32),
775}
776
777impl CubeCountSelection {
778 pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
780 let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
781
782 let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
783 let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
784
785 match num_cubes_actual == num_cubes {
786 true => CubeCountSelection::Exact(cube_count),
787 false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
788 }
789 }
790
791 pub fn has_idle(&self) -> bool {
793 matches!(self, Self::Approx(..))
794 }
795
796 pub fn cube_count(self) -> CubeCount {
798 match self {
799 CubeCountSelection::Exact(cube_count) => cube_count,
800 CubeCountSelection::Approx(cube_count, _) => cube_count,
801 }
802 }
803}
804
805impl From<CubeCountSelection> for CubeCount {
806 fn from(value: CubeCountSelection) -> Self {
807 value.cube_count()
808 }
809}
810
811impl CubeCount {
812 pub fn new_single() -> Self {
814 CubeCount::Static(1, 1, 1)
815 }
816
817 pub fn new_1d(x: u32) -> Self {
819 CubeCount::Static(x, 1, 1)
820 }
821
822 pub fn new_2d(x: u32, y: u32) -> Self {
824 CubeCount::Static(x, y, 1)
825 }
826
827 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
829 CubeCount::Static(x, y, z)
830 }
831}
832
833impl Debug for CubeCount {
834 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
835 match self {
836 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
837 CubeCount::Dynamic(_) => f.write_str("binding"),
838 }
839 }
840}
841
842impl Clone for CubeCount {
843 fn clone(&self) -> Self {
844 match self {
845 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
846 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
847 }
848 }
849}
850
851#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
852#[allow(missing_docs)]
853pub struct CubeDim {
855 pub x: u32,
857 pub y: u32,
859 pub z: u32,
861}
862
863impl CubeDim {
864 pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
872 let properties = client.properties();
873 let plane_size = properties.hardware.plane_size_max;
874 let plane_count = Self::calculate_plane_count_per_cube(
875 working_units as u32,
876 plane_size,
877 properties.hardware.num_cpu_cores,
878 );
879
880 let limit = properties.hardware.max_units_per_cube / plane_size;
882
883 Self::new_2d(plane_size, u32::min(limit, plane_count))
884 }
885
886 fn calculate_plane_count_per_cube(
887 working_units: u32,
888 plane_dim: u32,
889 num_cpu_cores: Option<u32>,
890 ) -> u32 {
891 match num_cpu_cores {
892 Some(num_cores) => core::cmp::min(num_cores, working_units),
893 None => {
894 let plane_count_max = core::cmp::max(1, working_units / plane_dim);
895
896 const NUM_PLANE_MAX: u32 = 8u32;
898 const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
899 let plane_count_max_log2 =
900 core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
901 2u32.pow(plane_count_max_log2)
902 }
903 }
904 }
905
906 pub const fn new_single() -> Self {
908 Self { x: 1, y: 1, z: 1 }
909 }
910
911 pub const fn new_1d(x: u32) -> Self {
913 Self { x, y: 1, z: 1 }
914 }
915
916 pub const fn new_2d(x: u32, y: u32) -> Self {
918 Self { x, y, z: 1 }
919 }
920
921 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
924 Self { x, y, z }
925 }
926
927 pub const fn num_elems(&self) -> u32 {
929 self.x * self.y * self.z
930 }
931
932 pub const fn can_contain(&self, other: CubeDim) -> bool {
934 self.x >= other.x && self.y >= other.y && self.z >= other.z
935 }
936}
937
938#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
940pub enum ExecutionMode {
941 #[default]
943 Checked,
944 Unchecked,
946}
947
948fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
949 let max_cube_counts = [max.0, max.1, max.2];
950 let mut num_cubes = [num_cubes, 1, 1];
951 let base = 2;
952
953 let mut reduce_count = |i: usize| {
954 if num_cubes[i] <= max_cube_counts[i] {
955 return true;
956 }
957
958 loop {
959 num_cubes[i] = num_cubes[i].div_ceil(base);
960 num_cubes[i + 1] *= base;
961
962 if num_cubes[i] <= max_cube_counts[i] {
963 return false;
964 }
965 }
966 };
967
968 for i in 0..2 {
969 if reduce_count(i) {
970 break;
971 }
972 }
973
974 num_cubes
975}
976
977#[cfg(test)]
978mod tests {
979 use super::*;
980
981 #[test_log::test]
982 fn safe_num_cubes_even() {
983 let max = (32, 32, 32);
984 let required = 2048;
985
986 let actual = cube_count_spread(&max, required);
987 let expected = [32, 32, 2];
988 assert_eq!(actual, expected);
989 }
990
991 #[test_log::test]
992 fn safe_num_cubes_odd() {
993 let max = (48, 32, 16);
994 let required = 3177;
995
996 let actual = cube_count_spread(&max, required);
997 let expected = [25, 32, 4];
998 assert_eq!(actual, expected);
999 }
1000}