ocl_convolution/
params.rs1use ocl::{
4 prm::{Uint2, Uint4},
5 OclPrm,
6};
7
8use std::marker::PhantomData;
9
10use crate::{
11 buffers::{Filters, Layout, Pinned},
12 ConvElement,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct Params {
22 pub strides: [u32; 2],
24 pub pads: [u32; 4],
27 pub groups: u32,
30 pub dilation: [u32; 2],
32}
33
34impl Default for Params {
35 fn default() -> Self {
36 Self {
37 strides: [1, 1],
38 pads: [0; 4],
39 groups: 1,
40 dilation: [1, 1],
41 }
42 }
43}
44
45#[derive(Debug, Default, Clone, Copy, PartialEq)]
46#[repr(C, packed)]
47pub struct ClParams {
48 strides: Uint2,
49 pads: Uint4,
50 groups: u32,
51 dilation: Uint2,
52}
53
54impl From<Params> for ClParams {
55 fn from(value: Params) -> Self {
56 ClParams {
57 strides: Uint2::from(value.strides),
58 pads: Uint4::from(value.pads),
59 groups: value.groups,
60 dilation: Uint2::from(value.dilation),
61 }
62 }
63}
64
65unsafe impl OclPrm for ClParams {}
67
68#[derive(Debug, Clone, Copy)]
74pub struct I8Params {
75 pub common: Params,
77 pub bit_shift: u8,
79 pub scale: i32,
81 pub output_bias: i32,
83 pub signal_bias: i32,
85 pub filter_bias: i32,
87}
88
89impl From<I8Params> for Params {
90 fn from(value: I8Params) -> Self {
91 value.common
92 }
93}
94
95impl I8Params {
96 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
103 pub fn convert_scale(bit_shift: u8, scale: f32) -> i32 {
106 let scale = (2.0_f32.powi(i32::from(bit_shift)) * scale).round();
107 assert!(
108 scale >= i32::MIN as f32 && scale <= i32::MAX as f32,
109 "Scale is out of `i32` bounds"
110 );
111 scale as i32
112 }
113}
114
115#[derive(Debug, Default, Clone, Copy, PartialEq)]
116#[repr(C, packed)]
117pub struct ClI8Params {
118 strides: Uint2,
119 pads: Uint4,
120 group: u32,
121 dilation: Uint2,
122 bit_shift: i32,
123 scale: i32,
124 output_bias: i32,
125 signal_bias: i32,
126 filter_bias: i32,
127}
128
129impl From<I8Params> for ClI8Params {
130 fn from(value: I8Params) -> Self {
131 let common_params = ClParams::from(value.common);
132 ClI8Params {
133 strides: common_params.strides,
134 pads: common_params.pads,
135 group: common_params.groups,
136 dilation: common_params.dilation,
137 bit_shift: i32::from(value.bit_shift),
138 scale: value.scale,
139 output_bias: value.output_bias,
140 signal_bias: value.signal_bias,
141 filter_bias: value.filter_bias,
142 }
143 }
144}
145
146unsafe impl OclPrm for ClI8Params {}
148
149#[derive(Debug, Clone, Copy, PartialEq, Hash)]
150#[repr(C, packed)]
151pub(crate) struct OutputParams {
152 pub batch_size: u32,
153 pub layout: Layout,
154}
155
156unsafe impl OclPrm for OutputParams {}
157
158impl Default for OutputParams {
159 fn default() -> Self {
160 Self {
161 batch_size: 0,
162 layout: Layout::ChannelsLast,
163 }
164 }
165}
166
167pub(crate) trait WithParams {
168 type Params: Copy + Into<Params> + Into<Self::ClParams>;
169 type ClParams: OclPrm;
170}
171
172impl<T: ConvElement> WithParams for PhantomData<T> {
173 type Params = T::Params;
174 type ClParams = T::ClParams;
175}
176
177impl<T: ConvElement> WithParams for Filters<T> {
178 type Params = T::Params;
179 type ClParams = T::ClParams;
180}
181
182impl<T: ConvElement> WithParams for Pinned<T> {
183 type Params = T::Params;
184 type ClParams = T::ClParams;
185}