Skip to main content

wae_authentication/
lib.rs

1#![doc = include_str!("../readme.md")]
2#![warn(missing_docs)]
3
4pub mod jwt;
5pub mod oauth2;
6pub mod saml;
7pub mod totp;
8
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, fmt};
11
12/// 认证错误类型
13#[derive(Debug)]
14pub enum AuthError {
15    /// 认证失败
16    AuthenticationFailed(String),
17
18    /// 无效凭证
19    InvalidCredentials,
20
21    /// Token 无效
22    InvalidToken(String),
23
24    /// Token 过期
25    TokenExpired,
26
27    /// 权限不足
28    PermissionDenied(String),
29
30    /// 用户不存在
31    UserNotFound(String),
32
33    /// 用户已存在
34    UserAlreadyExists(String),
35
36    /// 密码不符合要求
37    PasswordRequirement(String),
38
39    /// 账户被锁定
40    AccountLocked,
41
42    /// 账户未验证
43    AccountNotVerified,
44
45    /// 操作超时
46    Timeout(String),
47
48    /// 服务内部错误
49    Internal(String),
50}
51
52impl fmt::Display for AuthError {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            AuthError::AuthenticationFailed(msg) => write!(f, "Authentication failed: {}", msg),
56            AuthError::InvalidCredentials => write!(f, "Invalid credentials"),
57            AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg),
58            AuthError::TokenExpired => write!(f, "Token expired"),
59            AuthError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
60            AuthError::UserNotFound(msg) => write!(f, "User not found: {}", msg),
61            AuthError::UserAlreadyExists(msg) => write!(f, "User already exists: {}", msg),
62            AuthError::PasswordRequirement(msg) => write!(f, "Password does not meet requirements: {}", msg),
63            AuthError::AccountLocked => write!(f, "Account is locked"),
64            AuthError::AccountNotVerified => write!(f, "Account is not verified"),
65            AuthError::Timeout(msg) => write!(f, "Operation timeout: {}", msg),
66            AuthError::Internal(msg) => write!(f, "Auth internal error: {}", msg),
67        }
68    }
69}
70
71impl std::error::Error for AuthError {}
72
73/// 认证操作结果类型
74pub type AuthResult<T> = Result<T, AuthError>;
75
76/// 用户 ID 类型
77pub type UserId = String;
78
79/// 角色 ID 类型
80pub type RoleId = String;
81
82/// 权限代码类型
83pub type PermissionCode = String;
84
85/// 用户信息
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct UserInfo {
88    /// 用户 ID
89    pub id: UserId,
90    /// 用户名
91    pub username: String,
92    /// 邮箱
93    pub email: Option<String>,
94    /// 手机号
95    pub phone: Option<String>,
96    /// 显示名称
97    pub display_name: Option<String>,
98    /// 头像 URL
99    pub avatar_url: Option<String>,
100    /// 是否已验证
101    pub verified: bool,
102    /// 是否已禁用
103    pub disabled: bool,
104    /// 自定义属性
105    pub attributes: HashMap<String, serde_json::Value>,
106    /// 创建时间
107    pub created_at: i64,
108    /// 更新时间
109    pub updated_at: i64,
110}
111
112/// 用户凭证
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct Credentials {
115    /// 用户名或邮箱
116    pub identifier: String,
117    /// 密码
118    pub password: String,
119    /// 额外参数
120    pub extra: HashMap<String, String>,
121}
122
123/// 认证 Token
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct AuthToken {
126    /// 访问令牌
127    pub access_token: String,
128    /// 刷新令牌
129    pub refresh_token: Option<String>,
130    /// 令牌类型
131    pub token_type: String,
132    /// 过期时间 (秒)
133    pub expires_in: u64,
134    /// 过期时间戳
135    pub expires_at: i64,
136}
137
138/// Token 验证结果
139#[derive(Debug, Clone)]
140pub struct TokenValidation {
141    /// 用户 ID
142    pub user_id: UserId,
143    /// 用户信息
144    pub user: Option<UserInfo>,
145    /// 角色
146    pub roles: Vec<Role>,
147    /// 权限
148    pub permissions: Vec<PermissionCode>,
149    /// Token 元数据
150    pub metadata: HashMap<String, String>,
151}
152
153/// 角色信息
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct Role {
156    /// 角色 ID
157    pub id: RoleId,
158    /// 角色名称
159    pub name: String,
160    /// 角色描述
161    pub description: Option<String>,
162    /// 权限列表
163    pub permissions: Vec<PermissionCode>,
164}
165
166/// 用户创建请求
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct CreateUserRequest {
169    /// 用户名
170    pub username: String,
171    /// 密码
172    pub password: String,
173    /// 邮箱
174    pub email: Option<String>,
175    /// 手机号
176    pub phone: Option<String>,
177    /// 显示名称
178    pub display_name: Option<String>,
179    /// 自定义属性
180    pub attributes: HashMap<String, serde_json::Value>,
181}
182
183/// 用户更新请求
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct UpdateUserRequest {
186    /// 显示名称
187    pub display_name: Option<String>,
188    /// 头像 URL
189    pub avatar_url: Option<String>,
190    /// 自定义属性
191    pub attributes: Option<HashMap<String, serde_json::Value>>,
192}
193
194/// 密码修改请求
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ChangePasswordRequest {
197    /// 旧密码
198    pub old_password: String,
199    /// 新密码
200    pub new_password: String,
201}
202
203/// 认证配置
204#[derive(Debug, Clone)]
205pub struct AuthConfig {
206    /// Token 过期时间 (秒)
207    pub token_expires_in: u64,
208    /// 刷新 Token 过期时间 (秒)
209    pub refresh_token_expires_in: u64,
210    /// Token 签发者
211    pub issuer: String,
212    /// Token 受众
213    pub audience: String,
214    /// 密码最小长度
215    pub password_min_length: usize,
216    /// 是否要求密码包含数字
217    pub password_require_digit: bool,
218    /// 是否要求密码包含特殊字符
219    pub password_require_special: bool,
220    /// 登录失败锁定阈值
221    pub max_login_attempts: u32,
222    /// 锁定时间 (秒)
223    pub lockout_duration: u64,
224}
225
226impl Default for AuthConfig {
227    fn default() -> Self {
228        Self {
229            token_expires_in: 3600,
230            refresh_token_expires_in: 86400 * 7,
231            issuer: "wae-authentication".to_string(),
232            audience: "wae-api".to_string(),
233            password_min_length: 8,
234            password_require_digit: true,
235            password_require_special: false,
236            max_login_attempts: 5,
237            lockout_duration: 1800,
238        }
239    }
240}
241
242/// 认证服务 trait
243
244pub trait AuthService: Send + Sync {
245    /// 用户登录
246    ///
247    /// # Arguments
248    /// * `credentials` - 用户凭证
249    ///
250    /// # Returns
251    /// 认证 Token
252    async fn login(&self, credentials: &Credentials) -> AuthResult<AuthToken>;
253
254    /// 用户登出
255    ///
256    /// # Arguments
257    /// * `token` - 访问令牌
258    async fn logout(&self, token: &str) -> AuthResult<()>;
259
260    /// 刷新 Token
261    ///
262    /// # Arguments
263    /// * `refresh_token` - 刷新令牌
264    async fn refresh_token(&self, refresh_token: &str) -> AuthResult<AuthToken>;
265
266    /// 验证 Token
267    ///
268    /// # Arguments
269    /// * `token` - 访问令牌
270    async fn validate_token(&self, token: &str) -> AuthResult<TokenValidation>;
271
272    /// 创建用户
273    ///
274    /// # Arguments
275    /// * `request` - 创建请求
276    async fn create_user(&self, request: &CreateUserRequest) -> AuthResult<UserInfo>;
277
278    /// 获取用户信息
279    ///
280    /// # Arguments
281    /// * `user_id` - 用户 ID
282    async fn get_user(&self, user_id: &str) -> AuthResult<UserInfo>;
283
284    /// 更新用户信息
285    ///
286    /// # Arguments
287    /// * `user_id` - 用户 ID
288    /// * `request` - 更新请求
289    async fn update_user(&self, user_id: &str, request: &UpdateUserRequest) -> AuthResult<UserInfo>;
290
291    /// 删除用户
292    ///
293    /// # Arguments
294    /// * `user_id` - 用户 ID
295    async fn delete_user(&self, user_id: &str) -> AuthResult<()>;
296
297    /// 修改密码
298    ///
299    /// # Arguments
300    /// * `user_id` - 用户 ID
301    /// * `request` - 密码修改请求
302    async fn change_password(&self, user_id: &str, request: &ChangePasswordRequest) -> AuthResult<()>;
303
304    /// 重置密码
305    ///
306    /// # Arguments
307    /// * `identifier` - 用户标识 (用户名/邮箱/手机号)
308    async fn reset_password(&self, identifier: &str) -> AuthResult<()>;
309
310    /// 验证用户权限
311    ///
312    /// # Arguments
313    /// * `user_id` - 用户 ID
314    /// * `permission` - 权限代码
315    async fn check_permission(&self, user_id: &str, permission: &str) -> AuthResult<bool>;
316
317    /// 获取用户角色
318    ///
319    /// # Arguments
320    /// * `user_id` - 用户 ID
321    async fn get_user_roles(&self, user_id: &str) -> AuthResult<Vec<Role>>;
322
323    /// 分配角色
324    ///
325    /// # Arguments
326    /// * `user_id` - 用户 ID
327    /// * `role_id` - 角色 ID
328    async fn assign_role(&self, user_id: &str, role_id: &str) -> AuthResult<()>;
329
330    /// 移除角色
331    ///
332    /// # Arguments
333    /// * `user_id` - 用户 ID
334    /// * `role_id` - 角色 ID
335    async fn remove_role(&self, user_id: &str, role_id: &str) -> AuthResult<()>;
336
337    /// 获取配置
338    fn config(&self) -> &AuthConfig;
339}
340
341/// API Key 认证 trait
342
343pub trait ApiKeyAuth: Send + Sync {
344    /// 验证 API Key
345    ///
346    /// # Arguments
347    /// * `api_key` - API Key
348    async fn validate_api_key(&self, api_key: &str) -> AuthResult<TokenValidation>;
349
350    /// 创建 API Key
351    ///
352    /// # Arguments
353    /// * `user_id` - 用户 ID
354    /// * `name` - Key 名称
355    /// * `expires_in` - 过期时间 (秒),None 表示永不过期
356    async fn create_api_key(&self, user_id: &str, name: &str, expires_in: Option<u64>) -> AuthResult<String>;
357
358    /// 撤销 API Key
359    ///
360    /// # Arguments
361    /// * `api_key` - API Key
362    async fn revoke_api_key(&self, api_key: &str) -> AuthResult<()>;
363
364    /// 列出用户的所有 API Key
365    ///
366    /// # Arguments
367    /// * `user_id` - 用户 ID
368    async fn list_api_keys(&self, user_id: &str) -> AuthResult<Vec<ApiKeyInfo>>;
369}
370
371/// API Key 信息
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct ApiKeyInfo {
374    /// Key ID
375    pub id: String,
376    /// Key 名称
377    pub name: String,
378    /// Key 前缀 (用于识别)
379    pub prefix: String,
380    /// 创建时间
381    pub created_at: i64,
382    /// 过期时间
383    pub expires_at: Option<i64>,
384    /// 最后使用时间
385    pub last_used_at: Option<i64>,
386}
387
388/// 内存认证实现
389pub mod memory {
390    use super::*;
391    use std::{collections::HashMap, sync::Arc};
392    use tokio::sync::RwLock;
393
394    /// 内存用户存储
395    struct UserRecord {
396        info: UserInfo,
397        password_hash: String,
398        roles: Vec<Role>,
399        login_attempts: u32,
400        locked_until: Option<i64>,
401    }
402
403    /// 内存认证服务
404    pub struct MemoryAuthService {
405        config: AuthConfig,
406        users: Arc<RwLock<HashMap<UserId, UserRecord>>>,
407        tokens: Arc<RwLock<HashMap<String, (UserId, i64)>>>,
408        refresh_tokens: Arc<RwLock<HashMap<String, (UserId, i64)>>>,
409    }
410
411    impl MemoryAuthService {
412        /// 创建新的内存认证服务
413        pub fn new(config: AuthConfig) -> Self {
414            Self {
415                config,
416                users: Arc::new(RwLock::new(HashMap::new())),
417                tokens: Arc::new(RwLock::new(HashMap::new())),
418                refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
419            }
420        }
421
422        fn hash_password(password: &str) -> String {
423            format!("hash:{}", password)
424        }
425
426        fn verify_password(password: &str, hash: &str) -> bool {
427            hash == &Self::hash_password(password)
428        }
429
430        fn generate_token() -> String {
431            format!("token_{}", uuid::Uuid::new_v4())
432        }
433
434        fn current_timestamp() -> i64 {
435            chrono::Utc::now().timestamp()
436        }
437    }
438
439    impl Default for MemoryAuthService {
440        fn default() -> Self {
441            Self::new(AuthConfig::default())
442        }
443    }
444
445    impl AuthService for MemoryAuthService {
446        async fn login(&self, credentials: &Credentials) -> AuthResult<AuthToken> {
447            let mut users = self.users.write().await;
448
449            let user = users
450                .values_mut()
451                .find(|u| u.info.username == credentials.identifier || u.info.email.as_deref() == Some(&credentials.identifier))
452                .ok_or(AuthError::InvalidCredentials)?;
453
454            if user.locked_until.map(|t| t > Self::current_timestamp()).unwrap_or(false) {
455                return Err(AuthError::AccountLocked);
456            }
457
458            if !Self::verify_password(&credentials.password, &user.password_hash) {
459                user.login_attempts += 1;
460                if user.login_attempts >= self.config.max_login_attempts {
461                    user.locked_until = Some(Self::current_timestamp() + self.config.lockout_duration as i64);
462                    return Err(AuthError::AccountLocked);
463                }
464                return Err(AuthError::InvalidCredentials);
465            }
466
467            user.login_attempts = 0;
468            user.locked_until = None;
469
470            let access_token = Self::generate_token();
471            let refresh_token = Self::generate_token();
472            let now = Self::current_timestamp();
473
474            self.tokens
475                .write()
476                .await
477                .insert(access_token.clone(), (user.info.id.clone(), now + self.config.token_expires_in as i64));
478            self.refresh_tokens
479                .write()
480                .await
481                .insert(refresh_token.clone(), (user.info.id.clone(), now + self.config.refresh_token_expires_in as i64));
482
483            Ok(AuthToken {
484                access_token,
485                refresh_token: Some(refresh_token),
486                token_type: "Bearer".to_string(),
487                expires_in: self.config.token_expires_in,
488                expires_at: now + self.config.token_expires_in as i64,
489            })
490        }
491
492        async fn logout(&self, token: &str) -> AuthResult<()> {
493            self.tokens.write().await.remove(token);
494            Ok(())
495        }
496
497        async fn refresh_token(&self, refresh_token: &str) -> AuthResult<AuthToken> {
498            let mut refresh_tokens = self.refresh_tokens.write().await;
499            let (user_id, _) =
500                refresh_tokens.remove(refresh_token).ok_or_else(|| AuthError::InvalidToken("Invalid refresh token".into()))?;
501
502            let access_token = Self::generate_token();
503            let new_refresh_token = Self::generate_token();
504            let now = Self::current_timestamp();
505
506            self.tokens
507                .write()
508                .await
509                .insert(access_token.clone(), (user_id.clone(), now + self.config.token_expires_in as i64));
510            refresh_tokens.insert(new_refresh_token.clone(), (user_id, now + self.config.refresh_token_expires_in as i64));
511
512            Ok(AuthToken {
513                access_token,
514                refresh_token: Some(new_refresh_token),
515                token_type: "Bearer".to_string(),
516                expires_in: self.config.token_expires_in,
517                expires_at: now + self.config.token_expires_in as i64,
518            })
519        }
520
521        async fn validate_token(&self, token: &str) -> AuthResult<TokenValidation> {
522            let tokens = self.tokens.read().await;
523            let (user_id, expires_at) = tokens.get(token).ok_or_else(|| AuthError::InvalidToken("Token not found".into()))?;
524
525            if *expires_at < Self::current_timestamp() {
526                return Err(AuthError::TokenExpired);
527            }
528
529            let users = self.users.read().await;
530            let user = users.get(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.clone()))?;
531
532            let permissions: Vec<PermissionCode> = user.roles.iter().flat_map(|r| r.permissions.iter().cloned()).collect();
533
534            Ok(TokenValidation {
535                user_id: user_id.clone(),
536                user: Some(user.info.clone()),
537                roles: user.roles.clone(),
538                permissions,
539                metadata: HashMap::new(),
540            })
541        }
542
543        async fn create_user(&self, request: &CreateUserRequest) -> AuthResult<UserInfo> {
544            let mut users = self.users.write().await;
545
546            if users.values().any(|u| u.info.username == request.username) {
547                return Err(AuthError::UserAlreadyExists(request.username.clone()));
548            }
549
550            let user_id = uuid::Uuid::new_v4().to_string();
551            let now = Self::current_timestamp();
552
553            let info = UserInfo {
554                id: user_id.clone(),
555                username: request.username.clone(),
556                email: request.email.clone(),
557                phone: request.phone.clone(),
558                display_name: request.display_name.clone(),
559                avatar_url: None,
560                verified: false,
561                disabled: false,
562                attributes: request.attributes.clone(),
563                created_at: now,
564                updated_at: now,
565            };
566
567            let record = UserRecord {
568                info: info.clone(),
569                password_hash: Self::hash_password(&request.password),
570                roles: vec![],
571                login_attempts: 0,
572                locked_until: None,
573            };
574
575            users.insert(user_id, record);
576            Ok(info)
577        }
578
579        async fn get_user(&self, user_id: &str) -> AuthResult<UserInfo> {
580            let users = self.users.read().await;
581            users.get(user_id).map(|u| u.info.clone()).ok_or_else(|| AuthError::UserNotFound(user_id.into()))
582        }
583
584        async fn update_user(&self, user_id: &str, request: &UpdateUserRequest) -> AuthResult<UserInfo> {
585            let mut users = self.users.write().await;
586            let user = users.get_mut(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.into()))?;
587
588            if let Some(name) = &request.display_name {
589                user.info.display_name = Some(name.clone());
590            }
591            if let Some(url) = &request.avatar_url {
592                user.info.avatar_url = Some(url.clone());
593            }
594            if let Some(attrs) = &request.attributes {
595                user.info.attributes = attrs.clone();
596            }
597            user.info.updated_at = Self::current_timestamp();
598
599            Ok(user.info.clone())
600        }
601
602        async fn delete_user(&self, user_id: &str) -> AuthResult<()> {
603            let mut users = self.users.write().await;
604            users.remove(user_id).map(|_| ()).ok_or_else(|| AuthError::UserNotFound(user_id.into()))
605        }
606
607        async fn change_password(&self, user_id: &str, request: &ChangePasswordRequest) -> AuthResult<()> {
608            let mut users = self.users.write().await;
609            let user = users.get_mut(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.into()))?;
610
611            if !Self::verify_password(&request.old_password, &user.password_hash) {
612                return Err(AuthError::InvalidCredentials);
613            }
614
615            user.password_hash = Self::hash_password(&request.new_password);
616            user.info.updated_at = Self::current_timestamp();
617            Ok(())
618        }
619
620        async fn reset_password(&self, identifier: &str) -> AuthResult<()> {
621            let users = self.users.read().await;
622            let user = users
623                .values()
624                .find(|u| u.info.username == identifier || u.info.email.as_deref() == Some(identifier))
625                .ok_or_else(|| AuthError::UserNotFound(identifier.into()))?;
626
627            tracing::info!("Password reset requested for user: {}", user.info.id);
628            Ok(())
629        }
630
631        async fn check_permission(&self, user_id: &str, permission: &str) -> AuthResult<bool> {
632            let users = self.users.read().await;
633            let user = users.get(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.into()))?;
634
635            Ok(user.roles.iter().any(|r| r.permissions.iter().any(|p| p == permission)))
636        }
637
638        async fn get_user_roles(&self, user_id: &str) -> AuthResult<Vec<Role>> {
639            let users = self.users.read().await;
640            let user = users.get(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.into()))?;
641            Ok(user.roles.clone())
642        }
643
644        async fn assign_role(&self, user_id: &str, role_id: &str) -> AuthResult<()> {
645            let mut users = self.users.write().await;
646            let user = users.get_mut(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.into()))?;
647
648            if !user.roles.iter().any(|r| r.id == role_id) {
649                user.roles.push(Role { id: role_id.into(), name: role_id.into(), description: None, permissions: vec![] });
650            }
651            Ok(())
652        }
653
654        async fn remove_role(&self, user_id: &str, role_id: &str) -> AuthResult<()> {
655            let mut users = self.users.write().await;
656            let user = users.get_mut(user_id).ok_or_else(|| AuthError::UserNotFound(user_id.into()))?;
657
658            user.roles.retain(|r| r.id != role_id);
659            Ok(())
660        }
661
662        fn config(&self) -> &AuthConfig {
663            &self.config
664        }
665    }
666}
667
668/// 便捷函数:创建内存认证服务
669pub fn memory_auth_service(config: AuthConfig) -> memory::MemoryAuthService {
670    memory::MemoryAuthService::new(config)
671}