onnx_graph 0.1.0

ONNX graph parser and execution engine for deep neural networks
Documentation
use std::{any::Any, cell::RefCell, collections::HashMap, str::FromStr};

use crate::{
    nodes::{hash_trait::FromHashMap, node::Node, unique_ids::UniqueId},
    tensor_map::TensorMap,
    typed_array::TypedArray,
};
use anyhow::Result;
use ndarray::Ix4;
use ndarray::{ArrayView4, ArrayViewMut4};
use onnx_extractor::AttributeValue;
use rayon::iter::IndexedParallelIterator;
use rayon::iter::ParallelIterator;
use rayon::prelude::ParallelSliceMut;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AutoPad {
    #[default]
    NOTSET,
    SameUpper,
    SameLower,
    VALID,
}

impl FromStr for AutoPad {
    type Err = anyhow::Error;

    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        Ok(match s {
            "SAME_UPPER" => Self::SameUpper,
            "SAME_LOWER" => Self::SameLower,
            "VALID" => Self::VALID,
            _ => Self::NOTSET,
        })
    }
}

#[derive(Default)]
pub struct MaxPoolNode<T: Default> {
    x: String,

    o: String,

    unique_id: UniqueId,

    auto_pad: AutoPad,
    ceil_mode: i64,
    kernel_shape: Vec<usize>,
    dilations: Vec<usize>,
    strides: Vec<usize>,
    pads: Vec<usize>,
    storage_order: usize,
    next_node: Option<Vec<Box<dyn Node<T>>>>,
}

impl<T: Default> FromHashMap for MaxPoolNode<T> {
    fn from_hashmap(
        attrs: &std::collections::HashMap<String, AttributeValue>,
    ) -> anyhow::Result<Self> {
        Ok(Self {
            x: String::new(),
            o: String::new(),
            auto_pad: {
                match attrs.get("auto_pad") {
                    Some(av) => {
                        let pad = av.as_string().unwrap();
                        AutoPad::from_str(pad).unwrap()
                    }
                    None => AutoPad::NOTSET,
                }
            },
            kernel_shape: {
                match attrs.get("kernel_shape") {
                    Some(av) => av
                        .as_ints()
                        .unwrap()
                        .iter()
                        .map(|&val| val as usize)
                        .collect(),
                    None => vec![],
                }
            },
            pads: {
                match attrs.get("pads") {
                    Some(av) => av
                        .as_ints()
                        .unwrap()
                        .iter()
                        .map(|&val| val as usize)
                        .collect(),
                    None => vec![],
                }
            },
            strides: {
                match attrs.get("strides") {
                    Some(av) => av
                        .as_ints()
                        .unwrap()
                        .iter()
                        .map(|&val| val as usize)
                        .collect(),
                    None => vec![],
                }
            },
            dilations: {
                match attrs.get("dilations") {
                    Some(av) => av
                        .as_ints()
                        .unwrap()
                        .iter()
                        .map(|&val| val as usize)
                        .collect(),
                    None => vec![],
                }
            },
            ceil_mode: {
                match attrs.get("ceil_mode") {
                    Some(av) => av.as_int().unwrap(),
                    None => 0,
                }
            },
            storage_order: {
                match attrs.get("storage_order") {
                    Some(av) => av.as_int().unwrap().to_owned() as usize,
                    None => 0,
                }
            },
            unique_id: UniqueId::MaxPool,
            next_node: None,
        })
    }
}

impl<T: Default> MaxPoolNode<T> {
    pub fn new(
        auto_pad: &str,
        ceil_mode: i64,
        kernel_shape: Vec<usize>,
        dilations: Vec<usize>,
        strides: Vec<usize>,
        storage_order: usize,
        pads: Vec<usize>,
    ) -> Self {
        Self {
            x: String::new(),
            o: String::new(),
            auto_pad: AutoPad::from_str(auto_pad).unwrap(),
            ceil_mode,
            kernel_shape,
            dilations,
            strides,
            pads,
            storage_order,
            unique_id: UniqueId::MaxPool,
            next_node: None,
        }
    }

    pub fn add_input_strings(&mut self, x: String) {
        self.x = x;
    }

    pub fn add_output_strings(&mut self, o: String) {
        self.o = o;
    }
}

pub fn maxpool_fast(
    input: &TypedArray,
    kernel: &[usize],
    strides: &[usize],
    pads: &[usize],
    dilations: &[usize],
    o: &mut TypedArray,
) -> anyhow::Result<bool> {
    let kh = kernel[0];
    let kw = kernel[1];
    let sh = strides.first().copied().unwrap_or(1);
    let sw = strides.get(1).copied().unwrap_or(1);
    let ph = pads.first().copied().unwrap_or(0);
    let pw = pads.get(1).copied().unwrap_or(0);
    let dh = dilations.first().copied().unwrap_or(1);
    let dw = dilations.get(1).copied().unwrap_or(1);

    if let (TypedArray::F32(x), TypedArray::F32(out)) = (input, &mut *o)
        && kh == kw
        && sh == 1
        && sw == 1
        && ph == kh / 2
        && pw == kw / 2
        && dh == 1
        && dw == 1
    {
        let x4 = x.view().into_dimensionality::<Ix4>()?;

        let mut out4 = out.view_mut().into_dimensionality::<Ix4>()?;

        match kh {
            3 => maxpool_3x3_mut(&x4, &mut out4),
            5 => maxpool_5x5_mut(&x4, &mut out4),
            9 => maxpool_9x9_mut(&x4, &mut out4),
            13 => maxpool_13x13_mut(&x4, &mut out4),
            _ => return Ok(false),
        }

        return Ok(true);
    }

    Ok(false)
}

