use std::sync::Arc;
use chrono::{DateTime, Utc, Duration};
use sa_token_adapter::storage::SaStorage;
use crate::error::{SaTokenError, SaTokenResult};
use crate::token::TokenValue;
use crate::token::TokenGenerator;
use crate::config::SaTokenConfig;
use uuid::Uuid;
#[derive(Clone)]
pub struct RefreshTokenManager {
storage: Arc<dyn SaStorage>,
config: Arc<SaTokenConfig>,
}
impl RefreshTokenManager {
pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
Self { storage, config }
}
pub fn generate(&self, login_id: &str) -> String {
format!(
"refresh_{}_{}_{}",
Utc::now().timestamp_millis(),
login_id,
Uuid::new_v4().simple()
)
}
pub async fn store(
&self,
refresh_token: &str,
access_token: &str,
login_id: &str,
) -> SaTokenResult<()> {
self.store_with_extra(refresh_token, access_token, login_id, None).await
}
pub async fn store_with_extra(
&self,
refresh_token: &str,
access_token: &str,
login_id: &str,
extra_data: Option<&serde_json::Value>,
) -> SaTokenResult<()> {
let key = format!("sa:refresh:{}", refresh_token);
let expire_time = if self.config.refresh_token_timeout > 0 {
Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
} else {
None
};
let mut obj = serde_json::json!({
"access_token": access_token,
"login_id": login_id,
"created_at": Utc::now().to_rfc3339(),
"expire_time": expire_time.map(|t| t.to_rfc3339()),
});
if let Some(extra) = extra_data {
obj["extra_data"] = extra.clone();
}
let value = obj.to_string();
let ttl = if self.config.refresh_token_timeout > 0 {
Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
} else {
None
};
self.storage.set(&key, &value, ttl)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
let key = format!("sa:refresh:{}", refresh_token);
let value_str = self.storage.get(&key)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::RefreshTokenNotFound)?;
let value: serde_json::Value = serde_json::from_str(&value_str)
.map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
let login_id = value["login_id"].as_str()
.ok_or(SaTokenError::RefreshTokenMissingLoginId)?
.to_string();
if let Some(expire_str) = value["expire_time"].as_str() {
let expire_time = DateTime::parse_from_rfc3339(expire_str)
.map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
.with_timezone(&Utc);
if Utc::now() > expire_time {
self.delete(refresh_token).await?;
return Err(SaTokenError::TokenExpired);
}
}
Ok(login_id)
}
pub async fn refresh_access_token(
&self,
refresh_token: &str,
) -> SaTokenResult<(TokenValue, String)> {
let login_id = self.validate(refresh_token).await?;
let key = format!("sa:refresh:{}", refresh_token);
let value_str = self.storage.get(&key)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::RefreshTokenNotFound)?;
let mut value: serde_json::Value = serde_json::from_str(&value_str)
.map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
let extra_data = value.get("extra_data").cloned();
let new_access_token = match &extra_data {
Some(extra) => TokenGenerator::generate_with_login_id_and_extra(&self.config, &login_id, extra),
None => TokenGenerator::generate_with_login_id(&self.config, &login_id),
};
value["access_token"] = serde_json::json!(new_access_token.as_str());
value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
let ttl = if self.config.refresh_token_timeout > 0 {
Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
} else {
None
};
self.storage.set(&key, &value.to_string(), ttl)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok((new_access_token, login_id))
}
pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
let key = format!("sa:refresh:{}", refresh_token);
self.storage.delete(&key)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn get_user_refresh_tokens(&self, _login_id: &str) -> SaTokenResult<Vec<String>> {
Ok(vec![])
}
pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
let tokens = self.get_user_refresh_tokens(login_id).await?;
for token in tokens {
self.delete(&token).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sa_token_storage_memory::MemoryStorage;
use crate::config::TokenStyle;
fn create_test_config() -> Arc<SaTokenConfig> {
Arc::new(SaTokenConfig {
token_style: TokenStyle::Uuid,
timeout: 3600,
refresh_token_timeout: 7200,
enable_refresh_token: true,
..Default::default()
})
}
#[tokio::test]
async fn test_refresh_token_generation() {
let storage = Arc::new(MemoryStorage::new());
let config = create_test_config();
let refresh_mgr = RefreshTokenManager::new(storage, config);
let token1 = refresh_mgr.generate("user_123");
let token2 = refresh_mgr.generate("user_123");
assert_ne!(token1, token2);
assert!(token1.starts_with("refresh_"));
}
#[tokio::test]
async fn test_refresh_token_store_and_validate() {
let storage = Arc::new(MemoryStorage::new());
let config = create_test_config();
let refresh_mgr = RefreshTokenManager::new(storage, config);
let refresh_token = refresh_mgr.generate("user_123");
let access_token = "access_token_123";
refresh_mgr.store(&refresh_token, access_token, "user_123").await.unwrap();
let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
assert_eq!(login_id, "user_123");
}
#[tokio::test]
async fn test_refresh_access_token() {
let storage = Arc::new(MemoryStorage::new());
let config = create_test_config();
let refresh_mgr = RefreshTokenManager::new(storage, config);
let refresh_token = refresh_mgr.generate("user_123");
let old_access_token = "old_access_token";
refresh_mgr.store(&refresh_token, old_access_token, "user_123").await.unwrap();
let (new_access_token, login_id) = refresh_mgr.refresh_access_token(&refresh_token).await.unwrap();
assert_eq!(login_id, "user_123");
assert_ne!(new_access_token.as_str(), old_access_token);
}
#[tokio::test]
async fn test_delete_refresh_token() {
let storage = Arc::new(MemoryStorage::new());
let config = create_test_config();
let refresh_mgr = RefreshTokenManager::new(storage, config);
let refresh_token = refresh_mgr.generate("user_123");
refresh_mgr.store(&refresh_token, "access", "user_123").await.unwrap();
refresh_mgr.delete(&refresh_token).await.unwrap();
let result = refresh_mgr.validate(&refresh_token).await;
assert!(result.is_err());
}
}