use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use proc_macro2::{LineColumn, Span, TokenStream as TokenStream2};
use quote::{format_ident, quote, ToTokens};
use std::collections::HashSet;
use std::path::PathBuf;
use std::{env, fs};
use syn::{
parse_file, parse_macro_input, parse_quote, AngleBracketedGenericArguments, GenericArgument,
GenericParam, ItemFn, ItemImpl, ItemMod, ItemStruct, ItemTrait, Path,
};
use crate::error::{Error, SpannedError};
use crate::kernel_launcher_generator::generate_kernel_launcher;
use crate::rank_instantiation::*;
use crate::shadow_dispatch::{
desugar_variadic_trait_decl, desugar_variadic_trait_impl, emit_shadow_dispatch,
};
use crate::validate_dsl_syntax::validate_entry_point_parameters;
use cutile_compiler::kernel_naming::KernelNaming;
use cutile_compiler::syn_utils::*;
fn line_column_to_offset(source: &str, loc: LineColumn) -> Option<usize> {
let mut line_start = 0usize;
let mut current_line = 1usize;
for line in source.split_inclusive('\n') {
if current_line == loc.line {
let column_offset = byte_offset_for_char_column(line, loc.column)?;
return Some(line_start + column_offset);
}
line_start += line.len();
current_line += 1;
}
if current_line == loc.line {
let tail = &source[line_start..];
let column_offset = byte_offset_for_char_column(tail, loc.column)?;
return Some(line_start + column_offset);
}
None
}
fn byte_offset_for_char_column(line: &str, column: usize) -> Option<usize> {
if column == 0 {
return Some(0);
}
if column == line.chars().count() {
return Some(line.len());
}
line.char_indices().nth(column).map(|(idx, _)| idx)
}
fn source_slice_from_file(path: &str, start: LineColumn, end: LineColumn) -> Option<String> {
let source = fs::read_to_string(path).ok()?;
let start_offset = line_column_to_offset(&source, start)?;
let end_offset = line_column_to_offset(&source, end)?;
source.get(start_offset..end_offset).map(str::to_string)
}
pub fn get_ast_path(tile_rust_crate_root: &Ident) -> Path {
let s = format!("{tile_rust_crate_root}::cutile_compiler::ast");
syn::parse::<Path>(s.parse().unwrap()).unwrap()
}
pub fn get_registry_path(tile_rust_crate_root: &Ident) -> Path {
let s = format!("{tile_rust_crate_root}::cutile_compiler::registry");
syn::parse::<Path>(s.parse().unwrap()).unwrap()
}
pub fn get_self_ast_ident() -> Ident {
Ident::new("__module_ast_self", Span::call_site())
}
pub fn module(attributes: TokenStream, item: TokenStream) -> TokenStream {
let attrs = parse_macro_input!(attributes as SingleMetaList);
let is_tile_rust_crate = attrs.parse_bool("tile_rust_crate").unwrap_or(false);
let tile_rust_crate_root = Ident::new(
if is_tile_rust_crate {
"crate"
} else {
"cutile"
},
Span::call_site(),
);
let raw_item_source = item.to_string();
let mut module_item = parse_macro_input!(item as ItemMod);
module_item.attrs = attrs.into();
match module_inner(&module_item, &tile_rust_crate_root, raw_item_source) {
Ok(ts) => ts.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn process_items(
items: &[syn::Item],
parent_name: &Ident,
tile_rust_crate_root: &Ident,
) -> Result<(Vec<TokenStream2>, Vec<TokenStream2>), Error> {
let mut concrete_items: Vec<TokenStream2> = vec![];
let mut entry_functions: Vec<TokenStream2> = vec![];
for item in items {
match item {
syn::Item::Use(use_item) => {
concrete_items.push(use_item.to_token_stream());
}
syn::Item::Fn(function_item) => {
let entry_attrs = get_meta_list(
format!("{} :: entry", tile_rust_crate_root).as_str(),
&function_item.attrs,
);
if entry_attrs.is_some() {
entry_functions.push(kernel_launcher(parent_name, function_item)?);
};
concrete_items.push(function(function_item.clone(), tile_rust_crate_root)?);
}
syn::Item::Struct(struct_item) => {
let item_clone = struct_item.clone();
concrete_items.push(structure(item_clone)?.into());
}
syn::Item::Trait(trait_item) => {
let item_clone = trait_item.clone();
concrete_items.push(trait_(item_clone)?.into());
}
syn::Item::Type(type_item) => {
concrete_items.push(type_item.to_token_stream());
}
syn::Item::Impl(impl_item) => {
let item_clone = impl_item.clone();
concrete_items.push(implementation(item_clone)?.into());
}
syn::Item::Macro(macro_item) => {
let item_clone = macro_item.clone();
concrete_items.push(item_clone.to_token_stream());
}
syn::Item::Const(const_item) => {
concrete_items.push(const_item.to_token_stream());
}
syn::Item::Static(static_item) => {
concrete_items.push(static_item.to_token_stream());
}
syn::Item::Mod(submod) => {
let Some(sub_content) = &submod.content else {
return submod.err(
"Submodule inside `#[cutile::module]` must have an inline body \
(`mod foo { ... }`); file-loaded submodules (`mod foo;`) are \
not supported because the macro needs the body at expansion time.",
);
};
let (sub_concrete, sub_entries) =
process_items(&sub_content.1, &submod.ident, tile_rust_crate_root)?;
let sub_name = &submod.ident;
let sub_attrs = &submod.attrs;
let sub_vis = &submod.vis;
let sub_module = quote! {
#(#sub_attrs)*
#sub_vis mod #sub_name {
#(#sub_concrete)*
#(#sub_entries)*
}
};
concrete_items.push(sub_module);
}
other => {
return other.err("Unsupported item type in module.");
}
}
}
Ok((concrete_items, entry_functions))
}
fn module_inner(
module_item: &ItemMod,
tile_rust_crate_root: &Ident,
raw_item_source: String,
) -> Result<TokenStream2, Error> {
let Some(content) = &module_item.content else {
return module_item.err("Non-empty module expected.");
};
let name = &module_item.ident;
let (concrete_items, entry_functions) = process_items(&content.1, name, tile_rust_crate_root)?;
let ast_path = get_ast_path(tile_rust_crate_root);
let ast_module_item: ItemMod = module_item.clone();
let ast_module_tokens = emit_module_ast_self_and_registry_entry(
ast_module_item,
tile_rust_crate_root,
raw_item_source,
);
let res = if entry_functions.is_empty() {
quote! {
pub mod #name {
#![allow(nonstandard_style)]
#![allow(dead_code)]
#![allow(unused_variables)]
use #ast_path;
#ast_module_tokens
#(#concrete_items)*
}
}
} else {
quote! {
pub mod #name {
#![allow(dead_code)]
use std::{iter::zip, future::{Future, IntoFuture}, collections::HashMap, sync::Arc};
use #tile_rust_crate_root::error::{*};
use #tile_rust_crate_root::DType;
use #tile_rust_crate_root::{tensor};
use #tile_rust_crate_root::tensor::{KernelInput, KernelInputStored, KernelOutput, KernelOutputStored, SpecializationBits};
use #tile_rust_crate_root::tile_kernel::{*};
use #tile_rust_crate_root::cuda_async::error::{*};
use #tile_rust_crate_root::cuda_async::scheduling_policies::SchedulingPolicy;
use #tile_rust_crate_root::cuda_core::{Device, Function, Module, Stream, DriverError, LaunchConfig};
use #ast_path;
#ast_module_tokens
#(#concrete_items)*
#(#entry_functions)*
}
}
};
Ok(res)
}
pub fn trait_(mut item: ItemTrait) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_trait", &item.attrs);
let is_unchecked = get_meta_list("cuda_tile :: unchecked", &item.attrs);
if is_unchecked.is_some() {
return Ok(quote! {}.into());
}
clear_attributes(
HashSet::from(["cuda_tile :: variadic_trait", "cuda_tile :: ty"]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes)
if attributes.name_as_str().as_deref() == Some("cuda_tile :: variadic_trait") =>
{
desugar_variadic_trait_decl(&item)?
}
_ => quote! { #item },
};
Ok(res.into())
}
pub fn implementation(mut item: ItemImpl) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_impl", &item.attrs);
let is_unchecked = get_meta_list("cuda_tile :: unchecked", &item.attrs);
if is_unchecked.is_some() {
return Ok(quote! {}.into());
}
let is_variadic_trait_impl =
get_meta_list("cuda_tile :: variadic_trait_impl", &item.attrs).is_some();
clear_attributes(
HashSet::from([
"cuda_tile :: variadic_trait_impl",
"cuda_tile :: variadic_impl",
"cuda_tile :: ty",
]),
&mut item.attrs,
);
let res = if is_variadic_trait_impl {
desugar_variadic_trait_impl(&item)?
} else {
match attributes {
Some(attributes) => {
let items = variadic_impl(&attributes, item)?;
quote! {
#(#items)*
}
}
None => {
let item = instantiate_impl_for_rank(&item)?;
quote! { #item }
}
}
};
Ok(res.into())
}
pub fn structure(mut item: ItemStruct) -> Result<TokenStream, Error> {
let attributes = get_meta_list("cuda_tile :: variadic_struct", &item.attrs);
clear_attributes(
HashSet::from(["cuda_tile :: variadic_struct", "cuda_tile :: ty"]),
&mut item.attrs,
);
let res = match attributes {
Some(attributes) => match attributes.name_as_str().unwrap().as_str() {
"cuda_tile :: variadic_struct" => {
let items = variadic_struct(&attributes, item)?;
let structs = items.iter().map(|item| item.0.clone()).collect::<Vec<_>>();
let maybe_impls = items
.iter()
.filter(|item| item.1.is_some())
.collect::<Vec<_>>();
let impls = maybe_impls
.iter()
.map(|item| item.1.clone().unwrap())
.collect::<Vec<_>>();
quote! {
#(#structs)*
#(#impls)*
}
}
_ => {
let item = instantiate_struct_for_rank(&item)?;
quote! { #item }
}
},
None => {
let item = instantiate_struct_for_rank(&item)?;
quote! { #item }
}
};
Ok(res.into())
}
pub fn function(mut item: ItemFn, tile_rust_crate_root: &Ident) -> Result<TokenStream2, Error> {
let is_entry = get_meta_list_by_last_segment("entry", &item.attrs).is_some();
if is_entry {
validate_entry_point_parameters(&item)?
}
let attributes = get_meta_list("cuda_tile :: variadic_op", &item.attrs);
let emit_trait = attributes.is_some();
let method_override: Option<Ident> = attributes
.as_ref()
.and_then(|a| a.parse_string("method"))
.map(|s| Ident::new(&s, Span::call_site()));
let trait_name_override: Option<Ident> = attributes
.as_ref()
.and_then(|a| a.parse_string("trait_name"))
.map(|s| Ident::new(&s, Span::call_site()));
let original_item_for_shadow_dispatch = if emit_trait { Some(item.clone()) } else { None };
clear_attributes(
HashSet::from([
"cuda_tile :: variadic_op",
"cuda_tile :: op",
"cuda_tile :: compiler_op",
]),
&mut item.attrs,
);
clear_attributes(
HashSet::from([format!("{} :: entry", tile_rust_crate_root).as_str()]),
&mut item.attrs,
);
if is_entry {
let kernel_naming = KernelNaming::new(item.sig.ident.to_string().as_str());
let internal_name = kernel_naming.user_impl_name();
item.sig.ident = Ident::new(internal_name.as_str(), item.sig.ident.span());
}
let concrete_items = if emit_trait {
vec![]
} else {
vec![instantiate_function_for_rank(&item)?]
};
let shadow_dispatch_tokens = match original_item_for_shadow_dispatch {
Some(orig) => emit_shadow_dispatch(&orig, method_override, trait_name_override)?,
None => TokenStream2::new(),
};
let result = quote! {
#(#concrete_items)*
#shadow_dispatch_tokens
};
Ok(result)
}
pub fn kernel_launcher(module_ident: &Ident, item: &ItemFn) -> Result<TokenStream2, Error> {
let module_name = module_ident.to_string();
let function_name = item.sig.ident.to_string();
let kernel_naming = KernelNaming::new(function_name.as_str());
let function_entry_name = kernel_naming.entry_name();
let launcher_name = function_name.to_case(Case::UpperCamel).to_string();
let launcher_args_name = format!("{}Args", launcher_name);
let unsafety = item.sig.unsafety;
let (
required_generics,
(stored_args_type, returned_args_type),
device_op_impl,
kernel_input_info,
) = generate_kernel_launcher(
item,
&module_name,
&function_name,
function_entry_name.as_str(),
&launcher_name,
&launcher_args_name,
)?;
let launcher_ident = Ident::new(launcher_name.as_str(), Span::call_site());
let generic_params = required_generics.get_required_generics();
let generic_args = required_generics.get_generic_args();
let mut struct_generics = generic_params.clone();
for (ki_idx, ki_name) in kernel_input_info.type_param_names.iter().enumerate() {
let elem = &kernel_input_info.element_type_names[ki_idx];
struct_generics.params.push(
syn::parse_str::<GenericParam>(&format!("{ki_name}: KernelInput<{elem}>")).unwrap(),
);
}
for (ko_idx, ko_name) in kernel_input_info.ko_type_param_names.iter().enumerate() {
let elem = &kernel_input_info.ko_element_type_names[ko_idx];
struct_generics.params.push(
syn::parse_str::<GenericParam>(&format!("{ko_name}: KernelOutput<{elem}>")).unwrap(),
);
}
let device_op_param: GenericParam = parse_quote! { DI: DeviceOp<Output=#stored_args_type> };
struct_generics.params.push(device_op_param.clone());
let mut struct_args = generic_args.clone();
for ki_name in &kernel_input_info.type_param_names {
struct_args
.args
.push(syn::parse_str::<GenericArgument>(ki_name).unwrap());
}
for ko_name in &kernel_input_info.ko_type_param_names {
struct_args
.args
.push(syn::parse_str::<GenericArgument>(ko_name).unwrap());
}
let device_op_arg: GenericArgument = parse_quote! { DI };
struct_args.args.push(device_op_arg.clone());
let tile_kernel_impl_type_params = struct_generics.clone();
let tile_kernel_type_args: AngleBracketedGenericArguments =
parse_quote! { <#returned_args_type, #device_op_arg, #stored_args_type> };
let into_future_impl_type_params = struct_generics.clone();
let mut phantom_types: Vec<syn::Type> = vec![];
for ki_name in &kernel_input_info.type_param_names {
phantom_types.push(syn::parse_str::<syn::Type>(ki_name.as_str()).unwrap());
}
for ko_name in &kernel_input_info.ko_type_param_names {
phantom_types.push(syn::parse_str::<syn::Type>(ko_name.as_str()).unwrap());
}
for param in &generic_params.params {
if let syn::GenericParam::Type(tp) = param {
phantom_types.push(syn::parse_str::<syn::Type>(&tp.ident.to_string()).unwrap());
}
}
let ki_phantom_types = phantom_types;
let result = quote! {
pub struct #launcher_ident #struct_generics {
_const_grid: bool,
_grid: (u32, u32, u32),
input: Option<DI>,
function_generics: Option<Vec<String>>,
_phantom: std::marker::PhantomData<( #(#ki_phantom_types,)* )>,
_compile_options: CompileOptions,
}
impl #tile_kernel_impl_type_params #launcher_ident #struct_args {
pub #unsafety fn launch(input: DI) -> Self {
Self {
_const_grid: false,
_grid: (0, 0, 0),
input: Some(input),
function_generics: None,
_phantom: std::marker::PhantomData,
_compile_options: CompileOptions::default(),
}
}
}
impl #tile_kernel_impl_type_params TileKernel #tile_kernel_type_args for #launcher_ident #struct_args {
fn grid(mut self, grid: (u32, u32, u32)) -> Self {
self._grid = grid;
self._const_grid = false;
self
}
fn const_grid(mut self, grid: (u32, u32, u32)) -> Self {
self._grid = grid;
self._const_grid = true;
self
}
fn get_launch_grid(&self) -> (u32, u32, u32) {
self._grid
}
fn generics(mut self, generics: Vec<String>) -> Self {
self.function_generics = Some(generics);
self
}
fn compile_options(mut self, options: CompileOptions) -> Self {
self._compile_options = options;
self
}
}
impl #into_future_impl_type_params IntoFuture for #launcher_ident #struct_args {
type Output = Result<#returned_args_type, DeviceError>;
type IntoFuture = DeviceFuture<#returned_args_type, #launcher_ident #struct_args>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| { let stream = policy.next_stream()?; Ok(DeviceFuture::scheduled(self, ExecutionContext::new(stream))) }) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
#device_op_impl
};
let Some(_entry_attrs) = get_meta_list_by_last_segment("entry", &item.attrs) else {
return item.sig.ident.err(&format!(
"Unexpected entry point {function_name}: Missing entry annotation."
));
};
if let Ok(dir) = env::var("DUMP_KERNEL_LAUNCHER_DIR") {
let file = parse_file(&result.to_string()).expect("Failed to parse file.");
let filename = format!("{module_name}_{function_name}_launcher.rs");
let path = PathBuf::from(dir).join(filename);
let contents = file_item_string_pretty(&file);
fs::write(path.clone(), contents).unwrap_or_else(|_| panic!("Failed to write {path:?}"));
}
Ok(result)
}
pub fn emit_module_ast_self_and_registry_entry(
item: ItemMod,
tile_rust_crate_root: &Ident,
raw_item_source: String,
) -> TokenStream2 {
let item_start_span = match &item.vis {
syn::Visibility::Public(vis_pub) => vis_pub.span,
syn::Visibility::Restricted(vis_r) => vis_r.pub_token.span,
syn::Visibility::Inherited => item.mod_token.span,
};
let source_file = item_start_span.file().to_string();
let base_line = item_start_span.start().line;
let base_col = item_start_span.start().column;
let source_text = {
let full_span = item
.content
.as_ref()
.and_then(|(brace, _)| item_start_span.join(brace.span.close()));
let file_slice = item.content.as_ref().and_then(|(brace, _)| {
source_slice_from_file(
&source_file,
item_start_span.start(),
brace.span.close().end(),
)
});
full_span
.and_then(|sp| sp.source_text())
.or(file_slice)
.unwrap_or(raw_item_source)
};
emit_ast_self_and_registry(
&item.ident,
&source_text,
&source_file,
base_line,
base_col,
tile_rust_crate_root,
)
}
fn emit_ast_self_and_registry(
name: &Ident,
source_text: &str,
source_file: &str,
base_line: usize,
base_col: usize,
tile_rust_crate_root: &Ident,
) -> TokenStream2 {
let ast_path = get_ast_path(tile_rust_crate_root);
let registry_path = get_registry_path(tile_rust_crate_root);
let self_ast_ident = get_self_ast_ident();
let name_string = name.to_string();
let registry_static_ident =
format_ident!("__CUTILE_MODULE_ENTRY_{}", name_string.to_uppercase());
quote! {
pub fn #self_ast_ident() -> #ast_path::Module {
use #ast_path::syn;
let source_text: &str = #source_text;
let parsed_mod: syn::ItemMod = syn::parse_str(source_text)
.expect("__module_ast_self: failed to re-parse captured source text");
let span_base = #ast_path::SpanBase::new(
#source_file.to_string(),
#base_line,
#base_col,
);
let mut this_ast = #ast_path::Module::with_span_base(
#name_string,
parsed_mod,
span_base,
);
this_ast.set_absolute_path(module_path!().to_string());
this_ast
}
#[#registry_path::linkme::distributed_slice(#registry_path::CUTILE_MODULES)]
#[linkme(crate = #registry_path::linkme)]
static #registry_static_ident: #registry_path::CutileModuleEntry =
#registry_path::CutileModuleEntry {
absolute_path: ::std::module_path!(),
build: #self_ast_ident,
};
}
}