1use crate::RnnNetworkMode;
4use crate::{DirectionMode, RnnInputMode};
5use co::frameworks::native::flatbox::FlatBox;
6use co::plugin::numeric_helpers::Float;
7use co::plugin::Error as PluginError;
8use coaster as co;
9
10#[derive(Debug, Copy, Clone)]
11#[allow(missing_docs)]
12pub struct NormalizationConfig;
13
14#[derive(Debug, Clone)]
15#[allow(missing_docs)]
16pub struct PoolingConfig {
17 pub window: Vec<i32>,
18 pub padding: Vec<i32>,
19 pub stride: Vec<i32>,
21}
22
23#[derive(Debug, Copy, Clone)]
24#[allow(missing_docs)]
25pub struct DropoutConfig {
26 pub probability: f32,
27 pub seed: u64,
28}
29
30macro_rules! read {
33 ($x:ident, $t:ident, $slf:ident) => {
34 $x.read($slf.device()).unwrap().as_slice::<$t>()
35 };
36}
37
38macro_rules! read_write {
41 ($x:ident, $t: ident, $slf:ident) => {
42 $x.read_write($slf.device()).unwrap().as_mut_slice::<$t>()
43 };
44}
45
46macro_rules! write_only {
49 ($x:ident, $t: ident, $slf:ident) => {
50 $x.write_only($slf.device()).unwrap().as_mut_slice::<$t>()
51 };
52}
53
54pub fn write_to_memory<T: Iterator>(mem: &mut FlatBox, data: T)
56where
57 T::Item: Clone,
58{
59 let mem_buffer = mem.as_mut_slice::<T::Item>();
60 for (index, datum) in data.enumerate() {
61 mem_buffer[index] = datum;
62 }
63}
64
65pub fn sigmoid<T: Float>(x: T) -> T {
67 (T::one()) / (T::one() + (-x).exp())
68}
69
70pub fn sigmoid_grad<T: Float>(x: T, dx: T) -> T {
72 x * (T::one() - x) * dx
73}
74
75pub fn relu<T: Float>(x: T) -> T {
77 let x: T = x.clone();
78 x.max(T::zero())
79}
80
81pub fn relu_grad<T: Float>(x: T, dx: T) -> T {
83 if x > T::zero() {
84 return dx;
85 }
86 T::zero()
87}
88
89pub fn tanh<T: Float>(x: T) -> T {
91 x.tanh()
92}
93
94pub fn tanh_grad<T: Float>(x: T, dx: T) -> T {
97 (T::one() - x.powi(2)) * dx
98}
99
100#[macro_export]
102macro_rules! impl_ops_sigmoid_for {
103 ($t:ident, $b:ty) => {
104 impl Sigmoid<$t> for $b {
105 fn sigmoid(
106 &self,
107 x: &SharedTensor<$t>,
108 result: &mut SharedTensor<$t>,
109 ) -> Result<(), Error> {
110 map1(
111 read!(x, $t, self),
112 write_only!(result, $t, self),
113 crate::frameworks::native::helper::sigmoid,
114 )
115 }
116
117 fn sigmoid_grad(
118 &self,
119 x: &SharedTensor<$t>,
120 x_diff: &SharedTensor<$t>,
121 result: &SharedTensor<$t>,
122 result_diff: &mut SharedTensor<$t>,
123 ) -> Result<(), Error> {
124 map2(
125 read!(x, $t, self),
126 read!(x_diff, $t, self),
127 write_only!(result_diff, $t, self),
128 crate::frameworks::native::helper::sigmoid_grad,
129 )
130 }
131 }
132
133 impl SigmoidPointwise<$t> for $b {
134 fn sigmoid_pointwise(&self, x: &mut SharedTensor<$t>) -> Result<(), Error> {
135 map1_inplace(
136 read_write!(x, $t, self),
137 crate::frameworks::native::helper::sigmoid,
138 )
139 }
140
141 fn sigmoid_pointwise_grad(
142 &self,
143 x: &SharedTensor<$t>,
144 x_diff: &mut SharedTensor<$t>,
145 ) -> Result<(), $crate::co::error::Error> {
146 return map2_inplace(
147 read!(x, $t, self),
148 read_write!(x_diff, $t, self),
149 crate::frameworks::native::helper::sigmoid_grad,
150 );
151 }
152 }
153 };
154}
155
156#[macro_export]
158macro_rules! impl_ops_relu_for {
159 ($t:ident, $b:ty) => {
160 impl Relu<$t> for $b {
161 fn relu(
162 &self,
163 x: &SharedTensor<$t>,
164 result: &mut SharedTensor<$t>,
165 ) -> Result<(), $crate::co::error::Error> {
166 map1(
167 read!(x, $t, self),
168 write_only!(result, $t, self),
169 crate::frameworks::native::helper::relu,
170 )
171 }
172
173 fn relu_grad(
174 &self,
175 x: &SharedTensor<$t>,
176 x_diff: &SharedTensor<$t>,
177 result: &SharedTensor<$t>,
178 result_diff: &mut SharedTensor<$t>,
179 ) -> Result<(), Error> {
180 map2(
181 read!(x, $t, self),
182 read!(x_diff, $t, self),
183 write_only!(result_diff, $t, self),
184 crate::frameworks::native::helper::relu_grad,
185 )
186 }
187 }
188 impl ReluPointwise<$t> for $b {
189 fn relu_pointwise(
190 &self,
191 x: &mut SharedTensor<$t>,
192 ) -> Result<(), $crate::co::error::Error> {
193 map1_inplace(
194 read_write!(x, $t, self),
195 crate::frameworks::native::helper::relu,
196 )
197 }
198
199 fn relu_pointwise_grad(
200 &self,
201 x: &SharedTensor<$t>,
202 x_diff: &mut SharedTensor<$t>,
203 ) -> Result<(), $crate::co::error::Error> {
204 map2_inplace(
205 read!(x, $t, self),
206 read_write!(x_diff, $t, self),
207 crate::frameworks::native::helper::relu_grad,
208 )
209 }
210 }
211 };
212}
213
214#[macro_export]
216macro_rules! impl_ops_tanh_for {
217 ($t:ident, $b:ty) => {
218 impl $crate::plugin::Tanh<$t> for $b {
219 fn tanh(
220 &self,
221 x: &SharedTensor<$t>,
222 result: &mut SharedTensor<$t>,
223 ) -> Result<(), $crate::co::error::Error> {
224 map1(
225 read!(x, $t, self),
226 write_only!(result, $t, self),
227 crate::frameworks::native::helper::tanh,
228 )
229 }
230
231 fn tanh_grad(
232 &self,
233 x: &SharedTensor<$t>,
234 x_diff: &SharedTensor<$t>,
235 result: &SharedTensor<$t>,
236 result_diff: &mut SharedTensor<$t>,
237 ) -> Result<(), Error> {
238 map2(
239 read!(x, $t, self),
240 read!(x_diff, $t, self),
241 write_only!(result_diff, $t, self),
242 crate::frameworks::native::helper::tanh_grad,
243 )
244 }
245 }
246 impl $crate::plugin::TanhPointwise<$t> for $b {
247 fn tanh_pointwise(
248 &self,
249 x: &mut SharedTensor<$t>,
250 ) -> Result<(), $crate::co::error::Error> {
251 map1_inplace(
252 read_write!(x, $t, self),
253 crate::frameworks::native::helper::tanh,
254 )
255 }
256
257 fn tanh_pointwise_grad(
258 &self,
259 x: &SharedTensor<$t>,
260 x_diff: &mut SharedTensor<$t>,
261 ) -> Result<(), Error> {
262 map2_inplace(
263 read!(x, $t, self),
264 read_write!(x_diff, $t, self),
265 crate::frameworks::native::helper::tanh_grad,
266 )
267 }
268 }
269 };
270}
271
272#[derive(Debug, Clone)]
273#[allow(missing_docs)]
274pub struct ConvolutionConfig {
275 pub filter_shape: Vec<usize>,
276 pub stride: Vec<i32>,
277 pub padding: Vec<i32>,
278}
279
280#[derive(Debug, Clone, Copy)]
281#[allow(missing_docs)]
282pub struct RnnConfig {
284 pub hidden_size: usize,
286 pub num_layers: usize,
288 pub dropout_probability: f32,
290 pub dropout_seed: u64,
292 pub rnn_type: RnnNetworkMode,
294 pub input_mode: RnnInputMode,
296 pub direction_mode: DirectionMode,
298}
299
300#[macro_export]
302macro_rules! impl_ops_softmax_for {
303 ($t:ident, $b:ty) => {
304 impl $crate::plugin::Softmax<$t> for $b {
305 fn softmax(
306 &self,
307 x: &SharedTensor<$t>,
308 result: &mut SharedTensor<$t>,
309 ) -> Result<(), Error> {
310 let xs = read!(x, $t, self);
311 let rs = write_only!(result, $t, self);
312
313 map1(xs, rs, |v| v.exp())?;
314
315 let mut sum: $t = 0.0; for r in &*rs {
317 sum += *r;
318 }
319 for r in rs {
320 *r /= sum;
321 }
322 Ok(())
323 }
324
325 fn softmax_grad(
327 &self,
328 x: &SharedTensor<$t>,
329 x_diff: &SharedTensor<$t>,
330 result_diff: &mut SharedTensor<$t>,
331 ) -> Result<(), Error> {
332 let xs = read!(x, $t, self);
333 let dxs = read!(x_diff, $t, self);
334 let drs = write_only!(result_diff, $t, self);
335
336 let mut dot: $t = 0.0;
337 for (t, dt) in xs.iter().zip(dxs.iter()) {
338 dot += t * dt;
339 }
340
341 map2(xs, dxs, drs, |t, dt| t * (dt - dot))
342 }
343 }
344 };
345}
346
347#[macro_export]
349macro_rules! impl_ops_log_softmax_for {
350 ($t:ident, $b:ty) => {
351 impl $crate::plugin::LogSoftmax<$t> for $b {
352 fn log_softmax(
353 &self,
354 x: &SharedTensor<$t>,
355 result: &mut SharedTensor<$t>,
356 ) -> Result<(), $crate::co::error::Error> {
357 let xs = read!(x, $t, self);
358 let rs = write_only!(result, $t, self);
359
360 let max_x = xs
361 .iter()
362 .fold(::std::$t::NEG_INFINITY, |acc, &t| acc.max(t));
363
364 let mut logsum: $t = 0.0;
365 for t in xs {
366 logsum += (-(max_x - t)).exp();
367 }
368 logsum = max_x + logsum.ln();
369
370 map1(xs, rs, |t| t - logsum)
371 }
372
373 fn log_softmax_grad(
374 &self,
375 x: &SharedTensor<$t>,
376 x_diff: &SharedTensor<$t>,
377 result_diff: &mut SharedTensor<$t>,
378 ) -> Result<(), $crate::co::error::Error> {
379 let xs = read!(x, $t, self);
380 let dxs = read!(x_diff, $t, self);
381 let drs = write_only!(result_diff, $t, self);
382
383 let mut sum: $t = 0.0;
384 for &grad_val in dxs.iter() {
385 sum += grad_val;
386 }
387 map2(xs, dxs, drs, |t, dt| dt - t.exp() * sum)
388 }
389 }
390 };
391}
392
393#[macro_export]
396macro_rules! impl_ops_lrn_for {
397 ($t:ident, $b:ty) => {
398 impl ::plugin::LRN<$t> for $b {
399 fn new_lrn_config(
400 &self,
401 n: u32,
402 alpha: f64,
403 beta: f64,
404 k: f64,
405 ) -> Result<Self::CLRN, ::co::error::Error> {
406 unimplemented!();
407 Ok(::frameworks::native::helper::NormalizationConfig)
408 }
409
410 fn lrn(
411 &self,
412 x: &mut SharedTensor<$t>,
413 result: &mut SharedTensor<$t>,
414 config: &Self::CLRN,
415 ) -> Result<(), ::co::error::Error> {
416 unimplemented!();
417 Ok(())
418 }
419
420 fn lrn_plain(
421 &self,
422 x: &SharedTensor<$t>,
423 result: &mut SharedTensor<$t>,
424 config: &Self::CLRN,
425 ) -> Result<(), ::co::error::Error> {
426 unimplemented!();
427 Ok(())
428 }
429
430 fn lrn_grad(
431 &self,
432 x: &mut SharedTensor<$t>,
433 x_diff: &mut SharedTensor<$t>,
434 result: &mut SharedTensor<$t>,
435 result_diff: &mut SharedTensor<$t>,
436 config: &Self::CLRN,
437 ) -> Result<(), ::co::error::Error> {
438 unimplemented!();
439 Ok(())
440 }
441
442 fn lrn_grad_plain(
443 &self,
444 x: &SharedTensor<$t>,
445 x_diff: &SharedTensor<$t>,
446 result: &SharedTensor<$t>,
447 result_diff: &mut SharedTensor<$t>,
448 config: &Self::CLRN,
449 ) -> Result<(), ::co::error::Error> {
450 unimplemented!();
451 Ok(())
452 }
453 }
454 };
455}