steelix 0.1.0

Your one stop CLI for ONNX model analysis. Featuring graph visualization, FLOP counts, memory metrics and more!
Documentation
use std::borrow::Cow;
use steelix_onnx::onnx_pb;

use crate::prelude::*;

#[derive(Debug, Clone)]
pub struct Transpose {
    perm: Vec<usize>,
}

impl Transpose {
    fn transpose(&self, input: &Tensor, axes: &[usize]) -> Vec<usize> {
        let mut usage_counts = [0, 0, 0, 0];
        for axis in axes {
            usage_counts[*axis] += 1;
        }
        for count in usage_counts {
            assert_eq!(count, 1, "each axis must be listed exactly once");
        }
        let mut new_dim = usage_counts;
        {
            let dim = &input.shape;
            for (new_axis, &axis) in axes.iter().enumerate() {
                new_dim[new_axis] = dim[axis];
            }
        }
        new_dim.into()
    }
}

impl Op for Transpose {
    fn name(&self) -> Cow<str> {
        "Transpose".into()
    }

    fn op_group(&self) -> OpGroup {
        OpGroup::Shape
    }

    fn realize(&self, providers: PVec) -> anyhow::Result<RealizedOp> {
        validate_providers(&providers, 1, 1, &self.name())?;

        let transposed_shape = self.transpose(&providers[0], &self.perm).into();
        let result = Tensor::new(providers[0].dt, Shape(transposed_shape)).into_arc_tensor();

        Ok(RealizedOp::zero_cost(pvec!(result)))
    }
}

pub fn build_transpose(proto: &onnx_pb::NodeProto) -> Result<BoxOp, anyhow::Error> {
    let perm: Vec<usize> = proto
        .get_attribute::<Vec<i64>>("perm", None)?
        .iter()
        .cloned()
        .map(|x| x as usize)
        .collect();
    Ok(Box::new(Transpose { perm }) as BoxOp)
}