onnx_graph 0.1.2

ONNX graph parser and execution engine for deep neural networks
Documentation
use std::any::Any;

use crate::{
    nodes::{node::Node, onnx_operation_trait::FromOnnxOperation, unique_ids::UniqueId},
    tensor_map::TensorMap,
    typed_array::TypedArray,
};

use anyhow::{Ok, Result};
use ndarray::{ArrayD, IxDyn};
use onnx_extractor::OnnxOperation;

#[derive(Default)]
pub struct GatherNode<T: Default> {
    data: String,
    indices: String,
    o: String,

    axis: i64,

    unique_id: UniqueId,
    next_node: Option<Vec<Box<dyn Node<T>>>>,
}

impl<T: Default> FromOnnxOperation for GatherNode<T> {
    fn from_onnx_operation(elem: &OnnxOperation) -> Result<Self> {
        let attrs = &elem.attributes;
        let mut gather = Self {
            data: String::new(),
            indices: String::new(),
            o: String::new(),
            axis: attrs.get("axis").and_then(|v| v.as_int()).unwrap_or(0),
            unique_id: UniqueId::Gather,
            next_node: None,
        };
        gather.add_input_strings(elem.inputs[0].clone(), elem.inputs[1].clone());
        gather.add_output_strings(elem.outputs[0].clone());
        Ok(gather)
    }
}

impl<T: Default> GatherNode<T> {
    pub fn add_input_strings(&mut self, data: String, indices: String) {
        self.data = data;
        self.indices = indices;
    }

    pub fn add_output_strings(&mut self, o: String) {
        self.o = o;
    }
}

impl<T: Default + 'static> Node<T> for GatherNode<T> {
    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    fn get_unique_id(&self) -> UniqueId {
        self.unique_id
    }

    fn get_unique_id_mut(&mut self) -> UniqueId {
        self.unique_id
    }

    fn take_next(&mut self) -> Option<Vec<Box<dyn Node<T>>>> {
        self.next_node.take()
    }

    fn get_next_mut(&mut self) -> Option<&mut Vec<Box<dyn Node<T>>>> {
        self.next_node.as_mut()
    }

    fn set_next(&mut self, next: Option<Vec<Box<dyn Node<T>>>>) {
        self.next_node = next;
    }

    fn input_names(&self) -> Vec<String> {
        vec![self.data.clone(), self.indices.clone()]
    }

    fn output_names(&self) -> Vec<String> {
        vec![self.o.clone()]
    }

    fn get_next(&self) -> Option<&Vec<Box<dyn Node<T>>>> {
        self.next_node.as_ref()
    }

    fn execute(&self, omap: &mut TensorMap) {
        let [data, indices, o] = omap.get_disjoint_mut([&self.data, &self.indices, &self.o]);
        let data = &*data.unwrap();
        let indices = &*indices.unwrap();

        match o {
            Some(result) => {
                TypedArray::gather(data, indices, self.axis, result).unwrap();
            }
            _ => panic!("GatherNode: missing output {}", self.o),
        }
    }

    fn print(&self) {
        if let Some(list) = &self.next_node {
            print!("{}-", list.len());
        }
        println!(
            "gather-{},{},{} axis={}",
            self.data, self.indices, self.o, self.axis
        );
        if let Some(next) = &self.next_node {
            next.iter().for_each(|v| v.print());
        }
    }

    fn determine_output_shape(&mut self, omap: &mut TensorMap) {
        let [data, indices, o] = omap.get_disjoint_mut([&self.data, &self.indices, &self.o]);
        let data = data.map(|arr| &*arr);
        let indices = indices.map(|arr| &*arr);

        if let (Some(data), Some(indices), Some(o)) = (data, indices, o)
            && let (Some(data_shape), Some(idx_shape)) = (data.shape(), indices.shape())
        {
            let ndim = data_shape.len() as i64;
            let axis = if self.axis < 0 {
                (ndim + self.axis) as usize
            } else {
                self.axis as usize
            };

            let is_scalar_index = idx_shape.is_empty() || idx_shape == [1];

            let mut out_shape: Vec<usize> = Vec::new();
            for i in data_shape.iter().take(axis) {
                out_shape.push(*i);
            }
            if !is_scalar_index {
                for &s in idx_shape {
                    out_shape.push(s);
                }
            }
            for i in data_shape.iter().skip(axis + 1) {
                out_shape.push(*i);
            }

            *o = TypedArray::empty_with_others_type(data, &out_shape);
        }

        if let Some(list) = &mut self.next_node {
            for next in list {
                next.determine_output_shape(omap);
            }
        }
    }
}

