axum_rails_cookie/
lib.rs

1//!
2#![doc = include_str!("../README.md")]
3
4use axum::{
5  RequestPartsExt,
6  extract::{Extension, FromRequestParts},
7  http::request::Parts,
8};
9use axum_extra::extract::CookieJar;
10use message_verifier::{AesGcmEncryptor, AesHmacEncryptor, DerivedKeyParams, Encryptor};
11use std::convert::Infallible;
12
13/// Encryption salt used during encryption key derivation for encrypted cookies.
14const ENCRYPTION_SALT: &str = "encrypted cookie";
15/// Signing salt used during signed encrypted key derivation for signed encrypted cookies.
16const SIGNING_SALT: &str = "signed encrypted cookie";
17
18/// Represents different errors that can occur during cookie retrieval.
19#[derive(thiserror::Error, Debug)]
20pub enum RailsCookieError {
21  /// Error retrieving CookieConfig
22  #[error("Failed to extract Config")]
23  Config,
24
25  /// Error retrieving cookie jar
26  #[error("Failed to get cookie jar")]
27  CookieJar,
28
29  /// Error retrieving cookie
30  #[error("Failed to get cookie")]
31  CookieRetrieval,
32
33  /// Error creating cookie decryptor
34  #[error("Failed to create decryptor")]
35  DecryptorCreation,
36
37  /// Error decrypting cookie
38  #[error("Failed to decrypt cookie data")]
39  Decryption,
40
41  /// Error parsing decrypted cookie
42  #[error("Failed to parse valid utf8 from cookie data")]
43  CookieParse,
44}
45
46#[allow(unused)]
47#[derive(Debug, Clone)]
48pub enum CookieAlgorithm {
49  AesHmac,
50  AesGcm,
51}
52
53/// Represents values used during cookie retrieval.
54///
55/// # Example
56///
57/// You can create a `CookieConfig` using the following code:
58///
59/// ```
60/// use axum_rails_cookie::{CookieConfig, CookieAlgorithm};
61///
62/// let name = "_my_app_session_id";
63/// let secret = "3b53beba93922c29b3c335051f79e41c63fe626834d5a4a7ce96ebd189010063";
64/// let algorithm = CookieAlgorithm::AesHmac;
65/// let encryptor = CookieConfig::new(name, secret, algorithm);
66/// ```
67#[derive(Clone)]
68pub struct CookieConfig {
69  name: &'static str,
70  secret: &'static str,
71  algorithm: CookieAlgorithm,
72}
73
74impl CookieConfig {
75  pub fn new(name: &'static str, secret: &'static str, algorithm: CookieAlgorithm) -> Self {
76    CookieConfig {
77      name,
78      secret,
79      algorithm,
80    }
81  }
82}
83
84impl std::fmt::Debug for CookieConfig {
85  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86    write!(f, "CookieConfig {{}}")
87  }
88}
89
90impl std::fmt::Display for CookieConfig {
91  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92    write!(f, "{:?}", self)
93  }
94}
95
96/// Represents the success or failure of retrieving a rails cookie.
97#[derive(Debug)]
98pub enum RailsCookie {
99  Ok(String),
100  Err(RailsCookieError),
101}
102
103impl<S> FromRequestParts<S> for RailsCookie
104where
105  S: Send + Sync + 'static,
106{
107  type Rejection = Infallible;
108
109  async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
110    let Ok(Extension(config)): Result<Extension<CookieConfig>, _> =
111      Extension::from_request_parts(parts, state).await
112    else {
113      return Ok(RailsCookie::Err(RailsCookieError::Config));
114    };
115
116    let Ok(cookie_jar) = parts.extract::<CookieJar>().await;
117
118    let Some(cookie) = cookie_jar.get(config.name) else {
119      return Ok(RailsCookie::Err(RailsCookieError::CookieRetrieval));
120    };
121
122    let dkp = DerivedKeyParams::default();
123
124    let encryptor: Box<dyn Encryptor> = match config.algorithm {
125      CookieAlgorithm::AesHmac => {
126        let encryptor: Result<AesHmacEncryptor, _> =
127          AesHmacEncryptor::new(config.secret, ENCRYPTION_SALT, SIGNING_SALT, dkp);
128
129        match encryptor {
130          Ok(value) => Box::new(value),
131          Err(_) => return Ok(RailsCookie::Err(RailsCookieError::DecryptorCreation)),
132        }
133      }
134      CookieAlgorithm::AesGcm => {
135        let encryptor: Result<AesGcmEncryptor, _> =
136          AesGcmEncryptor::new(config.secret, ENCRYPTION_SALT, dkp);
137
138        match encryptor {
139          Ok(value) => Box::new(value),
140          Err(_) => return Ok(RailsCookie::Err(RailsCookieError::DecryptorCreation)),
141        }
142      }
143    };
144
145    let Ok(decrypted_value) = encryptor.decrypt_and_verify(cookie.value()) else {
146      return Ok(RailsCookie::Err(RailsCookieError::Decryption));
147    };
148
149    let Ok(decrypted_value) = String::from_utf8(decrypted_value) else {
150      return Ok(RailsCookie::Err(RailsCookieError::CookieParse));
151    };
152
153    Ok(RailsCookie::Ok(decrypted_value))
154  }
155}
156
157#[cfg(test)]
158mod tests {
159  use crate::{CookieAlgorithm, CookieConfig, ENCRYPTION_SALT, RailsCookie, SIGNING_SALT};
160  use axum::{
161    Router,
162    body::{Body, to_bytes},
163    extract::Extension,
164    http::{HeaderMap, Request, header},
165    response::Response,
166    routing::get,
167  };
168  use axum_extra::extract::cookie::Cookie;
169  use axum_macros::debug_handler;
170  use message_verifier::{AesGcmEncryptor, AesHmacEncryptor, DerivedKeyParams, Encryptor};
171  use tower::ServiceExt;
172
173  /// Used to configure generate_app_and_request behavior in tests
174  struct AppRequestConfig {
175    cookie_name: String,
176    include_config: bool,
177    include_cookie_header: bool,
178    encryption_salt: String,
179    cookie_algorithm: CookieAlgorithm,
180  }
181
182  /// Axum handler function to simulate usage of RailsCookie extractor.
183  #[debug_handler]
184  async fn test_handler(rails_cookie: RailsCookie) -> Result<String, String> {
185    match rails_cookie {
186      RailsCookie::Ok(cookie) => Ok(cookie),
187      RailsCookie::Err(err) => Err(err.to_string()),
188    }
189  }
190
191  /// Generate an encrypted cookie as a string.
192  ///
193  /// # Arguments
194  /// * `message` - Cookie data to be encrypted
195  /// * `secret` - Secret to derive encryption key from
196  /// * `encryption_salt` - Secret to derive signing key from
197  /// * `cookie_alg` - Algorithm to use for cookie encryption
198  fn generate_encrypted_cookie(
199    message: &str,
200    secret: &str,
201    encryption_salt: &str,
202    cookie_alg: CookieAlgorithm,
203  ) -> String {
204    let dkp = DerivedKeyParams::default();
205    let encryptor: Box<dyn Encryptor> = match cookie_alg {
206      CookieAlgorithm::AesHmac => {
207        let encryptor = AesHmacEncryptor::new(secret, encryption_salt, SIGNING_SALT, dkp).unwrap();
208        Box::new(encryptor)
209      }
210      CookieAlgorithm::AesGcm => {
211        let encryptor = AesGcmEncryptor::new(secret, encryption_salt, dkp).unwrap();
212        Box::new(encryptor)
213      }
214    };
215
216    encryptor.encrypt_and_sign(message).unwrap()
217  }
218
219  /// Generates an axum router and request based on the AppRequestConfig param.
220  ///
221  /// # Arguments
222  /// * `app_request_config` - Configuration that describes axum router and request behavior
223  fn generate_app_and_request(app_request_config: AppRequestConfig) -> (Router, Request<Body>) {
224    let secret_key_base = "3b53beba93922c29b3c335051f79e41c63fe626834d5a4a7ce96ebd189010063";
225    let encrypted_signed_cookie = generate_encrypted_cookie(
226      "hello world",
227      secret_key_base,
228      &app_request_config.encryption_salt,
229      app_request_config.cookie_algorithm.clone(),
230    );
231    let config = CookieConfig::new(
232      "test_cookie",
233      secret_key_base,
234      app_request_config.cookie_algorithm,
235    );
236
237    let app = if app_request_config.include_config {
238      Router::new()
239        .route("/", get(test_handler))
240        .layer(Extension(config))
241    } else {
242      Router::new().route("/", get(test_handler))
243    };
244
245    let request = if app_request_config.include_cookie_header {
246      let cookie = Cookie::new(app_request_config.cookie_name, encrypted_signed_cookie);
247      let mut headers = HeaderMap::new();
248      headers.insert(header::COOKIE, cookie.to_string().parse().unwrap());
249
250      Request::builder()
251        .uri("/")
252        .header(header::COOKIE, cookie.to_string())
253        .body(Body::empty())
254        .unwrap()
255    } else {
256      Request::builder().uri("/").body(Body::empty()).unwrap()
257    };
258
259    (app, request)
260  }
261
262  /// Gets the request body of an http request as a UTF8 string.
263  ///
264  /// # Arguments
265  /// * `response` - Response body from an axum request
266  async fn get_response_body(response: Response<Body>) -> String {
267    let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
268    String::from_utf8(body.to_vec()).unwrap()
269  }
270
271  mod cookie_config {
272    use crate::{CookieAlgorithm, CookieConfig};
273
274    #[test]
275    fn test_config_does_not_expose_details_when_printed() {
276      let cookie_config = CookieConfig::new(
277        "_some_app_session_id",
278        "3b53beba93922c29b3c335051f79e41c63fe626834d5a4a7ce96ebd189010063",
279        CookieAlgorithm::AesHmac,
280      );
281
282      let result = format!("{}", cookie_config);
283      let expected = "CookieConfig {}".to_string();
284
285      assert_eq!(expected, result);
286    }
287  }
288
289  mod aes_hmac {
290    use super::*;
291
292    #[tokio::test]
293    async fn test_valid_cookie_extraction() {
294      let (app, request) = generate_app_and_request(AppRequestConfig {
295        cookie_name: "test_cookie".into(),
296        include_config: true,
297        include_cookie_header: true,
298        encryption_salt: ENCRYPTION_SALT.to_string(),
299        cookie_algorithm: CookieAlgorithm::AesHmac,
300      });
301      let response = app.oneshot(request).await.unwrap();
302
303      assert_eq!(response.status(), 200);
304
305      let body = get_response_body(response).await;
306
307      assert_eq!(body, "hello world");
308    }
309
310    #[tokio::test]
311    async fn test_invalid_cookie_extraction() {
312      let (app, request) = generate_app_and_request(AppRequestConfig {
313        cookie_name: "does_not_exist".into(),
314        include_config: true,
315        include_cookie_header: true,
316        encryption_salt: ENCRYPTION_SALT.to_string(),
317        cookie_algorithm: CookieAlgorithm::AesHmac,
318      });
319      let response = app.oneshot(request).await.unwrap();
320
321      assert_eq!(response.status(), 200);
322
323      let body = get_response_body(response).await;
324
325      assert_eq!(body, "Failed to get cookie");
326    }
327
328    #[tokio::test]
329    async fn test_missing_app_config_extension_extraction() {
330      let (app, request) = generate_app_and_request(AppRequestConfig {
331        cookie_name: "test_cookie".into(),
332        include_config: false,
333        include_cookie_header: true,
334        encryption_salt: ENCRYPTION_SALT.to_string(),
335        cookie_algorithm: CookieAlgorithm::AesHmac,
336      });
337      let response = app.oneshot(request).await.unwrap();
338
339      assert_eq!(response.status(), 200);
340
341      let body = get_response_body(response).await;
342
343      assert_eq!(body, "Failed to extract Config");
344    }
345
346    #[tokio::test]
347    async fn test_missing_cookie_header_extraction() {
348      let (app, request) = generate_app_and_request(AppRequestConfig {
349        cookie_name: "test_cookie".into(),
350        include_config: true,
351        include_cookie_header: false,
352        encryption_salt: ENCRYPTION_SALT.to_string(),
353        cookie_algorithm: CookieAlgorithm::AesHmac,
354      });
355      let response = app.oneshot(request).await.unwrap();
356
357      assert_eq!(response.status(), 200);
358
359      let body = get_response_body(response).await;
360
361      assert_eq!(body, "Failed to get cookie");
362    }
363
364    #[tokio::test]
365    async fn test_incorrect_encryption_salt_extraction() {
366      let (app, request) = generate_app_and_request(AppRequestConfig {
367        cookie_name: "test_cookie".into(),
368        include_config: true,
369        include_cookie_header: true,
370        encryption_salt: "".to_string(),
371        cookie_algorithm: CookieAlgorithm::AesHmac,
372      });
373      let response = app.oneshot(request).await.unwrap();
374
375      assert_eq!(response.status(), 200);
376
377      let body = get_response_body(response).await;
378
379      assert_eq!(body, "Failed to decrypt cookie data");
380    }
381  }
382
383  mod aes_gcm {
384    use super::*;
385
386    #[tokio::test]
387    async fn test_valid_cookie_extraction() {
388      let (app, request) = generate_app_and_request(AppRequestConfig {
389        cookie_name: "test_cookie".into(),
390        include_config: true,
391        include_cookie_header: true,
392        encryption_salt: ENCRYPTION_SALT.to_string(),
393        cookie_algorithm: CookieAlgorithm::AesGcm,
394      });
395      let response = app.oneshot(request).await.unwrap();
396
397      assert_eq!(response.status(), 200);
398
399      let body = get_response_body(response).await;
400
401      assert_eq!(body, "hello world");
402    }
403
404    #[tokio::test]
405    async fn test_invalid_cookie_extraction() {
406      let (app, request) = generate_app_and_request(AppRequestConfig {
407        cookie_name: "does_not_exist".into(),
408        include_config: true,
409        include_cookie_header: true,
410        encryption_salt: ENCRYPTION_SALT.to_string(),
411        cookie_algorithm: CookieAlgorithm::AesGcm,
412      });
413      let response = app.oneshot(request).await.unwrap();
414
415      assert_eq!(response.status(), 200);
416
417      let body = get_response_body(response).await;
418
419      assert_eq!(body, "Failed to get cookie");
420    }
421
422    #[tokio::test]
423    async fn test_missing_app_config_extension_extraction() {
424      let (app, request) = generate_app_and_request(AppRequestConfig {
425        cookie_name: "test_cookie".into(),
426        include_config: false,
427        include_cookie_header: true,
428        encryption_salt: ENCRYPTION_SALT.to_string(),
429        cookie_algorithm: CookieAlgorithm::AesGcm,
430      });
431      let response = app.oneshot(request).await.unwrap();
432
433      assert_eq!(response.status(), 200);
434
435      let body = get_response_body(response).await;
436
437      assert_eq!(body, "Failed to extract Config");
438    }
439
440    #[tokio::test]
441    async fn test_missing_cookie_header_extraction() {
442      let (app, request) = generate_app_and_request(AppRequestConfig {
443        cookie_name: "test_cookie".into(),
444        include_config: true,
445        include_cookie_header: false,
446        encryption_salt: ENCRYPTION_SALT.to_string(),
447        cookie_algorithm: CookieAlgorithm::AesGcm,
448      });
449      let response = app.oneshot(request).await.unwrap();
450
451      assert_eq!(response.status(), 200);
452
453      let body = get_response_body(response).await;
454
455      assert_eq!(body, "Failed to get cookie");
456    }
457
458    #[tokio::test]
459    async fn test_incorrect_encryption_salt_extraction() {
460      let (app, request) = generate_app_and_request(AppRequestConfig {
461        cookie_name: "test_cookie".into(),
462        include_config: true,
463        include_cookie_header: true,
464        encryption_salt: "".to_string(),
465        cookie_algorithm: CookieAlgorithm::AesGcm,
466      });
467      let response = app.oneshot(request).await.unwrap();
468
469      assert_eq!(response.status(), 200);
470
471      let body = get_response_body(response).await;
472
473      assert_eq!(body, "Failed to decrypt cookie data");
474    }
475  }
476}