1use ndarray::{Array4, ArrayView4};
4use ocl::{flags, prm::Uint3, Buffer, Kernel};
5
6use std::{borrow::Cow, convert::TryFrom};
7
8use crate::{
9 base::Base,
10 params::{OutputParams, WithParams},
11 ConvElement, Params,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub struct FeatureMapShape {
17 pub batch_size: u32,
19 pub width: u32,
21 pub height: u32,
23 pub channels: u32,
25}
26
27impl FeatureMapShape {
28 fn from_nhwc_slice(shape: &[usize]) -> Self {
29 assert_eq!(shape.len(), 4);
30 FeatureMapShape {
31 batch_size: u32::try_from(shape[0]).expect("Cannot convert batch size to `u32`"),
32 height: u32::try_from(shape[1]).expect("Cannot convert height to `u32`"),
33 width: u32::try_from(shape[2]).expect("Cannot convert width to `u32`"),
34 channels: u32::try_from(shape[3]).expect("Cannot convert channel count to `u32`"),
35 }
36 }
37
38 fn from_nchw_slice(shape: &[usize]) -> Self {
39 assert_eq!(shape.len(), 4);
40 FeatureMapShape {
41 batch_size: u32::try_from(shape[0]).expect("Cannot convert batch size to `u32`"),
42 height: u32::try_from(shape[2]).expect("Cannot convert height to `u32`"),
43 width: u32::try_from(shape[3]).expect("Cannot convert width to `u32`"),
44 channels: u32::try_from(shape[1]).expect("Cannot convert channel count to `u32`"),
45 }
46 }
47
48 fn buffer_len(self) -> usize {
49 self.batch_size as usize
50 * self.width as usize
51 * self.height as usize
52 * self.channels as usize
53 }
54
55 fn as_array(self, layout: Layout) -> [usize; 4] {
56 match layout {
57 Layout::ChannelsFirst => [
58 self.batch_size as usize,
59 self.channels as usize,
60 self.height as usize,
61 self.width as usize,
62 ],
63 Layout::ChannelsLast => [
64 self.batch_size as usize,
65 self.height as usize,
66 self.width as usize,
67 self.channels as usize,
68 ],
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75#[repr(u8)]
76pub enum Layout {
77 ChannelsFirst = 0,
80 ChannelsLast = 1,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq)]
92pub struct FeatureMap<'a, T> {
93 layout: Layout,
94 inner: ArrayView4<'a, T>,
95 shape: FeatureMapShape,
96}
97
98impl<'a, T: ConvElement> FeatureMap<'a, T> {
99 pub fn nchw(array: impl Into<ArrayView4<'a, T>>) -> Self {
105 let array = array.into();
106 Self {
107 layout: Layout::ChannelsFirst,
108 shape: FeatureMapShape::from_nchw_slice(array.shape()),
109 inner: array,
110 }
111 }
112
113 pub fn nhwc(array: impl Into<ArrayView4<'a, T>>) -> Self {
119 let array = array.into();
120 Self {
121 layout: Layout::ChannelsLast,
122 shape: FeatureMapShape::from_nhwc_slice(array.shape()),
123 inner: array,
124 }
125 }
126
127 pub fn layout(self) -> Layout {
129 self.layout
130 }
131
132 pub fn shape(self) -> FeatureMapShape {
134 self.shape
135 }
136
137 fn to_nhwc(self) -> ArrayView4<'a, T> {
138 match self.layout {
139 Layout::ChannelsFirst => self.inner.permuted_axes([0, 2, 3, 1]),
140 Layout::ChannelsLast => self.inner,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub(crate) struct Filters<T: ConvElement> {
148 inner: Buffer<T>,
149 biases: Option<Buffer<T::Acc>>,
150 filter_count: u32,
151 channel_count: u32,
152}
153
154impl<T: ConvElement> Filters<T> {
155 pub fn filter_count(&self) -> u32 {
156 self.filter_count
157 }
158
159 pub fn channel_count(&self) -> u32 {
160 self.channel_count
161 }
162
163 pub fn new<U: WithParams>(
164 filters: ArrayView4<'_, T>,
165 biases: Option<&[T::Acc]>,
166 conv: &Base<U>,
167 ) -> ocl::Result<Self> {
168 assert!(
169 filters.shape()[1] == conv.size() as usize
170 && filters.shape()[2] == conv.size() as usize,
171 "Invalid filter shape: expected {0}x{0}, got {1}x{2}",
172 conv.size(),
173 filters.shape()[1],
174 filters.shape()[2]
175 );
176 if let Some(biases) = biases {
177 assert_eq!(
178 filters.shape()[0],
179 biases.len(),
180 "Number of filter biases does not agree with the number of filters"
181 );
182 }
183
184 let filters_slice = filters.as_slice().map_or_else(
185 || Cow::Owned(filters.iter().copied().collect()),
186 Cow::Borrowed,
187 );
188 let filters_buffer = Buffer::builder()
189 .queue(conv.queue().clone())
190 .len(filters.shape().iter().product::<usize>())
191 .flags(flags::MEM_READ_ONLY)
192 .copy_host_slice(filters_slice.as_ref())
193 .build()?;
194
195 let filter_biases = biases
196 .map(|biases| {
197 Buffer::builder()
198 .queue(conv.queue().clone())
199 .len(biases.len())
200 .flags(flags::MEM_READ_ONLY)
201 .copy_host_slice(biases)
202 .build()
203 })
204 .transpose()?;
205
206 conv.kernel().set_arg("filters", &filters_buffer)?;
207 conv.kernel()
208 .set_arg("filter_biases", filter_biases.as_ref())?;
209
210 Ok(Self {
211 inner: filters_buffer,
212 biases: filter_biases,
213 filter_count: u32::try_from(filters.shape()[0])
214 .expect("Cannot convert filter count to `u32`"),
215 channel_count: u32::try_from(filters.shape()[3])
216 .expect("Cannot convert channel count to `u32`"),
217 })
218 }
219
220 pub fn pass_as_arguments(&self, kernel: &Kernel) -> ocl::Result<()> {
221 kernel.set_arg("filters", &self.inner)?;
222 if let Some(ref biases) = self.biases {
223 kernel.set_arg("filter_biases", biases)?;
224 }
225 Ok(())
226 }
227}
228
229#[derive(Debug, Clone)]
231pub(crate) struct InputAndOutput<T: ConvElement> {
232 signal_buffer: Buffer<T>,
233 signal_dims: Uint3,
234 output_buffer: Buffer<T>,
235 output_shape: FeatureMapShape,
236}
237
238impl<T: ConvElement> InputAndOutput<T> {
239 pub fn new<U: WithParams>(
240 signal_shape: FeatureMapShape,
241 filter_count: u32,
242 conv: &Base<U>,
243 ) -> ocl::Result<Self> {
244 let Params {
245 pads,
246 strides,
247 dilation,
248 ..
249 } = conv.params().into();
250 let effective_kernel_h = conv.size() + (dilation[0] - 1) * (conv.size() - 1);
251 let out_h = (signal_shape.height - effective_kernel_h + pads[0] + pads[2]) / strides[0] + 1;
252 let effective_kernel_w = conv.size() + (dilation[1] - 1) * (conv.size() - 1);
253 let out_w = (signal_shape.width - effective_kernel_w + pads[1] + pads[3]) / strides[1] + 1;
254 let output_shape = FeatureMapShape {
255 height: out_h,
256 width: out_w,
257 channels: filter_count,
258 ..signal_shape
259 };
260
261 let signal_buffer = Buffer::builder()
262 .queue(conv.queue().clone())
263 .len(signal_shape.buffer_len())
264 .flags(flags::MEM_READ_ONLY)
265 .build()?;
266 let output_buffer = Buffer::builder()
267 .queue(conv.queue().clone())
268 .len(output_shape.buffer_len())
269 .flags(flags::MEM_HOST_READ_ONLY | flags::MEM_WRITE_ONLY)
270 .build()?;
271
272 let signal_dims = Uint3::new(
273 signal_shape.height,
274 signal_shape.width,
275 signal_shape.channels,
276 );
277 Ok(InputAndOutput {
278 signal_buffer,
279 signal_dims,
280 output_buffer,
281 output_shape,
282 })
283 }
284
285 pub fn write_signal(&self, signal: FeatureMap<'_, T>) -> ocl::Result<()> {
286 let signal = signal.to_nhwc();
287 let signal_slice = signal.as_slice().map_or_else(
288 || Cow::Owned(signal.iter().copied().collect()),
289 Cow::Borrowed,
290 );
291 self.signal_buffer.write(signal_slice.as_ref()).enq()
292 }
293
294 pub fn pass_as_arguments(&self, kernel: &Kernel) -> ocl::Result<()> {
295 kernel.set_arg("signal_dims", self.signal_dims)
296 }
297
298 pub fn execute(&self, kernel: &Kernel, out_layout: Layout) -> ocl::Result<Array4<T>> {
299 let s = self.output_shape;
300 kernel.set_arg(
301 "out_params",
302 OutputParams {
303 batch_size: s.batch_size,
304 layout: out_layout,
305 },
306 )?;
307 kernel.set_arg("output", &self.output_buffer)?;
308 kernel.set_arg("signal", &self.signal_buffer)?;
309
310 let command = kernel.cmd().global_work_size([
311 s.height as usize * s.batch_size as usize,
312 s.width as usize,
313 s.channels as usize,
314 ]);
315 unsafe {
316 command.enq()?;
317 }
318
319 let mut output_data = vec![T::default(); self.output_buffer.len()];
320 self.output_buffer.read(&mut output_data).enq()?;
321 let output =
322 Array4::from_shape_vec(self.output_shape.as_array(out_layout), output_data).unwrap();
323 Ok(output)
324 }
325}
326
327#[derive(Debug, Clone)]
329pub(crate) struct Pinned<T: ConvElement> {
330 pub io: InputAndOutput<T>,
331 pub signal_shape: FeatureMapShape,
332}