macro_rules! call_gather_for_typed_array {
    ($data:expr, $axis:expr, $idx_vec:expr, $idx_shape:expr, $o:expr, [$($variant:ident),+]) => {

        match $data {
            $(
                TypedArray::$variant(arr) => gather_variant!($variant, $axis, $idx_vec, $idx_shape, arr, $o),
            )+
            TypedArray::Bool(arr) => {
                let ndim = arr.ndim() as i64;
                let axis_usize = if $axis < 0 {
                    (ndim + $axis) as usize
                } else {
                    $axis as usize
                };

                let data_shape = arr.shape();
                let axis_size = data_shape[axis_usize] as i64;

                let is_scalar_index = $idx_shape.is_empty() || $idx_shape == [1usize];

                let mut out_shape: Vec<usize> = Vec::new();
                for i in 0..axis_usize {
                    out_shape.push(data_shape[i]);
                }
                if !is_scalar_index {
                    for &s in &($idx_shape) {
                        out_shape.push(s);
                    }
                }
                for i in (axis_usize + 1)..data_shape.len() {
                    out_shape.push(data_shape[i]);
                }


                let needs_alloc = match &*($o) {
                    TypedArray::Bool(out) => out.shape() != out_shape.as_slice(),
                    _ => true,
                };

                if needs_alloc {
                    *($o) = TypedArray::Bool(ArrayD::from_elem(IxDyn(&out_shape),false));
                }

                let out_arr = match $o {
                    TypedArray::Bool(arr) => arr,
                    _ => unreachable!(),
                };

                let data_sl = arr.as_slice_memory_order().unwrap();
                let out_sl = out_arr.as_slice_memory_order_mut().unwrap();

                let outer_size: usize = data_shape[..axis_usize].iter().product();
                let inner_size: usize = data_shape[axis_usize + 1..].iter().product();
                let axis_dim = data_shape[axis_usize];

                let mut out_idx = 0;
                for outer in 0..outer_size.max(1) {
                    for &idx in &($idx_vec) {
                        let idx = if idx < 0 {
                            (axis_size + idx) as usize
                        } else {
                            idx as usize
                        };

                        let src_offset = outer * axis_dim * inner_size + idx * inner_size;
                        let len = inner_size.max(1);

                        out_sl[out_idx..out_idx + len]
                            .copy_from_slice(&data_sl[src_offset..src_offset + len]);
                        out_idx += len;
                    }
                }
            }
            _ => return Err(anyhow::anyhow!("argmax: unsupported type")),
        }
    };
}

macro_rules! gather_variant {
    ($variant:ident, $axis:expr, $idx_vec:expr, $idx_shape:expr, $arr:expr, $o:expr) => {{
        let ndim = $arr.ndim() as i64;
        let axis_usize = if $axis < 0 {
            (ndim + $axis) as usize
        } else {
            $axis as usize
        };

        let data_shape = $arr.shape();
        let axis_size = data_shape[axis_usize] as i64;

        let is_scalar_index = $idx_shape.is_empty() || $idx_shape == [1usize];

        let mut out_shape: Vec<usize> = Vec::new();
        for i in 0..axis_usize {
            out_shape.push(data_shape[i]);
        }
        if !is_scalar_index {
            for &s in &($idx_shape) {
                out_shape.push(s);
            }
        }
        for i in (axis_usize + 1)..data_shape.len() {
            out_shape.push(data_shape[i]);
        }

        let needs_alloc = match &*($o) {
            TypedArray::$variant(out) => out.shape() != out_shape.as_slice(),
            _ => true,
        };

        if needs_alloc {
            *($o) = TypedArray::$variant(ArrayD::zeros(IxDyn(&out_shape)));
        }

        let out_arr = match $o {
            TypedArray::$variant(arr) => arr,
            _ => unreachable!(),
        };

        let data_sl = $arr.as_slice_memory_order().unwrap();
        let out_sl = out_arr.as_slice_memory_order_mut().unwrap();

        let outer_size: usize = data_shape[..axis_usize].iter().product();
        let inner_size: usize = data_shape[axis_usize + 1..].iter().product();
        let axis_dim = data_shape[axis_usize];

        let mut out_idx = 0;
        for outer in 0..outer_size.max(1) {
            for &idx in &($idx_vec) {
                let idx = if idx < 0 {
                    (axis_size + idx) as usize
                } else {
                    idx as usize
                };

                let src_offset = outer * axis_dim * inner_size + idx * inner_size;
                let len = inner_size.max(1);

                out_sl[out_idx..out_idx + len]
                    .copy_from_slice(&data_sl[src_offset..src_offset + len]);
                out_idx += len;
            }
        }
    }};
}

impl TypedArray {
    pub fn gather(
        data: &TypedArray,
        indices: &TypedArray,
        axis: i64,
        o: &mut TypedArray,
    ) -> anyhow::Result<()> {
        let idx_vec: Vec<i64> = match indices {
            TypedArray::Int64(arr) => arr.iter().copied().collect(),
            TypedArray::Int32(arr) => arr.iter().map(|&v| v as i64).collect(),
            _ => return Err(anyhow::anyhow!("Gather: indices must be I32 or I64")),
        };
        let idx_shape: Vec<usize> = match indices {
            TypedArray::Int64(arr) => arr.shape().to_vec(),
            TypedArray::Int32(arr) => arr.shape().to_vec(),
            _ => unreachable!(),
        };

        call_gather_for_typed_array!(
            data,
            axis,
            idx_vec,
            idx_shape,
            o,
            [Float, Double, Int32, Int64, Uint8, Uint16, Uint32, Uint64]
        );

        Ok(())
    }
}