Skip to main content

cubecl_ir/
plane.rs

1use core::fmt::Display;
2
3use crate::OperationReflect;
4
5use super::{BinaryOperator, UnaryOperator};
6use crate::TypeHash;
7
8/// All plane operations.
9///
10/// Note that not all backends support plane (warp/subgroup) operations.
11/// Use the [`crate::features::Features::plane`] flag to enable them.
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
14#[operation(opcode_name = PlaneOpCode)]
15#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
16pub enum Plane {
17    Elect,
18    All(UnaryOperator),
19    Any(UnaryOperator),
20    Ballot(UnaryOperator),
21    Broadcast(BinaryOperator),
22    Shuffle(BinaryOperator),
23    ShuffleXor(BinaryOperator),
24    ShuffleUp(BinaryOperator),
25    ShuffleDown(BinaryOperator),
26    Sum(UnaryOperator),
27    InclusiveSum(UnaryOperator),
28    ExclusiveSum(UnaryOperator),
29    Prod(UnaryOperator),
30    InclusiveProd(UnaryOperator),
31    ExclusiveProd(UnaryOperator),
32    Min(UnaryOperator),
33    Max(UnaryOperator),
34}
35
36impl Display for Plane {
37    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38        match self {
39            Plane::Elect => write!(f, "plane_elect()"),
40            Plane::All(op) => write!(f, "plane_all({})", op.input),
41            Plane::Any(op) => write!(f, "plane_any({})", op.input),
42            Plane::Ballot(op) => write!(f, "plane_ballot({})", op.input),
43            Plane::Broadcast(op) => {
44                write!(f, "plane_broadcast({}, {})", op.lhs, op.rhs)
45            }
46            Plane::Shuffle(op) => {
47                write!(f, "plane_shuffle({}, {})", op.lhs, op.rhs)
48            }
49            Plane::ShuffleXor(op) => {
50                write!(f, "plane_shuffle_xor({}, {})", op.lhs, op.rhs)
51            }
52            Plane::ShuffleUp(op) => {
53                write!(f, "plane_shuffle_up({}, {})", op.lhs, op.rhs)
54            }
55            Plane::ShuffleDown(op) => {
56                write!(f, "plane_shuffle_down({}, {})", op.lhs, op.rhs)
57            }
58            Plane::Sum(op) => write!(f, "plane_sum({})", op.input),
59            Plane::InclusiveSum(op) => write!(f, "plane_inclusive_sum({})", op.input),
60            Plane::ExclusiveSum(op) => write!(f, "plane_exclusive_sum({})", op.input),
61            Plane::Prod(op) => write!(f, "plane_product({})", op.input),
62            Plane::InclusiveProd(op) => write!(f, "plane_inclusive_product({})", op.input),
63            Plane::ExclusiveProd(op) => write!(f, "plane_exclusive_product({})", op.input),
64            Plane::Min(op) => write!(f, "plane_min({})", op.input),
65            Plane::Max(op) => write!(f, "plane_max({})", op.input),
66        }
67    }
68}