use std::fmt;
use std::os::raw::c_int;
use std::thread::sleep;
use valkey_module::alloc::ValkeyAlloc;
use valkey_module::logging::log_notice;
use valkey_module::{
valkey_module, Context, Status, ValkeyError, ValkeyString, AUTH_HANDLED, AUTH_NOT_HANDLED,
};
#[derive(Debug, Copy, Clone)]
enum AuthResult {
Allow, Deny, Next, }
#[derive(Debug)]
struct AuthPrivData {
result: AuthResult,
}
impl fmt::Display for AuthResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
fn authenticate_user_common(
ctx: &Context,
username: &ValkeyString,
password: &ValkeyString,
expected_user: &str,
expected_pass: &str,
callback_name: &str,
) -> Result<c_int, ValkeyError> {
ctx.log_notice(&format!(
"{} auth attempt for user: {}",
callback_name, username
));
if username.to_string() == expected_user {
if password.to_string() == expected_pass {
ctx.log_notice(&format!(
"{}: Matched user: {} with credentials",
callback_name, username
));
match ctx.authenticate_client_with_acl_user(username) {
Status::Ok => {
ctx.log_notice(&format!(
"{}: Successfully authenticated user: {}",
callback_name, username
));
Ok(AUTH_HANDLED)
}
Status::Err => {
ctx.log_warning(&format!(
"{}: Failed to authenticate user: {} with ACL",
callback_name, username
));
Err(ValkeyError::Str("Failed to authenticate with ACL"))
}
}
} else {
ctx.log_notice(&format!(
"{}: Auth explicitly denied for user: {}",
callback_name, username
));
if callback_name.contains("auth_one") {
Err(ValkeyError::Str(
"DENIED: Authentication credentials mismatch in auth_one",
))
} else if callback_name.contains("auth_two") {
Err(ValkeyError::Str(
"DENIED: Authentication credentials mismatch in auth_two",
))
} else {
Err(ValkeyError::Str(
"DENIED: Authentication credentials mismatch for the user",
))
}
}
} else {
ctx.log_notice(&format!(
"{}: auth not handled for user: {}",
callback_name, username
));
Ok(AUTH_NOT_HANDLED)
}
}
fn auth_callback_one(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
) -> Result<c_int, ValkeyError> {
authenticate_user_common(
ctx,
&username,
&password,
"user1",
"module_pass1",
"auth_one",
)
}
fn auth_callback_two(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
) -> Result<c_int, ValkeyError> {
authenticate_user_common(
ctx,
&username,
&password,
"user2",
"module_pass2",
"auth_two",
)
}
fn auth_reply_callback_common(
ctx: &Context,
username: ValkeyString,
_password: ValkeyString,
priv_data: Option<&AuthPrivData>,
callback_name: &str,
) -> Result<c_int, ValkeyError> {
match priv_data
.map(|data| data.result)
.unwrap_or(AuthResult::Next)
{
AuthResult::Allow => {
ctx.log_notice(&format!(
"{}: Auth allowed for user: {}",
callback_name, username
));
match ctx.authenticate_client_with_acl_user(&username) {
Status::Ok => {
ctx.log_notice(&format!(
"{}: Successfully authenticated user: {}",
callback_name, username
));
Ok(AUTH_HANDLED)
}
Status::Err => {
ctx.log_warning(&format!(
"{}: Failed to authenticate user: {} with ACL",
callback_name, username
));
Err(ValkeyError::Str("Failed to authenticate with ACL"))
}
}
}
AuthResult::Deny => {
ctx.log_notice(&format!(
"{}: Auth explicitly denied for user: {}",
callback_name, username
));
if callback_name.contains("blocked_auth_reply_one") {
Err(ValkeyError::Str(
"DENIED: Authentication credentials mismatch in blocked_auth_reply_one",
))
} else if callback_name.contains("blocked_auth_reply_two") {
Err(ValkeyError::Str(
"DENIED: Authentication credentials mismatch in blocked_auth_reply_two",
))
} else {
Err(ValkeyError::Str(
"DENIED: Authentication credentials mismatch for the user",
))
}
}
AuthResult::Next => {
ctx.log_notice(&format!(
"{}: Passing auth to next handler for user: {}",
callback_name, username
));
Ok(AUTH_NOT_HANDLED)
}
}
}
fn auth_reply_callback_one(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
priv_data: Option<&AuthPrivData>,
) -> Result<c_int, ValkeyError> {
auth_reply_callback_common(ctx, username, password, priv_data, "blocked_auth_reply_one")
}
fn auth_reply_callback_two(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
priv_data: Option<&AuthPrivData>,
) -> Result<c_int, ValkeyError> {
auth_reply_callback_common(ctx, username, password, priv_data, "blocked_auth_reply_two")
}
fn free_privdata_callback_one(ctx: &Context, data: AuthPrivData) {
ctx.log_notice(&format!(
"free_privdata_callback_one: Cleaning up: {}",
data.result
));
drop(data);
}
fn free_privdata_callback_two(ctx: &Context, data: AuthPrivData) {
ctx.log_notice(&format!(
"free_privdata_callback_two: Cleaning up: {}",
data.result
));
drop(data);
}
fn blocking_auth_common(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
auth_reply_fn: fn(
&Context,
ValkeyString,
ValkeyString,
Option<&AuthPrivData>,
) -> Result<c_int, ValkeyError>,
free_callback_fn: Option<fn(&Context, AuthPrivData)>,
auth_patterns: impl Fn(&str, &str) -> AuthResult + Send + 'static,
callback_name: &str,
) -> Result<c_int, ValkeyError> {
ctx.log_notice(&format!("{}: handling blocked client", callback_name));
if username.to_string() == "default" {
ctx.log_notice(&format!(
"{}: Default user authentication - passing to next handler",
callback_name
));
return Ok(AUTH_NOT_HANDLED);
}
let username_str = username.to_string();
let password_str = password.to_string();
let callback_name_str = callback_name.to_string();
let mut blocked_client = ctx.block_client_on_auth(auth_reply_fn, free_callback_fn);
if callback_name_str == "blocking_auth_callback_two" {
if username_str == "blockAbort" && password_str == "abort" {
blocked_client.abort().unwrap_or_else(|e| {
ctx.log_notice(&format!("Failed to abort blocked client: {:?}", e));
});
ctx.reply_error_string("ERR ABORT: Authentication aborted by server");
return Ok(AUTH_HANDLED);
}
}
std::thread::spawn(move || {
if callback_name_str == "blocking_auth_callback_two" && username_str == "blockUserDelay" {
sleep(std::time::Duration::from_secs(2));
}
let result = auth_patterns(&username_str, &password_str);
if let Err(e) = blocked_client.set_blocked_private_data(AuthPrivData { result }) {
log_notice(&format!(
"{}: Failed to set private data: {}",
callback_name_str, e
));
if let Err(abort_err) = blocked_client.abort() {
log_notice(&format!(
"{}: Failed to abort blocked client: {}",
callback_name_str, abort_err
));
}
}
});
Ok(AUTH_HANDLED)
}
fn blocking_auth_callback_one(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
) -> Result<c_int, ValkeyError> {
blocking_auth_common(
ctx,
username,
password,
auth_reply_callback_one,
Some(free_privdata_callback_one),
|user, pass| match (user, pass) {
("blockUser1", "module_blockPass1") => AuthResult::Allow,
("blockUser2", "module_blockPass2") => AuthResult::Allow,
("blockUser3", "module_blockPass3") => AuthResult::Allow,
("blockUser1", _) | ("blockUser2", _) | ("blockUser3", _) => AuthResult::Deny,
_ => AuthResult::Next,
},
"blocking_auth_callback_one",
)
}
fn blocking_auth_callback_two(
ctx: &Context,
username: ValkeyString,
password: ValkeyString,
) -> Result<c_int, ValkeyError> {
blocking_auth_common(
ctx,
username,
password,
auth_reply_callback_two,
Some(free_privdata_callback_two),
|user, pass| match (user, pass) {
("blockUser4", "module_blockPass4") => AuthResult::Allow,
("blockUser5", "module_blockPass5") => AuthResult::Allow,
("blockUser6", "module_blockPass6") => AuthResult::Allow,
("blockUserDelay", "blockPassDelay") => AuthResult::Allow,
("blockUser4", _) | ("blockUser5", _) | ("blockUser6", _) | ("blockUserDelay", _) => {
AuthResult::Deny
}
_ => AuthResult::Next,
},
"blocking_auth_callback_two",
)
}
valkey_module! {
name: "auth",
version: 1,
allocator: (ValkeyAlloc, ValkeyAlloc),
data_types: [],
auth: [
blocking_auth_callback_two,
blocking_auth_callback_one,
auth_callback_two,
auth_callback_one
],
commands: []
}