use crate::ast::{Module, SourceLocation, SpanBase};
use crate::error::{JITError, SpannedJITError};
use crate::generics::{GenericVars, TypeInstance};
use crate::kernel_naming::KernelNaming;
use crate::passes::name_resolution::NameResolver;
use crate::syn_utils::*;
use quote::ToTokens;
use std::collections::{HashMap, HashSet};
use syn::spanned::Spanned;
use syn::{
Expr, ExprMethodCall, GenericArgument, GenericParam, ImplItem, ImplItemFn, ItemFn, ItemImpl,
ItemMod, ItemStruct, PathArguments, Type,
};
pub struct CUDATileModules {
pub(crate) name_resolver: NameResolver,
pub(crate) span_bases: HashMap<String, SpanBase>,
pub(crate) use_catalog: crate::use_classifier::UseCatalog,
}
impl CUDATileModules {
pub fn modules(&self) -> &HashMap<String, ItemMod> {
self.name_resolver.all_modules()
}
pub fn primitives(&self) -> &HashMap<(String, String), ItemImpl> {
self.name_resolver.primitives()
}
pub fn structs(&self) -> &HashMap<String, ItemStruct> {
self.name_resolver.structs()
}
pub fn struct_impls(&self) -> &HashMap<String, Vec<(String, ItemImpl)>> {
self.name_resolver.struct_impls()
}
pub fn functions(&self) -> &HashMap<String, (String, ItemFn)> {
self.name_resolver.functions()
}
pub fn unresolved_name_hint(&self, name: &str) -> Option<String> {
crate::use_classifier::unresolved_name_hint(name, &self.use_catalog)
}
}
impl CUDATileModules {
pub fn new(modules_vec: Vec<Module>) -> Result<Self, JITError> {
let mut module_asts: Vec<(String, ItemMod)> = Vec::new();
let mut span_bases: HashMap<String, SpanBase> = HashMap::new();
for module in &modules_vec {
let module_name = module.name().to_string();
let module_ast = module.ast();
if module_ast.content.is_none() {
return module
.resolve_span(&module_ast.span())
.jit_error_result(&format!(
"module `{module_name}` must have a body (non-empty content)"
));
}
module_asts.push((module_name.clone(), module_ast.clone()));
span_bases.insert(module_name, module.span_base().clone());
}
let name_resolver = NameResolver::build(&module_asts)?;
Ok(CUDATileModules {
name_resolver,
span_bases,
use_catalog: HashMap::new(),
})
}
pub fn modules_from_registry(paths: &[&str]) -> Result<Vec<Module>, JITError> {
use crate::registry::CUTILE_MODULES;
let registry: HashMap<&str, fn() -> Module> = CUTILE_MODULES
.iter()
.map(|e| (e.absolute_path, e.build))
.collect();
paths
.iter()
.map(|path| {
registry
.get(*path)
.copied()
.map(|build| build())
.ok_or_else(|| {
let mut registered: Vec<&&str> = registry.keys().collect();
registered.sort();
SourceLocation::unknown().jit_error(&format!(
"module `{path}` not found in linker registry; \
registered modules: {registered:?}",
))
})
})
.collect()
}
pub fn from_kernel(kernel: Module) -> Result<Self, JITError> {
use crate::registry::CUTILE_MODULES;
use crate::use_classifier::{
classify_use, collect_use_imports, UseCatalog, UseClassification,
};
let registry: HashMap<&str, fn() -> Module> = CUTILE_MODULES
.iter()
.map(|e| (e.absolute_path, e.build))
.collect();
let mut working_set: Vec<Module> = vec![kernel];
let kernel_path = working_set[0].absolute_path().to_string();
let mut registry_visited: HashSet<String> = HashSet::new();
let mut module_visited: HashSet<String> = working_set
.iter()
.map(|m| m.absolute_path().to_string())
.collect();
let mut use_catalog: UseCatalog = HashMap::new();
struct Pending {
name: Option<String>,
path: String,
owning_module: String,
}
let mut queue: Vec<Pending> = collect_use_imports(working_set[0].ast())
.into_iter()
.map(|imp| Pending {
name: imp.name,
path: imp.path,
owning_module: kernel_path.clone(),
})
.collect();
while let Some(pending) = queue.pop() {
let resolved =
resolve_crate_prefix(&pending.path, crate_root_of(&pending.owning_module));
let appears_in_kernel_module = pending.owning_module == kernel_path;
match classify_use(&resolved, ®istry) {
UseClassification::Registered { absolute_path } => {
if !registry_visited.insert(absolute_path.clone()) {
continue;
}
let build_fn = registry[absolute_path.as_str()];
let module = build_fn();
let owning = module.absolute_path().to_string();
if !module_visited.insert(owning.clone()) {
continue;
}
queue.extend(collect_use_imports(module.ast()).into_iter().map(|imp| {
Pending {
name: imp.name,
path: imp.path,
owning_module: owning.clone(),
}
}));
working_set.push(module);
}
UseClassification::AllowedExternal => {
}
UseClassification::Other => {
if appears_in_kernel_module {
if let Some(name) = pending.name {
use_catalog.entry(name).or_insert(pending.path);
}
}
}
}
}
let mut modules = Self::new(working_set)?;
modules.use_catalog = use_catalog;
Ok(modules)
}
}
fn crate_root_of(absolute_path: &str) -> &str {
absolute_path.split("::").next().unwrap_or("")
}
fn resolve_crate_prefix(use_path: &str, crate_root: &str) -> String {
if let Some(rest) = use_path.strip_prefix("crate::") {
if crate_root.is_empty() {
use_path.to_string()
} else {
format!("{crate_root}::{rest}")
}
} else if use_path == "crate" {
crate_root.to_string()
} else {
use_path.to_string()
}
}
impl CUDATileModules {
pub fn get_span_base(&self, module_name: &str) -> Option<&SpanBase> {
self.span_bases.get(module_name)
}
pub fn resolve_span(&self, module_name: &str, span: &proc_macro2::Span) -> SourceLocation {
match self.span_bases.get(module_name) {
Some(base) => base.resolve_span(span),
None => SourceLocation::unknown(),
}
}
pub fn get_source_file(&self, module_name: &str) -> Option<&str> {
self.span_bases.get(module_name).and_then(|sb| {
if sb.file.is_empty() {
None
} else {
Some(sb.file.as_str())
}
})
}
}
#[derive(Default)]
struct TraitMatchCtx {
type_params: HashSet<String>,
const_array_params: HashMap<String, Option<usize>>,
const_scalar_params: HashSet<String>,
caller_array_params: HashMap<String, Vec<i32>>,
type_bindings: HashMap<String, String>,
array_bindings: HashMap<String, Vec<i32>>,
const_bindings: HashMap<String, i32>,
}
fn find_impl_method<'a>(item_impl: &'a ItemImpl, method_name: &str) -> Option<&'a ImplItemFn> {
item_impl.items.iter().find_map(|item| match item {
ImplItem::Fn(method) if method.sig.ident == method_name => Some(method),
_ => None,
})
}
fn instantiate_type_for_lookup(
ty: &Type,
generic_vars: &GenericVars,
primitives: &HashMap<(String, String), ItemImpl>,
) -> Type {
generic_vars
.instantiate_type(ty, primitives)
.map(|inst| inst.get_instantiated_type().clone())
.unwrap_or_else(|_| ty.clone())
}
fn trait_impl_matches_call(
item_impl: &ItemImpl,
impl_method: &ImplItemFn,
call_arg_rust_tys: &[Type],
generic_vars: &GenericVars,
primitives: &HashMap<(String, String), ItemImpl>,
) -> bool {
let self_ty = &*item_impl.self_ty;
let (param_types, _return_type) = get_sig_types(&impl_method.sig, Some(self_ty));
if param_types.len() != call_arg_rust_tys.len() {
return false;
}
let mut ctx = TraitMatchCtx::default();
ctx.caller_array_params = generic_vars.inst_array.clone();
collect_generics_for_trait_match(&item_impl.generics, &mut ctx);
collect_generics_for_trait_match(&impl_method.sig.generics, &mut ctx);
for (param_ty, call_arg_ty) in param_types.iter().zip(call_arg_rust_tys) {
let actual_ty = instantiate_type_for_lookup(call_arg_ty, generic_vars, primitives);
if !unify_trait_type(param_ty, &actual_ty, &mut ctx) {
return false;
}
}
true
}
fn collect_generics_for_trait_match(generics: &syn::Generics, ctx: &mut TraitMatchCtx) {
for param in &generics.params {
match param {
GenericParam::Type(type_param) => {
ctx.type_params.insert(type_param.ident.to_string());
}
GenericParam::Const(const_param) => match &const_param.ty {
Type::Array(array_ty) => {
ctx.const_array_params
.insert(const_param.ident.to_string(), const_array_len(array_ty));
}
_ => {
ctx.const_scalar_params
.insert(const_param.ident.to_string());
}
},
GenericParam::Lifetime(_) => {}
}
}
}
fn const_array_len(array_ty: &syn::TypeArray) -> Option<usize> {
expr_i32(&array_ty.len).map(|len| len as usize)
}
fn strip_reference_ty(ty: &Type) -> &Type {
match ty {
Type::Reference(type_ref) => strip_reference_ty(&type_ref.elem),
_ => ty,
}
}
fn unify_trait_type(pattern: &Type, actual: &Type, ctx: &mut TraitMatchCtx) -> bool {
let pattern = strip_reference_ty(pattern);
let actual = strip_reference_ty(actual);
if let Type::Path(pattern_path) = pattern {
let pattern_segment = pattern_path.path.segments.last().unwrap();
let pattern_ident = pattern_segment.ident.to_string();
if ctx.type_params.contains(&pattern_ident)
&& matches!(pattern_segment.arguments, PathArguments::None)
{
return bind_type_param(ctx, &pattern_ident, actual);
}
}
let (Type::Path(pattern_path), Type::Path(actual_path)) = (pattern, actual) else {
return pattern.to_token_stream().to_string() == actual.to_token_stream().to_string();
};
let pattern_segment = pattern_path.path.segments.last().unwrap();
let actual_segment = actual_path.path.segments.last().unwrap();
if pattern_segment.ident != actual_segment.ident {
return false;
}
match (&pattern_segment.arguments, &actual_segment.arguments) {
(PathArguments::None, PathArguments::None) => true,
(
PathArguments::AngleBracketed(pattern_args),
PathArguments::AngleBracketed(actual_args),
) => {
let mut pattern_args = pattern_args.clone();
let mut actual_args = actual_args.clone();
strip_generic_args_lifetimes(&mut pattern_args);
strip_generic_args_lifetimes(&mut actual_args);
if pattern_args.args.len() != actual_args.args.len() {
return false;
}
pattern_args.args.iter().zip(actual_args.args.iter()).all(
|(pattern_arg, actual_arg)| unify_trait_generic_arg(pattern_arg, actual_arg, ctx),
)
}
_ => false,
}
}
fn unify_trait_generic_arg(
pattern_arg: &GenericArgument,
actual_arg: &GenericArgument,
ctx: &mut TraitMatchCtx,
) -> bool {
match (pattern_arg, actual_arg) {
(GenericArgument::Type(pattern_ty), GenericArgument::Type(actual_ty)) => {
if let Some(param_name) = const_array_param_type_name(pattern_ty, ctx) {
let Some(actual_shape) = const_array_shape_from_type(actual_ty, ctx) else {
return false;
};
return bind_array_param_checked(ctx, ¶m_name, actual_shape);
}
unify_trait_type(pattern_ty, actual_ty, ctx)
}
(GenericArgument::Type(pattern_ty), GenericArgument::Const(actual_expr)) => {
let Some(param_name) = const_array_param_type_name(pattern_ty, ctx) else {
return false;
};
let Some(actual_shape) = const_shape(actual_expr) else {
return false;
};
bind_array_param_checked(ctx, ¶m_name, actual_shape)
}
(GenericArgument::Const(pattern_expr), GenericArgument::Type(actual_ty)) => {
let Some(actual_shape) = const_array_shape_from_type(actual_ty, ctx) else {
return false;
};
unify_shape_pattern(pattern_expr, &actual_shape, ctx)
}
(GenericArgument::Const(pattern_expr), GenericArgument::Const(actual_expr)) => {
unify_trait_const(pattern_expr, actual_expr, ctx)
}
(GenericArgument::Lifetime(_), GenericArgument::Lifetime(_)) => true,
_ => false,
}
}
fn const_array_param_type_name(ty: &Type, ctx: &TraitMatchCtx) -> Option<String> {
let Type::Path(type_path) = strip_reference_ty(ty) else {
return None;
};
let segment = type_path.path.segments.last()?;
if !matches!(segment.arguments, PathArguments::None) {
return None;
}
let ident = segment.ident.to_string();
ctx.const_array_params.contains_key(&ident).then_some(ident)
}
fn const_array_shape_from_type(ty: &Type, ctx: &TraitMatchCtx) -> Option<Vec<i32>> {
let Type::Path(type_path) = strip_reference_ty(ty) else {
return None;
};
let segment = type_path.path.segments.last()?;
if !matches!(segment.arguments, PathArguments::None) {
return None;
}
let ident = segment.ident.to_string();
ctx.array_bindings
.get(&ident)
.cloned()
.or_else(|| ctx.caller_array_params.get(&ident).cloned())
}
fn bind_type_param(ctx: &mut TraitMatchCtx, name: &str, actual: &Type) -> bool {
let actual = actual.to_token_stream().to_string();
match ctx.type_bindings.get(name) {
Some(bound) => bound == &actual,
None => {
ctx.type_bindings.insert(name.to_string(), actual);
true
}
}
}
fn unify_trait_const(pattern: &Expr, actual: &Expr, ctx: &mut TraitMatchCtx) -> bool {
if let Expr::Path(pattern_path) = pattern {
let ident = get_ident_from_path_expr(pattern_path).to_string();
if let Some(expected_len) = ctx.const_array_params.get(&ident).copied().flatten() {
let Some(actual_shape) = const_shape(actual) else {
return false;
};
if actual_shape.len() != expected_len {
return false;
}
return bind_array_param_checked(ctx, &ident, actual_shape);
}
if ctx.const_array_params.contains_key(&ident) {
let Some(actual_shape) = const_shape(actual) else {
return false;
};
return bind_array_param_checked(ctx, &ident, actual_shape);
}
}
let Some(actual_shape) = const_shape(actual) else {
return pattern.to_token_stream().to_string() == actual.to_token_stream().to_string();
};
unify_shape_pattern(pattern, &actual_shape, ctx)
}
fn bind_array_param(ctx: &mut TraitMatchCtx, name: &str, actual: Vec<i32>) -> bool {
match ctx.array_bindings.get(name) {
Some(bound) => bound == &actual,
None => {
ctx.array_bindings.insert(name.to_string(), actual);
true
}
}
}
fn bind_array_param_checked(ctx: &mut TraitMatchCtx, name: &str, actual: Vec<i32>) -> bool {
if let Some(expected_len) = ctx.const_array_params.get(name).copied().flatten() {
if actual.len() != expected_len {
return false;
}
}
bind_array_param(ctx, name, actual)
}
fn unify_shape_pattern(pattern: &Expr, actual_shape: &[i32], ctx: &mut TraitMatchCtx) -> bool {
let Some(pattern_elems) = const_shape_elements(pattern) else {
return false;
};
if pattern_elems.len() != actual_shape.len() {
return false;
}
pattern_elems
.iter()
.zip(actual_shape.iter())
.all(|(pattern_elem, actual_dim)| match expr_i32(pattern_elem) {
Some(pattern_dim) => pattern_dim == *actual_dim,
None => {
let Expr::Path(path) = pattern_elem else {
return false;
};
let ident = get_ident_from_path_expr(path).to_string();
if !ctx.const_scalar_params.contains(&ident) {
return false;
}
match ctx.const_bindings.get(&ident) {
Some(bound) => *bound == *actual_dim,
None => {
ctx.const_bindings.insert(ident, *actual_dim);
true
}
}
}
})
}
fn const_shape(expr: &Expr) -> Option<Vec<i32>> {
const_shape_elements(expr).and_then(|elems| elems.iter().map(expr_i32).collect())
}
fn const_shape_elements(expr: &Expr) -> Option<Vec<Expr>> {
match expr {
Expr::Block(block_expr) => match block_expr.block.stmts.as_slice() {
[syn::Stmt::Expr(inner, _)] => const_shape_elements(inner),
_ => None,
},
Expr::Array(array_expr) => Some(array_expr.elems.iter().cloned().collect()),
Expr::Repeat(repeat_expr) => {
let elem = *repeat_expr.expr.clone();
let len = expr_i32(&repeat_expr.len)? as usize;
Some(vec![elem; len])
}
_ => None,
}
}
fn expr_i32(expr: &Expr) -> Option<i32> {
match expr {
Expr::Lit(lit) => match &lit.lit {
syn::Lit::Int(int_lit) => int_lit.base10_parse::<i32>().ok(),
_ => None,
},
Expr::Unary(unary) => {
if !matches!(unary.op, syn::UnOp::Neg(_)) {
return None;
}
expr_i32(&unary.expr).map(|value| -value)
}
Expr::Paren(paren) => expr_i32(&paren.expr),
_ => None,
}
}
impl CUDATileModules {
pub fn get_primitives_attrs(
&self,
trait_name: &str,
rust_type_name: &str,
) -> Option<SingleMetaList> {
self.name_resolver
.get_primitive_attrs(trait_name, rust_type_name)
}
pub fn get_cuda_tile_type_attrs(&self, ident: &str) -> Option<SingleMetaList> {
self.name_resolver.get_type_attrs(ident)
}
pub fn get_function_by_name(&self, function_name: &str) -> Option<&(String, ItemFn)> {
let canonical_name = KernelNaming::canonical_public_name(function_name);
self.name_resolver.functions().get(canonical_name.as_str())
}
pub fn get_cuda_tile_op_attrs(&self, ident: &str) -> Option<SingleMetaList> {
self.name_resolver.get_op_attrs(ident)
}
pub fn get_fn_item(
&self,
module_name: &str,
function_name: &str,
) -> Result<&(String, ItemFn), JITError> {
if !self.name_resolver.has_module(module_name) {
return JITError::generic(&format!("undefined module: `{module_name}`"));
}
match self.get_function_by_name(function_name) {
Some(function) => Ok(function),
None => JITError::generic(&format!("undefined function: `{function_name}`")),
}
}
pub fn get_fn_entry_attrs(&self, fn_item: &ItemFn) -> Result<SingleMetaList, JITError> {
let entry_attrs = get_meta_list_by_last_segment("entry", &fn_item.attrs);
let Some(entry_attrs) = entry_attrs else {
return JITError::generic("function is missing a required `#[entry(...)]` attribute");
};
Ok(entry_attrs)
}
pub fn get_entry_arg_bool_by_function_name(
&self,
module_name: &str,
function_name: &str,
name: &str,
) -> Result<bool, JITError> {
let (_, fn_item) = self.get_fn_item(module_name, function_name)?;
let entry_attrs = self.get_fn_entry_attrs(fn_item)?;
Ok(entry_attrs.parse_bool(name).unwrap_or(false))
}
pub fn get_entry_arg_string_by_function_name(
&self,
module_name: &str,
function_name: &str,
name: &str,
) -> Result<Option<String>, JITError> {
let (_, fn_item) = self.get_fn_item(module_name, function_name)?;
let entry_attrs = self.get_fn_entry_attrs(fn_item)?;
Ok(entry_attrs.parse_string(name))
}
pub fn get_impl_item_fn(
&self,
receiver_rust_ty: &syn::Type,
method_call_expr: &ExprMethodCall,
generic_vars: &GenericVars,
call_arg_rust_tys: &[syn::Type],
) -> Result<Option<(String, ItemImpl, ImplItemFn)>, JITError> {
let method_name = method_call_expr.method.to_string();
let receiver_lookup_ty =
instantiate_type_for_lookup(receiver_rust_ty, generic_vars, self.primitives());
let receiver_type_str = get_type_ident(&receiver_lookup_ty)
.or_else(|| get_type_ident(receiver_rust_ty))
.map(|ident| ident.to_string());
if let Some(receiver_type_str) = receiver_type_str.as_deref() {
if let Some(impls_vec) = self.name_resolver.struct_impls().get(receiver_type_str) {
for (module_name, item_impl) in impls_vec {
if let Some(impl_method) = find_impl_method(item_impl, &method_name) {
return Ok(Some((
module_name.clone(),
item_impl.clone(),
impl_method.clone(),
)));
}
}
}
}
let mut trait_candidates = Vec::new();
if matches!(
generic_vars.instantiate_type(receiver_rust_ty, self.primitives())?,
TypeInstance::ElementType(_)
) {
if let Some(impls) = self
.name_resolver
.trait_impls()
.get(&("BroadcastScalar".to_string(), "E".to_string()))
{
trait_candidates.extend(impls.iter().cloned());
}
}
if let Some(receiver_type_str) = receiver_type_str.as_deref() {
for ((_trait_name, self_name), impls) in self.name_resolver.trait_impls() {
if self_name == receiver_type_str {
trait_candidates.extend(impls.iter().cloned());
}
}
}
let mut selected: Option<(String, ItemImpl, ImplItemFn)> = None;
for (module_name, item_impl) in trait_candidates {
for item in &item_impl.items {
match item {
ImplItem::Fn(impl_item_fn) => {
let impl_item_fn_name = impl_item_fn.sig.ident.to_string();
if method_name != impl_item_fn_name {
continue;
}
let dispatch_matches = trait_impl_matches_call(
&item_impl,
impl_item_fn,
call_arg_rust_tys,
generic_vars,
self.primitives(),
);
if dispatch_matches {
if selected.is_some() {
return Err(JITError::Generic(format!(
"ambiguous trait method dispatch for `{method_name}`"
)));
}
selected = Some((
module_name.clone(),
item_impl.clone(),
impl_item_fn.clone(),
));
}
}
_ => continue,
}
}
}
Ok(selected)
}
pub fn get_struct_field_type(&self, struct_name: &str, field_name: &str) -> Option<Type> {
self.name_resolver
.get_struct_field_type(struct_name, field_name)
}
}