use crate::*;
fn inject_at_start(
input: TokenStream,
before_fn: impl FnOnce(&Ident, &Ident) -> TokenStream2,
) -> TokenStream {
let mut input_fn: ItemFn = parse_macro_input!(input as ItemFn);
let vis: &Visibility = &input_fn.vis;
let sig: &mut Signature = &mut input_fn.sig;
let block: &Block = &input_fn.block;
let attrs: &Vec<Attribute> = &input_fn.attrs;
match parse_context_from_signature(sig) {
Ok(context) => match parse_stream_from_signature(sig) {
Ok(stream) => {
let before_code: TokenStream2 = before_fn(&context, &stream);
let stmts: &Vec<Stmt> = &block.stmts;
let gen_code: TokenStream2 = quote! {
#(#attrs)*
#vis #sig {
#before_code
#(#stmts)*
}
};
gen_code.into()
}
Err(err) => err.to_compile_error().into(),
},
Err(err) => err.to_compile_error().into(),
}
}
fn inject_at_end(
input: TokenStream,
after_fn: impl FnOnce(&Ident, &Ident) -> TokenStream2,
) -> TokenStream {
let mut input_fn: ItemFn = parse_macro_input!(input as ItemFn);
let vis: &Visibility = &input_fn.vis;
let sig: &mut Signature = &mut input_fn.sig;
let block: &Block = &input_fn.block;
let attrs: &Vec<Attribute> = &input_fn.attrs;
match parse_context_from_signature(sig) {
Ok(context) => match parse_stream_from_signature(sig) {
Ok(stream) => {
let after_code: TokenStream2 = after_fn(&context, &stream);
let stmts: &Vec<Stmt> = &block.stmts;
let (leading_stmts, tail_expr) = if let Some((last, leading)) = stmts.split_last() {
match last {
Stmt::Expr(expr, None) => (leading, Some(quote! { #expr })),
_ => (stmts.as_slice(), None),
}
} else {
(stmts.as_slice(), None)
};
let normalized_leading: Vec<TokenStream2> = leading_stmts
.iter()
.map(|stmt| match stmt {
Stmt::Expr(expr, None) => quote! { #expr; },
_ => quote! { #stmt },
})
.collect();
let gen_code: TokenStream2 = match tail_expr {
Some(expr) => quote! {
#(#attrs)*
#vis #sig {
#(#normalized_leading)*
#after_code
#expr
}
},
None => quote! {
#(#attrs)*
#vis #sig {
#(#normalized_leading)*
#after_code
}
},
};
gen_code.into()
}
Err(err) => err.to_compile_error().into(),
},
Err(err) => err.to_compile_error().into(),
}
}
pub(crate) fn inject(
position: Position,
input: TokenStream,
hook: impl FnOnce(&Ident, &Ident) -> TokenStream2,
) -> TokenStream {
match position {
Position::Prologue => inject_at_start(input, hook),
Position::Epilogue => inject_at_end(input, hook),
}
}
fn is_context_type(ty: &Type) -> bool {
if let Type::Reference(type_ref) = ty
&& let Type::Path(type_path) = &*type_ref.elem
{
let path: &Path = &type_path.path;
if path.segments.len() >= 2 {
let segments: Vec<_> = path.segments.iter().collect();
if segments.len() >= 2 {
let last_two: &[&PathSegment] = &segments[segments.len() - 2..];
if last_two[0].ident == "hyperlane" && last_two[1].ident == "Context" {
return true;
}
}
}
if path.segments.len() == 1 && path.segments[0].ident == "Context" {
return true;
}
}
false
}
fn is_stream_type(ty: &Type) -> bool {
if let Type::Reference(type_ref) = ty
&& let Type::Path(type_path) = &*type_ref.elem
{
let path: &Path = &type_path.path;
if path.segments.len() >= 2 {
let segments: Vec<_> = path.segments.iter().collect();
if segments.len() >= 2 {
let last_two: &[&PathSegment] = &segments[segments.len() - 2..];
if last_two[0].ident == "hyperlane" && last_two[1].ident == "Stream" {
return true;
}
}
}
if path.segments.len() == 1 && path.segments[0].ident == "Stream" {
return true;
}
}
false
}
pub(crate) fn parse_context_from_signature(sig: &mut Signature) -> syn::Result<Ident> {
for arg in sig.inputs.iter() {
if let FnArg::Typed(pat_type) = arg
&& is_context_type(&pat_type.ty)
{
match &*pat_type.pat {
Pat::Ident(pat_ident) => return Ok(pat_ident.ident.clone()),
Pat::Wild(_) => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"anonymous `_` parameter is not allowed for context; please use a named identifier like `ctx`",
));
}
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected identifier for context parameter",
));
}
}
}
}
Err(syn::Error::new_spanned(
&sig.inputs,
"expected at least one parameter of type &::hyperlane::Context",
))
}
pub(crate) fn parse_stream_from_signature(sig: &mut Signature) -> syn::Result<Ident> {
for arg in sig.inputs.iter() {
if let FnArg::Typed(pat_type) = arg
&& is_stream_type(&pat_type.ty)
{
match &*pat_type.pat {
Pat::Ident(pat_ident) => return Ok(pat_ident.ident.clone()),
Pat::Wild(_) => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"anonymous `_` parameter is not allowed for stream; please use a named identifier like `stream`",
));
}
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected identifier for stream parameter",
));
}
}
}
}
Err(syn::Error::new_spanned(
&sig.inputs,
"expected at least one parameter of type &::hyperlane::Stream",
))
}
pub(crate) fn expr_to_isize(opt_expr: &Option<Expr>) -> TokenStream2 {
match opt_expr {
Some(expr) => match expr {
Expr::Lit(ExprLit {
lit: Lit::Int(lit_int),
..
}) => {
let value: isize = lit_int.base10_parse::<isize>().unwrap();
quote! { Some(#value) }
}
Expr::Lit(ExprLit {
lit: Lit::Str(lit_str),
..
}) => {
let value: isize = lit_str.value().parse().expect("Cannot parse to isize");
quote! { Some(#value) }
}
_ => quote! { None },
},
None => quote! { None },
}
}
pub(crate) fn leak_mut_context(context: &Ident) -> TokenStream2 {
quote! {
unsafe { #context.leak_mut() }
}
}
pub(crate) fn leak_context(context: &Ident) -> TokenStream2 {
quote! {
unsafe { #context.leak() }
}
}