use super::tile_rust_type::TileRustType;
use cutile_ir::ir::{PointerType, ScalarType, TileElementType, TileType, Type};
use quote::ToTokens;
use syn::FnArg;
pub fn scalar_from_name(name: &str) -> Option<ScalarType> {
match name {
"i1" => Some(ScalarType::I1),
"i8" => Some(ScalarType::I8),
"i16" => Some(ScalarType::I16),
"i32" => Some(ScalarType::I32),
"i64" => Some(ScalarType::I64),
"f16" => Some(ScalarType::F16),
"bf16" => Some(ScalarType::BF16),
"f32" => Some(ScalarType::F32),
"tf32" => Some(ScalarType::TF32),
"f64" => Some(ScalarType::F64),
"f8e4m3fn" | "f8E4M3FN" => Some(ScalarType::F8E4M3FN),
"f8e5m2" | "f8E5M2" => Some(ScalarType::F8E5M2),
"bool" => Some(ScalarType::I1),
"u8" => Some(ScalarType::I8),
"u16" => Some(ScalarType::I16),
"u32" => Some(ScalarType::I32),
"u64" => Some(ScalarType::I64),
_ => None,
}
}
pub fn convert_type(old: &TileRustType) -> Option<Type> {
if let Some(ty) = &old.tile_ir_ty {
return Some(ty.clone());
}
let name = old.cuda_tile_name.as_deref()?;
scalar_from_name(name).map(Type::Scalar)
}
pub fn make_tile_type(element_name: &str, shape: &[i64]) -> Option<Type> {
let scalar = scalar_from_name(element_name)?;
Some(Type::Tile(TileType {
shape: shape.to_vec(),
element_type: TileElementType::Scalar(scalar),
}))
}
pub fn make_tensor_view_type(element_name: &str, shape: &[i64], strides: &[i64]) -> Option<Type> {
let scalar = scalar_from_name(element_name)?;
Some(Type::TensorView(cutile_ir::ir::TensorViewType {
element_type: scalar,
shape: shape.to_vec(),
strides: strides.to_vec(),
}))
}
pub fn make_scalar_tile_type(element_name: &str) -> Option<Type> {
make_tile_type(element_name, &[])
}
pub fn compile_entry_param_type(param: &FnArg) -> Option<Type> {
let FnArg::Typed(typed) = param else {
return None;
};
let ty_str = typed.ty.to_token_stream().to_string();
if ty_str.contains("PointerTile") {
let elem = extract_pointer_element_type(&ty_str)?;
let scalar = scalar_from_name(&elem)?;
return Some(Type::Tile(TileType {
shape: vec![],
element_type: TileElementType::Pointer(Box::new(PointerType { pointee: scalar })),
}));
}
let rust_type = ty_str.trim().to_string();
if let Some(scalar) = rust_scalar_type(&rust_type) {
return Some(Type::Tile(TileType {
shape: vec![],
element_type: TileElementType::Scalar(scalar),
}));
}
None
}
fn rust_scalar_type(name: &str) -> Option<ScalarType> {
match name {
"bool" => Some(ScalarType::I1),
"i8" | "u8" => Some(ScalarType::I8),
"i16" | "u16" => Some(ScalarType::I16),
"i32" | "u32" => Some(ScalarType::I32),
"i64" | "u64" => Some(ScalarType::I64),
"f16" => Some(ScalarType::F16),
"bf16" => Some(ScalarType::BF16),
"f32" => Some(ScalarType::F32),
"f64" => Some(ScalarType::F64),
_ => None,
}
}
fn extract_pointer_element_type(ty_str: &str) -> Option<String> {
let after_mut = ty_str.split("mut").nth(1)?;
let trimmed = after_mut.trim();
let end = trimmed.find(|c: char| c == ',' || c == '>' || c == ' ')?;
Some(trimmed[..end].to_string())
}