use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use std::collections::HashSet;
use std::path::PathBuf;
use std::{env, fs};
use syn::{
parse_file, parse_macro_input, parse_quote, AngleBracketedGenericArguments, GenericArgument,
GenericParam, Item, ItemFn, ItemImpl, ItemMod, ItemStruct, ItemTrait, ItemUse, Macro, Path,
UseTree,
};
use crate::error::{Error, SpannedError};
use crate::kernel_launcher_generator::generate_kernel_launcher;
use crate::rewrite_variadics::*;
use crate::validate_dsl_syntax::validate_entry_point_parameters;
use cutile_compiler::syn_utils::*;
pub fn get_ast_path(tile_rust_crate_root: &Ident) -> Path {
let s = format!("{tile_rust_crate_root}::cutile_compiler::ast");
syn::parse::<Path>(s.parse().unwrap()).unwrap()
}
pub fn get_asts_ident() -> Ident {
Ident::new("_module_asts", Span::call_site())
}
pub fn module(attributes: TokenStream, item: TokenStream) -> TokenStream {
let attrs = parse_macro_input!(attributes as SingleMetaList);
let is_core = attrs.parse_bool("core").unwrap_or(false);
let is_tile_rust_crate = attrs.parse_bool("tile_rust_crate").unwrap_or(false);
let tile_rust_crate_root = Ident::new(
if is_tile_rust_crate {
"crate"
} else {
"cutile"
},
Span::call_site(),
);
let raw_item_source = item.to_string();
let mut module_item = parse_macro_input!(item as ItemMod);
module_item.attrs = attrs.into();
match module_inner(
&module_item,
is_core,
&tile_rust_crate_root,
raw_item_source,
) {
Ok(ts) => ts.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn module_inner(
module_item: &ItemMod,
is_core: bool,
tile_rust_crate_root: &Ident,
raw_item_source: String,
) -> Result<TokenStream2, Error> {
let mut ast_content: Vec<Item> = vec![];
let Some(content) = &module_item.content else {
return module_item.err("Non-empty module expected.");
};
let mut concrete_items: Vec<TokenStream2> = vec![];
let name = &module_item.ident;
let mut module_ast_calls: Vec<String> = vec![];
let mut entry_functions: Vec<TokenStream2> = vec![];
for item in &content.1 {
match item {
syn::Item::Use(use_item) => {
concrete_items.push(use_item.to_token_stream().into());
if !is_core {
let mut use_tree = &use_item.tree;
let mut module_ast_use_path = vec![];
loop {
match use_tree {
UseTree::Path(path) => {
let path_ident_str = path.ident.to_string();
module_ast_use_path.push(path_ident_str);
use_tree = &path.tree;
}
_ => break,
}
}
let module_ast_call_str = format!(
"{}::{}()",
module_ast_use_path.last().unwrap(),
get_asts_ident().to_string()
);
module_ast_calls.push(module_ast_call_str);
let module_ast_use_path_str =
format!("use {};", module_ast_use_path.join("::"));
let module_ast_use_path_item =
syn::parse::<ItemUse>(module_ast_use_path_str.parse().unwrap()).unwrap();
concrete_items.push(module_ast_use_path_item.to_token_stream().into());
}
}
syn::Item::Fn(function_item) => {
let entry_attrs = get_meta_list(
format!("{} :: entry", tile_rust_crate_root.to_string()).as_str(),
&function_item.attrs,
);
if entry_attrs.is_some() {
entry_functions.push(kernel_launcher(name, &function_item)?);
};
ast_content.push(Item::Fn(function_item.clone()));
concrete_items.push(function(function_item.clone(), &tile_rust_crate_root)?);
}
syn::Item::Struct(struct_item) => {
ast_content.push(Item::Struct(struct_item.clone()));
let item_clone = struct_item.clone();
concrete_items.push(structure(item_clone)?.into());
}
syn::Item::Trait(trait_item) => {
if !is_core {
return trait_item.err("Unsupported item type in non-core module: trait definitions are only allowed in core modules.");
}
ast_content.push(Item::Trait(trait_item.clone()));
let item_clone = trait_item.clone();
concrete_items.push(trait_(item_clone)?.into());
}
syn::Item::Type(type_item) => {
concrete_items.push(type_item.to_token_stream().into());
}
syn::Item::Impl(impl_item) => {
if !is_core {
return impl_item.err("Unsupported item type in non-core module: impl blocks are only allowed in core modules.");
}
ast_content.push(Item::Impl(impl_item.clone()));
let item_clone = impl_item.clone();
concrete_items.push(implementation(item_clone)?.into());
}
syn::Item::Macro(macro_item) => {
if !is_core {
return macro_item.err("Unsupported item type in non-core module: macro invocations are only allowed in core modules.");
}
ast_content.push(Item::Macro(macro_item.clone()));
let item_clone = macro_item.clone();
concrete_items.push(item_clone.to_token_stream().into());
}
other => {
return other.err("Unsupported item type in module.");
}
}
}
let ast_path = get_ast_path(&tile_rust_crate_root);
let ast_module_item: ItemMod = module_item.clone();
let ast_module_tokens = module_asts(
ast_module_item,
module_ast_calls,
&tile_rust_crate_root,
raw_item_source,
);
let res = if entry_functions.len() == 0 {
quote! {
pub mod #name {
#![allow(nonstandard_style)]
#![allow(dead_code)]
#![allow(unused_variables)]
use #ast_path;
#ast_module_tokens
#(#concrete_items)*
}
}
} else {
quote! {
pub mod #name {
#![allow(dead_code)]
use std::{iter::zip, future::{Future, IntoFuture}, collections::HashMap, sync::Arc};
use #tile_rust_crate_root::error::{*};
use #tile_rust_crate_root::WithDType;
use #tile_rust_crate_root::{tensor};
use #tile_rust_crate_root::tile_kernel::{*};
use #tile_rust_crate_root::cuda_async::error::{*};
use #tile_rust_crate_root::cuda_async::scheduling_policies::SchedulingPolicy;
use #tile_rust_crate_root::cuda_core::{CudaContext, CudaFunction, CudaModule, CudaStream, DriverError, LaunchConfig};
use #ast_path;
#ast_module_tokens
#(#concrete_items)*
#(#entry_functions)*
}
}
};
Ok(res)
}
pub fn trait_(mut item: ItemTrait) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_trait", &item.attrs);
let is_unchecked = get_meta_list("cuda_tile :: unchecked", &item.attrs);
if is_unchecked.is_some() {
return Ok(quote! {}.into());
}
clear_attributes(
HashSet::from(["cuda_tile :: variadic_trait", "cuda_tile :: ty"]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_trait" => {
let items = variadic_trait(&attributes, item)?;
quote! {
#(#items)*
}
}
_ => {
let item = desugar_trait_cgas(&item)?;
quote! { #item }
}
},
None => {
let item = desugar_trait_cgas(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn implementation(mut item: ItemImpl) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_impl", &item.attrs);
let is_unchecked = get_meta_list("cuda_tile :: unchecked", &item.attrs);
if is_unchecked.is_some() {
return Ok(quote! {}.into());
}
clear_attributes(
HashSet::from([
"cuda_tile :: variadic_trait_impl",
"cuda_tile :: variadic_impl",
"cuda_tile :: ty",
]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_impl" => {
let items = variadic_impl(&attributes, item)?;
quote! {
#(#items)*
}
}
_ => {
let item = desugar_impl_cgas(&item)?;
quote! { #item }
}
},
None => {
let item = desugar_impl_cgas(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn structure(mut item: ItemStruct) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_struct", &item.attrs);
clear_attributes(
HashSet::from(["cuda_tile :: variadic_struct", "cuda_tile :: ty"]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_struct" => {
let items = variadic_struct(&attributes, item)?;
let structs = items.iter().map(|item| item.0.clone()).collect::<Vec<_>>();
let maybe_impls = items
.iter()
.filter(|item| item.1.is_some())
.collect::<Vec<_>>();
let impls = maybe_impls
.iter()
.map(|item| item.1.clone().unwrap())
.collect::<Vec<_>>();
quote! {
#(#structs)*
#(#impls)*
}
}
_ => {
let item = desugar_structure_cgas(&item)?;
quote! { #item }
}
},
None => {
let item = desugar_structure_cgas(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn function(mut item: ItemFn, tile_rust_crate_root: &Ident) -> Result<TokenStream2, Error> {
if get_meta_list_by_last_segment("entry", &item.attrs).is_some() {
validate_entry_point_parameters(&item)?
}
let attributes = get_meta_list("cuda_tile :: variadic_op", &item.attrs);
clear_attributes(
HashSet::from([
"cuda_tile :: variadic_op",
"cuda_tile :: op",
"cuda_tile :: compiler_op",
]),
&mut item.attrs,
);
clear_attributes(
HashSet::from([format!("{} :: entry", tile_rust_crate_root.to_string()).as_str()]),
&mut item.attrs,
);
let concrete_items = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_op" => variadic_op(&attributes, item.clone())?,
_ => vec![desugar_function_cgas(&item)?],
},
None => vec![desugar_function_cgas(&item)?],
};
let result = quote! {
#(#concrete_items)*
};
Ok(result.into())
}
pub fn kernel_launcher(module_ident: &Ident, item: &ItemFn) -> Result<TokenStream2, Error> {
let module_name = module_ident.to_string();
let function_name = item.sig.ident.to_string();
let function_entry_name = format!("{}_entry", function_name);
let launcher_name = format!("{}", function_name.to_case(Case::UpperCamel));
let launcher_args_name = format!("{}Args", launcher_name);
let unsafety = item.sig.unsafety;
let (required_generics, (launcher_args_type, launcher_type_def), device_op_impl) =
generate_kernel_launcher(
item,
&module_name,
&function_name,
&function_entry_name,
&launcher_name,
&launcher_args_name,
)?;
let launcher_ident = Ident::new(launcher_name.as_str(), Span::call_site());
let _launcher_args_ident = Ident::new(launcher_args_name.as_str(), Span::call_site());
let generic_params = required_generics.get_required_generics();
let generic_args = required_generics.get_generic_args();
let device_op_param: GenericParam =
parse_quote! { DI: DeviceOperation<Output=#launcher_args_type> };
let device_op_arg: GenericArgument = parse_quote! { DI };
let mut struct_generics = generic_params.clone();
struct_generics.params.push(device_op_param.clone());
let mut struct_args = generic_args.clone();
struct_args.args.push(device_op_arg.clone());
let tile_kernel_impl_type_params = struct_generics.clone();
let tile_kernel_type_args: AngleBracketedGenericArguments =
parse_quote! { <#launcher_args_type, #device_op_arg> };
let _device_operation_impl_type_params = struct_generics.clone();
let into_future_impl_type_params = struct_generics.clone();
let result = quote! {
#launcher_type_def
#[derive(Debug)]
pub struct #launcher_ident #struct_generics {
_const_grid: bool,
_grid: (u32, u32, u32),
input: Option<DI>,
function_generics: Option<Vec<String>>
}
impl #tile_kernel_impl_type_params #launcher_ident #struct_args {
pub #unsafety fn launch(input: DI) -> Self {
Self {
_const_grid: false,
_grid: (0, 0, 0),
input: Some(input),
function_generics: None
}
}
}
impl #tile_kernel_impl_type_params TileKernel #tile_kernel_type_args for #launcher_ident #struct_args {
fn grid(mut self, grid: (u32, u32, u32)) -> Self {
self._grid = grid;
self._const_grid = false;
self
}
fn const_grid(mut self, grid: (u32, u32, u32)) -> Self {
self._grid = grid;
self._const_grid = true;
self
}
fn get_launch_grid(&self) -> (u32, u32, u32) {
self._grid
}
fn generics(mut self, generics: Vec<String>) -> Self {
self.function_generics = Some(generics);
self
}
}
impl #into_future_impl_type_params IntoFuture for #launcher_ident #struct_args {
type Output = Result<#launcher_args_type, DeviceError>;
type IntoFuture = DeviceFuture<#launcher_args_type, #launcher_ident #struct_args>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| policy.schedule(self)) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
#device_op_impl
};
let Some(_entry_attrs) = get_meta_list_by_last_segment("entry", &item.attrs) else {
return item.sig.ident.err(&format!(
"Unexpected entry point {function_name}: Missing entry annotation."
));
};
if let Ok(dir) = env::var("DUMP_KERNEL_LAUNCHER_DIR") {
let file = parse_file(&result.to_string()).expect("Failed to parse file.");
let filename = format!("{module_name}_{function_name}_launcher.rs");
let path = PathBuf::from(dir).join(filename);
let contents = file_item_string_pretty(&file);
fs::write(path.clone(), contents).expect(format!("Failed to write {path:?}").as_str());
}
Ok(result)
}
pub fn module_asts(
item: ItemMod,
module_ast_calls: Vec<String>,
tile_rust_crate_root: &Ident,
raw_item_source: String,
) -> TokenStream2 {
let ast_path = get_ast_path(tile_rust_crate_root);
let name_string = item.ident.to_string();
let vec_expr_str = format!("vec![{}]", module_ast_calls.join(","));
let vec_expr = syn::parse::<Macro>(vec_expr_str.parse().unwrap()).unwrap();
let asts_ident = get_asts_ident();
let item_start_span = match &item.vis {
syn::Visibility::Public(vis_pub) => vis_pub.span,
syn::Visibility::Restricted(vis_r) => vis_r.pub_token.span,
syn::Visibility::Inherited => item.mod_token.span,
};
let source_file = item_start_span.file();
let base_line = item_start_span.start().line;
let base_col = item_start_span.start().column;
let source_text = {
let full_span = item
.content
.as_ref()
.and_then(|(brace, _)| item_start_span.join(brace.span.close()));
full_span
.and_then(|sp| sp.source_text())
.unwrap_or_else(|| raw_item_source)
};
let result = quote! {
pub fn #asts_ident() -> Vec<#ast_path::Module> {
use #ast_path::syn;
let source_text: &str = #source_text;
let parsed_mod: syn::ItemMod = syn::parse_str(source_text)
.expect("module_asts: failed to re-parse captured source text");
let span_base = #ast_path::SpanBase::new(
#source_file.to_string(),
#base_line,
#base_col,
);
let this_ast = #ast_path::Module::with_span_base(
#name_string,
parsed_mod,
span_base,
);
let mut module_asts: Vec<#ast_path::Module> = vec![];
let other_module_asts_asts: Vec<Vec<#ast_path::Module>> = #vec_expr;
for other_module_asts in other_module_asts_asts {
for module_ast in other_module_asts {
module_asts.push(module_ast);
}
}
module_asts.push(this_ast);
module_asts
}
};
result.into()
}