Skip to main content

wae_authentication/
lib.rs

1#![doc = include_str!("readme.md")]
2#![warn(missing_docs)]
3
4pub mod csrf;
5pub mod jwt;
6pub mod oauth2;
7pub mod password;
8pub mod rate_limit;
9pub mod saml;
10pub mod totp;
11
12use crate::password::{PasswordHashConfig, PasswordHasherService};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use wae_types::WaeError;
16
17/// 认证操作结果类型
18pub type AuthResult<T> = Result<T, WaeError>;
19
20/// 用户 ID 类型
21pub type UserId = String;
22
23/// 角色 ID 类型
24pub type RoleId = String;
25
26/// 权限代码类型
27pub type PermissionCode = String;
28
29/// 用户信息
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct UserInfo {
32    /// 用户 ID
33    pub id: UserId,
34    /// 用户名
35    pub username: String,
36    /// 邮箱
37    pub email: Option<String>,
38    /// 手机号
39    pub phone: Option<String>,
40    /// 显示名称
41    pub display_name: Option<String>,
42    /// 头像 URL
43    pub avatar_url: Option<String>,
44    /// 是否已验证
45    pub verified: bool,
46    /// 是否已禁用
47    pub disabled: bool,
48    /// 自定义属性
49    pub attributes: HashMap<String, serde_json::Value>,
50    /// 创建时间
51    pub created_at: i64,
52    /// 更新时间
53    pub updated_at: i64,
54}
55
56/// 用户凭证
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Credentials {
59    /// 用户名或邮箱
60    pub identifier: String,
61    /// 密码
62    pub password: String,
63    /// 额外参数
64    pub extra: HashMap<String, String>,
65}
66
67/// 认证 Token
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct AuthToken {
70    /// 访问令牌
71    pub access_token: String,
72    /// 刷新令牌
73    pub refresh_token: Option<String>,
74    /// 令牌类型
75    pub token_type: String,
76    /// 过期时间 (秒)
77    pub expires_in: u64,
78    /// 过期时间戳
79    pub expires_at: i64,
80}
81
82/// Token 验证结果
83#[derive(Debug, Clone)]
84pub struct TokenValidation {
85    /// 用户 ID
86    pub user_id: UserId,
87    /// 用户信息
88    pub user: Option<UserInfo>,
89    /// 角色
90    pub roles: Vec<Role>,
91    /// 权限
92    pub permissions: Vec<PermissionCode>,
93    /// Token 元数据
94    pub metadata: HashMap<String, String>,
95}
96
97/// 角色信息
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct Role {
100    /// 角色 ID
101    pub id: RoleId,
102    /// 角色名称
103    pub name: String,
104    /// 角色描述
105    pub description: Option<String>,
106    /// 权限列表
107    pub permissions: Vec<PermissionCode>,
108}
109
110/// 用户创建请求
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct CreateUserRequest {
113    /// 用户名
114    pub username: String,
115    /// 密码
116    pub password: String,
117    /// 邮箱
118    pub email: Option<String>,
119    /// 手机号
120    pub phone: Option<String>,
121    /// 显示名称
122    pub display_name: Option<String>,
123    /// 自定义属性
124    pub attributes: HashMap<String, serde_json::Value>,
125}
126
127/// 用户更新请求
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct UpdateUserRequest {
130    /// 显示名称
131    pub display_name: Option<String>,
132    /// 头像 URL
133    pub avatar_url: Option<String>,
134    /// 自定义属性
135    pub attributes: Option<HashMap<String, serde_json::Value>>,
136}
137
138/// 密码修改请求
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ChangePasswordRequest {
141    /// 旧密码
142    pub old_password: String,
143    /// 新密码
144    pub new_password: String,
145}
146
147/// 认证配置
148#[derive(Debug, Clone)]
149pub struct AuthConfig {
150    /// Token 过期时间 (秒)
151    pub token_expires_in: u64,
152    /// 刷新 Token 过期时间 (秒)
153    pub refresh_token_expires_in: u64,
154    /// Token 签发者
155    pub issuer: String,
156    /// Token 受众
157    pub audience: String,
158    /// 密码最小长度
159    pub password_min_length: usize,
160    /// 是否要求密码包含数字
161    pub password_require_digit: bool,
162    /// 是否要求密码包含特殊字符
163    pub password_require_special: bool,
164    /// 登录失败锁定阈值
165    pub max_login_attempts: u32,
166    /// 锁定时间 (秒)
167    pub lockout_duration: u64,
168    /// 密码哈希配置
169    pub password_hash_config: PasswordHashConfig,
170}
171
172impl Default for AuthConfig {
173    fn default() -> Self {
174        Self {
175            token_expires_in: 3600,
176            refresh_token_expires_in: 86400 * 7,
177            issuer: "wae-authentication".to_string(),
178            audience: "wae-api".to_string(),
179            password_min_length: 8,
180            password_require_digit: true,
181            password_require_special: false,
182            max_login_attempts: 5,
183            lockout_duration: 1800,
184            password_hash_config: PasswordHashConfig::default(),
185        }
186    }
187}
188
189/// 认证服务 trait
190
191pub trait AuthService: Send + Sync {
192    /// 用户登录
193    ///
194    /// # Arguments
195    /// * `credentials` - 用户凭证
196    ///
197    /// # Returns
198    /// 认证 Token
199    async fn login(&self, credentials: &Credentials) -> AuthResult<AuthToken>;
200
201    /// 用户登出
202    ///
203    /// # Arguments
204    /// * `token` - 访问令牌
205    async fn logout(&self, token: &str) -> AuthResult<()>;
206
207    /// 刷新 Token
208    ///
209    /// # Arguments
210    /// * `refresh_token` - 刷新令牌
211    async fn refresh_token(&self, refresh_token: &str) -> AuthResult<AuthToken>;
212
213    /// 验证 Token
214    ///
215    /// # Arguments
216    /// * `token` - 访问令牌
217    async fn validate_token(&self, token: &str) -> AuthResult<TokenValidation>;
218
219    /// 创建用户
220    ///
221    /// # Arguments
222    /// * `request` - 创建请求
223    async fn create_user(&self, request: &CreateUserRequest) -> AuthResult<UserInfo>;
224
225    /// 获取用户信息
226    ///
227    /// # Arguments
228    /// * `user_id` - 用户 ID
229    async fn get_user(&self, user_id: &str) -> AuthResult<UserInfo>;
230
231    /// 更新用户信息
232    ///
233    /// # Arguments
234    /// * `user_id` - 用户 ID
235    /// * `request` - 更新请求
236    async fn update_user(&self, user_id: &str, request: &UpdateUserRequest) -> AuthResult<UserInfo>;
237
238    /// 删除用户
239    ///
240    /// # Arguments
241    /// * `user_id` - 用户 ID
242    async fn delete_user(&self, user_id: &str) -> AuthResult<()>;
243
244    /// 修改密码
245    ///
246    /// # Arguments
247    /// * `user_id` - 用户 ID
248    /// * `request` - 密码修改请求
249    async fn change_password(&self, user_id: &str, request: &ChangePasswordRequest) -> AuthResult<()>;
250
251    /// 重置密码
252    ///
253    /// # Arguments
254    /// * `identifier` - 用户标识 (用户名/邮箱/手机号)
255    async fn reset_password(&self, identifier: &str) -> AuthResult<()>;
256
257    /// 验证用户权限
258    ///
259    /// # Arguments
260    /// * `user_id` - 用户 ID
261    /// * `permission` - 权限代码
262    async fn check_permission(&self, user_id: &str, permission: &str) -> AuthResult<bool>;
263
264    /// 获取用户角色
265    ///
266    /// # Arguments
267    /// * `user_id` - 用户 ID
268    async fn get_user_roles(&self, user_id: &str) -> AuthResult<Vec<Role>>;
269
270    /// 分配角色
271    ///
272    /// # Arguments
273    /// * `user_id` - 用户 ID
274    /// * `role_id` - 角色 ID
275    async fn assign_role(&self, user_id: &str, role_id: &str) -> AuthResult<()>;
276
277    /// 移除角色
278    ///
279    /// # Arguments
280    /// * `user_id` - 用户 ID
281    /// * `role_id` - 角色 ID
282    async fn remove_role(&self, user_id: &str, role_id: &str) -> AuthResult<()>;
283
284    /// 获取配置
285    fn config(&self) -> &AuthConfig;
286}
287
288/// API Key 认证 trait
289
290pub trait ApiKeyAuth: Send + Sync {
291    /// 验证 API Key
292    ///
293    /// # Arguments
294    /// * `api_key` - API Key
295    async fn validate_api_key(&self, api_key: &str) -> AuthResult<TokenValidation>;
296
297    /// 创建 API Key
298    ///
299    /// # Arguments
300    /// * `user_id` - 用户 ID
301    /// * `name` - Key 名称
302    /// * `expires_in` - 过期时间 (秒),None 表示永不过期
303    async fn create_api_key(&self, user_id: &str, name: &str, expires_in: Option<u64>) -> AuthResult<String>;
304
305    /// 撤销 API Key
306    ///
307    /// # Arguments
308    /// * `api_key` - API Key
309    async fn revoke_api_key(&self, api_key: &str) -> AuthResult<()>;
310
311    /// 列出用户的所有 API Key
312    ///
313    /// # Arguments
314    /// * `user_id` - 用户 ID
315    async fn list_api_keys(&self, user_id: &str) -> AuthResult<Vec<ApiKeyInfo>>;
316}
317
318/// API Key 信息
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct ApiKeyInfo {
321    /// Key ID
322    pub id: String,
323    /// Key 名称
324    pub name: String,
325    /// Key 前缀 (用于识别)
326    pub prefix: String,
327    /// 创建时间
328    pub created_at: i64,
329    /// 过期时间
330    pub expires_at: Option<i64>,
331    /// 最后使用时间
332    pub last_used_at: Option<i64>,
333}
334
335/// 内存认证实现
336pub mod memory {
337    use super::*;
338    use std::{collections::HashMap, sync::Arc};
339    use tokio::sync::RwLock;
340
341    /// 内存用户存储
342    struct UserRecord {
343        info: UserInfo,
344        password_hash: String,
345        roles: Vec<Role>,
346        login_attempts: u32,
347        locked_until: Option<i64>,
348    }
349
350    /// 内存认证服务
351    pub struct MemoryAuthService {
352        config: AuthConfig,
353        password_hasher: Arc<PasswordHasherService>,
354        users: Arc<RwLock<HashMap<UserId, UserRecord>>>,
355        tokens: Arc<RwLock<HashMap<String, (UserId, i64)>>>,
356        refresh_tokens: Arc<RwLock<HashMap<String, (UserId, i64)>>>,
357    }
358
359    impl MemoryAuthService {
360        /// 创建新的内存认证服务
361        pub fn new(config: AuthConfig) -> Self {
362            let password_hasher = PasswordHasherService::new(config.password_hash_config.clone());
363            Self {
364                config,
365                password_hasher: Arc::new(password_hasher),
366                users: Arc::new(RwLock::new(HashMap::new())),
367                tokens: Arc::new(RwLock::new(HashMap::new())),
368                refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
369            }
370        }
371
372        fn hash_password(&self, password: &str) -> AuthResult<String> {
373            self.password_hasher.hash_password(password).map_err(|e| e.into())
374        }
375
376        fn verify_password(&self, password: &str, hash: &str) -> AuthResult<bool> {
377            self.password_hasher.verify_password(password, hash).map_err(|e| e.into())
378        }
379
380        fn generate_token() -> String {
381            format!("token_{}", uuid::Uuid::new_v4())
382        }
383
384        fn current_timestamp() -> i64 {
385            chrono::Utc::now().timestamp()
386        }
387    }
388
389    impl Default for MemoryAuthService {
390        fn default() -> Self {
391            Self::new(AuthConfig::default())
392        }
393    }
394
395    impl AuthService for MemoryAuthService {
396        async fn login(&self, credentials: &Credentials) -> AuthResult<AuthToken> {
397            let mut users = self.users.write().await;
398
399            let user = users
400                .values_mut()
401                .find(|u| u.info.username == credentials.identifier || u.info.email.as_deref() == Some(&credentials.identifier))
402                .ok_or(WaeError::invalid_credentials())?;
403
404            if user.locked_until.map(|t| t > Self::current_timestamp()).unwrap_or(false) {
405                return Err(WaeError::account_locked());
406            }
407
408            if !self.verify_password(&credentials.password, &user.password_hash)? {
409                user.login_attempts += 1;
410                if user.login_attempts >= self.config.max_login_attempts {
411                    user.locked_until = Some(Self::current_timestamp() + self.config.lockout_duration as i64);
412                    return Err(WaeError::account_locked());
413                }
414                return Err(WaeError::invalid_credentials());
415            }
416
417            user.login_attempts = 0;
418            user.locked_until = None;
419
420            let access_token = Self::generate_token();
421            let refresh_token = Self::generate_token();
422            let now = Self::current_timestamp();
423
424            self.tokens
425                .write()
426                .await
427                .insert(access_token.clone(), (user.info.id.clone(), now + self.config.token_expires_in as i64));
428            self.refresh_tokens
429                .write()
430                .await
431                .insert(refresh_token.clone(), (user.info.id.clone(), now + self.config.refresh_token_expires_in as i64));
432
433            Ok(AuthToken {
434                access_token,
435                refresh_token: Some(refresh_token),
436                token_type: "Bearer".to_string(),
437                expires_in: self.config.token_expires_in,
438                expires_at: now + self.config.token_expires_in as i64,
439            })
440        }
441
442        async fn logout(&self, token: &str) -> AuthResult<()> {
443            self.tokens.write().await.remove(token);
444            Ok(())
445        }
446
447        async fn refresh_token(&self, refresh_token: &str) -> AuthResult<AuthToken> {
448            let mut refresh_tokens = self.refresh_tokens.write().await;
449            let (user_id, _) =
450                refresh_tokens.remove(refresh_token).ok_or_else(|| WaeError::invalid_token("Invalid refresh token"))?;
451
452            let access_token = Self::generate_token();
453            let new_refresh_token = Self::generate_token();
454            let now = Self::current_timestamp();
455
456            self.tokens
457                .write()
458                .await
459                .insert(access_token.clone(), (user_id.clone(), now + self.config.token_expires_in as i64));
460            refresh_tokens.insert(new_refresh_token.clone(), (user_id, now + self.config.refresh_token_expires_in as i64));
461
462            Ok(AuthToken {
463                access_token,
464                refresh_token: Some(new_refresh_token),
465                token_type: "Bearer".to_string(),
466                expires_in: self.config.token_expires_in,
467                expires_at: now + self.config.token_expires_in as i64,
468            })
469        }
470
471        async fn validate_token(&self, token: &str) -> AuthResult<TokenValidation> {
472            let tokens = self.tokens.read().await;
473            let (user_id, expires_at) = tokens.get(token).ok_or_else(|| WaeError::invalid_token("Token not found"))?;
474
475            if *expires_at < Self::current_timestamp() {
476                return Err(WaeError::token_expired());
477            }
478
479            let users = self.users.read().await;
480            let user = users.get(user_id).ok_or_else(|| WaeError::user_not_found(user_id.clone()))?;
481
482            let permissions: Vec<PermissionCode> = user.roles.iter().flat_map(|r| r.permissions.iter().cloned()).collect();
483
484            Ok(TokenValidation {
485                user_id: user_id.clone(),
486                user: Some(user.info.clone()),
487                roles: user.roles.clone(),
488                permissions,
489                metadata: HashMap::new(),
490            })
491        }
492
493        async fn create_user(&self, request: &CreateUserRequest) -> AuthResult<UserInfo> {
494            let mut users = self.users.write().await;
495
496            if users.values().any(|u| u.info.username == request.username) {
497                return Err(WaeError::user_already_exists(request.username.clone()));
498            }
499
500            let user_id = uuid::Uuid::new_v4().to_string();
501            let now = Self::current_timestamp();
502            let password_hash = self.hash_password(&request.password)?;
503
504            let info = UserInfo {
505                id: user_id.clone(),
506                username: request.username.clone(),
507                email: request.email.clone(),
508                phone: request.phone.clone(),
509                display_name: request.display_name.clone(),
510                avatar_url: None,
511                verified: false,
512                disabled: false,
513                attributes: request.attributes.clone(),
514                created_at: now,
515                updated_at: now,
516            };
517
518            let record = UserRecord { info: info.clone(), password_hash, roles: vec![], login_attempts: 0, locked_until: None };
519
520            users.insert(user_id, record);
521            Ok(info)
522        }
523
524        async fn get_user(&self, user_id: &str) -> AuthResult<UserInfo> {
525            let users = self.users.read().await;
526            users.get(user_id).map(|u| u.info.clone()).ok_or_else(|| WaeError::user_not_found(user_id))
527        }
528
529        async fn update_user(&self, user_id: &str, request: &UpdateUserRequest) -> AuthResult<UserInfo> {
530            let mut users = self.users.write().await;
531            let user = users.get_mut(user_id).ok_or_else(|| WaeError::user_not_found(user_id))?;
532
533            if let Some(name) = &request.display_name {
534                user.info.display_name = Some(name.clone());
535            }
536            if let Some(url) = &request.avatar_url {
537                user.info.avatar_url = Some(url.clone());
538            }
539            if let Some(attrs) = &request.attributes {
540                user.info.attributes = attrs.clone();
541            }
542            user.info.updated_at = Self::current_timestamp();
543
544            Ok(user.info.clone())
545        }
546
547        async fn delete_user(&self, user_id: &str) -> AuthResult<()> {
548            let mut users = self.users.write().await;
549            users.remove(user_id).map(|_| ()).ok_or_else(|| WaeError::user_not_found(user_id))
550        }
551
552        async fn change_password(&self, user_id: &str, request: &ChangePasswordRequest) -> AuthResult<()> {
553            let mut users = self.users.write().await;
554            let user = users.get_mut(user_id).ok_or_else(|| WaeError::user_not_found(user_id))?;
555
556            if !self.verify_password(&request.old_password, &user.password_hash)? {
557                return Err(WaeError::invalid_credentials());
558            }
559
560            user.password_hash = self.hash_password(&request.new_password)?;
561            user.info.updated_at = Self::current_timestamp();
562            Ok(())
563        }
564
565        async fn reset_password(&self, identifier: &str) -> AuthResult<()> {
566            let users = self.users.read().await;
567            let user = users
568                .values()
569                .find(|u| u.info.username == identifier || u.info.email.as_deref() == Some(identifier))
570                .ok_or_else(|| WaeError::user_not_found(identifier))?;
571
572            tracing::info!("Password reset requested for user: {}", user.info.id);
573            Ok(())
574        }
575
576        async fn check_permission(&self, user_id: &str, permission: &str) -> AuthResult<bool> {
577            let users = self.users.read().await;
578            let user = users.get(user_id).ok_or_else(|| WaeError::user_not_found(user_id))?;
579
580            Ok(user.roles.iter().any(|r| r.permissions.iter().any(|p| p == permission)))
581        }
582
583        async fn get_user_roles(&self, user_id: &str) -> AuthResult<Vec<Role>> {
584            let users = self.users.read().await;
585            let user = users.get(user_id).ok_or_else(|| WaeError::user_not_found(user_id))?;
586            Ok(user.roles.clone())
587        }
588
589        async fn assign_role(&self, user_id: &str, role_id: &str) -> AuthResult<()> {
590            let mut users = self.users.write().await;
591            let user = users.get_mut(user_id).ok_or_else(|| WaeError::user_not_found(user_id))?;
592
593            if !user.roles.iter().any(|r| r.id == role_id) {
594                user.roles.push(Role { id: role_id.into(), name: role_id.into(), description: None, permissions: vec![] });
595            }
596            Ok(())
597        }
598
599        async fn remove_role(&self, user_id: &str, role_id: &str) -> AuthResult<()> {
600            let mut users = self.users.write().await;
601            let user = users.get_mut(user_id).ok_or_else(|| WaeError::user_not_found(user_id))?;
602
603            user.roles.retain(|r| r.id != role_id);
604            Ok(())
605        }
606
607        fn config(&self) -> &AuthConfig {
608            &self.config
609        }
610    }
611}
612
613/// 便捷函数:创建内存认证服务
614pub fn memory_auth_service(config: AuthConfig) -> memory::MemoryAuthService {
615    memory::MemoryAuthService::new(config)
616}