cubecl_reduce/
strategy.rs1use cubecl_core::{Feature, prelude::*};
2use serde::{Deserialize, Serialize};
3
4use crate::ReduceError;
5
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize)]
7pub struct ReduceStrategy {
8 pub use_planes: bool,
12
13 pub shared: bool,
16}
17
18impl ReduceStrategy {
19 pub fn validate<R: Runtime>(
20 self,
21 client: &ComputeClient<R::Server, R::Channel>,
22 ) -> Result<Self, ReduceError> {
23 if self.use_planes {
24 if !support_plane::<R>(client) {
25 return Err(ReduceError::PlanesUnavailable);
26 }
27 if !precise_plane_dim::<R>(client) {
28 return Err(ReduceError::ImprecisePlaneDim);
29 }
30 }
31
32 Ok(self)
33 }
34
35 pub fn new<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>, shared: bool) -> Self {
36 Self {
37 use_planes: support_plane::<R>(client) && precise_plane_dim::<R>(client),
38 shared,
39 }
40 }
41}
42
43fn support_plane<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> bool {
44 client.properties().feature_enabled(Feature::Plane)
45}
46
47fn precise_plane_dim<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> bool {
48 let hw_props = &client.properties().hardware;
49 hw_props.plane_size_min == hw_props.plane_size_max
50}