use std::time::Duration;
use async_trait::async_trait;
use redis::{Client, AsyncCommands, aio::ConnectionManager};
use serde::{Deserialize, Serialize};
use sa_token_adapter::storage::{SaStorage, StorageResult, StorageError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedisConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub password: Option<String>,
#[serde(default)]
pub database: u8,
#[serde(default = "default_pool_size")]
pub pool_size: u32,
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
password: None,
database: 0,
pool_size: default_pool_size(),
}
}
}
impl RedisConfig {
pub fn to_url(&self) -> String {
if let Some(password) = &self.password {
format!("redis://:{}@{}:{}/{}", password, self.host, self.port, self.database)
} else {
format!("redis://{}:{}/{}", self.host, self.port, self.database)
}
}
}
fn default_host() -> String {
"localhost".to_string()
}
fn default_port() -> u16 {
6379
}
fn default_pool_size() -> u32 {
10
}
#[derive(Clone)]
pub struct RedisStorage {
client: ConnectionManager,
key_prefix: String,
}
impl RedisStorage {
pub async fn new(redis_url: &str, key_prefix: impl Into<String>) -> StorageResult<Self> {
let client = Client::open(redis_url)
.map_err(|e| StorageError::ConnectionError(e.to_string()))?;
let connection_manager = ConnectionManager::new(client).await
.map_err(|e| StorageError::ConnectionError(e.to_string()))?;
Ok(Self {
client: connection_manager,
key_prefix: key_prefix.into(),
})
}
pub async fn from_config(config: RedisConfig, key_prefix: impl Into<String>) -> StorageResult<Self> {
let redis_url = config.to_url();
Self::new(&redis_url, key_prefix).await
}
pub fn builder() -> RedisStorageBuilder {
RedisStorageBuilder::default()
}
fn full_key(&self, key: &str) -> String {
format!("{}{}", self.key_prefix, key)
}
}
#[derive(Default)]
pub struct RedisStorageBuilder {
config: RedisConfig,
key_prefix: Option<String>,
}
impl RedisStorageBuilder {
pub fn host(mut self, host: impl Into<String>) -> Self {
self.config.host = host.into();
self
}
pub fn port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.config.password = Some(password.into());
self
}
pub fn database(mut self, database: u8) -> Self {
self.config.database = database;
self
}
pub fn pool_size(mut self, size: u32) -> Self {
self.config.pool_size = size;
self
}
pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
self.key_prefix = Some(prefix.into());
self
}
pub async fn build(self) -> StorageResult<RedisStorage> {
let key_prefix = self.key_prefix
.expect("key_prefix must be set before building RedisStorage");
RedisStorage::from_config(self.config, key_prefix).await
}
}
#[async_trait]
impl SaStorage for RedisStorage {
async fn get(&self, key: &str) -> StorageResult<Option<String>> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
conn.get(&full_key).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
if let Some(ttl) = ttl {
conn.set_ex(&full_key, value, ttl.as_secs()).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
} else {
conn.set(&full_key, value).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
}
async fn delete(&self, key: &str) -> StorageResult<()> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
conn.del(&full_key).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn exists(&self, key: &str) -> StorageResult<bool> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
conn.exists(&full_key).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
conn.expire(&full_key, ttl.as_secs() as i64).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
let ttl_secs: i64 = conn.ttl(&full_key).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
match ttl_secs {
-2 => Ok(None), -1 => Ok(None), secs if secs > 0 => Ok(Some(Duration::from_secs(secs as u64))),
_ => Ok(Some(Duration::from_secs(0))),
}
}
async fn mget(&self, keys: &[&str]) -> StorageResult<Vec<Option<String>>> {
let mut conn = self.client.clone();
let full_keys: Vec<String> = keys.iter().map(|k| self.full_key(k)).collect();
conn.mget(&full_keys).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn mset(&self, items: &[(&str, &str)], ttl: Option<Duration>) -> StorageResult<()> {
let mut conn = self.client.clone();
let full_items: Vec<(String, &str)> = items.iter()
.map(|(k, v)| (self.full_key(k), *v))
.collect();
let mut pipe = redis::pipe();
for (key, value) in &full_items {
if let Some(ttl) = ttl {
pipe.set_ex(key, *value, ttl.as_secs());
} else {
pipe.set(key, *value);
}
}
pipe.query_async(&mut conn).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn mdel(&self, keys: &[&str]) -> StorageResult<()> {
let mut conn = self.client.clone();
let full_keys: Vec<String> = keys.iter().map(|k| self.full_key(k)).collect();
conn.del(&full_keys).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn incr(&self, key: &str) -> StorageResult<i64> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
conn.incr(&full_key, 1).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn decr(&self, key: &str) -> StorageResult<i64> {
let mut conn = self.client.clone();
let full_key = self.full_key(key);
conn.decr(&full_key, 1).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))
}
async fn clear(&self) -> StorageResult<()> {
let mut conn = self.client.clone();
let pattern = format!("{}*", self.key_prefix);
let keys: Vec<String> = conn.keys(&pattern).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
if !keys.is_empty() {
conn.del::<_, ()>(&keys).await
.map_err(|e| StorageError::OperationFailed(e.to_string()))?;
}
Ok(())
}
}