use std::sync::Arc;
use chrono::{DateTime, Utc, Duration};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use sa_token_adapter::storage::SaStorage;
use crate::error::{SaTokenError, SaTokenResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Client {
pub client_id: String,
pub client_secret: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub scope: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationCode {
pub code: String,
pub client_id: String,
pub user_id: String,
pub redirect_uri: String,
pub scope: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
pub access_token: String,
pub token_type: String,
pub expires_in: i64,
pub refresh_token: Option<String>,
pub scope: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2TokenInfo {
pub access_token: String,
pub client_id: String,
pub user_id: String,
pub scope: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub refresh_token: Option<String>,
}
pub struct OAuth2Manager {
storage: Arc<dyn SaStorage>,
code_ttl: i64,
token_ttl: i64,
refresh_token_ttl: i64,
}
impl OAuth2Manager {
pub fn new(storage: Arc<dyn SaStorage>) -> Self {
Self {
storage,
code_ttl: 600, token_ttl: 3600, refresh_token_ttl: 2592000, }
}
pub fn with_ttl(mut self, code_ttl: i64, token_ttl: i64, refresh_token_ttl: i64) -> Self {
self.code_ttl = code_ttl;
self.token_ttl = token_ttl;
self.refresh_token_ttl = refresh_token_ttl;
self
}
pub async fn register_client(&self, client: &OAuth2Client) -> SaTokenResult<()> {
let key = format!("oauth2:client:{}", client.client_id);
let value = serde_json::to_string(client)
.map_err(SaTokenError::SerializationError)?;
self.storage.set(&key, &value, None).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn get_client(&self, client_id: &str) -> SaTokenResult<OAuth2Client> {
let key = format!("oauth2:client:{}", client_id);
let value = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::OAuth2ClientNotFound)?;
serde_json::from_str(&value)
.map_err(SaTokenError::SerializationError)
}
pub async fn verify_client(&self, client_id: &str, client_secret: &str) -> SaTokenResult<bool> {
let client = self.get_client(client_id).await?;
Ok(client.client_secret == client_secret)
}
pub fn generate_authorization_code(
&self,
client_id: String,
user_id: String,
redirect_uri: String,
scope: Vec<String>,
) -> AuthorizationCode {
let now = Utc::now();
let code = format!("code_{}", Uuid::new_v4().simple());
AuthorizationCode {
code,
client_id,
user_id,
redirect_uri,
scope,
created_at: now,
expires_at: now + Duration::seconds(self.code_ttl),
}
}
pub async fn store_authorization_code(&self, auth_code: &AuthorizationCode) -> SaTokenResult<()> {
let key = format!("oauth2:code:{}", auth_code.code);
let value = serde_json::to_string(auth_code)
.map_err(SaTokenError::SerializationError)?;
let ttl = Some(std::time::Duration::from_secs(self.code_ttl as u64));
self.storage.set(&key, &value, ttl).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn get_authorization_code(&self, code: &str) -> SaTokenResult<AuthorizationCode> {
let key = format!("oauth2:code:{}", code);
let value = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::OAuth2CodeNotFound)?;
let auth_code: AuthorizationCode = serde_json::from_str(&value)
.map_err(SaTokenError::SerializationError)?;
if Utc::now() > auth_code.expires_at {
self.storage.delete(&key).await.ok();
return Err(SaTokenError::TokenExpired);
}
Ok(auth_code)
}
pub async fn consume_authorization_code(&self, code: &str) -> SaTokenResult<AuthorizationCode> {
let auth_code = self.get_authorization_code(code).await?;
let key = format!("oauth2:code:{}", code);
self.storage.delete(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(auth_code)
}
pub async fn exchange_code_for_token(
&self,
code: &str,
client_id: &str,
client_secret: &str,
redirect_uri: &str,
) -> SaTokenResult<AccessToken> {
if !self.verify_client(client_id, client_secret).await? {
return Err(SaTokenError::OAuth2InvalidCredentials);
}
let auth_code = self.consume_authorization_code(code).await?;
if auth_code.client_id != client_id {
return Err(SaTokenError::OAuth2ClientIdMismatch);
}
if auth_code.redirect_uri != redirect_uri {
return Err(SaTokenError::OAuth2RedirectUriMismatch);
}
self.generate_access_token(&auth_code.client_id, &auth_code.user_id, auth_code.scope).await
}
pub async fn generate_access_token(
&self,
client_id: &str,
user_id: &str,
scope: Vec<String>,
) -> SaTokenResult<AccessToken> {
let now = Utc::now();
let access_token = format!("at_{}", Uuid::new_v4().simple());
let refresh_token = format!("rt_{}", Uuid::new_v4().simple());
let token_info = OAuth2TokenInfo {
access_token: access_token.clone(),
client_id: client_id.to_string(),
user_id: user_id.to_string(),
scope: scope.clone(),
created_at: now,
expires_at: now + Duration::seconds(self.token_ttl),
refresh_token: Some(refresh_token.clone()),
};
let key = format!("oauth2:token:{}", access_token);
let value = serde_json::to_string(&token_info)
.map_err(SaTokenError::SerializationError)?;
let ttl = Some(std::time::Duration::from_secs(self.token_ttl as u64));
self.storage.set(&key, &value, ttl).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
let refresh_key = format!("oauth2:refresh:{}", refresh_token);
let refresh_value = serde_json::json!({
"user_id": user_id,
"client_id": client_id,
"scope": scope,
}).to_string();
let refresh_ttl = Some(std::time::Duration::from_secs(self.refresh_token_ttl as u64));
self.storage.set(&refresh_key, &refresh_value, refresh_ttl).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(AccessToken {
access_token,
token_type: "Bearer".to_string(),
expires_in: self.token_ttl,
refresh_token: Some(refresh_token),
scope,
})
}
pub async fn verify_access_token(&self, access_token: &str) -> SaTokenResult<OAuth2TokenInfo> {
let key = format!("oauth2:token:{}", access_token);
let value = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::OAuth2AccessTokenNotFound)?;
let token_info: OAuth2TokenInfo = serde_json::from_str(&value)
.map_err(SaTokenError::SerializationError)?;
if Utc::now() > token_info.expires_at {
self.storage.delete(&key).await.ok();
return Err(SaTokenError::TokenExpired);
}
Ok(token_info)
}
pub async fn refresh_access_token(
&self,
refresh_token: &str,
client_id: &str,
client_secret: &str,
) -> SaTokenResult<AccessToken> {
if !self.verify_client(client_id, client_secret).await? {
return Err(SaTokenError::OAuth2InvalidCredentials);
}
let key = format!("oauth2:refresh:{}", refresh_token);
let value = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::OAuth2RefreshTokenNotFound)?;
let data: serde_json::Value = serde_json::from_str(&value)
.map_err(SaTokenError::SerializationError)?;
let stored_client_id = data["client_id"].as_str()
.ok_or(SaTokenError::OAuth2InvalidRefreshToken)?;
if stored_client_id != client_id {
return Err(SaTokenError::OAuth2ClientIdMismatch);
}
let user_id = data["user_id"].as_str()
.ok_or(SaTokenError::OAuth2InvalidRefreshToken)?;
let scope: Vec<String> = data["scope"].as_array()
.ok_or(SaTokenError::OAuth2InvalidScope)?
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
self.generate_access_token(client_id, user_id, scope).await
}
pub async fn revoke_token(&self, token: &str) -> SaTokenResult<()> {
let access_key = format!("oauth2:token:{}", token);
let refresh_key = format!("oauth2:refresh:{}", token);
self.storage.delete(&access_key).await.ok();
self.storage.delete(&refresh_key).await.ok();
Ok(())
}
pub fn validate_redirect_uri(&self, client: &OAuth2Client, redirect_uri: &str) -> bool {
client.redirect_uris.iter().any(|uri| uri == redirect_uri)
}
pub fn validate_scope(&self, client: &OAuth2Client, requested_scope: &[String]) -> bool {
requested_scope.iter().all(|s| client.scope.contains(s))
}
}
#[cfg(test)]
mod tests {
use super::*;
use sa_token_storage_memory::MemoryStorage;
#[tokio::test]
async fn test_oauth2_authorization_code_flow() {
let storage = Arc::new(MemoryStorage::new());
let oauth2 = OAuth2Manager::new(storage);
let client = OAuth2Client {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["http://localhost:3000/callback".to_string()],
grant_types: vec!["authorization_code".to_string()],
scope: vec!["read".to_string(), "write".to_string()],
};
oauth2.register_client(&client).await.unwrap();
let auth_code = oauth2.generate_authorization_code(
"test_client".to_string(),
"user_123".to_string(),
"http://localhost:3000/callback".to_string(),
vec!["read".to_string()],
);
oauth2.store_authorization_code(&auth_code).await.unwrap();
let token = oauth2.exchange_code_for_token(
&auth_code.code,
"test_client",
"test_secret",
"http://localhost:3000/callback",
).await.unwrap();
assert_eq!(token.token_type, "Bearer");
assert!(token.refresh_token.is_some());
let token_info = oauth2.verify_access_token(&token.access_token).await.unwrap();
assert_eq!(token_info.user_id, "user_123");
assert_eq!(token_info.client_id, "test_client");
}
#[tokio::test]
async fn test_refresh_token() {
let storage = Arc::new(MemoryStorage::new());
let oauth2 = OAuth2Manager::new(storage);
let client = OAuth2Client {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["http://localhost:3000/callback".to_string()],
grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()],
scope: vec!["read".to_string()],
};
oauth2.register_client(&client).await.unwrap();
let token = oauth2.generate_access_token(
"test_client",
"user_123",
vec!["read".to_string()],
).await.unwrap();
let refresh_token = token.refresh_token.as_ref().unwrap();
let new_token = oauth2.refresh_access_token(
refresh_token,
"test_client",
"test_secret",
).await.unwrap();
assert_ne!(new_token.access_token, token.access_token);
}
}