openinfer-dsl 0.1.3

Rust-embedded DSL for defining OpenInfer graphs with explicit control flow and memory semantics.
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::Ident;

use crate::types::InitValue;

pub(crate) fn match_dtype(dtype: &Ident) -> syn::Result<TokenStream> {
    let s = dtype.to_string();
    match s.as_str() {
        "i8" => Ok(quote! { ::openinfer::DType::I8 }),
        "i16" => Ok(quote! { ::openinfer::DType::I16 }),
        "f32" => Ok(quote! { ::openinfer::DType::F32 }),
        "f64" => Ok(quote! { ::openinfer::DType::F64 }),
        "u8" => Ok(quote! { ::openinfer::DType::U8 }),
        "u16" => Ok(quote! { ::openinfer::DType::U16 }),
        "i32" => Ok(quote! { ::openinfer::DType::I32 }),
        "i64" => Ok(quote! { ::openinfer::DType::I64 }),
        "u32" => Ok(quote! { ::openinfer::DType::U32 }),
        "u64" => Ok(quote! { ::openinfer::DType::U64 }),
        "bool" => Ok(quote! { ::openinfer::DType::Bool }),
        "bitset" => Ok(quote! { ::openinfer::DType::Bitset }),
        "f16" => Ok(quote! { ::openinfer::DType::F16 }),
        "bf16" => Ok(quote! { ::openinfer::DType::BF16 }),
        "f8" => Ok(quote! { ::openinfer::DType::F8 }),
        "i4" => Ok(quote! { ::openinfer::DType::I4 }),
        "i2" => Ok(quote! { ::openinfer::DType::I2 }),
        "i1" => Ok(quote! { ::openinfer::DType::I1 }),
        "u4" => Ok(quote! { ::openinfer::DType::U4 }),
        "u2" => Ok(quote! { ::openinfer::DType::U2 }),
        "u1" => Ok(quote! { ::openinfer::DType::U1 }),
        "t2" => Ok(quote! { ::openinfer::DType::T2 }),
        "t1" => Ok(quote! { ::openinfer::DType::T1 }),
        _ => Err(syn::Error::new(dtype.span(), "unsupported dtype")),
    }
}

pub(crate) fn init_expr(init: &Option<InitValue>, dtype: &Ident) -> syn::Result<TokenStream> {
    let dtype_str = dtype.to_string();
    let out = match init {
        Some(InitValue::Float { lit, negative }) => {
            let lit_expr = if *negative {
                quote! { -#lit }
            } else {
                quote! { #lit }
            };
            match dtype_str.as_str() {
                "f16" => quote! {
                    Some(::openinfer::ScalarValue::F16(::openinfer::F16::from_f32(#lit_expr as f32)))
                },
                "bf16" => quote! {
                    Some(::openinfer::ScalarValue::BF16(::openinfer::BF16::from_f32(#lit_expr as f32)))
                },
                "f8" => quote! {
                    Some(::openinfer::ScalarValue::F8(::openinfer::F8::from_f32(#lit_expr as f32)))
                },
                "f32" => quote! { Some(::openinfer::ScalarValue::F32(#lit_expr as f32)) },
                "f64" => quote! { Some(::openinfer::ScalarValue::F64(#lit_expr as f64)) },
                _ => {
                    return Err(syn::Error::new(
                        dtype.span(),
                        "float init requires f8/bf16/f16/f32/f64 dtype",
                    ))
                }
            }
        }
        Some(InitValue::Int { lit, negative }) => {
            let lit_expr = if *negative {
                quote! { -#lit }
            } else {
                quote! { #lit }
            };
            let value: i128 = lit.base10_parse()?;
            let value = if *negative { -value } else { value };
            match dtype_str.as_str() {
                "i8" => {
                    if value < i8::MIN as i128 || value > i8::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "i8 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I8(#lit_expr as i8)) }
                }
                "i16" => {
                    if value < i16::MIN as i128 || value > i16::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "i16 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I16(#lit_expr as i16)) }
                }
                "i32" => {
                    if value < i32::MIN as i128 || value > i32::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "i32 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I32(#lit_expr as i32)) }
                }
                "i64" => {
                    if value < i64::MIN as i128 || value > i64::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "i64 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I64(#lit_expr as i64)) }
                }
                "u8" => {
                    if value < 0 || value > u8::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "u8 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U8(#lit_expr as u8)) }
                }
                "u16" => {
                    if value < 0 || value > u16::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "u16 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U16(#lit_expr as u16)) }
                }
                "u32" => {
                    if value < 0 || value > u32::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "u32 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U32(#lit_expr as u32)) }
                }
                "u64" => {
                    if value < 0 || value > u64::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "u64 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U64(#lit_expr as u64)) }
                }
                "bool" => {
                    if value != 0 && value != 1 {
                        return Err(syn::Error::new(dtype.span(), "bool init must be 0 or 1"));
                    }
                    quote! { Some(::openinfer::ScalarValue::Bool(#lit_expr != 0)) }
                }
                "bitset" => {
                    if value < 0 || value > u8::MAX as i128 {
                        return Err(syn::Error::new(dtype.span(), "bitset init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::Bitset(::openinfer::Bitset { bits: #lit_expr as u8 })) }
                }
                "i4" => {
                    if value < -8 || value > 7 {
                        return Err(syn::Error::new(dtype.span(), "i4 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I4(::openinfer::I4::from_i8(#lit_expr as i8))) }
                }
                "i2" => {
                    if value < -2 || value > 1 {
                        return Err(syn::Error::new(dtype.span(), "i2 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I2(::openinfer::I2::from_i8(#lit_expr as i8))) }
                }
                "i1" => {
                    if value < -1 || value > 0 {
                        return Err(syn::Error::new(dtype.span(), "i1 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::I1(::openinfer::I1::from_i8(#lit_expr as i8))) }
                }
                "u4" => {
                    if value < 0 || value > 15 {
                        return Err(syn::Error::new(dtype.span(), "u4 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U4(::openinfer::U4::from_u8(#lit_expr as u8))) }
                }
                "u2" => {
                    if value < 0 || value > 3 {
                        return Err(syn::Error::new(dtype.span(), "u2 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U2(::openinfer::U2::from_u8(#lit_expr as u8))) }
                }
                "u1" => {
                    if value < 0 || value > 1 {
                        return Err(syn::Error::new(dtype.span(), "u1 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::U1(::openinfer::U1::from_u8(#lit_expr as u8))) }
                }
                "t2" => {
                    if value < -1 || value > 1 {
                        return Err(syn::Error::new(dtype.span(), "t2 init out of range"));
                    }
                    quote! { Some(::openinfer::ScalarValue::T2(::openinfer::T2::from_i8(#lit_expr as i8))) }
                }
                "t1" => {
                    if value != -1 && value != 1 {
                        return Err(syn::Error::new(dtype.span(), "t1 init must be -1 or 1"));
                    }
                    quote! { Some(::openinfer::ScalarValue::T1(::openinfer::T1::from_i8(#lit_expr as i8))) }
                }
                _ => {
                    return Err(syn::Error::new(
                        dtype.span(),
                        "integer init requires integer/bool/bitset dtype",
                    ))
                }
            }
        }
        Some(InitValue::Bool { lit }) => {
            match dtype_str.as_str() {
                "bool" => quote! { Some(::openinfer::ScalarValue::Bool(#lit)) },
                _ => {
                    return Err(syn::Error::new(
                        dtype.span(),
                        "bool init requires bool dtype",
                    ))
                }
            }
        }
        None => quote! { None },
    };
    Ok(out)
}