use std::collections::{BTreeSet, HashSet};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, quote_spanned};
use syn::parse::{Parse, ParseStream, Parser};
use syn::spanned::Spanned;
use syn::visit::Visit;
#[proc_macro]
pub fn init_static(input: TokenStream) -> TokenStream {
init_static_inner(input.into()).into()
}
fn parse_repeated<T: Parse>(tokens: TokenStream2) -> syn::Result<Vec<T>> {
let parser = |input: ParseStream| {
let mut items = vec![];
while !input.is_empty() {
items.push(input.parse::<T>()?);
}
Ok(items)
};
parser.parse2(tokens)
}
pub(crate) fn init_static_inner(input: TokenStream2) -> TokenStream2 {
let input_items = match parse_repeated::<syn::Item>(input) {
Ok(items) => items,
Err(err) => return err.to_compile_error(),
};
let mut output = TokenStream2::new();
let mut inner = TokenStream2::new();
for item in input_items {
let syn::Item::Static(item_static) = item else {
output.extend(quote! { #item });
continue;
};
let mut is_try = false;
let mut is_async = false;
let mut free_paths = BTreeSet::new();
let mut scope = Scope {
is_try: &mut is_try,
is_async: &mut is_async,
free_paths: &mut free_paths,
locals: HashSet::new(),
};
scope.visit_item_static(&item_static);
let item_vis = &item_static.vis;
let item_ident = &item_static.ident;
let item_mut = &item_static.mutability;
let item_ty = &item_static.ty;
let item_expr = &item_static.expr;
let ty_span = item_ty.span();
let ident_span = item_ident.span();
let static_ty = quote_spanned! { ty_span =>
::init_static::InitStatic<#item_ty>
};
let static_expr = quote_spanned! { ident_span =>
::init_static::InitStatic!(#item_ident)
};
output.extend(quote! {
#[allow(clippy::type_complexity)]
#item_vis static #item_mut #item_ident: #static_ty = #static_expr;
});
let (deps_ident, deps_item) = if free_paths.is_empty() {
(quote! { ::std::vec::Vec::new }, quote! {})
} else {
let deps_ident = syn::Ident::new(&format!("DEPS_{item_ident}"), ident_span);
let deps_stmts = free_paths.iter().map(|path| {
let path = &path.path;
quote! {
(&#path).__get_symbol()
}
});
(
quote! { #deps_ident },
quote! {
#[allow(non_snake_case, clippy::needless_borrow)]
fn #deps_ident() -> ::std::vec::Vec<::std::option::Option<&'static ::init_static::Symbol>> {
use ::init_static::__private::MaybeInitStatic;
::std::vec![#(#deps_stmts),*]
}
},
)
};
let init_ident = syn::Ident::new(&format!("INIT_{item_ident}"), item_ident.span());
let (init_variant, init_item) = if is_async {
(
quote! { Async },
quote! {
#[allow(non_snake_case)]
fn #init_ident() -> ::init_static::__private::BoxFuture<::init_static::__private::anyhow::Result<()>> {
Box::pin(async {
::init_static::InitStatic::init(&#item_ident, #item_expr);
Ok(())
})
}
},
)
} else {
(
quote! { Sync },
quote! {
#[allow(non_snake_case)]
fn #init_ident() -> ::init_static::__private::anyhow::Result<()> {
::init_static::InitStatic::init(&#item_ident, #item_expr);
Ok(())
}
},
)
};
inner.extend(quote! {
#[::init_static::__private::linkme::distributed_slice(::init_static::__private::INIT)]
#[linkme(crate = ::init_static::__private::linkme)]
static #init_ident: ::init_static::__private::Init = {
#init_item
#deps_item
::init_static::__private::Init {
symbol: ::init_static::InitStatic::symbol(&#item_ident),
init: ::init_static::__private::InitFn::#init_variant(#init_ident),
deps: #deps_ident,
}
};
});
}
quote! {
#output
const _: () = {
#inner
};
}
}
struct Path<'ast> {
path: &'ast syn::Path,
repr: String,
}
impl<'ast> Path<'ast> {
fn new(inner: &'ast syn::Path) -> Self {
let repr = quote! { #inner }.to_string();
Self { path: inner, repr }
}
}
impl<'ast> ::std::cmp::PartialEq for Path<'ast> {
fn eq(&self, other: &Self) -> bool {
self.repr == other.repr
}
}
impl<'ast> ::std::cmp::Eq for Path<'ast> {}
impl<'ast> ::std::cmp::PartialOrd for Path<'ast> {
fn partial_cmp(&self, other: &Self) -> Option<::std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<'ast> ::std::cmp::Ord for Path<'ast> {
fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
self.repr.cmp(&other.repr)
}
}
struct Scope<'a, 'ast> {
is_try: &'a mut bool,
is_async: &'a mut bool,
free_paths: &'a mut BTreeSet<Path<'ast>>,
locals: HashSet<&'ast syn::Ident>,
}
impl<'i, 'ast> Visit<'ast> for Scope<'i, 'ast> {
fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) {
if expr_path.qself.is_none()
&& self.locals.iter().all(|&ident| !expr_path.path.is_ident(ident))
&& let Some(last_segment) = expr_path.path.segments.last()
&& last_segment.ident == last_segment.ident.to_string().to_ascii_uppercase()
{
self.free_paths.insert(Path::new(&expr_path.path));
}
syn::visit::visit_expr_path(self, expr_path);
}
fn visit_pat_ident(&mut self, pat_ident: &'ast syn::PatIdent) {
self.locals.insert(&pat_ident.ident);
syn::visit::visit_pat_ident(self, pat_ident);
}
fn visit_block(&mut self, block: &'ast syn::Block) {
let mut locals = HashSet::new();
for stmt in &block.stmts {
if let syn::Stmt::Item(item) = stmt {
match item {
syn::Item::Const(item_const) => {
locals.insert(&item_const.ident);
}
syn::Item::Static(item_static) => {
locals.insert(&item_static.ident);
}
_ => {}
}
}
}
let mut scope = Scope {
is_try: self.is_try,
is_async: self.is_async,
free_paths: self.free_paths,
locals: locals.union(&self.locals).cloned().collect(),
};
for stmt in &block.stmts {
match stmt {
syn::Stmt::Local(local) => {
for attrs in &local.attrs {
scope.visit_attribute(attrs);
}
if let Some(init) = &local.init {
scope.visit_local_init(init);
}
scope.visit_pat(&local.pat);
}
syn::Stmt::Expr(expr, _) => {
scope.visit_expr(expr);
}
syn::Stmt::Item(_item) => {
}
syn::Stmt::Macro(_macro) => {
}
}
scope.visit_stmt(stmt);
}
}
fn visit_expr_closure(&mut self, expr_closure: &'ast syn::ExprClosure) {
for attrs in &expr_closure.attrs {
self.visit_attribute(attrs);
}
let mut scope = Scope {
is_try: self.is_try,
is_async: self.is_async,
free_paths: self.free_paths,
locals: self.locals.clone(),
};
for pat in &expr_closure.inputs {
scope.visit_pat(pat);
}
scope.visit_return_type(&expr_closure.output);
scope.visit_expr(&expr_closure.body);
}
fn visit_expr_try(&mut self, expr_try: &'ast syn::ExprTry) {
*self.is_try = true;
syn::visit::visit_expr_try(self, expr_try);
}
fn visit_expr_await(&mut self, expr_await: &'ast syn::ExprAwait) {
*self.is_async = true;
syn::visit::visit_expr_await(self, expr_await);
}
}
#[cfg(test)]
mod test {
use std::env::var;
use std::fs::{create_dir_all, read_to_string, write};
use std::path::{Path, PathBuf};
use macro_expand::Context;
use pretty_assertions::StrComparison;
use prettyplease::unparse;
use walkdir::WalkDir;
use super::*;
struct TestDiff {
path: PathBuf,
expect: String,
actual: String,
}
#[test]
fn fixtures() {
let input_dir = "fixtures/input";
let output_dir = "fixtures/output";
let mut diffs = vec![];
let will_emit = var("EMIT").is_ok_and(|v| !v.is_empty());
for entry in WalkDir::new(input_dir).into_iter().filter_map(Result::ok) {
let input_path = entry.path();
if !input_path.is_file() || input_path.extension() != Some("rs".as_ref()) {
continue;
}
let path = input_path.strip_prefix(input_dir).unwrap();
let output_path = Path::new(output_dir).join(path);
let input = read_to_string(input_path).unwrap().parse().unwrap();
let mut ctx = Context::new();
ctx.register_proc_macro("init_static".into(), init_static_inner);
let actual = unparse(&syn::parse2(ctx.transform(input)).unwrap());
let expect_result = read_to_string(&output_path);
if let Ok(expect) = &expect_result
&& expect == &actual
{
continue;
}
if will_emit {
create_dir_all(output_path.parent().unwrap()).unwrap();
write(output_path, &actual).unwrap();
}
if let Ok(expect) = expect_result {
diffs.push(TestDiff {
path: path.to_path_buf(),
expect,
actual,
});
}
}
let len = diffs.len();
for diff in diffs {
eprintln!("diff {}", diff.path.display());
eprintln!("{}", StrComparison::new(&diff.expect, &diff.actual));
}
if len > 0 && !will_emit {
panic!("Some tests failed");
}
}
}