use crate::func::{get_generic_argument_type, OutputBindings};
use azure_functions_shared::codegen::{bindings::TRIGGERS, last_segment_in_path};
use azure_functions_shared::util::to_camel_case;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{FnArg, Ident, ItemFn, Pat, Type};
const INVOKER_PREFIX: &str = "__invoke_";
pub struct Invoker<'a> {
pub func: &'a ItemFn,
pub is_orchestration: bool,
}
impl<'a> Invoker<'a> {
pub fn name(&self) -> String {
format!("{}{}", INVOKER_PREFIX, self.func.sig.ident)
}
fn deref_arg_type(ty: &Type) -> &Type {
match ty {
Type::Reference(tr) => &*tr.elem,
_ => ty,
}
}
fn is_trigger_type(ty: &Type) -> bool {
match Invoker::deref_arg_type(ty) {
Type::Path(tp) => {
TRIGGERS.contains_key(last_segment_in_path(&tp.path).ident.to_string().as_str())
}
Type::Paren(tp) => Invoker::is_trigger_type(&tp.elem),
_ => false,
}
}
}
struct CommonInvokerTokens<'a> {
pub func: &'a ItemFn,
pub is_orchestration: bool,
}
impl<'a> CommonInvokerTokens<'a> {
fn get_input_args(&self) -> (Vec<&'a Ident>, Vec<&'a Type>) {
self.iter_args()
.filter_map(|(name, arg_type)| {
if Invoker::is_trigger_type(arg_type) {
return None;
}
Some((name, Invoker::deref_arg_type(arg_type)))
})
.unzip()
}
fn get_input_assignments(&self) -> Vec<TokenStream> {
self.iter_args()
.filter_map(|(_, arg_type)| {
if Invoker::is_trigger_type(arg_type) {
return None;
}
if let Type::Path(tp) = Invoker::deref_arg_type(arg_type) {
if get_generic_argument_type(last_segment_in_path(&tp.path), "Vec").is_some() {
return Some(quote!(__param
.data
.expect("expected parameter binding data")
.into_vec()));
}
}
Some(quote!(__param
.data
.expect("expected parameter binding data")
.into()))
})
.collect()
}
fn get_trigger_arg(&self) -> Option<(&'a Ident, &'a Type)> {
self.iter_args()
.find(|(_, arg_type)| Invoker::is_trigger_type(arg_type))
.map(|(name, arg_type)| (name, Invoker::deref_arg_type(arg_type)))
}
fn get_state_arg(&self, trigger: &Ident) -> TokenStream {
if self.is_orchestration {
quote!(let __state = #trigger.as_ref().unwrap().state();)
} else {
TokenStream::new()
}
}
fn get_args_for_call(&self) -> Vec<TokenStream> {
self.iter_args()
.map(|(name, arg_type)| {
let name_str = name.to_string();
if let Type::Reference(tr) = arg_type {
return match tr.mutability {
Some(_) => quote!(#name.as_mut().expect(concat!("parameter binding '", #name_str, "' was not provided"))),
None => quote!(#name.as_ref().expect(concat!("parameter binding '", #name_str, "' was not provided")))
};
}
quote!(#name.expect(concat!("parameter binding '", #name_str, "' was not provided")))
})
.collect()
}
fn iter_args(&self) -> impl Iterator<Item = (&'a Ident, &'a Type)> {
self.func.sig.inputs.iter().map(|x| match x {
FnArg::Typed(arg) => (
match &*arg.pat {
Pat::Ident(name) => &name.ident,
_ => panic!("expected ident argument pattern"),
},
&*arg.ty,
),
_ => panic!("expected captured arguments"),
})
}
}
impl ToTokens for CommonInvokerTokens<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let target = &self.func.sig.ident;
let (args, types) = self.get_input_args();
let args_for_match = args.clone();
let arg_assignments = self.get_input_assignments();
let arg_names: Vec<_> = args.iter().map(|x| to_camel_case(&x.to_string())).collect();
let (trigger_arg, trigger_type) = self
.get_trigger_arg()
.expect("the function must have a trigger");
let trigger_name = to_camel_case(&trigger_arg.to_string());
let args_for_call = self.get_args_for_call();
let state_arg = self.get_state_arg(trigger_arg);
quote!(
use azure_functions::{IntoVec, FromVec};
let mut #trigger_arg: Option<#trigger_type> = None;
#(let mut #args: Option<#types> = None;)*
let mut __metadata = Some(__req.trigger_metadata);
for __param in __req.input_data.into_iter() {
match __param.name.as_str() {
#trigger_name => #trigger_arg = Some(
#trigger_type::new(
__param.data.expect("expected parameter binding data"),
__metadata.take().expect("expected only one trigger"),
)
),
#(#arg_names => #args_for_match = Some(#arg_assignments),)*
_ => panic!(format!("unexpected parameter binding '{}'", __param.name)),
};
}
#state_arg
let __ret = #target(#(#args_for_call,)*);
)
.to_tokens(tokens);
}
}
impl ToTokens for Invoker<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let ident = Ident::new(
&format!("{}{}", INVOKER_PREFIX, self.func.sig.ident.to_string()),
self.func.sig.ident.span(),
);
let common_tokens = CommonInvokerTokens {
func: &self.func,
is_orchestration: self.is_orchestration,
};
let output_bindings = OutputBindings {
func: self.func,
is_orchestration: self.is_orchestration,
};
if self.is_orchestration {
quote!(
#[allow(dead_code)]
fn #ident(
__req: ::azure_functions::rpc::InvocationRequest,
) -> ::azure_functions::rpc::InvocationResponse {
#common_tokens
::azure_functions::durable::orchestrate(
__req.invocation_id,
__ret,
__state,
)
}
)
.to_tokens(tokens);
} else if self.func.sig.asyncness.is_some() {
quote!(
#[allow(dead_code)]
fn #ident(
__req: ::azure_functions::rpc::InvocationRequest,
) -> ::azure_functions::codegen::InvocationFuture {
#common_tokens
use futures::future::FutureExt;
let __id = __req.invocation_id;
Box::pin(
__ret.then(move |__ret| {
let mut __res = ::azure_functions::rpc::InvocationResponse {
invocation_id: __id,
result: Some(::azure_functions::rpc::StatusResult {
status: ::azure_functions::rpc::status_result::Status::Success as i32,
..Default::default()
}),
..Default::default()
};
#output_bindings
::futures::future::ready(__res)
})
)
}
).to_tokens(tokens);
} else {
quote!(
#[allow(dead_code)]
fn #ident(
__req: ::azure_functions::rpc::InvocationRequest,
) -> ::azure_functions::rpc::InvocationResponse {
#common_tokens
let mut __res = ::azure_functions::rpc::InvocationResponse {
invocation_id: __req.invocation_id,
result: Some(::azure_functions::rpc::StatusResult {
status: ::azure_functions::rpc::status_result::Status::Success as i32,
..Default::default()
}),
..Default::default()
};
#output_bindings
__res
}
)
.to_tokens(tokens);
}
}
}