atomic_lti/
jwt.rs

1use crate::{constants::ALGORITHM, errors::SecureError, stores::key_store::KeyStore};
2use jsonwebtoken::{decode_header, DecodingKey, EncodingKey, Header, Validation};
3use openssl::rsa::Rsa;
4
5/// Encode a JSON Web Token (JWT) asynchronously
6pub fn encode<T: serde::Serialize>(
7  claims: &T,
8  kid: &str,
9  rsa_key_pair: Rsa<openssl::pkey::Private>,
10) -> Result<String, SecureError> {
11  let der = rsa_key_pair.private_key_to_der()?;
12  let encoding_key: EncodingKey = EncodingKey::from_rsa_der(der.as_ref());
13
14  let mut header = Header::new(ALGORITHM);
15  header.kid = Some(kid.to_string());
16
17  let token = jsonwebtoken::encode(&header, &claims, &encoding_key)
18    .map_err(|e| SecureError::CannotEncodeJwtToken(e.to_string()))?;
19
20  Ok(token)
21}
22
23/// Encode a JWT using an async key store
24pub async fn encode_using_store<T: serde::Serialize>(
25  claims: &T,
26  key_store: &dyn KeyStore,
27) -> Result<String, SecureError> {
28  let (kid, rsa_key_pair) = key_store.get_current_key().await?;
29  encode(claims, &kid, rsa_key_pair)
30}
31
32/// Decode a JSON Web Token (JWT)
33pub fn decode<T: serde::de::DeserializeOwned + Clone>(
34  encoded_jwt: &str,
35  rsa_key_pair: Rsa<openssl::pkey::Private>,
36) -> Result<jsonwebtoken::TokenData<T>, SecureError> {
37  let public_key = rsa_key_pair.public_key_to_pem()?;
38  let decoding_key = DecodingKey::from_rsa_pem(public_key.as_ref())?;
39  let mut validation = Validation::new(ALGORITHM);
40  validation.validate_aud = false; // Don't validate audience since we don't know what to expect
41
42  jsonwebtoken::decode(encoded_jwt, &decoding_key, &validation)
43    .map_err(|e| SecureError::CannotDecodeJwtToken(e.to_string()))
44}
45
46/// Decode a JWT without validation
47/// WARNING: This is insecure and should only be used for testing or when signature validation
48/// is explicitly not required (e.g., extracting header information).
49pub fn insecure_decode<T: serde::de::DeserializeOwned + Clone>(
50  encoded_jwt: &str,
51) -> Result<jsonwebtoken::TokenData<T>, SecureError> {
52  // Use the recommended dangerous::insecure_decode function
53  jsonwebtoken::dangerous::insecure_decode(encoded_jwt)
54    .map_err(|e| SecureError::CannotDecodeJwtToken(e.to_string()))
55}
56
57/// Decode a JWT using an async key store
58pub async fn decode_using_store<T: serde::de::DeserializeOwned + Clone>(
59  encoded_jwt: &str,
60  key_store: &dyn KeyStore,
61) -> Result<jsonwebtoken::TokenData<T>, SecureError> {
62  let header =
63    decode_header(encoded_jwt).map_err(|e| SecureError::CannotDecodeJwtToken(e.to_string()))?;
64
65  if let Some(kid) = header.kid {
66    // Find the key that matches the kid in the JWT header
67    let key = key_store.get_key(&kid).await?;
68    decode(encoded_jwt, key)
69  } else {
70    Err(SecureError::InvalidKeyIdError(
71      "No kid present in JWT header".to_string(),
72    ))
73  }
74}
75
76#[cfg(test)]
77mod tests {
78  use super::*;
79  use crate::secure::generate_rsa_key_pair;
80  use crate::stores::key_store::KeyStore;
81  use async_trait::async_trait;
82  use chrono::{Duration, Utc};
83  use serde::{Deserialize, Serialize};
84  use std::collections::HashMap;
85
86  #[derive(Debug, Clone, Serialize, Deserialize)]
87  struct TestClaims {
88    sub: String,
89    exp: i64,
90    iat: i64,
91  }
92
93  struct MockKeyStore {
94    keys: HashMap<String, Rsa<openssl::pkey::Private>>,
95    current_kid: String,
96  }
97
98  #[async_trait]
99  impl KeyStore for MockKeyStore {
100    async fn get_current_keys(
101      &self,
102      _limit: i64,
103    ) -> Result<HashMap<String, Rsa<openssl::pkey::Private>>, SecureError> {
104      Ok(self.keys.clone())
105    }
106
107    async fn get_current_key(&self) -> Result<(String, Rsa<openssl::pkey::Private>), SecureError> {
108      self
109        .keys
110        .get(&self.current_kid)
111        .map(|key| (self.current_kid.clone(), key.clone()))
112        .ok_or(SecureError::EmptyKeys)
113    }
114
115    async fn get_key(&self, kid: &str) -> Result<Rsa<openssl::pkey::Private>, SecureError> {
116      self.keys.get(kid).cloned().ok_or(SecureError::InvalidKeyId)
117    }
118  }
119
120  #[tokio::test]
121  async fn test_encode_decode_using_async_store() {
122    let passphrase = "test_passphrase";
123    let (_, pem_string) = generate_rsa_key_pair(passphrase).unwrap();
124    let rsa_key = openssl::rsa::Rsa::private_key_from_pem_passphrase(
125      pem_string.as_bytes(),
126      passphrase.as_bytes(),
127    )
128    .unwrap();
129
130    let mut keys = HashMap::new();
131    let kid = "test_key_1";
132    keys.insert(kid.to_string(), rsa_key);
133
134    let key_store = MockKeyStore {
135      keys,
136      current_kid: kid.to_string(),
137    };
138
139    let now = Utc::now();
140    let claims = TestClaims {
141      sub: "test_user".to_string(),
142      exp: (now + Duration::hours(1)).timestamp(),
143      iat: now.timestamp(),
144    };
145
146    // Encode using async store
147    let token = encode_using_store(&claims, &key_store)
148      .await
149      .expect("Failed to encode token");
150
151    // Decode using async store
152    let decoded = decode_using_store::<TestClaims>(&token, &key_store)
153      .await
154      .expect("Failed to decode token");
155
156    assert_eq!(decoded.claims.sub, "test_user");
157    assert_eq!(decoded.header.kid, Some(kid.to_string()));
158  }
159
160  #[tokio::test]
161  async fn test_decode_missing_kid() {
162    let passphrase = "test_passphrase";
163    let (_, pem_string) = generate_rsa_key_pair(passphrase).unwrap();
164    let rsa_key = openssl::rsa::Rsa::private_key_from_pem_passphrase(
165      pem_string.as_bytes(),
166      passphrase.as_bytes(),
167    )
168    .unwrap();
169
170    let key_store = MockKeyStore {
171      keys: HashMap::new(),
172      current_kid: "test_key".to_string(),
173    };
174
175    let now = Utc::now();
176    let claims = TestClaims {
177      sub: "test_user".to_string(),
178      exp: (now + Duration::hours(1)).timestamp(),
179      iat: now.timestamp(),
180    };
181
182    // Encode without kid
183    let der = rsa_key.private_key_to_der().unwrap();
184    let encoding_key = EncodingKey::from_rsa_der(&der);
185    let header = Header::new(ALGORITHM); // No kid
186    let token = jsonwebtoken::encode(&header, &claims, &encoding_key).unwrap();
187
188    // Try to decode - should fail due to missing kid
189    let result = decode_using_store::<TestClaims>(&token, &key_store).await;
190
191    assert!(matches!(result, Err(SecureError::InvalidKeyIdError(_))));
192  }
193}