1use super::Handle;
2use crate::{
3 client::ComputeClient,
4 compiler::CompilationError,
5 config::{GlobalConfig, compilation::BoundsCheckMode},
6 kernel::KernelMetadata,
7 logging::ServerLogger,
8 memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
9 runtime::Runtime,
10 server::Binding,
11 storage::{ComputeStorage, ManagedResource},
12 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::boxed::Box;
15#[cfg(feature = "profile-tracy")]
16use alloc::format;
17use alloc::string::String;
18use alloc::sync::Arc;
19use alloc::vec::Vec;
20use core::fmt::Debug;
21use cubecl_common::{
22 backtrace::BackTrace,
23 bytes::Bytes,
24 device::{self, DeviceId},
25 future::DynFut,
26 profile::ProfileDuration,
27 stream_id::StreamId,
28};
29use cubecl_ir::{DeviceProperties, ElemType, StorageType};
30use cubecl_zspace::{Shape, Strides, metadata::Metadata};
31use thiserror::Error;
32
33#[derive(Error, Clone)]
34#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
35pub enum ProfileError {
37 #[error(
39 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
40 )]
41 Unknown {
42 reason: String,
44 #[cfg_attr(std_io, serde(skip))]
46 backtrace: BackTrace,
47 },
48
49 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
51 NotRegistered {
52 #[cfg_attr(std_io, serde(skip))]
54 backtrace: BackTrace,
55 },
56
57 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
59 Launch(#[from] LaunchError),
60
61 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
63 Server(#[from] Box<ServerError>),
64}
65
66impl core::fmt::Debug for ProfileError {
67 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
68 f.write_fmt(format_args!("{self}"))
69 }
70}
71
72pub struct ServerUtilities<Server: ComputeServer> {
74 #[cfg(feature = "profile-tracy")]
76 pub epoch_time: web_time::Instant,
77 #[cfg(feature = "profile-tracy")]
79 pub gpu_client: tracy_client::GpuContext,
80 pub properties: DeviceProperties,
82 pub properties_hash: u64,
84 pub info: Server::Info,
86 pub logger: Arc<ServerLogger>,
88 pub layout_policy: Server::MemoryLayoutPolicy,
90 pub check_mode: BoundsCheckMode,
92}
93
94pub trait MemoryLayoutPolicy: Send + Sync + 'static {
96 fn apply(
101 &self,
102 stream_id: StreamId,
103 descriptors: &[MemoryLayoutDescriptor],
104 ) -> (Handle, Vec<MemoryLayout>);
105}
106
107impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
108where
109 Server: ComputeServer,
110 Server::Info: core::fmt::Debug,
111{
112 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
113 f.debug_struct("ServerUtilities")
114 .field("properties", &self.properties)
115 .field("info", &self.info)
116 .field("logger", &self.logger)
117 .finish()
118 }
119}
120
121impl<S: ComputeServer> ServerUtilities<S> {
122 pub fn new(
124 properties: DeviceProperties,
125 logger: Arc<ServerLogger>,
126 info: S::Info,
127 allocator: S::MemoryLayoutPolicy,
128 ) -> Self {
129 #[cfg(feature = "profile-tracy")]
131 let client = tracy_client::Client::start();
132
133 Self {
134 properties_hash: properties.checksum(),
135 properties,
136 logger,
137 #[cfg(feature = "profile-tracy")]
139 gpu_client: client
140 .clone()
141 .new_gpu_context(
142 Some(&format!("{info:?}")),
143 tracy_client::GpuContextType::Invalid,
145 0, 1.0, )
148 .unwrap(),
149 #[cfg(feature = "profile-tracy")]
150 epoch_time: web_time::Instant::now(),
151 info,
152 layout_policy: allocator,
153 check_mode: GlobalConfig::get().compilation.check_mode,
154 }
155 }
156}
157
158#[derive(Error, Clone)]
160#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
161pub enum LaunchError {
162 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
164 CompilationError(#[from] CompilationError),
165
166 #[error(
168 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
169 )]
170 OutOfMemory {
171 reason: String,
173 #[cfg_attr(std_io, serde(skip))]
175 backtrace: BackTrace,
176 },
177
178 #[error("Too many resources were requested during launch\n{0}")]
180 TooManyResources(#[from] ResourceLimitError),
181
182 #[error(
184 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
185 )]
186 Unknown {
187 reason: String,
189 #[cfg_attr(std_io, serde(skip))]
191 backtrace: BackTrace,
192 },
193
194 #[error("An io error happened during launch\nCaused by:\n {0}")]
196 IoError(#[from] IoError),
197}
198
199#[derive(Error, Clone)]
201#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
202pub enum ResourceLimitError {
203 #[error(
205 "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
206 )]
207 SharedMemory {
208 requested: usize,
210 max: usize,
212 #[cfg_attr(std_io, serde(skip))]
214 backtrace: BackTrace,
215 },
216 #[error(
218 "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
219 )]
220 Units {
221 requested: u32,
223 max: u32,
225 #[cfg_attr(std_io, serde(skip))]
227 backtrace: BackTrace,
228 },
229 #[error(
231 "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
232 )]
233 CubeDim {
234 requested: (u32, u32, u32),
236 max: (u32, u32, u32),
238 #[cfg_attr(std_io, serde(skip))]
240 backtrace: BackTrace,
241 },
242}
243
244impl core::fmt::Debug for LaunchError {
245 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246 f.write_fmt(format_args!("{self}"))
247 }
248}
249
250impl core::fmt::Debug for ResourceLimitError {
251 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
252 f.write_fmt(format_args!("{self}"))
253 }
254}
255
256#[derive(Error, Debug, Clone)]
258#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
259pub enum ServerError {
260 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
262 Generic {
263 reason: String,
265 #[cfg_attr(std_io, serde(skip))]
267 backtrace: BackTrace,
268 },
269
270 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
272 Launch(#[from] LaunchError),
273
274 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
276 Profile(#[from] ProfileError),
277
278 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
280 Io(#[from] IoError),
281
282 #[error("The server is in an invalid state\nCaused by:\n {errors:?}")]
284 ServerUnhealthy {
285 errors: Vec<Self>,
287 #[cfg_attr(std_io, serde(skip))]
289 backtrace: BackTrace,
290 },
291}
292
293#[derive(Clone, Copy)]
295pub struct StreamErrorMode {
296 pub ignore: bool,
298 pub flush: bool,
300}
301
302pub trait ComputeServer:
307 Send + core::fmt::Debug + ServerCommunication + device::DeviceService + 'static
308where
309 Self: Sized,
310{
311 type Kernel: KernelMetadata;
313 type Info: Debug + Send + Sync;
315 type MemoryLayoutPolicy: MemoryLayoutPolicy;
317 type Storage: ComputeStorage;
319
320 fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId);
322
323 fn staging(
325 &mut self,
326 _sizes: &[usize],
327 _stream_id: StreamId,
328 ) -> Result<Vec<Bytes>, ServerError> {
329 Err(IoError::UnsupportedIoOperation {
330 backtrace: BackTrace::capture(),
331 }
332 .into())
333 }
334
335 fn logger(&self) -> Arc<ServerLogger>;
337
338 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
340
341 fn read(
343 &mut self,
344 descriptors: Vec<CopyDescriptor>,
345 stream_id: StreamId,
346 ) -> DynFut<Result<Vec<Bytes>, ServerError>>;
347
348 fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId);
350
351 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>>;
353
354 fn get_resource(
356 &mut self,
357 binding: Binding,
358 stream_id: StreamId,
359 ) -> Result<ManagedResource<<Self::Storage as ComputeStorage>::Resource>, ServerError>;
360
361 unsafe fn launch(
370 &mut self,
371 kernel: Self::Kernel,
372 count: CubeCount,
373 bindings: KernelArguments,
374 kind: ExecutionMode,
375 stream_id: StreamId,
376 );
377
378 fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError>;
380
381 fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError>;
383
384 fn memory_cleanup(&mut self, stream_id: StreamId);
386
387 fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError>;
389
390 fn end_profile(
392 &mut self,
393 stream_id: StreamId,
394 token: ProfilingToken,
395 ) -> Result<ProfileDuration, ProfileError>;
396
397 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
399}
400
401pub enum ReduceOperation {
403 Sum,
405 Mean,
407}
408
409pub trait ServerCommunication {
412 const SERVER_COMM_ENABLED: bool;
414
415 #[allow(unused_variables)]
425 fn sync_collective(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
426 todo!() }
428
429 #[allow(unused_variables)]
444 fn all_reduce(
445 &mut self,
446 src: Binding,
447 dst: Binding,
448 dtype: ElemType,
449 stream_id: StreamId,
450 op: ReduceOperation,
451 device_ids: Vec<DeviceId>,
452 ) -> Result<(), ServerError> {
453 unimplemented!()
454 }
455
456 #[allow(unused_variables)]
475 fn copy(
476 handle_dst: Handle,
477 server_src: &mut Self,
478 server_dst: &mut Self,
479 src: CopyDescriptor,
480 stream_id_src: StreamId,
481 stream_id_dst: StreamId,
482 ) -> Result<(), ServerError> {
483 if !Self::SERVER_COMM_ENABLED {
484 panic!("Server-to-server communication is not supported by this server.");
485 } else {
486 panic!(
487 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
488 );
489 }
490 }
491}
492
493#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
494pub struct ProfilingToken {
496 pub id: u64,
498}
499
500#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
502pub enum MemoryLayoutStrategy {
503 Contiguous,
505 Optimized,
508}
509
510#[derive(new, Debug, Clone)]
512pub struct MemoryLayoutDescriptor {
513 pub strategy: MemoryLayoutStrategy,
515 pub shape: Shape,
517 pub elem_size: usize,
519}
520
521impl MemoryLayoutDescriptor {
522 pub fn optimized(shape: Shape, elem_size: usize) -> Self {
524 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size)
525 }
526
527 pub fn contiguous(shape: Shape, elem_size: usize) -> Self {
529 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, elem_size)
530 }
531}
532
533#[derive(Debug, Clone)]
535pub struct MemoryLayout {
536 pub memory: Handle,
538 pub strides: Strides,
542}
543
544impl MemoryLayout {
545 pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
547 MemoryLayout {
548 memory: handle,
549 strides: strides.into(),
550 }
551 }
552}
553
554#[derive(Default, Clone)]
556pub struct Reason {
557 inner: ReasonInner,
558}
559
560#[cfg(std_io)]
561mod _reason_serde {
562 use super::*;
563
564 use alloc::string::ToString;
565 use serde::{Deserialize, Deserializer, Serialize, Serializer};
566
567 impl Serialize for Reason {
568 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569 where
570 S: Serializer,
571 {
572 serializer.serialize_str(&self.to_string())
574 }
575 }
576
577 impl<'de> Deserialize<'de> for Reason {
578 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
579 where
580 D: Deserializer<'de>,
581 {
582 let s = String::deserialize(deserializer)?;
584
585 Ok(Reason {
588 inner: ReasonInner::Dynamic(Arc::new(s)),
589 })
590 }
591 }
592}
593
594#[derive(Default, Clone)]
595enum ReasonInner {
596 Static(&'static str),
597 Dynamic(Arc<String>),
598 #[default]
599 NotProvided,
600}
601
602impl core::fmt::Display for Reason {
603 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
604 match &self.inner {
605 ReasonInner::Static(content) => f.write_str(content),
606 ReasonInner::Dynamic(content) => f.write_str(content),
607 ReasonInner::NotProvided => f.write_str("No reason provided for the error"),
608 }
609 }
610}
611
612impl core::fmt::Debug for Reason {
613 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
614 core::fmt::Display::fmt(&self, f)
615 }
616}
617
618impl From<&'static str> for Reason {
619 fn from(value: &'static str) -> Self {
620 Self {
621 inner: ReasonInner::Static(value),
622 }
623 }
624}
625
626impl From<String> for Reason {
627 fn from(value: String) -> Self {
628 Self {
629 inner: ReasonInner::Dynamic(Arc::new(value)),
630 }
631 }
632}
633
634#[derive(Error, Clone)]
637#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
638pub enum IoError {
639 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
641 BufferTooBig {
642 size: u64,
644 #[cfg_attr(std_io, serde(skip))]
646 backtrace: BackTrace,
647 },
648
649 #[error("the provided strides are not supported for this operation\n{backtrace}")]
651 UnsupportedStrides {
652 #[cfg_attr(std_io, serde(skip))]
654 backtrace: BackTrace,
655 },
656
657 #[error("couldn't find resource for that handle: {reason}\n{backtrace}")]
659 NotFound {
660 #[cfg_attr(std_io, serde(skip))]
662 backtrace: BackTrace,
663 reason: Reason,
665 },
666
667 #[error("couldn't free the handle, since it is currently in used. \n{backtrace}")]
669 FreeError {
670 #[cfg_attr(std_io, serde(skip))]
672 backtrace: BackTrace,
673 },
674
675 #[error("Unknown error happened during execution\n{backtrace}")]
677 Unknown {
678 description: String,
680 #[cfg_attr(std_io, serde(skip))]
682 backtrace: BackTrace,
683 },
684
685 #[error("The current IO operation is not supported\n{backtrace}")]
687 UnsupportedIoOperation {
688 #[cfg_attr(std_io, serde(skip))]
690 backtrace: BackTrace,
691 },
692
693 #[error("Can't perform the IO operation because of a runtime error: {0}")]
695 Execution(#[from] Box<ServerError>),
696}
697
698impl core::fmt::Debug for IoError {
699 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
700 f.write_fmt(format_args!("{self}"))
701 }
702}
703
704#[derive(Debug, Default)]
706pub struct KernelArguments {
707 pub buffers: Vec<Binding>,
709 pub info: MetadataBindingInfo,
712 pub tensor_maps: Vec<TensorMapBinding>,
714}
715
716impl core::fmt::Display for KernelArguments {
717 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
718 f.write_str("KernelArguments")?;
719 for b in self.buffers.iter() {
720 f.write_fmt(format_args!("\n - buffer: {b:?}\n"))?;
721 }
722
723 Ok(())
724 }
725}
726
727impl KernelArguments {
728 pub fn new() -> Self {
730 Self::default()
731 }
732
733 pub fn with_buffer(mut self, binding: Binding) -> Self {
735 self.buffers.push(binding);
736 self
737 }
738
739 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
741 self.buffers.extend(bindings);
742 self
743 }
744
745 pub fn with_info(mut self, info: MetadataBindingInfo) -> Self {
747 self.info = info;
748 self
749 }
750
751 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
753 self.tensor_maps.extend(bindings);
754 self
755 }
756}
757
758#[derive(new, Debug, Default)]
763pub struct MetadataBindingInfo {
764 pub data: Vec<u64>,
766 pub dynamic_metadata_offset: usize,
768}
769
770impl MetadataBindingInfo {
771 pub fn custom(data: Vec<u64>) -> Self {
773 Self::new(data, 0)
774 }
775}
776
777#[derive(new, Debug)]
779pub struct CopyDescriptor {
780 pub handle: Binding,
782 pub shape: Shape,
784 pub strides: Strides,
786 pub elem_size: usize,
788}
789
790#[derive(new, Debug)]
792pub struct TensorMapBinding {
793 pub binding: Binding,
795 pub map: TensorMapMeta,
797}
798
799#[derive(Debug, Clone)]
801pub struct TensorMapMeta {
802 pub format: TensorMapFormat,
804 pub metadata: Metadata,
806 pub elem_stride: Strides,
809 pub interleave: TensorMapInterleave,
811 pub swizzle: TensorMapSwizzle,
813 pub prefetch: TensorMapPrefetch,
815 pub oob_fill: OobFill,
817 pub storage_ty: StorageType,
819}
820
821#[allow(clippy::large_enum_variant)]
825pub enum CubeCount {
826 Static(u32, u32, u32),
828 Dynamic(Binding),
830}
831
832pub enum CubeCountSelection {
834 Exact(CubeCount),
836 Approx(CubeCount, u32),
840}
841
842impl CubeCountSelection {
843 pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
845 let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
846
847 let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
848 let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
849
850 match num_cubes_actual == num_cubes {
851 true => CubeCountSelection::Exact(cube_count),
852 false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
853 }
854 }
855
856 pub fn has_idle(&self) -> bool {
858 matches!(self, Self::Approx(..))
859 }
860
861 pub fn cube_count(self) -> CubeCount {
863 match self {
864 CubeCountSelection::Exact(cube_count) => cube_count,
865 CubeCountSelection::Approx(cube_count, _) => cube_count,
866 }
867 }
868}
869
870impl From<CubeCountSelection> for CubeCount {
871 fn from(value: CubeCountSelection) -> Self {
872 value.cube_count()
873 }
874}
875
876impl CubeCount {
877 pub fn new_single() -> Self {
879 CubeCount::Static(1, 1, 1)
880 }
881
882 pub fn new_1d(x: u32) -> Self {
884 CubeCount::Static(x, 1, 1)
885 }
886
887 pub fn new_2d(x: u32, y: u32) -> Self {
889 CubeCount::Static(x, y, 1)
890 }
891
892 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
894 CubeCount::Static(x, y, z)
895 }
896
897 pub fn is_empty(&self) -> bool {
899 match self {
900 Self::Static(x, y, z) => *x == 0 || *y == 0 || *z == 0,
901 Self::Dynamic(_) => false,
902 }
903 }
904}
905
906impl Debug for CubeCount {
907 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
908 match self {
909 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
910 CubeCount::Dynamic(_) => f.write_str("binding"),
911 }
912 }
913}
914
915impl Clone for CubeCount {
916 fn clone(&self) -> Self {
917 match self {
918 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
919 Self::Dynamic(binding) => Self::Dynamic(binding.clone()),
920 }
921 }
922}
923
924#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
925#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
926#[allow(missing_docs)]
927pub struct CubeDim {
929 pub x: u32,
931 pub y: u32,
933 pub z: u32,
935}
936
937impl CubeDim {
938 pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
946 let properties = client.properties();
947 let plane_size = properties.hardware.plane_size_max;
948 let plane_count = Self::calculate_plane_count_per_cube(
949 working_units as u32,
950 plane_size,
951 properties.hardware.num_cpu_cores,
952 );
953
954 let limit = properties.hardware.max_units_per_cube / plane_size;
956
957 Self::new_2d(plane_size, u32::min(limit, plane_count).max(1))
959 }
960
961 fn calculate_plane_count_per_cube(
962 working_units: u32,
963 plane_dim: u32,
964 num_cpu_cores: Option<u32>,
965 ) -> u32 {
966 match num_cpu_cores {
967 Some(num_cores) => core::cmp::min(num_cores, working_units),
968 None => {
969 let plane_count_max = core::cmp::max(1, working_units / plane_dim);
970
971 const NUM_PLANE_MAX: u32 = 8u32;
973 const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
974 let plane_count_max_log2 =
975 core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
976 2u32.pow(plane_count_max_log2)
977 }
978 }
979 }
980
981 pub const fn new_single() -> Self {
983 Self { x: 1, y: 1, z: 1 }
984 }
985
986 pub const fn new_1d(x: u32) -> Self {
988 Self { x, y: 1, z: 1 }
989 }
990
991 pub const fn new_2d(x: u32, y: u32) -> Self {
993 Self { x, y, z: 1 }
994 }
995
996 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
999 Self { x, y, z }
1000 }
1001
1002 pub const fn num_elems(&self) -> u32 {
1004 self.x * self.y * self.z
1005 }
1006
1007 pub const fn can_contain(&self, other: CubeDim) -> bool {
1009 self.x >= other.x && self.y >= other.y && self.z >= other.z
1010 }
1011}
1012
1013impl From<(u32, u32, u32)> for CubeDim {
1014 fn from(value: (u32, u32, u32)) -> Self {
1015 CubeDim::new_3d(value.0, value.1, value.2)
1016 }
1017}
1018
1019impl From<CubeDim> for (u32, u32, u32) {
1020 fn from(val: CubeDim) -> Self {
1021 (val.x, val.y, val.z)
1022 }
1023}
1024
1025#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
1027#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
1028pub enum ExecutionMode {
1029 #[default]
1031 Checked,
1032 Validate,
1034 Unchecked,
1036}
1037
1038fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1039 let max_cube_counts = [max.0, max.1, max.2];
1040 let mut num_cubes = [num_cubes, 1, 1];
1041 let base = 2;
1042
1043 let mut reduce_count = |i: usize| {
1044 if num_cubes[i] <= max_cube_counts[i] {
1045 return true;
1046 }
1047
1048 loop {
1049 num_cubes[i] = num_cubes[i].div_ceil(base);
1050 num_cubes[i + 1] *= base;
1051
1052 if num_cubes[i] <= max_cube_counts[i] {
1053 return false;
1054 }
1055 }
1056 };
1057
1058 for i in 0..2 {
1059 if reduce_count(i) {
1060 break;
1061 }
1062 }
1063
1064 num_cubes
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070
1071 #[test_log::test]
1072 fn safe_num_cubes_even() {
1073 let max = (32, 32, 32);
1074 let required = 2048;
1075
1076 let actual = cube_count_spread(&max, required);
1077 let expected = [32, 32, 2];
1078 assert_eq!(actual, expected);
1079 }
1080
1081 #[test_log::test]
1082 fn safe_num_cubes_odd() {
1083 let max = (48, 32, 16);
1084 let required = 3177;
1085
1086 let actual = cube_count_spread(&max, required);
1087 let expected = [25, 32, 4];
1088 assert_eq!(actual, expected);
1089 }
1090}