1use crate::{constants::ALGORITHM, errors::SecureError, stores::key_store::KeyStore};
2use jsonwebtoken::{decode_header, DecodingKey, EncodingKey, Header, Validation};
3use openssl::rsa::Rsa;
4
5pub fn encode<T: serde::Serialize>(
7 claims: &T,
8 kid: &str,
9 rsa_key_pair: Rsa<openssl::pkey::Private>,
10) -> Result<String, SecureError> {
11 let der = rsa_key_pair.private_key_to_der()?;
12 let encoding_key: EncodingKey = EncodingKey::from_rsa_der(der.as_ref());
13
14 let mut header = Header::new(ALGORITHM);
15 header.kid = Some(kid.to_string());
16
17 let token = jsonwebtoken::encode(&header, &claims, &encoding_key)
18 .map_err(|e| SecureError::CannotEncodeJwtToken(e.to_string()))?;
19
20 Ok(token)
21}
22
23pub async fn encode_using_store<T: serde::Serialize>(
25 claims: &T,
26 key_store: &dyn KeyStore,
27) -> Result<String, SecureError> {
28 let (kid, rsa_key_pair) = key_store.get_current_key().await?;
29 encode(claims, &kid, rsa_key_pair)
30}
31
32pub fn decode<T: serde::de::DeserializeOwned + Clone>(
34 encoded_jwt: &str,
35 rsa_key_pair: Rsa<openssl::pkey::Private>,
36) -> Result<jsonwebtoken::TokenData<T>, SecureError> {
37 let public_key = rsa_key_pair.public_key_to_pem()?;
38 let decoding_key = DecodingKey::from_rsa_pem(public_key.as_ref())?;
39 let mut validation = Validation::new(ALGORITHM);
40 validation.validate_aud = false; jsonwebtoken::decode(encoded_jwt, &decoding_key, &validation)
43 .map_err(|e| SecureError::CannotDecodeJwtToken(e.to_string()))
44}
45
46pub fn insecure_decode<T: serde::de::DeserializeOwned + Clone>(
50 encoded_jwt: &str,
51) -> Result<jsonwebtoken::TokenData<T>, SecureError> {
52 jsonwebtoken::dangerous::insecure_decode(encoded_jwt)
54 .map_err(|e| SecureError::CannotDecodeJwtToken(e.to_string()))
55}
56
57pub async fn decode_using_store<T: serde::de::DeserializeOwned + Clone>(
59 encoded_jwt: &str,
60 key_store: &dyn KeyStore,
61) -> Result<jsonwebtoken::TokenData<T>, SecureError> {
62 let header =
63 decode_header(encoded_jwt).map_err(|e| SecureError::CannotDecodeJwtToken(e.to_string()))?;
64
65 if let Some(kid) = header.kid {
66 let key = key_store.get_key(&kid).await?;
68 decode(encoded_jwt, key)
69 } else {
70 Err(SecureError::InvalidKeyIdError(
71 "No kid present in JWT header".to_string(),
72 ))
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use crate::secure::generate_rsa_key_pair;
80 use crate::stores::key_store::KeyStore;
81 use async_trait::async_trait;
82 use chrono::{Duration, Utc};
83 use serde::{Deserialize, Serialize};
84 use std::collections::HashMap;
85
86 #[derive(Debug, Clone, Serialize, Deserialize)]
87 struct TestClaims {
88 sub: String,
89 exp: i64,
90 iat: i64,
91 }
92
93 struct MockKeyStore {
94 keys: HashMap<String, Rsa<openssl::pkey::Private>>,
95 current_kid: String,
96 }
97
98 #[async_trait]
99 impl KeyStore for MockKeyStore {
100 async fn get_current_keys(
101 &self,
102 _limit: i64,
103 ) -> Result<HashMap<String, Rsa<openssl::pkey::Private>>, SecureError> {
104 Ok(self.keys.clone())
105 }
106
107 async fn get_current_key(&self) -> Result<(String, Rsa<openssl::pkey::Private>), SecureError> {
108 self
109 .keys
110 .get(&self.current_kid)
111 .map(|key| (self.current_kid.clone(), key.clone()))
112 .ok_or(SecureError::EmptyKeys)
113 }
114
115 async fn get_key(&self, kid: &str) -> Result<Rsa<openssl::pkey::Private>, SecureError> {
116 self.keys.get(kid).cloned().ok_or(SecureError::InvalidKeyId)
117 }
118 }
119
120 #[tokio::test]
121 async fn test_encode_decode_using_async_store() {
122 let passphrase = "test_passphrase";
123 let (_, pem_string) = generate_rsa_key_pair(passphrase).unwrap();
124 let rsa_key = openssl::rsa::Rsa::private_key_from_pem_passphrase(
125 pem_string.as_bytes(),
126 passphrase.as_bytes(),
127 )
128 .unwrap();
129
130 let mut keys = HashMap::new();
131 let kid = "test_key_1";
132 keys.insert(kid.to_string(), rsa_key);
133
134 let key_store = MockKeyStore {
135 keys,
136 current_kid: kid.to_string(),
137 };
138
139 let now = Utc::now();
140 let claims = TestClaims {
141 sub: "test_user".to_string(),
142 exp: (now + Duration::hours(1)).timestamp(),
143 iat: now.timestamp(),
144 };
145
146 let token = encode_using_store(&claims, &key_store)
148 .await
149 .expect("Failed to encode token");
150
151 let decoded = decode_using_store::<TestClaims>(&token, &key_store)
153 .await
154 .expect("Failed to decode token");
155
156 assert_eq!(decoded.claims.sub, "test_user");
157 assert_eq!(decoded.header.kid, Some(kid.to_string()));
158 }
159
160 #[tokio::test]
161 async fn test_decode_missing_kid() {
162 let passphrase = "test_passphrase";
163 let (_, pem_string) = generate_rsa_key_pair(passphrase).unwrap();
164 let rsa_key = openssl::rsa::Rsa::private_key_from_pem_passphrase(
165 pem_string.as_bytes(),
166 passphrase.as_bytes(),
167 )
168 .unwrap();
169
170 let key_store = MockKeyStore {
171 keys: HashMap::new(),
172 current_kid: "test_key".to_string(),
173 };
174
175 let now = Utc::now();
176 let claims = TestClaims {
177 sub: "test_user".to_string(),
178 exp: (now + Duration::hours(1)).timestamp(),
179 iat: now.timestamp(),
180 };
181
182 let der = rsa_key.private_key_to_der().unwrap();
184 let encoding_key = EncodingKey::from_rsa_der(&der);
185 let header = Header::new(ALGORITHM); let token = jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap();
187
188 let result = decode_using_store::<TestClaims>(&token, &key_store).await;
190
191 assert!(matches!(result, Err(SecureError::InvalidKeyIdError(_))));
192 }
193}