use std::{
collections::BTreeMap,
path::{Path, PathBuf},
};
use syn::{Item, ItemStruct};
pub struct CrateContext {
modules: BTreeMap<String, ParsedModule>,
}
impl CrateContext {
#[allow(dead_code)]
pub fn empty() -> Self {
CrateContext {
modules: BTreeMap::new(),
}
}
pub fn parse_from_manifest() -> syn::Result<Self> {
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").map_err(|_| {
syn::Error::new(
proc_macro2::Span::call_site(),
"CARGO_MANIFEST_DIR not set - cannot parse crate context",
)
})?;
let src_dir = PathBuf::from(&manifest_dir).join("src");
let root_file = if src_dir.join("lib.rs").exists() {
src_dir.join("lib.rs")
} else if src_dir.join("main.rs").exists() {
src_dir.join("main.rs")
} else {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!("Could not find lib.rs or main.rs in {:?}", src_dir),
));
};
Self::parse(&root_file)
}
pub fn parse(root: &Path) -> syn::Result<Self> {
let modules = ParsedModule::parse_recursive(root, "crate")?;
Ok(CrateContext { modules })
}
pub fn structs(&self) -> impl Iterator<Item = &ItemStruct> {
self.modules.values().flat_map(|module| module.structs())
}
pub fn structs_with_derive_and_path(&self, derive_name: &str) -> Vec<(&str, &ItemStruct)> {
self.modules
.iter()
.flat_map(|(path, module)| {
module
.structs()
.filter(|s| has_derive_attribute(&s.attrs, derive_name))
.map(move |s| (path.as_str(), s))
})
.collect()
}
pub fn find_const_module_path(&self, const_name: &str) -> Option<&str> {
for (path, module) in &self.modules {
for item in &module.items {
if let Item::Const(item_const) = item {
if item_const.ident == const_name {
return Some(path.as_str());
}
}
}
}
None
}
pub fn find_fn_module_path(&self, fn_name: &str) -> Option<&str> {
for (path, module) in &self.modules {
for item in &module.items {
if let Item::Fn(item_fn) = item {
if item_fn.sig.ident == fn_name {
return Some(path.as_str());
}
}
}
}
None
}
pub fn is_module_path_public(&self, module_path: &str) -> bool {
if module_path == "crate" {
return true;
}
let segments: Vec<&str> = module_path.split("::").collect();
for i in 1..segments.len() {
let parent_path = segments[..i].join("::");
let child_name = segments[i];
if let Some(parent_module) = self.modules.get(&parent_path) {
let is_pub = parent_module.items.iter().any(|item| {
if let Item::Mod(item_mod) = item {
item_mod.ident == child_name
&& matches!(item_mod.vis, syn::Visibility::Public(_))
} else {
false
}
});
if !is_pub {
return false;
}
} else {
return false;
}
}
true
}
pub fn get_struct_fields(&self, type_name: &syn::Type) -> Option<Vec<String>> {
let struct_name = match type_name {
syn::Type::Path(type_path) => type_path.path.segments.last()?.ident.to_string(),
_ => return None,
};
for item_struct in self.structs() {
if item_struct.ident == struct_name {
if let syn::Fields::Named(named_fields) = &item_struct.fields {
let field_names: Vec<String> = named_fields
.named
.iter()
.filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
.collect();
return Some(field_names);
}
}
}
None
}
}
pub struct ParsedModule {
items: Vec<Item>,
}
impl ParsedModule {
fn parse_recursive(
root: &Path,
module_path: &str,
) -> syn::Result<BTreeMap<String, ParsedModule>> {
let mut modules = BTreeMap::new();
let content = std::fs::read_to_string(root).map_err(|e| {
syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to read {:?}: {}", root, e),
)
})?;
let file: syn::File = syn::parse_str(&content).map_err(|e| {
syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to parse {:?}: {}", root, e),
)
})?;
let root_dir = root.parent().unwrap_or(Path::new("."));
let root_name = root.file_stem().and_then(|s| s.to_str()).unwrap_or("root");
let root_module = ParsedModule {
items: file.items.clone(),
};
modules.insert(module_path.to_string(), root_module);
for item in &file.items {
if let Item::Mod(item_mod) = item {
let mod_name = item_mod.ident.to_string();
let child_path = format!("{}::{}", module_path, mod_name);
if let Some((_, items)) = &item_mod.content {
let inline_module = ParsedModule {
items: items.clone(),
};
modules.insert(child_path, inline_module);
} else {
if let Some(mod_file) = find_module_file(root_dir, root_name, &mod_name) {
let child_modules = Self::parse_recursive(&mod_file, &child_path)?;
modules.extend(child_modules);
}
}
}
}
Ok(modules)
}
fn structs(&self) -> impl Iterator<Item = &ItemStruct> {
self.items.iter().filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
}
}
fn find_module_file(parent_dir: &Path, parent_name: &str, mod_name: &str) -> Option<PathBuf> {
let paths = vec![
parent_dir.join(format!("{}.rs", mod_name)),
parent_dir.join(mod_name).join("mod.rs"),
];
if parent_name == "mod" || parent_name == "lib" {
for path in &paths {
if path.exists() {
return Some(path.clone());
}
}
} else {
let parent_mod_dir = parent_dir.join(parent_name);
let extra_paths = [
parent_mod_dir.join(format!("{}.rs", mod_name)),
parent_mod_dir.join(mod_name).join("mod.rs"),
];
for path in paths.iter().chain(extra_paths.iter()) {
if path.exists() {
return Some(path.clone());
}
}
}
for path in &paths {
if path.exists() {
return Some(path.clone());
}
}
None
}
pub(crate) fn has_derive_attribute(attrs: &[syn::Attribute], derive_name: &str) -> bool {
for attr in attrs {
if !attr.path().is_ident("derive") {
continue;
}
if let Ok(nested) = attr.parse_args_with(
syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
) {
for path in nested {
if let Some(ident) = path.get_ident() {
if ident == derive_name {
return true;
}
}
if let Some(segment) = path.segments.last() {
if segment.ident == derive_name {
return true;
}
}
}
}
}
false
}