custos_math/raw_ops/cpu/
correlate.rs

1use custos::number::Number;
2
3pub fn correlate_valid_mut<T: Number>(
4    lhs_slice: &[T],
5    lhs_dims: (usize, usize),
6    kernel_slice: &[T],
7    kernel_dims: (usize, usize),
8    out: &mut [T],
9) {
10    let (lhs_rows, lhs_cols) = lhs_dims;
11    let (kernel_rows, kernel_cols) = kernel_dims;
12
13    let (out_rows, out_cols) = (lhs_rows - kernel_rows + 1, lhs_cols - kernel_cols + 1);
14
15    //loop for row-axis (y)
16    //moves multiplication 1 down
17    for y in 0..out_rows {
18        //loop for col-axis (x)
19        //moves multiplication 1 to the right
20        for x in 0..out_cols {
21            let mut sum = T::default();
22            //repeat kernel rows times to use move through all kernel rows
23            for idx in 0..kernel_rows {
24                let index = idx * lhs_cols + x + y * lhs_cols;
25                let lhs_kernel_row = &lhs_slice[index..index + kernel_cols];
26
27                let index = idx * kernel_cols;
28                let kernel_row = &kernel_slice[index..index + kernel_cols];
29
30                for (i, value) in lhs_kernel_row.iter().enumerate() {
31                    sum += *value * kernel_row[i];
32                }
33            }
34            // y * final_cols + x
35            out[y * out_cols + x] = sum;
36        }
37    }
38}
39
40#[cfg(not(feature = "no-std"))]
41pub fn add_full_padding<T: Number>(
42    lhs: &[T],
43    lhs_dims: (usize, usize),
44    kernel_dims: (usize, usize),
45) -> (Vec<T>, usize, usize) {
46    let (lhs_rows, lhs_cols) = lhs_dims;
47    let (kernel_rows, kernel_cols) = kernel_dims;
48
49    let (row_adding, col_adding) = ((kernel_rows - 1) * 2, (kernel_cols - 1) * 2);
50    let (out_rows, out_cols) = (lhs_rows + row_adding, lhs_cols + col_adding);
51
52    let mut out = vec![T::default(); out_rows * out_cols];
53
54    for row in 0..lhs_rows {
55        let idx = row * lhs_cols;
56        let lhs_row = &lhs[idx..idx + lhs_cols];
57
58        let index = (row + (kernel_rows - 1)) * (out_cols) + (kernel_cols - 1);
59        let out_row = &mut out[index..index + out_cols];
60
61        for (idx, value) in lhs_row.iter().enumerate() {
62            out_row[idx] = *value;
63        }
64    }
65    (out, out_rows, out_cols)
66}
67
68#[cfg(not(feature = "no-std"))]
69pub fn rot_kernel<T: Number>(kernel: &[T], kernel_shape: (usize, usize)) -> Vec<T> {
70    let (kernel_rows, kernel_cols) = kernel_shape;
71    let mut rotated = vec![T::default(); kernel.len()];
72
73    for (idx_rev, idx) in (0..kernel_rows).rev().zip(0..kernel_rows) {
74        let row_idx = idx_rev * kernel_cols;
75        let row = &kernel[row_idx..row_idx + kernel_cols];
76
77        for (i, value) in row.iter().rev().enumerate() {
78            rotated[idx * kernel_cols + i] = *value;
79        }
80    }
81    rotated
82}