armature_auth/
strategy.rs1use crate::{AuthError, AuthUser, Result};
4use armature_jwt::JwtManager;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8#[async_trait]
10pub trait AuthStrategy<T: AuthUser>: Send + Sync {
11 async fn authenticate(&self, credentials: &dyn std::any::Any) -> Result<T>;
13}
14
15pub struct LocalStrategy<T: AuthUser> {
17 _phantom: std::marker::PhantomData<T>,
18}
19
20impl<T: AuthUser> LocalStrategy<T> {
21 pub fn new() -> Self {
22 Self {
23 _phantom: std::marker::PhantomData,
24 }
25 }
26}
27
28impl<T: AuthUser> Default for LocalStrategy<T> {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct LocalCredentials {
37 pub username: String,
38 pub password: String,
39}
40
41pub struct JwtStrategy<T: AuthUser> {
43 _jwt_manager: JwtManager,
44 _phantom: std::marker::PhantomData<T>,
45}
46
47impl<T: AuthUser> JwtStrategy<T> {
48 pub fn new(jwt_manager: JwtManager) -> Self {
49 Self {
50 _jwt_manager: jwt_manager,
51 _phantom: std::marker::PhantomData,
52 }
53 }
54
55 pub fn extract_token<'a>(&self, header: &'a str) -> Result<&'a str> {
57 header
58 .strip_prefix("Bearer ")
59 .ok_or_else(|| AuthError::InvalidToken("Invalid Bearer token format".to_string()))
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct JwtCredentials {
66 pub token: String,
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 #[test]
74 fn test_local_credentials() {
75 let creds = LocalCredentials {
76 username: "user@example.com".to_string(),
77 password: "password123".to_string(),
78 };
79
80 assert_eq!(creds.username, "user@example.com");
81 assert_eq!(creds.password, "password123");
82 }
83
84 #[test]
85 fn test_jwt_token_extraction() {
86 use crate::UserContext;
87 use armature_jwt::JwtConfig;
88
89 let config = JwtConfig::new("test-secret".to_string());
90 let jwt_manager = JwtManager::new(config).unwrap();
91 let strategy = JwtStrategy::<UserContext>::new(jwt_manager);
92
93 let valid_header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...";
94 let token = strategy.extract_token(valid_header);
95 assert!(token.is_ok());
96
97 let invalid_header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...";
98 let token = strategy.extract_token(invalid_header);
99 assert!(token.is_err());
100 }
101}