ocl_convolution/
params.rs

1//! Convolution parameters.
2
3use ocl::{
4    prm::{Uint2, Uint4},
5    OclPrm,
6};
7
8use std::marker::PhantomData;
9
10use crate::{
11    buffers::{Filters, Layout, Pinned},
12    ConvElement,
13};
14
15/// General convolution parameters.
16///
17/// The parameters translate to the parameters of the [`Conv` ONNX operator][onnx-conv].
18///
19/// [onnx-conv]: https://github.com/onnx/onnx/blob/master/docs/Operators.md#conv
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct Params {
22    /// Strides along spatial dimensions.
23    pub strides: [u32; 2],
24    /// Pads along spatial dimensions. The first 2 values denote pads at the beginning of
25    /// rows / columns, the second 2 values – pads at the end.
26    pub pads: [u32; 4],
27    /// Number of groups in the convolution. Each group of filters will be applied to
28    /// a subset of input channels.
29    pub groups: u32,
30    /// Signal dilation along spatial dimensions.
31    pub dilation: [u32; 2],
32}
33
34impl Default for Params {
35    fn default() -> Self {
36        Self {
37            strides: [1, 1],
38            pads: [0; 4],
39            groups: 1,
40            dilation: [1, 1],
41        }
42    }
43}
44
45#[derive(Debug, Default, Clone, Copy, PartialEq)]
46#[repr(C, packed)]
47pub struct ClParams {
48    strides: Uint2,
49    pads: Uint4,
50    groups: u32,
51    dilation: Uint2,
52}
53
54impl From<Params> for ClParams {
55    fn from(value: Params) -> Self {
56        ClParams {
57            strides: Uint2::from(value.strides),
58            pads: Uint4::from(value.pads),
59            groups: value.groups,
60            dilation: Uint2::from(value.dilation),
61        }
62    }
63}
64
65// Safety ensured by the same alignment here and in OCL code.
66unsafe impl OclPrm for ClParams {}
67
68/// Params for the quantized convolution.
69///
70/// See [`Convolution`] docs for details how to set these parameters.
71///
72/// [`Convolution`]: crate::Convolution#connection-to-real-value-convolution
73#[derive(Debug, Clone, Copy)]
74pub struct I8Params {
75    /// Common parameters.
76    pub common: Params,
77    /// Upscale bit shift.
78    pub bit_shift: u8,
79    /// Fixed-point scale of the post-convolution transform.
80    pub scale: i32,
81    /// Bias for the post-convolution transform.
82    pub output_bias: i32,
83    /// Bias for the signal.
84    pub signal_bias: i32,
85    /// Bias for the filters.
86    pub filter_bias: i32,
87}
88
89impl From<I8Params> for Params {
90    fn from(value: I8Params) -> Self {
91        value.common
92    }
93}
94
95impl I8Params {
96    /// Converts `scale` to fixed-point presentation. The resulting value can be used
97    /// as the `scale` field.
98    ///
99    /// # Panics
100    ///
101    /// - Panics if the converted value is outside the `i32` bounds.
102    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
103    // ^-- precision loss is OK in general, and we perform `i32` range check before
104    // using `_ as i32` on an `f32` value.
105    pub fn convert_scale(bit_shift: u8, scale: f32) -> i32 {
106        let scale = (2.0_f32.powi(i32::from(bit_shift)) * scale).round();
107        assert!(
108            scale >= i32::MIN as f32 && scale <= i32::MAX as f32,
109            "Scale is out of `i32` bounds"
110        );
111        scale as i32
112    }
113}
114
115#[derive(Debug, Default, Clone, Copy, PartialEq)]
116#[repr(C, packed)]
117pub struct ClI8Params {
118    strides: Uint2,
119    pads: Uint4,
120    group: u32,
121    dilation: Uint2,
122    bit_shift: i32,
123    scale: i32,
124    output_bias: i32,
125    signal_bias: i32,
126    filter_bias: i32,
127}
128
129impl From<I8Params> for ClI8Params {
130    fn from(value: I8Params) -> Self {
131        let common_params = ClParams::from(value.common);
132        ClI8Params {
133            strides: common_params.strides,
134            pads: common_params.pads,
135            group: common_params.groups,
136            dilation: common_params.dilation,
137            bit_shift: i32::from(value.bit_shift),
138            scale: value.scale,
139            output_bias: value.output_bias,
140            signal_bias: value.signal_bias,
141            filter_bias: value.filter_bias,
142        }
143    }
144}
145
146// Safety ensured by the same alignment here and in OCL code.
147unsafe impl OclPrm for ClI8Params {}
148
149#[derive(Debug, Clone, Copy, PartialEq, Hash)]
150#[repr(C, packed)]
151pub(crate) struct OutputParams {
152    pub batch_size: u32,
153    pub layout: Layout,
154}
155
156unsafe impl OclPrm for OutputParams {}
157
158impl Default for OutputParams {
159    fn default() -> Self {
160        Self {
161            batch_size: 0,
162            layout: Layout::ChannelsLast,
163        }
164    }
165}
166
167pub(crate) trait WithParams {
168    type Params: Copy + Into<Params> + Into<Self::ClParams>;
169    type ClParams: OclPrm;
170}
171
172impl<T: ConvElement> WithParams for PhantomData<T> {
173    type Params = T::Params;
174    type ClParams = T::ClParams;
175}
176
177impl<T: ConvElement> WithParams for Filters<T> {
178    type Params = T::Params;
179    type ClParams = T::ClParams;
180}
181
182impl<T: ConvElement> WithParams for Pinned<T> {
183    type Params = T::Params;
184    type ClParams = T::ClParams;
185}