use async_trait::async_trait;
use chrono::Utc;
use redis::AsyncCommands;
use crate::core::{RefreshTokenData, TokenStore};
use crate::errors::JwtError;
#[derive(Debug, Clone)]
pub struct RedisConfig {
pub addr: String,
pub password: Option<String>,
pub db: i32,
pub pool_size: u32,
pub key_prefix: String,
pub tls: bool,
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
addr: "redis://127.0.0.1:6379/".to_string(),
password: None,
db: 0,
pool_size: 10,
key_prefix: "actix-jwt:".to_string(),
tls: false,
}
}
}
pub struct RedisRefreshTokenStore {
conn: redis::aio::ConnectionManager,
prefix: String,
}
impl std::fmt::Debug for RedisRefreshTokenStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisRefreshTokenStore")
.field("prefix", &self.prefix)
.finish()
}
}
impl RedisRefreshTokenStore {
pub async fn new(config: &RedisConfig) -> Result<Self, JwtError> {
let mut url = config.addr.clone();
if config.db != 0 && !url.contains('/') {
url = format!("{}/{}", url.trim_end_matches('/'), config.db);
}
if config.tls && url.starts_with("redis://") {
url = url.replacen("redis://", "rediss://", 1);
}
let client = redis::Client::open(url.as_str())
.map_err(|e| JwtError::Internal(format!("Failed to create Redis client: {}", e)))?;
let conn = redis::aio::ConnectionManager::new(client)
.await
.map_err(|e| JwtError::Internal(format!("Failed to connect to Redis: {}", e)))?;
let store = Self {
conn,
prefix: config.key_prefix.clone(),
};
store.ping().await?;
Ok(store)
}
fn key(&self, token: &str) -> String {
format!("{}{}", self.prefix, token)
}
pub async fn ping(&self) -> Result<(), JwtError> {
let mut conn = self.conn.clone();
let pong: String = redis::cmd("PING")
.query_async(&mut conn)
.await
.map_err(|e| JwtError::Internal(format!("Redis PING failed: {}", e)))?;
if pong != "PONG" {
return Err(JwtError::Internal(format!(
"Unexpected PING response: {}",
pong
)));
}
Ok(())
}
pub async fn close(self) {
drop(self.conn);
}
pub async fn flush_db(&self) -> Result<(), JwtError> {
let mut conn = self.conn.clone();
redis::cmd("FLUSHDB")
.query_async::<()>(&mut conn)
.await
.map_err(|e| JwtError::Internal(format!("Redis FLUSHDB failed: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl TokenStore for RedisRefreshTokenStore {
async fn set(
&self,
token: &str,
user_data: serde_json::Value,
expiry: chrono::DateTime<Utc>,
) -> Result<(), JwtError> {
if token.is_empty() {
return Err(JwtError::TokenEmpty);
}
let ttl_secs = (expiry - Utc::now()).num_seconds();
if ttl_secs < 1 {
return Err(JwtError::ExpiryInPast);
}
let data = RefreshTokenData {
user_data,
expiry,
created: Utc::now(),
};
let serialized = serde_json::to_string(&data)
.map_err(|e| JwtError::Internal(format!("Failed to serialize token data: {}", e)))?;
let mut conn = self.conn.clone();
conn.set_ex::<_, _, ()>(self.key(token), serialized, ttl_secs as u64)
.await
.map_err(|e| JwtError::Internal(format!("Redis SETEX failed: {}", e)))?;
Ok(())
}
async fn get(&self, token: &str) -> Result<serde_json::Value, JwtError> {
if token.is_empty() {
return Err(JwtError::TokenEmpty);
}
let mut conn = self.conn.clone();
let result: Option<String> = conn
.get(self.key(token))
.await
.map_err(|e| JwtError::Internal(format!("Redis GET failed: {}", e)))?;
match result {
Some(serialized) => {
let data: RefreshTokenData = serde_json::from_str(&serialized).map_err(|e| {
JwtError::Internal(format!("Failed to deserialize token data: {}", e))
})?;
if data.is_expired() {
let mut del_conn = self.conn.clone();
let _ = del_conn.del::<_, ()>(self.key(token)).await;
return Err(JwtError::RefreshTokenNotFound);
}
Ok(data.user_data)
}
None => Err(JwtError::RefreshTokenNotFound),
}
}
async fn delete(&self, token: &str) -> Result<(), JwtError> {
if token.is_empty() {
return Ok(());
}
let mut conn = self.conn.clone();
conn.del::<_, ()>(self.key(token))
.await
.map_err(|e| JwtError::Internal(format!("Redis DEL failed: {}", e)))?;
Ok(())
}
async fn cleanup(&self) -> Result<usize, JwtError> {
let pattern = format!("{}*", self.prefix);
let mut conn = self.conn.clone();
let mut removed = 0usize;
let mut cursor: u64 = 0;
loop {
let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await
.map_err(|e| JwtError::Internal(format!("Redis SCAN failed: {}", e)))?;
for key in &keys {
let value: Option<String> = conn
.get(key)
.await
.map_err(|e| JwtError::Internal(format!("Redis GET failed: {}", e)))?;
if let Some(serialized) = value {
if let Ok(data) = serde_json::from_str::<RefreshTokenData>(&serialized) {
if data.is_expired() {
conn.del::<_, ()>(key).await.map_err(|e| {
JwtError::Internal(format!("Redis DEL failed: {}", e))
})?;
removed += 1;
}
}
}
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
Ok(removed)
}
async fn count(&self) -> Result<usize, JwtError> {
let pattern = format!("{}*", self.prefix);
let mut conn = self.conn.clone();
let mut total = 0usize;
let mut cursor: u64 = 0;
loop {
let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await
.map_err(|e| JwtError::Internal(format!("Redis SCAN failed: {}", e)))?;
total += keys.len();
cursor = next_cursor;
if cursor == 0 {
break;
}
}
Ok(total)
}
}