burn_fusion/ops/
binary.rs1#[allow(missing_docs)]
2#[macro_export(local_inner_macros)]
3macro_rules! binary_float_ops {
4 (
5 $name:ident,
6 $ops:expr
7 ) => {
8 #[derive(Debug)]
9 struct $name<B: FusionBackend> {
10 desc: BinaryOpIr,
11 _b: PhantomData<B>,
12 }
13
14 impl<B: FusionBackend> $name<B> {
15 fn new(desc: BinaryOpIr) -> Self {
16 Self {
17 desc,
18 _b: PhantomData,
19 }
20 }
21 }
22
23 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
24 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
25 let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
26 let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
27 let output = $ops(lhs, rhs);
28
29 handles.register_float_tensor::<B>(&self.desc.out.id, output);
30 }
31 }
32 };
33}
34
35#[allow(missing_docs)]
36#[macro_export(local_inner_macros)]
37macro_rules! binary_float_cmp_ops {
38 (
39 $name:ident,
40 $ops:expr
41 ) => {
42 #[derive(new, Debug)]
43 struct $name<B: FusionBackend> {
44 desc: BinaryOpIr,
45 _b: PhantomData<B>,
46 }
47
48 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
49 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
50 let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
51 let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
52 let output = $ops(lhs, rhs, self.desc.out.dtype.into());
53
54 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
55 }
56 }
57 };
58}
59
60#[allow(missing_docs)]
61#[macro_export(local_inner_macros)]
62macro_rules! binary_int_cmp_ops {
63 (
64 $name:ident,
65 $ops:expr
66 ) => {
67 #[derive(Debug)]
68 struct $name<B: FusionBackend> {
69 desc: BinaryOpIr,
70 _b: PhantomData<B>,
71 }
72
73 impl<B: FusionBackend> $name<B> {
74 fn new(desc: BinaryOpIr) -> Self {
75 Self {
76 desc,
77 _b: PhantomData,
78 }
79 }
80 }
81
82 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
83 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
84 let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
85 let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
86 let output = $ops(lhs, rhs, self.desc.out.dtype.into());
87
88 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
89 }
90 }
91 };
92}
93
94#[allow(missing_docs)]
95#[macro_export(local_inner_macros)]
96macro_rules! binary_int_ops {
97 (
98 $name:ident,
99 $ops:expr
100 ) => {
101 #[derive(new, Debug)]
102 struct $name<B: FusionBackend> {
103 desc: BinaryOpIr,
104 _b: PhantomData<B>,
105 }
106
107 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
108 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
109 let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
110 let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
111 let output = $ops(lhs, rhs);
112
113 handles.register_int_tensor::<B>(&self.desc.out.id, output);
114 }
115 }
116 };
117}