use common::AclCategory;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use serde::Deserialize;
use serde_syn::{config, from_stream};
use syn::{
parse,
parse::{Parse, ParseStream},
parse_macro_input, ItemFn,
};
#[derive(Debug, Deserialize)]
pub enum RedisCommandFlags {
Write,
ReadOnly,
Admin,
DenyOOM,
DenyScript,
AllowLoading,
PubSub,
Random,
AllowStale,
NoMonitor,
NoSlowlog,
Fast,
GetkeysApi,
NoCluster,
NoAuth,
MayReplicate,
NoMandatoryKeys,
Blocking,
AllowBusy,
GetchannelsApi,
}
impl From<&RedisCommandFlags> for &'static str {
fn from(value: &RedisCommandFlags) -> Self {
match value {
RedisCommandFlags::Write => "write",
RedisCommandFlags::ReadOnly => "readonly",
RedisCommandFlags::Admin => "admin",
RedisCommandFlags::DenyOOM => "deny-oom",
RedisCommandFlags::DenyScript => "deny-script",
RedisCommandFlags::AllowLoading => "allow-loading",
RedisCommandFlags::PubSub => "pubsub",
RedisCommandFlags::Random => "random",
RedisCommandFlags::AllowStale => "allow-stale",
RedisCommandFlags::NoMonitor => "no-monitor",
RedisCommandFlags::NoSlowlog => "no-slowlog",
RedisCommandFlags::Fast => "fast",
RedisCommandFlags::GetkeysApi => "getkeys-api",
RedisCommandFlags::NoCluster => "no-cluster",
RedisCommandFlags::NoAuth => "no-auth",
RedisCommandFlags::MayReplicate => "may-replicate",
RedisCommandFlags::NoMandatoryKeys => "no-mandatory-keys",
RedisCommandFlags::Blocking => "blocking",
RedisCommandFlags::AllowBusy => "allow-busy",
RedisCommandFlags::GetchannelsApi => "getchannels-api",
}
}
}
#[derive(Debug, Deserialize)]
pub enum RedisEnterpriseCommandFlags {
ProxyFiltered,
}
impl From<&RedisEnterpriseCommandFlags> for &'static str {
fn from(value: &RedisEnterpriseCommandFlags) -> Self {
match value {
RedisEnterpriseCommandFlags::ProxyFiltered => "_proxy-filtered",
}
}
}
#[derive(Debug, Deserialize)]
pub enum RedisCommandKeySpecFlags {
ReadOnly,
ReadWrite,
Overwrite,
Remove,
Access,
Update,
Insert,
Delete,
NotKey,
Incomplete,
VariableFlags,
}
impl From<&RedisCommandKeySpecFlags> for &'static str {
fn from(value: &RedisCommandKeySpecFlags) -> Self {
match value {
RedisCommandKeySpecFlags::ReadOnly => "READ_ONLY",
RedisCommandKeySpecFlags::ReadWrite => "READ_WRITE",
RedisCommandKeySpecFlags::Overwrite => "OVERWRITE",
RedisCommandKeySpecFlags::Remove => "REMOVE",
RedisCommandKeySpecFlags::Access => "ACCESS",
RedisCommandKeySpecFlags::Update => "UPDATE",
RedisCommandKeySpecFlags::Insert => "INSERT",
RedisCommandKeySpecFlags::Delete => "DELETE",
RedisCommandKeySpecFlags::NotKey => "NOT_KEY",
RedisCommandKeySpecFlags::Incomplete => "INCOMPLETE",
RedisCommandKeySpecFlags::VariableFlags => "VARIABLE_FLAGS",
}
}
}
#[derive(Debug, Deserialize)]
pub struct FindKeysRange {
last_key: i32,
steps: i32,
limit: i32,
}
#[derive(Debug, Deserialize)]
pub struct FindKeysNum {
key_num_idx: i32,
first_key: i32,
key_step: i32,
}
#[derive(Debug, Deserialize)]
pub enum FindKeys {
Range(FindKeysRange),
Keynum(FindKeysNum),
}
#[derive(Debug, Deserialize)]
pub struct BeginSearchIndex {
index: i32,
}
#[derive(Debug, Deserialize)]
pub struct BeginSearchKeyword {
keyword: String,
startfrom: i32,
}
#[derive(Debug, Deserialize)]
pub enum BeginSearch {
Index(BeginSearchIndex),
Keyword(BeginSearchKeyword), }
#[derive(Debug, Deserialize)]
pub struct KeySpecArg {
notes: Option<String>,
flags: Vec<RedisCommandKeySpecFlags>,
begin_search: BeginSearch,
find_keys: FindKeys,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
pub enum CommandArgType {
String,
Integer,
Double,
Key,
Pattern,
UnixTime,
PureToken,
OneOf,
Block,
}
impl From<CommandArgType> for u32 {
fn from(arg_type: CommandArgType) -> Self {
match arg_type {
CommandArgType::String => 0,
CommandArgType::Integer => 1,
CommandArgType::Double => 2,
CommandArgType::Key => 3,
CommandArgType::Pattern => 4,
CommandArgType::UnixTime => 5,
CommandArgType::PureToken => 6,
CommandArgType::OneOf => 7,
CommandArgType::Block => 8,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub enum CommandArgFlags {
None,
Optional,
Multiple,
MultipleToken,
}
impl From<&CommandArgFlags> for &'static str {
fn from(value: &CommandArgFlags) -> Self {
match value {
CommandArgFlags::None => "NONE",
CommandArgFlags::Optional => "OPTIONAL",
CommandArgFlags::Multiple => "MULTIPLE",
CommandArgFlags::MultipleToken => "MULTIPLE_TOKEN",
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct CommandArg {
pub name: String,
pub arg_type: CommandArgType,
pub key_spec_index: Option<u32>,
pub token: Option<String>,
pub summary: Option<String>,
pub since: Option<String>,
pub flags: Option<Vec<CommandArgFlags>>,
pub deprecated_since: Option<String>,
pub subargs: Option<Vec<CommandArg>>,
pub display_text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Args {
name: Option<String>,
flags: Vec<RedisCommandFlags>,
enterprise_flags: Option<Vec<RedisEnterpriseCommandFlags>>,
summary: Option<String>,
complexity: Option<String>,
since: Option<String>,
tips: Option<String>,
arity: i64,
key_spec: Vec<KeySpecArg>,
args: Option<Vec<CommandArg>>,
acl_categories: Option<Vec<AclCategory>>,
}
impl Parse for Args {
fn parse(input: ParseStream) -> parse::Result<Self> {
from_stream(config::JSONY, input)
}
}
fn to_token_stream(s: Option<String>) -> proc_macro2::TokenStream {
s.map(|v| quote! {Some(#v.to_owned())})
.unwrap_or(quote! {None})
}
fn generate_command_arg(arg: &CommandArg) -> proc_macro2::TokenStream {
let name = &arg.name;
let arg_type: u32 = arg.arg_type.into();
let key_spec_index = arg
.key_spec_index
.map(|v| quote! {Some(#v)})
.unwrap_or(quote! {None});
let token = to_token_stream(arg.token.clone());
let summary = to_token_stream(arg.summary.clone());
let since = to_token_stream(arg.since.clone());
let flags: Vec<&'static str> = arg
.flags
.as_ref()
.map(|v| v.iter().map(|v| v.into()).collect())
.unwrap_or_default();
let flags = quote! {
vec![#(redis_module::commands::CommandArgFlags::try_from(#flags)?, )*]
};
let deprecated_since = to_token_stream(arg.deprecated_since.clone());
let display_text = to_token_stream(arg.display_text.clone());
let subargs = if let Some(subargs_vec) = &arg.subargs {
let subargs_tokens: Vec<_> = subargs_vec.iter().map(generate_command_arg).collect();
quote! {
Some(vec![#(#subargs_tokens),*])
}
} else {
quote! { None }
};
quote! {
redis_module::commands::RedisModuleCommandArg::new(
#name.to_owned(),
#arg_type,
#key_spec_index,
#token,
#summary,
#since,
#flags.into(),
#deprecated_since,
#subargs,
#display_text,
)
}
}
pub(crate) fn redis_command(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as Args);
let func: ItemFn = match syn::parse(item) {
Ok(res) => res,
Err(e) => return e.to_compile_error().into(),
};
let original_function_name = func.sig.ident.clone();
let c_function_name = Ident::new(&format!("_inner_{}", func.sig.ident), func.sig.ident.span());
let get_command_info_function_name = Ident::new(
&format!("_inner_get_command_info_{}", func.sig.ident),
func.sig.ident.span(),
);
let name_literal = args
.name
.unwrap_or_else(|| original_function_name.to_string());
let flags_str = args
.flags
.into_iter()
.fold(String::new(), |s, v| {
format!("{} {}", s, Into::<&'static str>::into(&v))
})
.trim()
.to_owned();
let flags_literal = quote!(#flags_str);
let enterprise_flags_str = args
.enterprise_flags
.map(|v| {
v.into_iter()
.fold(String::new(), |s, v| {
format!("{} {}", s, Into::<&'static str>::into(&v))
})
.trim()
.to_owned()
})
.unwrap_or_default();
let enterprise_flags_literal = quote!(#enterprise_flags_str);
let summary_literal = to_token_stream(args.summary);
let complexity_literal = to_token_stream(args.complexity);
let since_literal = to_token_stream(args.since);
let tips_literal = to_token_stream(args.tips);
let arity_literal = args.arity;
let key_spec_notes: Vec<_> = args
.key_spec
.iter()
.map(|v| {
v.notes
.as_ref()
.map(|v| quote! {Some(#v.to_owned())})
.unwrap_or(quote! {None})
})
.collect();
let key_spec_flags: Vec<_> = args
.key_spec
.iter()
.map(|v| {
let flags: Vec<&'static str> = v.flags.iter().map(|v| v.into()).collect();
quote! {
vec![#(redis_module::commands::KeySpecFlags::try_from(#flags)?, )*]
}
})
.collect();
let key_spec_begin_search: Vec<_> = args
.key_spec
.iter()
.map(|v| match &v.begin_search {
BeginSearch::Index(i) => {
let i = i.index;
quote! {
redis_module::commands::BeginSearch::new_index(#i)
}
}
BeginSearch::Keyword(begin_search_keyword) => {
let k = begin_search_keyword.keyword.as_str();
let i = begin_search_keyword.startfrom;
quote! {
redis_module::commands::BeginSearch::new_keyword(#k.to_owned(), #i)
}
}
})
.collect();
let key_spec_find_keys: Vec<_> = args
.key_spec
.iter()
.map(|v| match &v.find_keys {
FindKeys::Keynum(find_keys_num) => {
let keynumidx = find_keys_num.key_num_idx;
let firstkey = find_keys_num.first_key;
let keystep = find_keys_num.key_step;
quote! {
redis_module::commands::FindKeys::new_keys_num(#keynumidx, #firstkey, #keystep)
}
}
FindKeys::Range(find_keys_range) => {
let last_key = find_keys_range.last_key;
let steps = find_keys_range.steps;
let limit = find_keys_range.limit;
quote! {
redis_module::commands::FindKeys::new_range(#last_key, #steps, #limit)
}
}
})
.collect();
let command_args: Vec<_> = args
.args
.as_ref()
.map(|v| v.iter().map(generate_command_arg).collect())
.unwrap_or_default();
let acl_categories = args
.acl_categories
.map(|v| v.into_iter().map(String::from).collect::<Vec<_>>());
let acl_categories_tokens = if let Some(categories) = &acl_categories {
quote! {
Some(vec![#(#categories.to_owned()),*])
}
} else {
quote! { None }
};
let gen = quote! {
#func
extern "C" fn #c_function_name(
ctx: *mut redis_module::raw::RedisModuleCtx,
argv: *mut *mut redis_module::raw::RedisModuleString,
argc: i32,
) -> i32 {
let context = redis_module::Context::new(ctx);
let args = redis_module::decode_args(ctx, argv, argc);
let response = #original_function_name(&context, args);
context.reply(response.map(|v| v.into())) as i32
}
#[linkme::distributed_slice(redis_module::commands::COMMANDS_LIST)]
fn #get_command_info_function_name() -> Result<redis_module::commands::CommandInfo, redis_module::RedisError> {
let key_spec = vec![
#(
redis_module::commands::KeySpec::new(
#key_spec_notes,
#key_spec_flags.into(),
#key_spec_begin_search,
#key_spec_find_keys,
),
)*
];
let command_args = vec![#(#command_args),*];
Ok(redis_module::commands::CommandInfo::new(
#name_literal.to_owned(),
Some(#flags_literal.to_owned()),
Some(#enterprise_flags_literal.to_owned()),
#summary_literal,
#complexity_literal,
#since_literal,
#tips_literal,
#arity_literal,
key_spec,
#c_function_name,
command_args,
#acl_categories_tokens,
))
}
};
gen.into()
}