custos_math/raw_ops/cpu/
correlate.rs1use custos::number::Number;
2
3pub fn correlate_valid_mut<T: Number>(
4 lhs_slice: &[T],
5 lhs_dims: (usize, usize),
6 kernel_slice: &[T],
7 kernel_dims: (usize, usize),
8 out: &mut [T],
9) {
10 let (lhs_rows, lhs_cols) = lhs_dims;
11 let (kernel_rows, kernel_cols) = kernel_dims;
12
13 let (out_rows, out_cols) = (lhs_rows - kernel_rows + 1, lhs_cols - kernel_cols + 1);
14
15 for y in 0..out_rows {
18 for x in 0..out_cols {
21 let mut sum = T::default();
22 for idx in 0..kernel_rows {
24 let index = idx * lhs_cols + x + y * lhs_cols;
25 let lhs_kernel_row = &lhs_slice[index..index + kernel_cols];
26
27 let index = idx * kernel_cols;
28 let kernel_row = &kernel_slice[index..index + kernel_cols];
29
30 for (i, value) in lhs_kernel_row.iter().enumerate() {
31 sum += *value * kernel_row[i];
32 }
33 }
34 out[y * out_cols + x] = sum;
36 }
37 }
38}
39
40#[cfg(not(feature = "no-std"))]
41pub fn add_full_padding<T: Number>(
42 lhs: &[T],
43 lhs_dims: (usize, usize),
44 kernel_dims: (usize, usize),
45) -> (Vec<T>, usize, usize) {
46 let (lhs_rows, lhs_cols) = lhs_dims;
47 let (kernel_rows, kernel_cols) = kernel_dims;
48
49 let (row_adding, col_adding) = ((kernel_rows - 1) * 2, (kernel_cols - 1) * 2);
50 let (out_rows, out_cols) = (lhs_rows + row_adding, lhs_cols + col_adding);
51
52 let mut out = vec![T::default(); out_rows * out_cols];
53
54 for row in 0..lhs_rows {
55 let idx = row * lhs_cols;
56 let lhs_row = &lhs[idx..idx + lhs_cols];
57
58 let index = (row + (kernel_rows - 1)) * (out_cols) + (kernel_cols - 1);
59 let out_row = &mut out[index..index + out_cols];
60
61 for (idx, value) in lhs_row.iter().enumerate() {
62 out_row[idx] = *value;
63 }
64 }
65 (out, out_rows, out_cols)
66}
67
68#[cfg(not(feature = "no-std"))]
69pub fn rot_kernel<T: Number>(kernel: &[T], kernel_shape: (usize, usize)) -> Vec<T> {
70 let (kernel_rows, kernel_cols) = kernel_shape;
71 let mut rotated = vec![T::default(); kernel.len()];
72
73 for (idx_rev, idx) in (0..kernel_rows).rev().zip(0..kernel_rows) {
74 let row_idx = idx_rev * kernel_cols;
75 let row = &kernel[row_idx..row_idx + kernel_cols];
76
77 for (i, value) in row.iter().rev().enumerate() {
78 rotated[idx * kernel_cols + i] = *value;
79 }
80 }
81 rotated
82}