milter-callback 0.2.4

Attribute macros for milter callback generation
Documentation
use crate::tree_preds::*;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{FnArg, Ident, ItemFn};

#[derive(Clone)]
pub struct CallbackFn {
    name: Ident,
    handler_fn: ItemFn,
    handler_arg_index: usize,
    args: Vec<ArgSpec>,
    ok_result_arms: Option<TokenStream>,
}

impl CallbackFn {
    pub fn new(name: Ident, handler_fn: ItemFn) -> Self {
        Self {
            name,
            handler_fn,
            handler_arg_index: 0,
            args: vec![],
            ok_result_arms: None,
        }
    }

    pub fn input_unbound(&mut self, name: Ident, c_type: TokenStream) -> &mut Self {
        self.args.push(ArgSpec::Input(InputArg::new(name, c_type)));
        self
    }

    pub fn input(&mut self, name: Ident, c_type: TokenStream, arg_binding: Binding) -> &mut Self {
        let gen = arg_binding.to_gen(&self.handler_fn.sig.inputs[self.handler_arg_index]);
        self.args.push(ArgSpec::Bound(BoundArg::new(name, c_type, gen)));
        self.handler_arg_index += 1;
        self
    }

    pub fn extra_arg(&mut self, handler_arg_expr: TokenStream) -> &mut Self {
        self.args.push(ArgSpec::Extra(handler_arg_expr));
        self.handler_arg_index += 1;
        self
    }

    pub fn ok_result_arms(&mut self, ok_result_arms: TokenStream) -> &mut Self {
        self.ok_result_arms = Some(ok_result_arms);
        self
    }

