webnn-onnx-utils 0.1.1

Shared utilities for ONNX <-> WebNN conversion (types, op names, attrs, tensor data, etc.)
Documentation
use serde_json::Value as JsonValue;

use crate::error::{ConversionError, Result};
use crate::protos::onnx::{AttributeProto, attribute_proto};

#[derive(Debug, Clone)]
pub enum AttrValue {
    Int(i64),
    Float(f32),
    String(String),
    Ints(Vec<i64>),
    Floats(Vec<f32>),
}

pub struct AttrParser<'a> {
    attrs: &'a [AttributeProto],
}

impl<'a> AttrParser<'a> {
    pub fn new(attrs: &'a [AttributeProto]) -> Self {
        Self { attrs }
    }

    pub fn get_int(&self, name: &str) -> Option<i64> {
        let a = self.attrs.iter().find(|a| a.name == name)?;
        if a.r#type == attribute_proto::AttributeType::Int as i32 {
            Some(a.i)
        } else {
            None
        }
    }

    pub fn get_ints(&self, name: &str) -> Option<Vec<i64>> {
        let a = self.attrs.iter().find(|a| a.name == name)?;
        if a.r#type == attribute_proto::AttributeType::Ints as i32 && !a.ints.is_empty() {
            Some(a.ints.clone())
        } else {
            None
        }
    }

    pub fn get_float(&self, name: &str) -> Option<f32> {
        let a = self.attrs.iter().find(|a| a.name == name)?;
        if a.r#type == attribute_proto::AttributeType::Float as i32 {
            Some(a.f)
        } else {
            None
        }
    }

    pub fn get_floats(&self, name: &str) -> Option<Vec<f32>> {
        let a = self.attrs.iter().find(|a| a.name == name)?;
        if a.r#type == attribute_proto::AttributeType::Floats as i32 && !a.floats.is_empty() {
            Some(a.floats.clone())
        } else {
            None
        }
    }

    pub fn get_string(&self, name: &str) -> Option<String> {
        let a = self.attrs.iter().find(|a| a.name == name)?;
        if a.r#type == attribute_proto::AttributeType::String as i32 {
            Some(String::from_utf8_lossy(&a.s).to_string())
        } else {
            None
        }
    }
}

#[derive(Default)]
pub struct AttrBuilder {
    attrs: Vec<AttributeProto>,
}

impl AttrBuilder {
    pub fn new() -> Self {
        Self { attrs: vec![] }
    }

    pub fn add_int(mut self, name: &str, value: i64) -> Self {
        let a = AttributeProto {
            name: name.to_string(),
            r#type: attribute_proto::AttributeType::Int as i32,
            i: value,
            ..Default::default()
        };
        self.attrs.push(a);
        self
    }

    pub fn add_ints(mut self, name: &str, values: Vec<i64>) -> Self {
        let a = AttributeProto {
            name: name.to_string(),
            r#type: attribute_proto::AttributeType::Ints as i32,
            ints: values,
            ..Default::default()
        };
        self.attrs.push(a);
        self
    }

    pub fn add_float(mut self, name: &str, value: f32) -> Self {
        let a = AttributeProto {
            name: name.to_string(),
            r#type: attribute_proto::AttributeType::Float as i32,
            f: value,
            ..Default::default()
        };
        self.attrs.push(a);
        self
    }

    pub fn add_floats(mut self, name: &str, values: Vec<f32>) -> Self {
        let a = AttributeProto {
            name: name.to_string(),
            r#type: attribute_proto::AttributeType::Floats as i32,
            floats: values,
            ..Default::default()
        };
        self.attrs.push(a);
        self
    }

    pub fn add_string(mut self, name: &str, value: String) -> Self {
        let a = AttributeProto {
            name: name.to_string(),
            r#type: attribute_proto::AttributeType::String as i32,
            s: value.into_bytes(),
            ..Default::default()
        };
        self.attrs.push(a);
        self
    }

    pub fn build(self) -> Vec<AttributeProto> {
        self.attrs
    }
}

pub fn parse_json_ints(json: &JsonValue, key: &str) -> Option<Vec<i64>> {
    json.get(key)?
        .as_array()
        .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect::<Vec<_>>())
}

pub fn parse_json_floats(json: &JsonValue, key: &str) -> Option<Vec<f32>> {
    json.get(key)?.as_array().map(|arr| {
        arr.iter()
            .filter_map(|v| v.as_f64().map(|f| f as f32))
            .collect::<Vec<_>>()
    })
}

pub fn require_attr<T>(name: &str, v: Option<T>) -> Result<T> {
    v.ok_or_else(|| ConversionError::InvalidAttribute(format!("missing attribute: {name}")))
}