burn_fusion/ops/
binary.rs1use burn_ir::{BinaryOpIr, TensorIr};
2
3#[derive(Debug)]
4pub enum BinaryOpError {
5 #[allow(dead_code)]
6 DTypeMismatch {
8 lhs: burn_tensor::DType,
9 rhs: burn_tensor::DType,
10 },
11}
12
13pub(crate) fn check_binary_op(desc: BinaryOpIr) -> Result<BinaryOpIr, BinaryOpError> {
15 check_binary_op_types(&desc.lhs, &desc.rhs)?;
16 Ok(desc)
17}
18
19pub(crate) fn check_binary_op_types(lhs: &TensorIr, rhs: &TensorIr) -> Result<(), BinaryOpError> {
20 if lhs.dtype != rhs.dtype {
21 Err(BinaryOpError::DTypeMismatch {
22 lhs: lhs.dtype,
23 rhs: rhs.dtype,
24 })
25 } else {
26 Ok(())
27 }
28}
29
30#[allow(missing_docs)]
31#[macro_export(local_inner_macros)]
32macro_rules! binary_float_ops {
33 (
34 $name:ident,
35 $ops:expr
36 ) => {
37 #[derive(Debug)]
38 struct $name<B: FusionBackend> {
39 desc: BinaryOpIr,
40 _b: PhantomData<B>,
41 }
42
43 impl<B: FusionBackend> $name<B> {
44 fn new(desc: BinaryOpIr) -> Self {
45 Self {
46 desc: $crate::ops::binary::check_binary_op(desc).unwrap(),
47 _b: PhantomData,
48 }
49 }
50 }
51
52 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
53 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
54 let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
55 let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
56 let output = $ops(lhs, rhs);
57
58 handles.register_float_tensor::<B>(&self.desc.out.id, output);
59 }
60 }
61 };
62}
63
64#[allow(missing_docs)]
65#[macro_export(local_inner_macros)]
66macro_rules! binary_float_cmp_ops {
67 (
68 $name:ident,
69 $ops:expr
70 ) => {
71 #[derive(new, Debug)]
72 struct $name<B: FusionBackend> {
73 desc: BinaryOpIr,
74 _b: PhantomData<B>,
75 }
76
77 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
78 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
79 let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
80 let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
81 let output = $ops(lhs, rhs);
82
83 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
84 }
85 }
86 };
87}
88
89#[allow(missing_docs)]
90#[macro_export(local_inner_macros)]
91macro_rules! binary_int_cmp_ops {
92 (
93 $name:ident,
94 $ops:expr
95 ) => {
96 #[derive(Debug)]
97 struct $name<B: FusionBackend> {
98 desc: BinaryOpIr,
99 _b: PhantomData<B>,
100 }
101
102 impl<B: FusionBackend> $name<B> {
103 fn new(desc: BinaryOpIr) -> Self {
104 Self {
105 desc: $crate::ops::binary::check_binary_op(desc).unwrap(),
106 _b: PhantomData,
107 }
108 }
109 }
110
111 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
112 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
113 let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
114 let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
115 let output = $ops(lhs, rhs);
116
117 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
118 }
119 }
120 };
121}
122
123#[allow(missing_docs)]
124#[macro_export(local_inner_macros)]
125macro_rules! binary_int_ops {
126 (
127 $name:ident,
128 $ops:expr
129 ) => {
130 #[derive(new, Debug)]
131 struct $name<B: FusionBackend> {
132 desc: BinaryOpIr,
133 _b: PhantomData<B>,
134 }
135
136 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
137 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
138 let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
139 let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
140 let output = $ops(lhs, rhs);
141
142 handles.register_int_tensor::<B>(&self.desc.out.id, output);
143 }
144 }
145 };
146}