cubecl_ir/
plane.rs

1use core::fmt::Display;
2
3use crate::OperationReflect;
4
5use super::{BinaryOperator, UnaryOperator};
6use crate::TypeHash;
7
8/// All plane operations.
9///
10/// Note that not all backends support plane (warp/subgroup) operations. Use the [runtime flag](crate::Feature::Plane).
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
13#[operation(opcode_name = PlaneOpCode)]
14#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
15pub enum Plane {
16    Elect,
17    All(UnaryOperator),
18    Any(UnaryOperator),
19    Ballot(UnaryOperator),
20    Broadcast(BinaryOperator),
21    Sum(UnaryOperator),
22    InclusiveSum(UnaryOperator),
23    ExclusiveSum(UnaryOperator),
24    Prod(UnaryOperator),
25    InclusiveProd(UnaryOperator),
26    ExclusiveProd(UnaryOperator),
27    Min(UnaryOperator),
28    Max(UnaryOperator),
29}
30
31impl Display for Plane {
32    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33        match self {
34            Plane::Elect => writeln!(f, "plane_elect()"),
35            Plane::All(op) => writeln!(f, "plane_all({})", op.input),
36            Plane::Any(op) => writeln!(f, "plane_any({})", op.input),
37            Plane::Ballot(op) => writeln!(f, "plane_ballot({})", op.input),
38            Plane::Broadcast(op) => {
39                writeln!(f, "plane_broadcast({}, {})", op.lhs, op.rhs)
40            }
41            Plane::Sum(op) => writeln!(f, "plane_sum({})", op.input),
42            Plane::InclusiveSum(op) => writeln!(f, "plane_inclusive_sum({})", op.input),
43            Plane::ExclusiveSum(op) => writeln!(f, "plane_exclusive_sum({})", op.input),
44            Plane::Prod(op) => writeln!(f, "plane_product({})", op.input),
45            Plane::InclusiveProd(op) => writeln!(f, "plane_inclusive_product({})", op.input),
46            Plane::ExclusiveProd(op) => writeln!(f, "plane_exclusive_product({})", op.input),
47            Plane::Min(op) => writeln!(f, "plane_min({})", op.input),
48            Plane::Max(op) => writeln!(f, "plane_max({})", op.input),
49        }
50    }
51}