use crate::tree_preds::*;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{FnArg, Ident, ItemFn};
#[derive(Clone)]
pub struct CallbackFn {
name: Ident,
handler_fn: ItemFn,
handler_arg_index: usize,
args: Vec<ArgSpec>,
ok_result_arms: Option<TokenStream>,
}
impl CallbackFn {
pub fn new(name: Ident, handler_fn: ItemFn) -> Self {
Self {
name,
handler_fn,
handler_arg_index: 0,
args: vec![],
ok_result_arms: None,
}
}
pub fn input_unbound(&mut self, name: Ident, c_type: TokenStream) -> &mut Self {
self.args.push(ArgSpec::Input(InputArg::new(name, c_type)));
self
}
pub fn input(&mut self, name: Ident, c_type: TokenStream, arg_binding: Binding) -> &mut Self {
let gen = arg_binding.to_gen(&self.handler_fn.sig.inputs[self.handler_arg_index]);
self.args.push(ArgSpec::Bound(BoundArg::new(name, c_type, gen)));
self.handler_arg_index += 1;
self
}
pub fn extra_arg(&mut self, handler_arg_expr: TokenStream) -> &mut Self {
self.args.push(ArgSpec::Extra(handler_arg_expr));
self.handler_arg_index += 1;
self
}
pub fn ok_result_arms(&mut self, ok_result_arms: TokenStream) -> &mut Self {
self.ok_result_arms = Some(ok_result_arms);
self
}
pub fn generate(&self) -> TokenStream {
let callback_name = &self.name;
let callback_inputs = self.gen_callback_inputs();
let handler_call = self.gen_handler_call();
let handler_fn = &self.handler_fn;
quote! {
#[doc(hidden)]
pub(crate) unsafe extern "C" fn #callback_name(#callback_inputs) -> ::milter::sfsistat {
#handler_call
}
#handler_fn
}
}
fn gen_callback_inputs(&self) -> TokenStream {
let callback_inputs = self.args.iter().filter_map(|arg| match arg {
ArgSpec::Input(InputArg { name, c_type })
| ArgSpec::Bound(BoundArg { name, c_type, .. }) => Some(quote! { #name: #c_type }),
_ => None,
});
quote! { #(#callback_inputs),* }
}
fn gen_handler_call(&self) -> TokenStream {
let (arg_names, arg_exprs): (Vec<_>, Vec<_>) = self.args.iter().filter_map(|arg| match arg {
ArgSpec::Bound(BoundArg { name, gen, .. }) => Some((name, gen.expr(name))),
_ => None,
})
.unzip();
let handler_name = &self.handler_fn.sig.ident;
let args = self.args.iter().filter_map(|arg| match arg {
ArgSpec::Bound(BoundArg { name, gen, .. }) => Some(gen.arg_expr(name)),
ArgSpec::Extra(arg) => Some(arg.clone()),
_ => None,
});
let mut ok_result_arms = self.ok_result_arms.as_ref().map_or_else(
|| quote! { ::std::result::Result::Ok(status) => status as ::milter::sfsistat, },
|arms| arms.clone(),
);
if is_result_return(&self.handler_fn.sig.output) {
ok_result_arms = quote! {
::std::result::Result::Ok(result) => match result {
#ok_result_arms
::std::result::Result::Err(error) => {
let msg = ::std::format!("error in milter callback: {}\0", error);
::libc::syslog(::libc::LOG_WARNING, msg.as_ptr() as _);
::milter::Status::Tempfail as ::milter::sfsistat
}
},
};
}
quote! {
if ::milter::internal::is_panicked() {
::milter::Status::Tempfail as ::milter::sfsistat
} else {
match ::std::panic::catch_unwind(|| {
#( let #arg_names = #arg_exprs; )*
#handler_name(#(#args),*)
}) {
#ok_result_arms
::std::result::Result::Err(_) => {
::milter::internal::set_panicked(true);
::libc::syslog(::libc::LOG_ERR, "panic in milter callback, terminating\0".as_ptr() as _);
::milter::shutdown();
::milter::Status::Tempfail as ::milter::sfsistat
}
}
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Binding {
Context,
Str,
Strs,
SocketAddr,
Actions,
ProtocolOpts,
}
impl Binding {
fn to_gen(self, fn_arg: &FnArg) -> Gen {
match self {
Self::Context => Gen::Context,
Self::Str => Gen::Str(is_cstr_arg(fn_arg)),
Self::Strs => Gen::Strs(is_cstrs_arg(fn_arg)),
Self::SocketAddr => Gen::SocketAddr,
Self::Actions => Gen::Actions,
Self::ProtocolOpts => Gen::ProtocolOpts,
}
}
}
#[derive(Clone, Debug)]
enum ArgSpec {
Input(InputArg),
Bound(BoundArg),
Extra(TokenStream),
}
#[derive(Clone, Debug)]
struct InputArg {
name: Ident,
c_type: TokenStream,
}
impl InputArg {
fn new(name: Ident, c_type: TokenStream) -> Self {
Self { name, c_type }
}
}
#[derive(Clone, Debug)]
struct BoundArg {
name: Ident,
c_type: TokenStream,
gen: Gen,
}
impl BoundArg {
fn new(name: Ident, c_type: TokenStream, gen: Gen) -> Self {
Self { name, c_type, gen }
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
enum Gen {
Context,
Str(bool),
Strs(bool),
SocketAddr,
Actions,
ProtocolOpts,
}
impl Gen {
fn expr(&self, ident: &Ident) -> TokenStream {
match *self {
Self::Context => quote! { ::milter::Context::new(#ident) },
Self::Str(raw) => {
let cstr_expr = quote! { ::std::ffi::CStr::from_ptr(#ident) };
if raw {
cstr_expr
} else {
quote! { #cstr_expr.to_string_lossy() }
}
}
Self::Strs(raw) => {
let str_expr = Gen::Str(raw).expr(&format_ident!("p"));
quote! {
(0..)
.map(|i| *#ident.offset(i))
.take_while(|p| !p.is_null())
.map(|p| #str_expr)
.collect::<::std::vec::Vec<_>>()
}
}
Self::SocketAddr => quote! {
if #ident.is_null() {
::std::option::Option::None
} else {
match (*#ident).sa_family as _ {
::libc::AF_INET => {
let addr = #ident as *const ::libc::sockaddr_in;
let ip = ::std::net::Ipv4Addr::from(u32::from_be((*addr).sin_addr.s_addr));
let port = u16::from_be((*addr).sin_port);
::std::option::Option::Some(::std::net::SocketAddr::from(::std::net::SocketAddrV4::new(ip, port)))
}
::libc::AF_INET6 => {
let addr = #ident as *const ::libc::sockaddr_in6;
let ip = ::std::net::Ipv6Addr::from((*addr).sin6_addr.s6_addr);
let port = u16::from_be((*addr).sin6_port);
let flowinfo = (*addr).sin6_flowinfo;
let scope_id = (*addr).sin6_scope_id;
::std::option::Option::Some(::std::net::SocketAddr::from(::std::net::SocketAddrV6::new(ip, port, flowinfo, scope_id)))
}
_ => ::std::option::Option::None,
}
}
},
Self::Actions => quote! { ::milter::Actions::from_bits_truncate(#ident) },
Self::ProtocolOpts => quote! { ::milter::ProtocolOpts::from_bits_truncate(#ident) },
}
}
fn arg_expr(&self, ident: &Ident) -> TokenStream {
match *self {
Gen::Str(raw) if !raw => quote! { &#ident as &str },
Gen::Strs(raw) if !raw => quote! {
#ident.iter().map(|s| s as &str).collect::<::std::vec::Vec<&str>>()
},
_ => quote! { #ident },
}
}
}