hotload_macro 1.4.0

Zero cost hot update dynamic library; supporting DLL, SO
Documentation
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse_macro_input, Expr, ExprTuple, FnArg, Lit, ReturnType};
use std::path::PathBuf;

mod gen_rust_types;

/// hotload_macro
/// generate a struct with #[no_mangle] functions and struct type and same other types from a lib.rs file
#[proc_macro]
pub fn gen_rust_no_mangle(input: TokenStream) -> TokenStream {
    // Parse the input as an expression tuple
    let input = parse_macro_input!(input as ExprTuple);
    let elems = &input.elems;

    // Ensure we have exactly two elements in the tuple
    if elems.len() < 2 {
        return syn::Error::new_spanned(
            input,
            "Expected a tuple with two elements: (StructName, FilePath, Optioned value: bool is gen rust types)",
        )
        .to_compile_error()
        .into();
    }

    // Parse the first element as a path (struct name)
    let struct_name = match &elems[0] {
        Expr::Path(expr_path) => match expr_path.path.get_ident().cloned() {
            Some(ident) => ident,
            None => {
                return syn::Error::new_spanned(&elems[0], "Expected a struct name identifier")
                    .to_compile_error()
                    .into();
            }
        },
        _ => {
            return syn::Error::new_spanned(&elems[0], "Expected a struct name")
                .to_compile_error()
                .into()
        }
    };

    // Parse the second element as a string literal (file path)
    let file_path = match &elems[1] {
        Expr::Lit(expr_lit) => match &expr_lit.lit {
            Lit::Str(lit_str) => lit_str.value(),
            _ => {
                return syn::Error::new_spanned(
                    &elems[1],
                    "Expected a string literal for the file path",
                )
                .to_compile_error()
                .into()
            }
        },
        _ => {
            return syn::Error::new_spanned(
                &elems[1],
                "Expected a string literal for the file path",
            )
            .to_compile_error()
            .into()
        }
    };

    // Parse the third element as a bool (is gen rust types)
    let is_gen_rust_types = match &elems.get(2) {
        Some(v) => match v {
            Expr::Lit(expr_lit) => match &expr_lit.lit {
                Lit::Bool(v) => v.value(),
                _ => {
                    return syn::Error::new_spanned(v, "Expected bool for gen rust types")
                        .to_compile_error()
                        .into()
                }
            },
            _ => {
                return syn::Error::new_spanned(v, "Expected bool for gen rust types")
                    .to_compile_error()
                    .into()
            }
        },
        None => false,
    };

    // Read the file content
    let lib_rs_path = resolve_input_path(&file_path);
    let lib_rs_content = match std::fs::read_to_string(&lib_rs_path) {
        Ok(content) => content,
        Err(err) => {
            return syn::Error::new_spanned(
                &elems[1],
                format!(
                    "Failed to read file: {} ({})",
                    lib_rs_path.display(),
                    err
                ),
            )
            .to_compile_error()
            .into();
        }
    };

    // Parse the file content
    let file = match syn::parse_file(&lib_rs_content) {
        Ok(file) => file,
        Err(err) => {
            return syn::Error::new_spanned(
                &elems[1],
                format!(
                    "Failed to parse file: {} ({})",
                    lib_rs_path.display(),
                    err
                ),
            )
            .to_compile_error()
            .into();
        }
    };

    // Extract #[no_mangle] functions and statics
    let mut field_defs = vec![];
    let mut field_loads = vec![];
    let mut accessors = vec![];
    let mut types = vec![];
    let mut is_gen_a = false;

    for item in file.items {
        if let syn::Item::Static(item) = item {
            // Check if the static is public
            match item.vis {
                syn::Visibility::Public(_) => {}
                _ => continue,
            };

            // Check if the static has #[no_mangle] attribute
            let has_no_mangle = item
                .attrs
                .iter()
                .any(|attr| attr.to_token_stream().to_string().contains("no_mangle"));
            if !has_no_mangle {
                continue;
            }

            #[allow(unused_mut)]
            let mut static_type = Some(&item.ty);
            let mut_val = item.mutability.clone();

            is_gen_a = true;
            let is_mut = matches!(mut_val, syn::StaticMutability::Mut(_));
            let mut_quote = match mut_val {
                _ if static_type.to_token_stream().to_string().replace(' ', "").contains("&str") => {
                    // not support static &str
                    continue;
                }
                syn::StaticMutability::Mut(_) => {
                    quote! { &'a mut }
                }
                syn::StaticMutability::None => {
                    quote! { &'a }
                }
                _ => {
                    // unknown static mutability
                    quote! { &'a }
                }
            };

            let static_name = &item.ident;
            let name_str = static_name.to_string();

            field_defs.push(quote! {
                #static_name: #mut_quote #static_type,
            });

            field_loads.push(quote! {
                #static_name: lib.symbol_cstr(
                    ::std::ffi::CStr::from_bytes_with_nul_unchecked(concat!(#name_str, "\0").as_bytes())
                )?
            });

            let const_accessor = quote! {
                pub fn #static_name(&self) -> &#static_type {
                    self.#static_name
                }
            };

            if is_mut {
                let mut_ident = syn::Ident::new(&format!("{}_mut", static_name), static_name.span());
                accessors.push(quote! {
                    #const_accessor
                    pub fn #mut_ident(&mut self) -> &mut #static_type {
                        self.#static_name
                    }
                });
            } else {
                accessors.push(const_accessor);
            }
        } else if let syn::Item::Fn(item) = item {
            // Check if the function is public
            match item.vis {
                syn::Visibility::Public(_) => {}
                _ => continue,
            };

            // Check if the function has #[no_mangle] attribute
            let has_no_mangle = item
                .attrs
                .iter()
                .any(|attr| attr.to_token_stream().to_string().contains("no_mangle"));
            if !has_no_mangle {
                continue;
            }

            let fn_name = &item.sig.ident;
            let name_str = fn_name.to_string();

            // Process function arguments (name: type pairs)
            let inputs: Result<Vec<_>, syn::Error> = item
                .sig
                .inputs
                .iter()
                .map(|arg| {
                    if let FnArg::Typed(pat_type) = arg {
                        if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
                            let name = &pat_ident.ident;
                            let ty = &pat_type.ty;
                            Ok(quote! { #name: #ty })
                        } else {
                            Err(syn::Error::new_spanned(
                                &pat_type.pat,
                                "Unsupported function argument pattern",
                            ))
                        }
                    } else {
                        Err(syn::Error::new_spanned(
                            arg,
                            "Unsupported function argument pattern",
                        ))
                    }
                })
                .collect();

            let inputs = match inputs {
                Ok(v) => v,
                Err(err) => return err.to_compile_error().into(),
            };

            // Collect argument names for function call forwarding
            let arg_names: Vec<_> = item
                .sig
                .inputs
                .iter()
                .filter_map(|arg| {
                    if let FnArg::Typed(pat_type) = arg {
                        if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
                            Some(pat_ident.ident.clone())
                        } else {
                            None
                        }
                    } else {
                        None
                    }
                })
                .collect();

            // Process return type
            let ret_type = match &item.sig.output {
                ReturnType::Type(_, ty) => quote! { -> #ty },
                ReturnType::Default => quote! {},
            };

            if ret_type.to_string().contains("'a") {
                is_gen_a = true;
            }

            field_defs.push(quote! {
                #fn_name: fn(#(#inputs),*) #ret_type,
            });

            field_loads.push(quote! {
                #fn_name: lib.symbol_cstr(
                    ::std::ffi::CStr::from_bytes_with_nul_unchecked(concat!(#name_str, "\0").as_bytes())
                )?
            });

            accessors.push(quote! {
                pub fn #fn_name(&self, #(#inputs),*) #ret_type {
                    (self.#fn_name)(#(#arg_names),*)
                }
            });
        } else if is_gen_rust_types {
            if let Some((_name, v)) = gen_rust_types::gen_rust_types(item) {
                types.push(v);
            }
        }
    }

    let gen_a = if is_gen_a {
        quote! { <'a> }
    } else {
        quote! {}
    };

    // Generate the struct with manual WrapperApi impl (no derive needed)
    // Uses hotload::dlopen2:: paths so users don't need dlopen2 as a direct dependency
    let expanded = quote! {
        #(#types)*

        pub struct #struct_name #gen_a {
            #(#field_defs)*
        }

        impl #gen_a hotload::dlopen2::wrapper::WrapperApi for #struct_name #gen_a {
            unsafe fn load(lib: &hotload::dlopen2::raw::Library) -> ::std::result::Result<Self, hotload::dlopen2::Error> {
                unsafe {
                    Ok(Self {
                        #(#field_loads),*
                    })
                }
            }
        }

        #[allow(dead_code)]
        impl #gen_a #struct_name #gen_a {
            #(#accessors)*
        }
    };

    // 注意:proc-macro 中禁止使用 println!,会干扰 rust-analyzer 的 stdout 通信管道
    // eprintln!("hotload_macro gen code:\n{}\n", expanded);

    // Return the generated code as TokenStream
    TokenStream::from(expanded)
}

fn resolve_input_path(path: &str) -> PathBuf {
    let input = PathBuf::from(path);
    if input.is_absolute() {
        return input;
    }

    if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
        let mut base = PathBuf::from(manifest_dir);
        base.push(path);
        return base;
    }

    if let Ok(current_dir) = std::env::current_dir() {
        let mut base = current_dir;
        base.push(path);
        return base;
    }

    input
}