bmls/
axis_add.rs

1
2use itertools::izip;
3use crate::error::BMLSError;
4use crate::error;
5
6/// ## Inputs
7/// - X1: Input (NCHW)
8/// - X2: Values to Add to X1 (AXIS x 1)
9/// - Y: Output (NCHW)
10/// - Dim: Dimensions of Y and X1.
11/// - Axis: Axis to iterate
12#[inline]
13pub fn axis_add(
14    x1: &[f32],
15    x2: &[f32],
16    y: &mut [f32],
17    dim: [usize; 4],
18    axis: usize,
19) -> Result<(), BMLSError> {
20    // the expected lengths of X1 and Y.
21    let len = dim[0]*dim[1]*dim[2]*dim[3];
22
23    if y.len() != len {
24        return error::length_mismatch("Y", y.len(), "Dim", len)
25    }
26
27    if x1.len() != y.len() {
28        return error::length_mismatch("X1", x1.len(), "Y", y.len())
29    }
30
31    if x2.len() != dim[axis] {
32        return error::axis_mismatch(0, "X2", x2.len(), axis, "Y", dim[axis])
33    }
34
35    let cptr = y;
36    for n in 0..dim[0] {
37        for c in 0..dim[1] {
38            for h in 0..dim[2] {
39                for w in 0..dim[3] {
40                    let i = n * dim[1] * dim[2] * dim[3] + c * dim[2] * dim[3] + h * dim[3] + w;
41
42                    let indices = [n, c, h, w];
43
44                    cptr[i] = x1[i] + x2[indices[axis]];
45                }
46            }
47        }
48    }
49
50    Ok(())
51}
52
53#[inline]
54pub fn axis_add_wrt_x1(
55    gy: &[f32],
56    g1: &mut [f32],
57) -> Result<(), BMLSError> {
58    if gy.len() != g1.len() {
59        return error::length_mismatch("GY", gy.len(), "G1", g1.len())
60    }
61
62    for (gy, g1) in izip!(gy, g1) {
63        *g1 += *gy
64    }
65
66    Ok(())
67}
68
69#[inline]
70pub fn axis_add_wrt_x2(
71    gy: &[f32],
72    g2: &mut [f32],
73    dim: [usize; 4],
74    axis: usize,
75) -> Result<(), BMLSError> {    
76    if g2.len() != dim[axis] {
77        return error::axis_mismatch(0, "G2", g2.len(), axis, "Y", dim[axis])
78    }
79
80    let len = dim[0]*dim[1]*dim[2]*dim[3];
81    if gy.len() != len {
82        return error::length_mismatch("GY", gy.len(), "Dim", len)
83    }
84
85    for n in 0..dim[0] {
86        for c in 0..dim[1] {
87            for h in 0..dim[2] {
88                for w in 0..dim[3] {
89                    let i = n * dim[1] * dim[2] * dim[3] + c * dim[2] * dim[3] + h * dim[3] + w;
90                    let indices = [n, c, h, w];
91
92                    g2[indices[axis]] += gy[i]
93                }
94            }
95        }
96    }
97
98    Ok(())
99}
100
101#[cfg(test)]
102mod test {
103
104    use super::*;
105
106    #[test]
107    fn test_axis_add() {
108
109        // 4 x 4 x 4
110        let a = vec![0.0; 64];
111
112        // one b for each channel
113        let b = (1..5).map(|i| i as f32).collect::<Vec<f32>>();
114
115        let mut c = vec![0.0; 64];
116
117            axis_add(
118                &a,
119                &b,
120                &mut c,
121                [1, 4, 4, 4],
122                3,
123            ).unwrap();
124
125        for cc in 0..4 {
126            println!("");
127            println!("");
128            for ch in 0..4 {
129                println!("");
130                for cw in 0..4 {
131                    print!("{} ", c[cc * 4 * 4 + ch * 4 + cw]);
132                }
133            }
134        }
135
136        //panic!("");
137    }
138}