cubecl_reduce/
strategy.rs1use cubecl_core::prelude::*;
2use cubecl_runtime::Plane;
3use serde::{Deserialize, Serialize};
4
5use crate::ReduceError;
6
7#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize)]
8pub struct ReduceStrategy {
9 pub use_planes: bool,
13
14 pub shared: bool,
17}
18
19impl ReduceStrategy {
20 pub fn validate<R: Runtime>(
21 self,
22 client: &ComputeClient<R::Server>,
23 ) -> Result<Self, ReduceError> {
24 if self.use_planes {
25 if !support_plane::<R>(client) {
26 return Err(ReduceError::PlanesUnavailable);
27 }
28 if !precise_plane_dim::<R>(client) {
29 return Err(ReduceError::ImprecisePlaneDim);
30 }
31 }
32
33 Ok(self)
34 }
35
36 pub fn new<R: Runtime>(client: &ComputeClient<R::Server>, shared: bool) -> Self {
37 Self {
38 use_planes: support_plane::<R>(client) && precise_plane_dim::<R>(client),
39 shared,
40 }
41 }
42}
43
44fn support_plane<R: Runtime>(client: &ComputeClient<R::Server>) -> bool {
45 client.properties().features.plane.contains(Plane::Ops)
46}
47
48fn precise_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> bool {
49 let hw_props = &client.properties().hardware;
50 hw_props.plane_size_min == hw_props.plane_size_max
51}