1use lazy_static::lazy_static;
2use ndarray::{Array4, ArrayView4};
3use ocl::{
4 builders::KernelBuilder, prm::Uint3, Buffer, Context, Device, Kernel, Platform, ProQue,
5 Program, Queue,
6};
7
8use std::{convert::TryFrom, marker::PhantomData, sync::Mutex};
9
10use crate::{
11 buffers::{FeatureMap, FeatureMapShape, Filters, InputAndOutput, Pinned},
12 params::{OutputParams, Params, WithParams},
13 ConvElement,
14};
15
16#[derive(Debug)]
22pub struct ConvolutionBuilder<T> {
23 program: ProQue,
24 filter_size: u32,
25 _element_type: PhantomData<T>,
26}
27
28impl<T: ConvElement> ConvolutionBuilder<T> {
29 pub(crate) fn new(
31 filter_size: u32,
32 defines: &[(&'static str, i32)],
33 source: &str,
34 ) -> ocl::Result<Self> {
35 assert_eq!(
36 filter_size % 2,
37 1,
38 "Even convolution sizes are not supported"
39 );
40
41 let mut program_builder = Program::builder();
42 program_builder.cmplr_def(
43 "FILTER_SIZE",
44 i32::try_from(filter_size).expect("Cannot convert filter size to i32"),
45 );
46 for &(name, value) in defines {
47 program_builder.cmplr_def(name, value);
48 }
49 program_builder.source(source);
50
51 lazy_static! {
55 static ref MUTEX: Mutex<()> = Mutex::new(());
56 }
57 let (platform, device) = {
58 let _lock = MUTEX.lock().ok();
59 let platform = Platform::first()?;
60 (platform, Device::first(platform)?)
61 };
62
63 let context = Context::builder()
64 .platform(platform)
65 .devices(&device)
66 .build()?;
67 let program = ProQue::new(
68 context.clone(),
69 Queue::new(&context, device, None)?,
70 program_builder.build(&context)?,
71 None::<usize>,
72 );
73
74 Ok(Self {
75 program,
76 filter_size,
77 _element_type: PhantomData,
78 })
79 }
80
81 fn kernel_builder(&self) -> KernelBuilder<'_> {
82 self.program.kernel_builder("conv")
83 }
84}
85
86fn create_io<T: ConvElement, U: WithParams>(
87 signal_shape: FeatureMapShape,
88 filters: &Filters<T>,
89 conv: &Base<U>,
90) -> ocl::Result<InputAndOutput<T>> {
91 assert_eq!(
92 signal_shape.channels,
93 filters.channel_count() * Into::<Params>::into(conv.params).groups,
94 "Channel dimensionality in signal and filters must agree"
95 );
96 let io = InputAndOutput::new(signal_shape, filters.filter_count(), conv)?;
97 io.pass_as_arguments(&conv.kernel).map(|()| io)
98}
99
100#[derive(Debug)]
101pub(crate) struct Base<T: WithParams> {
102 size: u32,
103 params: T::Params,
104 kernel: Kernel,
105 buffers: T,
106 context: Context,
107}
108
109impl<T: WithParams> Base<T> {
110 pub fn kernel(&self) -> &Kernel {
111 &self.kernel
112 }
113
114 pub fn queue(&self) -> &Queue {
115 self.kernel
116 .default_queue()
117 .expect("kernel must come with a pre-configured queue")
118 }
119
120 pub fn size(&self) -> u32 {
121 self.size
122 }
123
124 pub fn params(&self) -> T::Params {
125 self.params
126 }
127
128 pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
129 self.params = params;
130 self.kernel
131 .set_arg("params", Into::<T::ClParams>::into(params))
132 }
133}
134
135impl<T: ConvElement> Base<PhantomData<T>> {
136 pub fn new(builder: &ConvolutionBuilder<T>, params: T::Params) -> ocl::Result<Self> {
137 let kernel = builder
138 .kernel_builder()
139 .arg_named("output", None::<&Buffer<T>>)
140 .arg_named("out_params", OutputParams::default())
141 .arg_named("signal", None::<&Buffer<T>>)
142 .arg_named("signal_dims", Uint3::new(0, 0, 0))
143 .arg_named("filters", None::<&Buffer<T>>)
144 .arg_named("filter_biases", None::<&Buffer<T::Acc>>)
145 .arg_named("params", Into::<T::ClParams>::into(params))
146 .build()?;
147 Ok(Base {
148 size: builder.filter_size,
149 params,
150 kernel,
151 buffers: PhantomData,
152 context: builder.program.context().clone(),
153 })
154 }
155
156 pub fn with_filters(
157 self,
158 filters: ArrayView4<'_, T>,
159 filter_biases: Option<&[T::Acc]>,
160 ) -> ocl::Result<Base<Filters<T>>> {
161 let filters = Filters::new(filters, filter_biases, &self)?;
162 Ok(Base {
163 buffers: filters,
164 size: self.size,
165 params: self.params,
166 kernel: self.kernel,
167 context: self.context,
168 })
169 }
170
171 pub fn compute(
172 &self,
173 signal: FeatureMap<'_, T>,
174 filters: ArrayView4<'_, T>,
175 filter_biases: Option<&[T::Acc]>,
176 ) -> ocl::Result<Array4<T>> {
177 let filter_channels =
178 u32::try_from(filters.shape()[3]).expect("Cannot convert filter dimension to `u32`");
179 assert_eq!(
180 signal.shape().channels,
181 filter_channels * Into::<Params>::into(self.params).groups,
182 "Channel dimensionality in signal and filters must agree"
183 );
184
185 let filter_count =
186 u32::try_from(filters.shape()[0]).expect("Cannot convert filter count to `u32`");
187 let filters = Filters::new(filters, filter_biases, self)?;
188 filters.pass_as_arguments(&self.kernel)?;
189 let io = InputAndOutput::new(signal.shape(), filter_count, self)?;
190 io.write_signal(signal)?;
191 io.pass_as_arguments(&self.kernel)?;
192 io.execute(&self.kernel, signal.layout())
193 }
194}
195
196impl<T: ConvElement> Base<Filters<T>> {
197 pub fn pinned(self, signal_shape: FeatureMapShape) -> ocl::Result<Base<Pinned<T>>> {
198 let io = create_io(signal_shape, &self.buffers, &self)?;
199 Ok(Base {
200 size: self.size,
201 params: self.params,
202 kernel: self.kernel,
203 buffers: Pinned { io, signal_shape },
204 context: self.context,
205 })
206 }
207
208 pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
209 let io = create_io(signal.shape(), &self.buffers, self)?;
210 io.write_signal(signal)?;
211 io.execute(&self.kernel, signal.layout())
212 }
213}
214
215impl<T: ConvElement> Base<Pinned<T>> {
216 pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
217 assert_eq!(
218 signal.shape(),
219 self.buffers.signal_shape,
220 "Signal dimensions differ from the ones set when pinning signal memory"
221 );
222 self.buffers.io.write_signal(signal)?;
223 self.buffers.io.execute(&self.kernel, signal.layout())
224 }
225}