use cutile_compiler::syn_utils::{get_ident_from_path, get_sig_types, get_type_ident};
use cutile_compiler::types::get_ptr_type;
use quote::ToTokens;
use syn::{ItemFn, Type};
use crate::error::{Error, SpannedError};
pub fn validate_entry_point_parameters(item: &ItemFn) -> Result<(), Error> {
let (input_types, _output_type) = get_sig_types(&item.sig, None);
for ty in input_types.iter() {
match ty {
Type::Reference(_) => {
let Some(ident) = get_type_ident(ty) else {
return ty.err("Not a supported parameter type.");
};
let type_name = ident.to_string();
if type_name == "MappedPartitionMut" {
ty.err("MappedPartitionMut parameters are passed by value; use `mut z: MappedPartitionMut<...>`, not `&mut MappedPartitionMut<...>`.")?;
}
if type_name != "Tensor" {
ty.err(&format!(
"References to `{}` as parameters are not supported. \
If this is a type alias for `Tensor`, define the alias in the same \
`#[cutile::module]` as the entry function; imported Tensor aliases are \
not supported by launcher generation.",
type_name
))?;
}
}
Type::Path(path_ty) => {
let ident = get_ident_from_path(&path_ty.path);
let type_name = ident.to_string();
if type_name == "Tensor" {
ty.err("Tensors cannot be moved into kernel functions. \
&mut Tensor corresponds to a partitioned tensor argument (e.g. x.partition([...])), \
and &Tensor corresponds to a tensor reference argument (e.g. Arc::new(x) or x.into()).")?;
}
if type_name == "MappedPartitionMut" {
continue;
}
}
Type::Ptr(ptr_type) => {
let ptr_str = ptr_type.to_token_stream().to_string();
let Some(_) = get_ptr_type(&ptr_str) else {
return ty.err(&format!("{} is not a supported pointer type.", ptr_str));
};
}
_ => {
ty.err(&format!(
"{} is not a supported parameter type.",
ty.to_token_stream()
))?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn imported_tensor_alias_parameter_error_mentions_same_module_aliases() {
let item: ItemFn = syn::parse_quote! {
fn kernel(x: &ImportedTensorAlias) {}
};
let err = validate_entry_point_parameters(&item).expect_err("expected alias rejection");
let message = err.to_string();
assert!(
message.contains("define the alias in the same `#[cutile::module]`")
&& message.contains("imported Tensor aliases are not supported"),
"{message}"
);
}
}