use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use proc_macro2::{LineColumn, Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use std::collections::HashSet;
use std::path::PathBuf;
use std::{env, fs};
use syn::{
parse_file, parse_macro_input, parse_quote, AngleBracketedGenericArguments, GenericArgument,
GenericParam, Item, ItemFn, ItemImpl, ItemMod, ItemStruct, ItemTrait, ItemUse, Macro, Path,
UseTree,
};
use crate::error::{Error, SpannedError};
use crate::kernel_launcher_generator::generate_kernel_launcher;
use crate::rewrite_variadics::*;
use crate::validate_dsl_syntax::validate_entry_point_parameters;
use cutile_compiler::kernel_naming::KernelNaming;
use cutile_compiler::syn_utils::*;
fn line_column_to_offset(source: &str, loc: LineColumn) -> Option<usize> {
let mut line_start = 0usize;
let mut current_line = 1usize;
for line in source.split_inclusive('\n') {
if current_line == loc.line {
let column_offset = byte_offset_for_char_column(line, loc.column)?;
return Some(line_start + column_offset);
}
line_start += line.len();
current_line += 1;
}
if current_line == loc.line {
let tail = &source[line_start..];
let column_offset = byte_offset_for_char_column(tail, loc.column)?;
return Some(line_start + column_offset);
}
None
}
fn byte_offset_for_char_column(line: &str, column: usize) -> Option<usize> {
if column == 0 {
return Some(0);
}
if column == line.chars().count() {
return Some(line.len());
}
line.char_indices().nth(column).map(|(idx, _)| idx)
}
fn source_slice_from_file(path: &str, start: LineColumn, end: LineColumn) -> Option<String> {
let source = fs::read_to_string(path).ok()?;
let start_offset = line_column_to_offset(&source, start)?;
let end_offset = line_column_to_offset(&source, end)?;
source.get(start_offset..end_offset).map(str::to_string)
}
pub fn get_ast_path(tile_rust_crate_root: &Ident) -> Path {
let s = format!("{tile_rust_crate_root}::cutile_compiler::ast");
syn::parse::<Path>(s.parse().unwrap()).unwrap()
}
pub fn get_asts_ident() -> Ident {
Ident::new("_module_asts", Span::call_site())
}
pub fn module(attributes: TokenStream, item: TokenStream) -> TokenStream {
let attrs = parse_macro_input!(attributes as SingleMetaList);
let is_core = attrs.parse_bool("core").unwrap_or(false);
let is_tile_rust_crate = attrs.parse_bool("tile_rust_crate").unwrap_or(false);
let tile_rust_crate_root = Ident::new(
if is_tile_rust_crate {
"crate"
} else {
"cutile"
},
Span::call_site(),
);
let raw_item_source = item.to_string();
let mut module_item = parse_macro_input!(item as ItemMod);
module_item.attrs = attrs.into();
match module_inner(
&module_item,
is_core,
&tile_rust_crate_root,
raw_item_source,
) {
Ok(ts) => ts.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn module_inner(
module_item: &ItemMod,
is_core: bool,
tile_rust_crate_root: &Ident,
raw_item_source: String,
) -> Result<TokenStream2, Error> {
let mut ast_content: Vec<Item> = vec![];
let Some(content) = &module_item.content else {
return module_item.err("Non-empty module expected.");
};
let mut concrete_items: Vec<TokenStream2> = vec![];
let name = &module_item.ident;
let mut module_ast_calls: Vec<String> = vec![];
let mut entry_functions: Vec<TokenStream2> = vec![];
for item in &content.1 {
match item {
syn::Item::Use(use_item) => {
concrete_items.push(use_item.to_token_stream());
if !is_core {
let mut use_tree = &use_item.tree;
let mut module_ast_use_path = vec![];
while let UseTree::Path(path) = use_tree {
let path_ident_str = path.ident.to_string();
module_ast_use_path.push(path_ident_str);
use_tree = &path.tree;
}
let module_ast_call_str = format!(
"{}::{}()",
module_ast_use_path.last().unwrap(),
get_asts_ident()
);
module_ast_calls.push(module_ast_call_str);
let module_ast_use_path_str =
format!("use {};", module_ast_use_path.join("::"));
let module_ast_use_path_item =
syn::parse::<ItemUse>(module_ast_use_path_str.parse().unwrap()).unwrap();
concrete_items.push(module_ast_use_path_item.to_token_stream());
}
}
syn::Item::Fn(function_item) => {
let entry_attrs = get_meta_list(
format!("{} :: entry", tile_rust_crate_root).as_str(),
&function_item.attrs,
);
if entry_attrs.is_some() {
entry_functions.push(kernel_launcher(name, function_item)?);
};
ast_content.push(Item::Fn(function_item.clone()));
concrete_items.push(function(function_item.clone(), tile_rust_crate_root)?);
}
syn::Item::Struct(struct_item) => {
ast_content.push(Item::Struct(struct_item.clone()));
let item_clone = struct_item.clone();
concrete_items.push(structure(item_clone)?.into());
}
syn::Item::Trait(trait_item) => {
if !is_core {
return trait_item.err("Unsupported item type in non-core module: trait definitions are only allowed in core modules.");
}
ast_content.push(Item::Trait(trait_item.clone()));
let item_clone = trait_item.clone();
concrete_items.push(trait_(item_clone)?.into());
}
syn::Item::Type(type_item) => {
concrete_items.push(type_item.to_token_stream());
}
syn::Item::Impl(impl_item) => {
if !is_core {
return impl_item.err("Unsupported item type in non-core module: impl blocks are only allowed in core modules.");
}
ast_content.push(Item::Impl(impl_item.clone()));
let item_clone = impl_item.clone();
concrete_items.push(implementation(item_clone)?.into());
}
syn::Item::Macro(macro_item) => {
if !is_core {
return macro_item.err("Unsupported item type in non-core module: macro invocations are only allowed in core modules.");
}
ast_content.push(Item::Macro(macro_item.clone()));
let item_clone = macro_item.clone();
concrete_items.push(item_clone.to_token_stream());
}
other => {
return other.err("Unsupported item type in module.");
}
}
}
let ast_path = get_ast_path(tile_rust_crate_root);
let ast_module_item: ItemMod = module_item.clone();
let ast_module_tokens = module_asts(
ast_module_item,
module_ast_calls,
tile_rust_crate_root,
raw_item_source,
);
let res = if entry_functions.is_empty() {
quote! {
pub mod #name {
#![allow(nonstandard_style)]
#![allow(dead_code)]
#![allow(unused_variables)]
use #ast_path;
#ast_module_tokens
#(#concrete_items)*
}
}
} else {
quote! {
pub mod #name {
#![allow(dead_code)]
use std::{iter::zip, future::{Future, IntoFuture}, collections::HashMap, sync::Arc};
use #tile_rust_crate_root::error::{*};
use #tile_rust_crate_root::DType;
use #tile_rust_crate_root::{tensor};
use #tile_rust_crate_root::tensor::{KernelInput, KernelInputStored, KernelOutput, KernelOutputStored, SpecializationBits};
use #tile_rust_crate_root::tile_kernel::{*};
use #tile_rust_crate_root::cuda_async::error::{*};
use #tile_rust_crate_root::cuda_async::scheduling_policies::SchedulingPolicy;
use #tile_rust_crate_root::cuda_core::{CudaContext, CudaFunction, CudaModule, CudaStream, DriverError, LaunchConfig};
use #ast_path;
#ast_module_tokens
#(#concrete_items)*
#(#entry_functions)*
}
}
};
Ok(res)
}
pub fn trait_(mut item: ItemTrait) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_trait", &item.attrs);
let is_unchecked = get_meta_list("cuda_tile :: unchecked", &item.attrs);
if is_unchecked.is_some() {
return Ok(quote! {}.into());
}
clear_attributes(
HashSet::from(["cuda_tile :: variadic_trait", "cuda_tile :: ty"]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_trait" => {
let items = variadic_trait(&attributes, item)?;
quote! {
#(#items)*
}
}
_ => {
let item = desugar_trait_cgas(&item)?;
quote! { #item }
}
},
None => {
let item = desugar_trait_cgas(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn implementation(mut item: ItemImpl) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_impl", &item.attrs);
let is_unchecked = get_meta_list("cuda_tile :: unchecked", &item.attrs);
if is_unchecked.is_some() {
return Ok(quote! {}.into());
}
clear_attributes(
HashSet::from([
"cuda_tile :: variadic_trait_impl",
"cuda_tile :: variadic_impl",
"cuda_tile :: ty",
]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_impl" => {
let items = variadic_impl(&attributes, item)?;
quote! {
#(#items)*
}
}
_ => {
let item = desugar_impl_cgas(&item)?;
quote! { #item }
}
},
None => {
let item = desugar_impl_cgas(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn structure(mut item: ItemStruct) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_struct", &item.attrs);
clear_attributes(
HashSet::from(["cuda_tile :: variadic_struct", "cuda_tile :: ty"]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_struct" => {
let items = variadic_struct(&attributes, item)?;
let structs = items.iter().map(|item| item.0.clone()).collect::<Vec<_>>();
let maybe_impls = items
.iter()
.filter(|item| item.1.is_some())
.collect::<Vec<_>>();
let impls = maybe_impls
.iter()
.map(|item| item.1.clone().unwrap())
.collect::<Vec<_>>();
quote! {
#(#structs)*
#(#impls)*
}
}
_ => {
let item = desugar_structure_cgas(&item)?;
quote! { #item }
}
},
None => {
let item = desugar_structure_cgas(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn function(mut item: ItemFn, tile_rust_crate_root: &Ident) -> Result<TokenStream2, Error> {
let is_entry = get_meta_list_by_last_segment("entry", &item.attrs).is_some();
if is_entry {
validate_entry_point_parameters(&item)?
}
let attributes = get_meta_list("cuda_tile :: variadic_op", &item.attrs);
clear_attributes(
HashSet::from([
"cuda_tile :: variadic_op",
"cuda_tile :: op",
"cuda_tile :: compiler_op",
]),
&mut item.attrs,
);
clear_attributes(
HashSet::from([format!("{} :: entry", tile_rust_crate_root).as_str()]),
&mut item.attrs,
);
if is_entry {
let kernel_naming = KernelNaming::new(item.sig.ident.to_string().as_str());
let internal_name = kernel_naming.user_impl_name();
item.sig.ident = Ident::new(internal_name.as_str(), item.sig.ident.span());
}
let concrete_items = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_op" => variadic_op(&attributes, item.clone())?,
_ => vec![desugar_function_cgas(&item)?],
},
None => vec![desugar_function_cgas(&item)?],
};
let result = quote! {
#(#concrete_items)*
};
Ok(result)
}
pub fn kernel_launcher(module_ident: &Ident, item: &ItemFn) -> Result<TokenStream2, Error> {
let module_name = module_ident.to_string();
let function_name = item.sig.ident.to_string();
let kernel_naming = KernelNaming::new(function_name.as_str());
let function_entry_name = kernel_naming.entry_name();
let launcher_name = function_name.to_case(Case::UpperCamel).to_string();
let launcher_args_name = format!("{}Args", launcher_name);
let unsafety = item.sig.unsafety;
let (
required_generics,
(stored_args_type, returned_args_type),
device_op_impl,
kernel_input_info,
) = generate_kernel_launcher(
item,
&module_name,
&function_name,
function_entry_name.as_str(),
&launcher_name,
&launcher_args_name,
)?;
let launcher_ident = Ident::new(launcher_name.as_str(), Span::call_site());
let generic_params = required_generics.get_required_generics();
let generic_args = required_generics.get_generic_args();
let mut struct_generics = generic_params.clone();
for (ki_idx, ki_name) in kernel_input_info.type_param_names.iter().enumerate() {
let elem = &kernel_input_info.element_type_names[ki_idx];
struct_generics.params.push(
syn::parse_str::<GenericParam>(&format!("{ki_name}: KernelInput<{elem}>")).unwrap(),
);
}
for (ko_idx, ko_name) in kernel_input_info.ko_type_param_names.iter().enumerate() {
let elem = &kernel_input_info.ko_element_type_names[ko_idx];
struct_generics.params.push(
syn::parse_str::<GenericParam>(&format!("{ko_name}: KernelOutput<{elem}>")).unwrap(),
);
}
let device_op_param: GenericParam = parse_quote! { DI: DeviceOp<Output=#stored_args_type> };
struct_generics.params.push(device_op_param.clone());
let mut struct_args = generic_args.clone();
for ki_name in &kernel_input_info.type_param_names {
struct_args
.args
.push(syn::parse_str::<GenericArgument>(ki_name).unwrap());
}
for ko_name in &kernel_input_info.ko_type_param_names {
struct_args
.args
.push(syn::parse_str::<GenericArgument>(ko_name).unwrap());
}
let device_op_arg: GenericArgument = parse_quote! { DI };
struct_args.args.push(device_op_arg.clone());
let tile_kernel_impl_type_params = struct_generics.clone();
let tile_kernel_type_args: AngleBracketedGenericArguments =
parse_quote! { <#returned_args_type, #device_op_arg, #stored_args_type> };
let into_future_impl_type_params = struct_generics.clone();
let mut phantom_types: Vec<syn::Type> = vec![];
for ki_name in &kernel_input_info.type_param_names {
phantom_types.push(syn::parse_str::<syn::Type>(ki_name.as_str()).unwrap());
}
for ko_name in &kernel_input_info.ko_type_param_names {
phantom_types.push(syn::parse_str::<syn::Type>(ko_name.as_str()).unwrap());
}
for param in &generic_params.params {
if let syn::GenericParam::Type(tp) = param {
phantom_types.push(syn::parse_str::<syn::Type>(&tp.ident.to_string()).unwrap());
}
}
let ki_phantom_types = phantom_types;
let result = quote! {
pub struct #launcher_ident #struct_generics {
_const_grid: bool,
_grid: (u32, u32, u32),
input: Option<DI>,
function_generics: Option<Vec<String>>,
_phantom: std::marker::PhantomData<( #(#ki_phantom_types,)* )>,
_compile_options: CompileOptions,
}
impl #tile_kernel_impl_type_params #launcher_ident #struct_args {
pub #unsafety fn launch(input: DI) -> Self {
Self {
_const_grid: false,
_grid: (0, 0, 0),
input: Some(input),
function_generics: None,
_phantom: std::marker::PhantomData,
_compile_options: CompileOptions::default(),
}
}
}
impl #tile_kernel_impl_type_params TileKernel #tile_kernel_type_args for #launcher_ident #struct_args {
fn grid(mut self, grid: (u32, u32, u32)) -> Self {
self._grid = grid;
self._const_grid = false;
self
}
fn const_grid(mut self, grid: (u32, u32, u32)) -> Self {
self._grid = grid;
self._const_grid = true;
self
}
fn get_launch_grid(&self) -> (u32, u32, u32) {
self._grid
}
fn generics(mut self, generics: Vec<String>) -> Self {
self.function_generics = Some(generics);
self
}
fn compile_options(mut self, options: CompileOptions) -> Self {
self._compile_options = options;
self
}
}
impl #into_future_impl_type_params IntoFuture for #launcher_ident #struct_args {
type Output = Result<#returned_args_type, DeviceError>;
type IntoFuture = DeviceFuture<#returned_args_type, #launcher_ident #struct_args>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| { let stream = policy.next_stream()?; Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream))) }) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
#device_op_impl
};
let Some(_entry_attrs) = get_meta_list_by_last_segment("entry", &item.attrs) else {
return item.sig.ident.err(&format!(
"Unexpected entry point {function_name}: Missing entry annotation."
));
};
if let Ok(dir) = env::var("DUMP_KERNEL_LAUNCHER_DIR") {
let file = parse_file(&result.to_string()).expect("Failed to parse file.");
let filename = format!("{module_name}_{function_name}_launcher.rs");
let path = PathBuf::from(dir).join(filename);
let contents = file_item_string_pretty(&file);
fs::write(path.clone(), contents).unwrap_or_else(|_| panic!("Failed to write {path:?}"));
}
Ok(result)
}
pub fn module_asts(
item: ItemMod,
module_ast_calls: Vec<String>,
tile_rust_crate_root: &Ident,
raw_item_source: String,
) -> TokenStream2 {
let ast_path = get_ast_path(tile_rust_crate_root);
let name_string = item.ident.to_string();
let vec_expr_str = format!("vec![{}]", module_ast_calls.join(","));
let vec_expr = syn::parse::<Macro>(vec_expr_str.parse().unwrap()).unwrap();
let asts_ident = get_asts_ident();
let item_start_span = match &item.vis {
syn::Visibility::Public(vis_pub) => vis_pub.span,
syn::Visibility::Restricted(vis_r) => vis_r.pub_token.span,
syn::Visibility::Inherited => item.mod_token.span,
};
let source_file = item_start_span.file();
let base_line = item_start_span.start().line;
let base_col = item_start_span.start().column;
let source_text = {
let full_span = item
.content
.as_ref()
.and_then(|(brace, _)| item_start_span.join(brace.span.close()));
let file_slice = item.content.as_ref().and_then(|(brace, _)| {
source_slice_from_file(
&source_file.to_string(),
item_start_span.start(),
brace.span.close().end(),
)
});
full_span
.and_then(|sp| sp.source_text())
.or(file_slice)
.unwrap_or(raw_item_source)
};
let result = quote! {
pub fn #asts_ident() -> Vec<#ast_path::Module> {
use #ast_path::syn;
let source_text: &str = #source_text;
let parsed_mod: syn::ItemMod = syn::parse_str(source_text)
.expect("module_asts: failed to re-parse captured source text");
let span_base = #ast_path::SpanBase::new(
#source_file.to_string(),
#base_line,
#base_col,
);
let mut this_ast = #ast_path::Module::with_span_base(
#name_string,
parsed_mod,
span_base,
);
this_ast.set_absolute_path(module_path!().to_string());
let mut module_asts: Vec<#ast_path::Module> = vec![];
let mut seen_paths: std::collections::HashSet<String> = std::collections::HashSet::new();
let other_module_asts_asts: Vec<Vec<#ast_path::Module>> = #vec_expr;
for other_module_asts in other_module_asts_asts {
for module_ast in other_module_asts {
if seen_paths.insert(module_ast.absolute_path().to_string()) {
module_asts.push(module_ast);
}
}
}
if seen_paths.insert(this_ast.absolute_path().to_string()) {
module_asts.push(this_ast);
}
module_asts
}
};
result
}