thread_local! {
    static POOL_TMP: RefCell<Vec<f32>> = const {RefCell::new(Vec::new())};
}

macro_rules! impl_maxpool_nxn {
    ($name:ident, $k:expr) => {
        pub fn $name(input: &ArrayView4<f32>, output: &mut ArrayViewMut4<f32>) {
            const K: usize = $k;
            const HALF: usize = K / 2;

            let (_, _, h, w) = input.dim();
            let hw = h * w;

            let in_sl = input.as_slice_memory_order().unwrap();
            let out_sl = output.as_slice_memory_order_mut().unwrap();

            out_sl
                .par_chunks_mut(hw)
                .enumerate()
                .for_each(|(ch, out_ch)| {
                    let in_ch = &in_sl[ch * hw..(ch + 1) * hw];

                    POOL_TMP.with(|cell| {
                        let mut tmp = cell.borrow_mut();
                        tmp.resize(hw, f32::NEG_INFINITY);

                        for y in 0..h {
                            let row = y * w;
                            let tmp_row = &mut tmp[row..row + w];
                            for x in 0..w {
                                let x0 = x.saturating_sub(HALF);
                                let x1 = (x + HALF).min(w - 1);
                                let mut val = f32::NEG_INFINITY;
                                for xi in x0..=x1 {
                                    unsafe {
                                        let v = *in_ch.get_unchecked(row + xi);
                                        if v > val {
                                            val = v;
                                        }
                                    }
                                }
                                unsafe {
                                    *tmp_row.get_unchecked_mut(x) = val;
                                }
                            }
                        }

                        for y in 0..h {
                            let y0 = y.saturating_sub(HALF);
                            let y1 = (y + HALF).min(h - 1);
                            let out_row = &mut out_ch[y * w..y * w + w];
                            for x in 0..w {
                                let mut val = f32::NEG_INFINITY;
                                for yi in y0..=y1 {
                                    unsafe {
                                        let v = *tmp.get_unchecked(yi * w + x);
                                        if v > val {
                                            val = v;
                                        }
                                    }
                                }
                                unsafe {
                                    *out_row.get_unchecked_mut(x) = val;
                                }
                            }
                        }
                    });
                });
        }
    };
}

impl_maxpool_nxn!(maxpool_3x3_mut, 3);
impl_maxpool_nxn!(maxpool_5x5_mut, 5);
impl_maxpool_nxn!(maxpool_9x9_mut, 9);
impl_maxpool_nxn!(maxpool_13x13_mut, 13);

impl<T: Default + 'static> Node<T> for MaxPoolNode<T> {
    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    fn get_unique_id(&self) -> UniqueId {
        self.unique_id
    }
    fn get_unique_id_mut(&mut self) -> UniqueId {
        self.unique_id
    }

    fn take_next(&mut self) -> Option<Vec<Box<dyn Node<T>>>> {
        self.next_node.take()
    }
    fn get_next_mut(&mut self) -> Option<&mut Vec<Box<dyn Node<T>>>> {
        self.next_node.as_mut()
    }

    fn set_next(&mut self, next: Option<Vec<Box<dyn Node<T>>>>) {
        self.next_node = next;
    }

    fn get_next(&self) -> Option<&Vec<Box<dyn Node<T>>>> {
        self.next_node.as_ref()
    }

    fn input_names(&self) -> Vec<String> {
        vec![self.o.clone()]
    }

    fn execute(&self, omap: &mut TensorMap) {
        let [x, o] = omap.get_disjoint_mut([&self.x, &self.o]);
        let x = &*x.unwrap();

        match o {
            Some(result) => {
                let kernel: Vec<usize> = self.kernel_shape.to_vec();
                let strides: Vec<usize> = self.strides.to_vec();
                let pads: Vec<usize> = self.pads.to_vec();
                let dilations: Vec<usize> = self.dilations.to_vec();

                let handled =
                    maxpool_fast(x, &kernel, &strides, &pads, &dilations, result).unwrap_or(false);

                if !handled {
                    x.max_pool(
                        &kernel,
                        &strides,
                        &pads,
                        &dilations,
                        self.ceil_mode != 0,
                        result,
                    )
                    .unwrap();
                }
            }
            None => panic!("MaxPoolNode: missing input {}", self.x),
        }
    }

    fn output_names(&self) -> Vec<String> {
        vec![self.o.clone()]
    }

    fn print(&self) {
        if let Some(list) = &self.next_node {
            print!("{}-", list.len());
        }
        println!("maxpool-{},{}", self.x, self.o);

        if let Some(next) = &self.next_node {
            next.iter().for_each(|v| v.print());
        }
    }

    fn self_count(&self, count: usize) -> usize {
        if let Some(next) = &self.next_node {
            let mut ct = 0;
            let mut sum = 0;
            next.iter().for_each(|val| {
                sum += val.self_count(ct);
                ct += 1;
            });
            sum
        } else {
            count
        }
    }

    fn insert(&mut self, next: Box<dyn Node<T>>) -> Result<()> {
        if let Some(next_node) = &mut self.next_node {
            next_node[0].insert(next)?;
            return Ok(());
        } else {
            self.next_node = Some(vec![next])
        }
        Ok(())
    }
}