steelix 0.1.0

Your one stop CLI for ONNX model analysis. Featuring graph visualization, FLOP counts, memory metrics and more!
Documentation
use crate::ir::OpError;
use crate::prelude::*;
use std::borrow::Cow;
use steelix_onnx::onnx_pb;

#[derive(Debug, Clone)]
pub struct Concat {
    axis: i64,
}

impl Concat {
    pub fn concat(&self, providers: &PVec) -> Result<Shape, OpError> {
        Ok(Tensor::stack_tensors(self.axis as usize, providers)?.shape)
    }
}

impl Op for Concat {
    fn name(&self) -> Cow<str> {
        "Concat".into()
    }

    fn op_group(&self) -> OpGroup {
        OpGroup::Shape
    }

    fn realize(&self, providers: PVec) -> anyhow::Result<RealizedOp> {
        validate_providers(&providers, 1, 2, &self.name())?;
        let new_shape = self.concat(&providers)?;

        Ok(RealizedOp::zero_cost(pvec!(Tensor::new(
            providers[0].dt,
            new_shape
        )
        .into_arc_tensor())))
    }
}

pub fn build_concat(proto: &onnx_pb::NodeProto) -> Result<BoxOp, anyhow::Error> {
    let axis = proto.get_attribute("axis", None)?;
    Ok(Box::new(Concat { axis }) as BoxOp)
}