cubecl_reduce/
strategy.rs

1use cubecl_core::prelude::*;
2use cubecl_runtime::Plane;
3use serde::{Deserialize, Serialize};
4
5use crate::ReduceError;
6
7#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize)]
8pub struct ReduceStrategy {
9    /// If true and the compute client support plane instructions,
10    /// then try using them in the kernel. It could still be impossible to use
11    /// plane instructions depending on the memory layout of the tensors.
12    pub use_planes: bool,
13
14    /// If true, all units within a single cube cooperate to reduce a single item in the output.
15    /// Else, each unit or plane (if planes is true) reduce a single item by itself.
16    pub shared: bool,
17}
18
19impl ReduceStrategy {
20    pub fn validate<R: Runtime>(
21        self,
22        client: &ComputeClient<R::Server>,
23    ) -> Result<Self, ReduceError> {
24        if self.use_planes {
25            if !support_plane::<R>(client) {
26                return Err(ReduceError::PlanesUnavailable);
27            }
28            if !precise_plane_dim::<R>(client) {
29                return Err(ReduceError::ImprecisePlaneDim);
30            }
31        }
32
33        Ok(self)
34    }
35
36    pub fn new<R: Runtime>(client: &ComputeClient<R::Server>, shared: bool) -> Self {
37        Self {
38            use_planes: support_plane::<R>(client) && precise_plane_dim::<R>(client),
39            shared,
40        }
41    }
42}
43
44fn support_plane<R: Runtime>(client: &ComputeClient<R::Server>) -> bool {
45    client.properties().features.plane.contains(Plane::Ops)
46}
47
48fn precise_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> bool {
49    let hw_props = &client.properties().hardware;
50    hw_props.plane_size_min == hw_props.plane_size_max
51}