cubecl_matmul/components/
resource.rs

1use cubecl_core::prelude::*;
2
3use crate::components::InvalidConfigError;
4
5/// Number of compute primitives required by some component, specified as either units or planes.
6pub enum ComputeResources {
7    Units(u32),
8    Planes(u32),
9}
10
11impl ComputeResources {
12    /// Ensures [ComputeResources] is Planes variant, converting
13    /// units using plane_dim, the number of units in a plane.
14    ///
15    /// Will fail if the number of units does not correspond to an exact number of planes
16    pub fn as_plane_resources(self, plane_dim: u32) -> Result<Self, InvalidConfigError> {
17        match self {
18            ComputeResources::Units(units) => {
19                if units % plane_dim == 0 {
20                    Ok(ComputeResources::Planes(units / plane_dim))
21                } else {
22                    Err(Box::new(format!(
23                        "Number of units {units:?} should be divisible by plane_dim {plane_dim:?}"
24                    )))
25                }
26            }
27            ComputeResources::Planes(_) => Ok(self),
28        }
29    }
30
31    /// Make a [CubeDim] from specified resources.
32    ///
33    /// Obtained CubeDim is always (plane_dim, number_of_planes, 1)
34    ///
35    /// Will fail if the number of units does not correspond to an exact number of planes
36    pub fn to_cube_dim(self, plane_dim: u32) -> Result<CubeDim, InvalidConfigError> {
37        match self {
38            ComputeResources::Units(_) => {
39                self.as_plane_resources(plane_dim)?.to_cube_dim(plane_dim)
40            }
41            ComputeResources::Planes(num_planes) => Ok(CubeDim::new_2d(plane_dim, num_planes)),
42        }
43    }
44
45    /// Get the number of planes
46    ///
47    /// Will fail if the number of units does not correspond to an exact number of planes
48    pub(crate) fn num_planes(self, plane_dim: u32) -> Result<u32, InvalidConfigError> {
49        let plane_resources = self.as_plane_resources(plane_dim)?;
50        if let ComputeResources::Planes(num_planes) = plane_resources {
51            Ok(num_planes)
52        } else {
53            unreachable!()
54        }
55    }
56}