use crate::ast::Module;
use std::collections::HashMap;
use syn::{ItemMod, UseTree};
const ALLOWED_EXTERNAL_PREFIXES: &[&str] = &["half"];
const STDLIB_PREFIXES: &[&str] = &["std", "core", "alloc", "proc_macro", "test"];
#[derive(Debug, Clone)]
pub struct UseImport {
pub name: Option<String>,
pub path: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UseClassification {
Registered { absolute_path: String },
AllowedExternal,
Other,
}
pub type UseCatalog = HashMap<String, String>;
pub fn classify_use(use_path: &str, registry: &HashMap<&str, fn() -> Module>) -> UseClassification {
let head = use_path.split("::").next().unwrap_or("");
if ALLOWED_EXTERNAL_PREFIXES.contains(&head) {
return UseClassification::AllowedExternal;
}
if let Some(registered) = find_longest_registered_prefix(use_path, registry) {
return UseClassification::Registered {
absolute_path: registered.to_string(),
};
}
UseClassification::Other
}
fn find_longest_registered_prefix<'a>(
use_path: &str,
registry: &HashMap<&'a str, fn() -> Module>,
) -> Option<&'a str> {
let segments: Vec<&str> = use_path.split("::").collect();
for end in (1..=segments.len()).rev() {
let candidate = segments[..end].join("::");
if let Some(key) = registry.keys().find(|k| **k == candidate) {
return Some(*key);
}
}
None
}
pub fn is_stdlib_path(path: &str) -> bool {
let head = path.split("::").next().unwrap_or("");
STDLIB_PREFIXES.contains(&head)
}
pub fn collect_use_imports(item_mod: &ItemMod) -> Vec<UseImport> {
let mut out = Vec::new();
if let Some((_, items)) = &item_mod.content {
for item in items {
if let syn::Item::Use(use_item) = item {
walk_use_tree(&use_item.tree, &[], &mut out);
}
}
}
out
}
fn walk_use_tree(tree: &UseTree, prefix: &[String], out: &mut Vec<UseImport>) {
match tree {
UseTree::Path(p) => {
let mut new_prefix = prefix.to_vec();
new_prefix.push(p.ident.to_string());
walk_use_tree(&p.tree, &new_prefix, out);
}
UseTree::Name(n) => {
let mut path = prefix.to_vec();
let name = n.ident.to_string();
path.push(name.clone());
out.push(UseImport {
name: Some(name),
path: path.join("::"),
});
}
UseTree::Rename(r) => {
let mut path = prefix.to_vec();
path.push(r.ident.to_string());
out.push(UseImport {
name: Some(r.rename.to_string()),
path: path.join("::"),
});
}
UseTree::Glob(_) => {
out.push(UseImport {
name: None,
path: prefix.join("::"),
});
}
UseTree::Group(g) => {
for sub in &g.items {
walk_use_tree(sub, prefix, out);
}
}
}
}
pub fn unresolved_name_hint(name: &str, catalog: &UseCatalog) -> Option<String> {
let import_path = catalog.get(name)?;
let stdlib_hint = if is_stdlib_path(import_path) {
" Rust standard-library types have no GPU representation in cuTile kernels."
} else {
""
};
Some(format!(
"`{name}` was imported from `{import_path}`, which is not supported in \
cuTile kernels.{stdlib_hint} Only types from `#[cutile::module]`-annotated \
modules and the cuTile external allowlist (`half::*`) are available in \
kernel bodies.",
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_uses_allowlist() {
let registry: HashMap<&str, fn() -> Module> = HashMap::new();
assert_eq!(
classify_use("half::f16", ®istry),
UseClassification::AllowedExternal,
);
assert_eq!(
classify_use("half::bf16", ®istry),
UseClassification::AllowedExternal,
);
}
#[test]
fn classify_other_for_unknown() {
let registry: HashMap<&str, fn() -> Module> = HashMap::new();
assert_eq!(
classify_use("std::collections::HashMap", ®istry),
UseClassification::Other,
);
assert_eq!(
classify_use("rayon::slice::ParallelSliceMut", ®istry),
UseClassification::Other,
);
assert_eq!(
classify_use("crate::utils::compute", ®istry),
UseClassification::Other,
);
}
#[test]
fn classify_registered_finds_longest_prefix() {
fn dummy_build() -> Module {
Module::new("dummy", syn::parse_quote! { pub mod dummy {} })
}
let mut registry: HashMap<&str, fn() -> Module> = HashMap::new();
registry.insert("cutile::core", dummy_build);
registry.insert("cutile::core::sub", dummy_build);
assert_eq!(
classify_use("cutile::core::Tile", ®istry),
UseClassification::Registered {
absolute_path: "cutile::core".into()
},
);
assert_eq!(
classify_use("cutile::core::sub::nested", ®istry),
UseClassification::Registered {
absolute_path: "cutile::core::sub".into()
},
);
}
#[test]
fn is_stdlib_path_recognises_stdlib() {
assert!(is_stdlib_path("std::collections::HashMap"));
assert!(is_stdlib_path("core::mem::size_of"));
assert!(is_stdlib_path("alloc::vec::Vec"));
assert!(!is_stdlib_path("half::f16"));
assert!(!is_stdlib_path("crate::foo"));
assert!(!is_stdlib_path("rayon::iter"));
}
#[test]
fn collect_use_imports_handles_all_tree_shapes() {
let m: ItemMod = syn::parse_quote! {
mod m {
use foo::Bar;
use foo::Baz as Quux;
use foo::nested::*;
use foo::group::{One, Two as RenamedTwo};
}
};
let imports = collect_use_imports(&m);
let by_path: HashMap<String, Option<String>> =
imports.into_iter().map(|i| (i.path, i.name)).collect();
assert_eq!(by_path.get("foo::Bar"), Some(&Some("Bar".into())));
assert_eq!(by_path.get("foo::Baz"), Some(&Some("Quux".into())));
assert_eq!(by_path.get("foo::nested"), Some(&None));
assert_eq!(by_path.get("foo::group::One"), Some(&Some("One".into())));
assert_eq!(
by_path.get("foo::group::Two"),
Some(&Some("RenamedTwo".into()))
);
}
#[test]
fn unresolved_hint_appends_stdlib_note_for_std_paths() {
let mut catalog: UseCatalog = HashMap::new();
catalog.insert("HashMap".into(), "std::collections::HashMap".into());
catalog.insert("compute".into(), "rayon::slice::compute".into());
let std_hint = unresolved_name_hint("HashMap", &catalog).unwrap();
assert!(std_hint.contains("std::collections::HashMap"));
assert!(std_hint.contains("standard-library"));
let other_hint = unresolved_name_hint("compute", &catalog).unwrap();
assert!(other_hint.contains("rayon::slice::compute"));
assert!(!other_hint.contains("standard-library"));
assert!(unresolved_name_hint("not_in_catalog", &catalog).is_none());
}
}