native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::scratch::Scratch;
use crate::tensor::TensorView;

pub struct Conv5dBackwardArgs<'t> {
    pub input: &'t TensorView<'t>,
    pub kernel: &'t TensorView<'t>,
    pub doutput: &'t TensorView<'t>,
    pub dinput: &'t mut TensorView<'t>,
    pub dkernel: &'t mut TensorView<'t>,
    pub dbias: Option<&'t mut [f32]>,
    pub pad: (usize, usize, usize),
    pub stride: (usize, usize, usize),
    pub scratch: &'t mut Scratch<'t>,
}

pub fn conv5d_forward(
    input: &TensorView<'_>,
    kernel: &TensorView<'_>,
    bias: Option<&[f32]>,
    output: &mut TensorView<'_>,
    pad: (usize, usize, usize),
    stride: (usize, usize, usize),
) {
    if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
        return;
    }
    if !input.is_valid_layout() || !kernel.is_valid_layout() || !output.is_valid_layout() {
        return;
    }

    let n = input.shape[0];
    let c = input.shape[1];
    let d = input.shape[2];
    let h = input.shape[3];
    let w = input.shape[4];
    let cout = kernel.shape[0];
    if kernel.shape[1] != c {
        return;
    }
    if output.shape[0] != n || output.shape[1] != cout {
        return;
    }
    if let Some(b) = bias {
        if b.len() < cout {
            return;
        }
    }

    let pad_d = pad.0 as isize;
    let pad_h = pad.1 as isize;
    let pad_w = pad.2 as isize;
    let d_isize = d as isize;
    let h_isize = h as isize;
    let w_isize = w as isize;

    for n in 0..n {
        for co in 0..cout {
            for od in 0..output.shape[2] {
                for oh in 0..output.shape[3] {
                    for ow in 0..output.shape[4] {
                        let mut acc = 0f32;
                        for ci in 0..c {
                            for kd in 0..kernel.shape[2] {
                                for kh in 0..kernel.shape[3] {
                                    for kw in 0..kernel.shape[4] {
                                        let id = (od * stride.0 + kd) as isize - pad_d;
                                        let ih = (oh * stride.1 + kh) as isize - pad_h;
                                        let iw = (ow * stride.2 + kw) as isize - pad_w;
                                        if id >= 0
                                            && ih >= 0
                                            && iw >= 0
                                            && id < d_isize
                                            && ih < h_isize
                                            && iw < w_isize
                                        {
                                            let idu = id as usize;
                                            let ihu = ih as usize;
                                            let iwu = iw as usize;
                                            let i_idx = match input.idx_linear(n, ci, idu, ihu, iwu)
                                            {
                                                Some(v) => v,
                                                None => continue,
                                            };
                                            let k_idx = match kernel.idx_linear(co, ci, kd, kh, kw)
                                            {
                                                Some(v) => v,
                                                None => continue,
                                            };
                                            acc += input.data[i_idx] * kernel.data[k_idx];
                                        }
                                    }
                                }
                            }
                        }
                        if let Some(out_i) = output.idx_linear(n, co, od, oh, ow) {
                            output.data[out_i] = acc + bias.map(|b| b[co]).unwrap_or(0.0);
                        }
                    }
                }
            }
        }
    }
}

