1use chrono::{Duration, Utc};
6use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
7use serde::{Deserialize, Serialize};
8
9#[derive(Serialize, Deserialize)]
10pub struct Claims {
11 sub: String,
13 exp: u64,
15 iat: u64,
17}
18
19#[allow(unused)]
20impl Claims {
21 pub fn new(sub: String, expiration: Duration) -> Self {
23 let now = Utc::now();
24 let exp = (now + expiration).timestamp() as u64;
25 Claims {
26 sub,
27 exp,
28 iat: now.timestamp() as u64,
29 }
30 }
31
32 pub fn get_sub(&self) -> &str {
34 &self.sub
35 }
36
37 pub fn get_exp(&self) -> u64 {
39 self.exp
40 }
41
42 pub fn get_iat(&self) -> u64 {
44 self.iat
45 }
46}
47
48#[derive(Clone, Debug)]
49pub enum JwtEncodeSecret {
50 SharedSecret(EncodingKey, Algorithm),
52 RsaKey(EncodingKey, Algorithm),
54 EcdsaKey(EncodingKey, Algorithm),
56}
57
58impl JwtEncodeSecret {
59 pub fn from_shared_secret(secret: &str) -> Self {
60 JwtEncodeSecret::SharedSecret(
61 EncodingKey::from_secret(secret.as_bytes()),
62 Algorithm::HS256,
63 )
64 }
65
66 pub fn from_rsa_pem(pem: &[u8]) -> Result<Self, jsonwebtoken::errors::Error> {
67 Ok(JwtEncodeSecret::RsaKey(
68 EncodingKey::from_rsa_pem(pem)?,
69 Algorithm::RS256,
70 ))
71 }
72
73 pub fn from_ec_pem(pem: &[u8]) -> Result<Self, jsonwebtoken::errors::Error> {
74 Ok(JwtEncodeSecret::EcdsaKey(
75 EncodingKey::from_ec_pem(pem)?,
76 Algorithm::ES256,
77 ))
78 }
79}
80
81#[derive(Clone, Debug)]
82pub enum JwtDecodeSecret {
83 SharedSecret(DecodingKey, Algorithm),
85 RsaKey(DecodingKey, Algorithm),
87 EcdsaKey(DecodingKey, Algorithm),
89}
90
91impl JwtDecodeSecret {
92 pub fn from_shared_secret(secret: &str) -> Self {
93 JwtDecodeSecret::SharedSecret(
94 DecodingKey::from_secret(secret.as_bytes()),
95 Algorithm::HS256,
96 )
97 }
98
99 pub fn from_rsa_pem(pem: &[u8]) -> Result<Self, jsonwebtoken::errors::Error> {
100 Ok(JwtDecodeSecret::RsaKey(
101 DecodingKey::from_rsa_pem(pem)?,
102 Algorithm::RS256,
103 ))
104 }
105
106 pub fn from_ec_pem(pem: &[u8]) -> Result<Self, jsonwebtoken::errors::Error> {
107 Ok(JwtDecodeSecret::EcdsaKey(
108 DecodingKey::from_ec_pem(pem)?,
109 Algorithm::ES256,
110 ))
111 }
112}
113
114pub fn validate_token(
115 secret: &JwtDecodeSecret,
116 token: &str,
117) -> Result<Claims, jsonwebtoken::errors::Error> {
118 let (decoding_key, validation) = match secret {
119 JwtDecodeSecret::SharedSecret(key, alg) => (key, Validation::new(*alg)),
120 JwtDecodeSecret::RsaKey(key, alg) => (key, Validation::new(*alg)),
121 JwtDecodeSecret::EcdsaKey(key, alg) => (key, Validation::new(*alg)),
122 };
123 let token_data = jsonwebtoken::decode::<Claims>(token, decoding_key, &validation)?;
124 Ok(token_data.claims)
125}
126
127#[allow(unused)]
128pub fn generate_token(
129 secret: &JwtEncodeSecret,
130 claims: &Claims,
131) -> Result<String, jsonwebtoken::errors::Error> {
132 let (encoding_key, header) = match secret {
133 JwtEncodeSecret::SharedSecret(key, alg) => (key, Header::new(*alg)),
134 JwtEncodeSecret::RsaKey(key, alg) => (key, Header::new(*alg)),
135 JwtEncodeSecret::EcdsaKey(key, alg) => (key, Header::new(*alg)),
136 };
137 jsonwebtoken::encode(&header, &claims, encoding_key)
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_claims_getters() {
146 let claims = Claims::new("test_user".to_string(), Duration::minutes(10));
147
148 assert_eq!(claims.get_sub(), "test_user");
149 assert_eq!(claims.get_exp() - claims.get_iat(), 600);
150 }
151
152 #[test]
153 fn test_shared_secret() {
154 let secret = "my_secret".to_string();
155 let decode_secret = JwtDecodeSecret::from_shared_secret(&secret);
156 let encode_secret = JwtEncodeSecret::from_shared_secret(&secret);
157 let claims = Claims::new("test_user".to_string(), Duration::minutes(10));
158 let token = generate_token(&encode_secret, &claims).expect("Failed to generate token");
159 let decoded_claims =
160 validate_token(&decode_secret, &token).expect("Failed to validate token");
161 assert_eq!(decoded_claims.sub, claims.sub);
162 assert_eq!(decoded_claims.exp, claims.exp);
163 assert_eq!(decoded_claims.iat, claims.iat);
164 }
165
166 #[test]
167 fn test_rsa_secret() {
168 let encode_secret =
169 JwtEncodeSecret::from_rsa_pem(include_bytes!("../../test-data/rsa-private.pem"))
170 .expect("Failed to create encode secret");
171 let decode_secret =
172 JwtDecodeSecret::from_rsa_pem(include_bytes!("../../test-data/rsa-public.pem"))
173 .expect("Failed to create decode secret");
174 let claims = Claims::new("test_user".to_string(), Duration::minutes(10));
175 let token = generate_token(&encode_secret, &claims).expect("Failed to generate token");
176 let decoded_claims =
177 validate_token(&decode_secret, &token).expect("Failed to validate token");
178 assert_eq!(decoded_claims.sub, claims.sub);
179 assert_eq!(decoded_claims.exp, claims.exp);
180 assert_eq!(decoded_claims.iat, claims.iat);
181 }
182
183 #[test]
184 fn test_ecdsa_secret() {
185 let encode_secret =
186 JwtEncodeSecret::from_ec_pem(include_bytes!("../../test-data/ecdsa-private.pem"))
187 .expect("Failed to create encode secret");
188 let decode_secret =
189 JwtDecodeSecret::from_ec_pem(include_bytes!("../../test-data/ecdsa-public.pem"))
190 .expect("Failed to create decode secret");
191 let claims = Claims::new("test_user".to_string(), Duration::minutes(10));
192 let token = generate_token(&encode_secret, &claims).expect("Failed to generate token");
193 let decoded_claims =
194 validate_token(&decode_secret, &token).expect("Failed to validate token");
195 assert_eq!(decoded_claims.sub, claims.sub);
196 assert_eq!(decoded_claims.exp, claims.exp);
197 assert_eq!(decoded_claims.iat, claims.iat);
198 }
199
200 #[test]
201 fn test_expired_token() {
202 let secret = "my_secret".to_string();
203 let decode_secret = JwtDecodeSecret::from_shared_secret(&secret);
204 let encode_secret = JwtEncodeSecret::from_shared_secret(&secret);
205 let claims = Claims::new("test_user".to_string(), Duration::seconds(-100));
206 let token = generate_token(&encode_secret, &claims).expect("Failed to generate token");
207 let result = validate_token(&decode_secret, &token);
208 assert!(result.is_err());
209 }
210}