Skip to main content

burn_fusion/ops/
binary.rs

1#[allow(missing_docs)]
2#[macro_export(local_inner_macros)]
3macro_rules! binary_float_ops {
4    (
5        $name:ident,
6        $ops:expr
7    ) => {
8        #[derive(Debug)]
9        struct $name<B: FusionBackend> {
10            desc: BinaryOpIr,
11            _b: PhantomData<B>,
12        }
13
14        impl<B: FusionBackend> $name<B> {
15            fn new(desc: BinaryOpIr) -> Self {
16                Self {
17                    desc,
18                    _b: PhantomData,
19                }
20            }
21        }
22
23        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
24            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
25                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
26                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
27                let output = $ops(lhs, rhs);
28
29                handles.register_float_tensor::<B>(&self.desc.out.id, output);
30            }
31        }
32    };
33}
34
35#[allow(missing_docs)]
36#[macro_export(local_inner_macros)]
37macro_rules! binary_float_cmp_ops {
38    (
39        $name:ident,
40        $ops:expr
41    ) => {
42        #[derive(new, Debug)]
43        struct $name<B: FusionBackend> {
44            desc: BinaryOpIr,
45            _b: PhantomData<B>,
46        }
47
48        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
49            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
50                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
51                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
52                let output = $ops(lhs, rhs, self.desc.out.dtype.into());
53
54                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
55            }
56        }
57    };
58}
59
60#[allow(missing_docs)]
61#[macro_export(local_inner_macros)]
62macro_rules! binary_int_cmp_ops {
63    (
64        $name:ident,
65        $ops:expr
66    ) => {
67        #[derive(Debug)]
68        struct $name<B: FusionBackend> {
69            desc: BinaryOpIr,
70            _b: PhantomData<B>,
71        }
72
73        impl<B: FusionBackend> $name<B> {
74            fn new(desc: BinaryOpIr) -> Self {
75                Self {
76                    desc,
77                    _b: PhantomData,
78                }
79            }
80        }
81
82        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
83            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
84                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
85                let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
86                let output = $ops(lhs, rhs, self.desc.out.dtype.into());
87
88                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
89            }
90        }
91    };
92}
93
94#[allow(missing_docs)]
95#[macro_export(local_inner_macros)]
96macro_rules! binary_int_ops {
97    (
98        $name:ident,
99        $ops:expr
100    ) => {
101        #[derive(new, Debug)]
102        struct $name<B: FusionBackend> {
103            desc: BinaryOpIr,
104            _b: PhantomData<B>,
105        }
106
107        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
108            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
109                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);
110                let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);
111                let output = $ops(lhs, rhs);
112
113                handles.register_int_tensor::<B>(&self.desc.out.id, output);
114            }
115        }
116    };
117}