1use crate::{
2 client::ComputeClient,
3 compiler::CompilationError,
4 kernel::KernelMetadata,
5 logging::ServerLogger,
6 memory_management::{
7 MemoryAllocationMode, MemoryHandle, MemoryUsage,
8 memory_pool::{SliceBinding, SliceHandle},
9 },
10 runtime::Runtime,
11 storage::{BindingResource, ComputeStorage},
12 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::collections::BTreeMap;
15#[cfg(feature = "profile-tracy")]
16use alloc::format;
17use alloc::string::String;
18use alloc::sync::Arc;
19use alloc::vec;
20use alloc::vec::Vec;
21use core::fmt::Debug;
22use cubecl_common::{
23 backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
24 stream_id::StreamId,
25};
26use cubecl_ir::{DeviceProperties, StorageType};
27use cubecl_zspace::{Strides, metadata::Metadata};
28use serde::{Deserialize, Serialize};
29use thiserror::Error;
30
31#[derive(Error, Clone)]
32pub enum ProfileError {
34 #[error(
36 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
37 )]
38 Unknown {
39 reason: String,
41 backtrace: BackTrace,
43 },
44
45 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
47 NotRegistered {
48 backtrace: BackTrace,
50 },
51
52 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
54 Launch(#[from] LaunchError),
55
56 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
58 Execution(#[from] ExecutionError),
59}
60
61impl core::fmt::Debug for ProfileError {
62 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63 f.write_fmt(format_args!("{self}"))
64 }
65}
66
67pub struct ServerUtilities<Server: ComputeServer> {
69 #[cfg(feature = "profile-tracy")]
71 pub epoch_time: web_time::Instant,
72 #[cfg(feature = "profile-tracy")]
74 pub gpu_client: tracy_client::GpuContext,
75 pub properties: DeviceProperties,
77 pub properties_hash: u64,
79 pub info: Server::Info,
81 pub logger: Arc<ServerLogger>,
83}
84
85impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
86where
87 Server: ComputeServer,
88 Server::Info: core::fmt::Debug,
89{
90 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
91 f.debug_struct("ServerUtilities")
92 .field("properties", &self.properties)
93 .field("info", &self.info)
94 .field("logger", &self.logger)
95 .finish()
96 }
97}
98
99impl<S: ComputeServer> ServerUtilities<S> {
100 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
102 #[cfg(feature = "profile-tracy")]
104 let client = tracy_client::Client::start();
105
106 Self {
107 properties_hash: properties.checksum(),
108 properties,
109 logger,
110 #[cfg(feature = "profile-tracy")]
112 gpu_client: client
113 .clone()
114 .new_gpu_context(
115 Some(&format!("{info:?}")),
116 tracy_client::GpuContextType::Invalid,
118 0, 1.0, )
121 .unwrap(),
122 #[cfg(feature = "profile-tracy")]
123 epoch_time: web_time::Instant::now(),
124 info,
125 }
126 }
127}
128
129#[derive(Error, Clone)]
131#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
132pub enum LaunchError {
133 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
135 CompilationError(#[from] CompilationError),
136
137 #[error(
139 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
140 )]
141 OutOfMemory {
142 reason: String,
144 #[cfg_attr(std_io, serde(skip))]
146 backtrace: BackTrace,
147 },
148
149 #[error("Too many resources were requested during launch\n{0}")]
151 TooManyResources(#[from] ResourceLimitError),
152
153 #[error(
155 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
156 )]
157 Unknown {
158 reason: String,
160 #[cfg_attr(std_io, serde(skip))]
162 backtrace: BackTrace,
163 },
164
165 #[error("An io error happened during launch\nCaused by:\n {0}")]
167 IoError(#[from] IoError),
168}
169
170#[derive(Error, Clone)]
172#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
173pub enum ResourceLimitError {
174 #[error(
176 "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
177 )]
178 SharedMemory {
179 requested: usize,
181 max: usize,
183 #[cfg_attr(std_io, serde(skip))]
185 backtrace: BackTrace,
186 },
187 #[error(
189 "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
190 )]
191 Units {
192 requested: u32,
194 max: u32,
196 #[cfg_attr(std_io, serde(skip))]
198 backtrace: BackTrace,
199 },
200 #[error(
202 "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
203 )]
204 CubeDim {
205 requested: (u32, u32, u32),
207 max: (u32, u32, u32),
209 #[cfg_attr(std_io, serde(skip))]
211 backtrace: BackTrace,
212 },
213}
214
215impl core::fmt::Debug for LaunchError {
216 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
217 f.write_fmt(format_args!("{self}"))
218 }
219}
220
221impl core::fmt::Debug for ResourceLimitError {
222 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
223 f.write_fmt(format_args!("{self}"))
224 }
225}
226
227#[derive(Error, Debug, Clone)]
229#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
230pub enum ExecutionError {
231 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
233 Generic {
234 reason: String,
236 #[cfg_attr(std_io, serde(skip))]
238 backtrace: BackTrace,
239 },
240}
241
242pub trait ComputeServer:
247 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
248where
249 Self: Sized,
250{
251 type Kernel: KernelMetadata;
253 type Info: Debug + Send + Sync;
255 type Storage: ComputeStorage;
257
258 fn create(
260 &mut self,
261 descriptors: Vec<AllocationDescriptor<'_>>,
262 stream_id: StreamId,
263 ) -> Result<Vec<Allocation>, IoError>;
264
265 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
267 Err(IoError::UnsupportedIoOperation {
268 backtrace: BackTrace::capture(),
269 })
270 }
271
272 fn logger(&self) -> Arc<ServerLogger>;
274
275 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
277
278 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
280 let alloc = self
281 .create(
282 vec![AllocationDescriptor::new(
283 AllocationKind::Contiguous,
284 &[data.len()],
285 1,
286 )],
287 stream_id,
288 )?
289 .remove(0);
290 self.write(
291 vec![(
292 CopyDescriptor::new(
293 alloc.handle.clone().binding(),
294 &[data.len()],
295 &alloc.strides,
296 1,
297 ),
298 Bytes::from_bytes_vec(data.to_vec()),
299 )],
300 stream_id,
301 )?;
302 Ok(alloc.handle)
303 }
304
305 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
307 let alloc = self
308 .create(
309 vec![AllocationDescriptor::new(
310 AllocationKind::Contiguous,
311 &[data.len()],
312 1,
313 )],
314 stream_id,
315 )?
316 .remove(0);
317 self.write(
318 vec![(
319 CopyDescriptor::new(
320 alloc.handle.clone().binding(),
321 &[data.len()],
322 &alloc.strides,
323 1,
324 ),
325 data,
326 )],
327 stream_id,
328 )?;
329 Ok(alloc.handle)
330 }
331
332 fn read<'a>(
334 &mut self,
335 descriptors: Vec<CopyDescriptor<'a>>,
336 stream_id: StreamId,
337 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
338
339 fn write(
341 &mut self,
342 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
343 stream_id: StreamId,
344 ) -> Result<(), IoError>;
345
346 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
348
349 fn get_resource(
351 &mut self,
352 binding: Binding,
353 stream_id: StreamId,
354 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
355
356 unsafe fn launch(
365 &mut self,
366 kernel: Self::Kernel,
367 count: CubeCount,
368 bindings: Bindings,
369 kind: ExecutionMode,
370 stream_id: StreamId,
371 ) -> Result<(), LaunchError>;
372
373 fn flush(&mut self, stream_id: StreamId);
375
376 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
378
379 fn memory_cleanup(&mut self, stream_id: StreamId);
381
382 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
384
385 fn end_profile(
387 &mut self,
388 stream_id: StreamId,
389 token: ProfilingToken,
390 ) -> Result<ProfileDuration, ProfileError>;
391
392 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
394}
395
396pub trait ServerCommunication {
399 const SERVER_COMM_ENABLED: bool;
401
402 #[allow(unused_variables)]
421 fn copy(
422 server_src: &mut Self,
423 server_dst: &mut Self,
424 src: CopyDescriptor<'_>,
425 stream_id_src: StreamId,
426 stream_id_dst: StreamId,
427 ) -> Result<Allocation, IoError> {
428 if !Self::SERVER_COMM_ENABLED {
429 panic!("Server-to-server communication is not supported by this server.");
430 } else {
431 panic!(
432 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
433 );
434 }
435 }
436}
437
438#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
439pub struct ProfilingToken {
441 pub id: u64,
443}
444
445#[derive(new, Debug, PartialEq, Eq)]
447pub struct Handle {
448 pub memory: SliceHandle,
450 pub offset_start: Option<u64>,
452 pub offset_end: Option<u64>,
454 pub stream: cubecl_common::stream_id::StreamId,
456 pub cursor: u64,
458 size: u64,
460}
461
462#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
464pub enum AllocationKind {
465 Contiguous,
467 Optimized,
470}
471
472#[derive(new, Debug, Clone, Copy)]
474pub struct AllocationDescriptor<'a> {
475 pub kind: AllocationKind,
477 pub shape: &'a [usize],
479 pub elem_size: usize,
481}
482
483impl<'a> AllocationDescriptor<'a> {
484 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
486 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
487 }
488
489 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
491 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
492 }
493}
494
495#[derive(Debug)]
497pub struct Allocation {
498 pub handle: Handle,
500 pub strides: Strides,
502}
503
504impl Allocation {
505 pub fn new(handle: Handle, strides: impl Into<Strides>) -> Self {
507 Allocation {
508 handle,
509 strides: strides.into(),
510 }
511 }
512}
513
514#[derive(Error, Clone)]
517#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
518pub enum IoError {
519 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
521 BufferTooBig {
522 size: u64,
524 #[cfg_attr(std_io, serde(skip))]
526 backtrace: BackTrace,
527 },
528
529 #[error("the provided strides are not supported for this operation\n{backtrace}")]
531 UnsupportedStrides {
532 #[cfg_attr(std_io, serde(skip))]
534 backtrace: BackTrace,
535 },
536
537 #[error("couldn't find resource for that handle\n{backtrace}")]
539 InvalidHandle {
540 #[cfg_attr(std_io, serde(skip))]
542 backtrace: BackTrace,
543 },
544
545 #[error("Unknown error happened during execution\n{backtrace}")]
547 Unknown {
548 description: String,
550 #[cfg_attr(std_io, serde(skip))]
552 backtrace: BackTrace,
553 },
554
555 #[error("The current IO operation is not supported\n{backtrace}")]
557 UnsupportedIoOperation {
558 #[cfg_attr(std_io, serde(skip))]
560 backtrace: BackTrace,
561 },
562
563 #[error("Can't perform the IO operation because of a runtime error")]
565 Execution(#[from] ExecutionError),
566}
567
568impl core::fmt::Debug for IoError {
569 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
570 f.write_fmt(format_args!("{self}"))
571 }
572}
573
574impl Handle {
575 pub fn offset_start(mut self, offset: u64) -> Self {
577 if let Some(val) = &mut self.offset_start {
578 *val += offset;
579 } else {
580 self.offset_start = Some(offset);
581 }
582
583 self
584 }
585 pub fn offset_end(mut self, offset: u64) -> Self {
587 if let Some(val) = &mut self.offset_end {
588 *val += offset;
589 } else {
590 self.offset_end = Some(offset);
591 }
592
593 self
594 }
595
596 pub fn size(&self) -> u64 {
598 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
599 }
600}
601
602#[derive(Debug, Default)]
604pub struct Bindings {
605 pub buffers: Vec<Binding>,
607 pub metadata: MetadataBinding,
610 pub scalars: BTreeMap<StorageType, ScalarBinding>,
612 pub tensor_maps: Vec<TensorMapBinding>,
614}
615
616impl Bindings {
617 pub fn new() -> Self {
619 Self::default()
620 }
621
622 pub fn with_buffer(mut self, binding: Binding) -> Self {
624 self.buffers.push(binding);
625 self
626 }
627
628 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
630 self.buffers.extend(bindings);
631 self
632 }
633
634 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
636 self.scalars
637 .insert(ty, ScalarBinding::new(ty, length, data));
638 self
639 }
640
641 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
643 self.scalars
644 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
645 self
646 }
647
648 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
650 self.metadata = meta;
651 self
652 }
653
654 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
656 self.tensor_maps.extend(bindings);
657 self
658 }
659}
660
661#[derive(new, Debug, Default)]
663pub struct MetadataBinding {
664 pub data: Vec<u64>,
666 pub static_len: usize,
668}
669
670#[derive(new, Debug, Clone)]
672pub struct ScalarBinding {
673 pub ty: StorageType,
675 pub length: usize,
677 pub data: Vec<u64>,
679}
680
681impl ScalarBinding {
682 pub fn data(&self) -> &[u8] {
684 bytemuck::cast_slice(&self.data)
685 }
686}
687
688#[derive(new, Debug)]
690pub struct Binding {
691 pub memory: SliceBinding,
693 pub offset_start: Option<u64>,
695 pub offset_end: Option<u64>,
697 pub stream: cubecl_common::stream_id::StreamId,
699 pub cursor: u64,
701 size: u64,
703}
704
705impl Binding {
706 pub fn size(&self) -> u64 {
708 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
709 }
710}
711
712#[derive(new, Debug, Clone)]
714pub struct CopyDescriptor<'a> {
715 pub binding: Binding,
717 pub shape: &'a [usize],
719 pub strides: &'a [usize],
721 pub elem_size: usize,
723}
724
725#[derive(new, Debug, Clone)]
727pub struct TensorMapBinding {
728 pub binding: Binding,
730 pub map: TensorMapMeta,
732}
733
734#[derive(Debug, Clone)]
736pub struct TensorMapMeta {
737 pub format: TensorMapFormat,
739 pub metadata: Metadata,
741 pub elem_stride: Strides,
744 pub interleave: TensorMapInterleave,
746 pub swizzle: TensorMapSwizzle,
748 pub prefetch: TensorMapPrefetch,
750 pub oob_fill: OobFill,
752 pub storage_ty: StorageType,
754}
755
756impl Handle {
757 pub fn can_mut(&self) -> bool {
759 self.memory.can_mut() && self.stream == StreamId::current()
760 }
761}
762
763impl Handle {
764 pub fn binding(self) -> Binding {
766 Binding {
767 memory: MemoryHandle::binding(self.memory),
768 offset_start: self.offset_start,
769 offset_end: self.offset_end,
770 size: self.size,
771 stream: self.stream,
772 cursor: self.cursor,
773 }
774 }
775
776 pub fn copy_descriptor<'a>(
778 &'a self,
779 shape: &'a [usize],
780 strides: &'a [usize],
781 elem_size: usize,
782 ) -> CopyDescriptor<'a> {
783 CopyDescriptor {
784 shape,
785 strides,
786 elem_size,
787 binding: self.clone().binding(),
788 }
789 }
790}
791
792impl Clone for Handle {
793 fn clone(&self) -> Self {
794 Self {
795 memory: self.memory.clone(),
796 offset_start: self.offset_start,
797 offset_end: self.offset_end,
798 size: self.size,
799 stream: self.stream,
800 cursor: self.cursor,
801 }
802 }
803}
804
805impl Clone for Binding {
806 fn clone(&self) -> Self {
807 Self {
808 memory: self.memory.clone(),
809 offset_start: self.offset_start,
810 offset_end: self.offset_end,
811 size: self.size,
812 stream: self.stream,
813 cursor: self.cursor,
814 }
815 }
816}
817
818#[allow(clippy::large_enum_variant)]
822pub enum CubeCount {
823 Static(u32, u32, u32),
825 Dynamic(Binding),
827}
828
829pub enum CubeCountSelection {
831 Exact(CubeCount),
833 Approx(CubeCount, u32),
837}
838
839impl CubeCountSelection {
840 pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
842 let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
843
844 let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
845 let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
846
847 match num_cubes_actual == num_cubes {
848 true => CubeCountSelection::Exact(cube_count),
849 false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
850 }
851 }
852
853 pub fn has_idle(&self) -> bool {
855 matches!(self, Self::Approx(..))
856 }
857
858 pub fn cube_count(self) -> CubeCount {
860 match self {
861 CubeCountSelection::Exact(cube_count) => cube_count,
862 CubeCountSelection::Approx(cube_count, _) => cube_count,
863 }
864 }
865}
866
867impl From<CubeCountSelection> for CubeCount {
868 fn from(value: CubeCountSelection) -> Self {
869 value.cube_count()
870 }
871}
872
873impl CubeCount {
874 pub fn new_single() -> Self {
876 CubeCount::Static(1, 1, 1)
877 }
878
879 pub fn new_1d(x: u32) -> Self {
881 CubeCount::Static(x, 1, 1)
882 }
883
884 pub fn new_2d(x: u32, y: u32) -> Self {
886 CubeCount::Static(x, y, 1)
887 }
888
889 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
891 CubeCount::Static(x, y, z)
892 }
893}
894
895impl Debug for CubeCount {
896 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
897 match self {
898 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
899 CubeCount::Dynamic(_) => f.write_str("binding"),
900 }
901 }
902}
903
904impl Clone for CubeCount {
905 fn clone(&self) -> Self {
906 match self {
907 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
908 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
909 }
910 }
911}
912
913#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
914#[allow(missing_docs)]
915pub struct CubeDim {
917 pub x: u32,
919 pub y: u32,
921 pub z: u32,
923}
924
925impl CubeDim {
926 pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
934 let properties = client.properties();
935 let plane_size = properties.hardware.plane_size_max;
936 let plane_count = Self::calculate_plane_count_per_cube(
937 working_units as u32,
938 plane_size,
939 properties.hardware.num_cpu_cores,
940 );
941
942 let limit = properties.hardware.max_units_per_cube / plane_size;
944
945 Self::new_2d(plane_size, u32::min(limit, plane_count))
946 }
947
948 fn calculate_plane_count_per_cube(
949 working_units: u32,
950 plane_dim: u32,
951 num_cpu_cores: Option<u32>,
952 ) -> u32 {
953 match num_cpu_cores {
954 Some(num_cores) => core::cmp::min(num_cores, working_units),
955 None => {
956 let plane_count_max = core::cmp::max(1, working_units / plane_dim);
957
958 const NUM_PLANE_MAX: u32 = 8u32;
960 const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
961 let plane_count_max_log2 =
962 core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
963 2u32.pow(plane_count_max_log2)
964 }
965 }
966 }
967
968 pub const fn new_single() -> Self {
970 Self { x: 1, y: 1, z: 1 }
971 }
972
973 pub const fn new_1d(x: u32) -> Self {
975 Self { x, y: 1, z: 1 }
976 }
977
978 pub const fn new_2d(x: u32, y: u32) -> Self {
980 Self { x, y, z: 1 }
981 }
982
983 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
986 Self { x, y, z }
987 }
988
989 pub const fn num_elems(&self) -> u32 {
991 self.x * self.y * self.z
992 }
993
994 pub const fn can_contain(&self, other: CubeDim) -> bool {
996 self.x >= other.x && self.y >= other.y && self.z >= other.z
997 }
998}
999
1000impl From<(u32, u32, u32)> for CubeDim {
1001 fn from(value: (u32, u32, u32)) -> Self {
1002 CubeDim::new_3d(value.0, value.1, value.2)
1003 }
1004}
1005
1006impl From<CubeDim> for (u32, u32, u32) {
1007 fn from(val: CubeDim) -> Self {
1008 (val.x, val.y, val.z)
1009 }
1010}
1011
1012#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
1014pub enum ExecutionMode {
1015 #[default]
1017 Checked,
1018 Unchecked,
1020}
1021
1022fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1023 let max_cube_counts = [max.0, max.1, max.2];
1024 let mut num_cubes = [num_cubes, 1, 1];
1025 let base = 2;
1026
1027 let mut reduce_count = |i: usize| {
1028 if num_cubes[i] <= max_cube_counts[i] {
1029 return true;
1030 }
1031
1032 loop {
1033 num_cubes[i] = num_cubes[i].div_ceil(base);
1034 num_cubes[i + 1] *= base;
1035
1036 if num_cubes[i] <= max_cube_counts[i] {
1037 return false;
1038 }
1039 }
1040 };
1041
1042 for i in 0..2 {
1043 if reduce_count(i) {
1044 break;
1045 }
1046 }
1047
1048 num_cubes
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054
1055 #[test_log::test]
1056 fn safe_num_cubes_even() {
1057 let max = (32, 32, 32);
1058 let required = 2048;
1059
1060 let actual = cube_count_spread(&max, required);
1061 let expected = [32, 32, 2];
1062 assert_eq!(actual, expected);
1063 }
1064
1065 #[test_log::test]
1066 fn safe_num_cubes_odd() {
1067 let max = (48, 32, 16);
1068 let required = 3177;
1069
1070 let actual = cube_count_spread(&max, required);
1071 let expected = [25, 32, 4];
1072 assert_eq!(actual, expected);
1073 }
1074}