use crate::acl::{AclConfig, AclManager};
use crate::backend::RedisConnectionType;
use crate::base::constants::DEFAULT_QUEUE_NAME;
use crate::error::{Error, Result};
use async_trait::async_trait;
use redis::acl::Rule;
use redis::aio::MultiplexedConnection;
use redis::{AsyncTypedCommands, Client};
pub struct RedisAclManager {
client: Client,
}
impl RedisAclManager {
pub async fn new(conn: RedisConnectionType) -> Result<Self> {
match conn {
crate::backend::RedisConnectionType::Single {
connection_info,
#[cfg(feature = "tls")]
tls_certs,
} => {
#[cfg(feature = "tls")]
let client = {
if let Some(tls) = tls_certs {
Client::build_with_tls(connection_info, tls)?
} else {
Client::open(connection_info)?
}
};
#[cfg(not(feature = "tls"))]
let client = Client::open(connection_info)?;
Ok(Self { client })
}
#[cfg(feature = "cluster")]
crate::backend::RedisConnectionType::Cluster(_) => Err(Error::not_supported(
"ACL management is not supported in cluster mode",
)),
#[cfg(feature = "sentinel")]
crate::backend::RedisConnectionType::Sentinel { .. } => Err(Error::not_supported(
"ACL management is not yet supported in sentinel mode",
)),
}
}
pub async fn from_url(url: &str) -> Result<Self> {
let conn = crate::backend::RedisConnectionType::single(url)?;
Self::new(conn).await
}
async fn get_connection(&self) -> Result<MultiplexedConnection> {
let conn = self.client.get_multiplexed_async_connection().await?;
Ok(conn)
}
fn build_acl_rules(&self, config: &AclConfig) -> Vec<Rule> {
let mut rules = vec![
Rule::On,
Rule::ResetChannels,
Rule::AllCommands,
Rule::RemoveCategory("dangerous".to_string()),
Rule::RemoveCategory("admin".to_string()),
Rule::RemoveCommand("keys".to_string()),
Rule::RemoveCommand("info".to_string()),
Rule::RemoveCommand("select".to_string()),
Rule::AddPass(config.node_config.password.clone()),
Rule::Pattern(format!("asynq:{{{}}}:*", DEFAULT_QUEUE_NAME)),
config.node_config.asynq_key_pattern(),
];
rules.extend(AclConfig::default_key_patterns(
&config.node_config.username,
));
for key in &config.write_only_keys {
let formatted_key = if key.starts_with("%W~") {
key.clone()
} else {
format!("%W~{}", key)
};
rules.push(Rule::Other(formatted_key));
}
rules
}
}
#[async_trait]
impl AclManager for RedisAclManager {
async fn create_tenant_user(&self, config: &AclConfig) -> Result<()> {
if !config.enabled {
return Err(Error::config("ACL feature is not enabled"));
}
let mut conn = self.get_connection().await?;
let rules = self.build_acl_rules(config);
Ok(
conn
.acl_setuser_rules(&config.node_config.username, rules.as_slice())
.await
.map_err(Error::Redis)?,
)
}
async fn delete_tenant_user(&self, username: &str) -> Result<()> {
let mut conn = self.get_connection().await?;
let result: redis::RedisResult<usize> = conn.acl_deluser(&[username]).await;
match result {
Ok(deleted) => {
if deleted > 0 {
Ok(())
} else {
Err(Error::other(format!("User '{}' not found", username)))
}
}
Err(e) => Err(Error::Redis(e)),
}
}
async fn list_tenant_users(&self) -> Result<Vec<String>> {
let mut conn = self.get_connection().await?;
let result: redis::RedisResult<Vec<String>> = conn.acl_list().await;
match result {
Ok(users) => Ok(users),
Err(e) => Err(Error::Redis(e)),
}
}
async fn tenant_user_exists(&self, username: &str) -> Result<bool> {
let mut conn = self.get_connection().await?;
Ok(conn.acl_getuser(username).await?.is_some())
}
async fn update_tenant_user(&self, config: &AclConfig) -> Result<()> {
let username = &config.node_config.username;
if self.tenant_user_exists(username).await? {
self.delete_tenant_user(username).await?;
}
self.create_tenant_user(config).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::acl::NodeConfig;
#[test]
fn test_build_acl_rules() {
let node_config = NodeConfig::new("localhost:6379", "tenant1", "pass123", 0);
let acl_config = AclConfig::new(node_config)
.enable(true)
.add_write_only_key("custom:result:*");
assert!(acl_config.enabled);
assert_eq!(acl_config.node_config.username, "tenant1");
assert_eq!(acl_config.node_config.password, "pass123");
assert_eq!(acl_config.node_config.db, 0);
assert_eq!(acl_config.write_only_keys.len(), 1);
}
#[test]
fn test_acl_command_generation_disabled() {
let node_config = NodeConfig::new("localhost:6379", "tenant1", "pass123", 0);
let acl_config = AclConfig::new(node_config).enable(false);
assert!(!acl_config.enabled);
}
#[test]
fn test_acl_config_with_custom_patterns() {
let node_config = NodeConfig::new("localhost:6379", "tenant1", "pass123", 2);
let acl_config = AclConfig::new(node_config)
.enable(true)
.add_write_only_key("myapp:results:*");
assert_eq!(acl_config.write_only_keys.len(), 1);
assert_eq!(acl_config.node_config.db, 2);
}
#[test]
fn test_rule_types() {
let rule_on = Rule::On;
let rule_add_cat = Rule::AddCategory("all".to_string());
let rule_remove_cat = Rule::RemoveCategory("dangerous".to_string());
let rule_pattern = Rule::Pattern("asynq:*".to_string());
let rule_pass = Rule::AddPass("password".to_string());
assert_eq!(rule_on, Rule::On);
assert_eq!(rule_add_cat, Rule::AddCategory("all".to_string()));
assert_eq!(
rule_remove_cat,
Rule::RemoveCategory("dangerous".to_string())
);
assert_eq!(rule_pattern, Rule::Pattern("asynq:*".to_string()));
assert_eq!(rule_pass, Rule::AddPass("password".to_string()));
}
}