use redis::{AsyncCommands, Script};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tokio::sync::Mutex;
#[cfg(feature = "rbac")]
use crate::models::RTokenInfo;
#[cfg(feature = "actix")]
use actix_web::web;
fn now_ms_u64() -> u64 {
u64::try_from(chrono::Utc::now().timestamp_millis()).unwrap_or(0)
}
fn add_ttl_ms(now_ms: u64, ttl_seconds: u64) -> u64 {
let ttl_ms = (ttl_seconds as u128).saturating_mul(1000);
(now_ms as u128)
.saturating_add(ttl_ms)
.min(u64::MAX as u128) as u64
}
#[derive(Clone)]
pub struct RTokenRedisManager {
prefix: String,
connections: Arc<Vec<Mutex<redis::aio::ConnectionManager>>>,
next_index: Arc<AtomicUsize>,
}
#[cfg(any(feature = "actix", feature = "axum"))]
#[derive(Debug)]
pub struct RRedisUser {
pub id: String,
pub token: String,
#[cfg(feature = "rbac")]
pub roles: Vec<String>,
}
#[cfg(feature = "actix")]
impl actix_web::FromRequest for RRedisUser {
type Error = actix_web::Error;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self, Self::Error>>>>;
fn from_request(
req: &actix_web::HttpRequest,
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
let manager = match req.app_data::<web::Data<RTokenRedisManager>>() {
Some(manager) => manager.clone(),
None => {
return Box::pin(async {
Err(actix_web::error::ErrorInternalServerError(
"Token manager not found",
))
});
}
};
let token = crate::extract_token_from_request(req);
Box::pin(async move {
let token = token.ok_or_else(|| actix_web::error::ErrorUnauthorized("Unauthorized"))?;
#[cfg(feature = "rbac")]
let user_info = manager
.validate_with_roles(&token)
.await
.map_err(|_| actix_web::error::ErrorInternalServerError("Redis error"))?;
#[cfg(not(feature = "rbac"))]
let user_info = manager
.validate(&token)
.await
.map_err(|_| actix_web::error::ErrorInternalServerError("Redis error"))?;
#[cfg(feature = "rbac")]
if let Some((user_id, roles)) = user_info {
return Ok(Self {
id: user_id,
token,
roles,
});
}
#[cfg(not(feature = "rbac"))]
if let Some(user_id) = user_info {
return Ok(Self { id: user_id, token });
}
Err(actix_web::error::ErrorUnauthorized("Invalid token"))
})
}
}
impl RTokenRedisManager {
async fn lock_connection(
&self,
) -> Result<tokio::sync::MutexGuard<'_, redis::aio::ConnectionManager>, redis::RedisError> {
let len = self.connections.len();
if len == 0 {
return Err(redis::RedisError::from((
redis::ErrorKind::Client,
"no redis connections",
)));
}
let index = self.next_index.fetch_add(1, Ordering::Relaxed) % len;
match self.connections.get(index) {
Some(conn) => Ok(conn.lock().await),
None => Err(redis::RedisError::from((
redis::ErrorKind::Client,
"no redis connections",
))),
}
}
pub fn new(prefix: impl Into<String>, connection: redis::aio::ConnectionManager) -> Self {
let mut prefix = prefix.into();
if !prefix.ends_with(':') {
prefix.push(':');
}
Self {
prefix,
connections: Arc::new(vec![Mutex::new(connection)]),
next_index: Arc::new(AtomicUsize::new(0)),
}
}
#[cfg(feature = "rbac")]
pub async fn login_with_roles(
&self,
user_id: &str,
ttl_seconds: u64,
roles: impl Into<Vec<String>>,
) -> Result<String, redis::RedisError> {
let token = uuid::Uuid::new_v4().to_string();
let key = self.key(&token);
let mut connection = self.lock_connection().await?;
let expire_at = add_ttl_ms(now_ms_u64(), ttl_seconds);
let info = RTokenInfo {
user_id: user_id.to_string(),
expire_at,
roles: roles.into(),
};
let value = serde_json::to_string(&info).map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Client,
"serialize token info",
e.to_string(),
))
})?;
let _: () = connection.set_ex(key, value, ttl_seconds).await?;
Ok(token)
}
#[cfg(feature = "rbac")]
pub async fn get_roles(&self, token: &str) -> Result<Option<Vec<String>>, redis::RedisError> {
Ok(self
.validate_with_roles(token)
.await?
.map(|(_user_id, roles)| roles))
}
#[cfg(feature = "rbac")]
pub async fn set_roles(
&self,
token: &str,
roles: impl Into<Vec<String>>,
) -> Result<(), redis::RedisError> {
let key = self.key(token);
let mut connection = self.lock_connection().await?;
let ttl_seconds: i64 = connection.ttl(&key).await?;
if ttl_seconds == -2 {
return Ok(());
}
let value: Option<String> = connection.get(&key).await?;
let Some(value) = value else {
return Ok(());
};
let mut info = serde_json::from_str::<RTokenInfo>(&value).unwrap_or(RTokenInfo {
user_id: value,
expire_at: 0,
roles: Vec::new(),
});
info.roles = roles.into();
let new_value = serde_json::to_string(&info).map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Client,
"serialize token info",
e.to_string(),
))
})?;
match ttl_seconds {
ttl if ttl >= 0 => {
let _: () = connection.set_ex(key, new_value, ttl as u64).await?;
}
_ => {
let _: () = connection.set(key, new_value).await?;
}
}
Ok(())
}
#[cfg(feature = "rbac")]
pub async fn validate_with_roles(
&self,
token: &str,
) -> Result<Option<(String, Vec<String>)>, redis::RedisError> {
let key = self.key(token);
let mut connection = self.lock_connection().await?;
let value: Option<String> = connection.get(key).await?;
let Some(value) = value else {
return Ok(None);
};
let info = serde_json::from_str::<RTokenInfo>(&value).unwrap_or(RTokenInfo {
user_id: value,
expire_at: 0,
roles: Vec::new(),
});
Ok(Some((info.user_id, info.roles)))
}
#[cfg(feature = "rbac")]
pub async fn validate(&self, token: &str) -> Result<Option<String>, redis::RedisError> {
Ok(self
.validate_with_roles(token)
.await?
.map(|(user_id, _roles)| user_id))
}
pub async fn connect(
redis_url: &str,
prefix: impl Into<String>,
) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(redis_url)?;
let mut connections = Vec::with_capacity(4);
for _ in 0..4 {
connections.push(Mutex::new(client.get_connection_manager().await?));
}
Ok(Self {
prefix: {
let mut prefix = prefix.into();
if !prefix.ends_with(':') {
prefix.push(':');
}
prefix
},
connections: Arc::new(connections),
next_index: Arc::new(AtomicUsize::new(0)),
})
}
pub async fn ttl_seconds(&self, token: &str) -> Result<Option<i64>, redis::RedisError> {
let key = self.key(token);
let mut connection = self.lock_connection().await?;
let ttl: i64 = connection.ttl(key).await?;
if ttl == -2 {
return Ok(None);
}
Ok(Some(ttl))
}
pub async fn renew(&self, token: &str, ttl_seconds: u64) -> Result<bool, redis::RedisError> {
let key = self.key(token);
let mut connection = self.lock_connection().await?;
let seconds = i64::try_from(ttl_seconds).unwrap_or(i64::MAX);
let updated: bool = connection.expire(key, seconds).await?;
Ok(updated)
}
pub async fn rotate(
&self,
token: &str,
ttl_seconds: u64,
) -> Result<Option<String>, redis::RedisError> {
let old_key = self.key(token);
let mut connection = self.lock_connection().await?;
let mut raw_value: Option<String> = connection.get(&old_key).await?;
if raw_value.is_none() {
return Ok(None);
}
let new_token = uuid::Uuid::new_v4().to_string();
let new_key = self.key(&new_token);
let script = Script::new(
r#"
local old_key = KEYS[1]
local new_key = KEYS[2]
local ttl = tonumber(ARGV[1])
local expected = ARGV[2]
local new_value = ARGV[3]
local cur = redis.call('GET', old_key)
if (not cur) or (cur ~= expected) then
return 0
end
redis.call('SETEX', new_key, ttl, new_value)
redis.call('DEL', old_key)
return 1
"#,
);
for _ in 0..2 {
let Some(current_value) = raw_value.as_ref() else {
return Ok(None);
};
#[cfg(feature = "rbac")]
let new_value = {
let expire_at = add_ttl_ms(now_ms_u64(), ttl_seconds);
match serde_json::from_str::<RTokenInfo>(current_value) {
Ok(mut info) => {
info.expire_at = expire_at;
serde_json::to_string(&info).map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Client,
"serialize token info",
e.to_string(),
))
})?
}
Err(_) => current_value.clone(),
}
};
#[cfg(not(feature = "rbac"))]
let new_value = current_value.clone();
let ok: i32 = script
.key(&old_key)
.key(&new_key)
.arg(ttl_seconds)
.arg(current_value)
.arg(&new_value)
.invoke_async(&mut *connection)
.await?;
if ok == 1 {
return Ok(Some(new_token));
}
raw_value = connection.get(&old_key).await?;
}
Ok(None)
}
fn key(&self, token: &str) -> String {
format!("{}{}", self.prefix, token)
}
pub async fn login(
&self,
user_id: &str,
ttl_seconds: u64,
) -> Result<String, redis::RedisError> {
let token = uuid::Uuid::new_v4().to_string();
let key = self.key(&token);
let mut connection = self.lock_connection().await?;
let _: () = connection.set_ex(key, user_id, ttl_seconds).await?;
Ok(token)
}
pub async fn logout(&self, token: &str) -> Result<(), redis::RedisError> {
let key = self.key(token);
let mut connection = self.lock_connection().await?;
let _: i64 = connection.del(key).await?;
Ok(())
}
#[cfg(not(feature = "rbac"))]
pub async fn validate(&self, token: &str) -> Result<Option<String>, redis::RedisError> {
let key = self.key(token);
let mut connection = self.lock_connection().await?;
let user_id: Option<String> = connection.get(key).await?;
Ok(user_id)
}
}