burn_fusion/ops/
binary.rs

1use burn_ir::{BinaryOpIr, TensorIr};
2
3#[derive(Debug)]
4pub enum BinaryOpError {
5    #[allow(dead_code)]
6    /// Binary op data type mismatch.
7    DTypeMismatch {
8        lhs: burn_tensor::DType,
9        rhs: burn_tensor::DType,
10    },
11}
12
13// Until we have floating point type promotion, check that lhs and rhs dtypes are the same.
14pub(crate) fn check_binary_op(desc: BinaryOpIr) -> Result<BinaryOpIr, BinaryOpError> {
15    check_binary_op_types(&desc.lhs, &desc.rhs)?;
16    Ok(desc)
17}
18
19pub(crate) fn check_binary_op_types(lhs: &TensorIr, rhs: &TensorIr) -> Result<(), BinaryOpError> {
20    if lhs.dtype != rhs.dtype {
21        Err(BinaryOpError::DTypeMismatch {
22            lhs: lhs.dtype,
23            rhs: rhs.dtype,
24        })
25    } else {
26        Ok(())
27    }
28}
29
30#[allow(missing_docs)]
31#[macro_export(local_inner_macros)]
32macro_rules! binary_float_ops {
33    (
34        $name:ident,
35        $ops:expr
36    ) => {
37        #[derive(Debug)]
38        struct $name<B: FusionBackend> {
39            desc: BinaryOpIr,
40            _b: PhantomData<B>,
41        }
42
43        impl<B: FusionBackend> $name<B> {
44            fn new(desc: BinaryOpIr) -> Self {
45                Self {
46                    desc: $crate::ops::binary::check_binary_op(desc).unwrap(),
47                    _b: PhantomData,
48                }
49            }
50        }
51
52        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
53            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
54                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
55                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
56                let output = $ops(lhs, rhs);
57
58                handles.register_float_tensor::<B>(&self.desc.out.id, output);
59            }
60        }
61    };
62}
63
64#[allow(missing_docs)]
65#[macro_export(local_inner_macros)]
66macro_rules! binary_float_cmp_ops {
67    (
68        $name:ident,
69        $ops:expr
70    ) => {
71        #[derive(new, Debug)]
72        struct $name<B: FusionBackend> {
73            desc: BinaryOpIr,
74            _b: PhantomData<B>,
75        }
76
77        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
78            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
79                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
80                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
81                let output = $ops(lhs, rhs);
82
83                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
84            }
85        }
86    };
87}
88
89#[allow(missing_docs)]
90#[macro_export(local_inner_macros)]
91macro_rules! binary_int_cmp_ops {
92    (
93        $name:ident,
94        $ops:expr
95    ) => {
96        #[derive(Debug)]
97        struct $name<B: FusionBackend> {
98            desc: BinaryOpIr,
99            _b: PhantomData<B>,
100        }
101
102        impl<B: FusionBackend> $name<B> {
103            fn new(desc: BinaryOpIr) -> Self {
104                Self {
105                    desc: $crate::ops::binary::check_binary_op(desc).unwrap(),
106                    _b: PhantomData,
107                }
108            }
109        }
110
111        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
112            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
113                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
114                let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
115                let output = $ops(lhs, rhs);
116
117                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
118            }
119        }
120    };
121}
122
123#[allow(missing_docs)]
124#[macro_export(local_inner_macros)]
125macro_rules! binary_int_ops {
126    (
127        $name:ident,
128        $ops:expr
129    ) => {
130        #[derive(new, Debug)]
131        struct $name<B: FusionBackend> {
132            desc: BinaryOpIr,
133            _b: PhantomData<B>,
134        }
135
136        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
137            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
138                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
139                let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
140                let output = $ops(lhs, rhs);
141
142                handles.register_int_tensor::<B>(&self.desc.out.id, output);
143            }
144        }
145    };
146}