extern crate alloc;
use alloc::rc::Rc;
use burn_store::{TensorSnapshot, TensorSnapshotError};
use proc_macro2::{Ident, Span, TokenStream};
use onnx_ir::Argument;
use crate::burn::BurnImports;
#[derive(Debug, Clone)]
pub struct Field {
pub name: Ident,
pub ty: TokenStream,
pub init: TokenStream,
}
impl Field {
pub fn new<S: AsRef<str>>(name: S, ty: TokenStream, init: TokenStream) -> Self {
if name.as_ref().is_empty() {
panic!("Field with type {ty:?} was passed with empty name");
}
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
ty,
init,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorKind {
Int,
Float,
Bool,
}
impl From<onnx_ir::ir::DType> for TensorKind {
fn from(dtype: onnx_ir::ir::DType) -> Self {
use onnx_ir::ir::DType;
match dtype {
DType::F32 => TensorKind::Float,
DType::F64 => TensorKind::Float,
DType::I32 => TensorKind::Int,
DType::I64 => TensorKind::Int,
DType::I8 | DType::U8 => TensorKind::Int,
DType::Bool => TensorKind::Bool,
_ => panic!("Unsupported tensor type"),
}
}
}
impl quote::ToTokens for TensorKind {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let kind = match self {
TensorKind::Int => quote::quote! { Int },
TensorKind::Float => quote::quote! { Float },
TensorKind::Bool => quote::quote! { Bool },
};
tokens.extend(kind);
}
}
#[allow(dead_code)]
pub trait OnnxIntoNode: Sized {
fn from_onnx(node: onnx_ir::Node) -> Self;
}
pub trait NodeCodegen: std::fmt::Debug {
fn inputs(&self) -> &[Argument];
fn outputs(&self) -> &[Argument];
fn forward(&self, scope: &mut super::scope::ScopeAtPosition<'_>) -> TokenStream;
fn register_imports(&self, _imports: &mut BurnImports) {}
fn field(&self) -> Option<Field> {
None
}
fn collect_snapshots(&self, _field_name: &str) -> Vec<TensorSnapshot> {
vec![]
}
}
pub fn extract_node_data(
inputs: &[onnx_ir::Argument],
input_index: usize,
) -> Option<burn::tensor::TensorData> {
let input = inputs.get(input_index)?;
input.value()
}
pub fn arg_to_ident(arg: &Argument) -> proc_macro2::Ident {
proc_macro2::Ident::new(&arg.name, proc_macro2::Span::call_site())
}
#[cfg(feature = "onnx")]
pub type SerializationBackend = burn_ndarray::NdArray<f64>;
pub fn create_lazy_snapshot(
input: &Argument,
path: &str,
container_type: &str,
) -> Option<TensorSnapshot> {
use burn::module::ParamId;
use burn::tensor::TensorData;
use onnx_ir::ir::ArgType;
let (dtype, shape) = match &input.ty {
ArgType::Tensor(tensor_type) => {
let dtype = tensor_type.dtype;
let shape = tensor_type
.static_shape
.as_ref()
.map(|s| s.to_vec())
.unwrap_or_default();
(dtype, shape)
}
_ => return None,
};
let input_clone = input.clone();
let data_fn = Rc::new(move || -> Result<TensorData, TensorSnapshotError> {
input_clone.value().ok_or_else(|| {
TensorSnapshotError::DataError(format!(
"Failed to extract tensor data for '{}'",
input_clone.name
))
})
});
let path_stack: Vec<String> = path.split('.').map(String::from).collect();
let container_stack = vec![format!("Struct:{}", container_type)];
Some(TensorSnapshot::from_closure(
data_fn,
dtype,
shape,
path_stack,
container_stack,
ParamId::new(),
))
}