cubecl_reduce/
strategy.rs

1use cubecl_core::{Feature, prelude::*};
2use serde::{Deserialize, Serialize};
3
4use crate::ReduceError;
5
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize)]
7pub struct ReduceStrategy {
8    /// If true and the compute client support plane instructions,
9    /// then try using them in the kernel. It could still be impossible to use
10    /// plane instructions depending on the memory layout of the tensors.
11    pub use_planes: bool,
12
13    /// If true, all units within a single cube cooperate to reduce a single item in the output.
14    /// Else, each unit or plane (if planes is true) reduce a single item by itself.
15    pub shared: bool,
16}
17
18impl ReduceStrategy {
19    pub fn validate<R: Runtime>(
20        self,
21        client: &ComputeClient<R::Server, R::Channel>,
22    ) -> Result<Self, ReduceError> {
23        if self.use_planes {
24            if !support_plane::<R>(client) {
25                return Err(ReduceError::PlanesUnavailable);
26            }
27            if !precise_plane_dim::<R>(client) {
28                return Err(ReduceError::ImprecisePlaneDim);
29            }
30        }
31
32        Ok(self)
33    }
34
35    pub fn new<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>, shared: bool) -> Self {
36        Self {
37            use_planes: support_plane::<R>(client) && precise_plane_dim::<R>(client),
38            shared,
39        }
40    }
41}
42
43fn support_plane<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> bool {
44    client.properties().feature_enabled(Feature::Plane)
45}
46
47fn precise_plane_dim<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> bool {
48    let hw_props = &client.properties().hardware;
49    hw_props.plane_size_min == hw_props.plane_size_max
50}