use crate::priority::call_graph::FunctionId;
use anyhow::Result;
use im::{HashMap, HashSet, Vector};
use std::path::{Path, PathBuf};
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{File, ItemFn, ItemMod, ItemUse, Path as SynPath, UseTree, Visibility};
#[derive(Debug, Clone)]
pub struct ModuleBoundary {
pub module_path: String,
pub file_path: PathBuf,
pub parent_module: Option<String>,
pub submodules: HashSet<String>,
pub public_exports: Vector<PublicExport>,
}
#[derive(Debug, Clone)]
pub struct PublicExport {
pub name: String,
pub export_type: ExportType,
pub function_id: Option<FunctionId>,
pub visibility: VisibilityLevel,
pub line: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExportType {
Function,
Type,
Constant,
Module,
Macro,
Trait,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VisibilityLevel {
Private,
Crate,
Public,
PublicSuper,
PublicIn(String),
}
#[derive(Debug, Clone)]
pub struct PublicApiInfo {
pub function_id: FunctionId,
pub defining_module: String,
pub visibility: VisibilityLevel,
pub is_reexported: bool,
pub importing_modules: HashSet<String>,
}
#[derive(Debug, Clone)]
pub struct CrossModuleCall {
pub caller: FunctionId,
pub module_path: String,
pub function_name: String,
pub line: usize,
pub through_import: bool,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ModuleImport {
pub importing_module: String,
pub imported_module: String,
pub imported_items: Vector<String>,
pub is_glob_import: bool,
pub line: usize,
}
#[derive(Debug, Clone)]
pub struct CrossModuleTracker {
module_boundaries: HashMap<String, ModuleBoundary>,
public_apis: HashMap<FunctionId, PublicApiInfo>,
cross_module_calls: Vector<CrossModuleCall>,
module_imports: Vector<ModuleImport>,
file_to_module: HashMap<PathBuf, String>,
reexports: HashMap<String, Vector<String>>,
}
impl CrossModuleTracker {
pub fn new() -> Self {
Self {
module_boundaries: HashMap::new(),
public_apis: HashMap::new(),
cross_module_calls: Vector::new(),
module_imports: Vector::new(),
file_to_module: HashMap::new(),
reexports: HashMap::new(),
}
}
pub fn analyze_workspace(&mut self, workspace_files: &[(PathBuf, File)]) -> Result<()> {
for (file_path, ast) in workspace_files {
let module_path = self.infer_module_path(file_path);
self.file_to_module
.insert(file_path.clone(), module_path.clone());
let mut visitor = ModuleVisitor::new(file_path.clone(), module_path.clone());
visitor.visit_file(ast);
let boundary = ModuleBoundary {
module_path: module_path.clone(),
file_path: file_path.clone(),
parent_module: self.infer_parent_module(&module_path),
submodules: visitor.submodules.into_iter().collect(),
public_exports: visitor.public_exports.into_iter().collect(),
};
self.module_boundaries.insert(module_path, boundary);
}
for (file_path, ast) in workspace_files {
let module_path = self.file_to_module.get(file_path).unwrap().clone();
let mut call_visitor = CrossModuleCallVisitor::new(module_path.clone());
call_visitor.visit_file(ast);
for call in call_visitor.cross_module_calls {
self.cross_module_calls.push_back(call);
}
for import in call_visitor.module_imports {
self.module_imports.push_back(import);
}
}
self.build_public_api_mappings();
Ok(())
}
pub fn get_cross_module_calls(&self) -> Vector<CrossModuleCall> {
self.cross_module_calls.clone()
}
pub fn get_public_apis(&self) -> Vec<PublicApiInfo> {
self.public_apis.values().cloned().collect()
}
pub fn is_public_api(&self, func_id: &FunctionId) -> bool {
self.public_apis.contains_key(func_id)
}
fn find_function_export(boundary: &ModuleBoundary, function_name: &str) -> Option<FunctionId> {
boundary
.public_exports
.iter()
.find(|export| {
export.name == function_name && export.export_type == ExportType::Function
})
.and_then(|export| export.function_id.clone())
}
fn resolve_through_reexports(
&self,
reexported_modules: &Vector<String>,
function_name: &str,
) -> Option<FunctionId> {
reexported_modules
.iter()
.find_map(|module| self.resolve_module_call(module, function_name))
}
pub fn resolve_module_call(
&self,
module_path: &str,
function_name: &str,
) -> Option<FunctionId> {
if let Some(boundary) = self.module_boundaries.get(module_path) {
if let Some(func_id) = Self::find_function_export(boundary, function_name) {
return Some(func_id);
}
}
if let Some(reexported_modules) = self.reexports.get(module_path) {
return self.resolve_through_reexports(reexported_modules, function_name);
}
None
}
pub fn get_statistics(&self) -> CrossModuleStatistics {
let total_modules = self.module_boundaries.len();
let total_public_apis = self.public_apis.len();
let total_cross_module_calls = self.cross_module_calls.len();
let total_imports = self.module_imports.len();
let public_functions = self
.public_apis
.values()
.filter(|api| matches!(api.visibility, VisibilityLevel::Public))
.count();
let crate_functions = self
.public_apis
.values()
.filter(|api| matches!(api.visibility, VisibilityLevel::Crate))
.count();
CrossModuleStatistics {
total_modules,
total_public_apis,
total_cross_module_calls,
total_imports,
public_functions,
crate_functions,
}
}
pub fn get_public_exclusions(&self) -> HashSet<FunctionId> {
self.public_apis
.iter()
.filter(|(_, api)| matches!(api.visibility, VisibilityLevel::Public))
.map(|(func_id, _)| func_id.clone())
.collect()
}
pub fn get_crate_visible_functions(&self) -> HashSet<FunctionId> {
self.public_apis
.iter()
.filter(|(_, api)| !matches!(api.visibility, VisibilityLevel::Private))
.map(|(func_id, _)| func_id.clone())
.collect()
}
fn infer_module_path(&self, file_path: &Path) -> String {
let path_str = file_path.to_string_lossy();
let relative_path = path_str
.strip_prefix("src/")
.unwrap_or(&path_str)
.strip_suffix(".rs")
.unwrap_or(&path_str);
let module_path = relative_path.replace('/', "::");
match module_path.as_str() {
"lib" => "crate".to_string(),
"main" => "crate".to_string(),
_ => format!("crate::{module_path}"),
}
}
fn infer_parent_module(&self, module_path: &str) -> Option<String> {
if module_path == "crate" {
None
} else {
let parts: Vec<&str> = module_path.split("::").collect();
if parts.len() > 1 {
Some(parts[..parts.len() - 1].join("::"))
} else {
Some("crate".to_string())
}
}
}
fn build_public_api_mappings(&mut self) {
for (module_path, boundary) in &self.module_boundaries {
for export in &boundary.public_exports {
if export.export_type == ExportType::Function {
if let Some(func_id) = &export.function_id {
let api_info = PublicApiInfo {
function_id: func_id.clone(),
defining_module: module_path.clone(),
visibility: export.visibility.clone(),
is_reexported: false, importing_modules: HashSet::new(), };
self.public_apis.insert(func_id.clone(), api_info);
}
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct CrossModuleStatistics {
pub total_modules: usize,
pub total_public_apis: usize,
pub total_cross_module_calls: usize,
pub total_imports: usize,
pub public_functions: usize,
pub crate_functions: usize,
}
struct ModuleVisitor {
file_path: PathBuf,
module_path: String,
submodules: Vec<String>,
public_exports: Vec<PublicExport>,
}
impl ModuleVisitor {
fn new(file_path: PathBuf, module_path: String) -> Self {
Self {
file_path,
module_path,
submodules: Vec::new(),
public_exports: Vec::new(),
}
}
fn get_line_number(&self, span: proc_macro2::Span) -> usize {
span.start().line
}
fn extract_visibility(&self, vis: &Visibility) -> VisibilityLevel {
match vis {
Visibility::Public(_) => VisibilityLevel::Public,
Visibility::Restricted(restricted) => Self::classify_restricted_visibility(restricted),
Visibility::Inherited => VisibilityLevel::Private,
}
}
fn classify_restricted_visibility(restricted: &syn::VisRestricted) -> VisibilityLevel {
if restricted.in_token.is_some() {
Self::extract_public_in_visibility(restricted)
} else {
Self::extract_scope_visibility(restricted)
}
}
fn extract_public_in_visibility(restricted: &syn::VisRestricted) -> VisibilityLevel {
restricted
.path
.get_ident()
.map(|path| VisibilityLevel::PublicIn(path.to_string()))
.unwrap_or(VisibilityLevel::Crate)
}
fn extract_scope_visibility(restricted: &syn::VisRestricted) -> VisibilityLevel {
restricted
.path
.get_ident()
.and_then(|ident| match ident.to_string().as_str() {
"super" => Some(VisibilityLevel::PublicSuper),
"crate" => Some(VisibilityLevel::Crate),
_ => None,
})
.unwrap_or(VisibilityLevel::Crate)
}
}
impl<'ast> Visit<'ast> for ModuleVisitor {
fn visit_item_fn(&mut self, item: &'ast ItemFn) {
let visibility = self.extract_visibility(&item.vis);
if !matches!(visibility, VisibilityLevel::Private) {
let func_name = item.sig.ident.to_string();
let line = self.get_line_number(item.sig.ident.span());
let func_id = FunctionId::new(self.file_path.clone(), func_name.clone(), line);
let export = PublicExport {
name: func_name,
export_type: ExportType::Function,
function_id: Some(func_id),
visibility,
line,
};
self.public_exports.push(export);
}
syn::visit::visit_item_fn(self, item);
}
fn visit_item_mod(&mut self, item: &'ast ItemMod) {
let mod_name = item.ident.to_string();
let full_module_path = format!("{}::{}", self.module_path, mod_name);
self.submodules.push(full_module_path);
syn::visit::visit_item_mod(self, item);
}
}
struct CrossModuleCallVisitor {
current_module: String,
cross_module_calls: Vec<CrossModuleCall>,
module_imports: Vec<ModuleImport>,
current_function: Option<FunctionId>,
}
impl CrossModuleCallVisitor {
fn new(current_module: String) -> Self {
Self {
current_module,
cross_module_calls: Vec::new(),
module_imports: Vec::new(),
current_function: None,
}
}
fn get_line_number(&self, span: proc_macro2::Span) -> usize {
span.start().line
}
fn extract_path_string(&self, path: &SynPath) -> String {
path.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}
fn create_module_import(
importing_module: String,
imported_items: Vec<String>,
is_glob_import: bool,
line: usize,
) -> ModuleImport {
ModuleImport {
importing_module,
imported_module: "unknown".to_string(),
imported_items: imported_items.into_iter().collect(),
is_glob_import,
line,
}
}
fn extract_import_name(use_tree: &UseTree) -> Option<String> {
match use_tree {
UseTree::Name(use_name) => Some(use_name.ident.to_string()),
UseTree::Rename(use_rename) => Some(use_rename.ident.to_string()),
_ => None,
}
}
fn analyze_use_tree(&mut self, use_tree: &UseTree, line: usize) {
match use_tree {
UseTree::Path(use_path) => {
self.analyze_use_tree(&use_path.tree, line);
}
UseTree::Group(use_group) => {
for item in &use_group.items {
self.analyze_use_tree(item, line);
}
}
UseTree::Glob(_) => {
let import =
Self::create_module_import(self.current_module.clone(), vec![], true, line);
self.module_imports.push(import);
}
_ => {
if let Some(imported_item) = Self::extract_import_name(use_tree) {
let import = Self::create_module_import(
self.current_module.clone(),
vec![imported_item],
false,
line,
);
self.module_imports.push(import);
}
}
}
}
}
impl<'ast> Visit<'ast> for CrossModuleCallVisitor {
fn visit_item_fn(&mut self, item: &'ast ItemFn) {
let func_name = item.sig.ident.to_string();
let line = self.get_line_number(item.sig.ident.span());
self.current_function = Some(FunctionId::new(
PathBuf::new(), func_name,
line,
));
syn::visit::visit_item_fn(self, item);
self.current_function = None;
}
fn visit_item_use(&mut self, item: &'ast ItemUse) {
let line = self.get_line_number(item.use_token.span);
self.analyze_use_tree(&item.tree, line);
syn::visit::visit_item_use(self, item);
}
fn visit_expr_call(&mut self, expr: &'ast syn::ExprCall) {
if let Some(caller) = &self.current_function {
if let syn::Expr::Path(path_expr) = &*expr.func {
let path_string = self.extract_path_string(&path_expr.path);
let line = self.get_line_number(path_expr.path.span());
if path_string.contains("::") {
let parts: Vec<&str> = path_string.rsplitn(2, "::").collect();
if parts.len() == 2 {
let function_name = parts[0].to_string();
let module_path = parts[1].to_string();
let cross_call = CrossModuleCall {
caller: caller.clone(),
module_path,
function_name,
line,
through_import: false, };
self.cross_module_calls.push(cross_call);
}
}
}
}
syn::visit::visit_expr_call(self, expr);
}
}
impl Default for CrossModuleTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::UseGlob;
fn create_test_export(
name: &str,
export_type: ExportType,
func_id: Option<FunctionId>,
) -> PublicExport {
PublicExport {
name: name.to_string(),
export_type,
function_id: func_id,
visibility: VisibilityLevel::Public,
line: 1,
}
}
fn create_test_function_id(name: &str) -> FunctionId {
FunctionId::new(PathBuf::from("test.rs"), name.to_string(), 10)
}
#[test]
fn test_find_function_export_found() {
let func_id = create_test_function_id("test_func");
let exports = vec![
create_test_export(
"other_func",
ExportType::Function,
Some(create_test_function_id("other")),
),
create_test_export("test_func", ExportType::Function, Some(func_id.clone())),
create_test_export("test_type", ExportType::Type, None),
]
.into_iter()
.collect();
let boundary = ModuleBoundary {
module_path: "test::module".to_string(),
file_path: PathBuf::from("test.rs"),
parent_module: None,
submodules: HashSet::new(),
public_exports: exports,
};
let result = CrossModuleTracker::find_function_export(&boundary, "test_func");
assert!(result.is_some());
assert_eq!(result.unwrap().name, "test_func");
}
#[test]
fn test_find_function_export_not_found() {
let exports = vec![
create_test_export(
"func1",
ExportType::Function,
Some(create_test_function_id("func1")),
),
create_test_export(
"func2",
ExportType::Function,
Some(create_test_function_id("func2")),
),
]
.into_iter()
.collect();
let boundary = ModuleBoundary {
module_path: "test::module".to_string(),
file_path: PathBuf::from("test.rs"),
parent_module: None,
submodules: HashSet::new(),
public_exports: exports,
};
let result = CrossModuleTracker::find_function_export(&boundary, "non_existent");
assert!(result.is_none());
}
#[test]
fn test_find_function_export_wrong_type() {
let exports = vec![
create_test_export("test_name", ExportType::Type, None),
create_test_export("test_const", ExportType::Constant, None),
]
.into_iter()
.collect();
let boundary = ModuleBoundary {
module_path: "test::module".to_string(),
file_path: PathBuf::from("test.rs"),
parent_module: None,
submodules: HashSet::new(),
public_exports: exports,
};
let result = CrossModuleTracker::find_function_export(&boundary, "test_name");
assert!(result.is_none());
}
#[test]
fn test_resolve_through_reexports_found() {
let mut tracker = CrossModuleTracker::new();
let func_id = create_test_function_id("target_func");
let exports = vec![create_test_export(
"target_func",
ExportType::Function,
Some(func_id.clone()),
)]
.into_iter()
.collect();
let boundary = ModuleBoundary {
module_path: "target::module".to_string(),
file_path: PathBuf::from("target.rs"),
parent_module: None,
submodules: HashSet::new(),
public_exports: exports,
};
tracker
.module_boundaries
.insert("target::module".to_string(), boundary);
let reexported_modules = vec!["target::module".to_string()].into_iter().collect();
let result = tracker.resolve_through_reexports(&reexported_modules, "target_func");
assert!(result.is_some());
assert_eq!(result.unwrap().name, "target_func");
}
#[test]
fn test_resolve_through_reexports_not_found() {
let tracker = CrossModuleTracker::new();
let reexported_modules = vec!["module1".to_string(), "module2".to_string()]
.into_iter()
.collect();
let result = tracker.resolve_through_reexports(&reexported_modules, "non_existent");
assert!(result.is_none());
}
#[test]
fn test_resolve_through_reexports_empty_list() {
let tracker = CrossModuleTracker::new();
let reexported_modules = Vector::new();
let result = tracker.resolve_through_reexports(&reexported_modules, "any_func");
assert!(result.is_none());
}
#[test]
fn test_resolve_module_call_direct() {
let mut tracker = CrossModuleTracker::new();
let func_id = create_test_function_id("direct_func");
let exports = vec![create_test_export(
"direct_func",
ExportType::Function,
Some(func_id.clone()),
)]
.into_iter()
.collect();
let boundary = ModuleBoundary {
module_path: "direct::module".to_string(),
file_path: PathBuf::from("direct.rs"),
parent_module: None,
submodules: HashSet::new(),
public_exports: exports,
};
tracker
.module_boundaries
.insert("direct::module".to_string(), boundary);
let result = tracker.resolve_module_call("direct::module", "direct_func");
assert!(result.is_some());
assert_eq!(result.unwrap().name, "direct_func");
}
#[test]
fn test_resolve_module_call_via_reexport() {
let mut tracker = CrossModuleTracker::new();
let func_id = create_test_function_id("reexported_func");
let exports = vec![create_test_export(
"reexported_func",
ExportType::Function,
Some(func_id.clone()),
)]
.into_iter()
.collect();
let boundary = ModuleBoundary {
module_path: "original::module".to_string(),
file_path: PathBuf::from("original.rs"),
parent_module: None,
submodules: HashSet::new(),
public_exports: exports,
};
tracker
.module_boundaries
.insert("original::module".to_string(), boundary);
tracker.reexports.insert(
"facade::module".to_string(),
vec!["original::module".to_string()].into_iter().collect(),
);
let result = tracker.resolve_module_call("facade::module", "reexported_func");
assert!(result.is_some());
assert_eq!(result.unwrap().name, "reexported_func");
}
#[test]
fn test_create_module_import_with_items() {
let import = CrossModuleCallVisitor::create_module_import(
"test_module".to_string(),
vec!["func1".to_string(), "func2".to_string()],
false,
42,
);
assert_eq!(import.importing_module, "test_module");
assert_eq!(import.imported_module, "unknown");
assert_eq!(import.imported_items.len(), 2);
assert!(import.imported_items.contains(&"func1".to_string()));
assert!(import.imported_items.contains(&"func2".to_string()));
assert!(!import.is_glob_import);
assert_eq!(import.line, 42);
}
#[test]
fn test_create_module_import_glob() {
let import = CrossModuleCallVisitor::create_module_import(
"glob_module".to_string(),
vec![],
true,
100,
);
assert_eq!(import.importing_module, "glob_module");
assert_eq!(import.imported_module, "unknown");
assert_eq!(import.imported_items.len(), 0);
assert!(import.is_glob_import);
assert_eq!(import.line, 100);
}
#[test]
fn test_extract_import_name_from_use_name() {
use syn::{Ident, UseName, UseTree};
let ident = Ident::new("test_func", proc_macro2::Span::call_site());
let use_name = UseName { ident };
let use_tree = UseTree::Name(use_name);
let result = CrossModuleCallVisitor::extract_import_name(&use_tree);
assert_eq!(result, Some("test_func".to_string()));
}
#[test]
fn test_extract_import_name_from_use_rename() {
use syn::{Ident, UseRename, UseTree};
let ident = Ident::new("original_name", proc_macro2::Span::call_site());
let as_token = syn::Token);
let rename = Ident::new("new_name", proc_macro2::Span::call_site());
let use_rename = UseRename {
ident,
as_token,
rename,
};
let use_tree = UseTree::Rename(use_rename);
let result = CrossModuleCallVisitor::extract_import_name(&use_tree);
assert_eq!(result, Some("original_name".to_string()));
}
#[test]
fn test_extract_import_name_from_glob() {
use syn::{UseGlob, UseTree};
let star_token = syn::Token);
let use_glob = UseGlob { star_token };
let use_tree = UseTree::Glob(use_glob);
let result = CrossModuleCallVisitor::extract_import_name(&use_tree);
assert_eq!(result, None);
}
#[test]
fn test_extract_import_name_from_path() {
use syn::{Ident, UsePath, UseTree};
let ident = Ident::new("path", proc_macro2::Span::call_site());
let colon2_token = syn::Token);
let tree = Box::new(UseTree::Glob(UseGlob {
star_token: syn::Token),
}));
let use_path = UsePath {
ident,
colon2_token,
tree,
};
let use_tree = UseTree::Path(use_path);
let result = CrossModuleCallVisitor::extract_import_name(&use_tree);
assert_eq!(result, None);
}
}