use proc_macro2::TokenStream;
use quote::quote;
use onnx_ir::ir::DType;
use onnx_ir::node::padding::{PaddingConfig1d, PaddingConfig2d, PaddingConfig3d};
fn convert_primitive<T: core::fmt::Debug>(primitive: T) -> TokenStream {
let value = format!("{primitive:?}");
value.parse().unwrap()
}
fn convert_to_array<'a, I, T>(list: I) -> TokenStream
where
I: Iterator<Item = &'a T>,
T: ToTokens + 'a,
{
let mut body = quote! {};
list.for_each(|item| {
let elem = item.to_tokens();
body.extend(quote! {#elem,});
});
quote! {
[#body]
}
}
pub trait ToTokens {
fn to_tokens(&self) -> TokenStream;
}
impl<const N: usize, T: Copy + ToTokens> ToTokens for [T; N] {
fn to_tokens(&self) -> TokenStream {
convert_to_array(self.iter())
}
}
impl<T: Copy + ToTokens> ToTokens for Vec<T> {
fn to_tokens(&self) -> TokenStream {
convert_to_array(self.iter())
}
}
impl ToTokens for usize {
fn to_tokens(&self) -> TokenStream {
convert_primitive(self)
}
}
impl ToTokens for i64 {
fn to_tokens(&self) -> TokenStream {
convert_primitive(self)
}
}
impl ToTokens for f64 {
fn to_tokens(&self) -> TokenStream {
convert_primitive(self)
}
}
impl ToTokens for f32 {
fn to_tokens(&self) -> TokenStream {
convert_primitive(self)
}
}
impl ToTokens for PaddingConfig1d {
fn to_tokens(&self) -> TokenStream {
match self {
Self::Valid => quote! { PaddingConfig1d::Valid },
Self::Explicit(padding) => {
let padding = padding.to_tokens();
quote! { PaddingConfig1d::Explicit(#padding) }
}
}
}
}
impl ToTokens for PaddingConfig2d {
fn to_tokens(&self) -> TokenStream {
match self {
Self::Valid => quote! { PaddingConfig2d::Valid },
Self::Explicit(padding1, padding2) => {
let padding1 = padding1.to_tokens();
let padding2 = padding2.to_tokens();
quote! { PaddingConfig2d::Explicit(#padding1, #padding2) }
}
}
}
}
impl ToTokens for PaddingConfig3d {
fn to_tokens(&self) -> TokenStream {
match self {
Self::Valid => quote! { PaddingConfig3d::Valid },
Self::Explicit(padding1, padding2, padding3) => {
let padding1 = padding1.to_tokens();
let padding2 = padding2.to_tokens();
let padding3 = padding3.to_tokens();
quote! { PaddingConfig3d::Explicit(#padding1, #padding2, #padding3) }
}
}
}
}
impl ToTokens for DType {
fn to_tokens(&self) -> TokenStream {
match self {
DType::F16 => quote! { burn::tensor::DType::F16 },
DType::BF16 => quote! { burn::tensor::DType::BF16 },
DType::F32 => quote! { burn::tensor::DType::F32 },
DType::F64 => quote! { burn::tensor::DType::F64 },
DType::I8 => quote! { burn::tensor::DType::I8 },
DType::I16 => quote! { burn::tensor::DType::I16 },
DType::I32 => quote! { burn::tensor::DType::I32 },
DType::I64 => quote! { burn::tensor::DType::I64 },
DType::U8 => quote! { burn::tensor::DType::U8 },
DType::U16 => quote! { burn::tensor::DType::U16 },
DType::U32 => quote! { burn::tensor::DType::U32 },
DType::U64 => quote! { burn::tensor::DType::U64 },
DType::Bool => quote! { burn::tensor::DType::Bool },
_ => panic!(
"Unsupported dtype for ONNX code generation: {:?}. \
Flex32 and QFloat are Burn-specific runtime types.",
self
),
}
}
}