use crate::generics::GenericVars;
use proc_macro2::Ident;
use quote::ToTokens;
use std::collections::HashSet;
use std::fmt::Display;
use std::str::FromStr;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
parse_quote, AngleBracketedGenericArguments, Attribute, ConstParam, Expr, ExprCall,
ExprClosure, ExprPath, FnArg, GenericArgument, GenericParam, Generics, Item, ItemFn, Lit, Meta,
MetaList, Pat, PathArguments, ReturnType, Signature, Token, Type,
};
#[derive(Debug)]
pub struct SingleMetaList {
name: Option<String>,
meta_list: Option<MetaList>,
variables: Vec<Meta>,
}
impl SingleMetaList {
pub fn from_attribute(attr: Attribute) -> Self {
let Meta::List(meta_list) = attr.meta else {
panic!("Unexpected attribute list {:#?}", attr.meta)
};
let tokens = proc_macro2::TokenStream::from(meta_list.tokens.clone());
let mut result = syn::parse2::<SingleMetaList>(tokens).unwrap();
result.name = Some(meta_list.path.to_token_stream().to_string());
result.meta_list = Some(meta_list);
return result;
}
pub fn name_as_str(&self) -> Option<String> {
match &self.name {
Some(s) => Some(s.clone()),
None => None,
}
}
pub fn name_as_vec(&self) -> Option<Vec<&str>> {
match &self.name {
Some(s) => Some(s.as_str().split(" :: ").collect()),
None => None,
}
}
fn get_value(&self, name: &str) -> Option<&Expr> {
for item in &self.variables {
match item {
Meta::NameValue(name_value) => {
let meta_ident = name_value.path.get_ident();
let meta_name = meta_ident.clone().unwrap().to_string();
if name == meta_name {
return Some(&name_value.value);
}
}
_ => continue,
}
}
None
}
pub fn parse_custom_expr(&self, name: &str) -> Option<&Expr> {
self.get_value(name)
}
pub fn parse_string_arr(&self, name: &str) -> Option<Vec<String>> {
let value = self.get_value(name);
match value {
Some(val) => {
let Expr::Array(ref arr) = val else {
panic!("{name} is not an array: {val:#?}")
};
let mut res = vec![];
for val in &arr.elems {
let Expr::Lit(ref lit) = val else {
panic!("{name} is not a literal: {val:#?}")
};
let Lit::Str(ref lit_str) = lit.lit else {
panic!("{name} is not a string: {lit:#?}")
};
res.push(lit_str.value().clone());
}
Some(res)
}
None => None,
}
}
pub fn parse_string(&self, name: &str) -> Option<String> {
let value = self.get_value(name);
match value {
Some(val) => {
let Expr::Lit(ref lit) = val else {
panic!("{name} is not a literal: {val:#?}")
};
let Lit::Str(ref lit_str) = lit.lit else {
panic!("{name} is not a string: {lit:#?}")
};
Some(lit_str.value().clone())
}
None => None,
}
}
pub fn parse_int(&self, name: &str) -> Option<u32> {
let value = self.get_value(name);
match value {
Some(val) => {
let Expr::Lit(ref lit) = val else {
panic!("{name} is not a literal: {val:#?}")
};
let Lit::Int(ref lit_int) = lit.lit else {
panic!("{name} is not a string: {lit:#?}")
};
Some(lit_int.base10_parse().unwrap())
}
None => None,
}
}
pub fn parse_bool(&self, name: &str) -> Option<bool> {
let value = self.get_value(name);
match value {
Some(val) => {
let Expr::Lit(ref lit) = val else {
panic!("{name} is not a literal: {val:#?}")
};
let Lit::Bool(ref lit_bool) = lit.lit else {
panic!("{name} is not a string: {lit:#?}")
};
Some(lit_bool.value)
}
None => None,
}
}
}
impl Parse for SingleMetaList {
fn parse(input: ParseStream) -> syn::Result<Self> {
let variables = Punctuated::<Meta, Token![,]>::parse_terminated(input)?
.into_iter()
.collect();
Ok(Self {
name: None,
meta_list: None,
variables,
})
}
}
impl Into<Vec<Attribute>> for SingleMetaList {
fn into(self) -> Vec<Attribute> {
let mut res = vec![];
for meta in self.variables {
let attr: Attribute = parse_quote! {
#[noname(#meta)]
};
res.push(attr);
}
res
}
}
pub fn clear_attributes(attr_names: HashSet<&str>, attrs: &mut Vec<Attribute>) -> () {
*attrs = attrs
.clone()
.into_iter()
.filter(|attr| match &attr.meta {
Meta::Path(meta_path) => {
let name = meta_path.to_token_stream().to_string();
!attr_names.contains(name.as_str())
}
Meta::List(meta_list) => {
let name = meta_list.path.to_token_stream().to_string();
!attr_names.contains(name.as_str())
}
_ => true,
})
.collect::<Vec<Attribute>>();
}
pub fn get_attribute(
lookup_str: &str,
outer_attrs: &Vec<Attribute>,
last_seg_only: bool,
) -> Option<Attribute> {
for attr in outer_attrs {
let Meta::List(meta_list) = &attr.meta else {
continue;
};
let parsed_str = if last_seg_only {
meta_list
.path
.segments
.last()
.unwrap()
.to_token_stream()
.to_string()
} else {
meta_list.path.to_token_stream().to_string()
};
if parsed_str == lookup_str {
return Some(attr.clone());
}
}
None
}
pub fn get_meta_list(attr_name: &str, outer_attrs: &Vec<Attribute>) -> Option<SingleMetaList> {
match get_attribute(attr_name, outer_attrs, false) {
Some(attr) => Some(SingleMetaList::from_attribute(attr)),
None => None,
}
}
pub fn get_meta_list_by_last_segment(
last_seg: &str,
outer_attrs: &Vec<Attribute>,
) -> Option<SingleMetaList> {
match get_attribute(last_seg, outer_attrs, true) {
Some(attr) => Some(SingleMetaList::from_attribute(attr)),
None => None,
}
}
pub fn get_cuda_tile_meta_list(outer_attrs: &Vec<Attribute>) -> Option<SingleMetaList> {
let mut found: Option<SingleMetaList> = None;
for attr in outer_attrs {
let Meta::List(meta_list) = &attr.meta else {
continue;
};
let name = meta_list.path.to_token_stream().to_string();
let name_parts = name.split(" :: ").collect::<Vec<&str>>();
if name_parts[0] == "cuda_tile" {
if found.is_some() {
panic!("Found multiple cuda_tile attributes {outer_attrs:#?}")
}
found = Some(SingleMetaList::from_attribute(attr.clone()));
}
}
found
}
#[derive(Debug, Clone)]
pub struct VarCGAParameter {
pub name: String,
pub element_type: String,
pub length_var: String,
}
impl VarCGAParameter {
pub fn instance(&self, length: u32) -> CGAParameter {
CGAParameter {
name: self.name.clone(),
element_type: self.element_type.clone(),
length,
}
}
}
impl VarCGAParameter {
pub fn from_const_param(const_param: &ConstParam) -> VarCGAParameter {
let name = const_param.ident.to_string();
let Type::Array(ty_arr) = &const_param.ty else {
panic!("Expected array type.")
};
let Type::Path(ref element_type) = *ty_arr.elem else {
panic!("Expected type path.")
};
let element_type = element_type
.path
.get_ident()
.unwrap()
.to_string()
.to_string();
match &ty_arr.len {
Expr::Path(length_expr) => {
let length_var = length_expr
.path
.get_ident()
.unwrap()
.to_string()
.to_string();
VarCGAParameter {
name,
element_type,
length_var,
}
}
_ => {
panic!("Unexpected path expression {:#?}.", &ty_arr.len)
}
}
}
pub fn is_var_cga(const_param: &ConstParam) -> bool {
let Type::Array(ty_arr) = &const_param.ty else {
panic!("Expected array type.")
};
if let Expr::Path(_length_expr) = &ty_arr.len {
return true;
};
false
}
pub fn maybe_var_cga(const_param: &ConstParam) -> Option<VarCGAParameter> {
if VarCGAParameter::is_var_cga(const_param) {
Some(VarCGAParameter::from_const_param(const_param))
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct CGAParameter {
pub name: String,
pub element_type: String,
pub length: u32,
}
#[derive(Debug, Clone)]
pub struct ConstParameter {
pub name: String,
pub ty: String,
}
impl CGAParameter {
pub fn from_const_param(const_param: &ConstParam) -> CGAParameter {
let name = const_param.ident.to_string();
let Type::Array(ty_arr) = &const_param.ty else {
panic!("Expected array type.")
};
let Type::Path(ref element_type) = *ty_arr.elem else {
panic!("Expected type path.")
};
let element_type = element_type
.path
.get_ident()
.unwrap()
.to_string()
.to_string();
match &ty_arr.len {
Expr::Lit(expr_lit) => {
let length: u32 = expr_lit
.to_token_stream()
.to_string()
.parse::<u32>()
.unwrap();
CGAParameter {
name,
element_type,
length,
}
}
_ => {
panic!("Unexpected path expression {:#?}.", &ty_arr.len)
}
}
}
}
pub fn parse_cgas(generics: &Generics) -> (Vec<CGAParameter>, Vec<ConstParameter>) {
let mut cga_params: Vec<CGAParameter> = vec![];
let mut const_params: Vec<ConstParameter> = vec![];
for param in &generics.params {
match param {
GenericParam::Type(_type_param) => continue,
GenericParam::Const(const_param) => match &const_param.ty {
Type::Array(_ty_arr) => {
let arr_type_param = CGAParameter::from_const_param(const_param);
cga_params.push(arr_type_param);
}
Type::Path(type_path) => {
let name = const_param.ident.to_string();
let ty = type_path.to_token_stream().to_string();
const_params.push(ConstParameter { name, ty });
}
_ => continue,
},
_ => continue,
}
}
(cga_params, const_params)
}
pub fn get_fn_arg_var_name(arg: &FnArg) -> String {
match arg {
FnArg::Receiver(receiver) => receiver.self_token.to_token_stream().to_string(),
FnArg::Typed(fn_param) => match fn_param.pat.as_ref() {
syn::Pat::Ident(identifier) => identifier.ident.to_string(),
_ => panic!("Unexpected argument pattern"),
},
}
}
pub fn get_sig_param_names(sig: &Signature) -> Vec<String> {
let mut result = vec![];
for arg in &sig.inputs {
let name = get_fn_arg_var_name(arg);
result.push(name);
}
result
}
pub fn get_call_expression_generics(
call_expr: &ExprCall,
) -> Option<AngleBracketedGenericArguments> {
match &*call_expr.func {
Expr::Path(path_expr) => {
let last_seg = path_expr.path.segments.last();
match last_seg {
Some(seg) => {
if let PathArguments::AngleBracketed(path_generic_args) = &seg.arguments {
Some(path_generic_args.clone())
} else {
None
}
}
_ => None,
}
}
_ => None,
}
}
pub fn get_generic_arg_ints<T>(
generic_args: &AngleBracketedGenericArguments,
generic_vars: Option<&GenericVars>,
) -> Vec<T>
where
T: FromStr + From<i32>,
T::Err: Display,
{
let mut result = vec![];
for arg in &generic_args.args {
match arg {
GenericArgument::Const(expr) => {
let Expr::Lit(lit) = expr else {
panic!("Unexpected expression.")
};
let Lit::Int(int_expr) = &lit.lit else {
panic!("Unexpected expression.")
};
let x = int_expr.base10_parse::<T>().expect("Failed to parse int.");
result.push(x);
}
GenericArgument::Type(ty) => {
if let Type::Path(path_ty) = ty {
if let Some(const_var) = path_ty.path.get_ident() {
if let Some(generic_vars) = generic_vars {
if let Some(const_val) =
generic_vars.inst_i32.get(&const_var.to_string())
{
result.push(T::from(*const_val));
}
}
}
}
}
_ => continue,
};
}
result
}
pub fn get_sig_types(sig: &Signature, self_ty: Option<&Type>) -> (Vec<Type>, Type) {
let mut input_tys: Vec<Type> = vec![];
for input in sig.inputs.iter() {
match input {
FnArg::Typed(fn_param) => {
let _name = {
match &*fn_param.pat {
Pat::Ident(ident) => ident.ident.to_string(),
_ => panic!("Unexpected function param pattern {:#?}.", fn_param.pat),
}
};
let ty = &*fn_param.ty;
input_tys.push(ty.clone());
}
FnArg::Receiver(_fn_self) => {
assert!(
self_ty.is_some(),
"bind_parameters for impls requires self_ty."
);
let self_ty = self_ty.unwrap().clone();
input_tys.push(self_ty);
}
}
}
let ret_ty = get_sig_output_type(sig);
(input_tys, ret_ty)
}
pub fn get_sig_output_type(sig: &Signature) -> Type {
match &sig.output {
ReturnType::Type(_, return_type) => *return_type.clone(),
ReturnType::Default => syn::parse2::<Type>("()".parse().unwrap()).unwrap(),
}
}
pub fn function_returns(fn_item: &ItemFn) -> bool {
match &fn_item.sig.output {
ReturnType::Type(_, return_type) => match &**return_type {
Type::Tuple(type_tuple) => type_tuple.elems.len() > 0,
_ => true,
},
ReturnType::Default => false,
}
}
pub fn get_ident_from_path(path: &syn::Path) -> Ident {
path.segments.last().unwrap().ident.clone()
}
pub fn get_ident_from_path_expr(path_expr: &ExprPath) -> Ident {
get_ident_from_path(&path_expr.path)
}
pub fn get_ident_from_expr(expr: &Expr) -> Option<Ident> {
match expr {
Expr::Path(path_expr) => Some(get_ident_from_path(&path_expr.path)),
Expr::Reference(ref_expr) => get_ident_from_expr(&*ref_expr.expr),
_ => None,
}
}
pub fn get_type_ident(ty: &Type) -> Option<Ident> {
match ty {
Type::Path(type_path) => Some(type_path.path.segments.last().unwrap().ident.clone()),
Type::Ptr(type_ptr) => get_type_ident(&type_ptr.elem),
Type::Reference(type_ref) => get_type_ident(&type_ref.elem),
_ => None,
}
}
pub fn get_type_str(ty: &Type) -> Option<String> {
match ty {
Type::Path(type_path) => Some(type_path.path.segments.last().unwrap().ident.to_string()),
Type::Ptr(type_ptr) => Some(type_ptr.to_token_stream().to_string()),
Type::Reference(type_ref) => Some(type_ref.to_token_stream().to_string()),
_ => None,
}
}
pub fn get_ident_generic_args(ty: &Type) -> (Option<Ident>, AngleBracketedGenericArguments) {
match ty {
Type::Path(type_path) => {
let result_type = type_path.clone();
let maybe_last_seg = result_type.path.segments.last().unwrap();
let last_seg = maybe_last_seg.clone();
match last_seg.arguments {
PathArguments::AngleBracketed(type_params) => {
(Some(last_seg.ident.clone()), type_params.clone())
}
_ => panic!(
"get_ident_generic_args: Unexpected generic arguments {:#?} for {ty:#?}",
last_seg.arguments
),
}
}
Type::Reference(ref_type) => get_ident_generic_args(&ref_type.elem),
_ => panic!("get_ident_generic_args: Unexpected type {:#?}", ty),
}
}
pub fn maybe_generic_args(ty: &Type) -> Option<AngleBracketedGenericArguments> {
match ty {
Type::Path(type_path) => {
let result_type = type_path.clone();
let maybe_last_seg = result_type.path.segments.last().unwrap();
let last_seg = maybe_last_seg.clone();
match last_seg.arguments {
PathArguments::AngleBracketed(type_params) => {
Some(type_params.clone())
}
PathArguments::None => None,
_ => panic!(
"get_ident_generic_args: Unexpected generic arguments {:#?} for {ty:#?}",
last_seg.arguments
),
}
}
Type::Reference(ref_type) => maybe_generic_args(&ref_type.elem),
Type::Ptr(_ptr_type) => None,
Type::Array(_arr_type) => None,
Type::Slice(_slice_type) => None,
Type::Tuple(_) => {
None
}
_ => panic!("get_ident_generic_args: Unexpected type {:#?}", ty),
}
}
pub fn get_supported_generic_params(generics: &Generics) -> Vec<(String, Option<Type>)> {
let mut param_names: Vec<(String, Option<Type>)> = vec![];
for param in &generics.params {
match param {
GenericParam::Type(type_param) => {
let name = type_param.ident.to_string();
param_names.push((name, None));
}
GenericParam::Const(const_param) => {
let name = const_param.ident.to_string();
let ty = const_param.ty.clone();
param_names.push((name, Some(ty)));
}
GenericParam::Lifetime(_lifetime_param) => continue,
#[allow(unreachable_patterns)]
_ => panic!("Unexpected generic parameter {:#?}", param),
}
}
param_names
}
pub fn strip_generic_args_lifetimes(gen_args: &mut AngleBracketedGenericArguments) -> () {
let mut res = gen_args.args.clone();
res.clear();
for gen_arg in gen_args.args.iter() {
if let GenericArgument::Lifetime(_gen_arg_lifetime) = gen_arg {
continue;
}
res.push(gen_arg.clone());
}
gen_args.args = res;
}
pub fn strip_generics_lifetimes(generics: &mut Generics) -> () {
let mut res = generics.params.clone();
res.clear();
for gen_param in generics.params.iter() {
if let GenericParam::Lifetime(_) = gen_param {
continue;
}
res.push(gen_param.clone());
}
generics.params = res;
}
pub fn item_string_pretty(item: &Item) -> String {
let file = syn::File {
attrs: vec![],
items: vec![item.clone()],
shebang: None,
};
file_item_string_pretty(&file)
}
pub fn file_item_string_pretty(file: &syn::File) -> String {
prettyplease::unparse(file)
}
#[derive(Debug, Clone)]
pub struct ClosureParam {
pub name: String,
pub ty: Option<Type>, }
#[derive(Debug, Clone)]
pub struct ClosureInfo {
pub params: Vec<ClosureParam>,
pub body: Box<Expr>,
pub is_async: bool,
pub is_move: bool,
}
pub fn parse_closure(closure_expr: &ExprClosure) -> ClosureInfo {
let mut params = Vec::new();
for input in &closure_expr.inputs {
match input {
Pat::Ident(pat_ident) => {
params.push(ClosureParam {
name: pat_ident.ident.to_string(),
ty: None,
});
}
Pat::Type(pat_type) => {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
params.push(ClosureParam {
name: pat_ident.ident.to_string(),
ty: Some((*pat_type.ty).clone()),
});
} else {
panic!("Unsupported closure parameter pattern: {:#?}", pat_type.pat);
}
}
_ => panic!("Unsupported closure parameter pattern: {:#?}", input),
}
}
ClosureInfo {
params,
body: closure_expr.body.clone(),
is_async: closure_expr.asyncness.is_some(),
is_move: closure_expr.capture.is_some(),
}
}
fn expr_references_vars(expr: &Expr, var_names: &HashSet<String>) -> Vec<String> {
let mut captured = Vec::new();
match expr {
Expr::Path(path_expr) => {
if let Some(ident) = path_expr.path.get_ident() {
let name = ident.to_string();
if var_names.contains(&name) {
captured.push(name);
}
}
}
Expr::Binary(bin_expr) => {
captured.extend(expr_references_vars(&bin_expr.left, var_names));
captured.extend(expr_references_vars(&bin_expr.right, var_names));
}
Expr::Unary(unary_expr) => {
captured.extend(expr_references_vars(&unary_expr.expr, var_names));
}
Expr::Call(call_expr) => {
captured.extend(expr_references_vars(&call_expr.func, var_names));
for arg in &call_expr.args {
captured.extend(expr_references_vars(arg, var_names));
}
}
Expr::MethodCall(method_call) => {
captured.extend(expr_references_vars(&method_call.receiver, var_names));
for arg in &method_call.args {
captured.extend(expr_references_vars(arg, var_names));
}
}
Expr::Field(field_expr) => {
captured.extend(expr_references_vars(&field_expr.base, var_names));
}
Expr::Index(index_expr) => {
captured.extend(expr_references_vars(&index_expr.expr, var_names));
captured.extend(expr_references_vars(&index_expr.index, var_names));
}
Expr::Paren(paren_expr) => {
captured.extend(expr_references_vars(&paren_expr.expr, var_names));
}
Expr::Block(block_expr) => {
for stmt in &block_expr.block.stmts {
if let syn::Stmt::Expr(expr, _) = stmt {
captured.extend(expr_references_vars(expr, var_names));
}
}
}
Expr::If(if_expr) => {
captured.extend(expr_references_vars(&if_expr.cond, var_names));
for stmt in &if_expr.then_branch.stmts {
if let syn::Stmt::Expr(expr, _) = stmt {
captured.extend(expr_references_vars(expr, var_names));
}
}
if let Some((_, else_expr)) = &if_expr.else_branch {
captured.extend(expr_references_vars(else_expr, var_names));
}
}
_ => {}
}
captured
}
pub fn get_closure_captures(
closure_expr: &ExprClosure,
outer_scope_vars: &[String],
) -> Vec<String> {
let closure_info = parse_closure(closure_expr);
let outer_vars: HashSet<String> = outer_scope_vars.iter().cloned().collect();
let param_names: HashSet<String> = closure_info.params.iter().map(|p| p.name.clone()).collect();
let vars_available_to_capture: HashSet<String> =
outer_vars.difference(¶m_names).cloned().collect();
let mut captured = expr_references_vars(&closure_info.body, &vars_available_to_capture);
captured.sort();
captured.dedup();
captured
}
pub fn is_closure(expr: &Expr) -> bool {
matches!(expr, Expr::Closure(_))
}