use cutile_compiler::ast::{Module, SpanBase};
use cutile_compiler::compiler::CUDATileModules;
use syn::parse_quote;
#[cutile::module]
mod use_classifier_test_module {
use cutile::core::*;
#[cutile::entry()]
fn noop_kernel<const S: [i32; 1]>(out: &mut Tensor<f32, S>) {
let tile: Tile<f32, S> = constant(0.0f32, out.shape());
out.store(tile);
}
}
use use_classifier_test_module::__module_ast_self;
fn make_synthetic_kernel(uses: syn::ItemMod) -> Module {
Module::with_span_base("synthetic_kernel", uses, SpanBase::unknown())
}
#[test]
fn from_kernel_catalogs_stdlib_imports() {
let kernel_ast: syn::ItemMod = parse_quote! {
pub mod synthetic_kernel {
use std::collections::HashMap;
use cutile::core::*;
}
};
let modules = CUDATileModules::from_kernel(make_synthetic_kernel(kernel_ast))
.expect("from_kernel should construct CUDATileModules");
let hint = modules
.unresolved_name_hint("HashMap")
.expect("HashMap should be in the catalog");
assert!(
hint.contains("std::collections::HashMap"),
"hint should reference the import path; got: {hint}"
);
assert!(
hint.contains("standard-library"),
"stdlib imports should get the stdlib hint; got: {hint}"
);
}
#[test]
fn from_kernel_catalogs_third_party_imports() {
let kernel_ast: syn::ItemMod = parse_quote! {
pub mod synthetic_kernel {
use rayon::slice::ParallelSliceMut;
use cutile::core::*;
}
};
let modules = CUDATileModules::from_kernel(make_synthetic_kernel(kernel_ast))
.expect("from_kernel should construct CUDATileModules");
let hint = modules
.unresolved_name_hint("ParallelSliceMut")
.expect("ParallelSliceMut should be in the catalog");
assert!(hint.contains("rayon::slice::ParallelSliceMut"));
assert!(
!hint.contains("standard-library"),
"non-stdlib imports should NOT get the stdlib hint; got: {hint}"
);
}
#[test]
fn from_kernel_catalogs_unannotated_user_module() {
let kernel_ast: syn::ItemMod = parse_quote! {
pub mod synthetic_kernel {
use my_crate::utils::compute;
use cutile::core::*;
}
};
let modules = CUDATileModules::from_kernel(make_synthetic_kernel(kernel_ast))
.expect("from_kernel should construct CUDATileModules");
let hint = modules
.unresolved_name_hint("compute")
.expect("compute should be in the catalog");
assert!(hint.contains("my_crate::utils::compute"));
}
#[test]
fn from_kernel_does_not_catalog_registered_imports() {
let kernel = __module_ast_self();
let modules = CUDATileModules::from_kernel(kernel).expect("from_kernel should succeed");
assert!(
modules.unresolved_name_hint("Tile").is_none(),
"names from a registered cuTile module should not be catalogued"
);
assert!(
modules.unresolved_name_hint("Tensor").is_none(),
"names from a registered cuTile module should not be catalogued"
);
}
#[test]
fn from_kernel_does_not_catalog_allowed_external_imports() {
let kernel_ast: syn::ItemMod = parse_quote! {
pub mod synthetic_kernel {
use half::f16;
use half::bf16;
use cutile::core::*;
}
};
let modules = CUDATileModules::from_kernel(make_synthetic_kernel(kernel_ast))
.expect("from_kernel should construct CUDATileModules");
assert!(
modules.unresolved_name_hint("f16").is_none(),
"names on the external allowlist should not be catalogued"
);
assert!(
modules.unresolved_name_hint("bf16").is_none(),
"names on the external allowlist should not be catalogued"
);
}
#[test]
fn from_kernel_catalogs_renamed_imports_under_alias() {
let kernel_ast: syn::ItemMod = parse_quote! {
pub mod synthetic_kernel {
use std::collections::HashMap as Map;
use cutile::core::*;
}
};
let modules = CUDATileModules::from_kernel(make_synthetic_kernel(kernel_ast))
.expect("from_kernel should construct CUDATileModules");
assert!(
modules.unresolved_name_hint("HashMap").is_none(),
"the original name should NOT be in the catalog when renamed"
);
let hint = modules
.unresolved_name_hint("Map")
.expect("the alias should be catalogued");
assert!(hint.contains("std::collections::HashMap"));
}
#[test]
fn from_kernel_handles_grouped_imports() {
let kernel_ast: syn::ItemMod = parse_quote! {
pub mod synthetic_kernel {
use std::collections::{HashMap, BTreeMap};
use cutile::core::*;
}
};
let modules = CUDATileModules::from_kernel(make_synthetic_kernel(kernel_ast))
.expect("from_kernel should construct CUDATileModules");
assert!(modules.unresolved_name_hint("HashMap").is_some());
assert!(modules.unresolved_name_hint("BTreeMap").is_some());
}
#[test]
fn from_kernel_does_not_catalog_dep_modules_imports() {
let kernel = __module_ast_self();
let modules =
CUDATileModules::from_kernel(kernel).expect("from_kernel should construct CUDATileModules");
for name in &[
"atomic",
"cmp_ordering",
"ftz",
"rounding",
"scope",
"ordering",
"tma",
"padding",
"predicate",
"overflow",
] {
assert!(
modules.unresolved_name_hint(name).is_none(),
"transitive-dep import `{name}` should NOT be in the catalog",
);
}
}
#[test]
fn legacy_new_constructor_yields_empty_catalog() {
let kernel = __module_ast_self();
let modules = CUDATileModules::new(vec![kernel]).expect("new should succeed");
assert!(modules.unresolved_name_hint("HashMap").is_none());
assert!(modules.unresolved_name_hint("anything").is_none());
}