bmls/
sub.rs

1
2use itertools::izip;
3use crate::error::BMLSError;
4use crate::error;
5
6/// # Subtraction Operation
7/// - X1: Left Operand
8/// - X2: Right Operand
9/// - Y: Output
10#[inline]
11pub fn sub(
12    x1: &[f32],
13    x2: &[f32],
14    y: &mut [f32],
15) -> Result<(), BMLSError> {
16    if x1.len() != x2.len() {
17        return error::length_mismatch("X2", x1.len(), "X2", x2.len())
18    }
19
20    if x2.len() != y.len() {
21        return error::length_mismatch("X2", x2.len(), "Y", y.len())
22    }
23
24    for (x1, x2, y) in izip!(x1, x2, y) {
25        *y = *x1 - *x2; 
26    }
27
28    Ok(())
29}
30
31/// # Subtraction W.r.t. X1
32/// - GY: Gradient w.r.t. Output Y
33/// - G1: Gradient W.r.t. Input X1 
34#[inline]
35pub fn sub_wrt_x1(
36    gy: &[f32],
37    g1: &mut [f32]
38) -> Result<(), BMLSError> {
39    if gy.len() != g1.len() {
40        return error::length_mismatch("GY", gy.len(), "G1", g1.len())
41    }
42
43    for (gy, g1) in izip!(gy, g1) {
44        *g1 += *gy;
45    }
46
47    Ok(())
48}
49
50/// # Subtraction w.r.t. X2
51/// - GY: Gradient w.r.t. Output Y
52/// - G2: Gradient w.r.t. Input X2 
53#[inline]
54pub fn sub_wrt_x2(
55    gy: &[f32],
56    g2: &mut [f32]
57) -> Result<(), BMLSError> {
58    if gy.len() != g2.len() {
59        return error::length_mismatch("GY", gy.len(), "G2", g2.len())
60    }
61
62    for (gy, g2) in izip!(gy, g2) {
63        *g2 -= *gy;
64    }
65
66    Ok(())
67}