    pub fn generate(&self) -> TokenStream {
        let callback_name = &self.name;
        let callback_inputs = self.gen_callback_inputs();
        let handler_call = self.gen_handler_call();
        let handler_fn = &self.handler_fn;

        quote! {
            #[doc(hidden)]
            pub(crate) unsafe extern "C" fn #callback_name(#callback_inputs) -> ::milter::sfsistat {
                #handler_call
            }

            #handler_fn
        }
    }

    fn gen_callback_inputs(&self) -> TokenStream {
        let callback_inputs = self.args.iter().filter_map(|arg| match arg {
            ArgSpec::Input(InputArg { name, c_type })
            | ArgSpec::Bound(BoundArg { name, c_type, .. }) => Some(quote! { #name: #c_type }),
            _ => None,
        });

        quote! { #(#callback_inputs),* }
    }

    fn gen_handler_call(&self) -> TokenStream {
        let (arg_names, arg_exprs): (Vec<_>, Vec<_>) = self.args.iter().filter_map(|arg| match arg {
            ArgSpec::Bound(BoundArg { name, gen, .. }) => Some((name, gen.expr(name))),
            _ => None,
        })
        .unzip();

        let handler_name = &self.handler_fn.sig.ident;

        let args = self.args.iter().filter_map(|arg| match arg {
            ArgSpec::Bound(BoundArg { name, gen, .. }) => Some(gen.arg_expr(name)),
            ArgSpec::Extra(arg) => Some(arg.clone()),
            _ => None,
        });

        let mut ok_result_arms = self.ok_result_arms.as_ref().map_or_else(
            || quote! { ::std::result::Result::Ok(status) => status as ::milter::sfsistat, },
            |arms| arms.clone(),
        );

        if is_result_return(&self.handler_fn.sig.output) {
            ok_result_arms = quote! {
                ::std::result::Result::Ok(result) => match result {
                    #ok_result_arms
                    ::std::result::Result::Err(error) => {
                        let msg = ::std::format!("error in milter callback: {}\0", error);
                        ::libc::syslog(::libc::LOG_WARNING, msg.as_ptr() as _);
                        ::milter::Status::Tempfail as ::milter::sfsistat
                    }
                },
            };
        }

        quote! {
            if ::milter::internal::is_panicked() {
                ::milter::Status::Tempfail as ::milter::sfsistat
            } else {
                match ::std::panic::catch_unwind(|| {
                    #( let #arg_names = #arg_exprs; )*

                    #handler_name(#(#args),*)
                }) {
                    #ok_result_arms
                    ::std::result::Result::Err(_) => {
                        ::milter::internal::set_panicked(true);
                        ::libc::syslog(::libc::LOG_ERR, "panic in milter callback, terminating\0".as_ptr() as _);
                        ::milter::shutdown();
                        ::milter::Status::Tempfail as ::milter::sfsistat
                    }
                }
            }
        }
    }
}

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Binding {
    Context,
    Str,
    Strs,
    SocketAddr,
    Actions,
    ProtocolOpts,
}

impl Binding {
    fn to_gen(self, fn_arg: &FnArg) -> Gen {
        match self {
            Self::Context => Gen::Context,
            Self::Str => Gen::Str(is_cstr_arg(fn_arg)),
            Self::Strs => Gen::Strs(is_cstrs_arg(fn_arg)),
            Self::SocketAddr => Gen::SocketAddr,
            Self::Actions => Gen::Actions,
            Self::ProtocolOpts => Gen::ProtocolOpts,
        }
    }
}

#[derive(Clone, Debug)]
enum ArgSpec {
    Input(InputArg),
    Bound(BoundArg),
    Extra(TokenStream),
}

#[derive(Clone, Debug)]
struct InputArg {
    name: Ident,
    c_type: TokenStream,
}

impl InputArg {
    fn new(name: Ident, c_type: TokenStream) -> Self {
        Self { name, c_type }
    }
}

#[derive(Clone, Debug)]
struct BoundArg {
    name: Ident,
    c_type: TokenStream,
    gen: Gen,
}

impl BoundArg {
    fn new(name: Ident, c_type: TokenStream, gen: Gen) -> Self {
        Self { name, c_type, gen }
    }
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
enum Gen {
    Context,
    Str(bool),
    Strs(bool),
    SocketAddr,
    Actions,
    ProtocolOpts,
}

impl Gen {
    fn expr(&self, ident: &Ident) -> TokenStream {
        match *self {
            Self::Context => quote! { ::milter::Context::new(#ident) },
            Self::Str(raw) => {
                let cstr_expr = quote! { ::std::ffi::CStr::from_ptr(#ident) };

                if raw {
                    cstr_expr
                } else {
                    // It isn’t clear what to do on non-ASCII inputs, but given
                    // the preference for UTF-8 in RFCs 6531 (SMTPUTF8) and
                    // 6532, this use of `to_string_lossy` seems reasonable.
                    quote! { #cstr_expr.to_string_lossy() }
                }
            }
            Self::Strs(raw) => {
                let str_expr = Gen::Str(raw).expr(&format_ident!("p"));

                quote! {
                    (0..)
                        .map(|i| *#ident.offset(i))
                        .take_while(|p| !p.is_null())
                        .map(|p| #str_expr)
                        .collect::<::std::vec::Vec<_>>()
                }
            }
            Self::SocketAddr => quote! {
                if #ident.is_null() {
                    ::std::option::Option::None
                } else {
                    match (*#ident).sa_family as _ {
                        ::libc::AF_INET => {
                            let addr = #ident as *const ::libc::sockaddr_in;
                            let ip = ::std::net::Ipv4Addr::from(u32::from_be((*addr).sin_addr.s_addr));
                            let port = u16::from_be((*addr).sin_port);
                            ::std::option::Option::Some(::std::net::SocketAddr::from(::std::net::SocketAddrV4::new(ip, port)))
                        }
                        ::libc::AF_INET6 => {
                            let addr = #ident as *const ::libc::sockaddr_in6;
                            let ip = ::std::net::Ipv6Addr::from((*addr).sin6_addr.s6_addr);
                            let port = u16::from_be((*addr).sin6_port);
                            let flowinfo = (*addr).sin6_flowinfo;
                            let scope_id = (*addr).sin6_scope_id;
                            ::std::option::Option::Some(::std::net::SocketAddr::from(::std::net::SocketAddrV6::new(ip, port, flowinfo, scope_id)))
                        }
                        _ => ::std::option::Option::None,
                    }
                }
            },
            Self::Actions => quote! { ::milter::Actions::from_bits_truncate(#ident) },
            Self::ProtocolOpts => quote! { ::milter::ProtocolOpts::from_bits_truncate(#ident) },
        }
    }

    fn arg_expr(&self, ident: &Ident) -> TokenStream {
        // The generated string expressions are of type Cow<str>, which can be
        // borrowed as &str. The target types here could be inferred, but are
        // specified explicitly; this makes for clearer macro error messages.
        match *self {
            Gen::Str(raw) if !raw => quote! { &#ident as &str },
            Gen::Strs(raw) if !raw => quote! {
                #ident.iter().map(|s| s as &str).collect::<::std::vec::Vec<&str>>()
            },
            _ => quote! { #ident },
        }
    }
}