pub fn conv5d_backward<'a>(args: Conv5dBackwardArgs<'a>) {
    let Conv5dBackwardArgs {
        input,
        kernel,
        doutput,
        dinput,
        dkernel,
        mut dbias,
        pad,
        stride,
        scratch,
    } = args;

    if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
        return;
    }
    if !input.is_valid_layout()
        || !kernel.is_valid_layout()
        || !doutput.is_valid_layout()
        || !dinput.is_valid_layout()
        || !dkernel.is_valid_layout()
    {
        return;
    }

    for v in dinput.data.iter_mut() {
        *v = 0.0;
    }
    for v in dkernel.data.iter_mut() {
        *v = 0.0;
    }
    if let Some(b) = dbias.as_mut() {
        for x in b.iter_mut() {
            *x = 0.0;
        }
    }

    let n = input.shape[0];
    let c = input.shape[1];
    let d = input.shape[2];
    let h = input.shape[3];
    let w = input.shape[4];
    let cout = kernel.shape[0];
    if kernel.shape[1] != c {
        return;
    }
    if doutput.shape[0] != n || doutput.shape[1] != cout {
        return;
    }
    if dinput.shape != input.shape {
        return;
    }
    if dkernel.shape != kernel.shape {
        return;
    }
    if let Some(b) = dbias.as_ref() {
        if b.len() < cout {
            return;
        }
    }

    let pad_d = pad.0 as isize;
    let pad_h = pad.1 as isize;
    let pad_w = pad.2 as isize;
    let d_isize = d as isize;
    let h_isize = h as isize;
    let w_isize = w as isize;

    if let Some(b) = dbias.as_mut() {
        for co in 0..cout {
            let mut sum = 0f32;
            for n in 0..n {
                for od in 0..doutput.shape[2] {
                    for oh in 0..doutput.shape[3] {
                        for ow in 0..doutput.shape[4] {
                            let idx = match doutput.idx_linear(n, co, od, oh, ow) {
                                Some(v) => v,
                                None => continue,
                            };
                            sum += doutput.data[idx];
                        }
                    }
                }
            }
            b[co] = sum;
        }
    }

    scratch.base_ptr();

    for co in 0..cout {
        for ci in 0..c {
            for kd in 0..kernel.shape[2] {
                for kh in 0..kernel.shape[3] {
                    for kw in 0..kernel.shape[4] {
                        let mut acc = 0f32;
                        for n in 0..n {
                            for od in 0..doutput.shape[2] {
                                for oh in 0..doutput.shape[3] {
                                    for ow in 0..doutput.shape[4] {
                                        let id = (od * stride.0 + kd) as isize - pad_d;
                                        let ih = (oh * stride.1 + kh) as isize - pad_h;
                                        let iw = (ow * stride.2 + kw) as isize - pad_w;
                                        if id >= 0
                                            && ih >= 0
                                            && iw >= 0
                                            && id < d_isize
                                            && ih < h_isize
                                            && iw < w_isize
                                        {
                                            let idu = id as usize;
                                            let ihu = ih as usize;
                                            let iwu = iw as usize;
                                            let in_idx =
                                                match input.idx_linear(n, ci, idu, ihu, iwu) {
                                                    Some(v) => v,
                                                    None => continue,
                                                };
                                            let dout_idx =
                                                match doutput.idx_linear(n, co, od, oh, ow) {
                                                    Some(v) => v,
                                                    None => continue,
                                                };
                                            acc += input.data[in_idx] * doutput.data[dout_idx];
                                        }
                                    }
                                }
                            }
                        }
                        if let Some(k_idx) = dkernel.idx_linear(co, ci, kd, kh, kw) {
                            dkernel.data[k_idx] = acc;
                        }
                    }
                }
            }
        }
    }

    for n in 0..n {
        for ci in 0..c {
            for id in 0..d {
                for ih in 0..h {
                    for iw in 0..w {
                        let mut acc = 0f32;
                        for co in 0..cout {
                            for kd in 0..kernel.shape[2] {
                                for kh in 0..kernel.shape[3] {
                                    for kw in 0..kernel.shape[4] {
                                        let od_num = (id + pad.0) as isize - kd as isize;
                                        let oh_num = (ih + pad.1) as isize - kh as isize;
                                        let ow_num = (iw + pad.2) as isize - kw as isize;
                                        if od_num < 0 || oh_num < 0 || ow_num < 0 {
                                            continue;
                                        }
                                        let od = od_num as usize;
                                        let oh = oh_num as usize;
                                        let ow = ow_num as usize;
                                        if !od.is_multiple_of(stride.0)
                                            || !oh.is_multiple_of(stride.1)
                                            || !ow.is_multiple_of(stride.2)
                                        {
                                            continue;
                                        }
                                        let od = od / stride.0;
                                        let oh = oh / stride.1;
                                        let ow = ow / stride.2;
                                        if od < doutput.shape[2]
                                            && oh < doutput.shape[3]
                                            && ow < doutput.shape[4]
                                        {
                                            let dout_idx =
                                                match doutput.idx_linear(n, co, od, oh, ow) {
                                                    Some(v) => v,
                                                    None => continue,
                                                };
                                            let k_idx = match kernel.idx_linear(co, ci, kd, kh, kw)
                                            {
                                                Some(v) => v,
                                                None => continue,
                                            };
                                            acc += doutput.data[dout_idx] * kernel.data[k_idx];
                                        }
                                    }
                                }
                            }
                        }
                        if let Some(i_idx) = dinput.idx_linear(n, ci, id, ih, iw) {
                            dinput.data[i_idx] = acc;
                        }
                    }
                }
            }
        }
    }
}