1
2use itertools::izip;
3use crate::error::BMLSError;
4use crate::error;
5
6#[inline]
13pub fn axis_add(
14 x1: &[f32],
15 x2: &[f32],
16 y: &mut [f32],
17 dim: [usize; 4],
18 axis: usize,
19) -> Result<(), BMLSError> {
20 let len = dim[0]*dim[1]*dim[2]*dim[3];
22
23 if y.len() != len {
24 return error::length_mismatch("Y", y.len(), "Dim", len)
25 }
26
27 if x1.len() != y.len() {
28 return error::length_mismatch("X1", x1.len(), "Y", y.len())
29 }
30
31 if x2.len() != dim[axis] {
32 return error::axis_mismatch(0, "X2", x2.len(), axis, "Y", dim[axis])
33 }
34
35 let cptr = y;
36 for n in 0..dim[0] {
37 for c in 0..dim[1] {
38 for h in 0..dim[2] {
39 for w in 0..dim[3] {
40 let i = n * dim[1] * dim[2] * dim[3] + c * dim[2] * dim[3] + h * dim[3] + w;
41
42 let indices = [n, c, h, w];
43
44 cptr[i] = x1[i] + x2[indices[axis]];
45 }
46 }
47 }
48 }
49
50 Ok(())
51}
52
53#[inline]
54pub fn axis_add_wrt_x1(
55 gy: &[f32],
56 g1: &mut [f32],
57) -> Result<(), BMLSError> {
58 if gy.len() != g1.len() {
59 return error::length_mismatch("GY", gy.len(), "G1", g1.len())
60 }
61
62 for (gy, g1) in izip!(gy, g1) {
63 *g1 += *gy
64 }
65
66 Ok(())
67}
68
69#[inline]
70pub fn axis_add_wrt_x2(
71 gy: &[f32],
72 g2: &mut [f32],
73 dim: [usize; 4],
74 axis: usize,
75) -> Result<(), BMLSError> {
76 if g2.len() != dim[axis] {
77 return error::axis_mismatch(0, "G2", g2.len(), axis, "Y", dim[axis])
78 }
79
80 let len = dim[0]*dim[1]*dim[2]*dim[3];
81 if gy.len() != len {
82 return error::length_mismatch("GY", gy.len(), "Dim", len)
83 }
84
85 for n in 0..dim[0] {
86 for c in 0..dim[1] {
87 for h in 0..dim[2] {
88 for w in 0..dim[3] {
89 let i = n * dim[1] * dim[2] * dim[3] + c * dim[2] * dim[3] + h * dim[3] + w;
90 let indices = [n, c, h, w];
91
92 g2[indices[axis]] += gy[i]
93 }
94 }
95 }
96 }
97
98 Ok(())
99}
100
101#[cfg(test)]
102mod test {
103
104 use super::*;
105
106 #[test]
107 fn test_axis_add() {
108
109 let a = vec![0.0; 64];
111
112 let b = (1..5).map(|i| i as f32).collect::<Vec<f32>>();
114
115 let mut c = vec![0.0; 64];
116
117 axis_add(
118 &a,
119 &b,
120 &mut c,
121 [1, 4, 4, 4],
122 3,
123 ).unwrap();
124
125 for cc in 0..4 {
126 println!("");
127 println!("");
128 for ch in 0..4 {
129 println!("");
130 for cw in 0..4 {
131 print!("{} ", c[cc * 4 * 4 + ch * 4 + cw]);
132 }
133 }
134 }
135
136 }
138}