#[cfg(feature = "dart-codegen")]
use std::{env, fs::File, io::Write as _, path::PathBuf};
use inflector::Inflector as _;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_quote, punctuated::Punctuated, spanned::Spanned as _, token};
#[cfg(feature = "dart-codegen")]
use crate::dart_codegen::{DartCodegen, FnRegistrationBuilder};
pub(crate) fn expand(
args: TokenStream,
input: TokenStream,
) -> syn::Result<TokenStream> {
let expander = ModExpander::try_from(syn::parse2::<syn::ItemMod>(input)?)?;
#[cfg(feature = "dart-codegen")]
expander.generate_dart_code(&syn::parse2(args)?)?;
#[cfg(not(feature = "dart-codegen"))]
drop(args);
Ok(expander.expand())
}
#[derive(Debug)]
struct ModExpander {
vis: syn::Visibility,
ident: syn::Ident,
attrs: Vec<syn::Attribute>,
uses: Vec<syn::ItemUse>,
fn_expanders: Vec<FnExpander>,
register_fn_name: syn::Ident,
}
impl TryFrom<syn::ItemMod> for ModExpander {
type Error = syn::Error;
fn try_from(module: syn::ItemMod) -> syn::Result<Self> {
use self::mod_parser as parser;
let mod_span = module.span();
let mut extern_fns: Vec<FnExpander> = Vec::new();
let mut use_items = Vec::new();
let register_prefix = &module.ident;
for item in parser::try_unwrap_mod_content(module.content)? {
match item {
syn::Item::ForeignMod(r#mod) => {
for i in r#mod.items {
extern_fns.push(FnExpander::parse(
parser::get_extern_fn(i)?,
register_prefix,
)?);
}
}
syn::Item::Use(r#use) => {
use_items.push(r#use);
}
syn::Item::Const(_)
| syn::Item::Enum(_)
| syn::Item::ExternCrate(_)
| syn::Item::Fn(_)
| syn::Item::Impl(_)
| syn::Item::Macro(_)
| syn::Item::Mod(_)
| syn::Item::Static(_)
| syn::Item::Struct(_)
| syn::Item::Trait(_)
| syn::Item::TraitAlias(_)
| syn::Item::Type(_)
| syn::Item::Union(_)
| syn::Item::Verbatim(_) => {
return Err(syn::Error::new(
item.span(),
"Module contains unsupported content",
));
}
_ => {
return Err(syn::Error::new(
item.span(),
"Module contains unknown content",
));
}
}
}
if extern_fns.is_empty() {
return Err(syn::Error::new(
mod_span,
"At least one `extern \"C\"` block is required",
));
}
Ok(Self {
register_fn_name: format_ident!("register_{}", module.ident),
vis: module.vis,
ident: module.ident,
attrs: module.attrs,
uses: use_items,
fn_expanders: extern_fns,
})
}
}
impl ModExpander {
fn expand(&self) -> TokenStream {
let (vis, ident) = (&self.vis, &self.ident);
let attrs = &self.attrs;
let uses = &self.uses;
let type_aliases =
self.fn_expanders.iter().map(FnExpander::gen_fn_type);
let fn_storages =
self.fn_expanders.iter().map(FnExpander::gen_fn_storages);
let register_fn_name = &self.register_fn_name;
let register_fn_inputs =
self.fn_expanders.iter().map(FnExpander::gen_register_fn_input);
let register_fn_assigns =
self.fn_expanders.iter().map(FnExpander::gen_register_fn_expr);
let caller_fns =
self.fn_expanders.iter().map(FnExpander::gen_caller_fn);
let errors_slots =
self.fn_expanders.iter().map(FnExpander::get_errors_slot);
let errors_setters =
self.fn_expanders.iter().map(FnExpander::gen_error_setter);
quote! {
#[automatically_derived]
#( #attrs )*
#vis mod #ident {
#( #uses )*
#( #type_aliases )*
#( #fn_storages )*
#( #errors_slots )*
#[unsafe(no_mangle)]
pub unsafe extern "C" fn #register_fn_name(
#( #register_fn_inputs, )*
) {
#( #register_fn_assigns; )*
}
#( #errors_setters )*
#( #caller_fns )*
}
}
}
#[cfg(feature = "dart-codegen")]
fn generate_dart_code(
&self,
relative_path: &syn::ExprLit,
) -> syn::Result<()> {
let root_path = env::var("CARGO_MANIFEST_DIR").map_err(|e| {
syn::Error::new(
relative_path.span(),
format!("Cannot read `CARGO_MANIFEST_DIR` env var: {e}"),
)
})?;
let path = PathBuf::from(root_path).join(get_path_arg(relative_path)?);
let mut file = File::create(path).map_err(|e| {
syn::Error::new(
relative_path.span(),
format!("Failed to create file at the provided path: {e}"),
)
})?;
let registerers = self
.fn_expanders
.iter()
.map(|f| FnRegistrationBuilder {
inputs: f.input_args.iter().cloned().collect(),
output: f.ret_ok_ty.clone(),
name: f.ident.clone(),
error_setter_ident: f.error_setter_ident.clone(),
})
.collect::<Vec<_>>();
let generated_code =
DartCodegen::new(&self.register_fn_name, registerers)?
.generate()
.map_err(|e| {
syn::Error::new(
relative_path.span(),
format!("Failed to generate Dart code: {e}"),
)
})?;
file.write_all(generated_code.as_bytes()).map_err(|e| {
let msg = format!(
"Failed to write generated Dart code to the file: {e}",
);
syn::Error::new(relative_path.span(), msg)
})
}
}
#[cfg(feature = "dart-codegen")]
fn get_path_arg(arg: &syn::ExprLit) -> syn::Result<String> {
use proc_macro2::Span;
if let syn::Lit::Str(lit) = &arg.lit {
Ok(lit.value())
} else {
let msg = format!(
"Expected a str literal with a Dart file path, got: {arg:?}",
);
Err(syn::Error::new(Span::call_site(), msg))
}
}
mod mod_parser {
use proc_macro2::Span;
use syn::{spanned::Spanned as _, token};
pub(super) fn get_extern_fn(
item: syn::ForeignItem,
) -> syn::Result<syn::ForeignItemFn> {
if let syn::ForeignItem::Fn(func) = item {
Ok(func)
} else {
Err(syn::Error::new(item.span(), "Unsupported item"))
}
}
pub(super) fn try_unwrap_mod_content(
item: Option<(token::Brace, Vec<syn::Item>)>,
) -> syn::Result<Vec<syn::Item>> {
if let Some((_, items)) = item {
Ok(items)
} else {
Err(syn::Error::new(Span::call_site(), "Empty module provided"))
}
}
}
struct IdentGenerator<'a> {
prefix: &'a syn::Ident,
name: &'a syn::Ident,
}
impl<'a> IdentGenerator<'a> {
const fn new(prefix: &'a syn::Ident, name: &'a syn::Ident) -> Self {
Self { prefix, name }
}
fn type_alias(&self) -> syn::Ident {
format_ident!(
"{}{}Function",
self.prefix.to_string().to_class_case(),
self.name.to_string().to_class_case(),
)
}
fn fn_storage(&self) -> syn::Ident {
format_ident!(
"{}__{}__FUNCTION",
self.prefix.to_string().to_screaming_snake_case(),
self.name.to_string().to_screaming_snake_case(),
)
}
fn error_slot_name(&self) -> syn::Ident {
format_ident!(
"{}__{}__ERROR",
self.prefix.to_string().to_screaming_snake_case(),
self.name.to_string().to_screaming_snake_case(),
)
}
fn error_setter_name(&self) -> syn::Ident {
format_ident!(
"{}__{}__set_error",
self.prefix.to_string().to_lowercase(),
self.name.to_string().to_lowercase(),
)
}
}
#[derive(Debug)]
struct FnExpander {
ident: syn::Ident,
type_alias_ident: syn::Ident,
fn_storage_ident: syn::Ident,
error_slot_ident: syn::Ident,
error_setter_ident: syn::Ident,
input_args: Punctuated<syn::FnArg, token::Comma>,
ret_ty: syn::ReturnType,
ret_ok_ty: syn::Type,
doc_attrs: Vec<syn::Attribute>,
}
impl FnExpander {
fn parse(
item: syn::ForeignItemFn,
prefix: &syn::Ident,
) -> syn::Result<Self> {
for arg in &item.sig.inputs {
match arg {
syn::FnArg::Typed(a) => {
if !matches!(&*a.pat, syn::Pat::Ident(_)) {
return Err(syn::Error::new(
a.span(),
"Incorrect argument identifier",
));
}
}
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new(
arg.span(),
"`self` argument is invalid here",
));
}
}
}
let ret_ok_ty = {
let err = Err(syn::Error::new(
item.sig.output.span(),
"must return `Result<T, platform::Error>`",
));
let syn::ReturnType::Type(_, ret_ty) = item.sig.output.clone()
else {
return err;
};
let syn::Type::Path(ret_ty_path) = *ret_ty else {
return err;
};
let Some(ret_ty_args) =
ret_ty_path.path.segments.last().map(|s| s.arguments.clone())
else {
return err;
};
let syn::PathArguments::AngleBracketed(res) = &ret_ty_args else {
return err;
};
let Some(syn::GenericArgument::Type(res_ok_ty)) = res.args.first()
else {
return err;
};
res_ok_ty.clone()
};
let ident_generator = IdentGenerator::new(prefix, &item.sig.ident);
Ok(Self {
type_alias_ident: ident_generator.type_alias(),
fn_storage_ident: ident_generator.fn_storage(),
error_slot_ident: ident_generator.error_slot_name(),
error_setter_ident: ident_generator.error_setter_name(),
ident: item.sig.ident,
input_args: item.sig.inputs,
ret_ty: item.sig.output,
ret_ok_ty,
doc_attrs: item
.attrs
.into_iter()
.map(|attr| {
if attr.path().get_ident().is_some_and(|i| i == "doc") {
Ok(attr)
} else {
Err(syn::Error::new(
attr.span(),
"only #[doc] attributes supported on extern functions",
))
}
})
.collect::<syn::Result<_>>()?,
})
}
fn gen_register_fn_input(&self) -> syn::FnArg {
let ident = &self.ident;
let fn_type_alias = &self.type_alias_ident;
parse_quote! {
#ident: #fn_type_alias
}
}
fn gen_register_fn_expr(&self) -> syn::Expr {
let fn_storage_ident = &self.fn_storage_ident;
let ident = &self.ident;
parse_quote! {
*#fn_storage_ident.borrow_mut() = Some(#ident)
}
}
fn gen_fn_type(&self) -> TokenStream {
let name = &self.type_alias_ident;
let ret_ok_ty = &self.ret_ok_ty;
let args = &self.input_args;
quote! {
type #name = extern "C" fn (#args) -> #ret_ok_ty;
}
}
fn gen_fn_storages(&self) -> TokenStream {
let name = &self.fn_storage_ident;
let type_alias = &self.type_alias_ident;
quote! {
static #name: ::std::sync::LazyLock<
::send_wrapper::SendWrapper<
::std::cell::RefCell<Option<#type_alias>>>> =
::std::sync::LazyLock::new(|| {
::send_wrapper::SendWrapper::new(
::std::cell::RefCell::new(None)
)
});
}
}
fn gen_caller_fn(&self) -> TokenStream {
let doc_attrs = &self.doc_attrs;
let name = &self.ident;
let error_slot = &self.error_slot_ident;
let args = &self.input_args;
let args_idents = self.input_args.iter().filter_map(|arg| {
if let syn::FnArg::Typed(a) = arg {
if let syn::Pat::Ident(pat) = &*a.pat {
return Some(&pat.ident);
}
}
None
});
let ret_ty = &self.ret_ty;
let fn_storage_ident = &self.fn_storage_ident;
let none_message = format!("`{fn_storage_ident}` is not set");
quote! {
#( #doc_attrs )*
pub unsafe fn #name(#args) #ret_ty {
let res = (*#fn_storage_ident
.borrow()
.as_ref()
.expect(#none_message))
(#( #args_idents ),*);
if let Some(e) = #error_slot.borrow_mut().take() {
Err(e)
} else {
Ok(res)
}
}
}
}
fn get_errors_slot(&self) -> TokenStream {
let name = &self.error_slot_ident;
quote! {
static #name: ::std::sync::LazyLock<::send_wrapper::SendWrapper<
::std::cell::RefCell<Option<Error>>>> =
::std::sync::LazyLock::new(|| {
::send_wrapper::SendWrapper::new(
::std::cell::RefCell::new(None)
)
});
}
}
fn gen_error_setter(&self) -> TokenStream {
let doc = format!("Error setter for the `{}` function", self.ident);
let fn_name = &self.error_setter_ident;
let error_slot = &self.error_slot_ident;
quote! {
#[doc = #doc]
#[unsafe(no_mangle)]
pub unsafe extern "C" fn #fn_name(err: Dart_Handle) {
_ = #error_slot.replace(Some(Error::from_handle(err)));
}
}
}
}