1use crate::{
2 client::ComputeClient,
3 compiler::CompilationError,
4 kernel::KernelMetadata,
5 logging::ServerLogger,
6 memory_management::{
7 MemoryAllocationMode, MemoryHandle, MemoryUsage,
8 memory_pool::{SliceBinding, SliceHandle},
9 },
10 runtime::Runtime,
11 storage::{BindingResource, ComputeStorage},
12 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
13};
14use alloc::collections::BTreeMap;
15use alloc::string::String;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::fmt::Debug;
20use cubecl_common::{
21 backtrace::BackTrace, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
22 stream_id::StreamId,
23};
24use cubecl_ir::{DeviceProperties, StorageType};
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28#[derive(Error, Clone)]
29pub enum ProfileError {
31 #[error(
33 "An unknown error happened during profiling\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
34 )]
35 Unknown {
36 reason: String,
38 backtrace: BackTrace,
40 },
41
42 #[error("No profiling registered\nBacktrace:\n{backtrace}")]
44 NotRegistered {
45 backtrace: BackTrace,
47 },
48
49 #[error("A launch error happened during profiling\nCaused by:\n {0}")]
51 Launch(#[from] LaunchError),
52
53 #[error("An execution error happened during profiling\nCaused by:\n {0}")]
55 Execution(#[from] ExecutionError),
56}
57
58impl core::fmt::Debug for ProfileError {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60 f.write_fmt(format_args!("{self}"))
61 }
62}
63
64pub struct ServerUtilities<Server: ComputeServer> {
66 #[cfg(feature = "profile-tracy")]
68 pub epoch_time: web_time::Instant,
69 #[cfg(feature = "profile-tracy")]
71 pub gpu_client: tracy_client::GpuContext,
72 pub properties: DeviceProperties,
74 pub properties_hash: u64,
76 pub info: Server::Info,
78 pub logger: Arc<ServerLogger>,
80}
81
82impl<Server: core::fmt::Debug> core::fmt::Debug for ServerUtilities<Server>
83where
84 Server: ComputeServer,
85 Server::Info: core::fmt::Debug,
86{
87 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
88 f.debug_struct("ServerUtilities")
89 .field("properties", &self.properties)
90 .field("info", &self.info)
91 .field("logger", &self.logger)
92 .finish()
93 }
94}
95
96impl<S: ComputeServer> ServerUtilities<S> {
97 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
99 #[cfg(feature = "profile-tracy")]
101 let client = tracy_client::Client::start();
102
103 Self {
104 properties_hash: properties.checksum(),
105 properties,
106 logger,
107 #[cfg(feature = "profile-tracy")]
109 gpu_client: client
110 .clone()
111 .new_gpu_context(
112 Some(&format!("{info:?}")),
113 tracy_client::GpuContextType::Invalid,
115 0, 1.0, )
118 .unwrap(),
119 #[cfg(feature = "profile-tracy")]
120 epoch_time: web_time::Instant::now(),
121 info,
122 }
123 }
124}
125
126#[derive(Error, Clone)]
128#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
129pub enum LaunchError {
130 #[error("A compilation error happened during launch\nCaused by:\n {0}")]
132 CompilationError(#[from] CompilationError),
133
134 #[error(
136 "An out-of-memory error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
137 )]
138 OutOfMemory {
139 reason: String,
141 #[cfg_attr(std_io, serde(skip))]
143 backtrace: BackTrace,
144 },
145
146 #[error("Too many resources were requested during launch\n{0}")]
148 TooManyResources(#[from] ResourceLimitError),
149
150 #[error(
152 "An unknown error happened during launch\nCaused by:\n {reason}\nBacktrace\n{backtrace}"
153 )]
154 Unknown {
155 reason: String,
157 #[cfg_attr(std_io, serde(skip))]
159 backtrace: BackTrace,
160 },
161
162 #[error("An io error happened during launch\nCaused by:\n {0}")]
164 IoError(#[from] IoError),
165}
166
167#[derive(Error, Clone)]
169#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
170pub enum ResourceLimitError {
171 #[error(
173 "Too much shared memory requested.\nRequested {requested} bytes, maximum {max} bytes available.\nBacktrace\n{backtrace}"
174 )]
175 SharedMemory {
176 requested: usize,
178 max: usize,
180 #[cfg_attr(std_io, serde(skip))]
182 backtrace: BackTrace,
183 },
184 #[error(
186 "Total unit count exceeds maximum.\nRequested {requested} units, max units is {max}.\nBacktrace\n{backtrace}"
187 )]
188 Units {
189 requested: u32,
191 max: u32,
193 #[cfg_attr(std_io, serde(skip))]
195 backtrace: BackTrace,
196 },
197 #[error(
199 "Cube dim exceeds maximum bounds.\nRequested {requested:?}, max is {max:?}.\nBacktrace\n{backtrace}"
200 )]
201 CubeDim {
202 requested: (u32, u32, u32),
204 max: (u32, u32, u32),
206 #[cfg_attr(std_io, serde(skip))]
208 backtrace: BackTrace,
209 },
210}
211
212impl core::fmt::Debug for LaunchError {
213 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
214 f.write_fmt(format_args!("{self}"))
215 }
216}
217
218impl core::fmt::Debug for ResourceLimitError {
219 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
220 f.write_fmt(format_args!("{self}"))
221 }
222}
223
224#[derive(Error, Debug, Clone)]
226#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
227pub enum ExecutionError {
228 #[error("An error happened during execution\nCaused by:\n {reason}\nBacktrace:\n{backtrace}")]
230 Generic {
231 reason: String,
233 #[cfg_attr(std_io, serde(skip))]
235 backtrace: BackTrace,
236 },
237}
238
239pub trait ComputeServer:
244 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
245where
246 Self: Sized,
247{
248 type Kernel: KernelMetadata;
250 type Info: Debug + Send + Sync;
252 type Storage: ComputeStorage;
254
255 fn create(
257 &mut self,
258 descriptors: Vec<AllocationDescriptor<'_>>,
259 stream_id: StreamId,
260 ) -> Result<Vec<Allocation>, IoError>;
261
262 fn staging(&mut self, _sizes: &[usize], _stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
264 Err(IoError::UnsupportedIoOperation {
265 backtrace: BackTrace::capture(),
266 })
267 }
268
269 fn logger(&self) -> Arc<ServerLogger>;
271
272 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
274
275 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
277 let alloc = self
278 .create(
279 vec![AllocationDescriptor::new(
280 AllocationKind::Contiguous,
281 &[data.len()],
282 1,
283 )],
284 stream_id,
285 )?
286 .remove(0);
287 self.write(
288 vec![(
289 CopyDescriptor::new(
290 alloc.handle.clone().binding(),
291 &[data.len()],
292 &alloc.strides,
293 1,
294 ),
295 Bytes::from_bytes_vec(data.to_vec()),
296 )],
297 stream_id,
298 )?;
299 Ok(alloc.handle)
300 }
301
302 fn create_with_bytes(&mut self, data: Bytes, stream_id: StreamId) -> Result<Handle, IoError> {
304 let alloc = self
305 .create(
306 vec![AllocationDescriptor::new(
307 AllocationKind::Contiguous,
308 &[data.len()],
309 1,
310 )],
311 stream_id,
312 )?
313 .remove(0);
314 self.write(
315 vec![(
316 CopyDescriptor::new(
317 alloc.handle.clone().binding(),
318 &[data.len()],
319 &alloc.strides,
320 1,
321 ),
322 data,
323 )],
324 stream_id,
325 )?;
326 Ok(alloc.handle)
327 }
328
329 fn read<'a>(
331 &mut self,
332 descriptors: Vec<CopyDescriptor<'a>>,
333 stream_id: StreamId,
334 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
335
336 fn write(
338 &mut self,
339 descriptors: Vec<(CopyDescriptor<'_>, Bytes)>,
340 stream_id: StreamId,
341 ) -> Result<(), IoError>;
342
343 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>>;
345
346 fn get_resource(
348 &mut self,
349 binding: Binding,
350 stream_id: StreamId,
351 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
352
353 unsafe fn launch(
362 &mut self,
363 kernel: Self::Kernel,
364 count: CubeCount,
365 bindings: Bindings,
366 kind: ExecutionMode,
367 stream_id: StreamId,
368 ) -> Result<(), LaunchError>;
369
370 fn flush(&mut self, stream_id: StreamId);
372
373 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
375
376 fn memory_cleanup(&mut self, stream_id: StreamId);
378
379 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
381
382 fn end_profile(
384 &mut self,
385 stream_id: StreamId,
386 token: ProfilingToken,
387 ) -> Result<ProfileDuration, ProfileError>;
388
389 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
391}
392
393pub trait ServerCommunication {
396 const SERVER_COMM_ENABLED: bool;
398
399 #[allow(unused_variables)]
418 fn copy(
419 server_src: &mut Self,
420 server_dst: &mut Self,
421 src: CopyDescriptor<'_>,
422 stream_id_src: StreamId,
423 stream_id_dst: StreamId,
424 ) -> Result<Allocation, IoError> {
425 if !Self::SERVER_COMM_ENABLED {
426 panic!("Server-to-server communication is not supported by this server.");
427 } else {
428 panic!(
429 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
430 );
431 }
432 }
433}
434
435#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
436pub struct ProfilingToken {
438 pub id: u64,
440}
441
442#[derive(new, Debug, PartialEq, Eq)]
444pub struct Handle {
445 pub memory: SliceHandle,
447 pub offset_start: Option<u64>,
449 pub offset_end: Option<u64>,
451 pub stream: cubecl_common::stream_id::StreamId,
453 pub cursor: u64,
455 size: u64,
457}
458
459#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
461pub enum AllocationKind {
462 Contiguous,
464 Optimized,
467}
468
469#[derive(new, Debug, Clone, Copy)]
471pub struct AllocationDescriptor<'a> {
472 pub kind: AllocationKind,
474 pub shape: &'a [usize],
476 pub elem_size: usize,
478}
479
480impl<'a> AllocationDescriptor<'a> {
481 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
483 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
484 }
485
486 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
488 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
489 }
490}
491
492#[derive(new, Debug)]
494pub struct Allocation {
495 pub handle: Handle,
497 pub strides: Vec<usize>,
499}
500
501#[derive(Error, Clone)]
504#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
505pub enum IoError {
506 #[error("can't allocate buffer of size: {size}\n{backtrace}")]
508 BufferTooBig {
509 size: u64,
511 #[cfg_attr(std_io, serde(skip))]
513 backtrace: BackTrace,
514 },
515
516 #[error("the provided strides are not supported for this operation\n{backtrace}")]
518 UnsupportedStrides {
519 #[cfg_attr(std_io, serde(skip))]
521 backtrace: BackTrace,
522 },
523
524 #[error("couldn't find resource for that handle\n{backtrace}")]
526 InvalidHandle {
527 #[cfg_attr(std_io, serde(skip))]
529 backtrace: BackTrace,
530 },
531
532 #[error("Unknown error happened during execution\n{backtrace}")]
534 Unknown {
535 description: String,
537 #[cfg_attr(std_io, serde(skip))]
539 backtrace: BackTrace,
540 },
541
542 #[error("The current IO operation is not supported\n{backtrace}")]
544 UnsupportedIoOperation {
545 #[cfg_attr(std_io, serde(skip))]
547 backtrace: BackTrace,
548 },
549
550 #[error("Can't perform the IO operation because of a runtime error")]
552 Execution(#[from] ExecutionError),
553}
554
555impl core::fmt::Debug for IoError {
556 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
557 f.write_fmt(format_args!("{self}"))
558 }
559}
560
561impl Handle {
562 pub fn offset_start(mut self, offset: u64) -> Self {
564 if let Some(val) = &mut self.offset_start {
565 *val += offset;
566 } else {
567 self.offset_start = Some(offset);
568 }
569
570 self
571 }
572 pub fn offset_end(mut self, offset: u64) -> Self {
574 if let Some(val) = &mut self.offset_end {
575 *val += offset;
576 } else {
577 self.offset_end = Some(offset);
578 }
579
580 self
581 }
582
583 pub fn size(&self) -> u64 {
585 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
586 }
587}
588
589#[derive(Debug, Default)]
591pub struct Bindings {
592 pub buffers: Vec<Binding>,
594 pub metadata: MetadataBinding,
597 pub scalars: BTreeMap<StorageType, ScalarBinding>,
599 pub tensor_maps: Vec<TensorMapBinding>,
601}
602
603impl Bindings {
604 pub fn new() -> Self {
606 Self::default()
607 }
608
609 pub fn with_buffer(mut self, binding: Binding) -> Self {
611 self.buffers.push(binding);
612 self
613 }
614
615 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
617 self.buffers.extend(bindings);
618 self
619 }
620
621 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
623 self.scalars
624 .insert(ty, ScalarBinding::new(ty, length, data));
625 self
626 }
627
628 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
630 self.scalars
631 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
632 self
633 }
634
635 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
637 self.metadata = meta;
638 self
639 }
640
641 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
643 self.tensor_maps.extend(bindings);
644 self
645 }
646}
647
648#[derive(new, Debug, Default)]
650pub struct MetadataBinding {
651 pub data: Vec<u64>,
653 pub static_len: usize,
655}
656
657#[derive(new, Debug, Clone)]
659pub struct ScalarBinding {
660 pub ty: StorageType,
662 pub length: usize,
664 pub data: Vec<u64>,
666}
667
668impl ScalarBinding {
669 pub fn data(&self) -> &[u8] {
671 bytemuck::cast_slice(&self.data)
672 }
673}
674
675#[derive(new, Debug)]
677pub struct Binding {
678 pub memory: SliceBinding,
680 pub offset_start: Option<u64>,
682 pub offset_end: Option<u64>,
684 pub stream: cubecl_common::stream_id::StreamId,
686 pub cursor: u64,
688 size: u64,
690}
691
692impl Binding {
693 pub fn size(&self) -> u64 {
695 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
696 }
697}
698
699#[derive(new, Debug, Clone)]
701pub struct CopyDescriptor<'a> {
702 pub binding: Binding,
704 pub shape: &'a [usize],
706 pub strides: &'a [usize],
708 pub elem_size: usize,
710}
711
712#[derive(new, Debug, Clone)]
714pub struct TensorMapBinding {
715 pub binding: Binding,
717 pub map: TensorMapMeta,
719}
720
721#[derive(Debug, Clone)]
723pub struct TensorMapMeta {
724 pub format: TensorMapFormat,
726 pub rank: usize,
728 pub shape: Vec<usize>,
730 pub strides: Vec<usize>,
732 pub elem_stride: Vec<usize>,
735 pub interleave: TensorMapInterleave,
737 pub swizzle: TensorMapSwizzle,
739 pub prefetch: TensorMapPrefetch,
741 pub oob_fill: OobFill,
743 pub storage_ty: StorageType,
745}
746
747impl Handle {
748 pub fn can_mut(&self) -> bool {
750 self.memory.can_mut() && self.stream == StreamId::current()
751 }
752}
753
754impl Handle {
755 pub fn binding(self) -> Binding {
757 Binding {
758 memory: MemoryHandle::binding(self.memory),
759 offset_start: self.offset_start,
760 offset_end: self.offset_end,
761 size: self.size,
762 stream: self.stream,
763 cursor: self.cursor,
764 }
765 }
766
767 pub fn copy_descriptor<'a>(
769 &'a self,
770 shape: &'a [usize],
771 strides: &'a [usize],
772 elem_size: usize,
773 ) -> CopyDescriptor<'a> {
774 CopyDescriptor {
775 shape,
776 strides,
777 elem_size,
778 binding: self.clone().binding(),
779 }
780 }
781}
782
783impl Clone for Handle {
784 fn clone(&self) -> Self {
785 Self {
786 memory: self.memory.clone(),
787 offset_start: self.offset_start,
788 offset_end: self.offset_end,
789 size: self.size,
790 stream: self.stream,
791 cursor: self.cursor,
792 }
793 }
794}
795
796impl Clone for Binding {
797 fn clone(&self) -> Self {
798 Self {
799 memory: self.memory.clone(),
800 offset_start: self.offset_start,
801 offset_end: self.offset_end,
802 size: self.size,
803 stream: self.stream,
804 cursor: self.cursor,
805 }
806 }
807}
808
809#[allow(clippy::large_enum_variant)]
813pub enum CubeCount {
814 Static(u32, u32, u32),
816 Dynamic(Binding),
818}
819
820pub enum CubeCountSelection {
822 Exact(CubeCount),
824 Approx(CubeCount, u32),
828}
829
830impl CubeCountSelection {
831 pub fn new<R: Runtime>(client: &ComputeClient<R>, num_cubes: u32) -> Self {
833 let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
834
835 let num_cubes_actual = cube_count[0] * cube_count[1] * cube_count[2];
836 let cube_count = CubeCount::Static(cube_count[0], cube_count[1], cube_count[2]);
837
838 match num_cubes_actual == num_cubes {
839 true => CubeCountSelection::Exact(cube_count),
840 false => CubeCountSelection::Approx(cube_count, num_cubes_actual),
841 }
842 }
843
844 pub fn has_idle(&self) -> bool {
846 matches!(self, Self::Approx(..))
847 }
848
849 pub fn cube_count(self) -> CubeCount {
851 match self {
852 CubeCountSelection::Exact(cube_count) => cube_count,
853 CubeCountSelection::Approx(cube_count, _) => cube_count,
854 }
855 }
856}
857
858impl From<CubeCountSelection> for CubeCount {
859 fn from(value: CubeCountSelection) -> Self {
860 value.cube_count()
861 }
862}
863
864impl CubeCount {
865 pub fn new_single() -> Self {
867 CubeCount::Static(1, 1, 1)
868 }
869
870 pub fn new_1d(x: u32) -> Self {
872 CubeCount::Static(x, 1, 1)
873 }
874
875 pub fn new_2d(x: u32, y: u32) -> Self {
877 CubeCount::Static(x, y, 1)
878 }
879
880 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
882 CubeCount::Static(x, y, z)
883 }
884}
885
886impl Debug for CubeCount {
887 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
888 match self {
889 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
890 CubeCount::Dynamic(_) => f.write_str("binding"),
891 }
892 }
893}
894
895impl Clone for CubeCount {
896 fn clone(&self) -> Self {
897 match self {
898 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
899 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
900 }
901 }
902}
903
904#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, serde::Serialize, serde::Deserialize)]
905#[allow(missing_docs)]
906pub struct CubeDim {
908 pub x: u32,
910 pub y: u32,
912 pub z: u32,
914}
915
916impl CubeDim {
917 pub fn new<R: Runtime>(client: &ComputeClient<R>, working_units: usize) -> Self {
925 let properties = client.properties();
926 let plane_size = properties.hardware.plane_size_max;
927 let plane_count = Self::calculate_plane_count_per_cube(
928 working_units as u32,
929 plane_size,
930 properties.hardware.num_cpu_cores,
931 );
932
933 let limit = properties.hardware.max_units_per_cube / plane_size;
935
936 Self::new_2d(plane_size, u32::min(limit, plane_count))
937 }
938
939 fn calculate_plane_count_per_cube(
940 working_units: u32,
941 plane_dim: u32,
942 num_cpu_cores: Option<u32>,
943 ) -> u32 {
944 match num_cpu_cores {
945 Some(num_cores) => core::cmp::min(num_cores, working_units),
946 None => {
947 let plane_count_max = core::cmp::max(1, working_units / plane_dim);
948
949 const NUM_PLANE_MAX: u32 = 8u32;
951 const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
952 let plane_count_max_log2 =
953 core::cmp::min(NUM_PLANE_MAX_LOG2, u32::ilog2(plane_count_max));
954 2u32.pow(plane_count_max_log2)
955 }
956 }
957 }
958
959 pub const fn new_single() -> Self {
961 Self { x: 1, y: 1, z: 1 }
962 }
963
964 pub const fn new_1d(x: u32) -> Self {
966 Self { x, y: 1, z: 1 }
967 }
968
969 pub const fn new_2d(x: u32, y: u32) -> Self {
971 Self { x, y, z: 1 }
972 }
973
974 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
977 Self { x, y, z }
978 }
979
980 pub const fn num_elems(&self) -> u32 {
982 self.x * self.y * self.z
983 }
984
985 pub const fn can_contain(&self, other: CubeDim) -> bool {
987 self.x >= other.x && self.y >= other.y && self.z >= other.z
988 }
989}
990
991impl From<(u32, u32, u32)> for CubeDim {
992 fn from(value: (u32, u32, u32)) -> Self {
993 CubeDim::new_3d(value.0, value.1, value.2)
994 }
995}
996
997impl From<CubeDim> for (u32, u32, u32) {
998 fn from(val: CubeDim) -> Self {
999 (val.x, val.y, val.z)
1000 }
1001}
1002
1003#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy, Serialize, Deserialize)]
1005pub enum ExecutionMode {
1006 #[default]
1008 Checked,
1009 Unchecked,
1011}
1012
1013fn cube_count_spread(max: &(u32, u32, u32), num_cubes: u32) -> [u32; 3] {
1014 let max_cube_counts = [max.0, max.1, max.2];
1015 let mut num_cubes = [num_cubes, 1, 1];
1016 let base = 2;
1017
1018 let mut reduce_count = |i: usize| {
1019 if num_cubes[i] <= max_cube_counts[i] {
1020 return true;
1021 }
1022
1023 loop {
1024 num_cubes[i] = num_cubes[i].div_ceil(base);
1025 num_cubes[i + 1] *= base;
1026
1027 if num_cubes[i] <= max_cube_counts[i] {
1028 return false;
1029 }
1030 }
1031 };
1032
1033 for i in 0..2 {
1034 if reduce_count(i) {
1035 break;
1036 }
1037 }
1038
1039 num_cubes
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044 use super::*;
1045
1046 #[test_log::test]
1047 fn safe_num_cubes_even() {
1048 let max = (32, 32, 32);
1049 let required = 2048;
1050
1051 let actual = cube_count_spread(&max, required);
1052 let expected = [32, 32, 2];
1053 assert_eq!(actual, expected);
1054 }
1055
1056 #[test_log::test]
1057 fn safe_num_cubes_odd() {
1058 let max = (48, 32, 16);
1059 let required = 3177;
1060
1061 let actual = cube_count_spread(&max, required);
1062 let expected = [25, 32, 4];
1063 assert_eq!(actual, expected);
1064 }
1065}