ocl_convolution/
base.rs

1use lazy_static::lazy_static;
2use ndarray::{Array4, ArrayView4};
3use ocl::{
4    builders::KernelBuilder, prm::Uint3, Buffer, Context, Device, Kernel, Platform, ProQue,
5    Program, Queue,
6};
7
8use std::{convert::TryFrom, marker::PhantomData, sync::Mutex};
9
10use crate::{
11    buffers::{FeatureMap, FeatureMapShape, Filters, InputAndOutput, Pinned},
12    params::{OutputParams, Params, WithParams},
13    ConvElement,
14};
15
16/// Convolution builder. The same builder can be used to create multiple `Convolution`s
17/// which share the same spatial size.
18///
19/// A builder can be created using [`Convolution::f32()`](crate::Convolution::f32()) or
20/// [`Convolution::i8()`](crate::Convolution::i8()) methods.
21#[derive(Debug)]
22pub struct ConvolutionBuilder<T> {
23    program: ProQue,
24    filter_size: u32,
25    _element_type: PhantomData<T>,
26}
27
28impl<T: ConvElement> ConvolutionBuilder<T> {
29    /// Initializes a builder with a specific filter size.
30    pub(crate) fn new(
31        filter_size: u32,
32        defines: &[(&'static str, i32)],
33        source: &str,
34    ) -> ocl::Result<Self> {
35        assert_eq!(
36            filter_size % 2,
37            1,
38            "Even convolution sizes are not supported"
39        );
40
41        let mut program_builder = Program::builder();
42        program_builder.cmplr_def(
43            "FILTER_SIZE",
44            i32::try_from(filter_size).expect("Cannot convert filter size to i32"),
45        );
46        for &(name, value) in defines {
47            program_builder.cmplr_def(name, value);
48        }
49        program_builder.source(source);
50
51        // For some reason, certain OpenCL implementations (e.g., POCL) do not work well
52        // when the list of devices for a platform is queried from multiple threads.
53        // Hence, we introduce a `Mutex` to serialize these calls.
54        lazy_static! {
55            static ref MUTEX: Mutex<()> = Mutex::new(());
56        }
57        let (platform, device) = {
58            let _lock = MUTEX.lock().ok();
59            let platform = Platform::first()?;
60            (platform, Device::first(platform)?)
61        };
62
63        let context = Context::builder()
64            .platform(platform)
65            .devices(&device)
66            .build()?;
67        let program = ProQue::new(
68            context.clone(),
69            Queue::new(&context, device, None)?,
70            program_builder.build(&context)?,
71            None::<usize>,
72        );
73
74        Ok(Self {
75            program,
76            filter_size,
77            _element_type: PhantomData,
78        })
79    }
80
81    fn kernel_builder(&self) -> KernelBuilder<'_> {
82        self.program.kernel_builder("conv")
83    }
84}
85
86fn create_io<T: ConvElement, U: WithParams>(
87    signal_shape: FeatureMapShape,
88    filters: &Filters<T>,
89    conv: &Base<U>,
90) -> ocl::Result<InputAndOutput<T>> {
91    assert_eq!(
92        signal_shape.channels,
93        filters.channel_count() * Into::<Params>::into(conv.params).groups,
94        "Channel dimensionality in signal and filters must agree"
95    );
96    let io = InputAndOutput::new(signal_shape, filters.filter_count(), conv)?;
97    io.pass_as_arguments(&conv.kernel).map(|()| io)
98}
99
100#[derive(Debug)]
101pub(crate) struct Base<T: WithParams> {
102    size: u32,
103    params: T::Params,
104    kernel: Kernel,
105    buffers: T,
106    context: Context,
107}
108
109impl<T: WithParams> Base<T> {
110    pub fn kernel(&self) -> &Kernel {
111        &self.kernel
112    }
113
114    pub fn queue(&self) -> &Queue {
115        self.kernel
116            .default_queue()
117            .expect("kernel must come with a pre-configured queue")
118    }
119
120    pub fn size(&self) -> u32 {
121        self.size
122    }
123
124    pub fn params(&self) -> T::Params {
125        self.params
126    }
127
128    pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
129        self.params = params;
130        self.kernel
131            .set_arg("params", Into::<T::ClParams>::into(params))
132    }
133}
134
135impl<T: ConvElement> Base<PhantomData<T>> {
136    pub fn new(builder: &ConvolutionBuilder<T>, params: T::Params) -> ocl::Result<Self> {
137        let kernel = builder
138            .kernel_builder()
139            .arg_named("output", None::<&Buffer<T>>)
140            .arg_named("out_params", OutputParams::default())
141            .arg_named("signal", None::<&Buffer<T>>)
142            .arg_named("signal_dims", Uint3::new(0, 0, 0))
143            .arg_named("filters", None::<&Buffer<T>>)
144            .arg_named("filter_biases", None::<&Buffer<T::Acc>>)
145            .arg_named("params", Into::<T::ClParams>::into(params))
146            .build()?;
147        Ok(Base {
148            size: builder.filter_size,
149            params,
150            kernel,
151            buffers: PhantomData,
152            context: builder.program.context().clone(),
153        })
154    }
155
156    pub fn with_filters(
157        self,
158        filters: ArrayView4<'_, T>,
159        filter_biases: Option<&[T::Acc]>,
160    ) -> ocl::Result<Base<Filters<T>>> {
161        let filters = Filters::new(filters, filter_biases, &self)?;
162        Ok(Base {
163            buffers: filters,
164            size: self.size,
165            params: self.params,
166            kernel: self.kernel,
167            context: self.context,
168        })
169    }
170
171    pub fn compute(
172        &self,
173        signal: FeatureMap<'_, T>,
174        filters: ArrayView4<'_, T>,
175        filter_biases: Option<&[T::Acc]>,
176    ) -> ocl::Result<Array4<T>> {
177        let filter_channels =
178            u32::try_from(filters.shape()[3]).expect("Cannot convert filter dimension to `u32`");
179        assert_eq!(
180            signal.shape().channels,
181            filter_channels * Into::<Params>::into(self.params).groups,
182            "Channel dimensionality in signal and filters must agree"
183        );
184
185        let filter_count =
186            u32::try_from(filters.shape()[0]).expect("Cannot convert filter count to `u32`");
187        let filters = Filters::new(filters, filter_biases, self)?;
188        filters.pass_as_arguments(&self.kernel)?;
189        let io = InputAndOutput::new(signal.shape(), filter_count, self)?;
190        io.write_signal(signal)?;
191        io.pass_as_arguments(&self.kernel)?;
192        io.execute(&self.kernel, signal.layout())
193    }
194}
195
196impl<T: ConvElement> Base<Filters<T>> {
197    pub fn pinned(self, signal_shape: FeatureMapShape) -> ocl::Result<Base<Pinned<T>>> {
198        let io = create_io(signal_shape, &self.buffers, &self)?;
199        Ok(Base {
200            size: self.size,
201            params: self.params,
202            kernel: self.kernel,
203            buffers: Pinned { io, signal_shape },
204            context: self.context,
205        })
206    }
207
208    pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
209        let io = create_io(signal.shape(), &self.buffers, self)?;
210        io.write_signal(signal)?;
211        io.execute(&self.kernel, signal.layout())
212    }
213}
214
215impl<T: ConvElement> Base<Pinned<T>> {
216    pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
217        assert_eq!(
218            signal.shape(),
219            self.buffers.signal_shape,
220            "Signal dimensions differ from the ones set when pinning signal memory"
221        );
222        self.buffers.io.write_signal(signal)?;
223        self.buffers.io.execute(&self.kernel, signal.layout())
224    }
225}