1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
use crate::{
    codegen::{execute_static, StaticHandle, WorkgroupLaunch},
    element::WgpuElement,
    tensor::WgpuTensor,
};
use burn_tensor::Shape;

/// Creates a binary kernel.
#[macro_export]
macro_rules! binary {
    (
        operator: $ops:expr,
        input: $lhs:expr; $rhs:expr,
        elem: $elem:ty
    ) => {{
        binary!(operator: $ops, elem_in: $elem, elem_out: $elem);

        $crate::kernel::binary::<Ops<$elem, $elem>, OpsInplaceLhs<$elem, $elem>, OpsInplaceRhs<$elem, $elem>, $elem, D>(
            $lhs, $rhs, true
        )
    }};

    (
        operator: $ops:expr,
        elem_in: $elem_in:ty,
        elem_out: $elem_out:ty
    ) => {
        pub struct Ops<I, O> {
            _i: core::marker::PhantomData<I>,
            _o: core::marker::PhantomData<O>,
        }
        pub struct OpsInplaceLhs<I, O> {
            _i: core::marker::PhantomData<I>,
            _o: core::marker::PhantomData<O>,
        }
        pub struct OpsInplaceRhs<I, O> {
            _i: core::marker::PhantomData<I>,
            _o: core::marker::PhantomData<O>,
        }

        #[allow(clippy::redundant_closure_call)]
        impl<I, O> $crate::kernel::StaticKernelSource for Ops<I, O>
        where
            I: $crate::element::WgpuElement,
            O: $crate::element::WgpuElement
        {
            fn source() -> $crate::kernel::SourceTemplate {
                let shader = $crate::codegen::ElemWiseKernelCodegen::new()
                    .inputs(&[
                        $crate::codegen::Input::Array {
                            item: $crate::codegen::Item::Scalar(I::elem_type()),
                            visibility: $crate::codegen::Visibility::Read,
                            strategy: $crate::codegen::ReadingStrategy::OutputLayout,
                        },
                        $crate::codegen::Input::Array {
                            item: $crate::codegen::Item::Scalar(I::elem_type()),
                            visibility: $crate::codegen::Visibility::Read,
                            strategy: $crate::codegen::ReadingStrategy::OutputLayout,
                        },
                    ])
                    .body(&[$ops(I::elem_type())])
                    .outputs(&[$crate::codegen::Output::Array {
                        item: $crate::codegen::Item::Scalar(O::elem_type()),
                        local: 0,
                    }])
                    .compile();

                $crate::kernel::SourceTemplate::new(shader.to_string())
            }
        }

        #[allow(clippy::redundant_closure_call)]
        impl<I, O> $crate::kernel::StaticKernelSource
            for OpsInplaceLhs<I, O>
        where
            I: $crate::element::WgpuElement,
            O: $crate::element::WgpuElement
        {
            fn source() -> $crate::kernel::SourceTemplate {
                let shader = $crate::codegen::ElemWiseKernelCodegen::new()
                    .inputs(&[
                        $crate::codegen::Input::Array {
                            item: $crate::codegen::Item::Scalar(I::elem_type()),
                            visibility: $crate::codegen::Visibility::ReadWrite,
                            strategy: $crate::codegen::ReadingStrategy::Plain,
                        },
                        $crate::codegen::Input::Array {
                            item: $crate::codegen::Item::Scalar(I::elem_type()),
                            visibility: $crate::codegen::Visibility::Read,
                            strategy: $crate::codegen::ReadingStrategy::OutputLayout,
                        },
                    ])
                    .body(&[$ops(I::elem_type())])
                    .outputs(&[$crate::codegen::Output::Input {
                        item: $crate::codegen::Item::Scalar(I::elem_type()),
                        input: 0,
                        local: 0,
                    }])
                    .compile();

                $crate::kernel::SourceTemplate::new(shader.to_string())
            }
        }

        #[allow(clippy::redundant_closure_call)]
        impl<I, O> $crate::kernel::StaticKernelSource
            for OpsInplaceRhs<I, O>
        where
            I: $crate::element::WgpuElement,
            O: $crate::element::WgpuElement
        {
            fn source() -> $crate::kernel::SourceTemplate {
                let shader = $crate::codegen::ElemWiseKernelCodegen::new()
                    .inputs(&[
                        $crate::codegen::Input::Array {
                            item: $crate::codegen::Item::Scalar(I::elem_type()),
                            visibility: $crate::codegen::Visibility::Read,
                            strategy: $crate::codegen::ReadingStrategy::OutputLayout,
                        },
                        $crate::codegen::Input::Array {
                            item: $crate::codegen::Item::Scalar(I::elem_type()),
                            visibility: $crate::codegen::Visibility::ReadWrite,
                            strategy: $crate::codegen::ReadingStrategy::Plain,
                        },
                    ])
                    .body(&[$ops(I::elem_type())])
                    .outputs(&[$crate::codegen::Output::Input {
                        item: $crate::codegen::Item::Scalar(I::elem_type()),
                        input: 1,
                        local: 0,
                    }])
                    .compile();

                $crate::kernel::SourceTemplate::new(shader.to_string())
            }
        }
    };
}

/// Launch an binary operation.
pub fn binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, E, const D: usize>(
    lhs: WgpuTensor<E, D>,
    rhs: WgpuTensor<E, D>,
    inplace_enabled: bool,
) -> WgpuTensor<E, D>
where
    Kernel: crate::kernel::StaticKernelSource,
    KernelInplaceLhs: crate::kernel::StaticKernelSource,
    KernelInplaceRhs: crate::kernel::StaticKernelSource,
    E: WgpuElement,
{
    if inplace_enabled && lhs.can_mut_broadcast(&rhs) {
        execute_static::<KernelInplaceLhs, E>(
            &[
                StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
                StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
            ],
            &[],
            None,
            WorkgroupLaunch::Input { pos: 0 },
            rhs.client,
        );

        lhs
    } else if inplace_enabled && rhs.can_mut_broadcast(&lhs) {
        execute_static::<KernelInplaceRhs, E>(
            &[
                StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
                StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
            ],
            &[],
            None,
            WorkgroupLaunch::Input { pos: 1 },
            lhs.client,
        );

        rhs
    } else {
        let mut shape_out = [0; D];
        lhs.shape
            .dims
            .iter()
            .zip(rhs.shape.dims.iter())
            .enumerate()
            .for_each(|(index, (dim_lhs, dim_rhs))| {
                shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
            });

        let shape_out = Shape::new(shape_out);
        let num_elems = shape_out.num_elements();
        let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
        let out = WgpuTensor::new(lhs.client.clone(), lhs.device, shape_out, buffer);

        execute_static::<Kernel, E>(
            &[
                StaticHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
                StaticHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
            ],
            &[StaticHandle::new(
                &out.handle,
                &out.strides,
                &out.shape.dims,
            )],
            None,
            WorkgroupLaunch::Output { pos: 0 },
            lhs.client,
        );

        out
    }
}