use crate::error::JITError;
use crate::syn_utils::*;
use std::collections::HashMap;
use syn::{ImplItem, ImplItemFn, Item, ItemFn, ItemImpl, ItemMod, ItemStruct, UseTree};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DefId {
pub module: String,
pub name: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DefKind {
Fn,
Struct,
Trait,
AssocFn,
}
#[derive(Debug, Clone)]
pub enum Res {
Def(DefKind, DefId),
Local(String),
PrimTy(String),
Err,
}
impl Res {
pub fn def_id(&self) -> Option<&DefId> {
match self {
Res::Def(_, id) => Some(id),
_ => None,
}
}
pub fn def_kind(&self) -> Option<DefKind> {
match self {
Res::Def(kind, _) => Some(*kind),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Namespace {
Type,
Value,
}
pub struct ModuleItems {
pub functions: HashMap<String, ItemFn>,
pub structs: HashMap<String, ItemStruct>,
pub struct_impls: HashMap<String, Vec<ItemImpl>>,
pub trait_impls: HashMap<(String, String), Vec<ItemImpl>>,
pub primitives: HashMap<(String, String), ItemImpl>,
}
impl ModuleItems {
fn new() -> Self {
Self {
functions: HashMap::new(),
structs: HashMap::new(),
struct_impls: HashMap::new(),
trait_impls: HashMap::new(),
primitives: HashMap::new(),
}
}
}
pub struct NameResolver {
items: HashMap<String, ModuleItems>,
modules: HashMap<String, ItemMod>,
imports: HashMap<String, HashMap<String, String>>,
core_module: Option<String>,
cached_primitives: HashMap<(String, String), ItemImpl>,
cached_functions: HashMap<String, (String, ItemFn)>,
cached_structs: HashMap<String, ItemStruct>,
cached_struct_impls: HashMap<String, Vec<(String, ItemImpl)>>,
cached_trait_impls: HashMap<(String, String), Vec<(String, ItemImpl)>>,
}
impl NameResolver {
fn collect_use_imports(
items_block: &[Item],
items: &HashMap<String, ModuleItems>,
module_imports: &mut HashMap<String, String>,
) {
for item in items_block {
match item {
Item::Use(use_item) => {
Self::process_use_tree(&use_item.tree, &[], items, module_imports);
}
Item::Mod(submod) => {
if let Some((_, sub_items)) = &submod.content {
Self::collect_use_imports(sub_items, items, module_imports);
}
}
_ => {}
}
}
}
#[allow(clippy::too_many_arguments)]
fn index_items(
items_block: &[Item],
module_name: &str,
mi: &mut ModuleItems,
has_cuda_tile_ty: &mut bool,
cached_functions: &mut HashMap<String, (String, ItemFn)>,
cached_structs: &mut HashMap<String, ItemStruct>,
cached_struct_impls: &mut HashMap<String, Vec<(String, ItemImpl)>>,
cached_trait_impls: &mut HashMap<(String, String), Vec<(String, ItemImpl)>>,
cached_primitives: &mut HashMap<(String, String), ItemImpl>,
) -> Result<(), JITError> {
for item in items_block {
match item {
Item::Fn(f) => {
let name = f.sig.ident.to_string();
mi.functions.insert(name.clone(), f.clone());
if cached_functions
.insert(name.clone(), (module_name.to_string(), f.clone()))
.is_some()
{
return Err(JITError::generic_err(
&format!("duplicate functions are not supported; try renaming your function: {name}"),
));
}
}
Item::Struct(s) => {
let name = s.ident.to_string();
mi.structs.insert(name.clone(), s.clone());
cached_structs.insert(name, s.clone());
}
Item::Impl(impl_item) => {
let self_ident = get_type_str(&impl_item.self_ty);
let trait_ident = impl_item
.trait_
.as_ref()
.map(|(_, path, _)| path.segments.last().unwrap().ident.to_string());
match (&self_ident, &trait_ident) {
(Some(self_name), Some(trait_name)) => {
if get_meta_list("cuda_tile :: ty", &impl_item.attrs).is_some() {
*has_cuda_tile_ty = true;
let key = (trait_name.clone(), self_name.clone());
mi.primitives.insert(key.clone(), impl_item.clone());
cached_primitives.insert(key, impl_item.clone());
} else {
let key = (trait_name.clone(), self_name.clone());
mi.trait_impls
.entry(key.clone())
.or_default()
.push(impl_item.clone());
cached_trait_impls
.entry(key)
.or_default()
.push((module_name.to_string(), impl_item.clone()));
}
}
(Some(self_name), None) => {
mi.struct_impls
.entry(self_name.clone())
.or_default()
.push(impl_item.clone());
cached_struct_impls
.entry(self_name.clone())
.or_default()
.push((module_name.to_string(), impl_item.clone()));
}
_ => {}
}
}
Item::Mod(submod) => {
if let Some((_, sub_items)) = &submod.content {
Self::index_items(
sub_items,
module_name,
mi,
has_cuda_tile_ty,
cached_functions,
cached_structs,
cached_struct_impls,
cached_trait_impls,
cached_primitives,
)?;
}
}
_ => {}
}
}
Ok(())
}
pub fn build(module_asts: &[(String, ItemMod)]) -> Result<Self, JITError> {
let mut items: HashMap<String, ModuleItems> = HashMap::new();
let mut modules: HashMap<String, ItemMod> = HashMap::new();
let mut core_module: Option<String> = None;
let mut cached_primitives: HashMap<(String, String), ItemImpl> = HashMap::new();
let mut cached_functions: HashMap<String, (String, ItemFn)> = HashMap::new();
let mut cached_structs: HashMap<String, ItemStruct> = HashMap::new();
let mut cached_struct_impls: HashMap<String, Vec<(String, ItemImpl)>> = HashMap::new();
let mut cached_trait_impls: HashMap<(String, String), Vec<(String, ItemImpl)>> =
HashMap::new();
for (module_name, module_ast) in module_asts {
modules.insert(module_name.clone(), module_ast.clone());
let mut mi = ModuleItems::new();
let Some(content) = &module_ast.content else {
items.insert(module_name.clone(), mi);
continue;
};
let mut has_cuda_tile_ty = false;
Self::index_items(
&content.1,
module_name,
&mut mi,
&mut has_cuda_tile_ty,
&mut cached_functions,
&mut cached_structs,
&mut cached_struct_impls,
&mut cached_trait_impls,
&mut cached_primitives,
)?;
if has_cuda_tile_ty {
core_module = Some(module_name.clone());
}
items.insert(module_name.clone(), mi);
}
let mut imports: HashMap<String, HashMap<String, String>> = HashMap::new();
for (module_name, module_ast) in module_asts {
let mut module_imports: HashMap<String, String> = HashMap::new();
if let Some(content) = &module_ast.content {
Self::collect_use_imports(&content.1, &items, &mut module_imports);
}
imports.insert(module_name.clone(), module_imports);
}
Ok(NameResolver {
items,
modules,
imports,
core_module,
cached_primitives,
cached_functions,
cached_structs,
cached_struct_impls,
cached_trait_impls,
})
}
pub fn resolve_path(&self, path: &syn::Path, calling_module: &str) -> Res {
let segments: Vec<String> = path.segments.iter().map(|s| s.ident.to_string()).collect();
match segments.len() {
0 => Res::Err,
1 => {
let name = &segments[0];
self.resolve_unqualified(name, calling_module)
}
2 => {
let (qualifier, name) = (&segments[0], &segments[1]);
if let Some(res) = self.resolve_in_module(name, qualifier) {
return res;
}
if let Some((module, _, _method)) = self.find_method(qualifier, name) {
return Res::Def(
DefKind::AssocFn,
DefId {
module: module.to_string(),
name: name.clone(),
},
);
}
Res::Err
}
_ => {
for i in 0..segments.len() - 1 {
let candidate_module = &segments[i];
let item_name = &segments[segments.len() - 1];
if let Some(res) = self.resolve_in_module(item_name, candidate_module) {
return res;
}
}
let name = &segments[segments.len() - 1];
self.resolve_unqualified(name, calling_module)
}
}
}
fn resolve_unqualified(&self, name: &str, calling_module: &str) -> Res {
if let Some(res) = self.resolve_in_module(name, calling_module) {
return res;
}
if let Some(module_imports) = self.imports.get(calling_module) {
if let Some(source_module) = module_imports.get(name) {
if let Some(res) = self.resolve_in_module(name, source_module) {
return res;
}
}
}
if let Some(core) = &self.core_module {
if calling_module != core {
if let Some(res) = self.resolve_in_module(name, core) {
return res;
}
}
}
for (module_name, mi) in &self.items {
if let Some(res) = Self::lookup_in_items(name, module_name, mi) {
return res;
}
}
Res::Err
}
fn resolve_in_module(&self, name: &str, module: &str) -> Option<Res> {
let mi = self.items.get(module)?;
Self::lookup_in_items(name, module, mi)
}
fn lookup_in_items(name: &str, module: &str, mi: &ModuleItems) -> Option<Res> {
if mi.functions.contains_key(name) {
return Some(Res::Def(
DefKind::Fn,
DefId {
module: module.to_string(),
name: name.to_string(),
},
));
}
if mi.structs.contains_key(name) {
return Some(Res::Def(
DefKind::Struct,
DefId {
module: module.to_string(),
name: name.to_string(),
},
));
}
None
}
pub fn get_fn(&self, def_id: &DefId) -> Option<&ItemFn> {
self.items.get(&def_id.module)?.functions.get(&def_id.name)
}
pub fn get_struct(&self, def_id: &DefId) -> Option<&ItemStruct> {
self.items.get(&def_id.module)?.structs.get(&def_id.name)
}
pub fn find_method(
&self,
struct_name: &str,
method_name: &str,
) -> Option<(&str, &ItemImpl, ImplItemFn)> {
for (module_name, mi) in &self.items {
if let Some(impls) = mi.struct_impls.get(struct_name) {
for impl_item in impls {
for item in &impl_item.items {
if let ImplItem::Fn(f) = item {
if f.sig.ident == method_name {
return Some((module_name.as_str(), impl_item, f.clone()));
}
}
}
}
}
}
None
}
pub fn get_primitive(&self, trait_name: &str, rust_type: &str) -> Option<&ItemImpl> {
let key = (trait_name.to_string(), rust_type.to_string());
for mi in self.items.values() {
if let Some(impl_item) = mi.primitives.get(&key) {
return Some(impl_item);
}
}
None
}
pub fn get_primitive_attrs(&self, trait_name: &str, rust_type: &str) -> Option<SingleMetaList> {
let impl_item = self.get_primitive(trait_name, rust_type)?;
get_meta_list("cuda_tile :: ty", &impl_item.attrs)
}
pub fn get_trait_impl(&self, trait_name: &str, self_type: &str) -> Option<(&str, &ItemImpl)> {
let key = (trait_name.to_string(), self_type.to_string());
for (module_name, mi) in &self.items {
if let Some(impls) = mi.trait_impls.get(&key) {
let Some(impl_item) = impls.first() else {
continue;
};
return Some((module_name.as_str(), impl_item));
}
}
None
}
pub fn get_type_attrs(&self, struct_name: &str) -> Option<SingleMetaList> {
for mi in self.items.values() {
if let Some(s) = mi.structs.get(struct_name) {
return get_meta_list("cuda_tile :: ty", &s.attrs);
}
}
None
}
pub fn get_op_attrs(&self, fn_name: &str) -> Option<SingleMetaList> {
for mi in self.items.values() {
if let Some(f) = mi.functions.get(fn_name) {
return get_meta_list("cuda_tile :: op", &f.attrs);
}
}
None
}
pub fn get_struct_field_type(&self, struct_name: &str, field_name: &str) -> Option<syn::Type> {
for mi in self.items.values() {
if let Some(s) = mi.structs.get(struct_name) {
for field in &s.fields {
if let Some(ident) = &field.ident {
if ident == field_name {
return Some(field.ty.clone());
}
}
}
}
}
None
}
pub fn module(&self, name: &str) -> Option<&ItemMod> {
self.modules.get(name)
}
pub fn has_module(&self, name: &str) -> bool {
self.modules.contains_key(name)
}
pub fn core_module(&self) -> Option<&str> {
self.core_module.as_deref()
}
pub fn primitives(&self) -> &HashMap<(String, String), ItemImpl> {
&self.cached_primitives
}
pub fn functions(&self) -> &HashMap<String, (String, ItemFn)> {
&self.cached_functions
}
pub fn structs(&self) -> &HashMap<String, ItemStruct> {
&self.cached_structs
}
pub fn struct_impls(&self) -> &HashMap<String, Vec<(String, ItemImpl)>> {
&self.cached_struct_impls
}
pub fn trait_impls(&self) -> &HashMap<(String, String), Vec<(String, ItemImpl)>> {
&self.cached_trait_impls
}
pub fn all_modules(&self) -> &HashMap<String, ItemMod> {
&self.modules
}
pub fn find_all_definitions(&self, name: &str) -> Vec<&str> {
self.items
.iter()
.filter(|(_, mi)| mi.functions.contains_key(name) || mi.structs.contains_key(name))
.map(|(module_name, _)| module_name.as_str())
.collect()
}
fn process_use_tree(
tree: &UseTree,
path_prefix: &[String],
items: &HashMap<String, ModuleItems>,
imports: &mut HashMap<String, String>,
) {
match tree {
UseTree::Path(path) => {
let mut prefix = path_prefix.to_vec();
prefix.push(path.ident.to_string());
Self::process_use_tree(&path.tree, &prefix, items, imports);
}
UseTree::Name(name) => {
if let Some(source) = path_prefix.last() {
imports.insert(name.ident.to_string(), source.clone());
}
}
UseTree::Glob(_) => {
if let Some(source) = path_prefix.last() {
if let Some(mi) = items.get(source) {
for name in mi.functions.keys() {
imports.insert(name.clone(), source.clone());
}
for name in mi.structs.keys() {
imports.insert(name.clone(), source.clone());
}
}
}
}
UseTree::Group(group) => {
for tree in &group.items {
Self::process_use_tree(tree, path_prefix, items, imports);
}
}
UseTree::Rename(rename) => {
if let Some(source) = path_prefix.last() {
imports.insert(rename.rename.to_string(), source.clone());
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
fn make_module(name: &str, items_vec: Vec<Item>) -> (String, ItemMod) {
let ident = syn::Ident::new(name, proc_macro2::Span::call_site());
let module: ItemMod = parse_quote! {
mod #ident {
#(#items_vec)*
}
};
(name.to_string(), module)
}
fn parse_path(s: &str) -> syn::Path {
syn::parse_str(s).unwrap()
}
#[test]
fn resolve_unqualified_local() {
let (name, module) = make_module("test_mod", vec![parse_quote! { fn my_func() {} }]);
let resolver = NameResolver::build(&[(name, module)]).unwrap();
let res = resolver.resolve_path(&parse_path("my_func"), "test_mod");
match res {
Res::Def(DefKind::Fn, def_id) => {
assert_eq!(def_id.module, "test_mod");
assert_eq!(def_id.name, "my_func");
}
_ => panic!("expected Def(Fn, ...), got {:?}", res),
}
}
#[test]
fn resolve_qualified_module_item() {
let (a, a_mod) = make_module("mod_a", vec![parse_quote! { fn helper() {} }]);
let (b, b_mod) = make_module("mod_b", vec![parse_quote! { fn other() {} }]);
let resolver = NameResolver::build(&[(a, a_mod), (b, b_mod)]).unwrap();
let res = resolver.resolve_path(&parse_path("mod_a::helper"), "mod_b");
match res {
Res::Def(DefKind::Fn, def_id) => assert_eq!(def_id.module, "mod_a"),
_ => panic!("expected Def, got {:?}", res),
}
}
#[test]
fn resolve_unknown_returns_err() {
let (name, module) = make_module("test_mod", vec![parse_quote! { fn my_func() {} }]);
let resolver = NameResolver::build(&[(name, module)]).unwrap();
assert!(matches!(
resolver.resolve_path(&parse_path("nonexistent"), "test_mod"),
Res::Err
));
}
#[test]
fn duplicate_function_names_rejected() {
let (a, a_mod) = make_module("mod_a", vec![parse_quote! { fn dup() -> i32 { 1 } }]);
let (b, b_mod) = make_module("mod_b", vec![parse_quote! { fn dup() -> i32 { 2 } }]);
assert!(NameResolver::build(&[(a, a_mod), (b, b_mod)]).is_err());
}
#[test]
fn cross_module_resolution() {
let (a, a_mod) = make_module("mod_a", vec![parse_quote! { fn helper() {} }]);
let (b, b_mod) = make_module("mod_b", vec![parse_quote! { fn other() {} }]);
let resolver = NameResolver::build(&[(a, a_mod), (b, b_mod)]).unwrap();
match resolver.resolve_path(&parse_path("helper"), "mod_b") {
Res::Def(_, def_id) => assert_eq!(def_id.module, "mod_a"),
_ => panic!("expected Def"),
}
match resolver.resolve_path(&parse_path("other"), "mod_b") {
Res::Def(_, def_id) => assert_eq!(def_id.module, "mod_b"),
_ => panic!("expected Def"),
}
}
#[test]
fn resolve_struct() {
let (name, module) = make_module("test_mod", vec![parse_quote! { struct Foo {} }]);
let resolver = NameResolver::build(&[(name, module)]).unwrap();
match resolver.resolve_path(&parse_path("Foo"), "test_mod") {
Res::Def(DefKind::Struct, def_id) => {
assert_eq!(def_id.name, "Foo");
assert!(resolver.get_struct(&def_id).is_some());
}
_ => panic!("expected Def(Struct, ...)"),
}
}
#[test]
fn cached_flat_maps_populated() {
let (a, a_mod) = make_module(
"mod_a",
vec![
parse_quote! { fn alpha() {} },
parse_quote! { struct Beta {} },
],
);
let (b, b_mod) = make_module("mod_b", vec![parse_quote! { fn gamma() {} }]);
let resolver = NameResolver::build(&[(a, a_mod), (b, b_mod)]).unwrap();
assert!(resolver.functions().contains_key("alpha"));
assert!(resolver.functions().contains_key("gamma"));
assert!(resolver.structs().contains_key("Beta"));
}
#[test]
fn get_fn_via_def_id() {
let (name, module) = make_module("test_mod", vec![parse_quote! { fn my_func() {} }]);
let resolver = NameResolver::build(&[(name, module)]).unwrap();
let def_id = DefId {
module: "test_mod".into(),
name: "my_func".into(),
};
let f = resolver.get_fn(&def_id).unwrap();
assert_eq!(f.sig.ident, "my_func");
}
}