extern crate alloc;
use alloc::rc::Rc;
use burn::tensor::Shape;
use burn_store::{TensorSnapshot, TensorSnapshotError};
use proc_macro2::{Ident, Span, TokenStream};
use onnx_ir::Argument;
use crate::burn::BurnImports;
#[derive(Debug, Clone)]
pub struct Field {
pub name: Ident,
pub ty: TokenStream,
pub init: TokenStream,
}
impl Field {
pub fn new<S: AsRef<str>>(name: S, ty: TokenStream, init: TokenStream) -> Self {
if name.as_ref().is_empty() {
panic!("Field with type {ty:?} was passed with empty name");
}
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
ty,
init,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorKind {
Int,
Float,
Bool,
}
impl From<onnx_ir::ir::DType> for TensorKind {
fn from(dtype: onnx_ir::ir::DType) -> Self {
if dtype.is_float() {
TensorKind::Float
} else if dtype.is_int() || dtype.is_uint() {
TensorKind::Int
} else if dtype.is_bool() {
TensorKind::Bool
} else {
panic!("Unsupported tensor type: {dtype:?}")
}
}
}
impl quote::ToTokens for TensorKind {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let kind = match self {
TensorKind::Int => quote::quote! { Int },
TensorKind::Float => quote::quote! { Float },
TensorKind::Bool => quote::quote! { Bool },
};
tokens.extend(kind);
}
}
#[allow(dead_code)]
pub trait OnnxIntoNode: Sized {
fn from_onnx(node: onnx_ir::Node) -> Self;
}
pub trait NodeCodegen: std::fmt::Debug {
fn inputs(&self) -> &[Argument];
fn outputs(&self) -> &[Argument];
fn forward(&self, scope: &mut super::scope::ScopeAtPosition<'_>) -> TokenStream;
fn register_imports(&self, _imports: &mut BurnImports) {}
fn field(&self) -> Option<Field> {
None
}
fn collect_snapshots(&self, _field_name: &str) -> Vec<TensorSnapshot> {
vec![]
}
}
pub fn extract_node_data(
inputs: &[onnx_ir::Argument],
input_index: usize,
) -> Option<burn::tensor::TensorData> {
let input = inputs.get(input_index)?;
input.value()
}
pub fn arg_to_ident(arg: &Argument) -> proc_macro2::Ident {
proc_macro2::Ident::new(&arg.name, proc_macro2::Span::call_site())
}
pub type SerializationBackend = burn_ndarray::NdArray<f64>;
pub fn create_lazy_snapshot(
input: &Argument,
path: &str,
container_type: &str,
) -> Option<TensorSnapshot> {
use burn::module::ParamId;
use burn::tensor::TensorData;
use onnx_ir::ir::ArgType;
let (dtype, shape, is_scalar) = match &input.ty {
ArgType::Tensor(tensor_type) => {
let dtype = tensor_type.dtype;
let shape: Shape = tensor_type.static_shape_known().unwrap_or_default().into();
(dtype, shape, false)
}
ArgType::ScalarTensor(d) | ArgType::ScalarNative(d) => (*d, Shape::from([1]), true),
_ => return None,
};
let input_clone = input.clone();
let data_fn = Rc::new(move || -> Result<TensorData, TensorSnapshotError> {
let mut data = input_clone.value().ok_or_else(|| {
TensorSnapshotError::DataError(format!(
"Failed to extract tensor data for '{}'",
input_clone.name
))
})?;
if is_scalar && data.shape.is_empty() {
data.shape = Shape::from([1]);
}
Ok(data)
});
let path_stack: Vec<String> = path.split('.').map(String::from).collect();
let container_stack = vec![format!("Struct:{}", container_type)];
Some(TensorSnapshot::from_closure(
data_fn,
dtype,
shape,
path_stack,
container_stack,
ParamId::new(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use onnx_ir::ir::{BoolStore, DType};
#[test]
fn tensor_kind_from_dtype_float_types() {
assert_eq!(TensorKind::from(DType::F16), TensorKind::Float);
assert_eq!(TensorKind::from(DType::BF16), TensorKind::Float);
assert_eq!(TensorKind::from(DType::F32), TensorKind::Float);
assert_eq!(TensorKind::from(DType::F64), TensorKind::Float);
}
#[test]
fn tensor_kind_from_dtype_signed_int_types() {
assert_eq!(TensorKind::from(DType::I8), TensorKind::Int);
assert_eq!(TensorKind::from(DType::I16), TensorKind::Int);
assert_eq!(TensorKind::from(DType::I32), TensorKind::Int);
assert_eq!(TensorKind::from(DType::I64), TensorKind::Int);
}
#[test]
fn tensor_kind_from_dtype_unsigned_int_types() {
assert_eq!(TensorKind::from(DType::U8), TensorKind::Int);
assert_eq!(TensorKind::from(DType::U16), TensorKind::Int);
assert_eq!(TensorKind::from(DType::U32), TensorKind::Int);
assert_eq!(TensorKind::from(DType::U64), TensorKind::Int);
}
#[test]
fn tensor_kind_from_dtype_bool() {
assert_eq!(
TensorKind::from(DType::Bool(BoolStore::Native)),
TensorKind::Bool
);
}
}