1use super::Handle;
2use crate::{
3 client::ComputeClient,
4 compiler::CompilationError,
5 config::{CubeClRuntimeConfig, RuntimeConfig, compilation::BoundsCheckMode},
6 kernel::KernelMetadata,
7 logging::ServerLogger,
8 memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
9 runtime::Runtime,
10 server::Binding,
11 storage::{ComputeStorage, ManagedResource},
12 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use ahash::AHasher;
15use alloc::boxed::Box;
16#[cfg(feature = "profile-tracy")]
17use alloc::format;
18use alloc::string::String;
19use alloc::sync::Arc;
20use alloc::vec::Vec;
21use core::{
22 fmt::Debug,
23 hash::{Hash, Hasher},
24};
25use cubecl_common::{
26 backtrace::BackTrace,
27 bytes::Bytes,
28 device::{self, DeviceId},
29 future::DynFut,
30 profile::ProfileDuration,
31 stream_id::StreamId,
32 stub::RwLock,
33};
34use cubecl_ir::{DeviceProperties, ElemType, StorageType};
35use cubecl_zspace::{Shape, Strides, metadata::Metadata};
36use hashbrown::HashSet;
37use thiserror::Error;
38
39#[derive(Error, Clone)]
40#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
41pub enum ProfileError {
43 #[error(
45 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
46 )]
47 Unknown {
48 reason: String,
50 #[cfg_attr(std_io, serde(skip))]
52 backtrace: BackTrace,
53 },
54
55 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
57 NotRegistered {
58 #[cfg_attr(std_io, serde(skip))]
60 backtrace: BackTrace,
61 },
62
63 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
65 Launch(#[from] LaunchError),
66
67 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
69 Server(#[from] Box<ServerError>),
70}
71
72impl core::fmt::Debug for ProfileError {
73 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74 f.write_fmt(format_args!("{self}"))
75 }
76}
77
78pub struct ServerUtilities<Server: ComputeServer> {
80 #[cfg(feature = "profile-tracy")]
82 pub epoch_time: web_time::Instant,
83 #[cfg(feature = "profile-tracy")]
85 pub gpu_client: tracy_client::GpuContext,
86 pub properties: DeviceProperties,
88 pub properties_hash: u64,
90 pub info: Server::Info,
92 pub logger: Arc<ServerLogger>,
94 pub layout_policy: Server::MemoryLayoutPolicy,
96 pub check_mode: BoundsCheckMode,
98 pub initialized_comms: RwLock<HashSet<CommunicationId>>,
100}
101
102pub trait MemoryLayoutPolicy: Send + Sync + 'static {
104 fn apply(
109 &self,
110 stream_id: StreamId,
111 descriptors: &[MemoryLayoutDescriptor],
112 ) -> (Handle, Vec<MemoryLayout>);
113}
114
115impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
116where
117 Server: ComputeServer,
118 Server::Info: core::fmt::Debug,
119{
120 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
121 f.debug_struct("ServerUtilities")
122 .field("properties", &self.properties)
123 .field("info", &self.info)
124 .field("logger", &self.logger)
125 .finish()
126 }
127}
128
129impl<S: ComputeServer> ServerUtilities<S> {
130 pub fn new(
132 properties: DeviceProperties,
133 logger: Arc<ServerLogger>,
134 info: S::Info,
135 allocator: S::MemoryLayoutPolicy,
136 ) -> Self {
137 #[cfg(feature = "profile-tracy")]
139 let client = tracy_client::Client::start();
140
141 Self {
142 properties_hash: properties.checksum(),
143 properties,
144 logger,
145 #[cfg(feature = "profile-tracy")]
147 gpu_client: client
148 .clone()
149 .new_gpu_context(
150 Some(&format!("{info:?}")),
151 tracy_client::GpuContextType::Invalid,
153 0, 1.0, )
156 .unwrap(),
157 #[cfg(feature = "profile-tracy")]
158 epoch_time: web_time::Instant::now(),
159 info,
160 layout_policy: allocator,
161 check_mode: CubeClRuntimeConfig::get().compilation.check_mode,
162 initialized_comms: RwLock::new(HashSet::default()),
163 }
164 }
165}
166
167#[derive(Error, Clone)]
169#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
170pub enum LaunchError {
171 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
173 CompilationError(#[from] CompilationError),
174
175 #[error(
177 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
178 )]
179 OutOfMemory {
180 reason: String,
182 #[cfg_attr(std_io, serde(skip))]
184 backtrace: BackTrace,
185 },
186
187 #[error("Too many resources were requested during launch\n{0}")]
189 TooManyResources(#[from] ResourceLimitError),
190
191 #[error(
193 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
194 )]
195 Unknown {
196 reason: String,
198 #[cfg_attr(std_io, serde(skip))]
200 backtrace: BackTrace,
201 },
202
203 #[error("An io error happened during launch\nCaused by:\n {0}")]
205 IoError(#[from] IoError),
206}
207
208#[derive(Error, Clone)]
210#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
211pub enum ResourceLimitError {
212 #[error(
214 "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
215 )]
216 SharedMemory {
217 requested: usize,
219 max: usize,
221 #[cfg_attr(std_io, serde(skip))]
223 backtrace: BackTrace,
224 },
225 #[error(
227 "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
228 )]
229 Units {
230 requested: u32,
232 max: u32,
234 #[cfg_attr(std_io, serde(skip))]
236 backtrace: BackTrace,
237 },
238 #[error(
240 "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
241 )]
242 CubeDim {
243 requested: (u32, u32, u32),
245 max: (u32, u32, u32),
247 #[cfg_attr(std_io, serde(skip))]
249 backtrace: BackTrace,
250 },
251}
252
253impl core::fmt::Debug for LaunchError {
254 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
255 f.write_fmt(format_args!("{self}"))
256 }
257}
258
259impl core::fmt::Debug for ResourceLimitError {
260 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
261 f.write_fmt(format_args!("{self}"))
262 }
263}
264
265#[derive(Error, Debug, Clone)]
267#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
268pub enum ServerError {
269 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
271 Generic {
272 reason: String,
274 #[cfg_attr(std_io, serde(skip))]
276 backtrace: BackTrace,
277 },
278
279 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
281 Launch(#[from] LaunchError),
282
283 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
285 Profile(#[from] ProfileError),
286
287 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
289 Io(#[from] IoError),
290
291 #[error("The server is in an invalid state\nCaused by:\n {errors:?}")]
293 ServerUnhealthy {
294 errors: Vec<Self>,
296 #[cfg_attr(std_io, serde(skip))]
298 backtrace: BackTrace,
299 },
300}
301
302#[derive(Clone, Copy)]
304pub struct StreamErrorMode {
305 pub ignore: bool,
307 pub flush: bool,
309}
310
311pub trait ComputeServer:
316 Send + core::fmt::Debug + ServerCommunication + device::DeviceService + 'static
317where
318 Self: Sized,
319{
320 type Kernel: KernelMetadata;
322 type Info: Debug + Send + Sync;
324 type MemoryLayoutPolicy: MemoryLayoutPolicy;
326 type Storage: ComputeStorage;
328
329 fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId);
331
332 fn staging(
334 &mut self,
335 _sizes: &[usize],
336 _stream_id: StreamId,
337 ) -> Result<Vec<Bytes>, ServerError> {
338 Err(IoError::UnsupportedIoOperation {
339 backtrace: BackTrace::capture(),
340 }
341 .into())
342 }
343
344 fn logger(&self) -> Arc<ServerLogger>;
346
347 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
349
350 fn read(
352 &mut self,
353 descriptors: Vec<CopyDescriptor>,
354 stream_id: StreamId,
355 ) -> DynFut<Result<Vec<Bytes>, ServerError>>;
356
357 fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId);
359
360 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>>;
362
363 fn get_resource(
365 &mut self,
366 binding: Binding,
367 stream_id: StreamId,
368 ) -> Result<ManagedResource<<Self::Storage as ComputeStorage>::Resource>, ServerError>;
369
370 unsafe fn launch(
379 &mut self,
380 kernel: Self::Kernel,
381 count: CubeCount,
382 bindings: KernelArguments,
383 kind: ExecutionMode,
384 stream_id: StreamId,
385 );
386
387 fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError>;
389
390 fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError>;
392
393 fn memory_cleanup(&mut self, stream_id: StreamId);
395
396 fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError>;
398
399 fn end_profile(
401 &mut self,
402 stream_id: StreamId,
403 token: ProfilingToken,
404 ) -> Result<ProfileDuration, ProfileError>;
405
406 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
408}
409
410#[derive(Clone, Debug, Hash, Eq, PartialEq)]
412pub struct CommunicationId {
413 pub id: u64,
415}
416
417impl From<Vec<DeviceId>> for CommunicationId {
418 fn from(mut value: Vec<DeviceId>) -> Self {
419 value.sort();
421 let mut hasher = AHasher::default();
422 value.hash(&mut hasher);
423 CommunicationId {
424 id: hasher.finish(),
425 }
426 }
427}
428
429pub enum ReduceOperation {
431 Sum,
433 Mean,
435}
436
437pub trait ServerCommunication {
440 const SERVER_COMM_ENABLED: bool;
442
443 #[allow(unused_variables)]
453 fn sync_collective(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
454 todo!() }
456
457 #[allow(unused_variables)]
467 fn comm_init(&mut self, device_ids: Vec<DeviceId>) -> Result<(), ServerError> {
468 unimplemented!()
469 }
470
471 #[allow(unused_variables)]
487 fn all_reduce(
488 &mut self,
489 src: Binding,
490 dst: Binding,
491 dtype: ElemType,
492 stream_id: StreamId,
493 op: ReduceOperation,
494 device_ids: Vec<DeviceId>,
495 ) -> Result<(), ServerError> {
496 unimplemented!()
497 }
498
499 #[allow(unused_variables)]
512 fn send(
513 &mut self,
514 desc: CopyDescriptor,
515 dtype: ElemType,
516 stream_id: StreamId,
517 device_id_dst: DeviceId,
518 ) -> Result<(), ServerError> {
519 unimplemented!()
520 }
521
522 #[allow(unused_variables)]
535 fn recv(
536 &mut self,
537 handle: Handle,
538 dtype: ElemType,
539 stream_id: StreamId,
540 device_id_src: DeviceId,
541 ) -> Result<(), ServerError> {
542 unimplemented!()
543 }
544}
545
546#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
547pub struct ProfilingToken {
549 pub id: u64,
551}
552
553#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
555pub enum MemoryLayoutStrategy {
556 Contiguous,
558 Optimized,
561}
562
563#[derive(new, Debug, Clone)]
565pub struct MemoryLayoutDescriptor {
566 pub strategy: MemoryLayoutStrategy,
568 pub shape: Shape,
570 pub elem_size: usize,
572}
573
574impl MemoryLayoutDescriptor {
575 pub fn optimized(shape: Shape, elem_size: usize) -> Self {
577 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size)
578 }
579
580 pub fn contiguous(shape: Shape, elem_size: usize) -> Self {
582 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, elem_size)
583 }
584}
585
586#[derive(Debug, Clone)]
588pub struct MemoryLayout {
589 pub memory: Handle,
591 pub strides: Strides,
595}
596
597impl MemoryLayout {
598 pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
600 MemoryLayout {
601 memory: handle,
602 strides: strides.into(),
603 }
604 }
605}
606
607#[derive(Default, Clone)]
609pub struct Reason {
610 inner: ReasonInner,
611}
612
613#[cfg(std_io)]
614mod _reason_serde {
615 use super::*;
616
617 use alloc::string::ToString;
618 use serde::{Deserialize, Deserializer, Serialize, Serializer};
619
620 impl Serialize for Reason {
621 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
622 where
623 S: Serializer,
624 {
625 serializer.serialize_str(&self.to_string())
627 }
628 }
629
630 impl<'de> Deserialize<'de> for Reason {
631 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
632 where
633 D: Deserializer<'de>,
634 {
635 let s = String::deserialize(deserializer)?;
637
638 Ok(Reason {
641 inner: ReasonInner::Dynamic(Arc::new(s)),
642 })
643 }
644 }
645}
646
647#[derive(Default, Clone)]
648enum ReasonInner {
649 Static(&'static str),
650 Dynamic(Arc<String>),
651 #[default]
652 NotProvided,
653}
654
655impl core::fmt::Display for Reason {
656 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
657 match &self.inner {
658 ReasonInner::Static(content) => f.write_str(content),
659 ReasonInner::Dynamic(content) => f.write_str(content),
660 ReasonInner::NotProvided => f.write_str("No reason provided for the error"),
661 }
662 }
663}
664
665impl core::fmt::Debug for Reason {
666 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
667 core::fmt::Display::fmt(&self, f)
668 }
669}
670
671impl From<&'static str> for Reason {
672 fn from(value: &'static str) -> Self {
673 Self {
674 inner: ReasonInner::Static(value),
675 }
676 }
677}
678
679impl From<String> for Reason {
680 fn from(value: String) -> Self {
681 Self {
682 inner: ReasonInner::Dynamic(Arc::new(value)),
683 }
684 }
685}
686
687#[derive(Error, Clone)]
690#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
691pub enum IoError {
692 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
694 BufferTooBig {
695 size: u64,
697 #[cfg_attr(std_io, serde(skip))]
699 backtrace: BackTrace,
700 },
701
702 #[error("the provided strides are not supported for this operation\n{backtrace}")]
704 UnsupportedStrides {
705 #[cfg_attr(std_io, serde(skip))]
707 backtrace: BackTrace,
708 },
709
710 #[error("couldn't find resource for that handle: {reason}\n{backtrace}")]
712 NotFound {
713 #[cfg_attr(std_io, serde(skip))]
715 backtrace: BackTrace,
716 reason: Reason,
718 },
719
720 #[error("couldn't free the handle, since it is currently in used. \n{backtrace}")]
722 FreeError {
723 #[cfg_attr(std_io, serde(skip))]
725 backtrace: BackTrace,
726 },
727
728 #[error("Unknown error happened during execution\n{backtrace}")]
730 Unknown {
731 description: String,
733 #[cfg_attr(std_io, serde(skip))]
735 backtrace: BackTrace,
736 },
737
738 #[error("The current IO operation is not supported\n{backtrace}")]
740 UnsupportedIoOperation {
741 #[cfg_attr(std_io, serde(skip))]
743 backtrace: BackTrace,
744 },
745
746 #[error("Can't perform the IO operation because of a runtime error: {0}")]
748 Execution(#[from] Box<ServerError>),
749}
750
751impl core::fmt::Debug for IoError {
752 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
753 f.write_fmt(format_args!("{self}"))
754 }
755}
756
757#[derive(Debug, Default)]
759pub struct KernelArguments {
760 pub buffers: Vec<Binding>,
762 pub info: MetadataBindingInfo,
765 pub tensor_maps: Vec<TensorMapBinding>,
767}
768
769impl core::fmt::Display for KernelArguments {
770 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
771 f.write_str("KernelArguments")?;
772 for b in self.buffers.iter() {
773 f.write_fmt(format_args!("\n - buffer: {b:?}\n"))?;
774 }
775
776 Ok(())
777 }
778}
779
780impl KernelArguments {
781 pub fn new() -> Self {
783 Self::default()
784 }
785
786 pub fn with_buffer(mut self, binding: Binding) -> Self {
788 self.buffers.push(binding);
789 self
790 }
791
792 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
794 self.buffers.extend(bindings);
795 self
796 }
797
798 pub fn with_info(mut self, info: MetadataBindingInfo) -> Self {
800 self.info = info;
801 self
802 }
803
804 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
806 self.tensor_maps.extend(bindings);
807 self
808 }
809}
810
811#[derive(new, Debug, Default)]
816pub struct MetadataBindingInfo {
817 pub data: Vec<u64>,
819 pub dynamic_metadata_offset: usize,
821}
822
823impl MetadataBindingInfo {
824 pub fn custom(data: Vec<u64>) -> Self {
826 Self::new(data, 0)
827 }
828}
829
830#[derive(new, Debug)]
832pub struct CopyDescriptor {
833 pub handle: Binding,
835 pub shape: Shape,
837 pub strides: Strides,
839 pub elem_size: usize,
841}
842
843#[derive(new, Debug)]
845pub struct TensorMapBinding {
846 pub binding: Binding,
848 pub map: TensorMapMeta,
850}
851
852#[derive(Debug, Clone)]
854pub struct TensorMapMeta {
855 pub format: TensorMapFormat,
857 pub metadata: Metadata,
859 pub elem_stride: Strides,
862 pub interleave: TensorMapInterleave,
864 pub swizzle: TensorMapSwizzle,
866 pub prefetch: TensorMapPrefetch,
868 pub oob_fill: OobFill,
870 pub storage_ty: StorageType,
872}
873
874#[allow(clippy::large_enum_variant)]
878pub enum CubeCount {
879 Static(u32, u32, u32),
881 Dynamic(Binding),
883}
884
885pub enum CubeCountSelection {
887 Exact(CubeCount),
889 Approx(CubeCount, u32),
893}
894
895impl CubeCountSelection {
896 pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
898 let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
899
900 let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
901 let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
902
903 match num_cubes_actual == num_cubes {
904 true => CubeCountSelection::Exact(cube_count),
905 false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
906 }
907 }
908
909 pub fn has_idle(&self) -> bool {
911 matches!(self, Self::Approx(..))
912 }
913
914 pub fn cube_count(self) -> CubeCount {
916 match self {
917 CubeCountSelection::Exact(cube_count) => cube_count,
918 CubeCountSelection::Approx(cube_count, _) => cube_count,
919 }
920 }
921}
922
923impl From<CubeCountSelection> for CubeCount {
924 fn from(value: CubeCountSelection) -> Self {
925 value.cube_count()
926 }
927}
928
929impl CubeCount {
930 pub fn new_single() -> Self {
932 CubeCount::Static(1, 1, 1)
933 }
934
935 pub fn new_1d(x: u32) -> Self {
937 CubeCount::Static(x, 1, 1)
938 }
939
940 pub fn new_2d(x: u32, y: u32) -> Self {
942 CubeCount::Static(x, y, 1)
943 }
944
945 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
947 CubeCount::Static(x, y, z)
948 }
949
950 pub fn is_empty(&self) -> bool {
952 match self {
953 Self::Static(x, y, z) => *x == 0 || *y == 0 || *z == 0,
954 Self::Dynamic(_) => false,
955 }
956 }
957}
958
959impl Debug for CubeCount {
960 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
961 match self {
962 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
963 CubeCount::Dynamic(_) => f.write_str("binding"),
964 }
965 }
966}
967
968impl Clone for CubeCount {
969 fn clone(&self) -> Self {
970 match self {
971 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
972 Self::Dynamic(binding) => Self::Dynamic(binding.clone()),
973 }
974 }
975}
976
977#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
978#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
979#[allow(missing_docs)]
980pub struct CubeDim {
982 pub x: u32,
984 pub y: u32,
986 pub z: u32,
988}
989
990impl CubeDim {
991 pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
999 let properties = client.properties();
1000 let plane_size = properties.hardware.plane_size_max;
1001 let plane_count = Self::calculate_plane_count_per_cube(
1002 working_units as u32,
1003 plane_size,
1004 properties.hardware.num_cpu_cores,
1005 );
1006
1007 let limit = properties.hardware.max_units_per_cube / plane_size;
1009
1010 Self::new_2d(plane_size, u32::min(limit, plane_count).max(1))
1012 }
1013
1014 fn calculate_plane_count_per_cube(
1015 working_units: u32,
1016 plane_dim: u32,
1017 num_cpu_cores: Option<u32>,
1018 ) -> u32 {
1019 match num_cpu_cores {
1020 Some(num_cores) => core::cmp::min(num_cores, working_units),
1021 None => {
1022 let plane_count_max = core::cmp::max(1, working_units / plane_dim);
1023
1024 const NUM_PLANE_MAX: u32 = 8u32;
1026 const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
1027 let plane_count_max_log2 =
1028 core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
1029 2u32.pow(plane_count_max_log2)
1030 }
1031 }
1032 }
1033
1034 pub const fn new_single() -> Self {
1036 Self { x: 1, y: 1, z: 1 }
1037 }
1038
1039 pub const fn new_1d(x: u32) -> Self {
1041 Self { x, y: 1, z: 1 }
1042 }
1043
1044 pub const fn new_2d(x: u32, y: u32) -> Self {
1046 Self { x, y, z: 1 }
1047 }
1048
1049 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
1052 Self { x, y, z }
1053 }
1054
1055 pub const fn num_elems(&self) -> u32 {
1057 self.x * self.y * self.z
1058 }
1059
1060 pub const fn can_contain(&self, other: CubeDim) -> bool {
1062 self.x >= other.x && self.y >= other.y && self.z >= other.z
1063 }
1064}
1065
1066impl From<(u32, u32, u32)> for CubeDim {
1067 fn from(value: (u32, u32, u32)) -> Self {
1068 CubeDim::new_3d(value.0, value.1, value.2)
1069 }
1070}
1071
1072impl From<CubeDim> for (u32, u32, u32) {
1073 fn from(val: CubeDim) -> Self {
1074 (val.x, val.y, val.z)
1075 }
1076}
1077
1078#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
1080#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
1081pub enum ExecutionMode {
1082 #[default]
1084 Checked,
1085 Validate,
1087 Unchecked,
1089}
1090
1091fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1092 let max_cube_counts = [max.0, max.1, max.2];
1093 let mut num_cubes = [num_cubes, 1, 1];
1094 let base = 2;
1095
1096 let mut reduce_count = |i: usize| {
1097 if num_cubes[i] <= max_cube_counts[i] {
1098 return true;
1099 }
1100
1101 loop {
1102 num_cubes[i] = num_cubes[i].div_ceil(base);
1103 num_cubes[i + 1] *= base;
1104
1105 if num_cubes[i] <= max_cube_counts[i] {
1106 return false;
1107 }
1108 }
1109 };
1110
1111 for i in 0..2 {
1112 if reduce_count(i) {
1113 break;
1114 }
1115 }
1116
1117 num_cubes
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122 use super::*;
1123
1124 #[test_log::test]
1125 fn safe_num_cubes_even() {
1126 let max = (32, 32, 32);
1127 let required = 2048;
1128
1129 let actual = cube_count_spread(&max, required);
1130 let expected = [32, 32, 2];
1131 assert_eq!(actual, expected);
1132 }
1133
1134 #[test_log::test]
1135 fn safe_num_cubes_odd() {
1136 let max = (48, 32, 16);
1137 let required = 3177;
1138
1139 let actual = cube_count_spread(&max, required);
1140 let expected = [25, 32, 4];
1141 assert_eq!(actual, expected);
1142 }
1143}