1#![doc = include_str!("../README.md")]
3
4use axum::{
5 RequestPartsExt,
6 extract::{Extension, FromRequestParts},
7 http::request::Parts,
8};
9use axum_extra::extract::CookieJar;
10use message_verifier::{AesGcmEncryptor, AesHmacEncryptor, DerivedKeyParams, Encryptor};
11use std::convert::Infallible;
12
13const ENCRYPTION_SALT: &str = "encrypted cookie";
15const SIGNING_SALT: &str = "signed encrypted cookie";
17
18#[derive(thiserror::Error, Debug)]
20pub enum RailsCookieError {
21 #[error("Failed to extract Config")]
23 Config,
24
25 #[error("Failed to get cookie jar")]
27 CookieJar,
28
29 #[error("Failed to get cookie")]
31 CookieRetrieval,
32
33 #[error("Failed to create decryptor")]
35 DecryptorCreation,
36
37 #[error("Failed to decrypt cookie data")]
39 Decryption,
40
41 #[error("Failed to parse valid utf8 from cookie data")]
43 CookieParse,
44}
45
46#[allow(unused)]
47#[derive(Debug, Clone)]
48pub enum CookieAlgorithm {
49 AesHmac,
50 AesGcm,
51}
52
53#[derive(Clone)]
68pub struct CookieConfig {
69 name: &'static str,
70 secret: &'static str,
71 algorithm: CookieAlgorithm,
72}
73
74impl CookieConfig {
75 pub fn new(name: &'static str, secret: &'static str, algorithm: CookieAlgorithm) -> Self {
76 CookieConfig {
77 name,
78 secret,
79 algorithm,
80 }
81 }
82}
83
84impl std::fmt::Debug for CookieConfig {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "CookieConfig {{}}")
87 }
88}
89
90impl std::fmt::Display for CookieConfig {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{:?}", self)
93 }
94}
95
96#[derive(Debug)]
98pub enum RailsCookie {
99 Ok(String),
100 Err(RailsCookieError),
101}
102
103impl<S> FromRequestParts<S> for RailsCookie
104where
105 S: Send + Sync + 'static,
106{
107 type Rejection = Infallible;
108
109 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
110 let Ok(Extension(config)): Result<Extension<CookieConfig>, _> =
111 Extension::from_request_parts(parts, state).await
112 else {
113 return Ok(RailsCookie::Err(RailsCookieError::Config));
114 };
115
116 let Ok(cookie_jar) = parts.extract::<CookieJar>().await;
117
118 let Some(cookie) = cookie_jar.get(config.name) else {
119 return Ok(RailsCookie::Err(RailsCookieError::CookieRetrieval));
120 };
121
122 let dkp = DerivedKeyParams::default();
123
124 let encryptor: Box<dyn Encryptor> = match config.algorithm {
125 CookieAlgorithm::AesHmac => {
126 let encryptor: Result<AesHmacEncryptor, _> =
127 AesHmacEncryptor::new(config.secret, ENCRYPTION_SALT, SIGNING_SALT, dkp);
128
129 match encryptor {
130 Ok(value) => Box::new(value),
131 Err(_) => return Ok(RailsCookie::Err(RailsCookieError::DecryptorCreation)),
132 }
133 }
134 CookieAlgorithm::AesGcm => {
135 let encryptor: Result<AesGcmEncryptor, _> =
136 AesGcmEncryptor::new(config.secret, ENCRYPTION_SALT, dkp);
137
138 match encryptor {
139 Ok(value) => Box::new(value),
140 Err(_) => return Ok(RailsCookie::Err(RailsCookieError::DecryptorCreation)),
141 }
142 }
143 };
144
145 let Ok(decrypted_value) = encryptor.decrypt_and_verify(cookie.value()) else {
146 return Ok(RailsCookie::Err(RailsCookieError::Decryption));
147 };
148
149 let Ok(decrypted_value) = String::from_utf8(decrypted_value) else {
150 return Ok(RailsCookie::Err(RailsCookieError::CookieParse));
151 };
152
153 Ok(RailsCookie::Ok(decrypted_value))
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use crate::{CookieAlgorithm, CookieConfig, ENCRYPTION_SALT, RailsCookie, SIGNING_SALT};
160 use axum::{
161 Router,
162 body::{Body, to_bytes},
163 extract::Extension,
164 http::{HeaderMap, Request, header},
165 response::Response,
166 routing::get,
167 };
168 use axum_extra::extract::cookie::Cookie;
169 use axum_macros::debug_handler;
170 use message_verifier::{AesGcmEncryptor, AesHmacEncryptor, DerivedKeyParams, Encryptor};
171 use tower::ServiceExt;
172
173 struct AppRequestConfig {
175 cookie_name: String,
176 include_config: bool,
177 include_cookie_header: bool,
178 encryption_salt: String,
179 cookie_algorithm: CookieAlgorithm,
180 }
181
182 #[debug_handler]
184 async fn test_handler(rails_cookie: RailsCookie) -> Result<String, String> {
185 match rails_cookie {
186 RailsCookie::Ok(cookie) => Ok(cookie),
187 RailsCookie::Err(err) => Err(err.to_string()),
188 }
189 }
190
191 fn generate_encrypted_cookie(
199 message: &str,
200 secret: &str,
201 encryption_salt: &str,
202 cookie_alg: CookieAlgorithm,
203 ) -> String {
204 let dkp = DerivedKeyParams::default();
205 let encryptor: Box<dyn Encryptor> = match cookie_alg {
206 CookieAlgorithm::AesHmac => {
207 let encryptor = AesHmacEncryptor::new(secret, encryption_salt, SIGNING_SALT, dkp).unwrap();
208 Box::new(encryptor)
209 }
210 CookieAlgorithm::AesGcm => {
211 let encryptor = AesGcmEncryptor::new(secret, encryption_salt, dkp).unwrap();
212 Box::new(encryptor)
213 }
214 };
215
216 encryptor.encrypt_and_sign(message).unwrap()
217 }
218
219 fn generate_app_and_request(app_request_config: AppRequestConfig) -> (Router, Request<Body>) {
224 let secret_key_base = "3b53beba93922c29b3c335051f79e41c63fe626834d5a4a7ce96ebd189010063";
225 let encrypted_signed_cookie = generate_encrypted_cookie(
226 "hello world",
227 secret_key_base,
228 &app_request_config.encryption_salt,
229 app_request_config.cookie_algorithm.clone(),
230 );
231 let config = CookieConfig::new(
232 "test_cookie",
233 secret_key_base,
234 app_request_config.cookie_algorithm,
235 );
236
237 let app = if app_request_config.include_config {
238 Router::new()
239 .route("/", get(test_handler))
240 .layer(Extension(config))
241 } else {
242 Router::new().route("/", get(test_handler))
243 };
244
245 let request = if app_request_config.include_cookie_header {
246 let cookie = Cookie::new(app_request_config.cookie_name, encrypted_signed_cookie);
247 let mut headers = HeaderMap::new();
248 headers.insert(header::COOKIE, cookie.to_string().parse().unwrap());
249
250 Request::builder()
251 .uri("/")
252 .header(header::COOKIE, cookie.to_string())
253 .body(Body::empty())
254 .unwrap()
255 } else {
256 Request::builder().uri("/").body(Body::empty()).unwrap()
257 };
258
259 (app, request)
260 }
261
262 async fn get_response_body(response: Response<Body>) -> String {
267 let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
268 String::from_utf8(body.to_vec()).unwrap()
269 }
270
271 mod cookie_config {
272 use crate::{CookieAlgorithm, CookieConfig};
273
274 #[test]
275 fn test_config_does_not_expose_details_when_printed() {
276 let cookie_config = CookieConfig::new(
277 "_some_app_session_id",
278 "3b53beba93922c29b3c335051f79e41c63fe626834d5a4a7ce96ebd189010063",
279 CookieAlgorithm::AesHmac,
280 );
281
282 let result = format!("{}", cookie_config);
283 let expected = "CookieConfig {}".to_string();
284
285 assert_eq!(expected, result);
286 }
287 }
288
289 mod aes_hmac {
290 use super::*;
291
292 #[tokio::test]
293 async fn test_valid_cookie_extraction() {
294 let (app, request) = generate_app_and_request(AppRequestConfig {
295 cookie_name: "test_cookie".into(),
296 include_config: true,
297 include_cookie_header: true,
298 encryption_salt: ENCRYPTION_SALT.to_string(),
299 cookie_algorithm: CookieAlgorithm::AesHmac,
300 });
301 let response = app.oneshot(request).await.unwrap();
302
303 assert_eq!(response.status(), 200);
304
305 let body = get_response_body(response).await;
306
307 assert_eq!(body, "hello world");
308 }
309
310 #[tokio::test]
311 async fn test_invalid_cookie_extraction() {
312 let (app, request) = generate_app_and_request(AppRequestConfig {
313 cookie_name: "does_not_exist".into(),
314 include_config: true,
315 include_cookie_header: true,
316 encryption_salt: ENCRYPTION_SALT.to_string(),
317 cookie_algorithm: CookieAlgorithm::AesHmac,
318 });
319 let response = app.oneshot(request).await.unwrap();
320
321 assert_eq!(response.status(), 200);
322
323 let body = get_response_body(response).await;
324
325 assert_eq!(body, "Failed to get cookie");
326 }
327
328 #[tokio::test]
329 async fn test_missing_app_config_extension_extraction() {
330 let (app, request) = generate_app_and_request(AppRequestConfig {
331 cookie_name: "test_cookie".into(),
332 include_config: false,
333 include_cookie_header: true,
334 encryption_salt: ENCRYPTION_SALT.to_string(),
335 cookie_algorithm: CookieAlgorithm::AesHmac,
336 });
337 let response = app.oneshot(request).await.unwrap();
338
339 assert_eq!(response.status(), 200);
340
341 let body = get_response_body(response).await;
342
343 assert_eq!(body, "Failed to extract Config");
344 }
345
346 #[tokio::test]
347 async fn test_missing_cookie_header_extraction() {
348 let (app, request) = generate_app_and_request(AppRequestConfig {
349 cookie_name: "test_cookie".into(),
350 include_config: true,
351 include_cookie_header: false,
352 encryption_salt: ENCRYPTION_SALT.to_string(),
353 cookie_algorithm: CookieAlgorithm::AesHmac,
354 });
355 let response = app.oneshot(request).await.unwrap();
356
357 assert_eq!(response.status(), 200);
358
359 let body = get_response_body(response).await;
360
361 assert_eq!(body, "Failed to get cookie");
362 }
363
364 #[tokio::test]
365 async fn test_incorrect_encryption_salt_extraction() {
366 let (app, request) = generate_app_and_request(AppRequestConfig {
367 cookie_name: "test_cookie".into(),
368 include_config: true,
369 include_cookie_header: true,
370 encryption_salt: "".to_string(),
371 cookie_algorithm: CookieAlgorithm::AesHmac,
372 });
373 let response = app.oneshot(request).await.unwrap();
374
375 assert_eq!(response.status(), 200);
376
377 let body = get_response_body(response).await;
378
379 assert_eq!(body, "Failed to decrypt cookie data");
380 }
381 }
382
383 mod aes_gcm {
384 use super::*;
385
386 #[tokio::test]
387 async fn test_valid_cookie_extraction() {
388 let (app, request) = generate_app_and_request(AppRequestConfig {
389 cookie_name: "test_cookie".into(),
390 include_config: true,
391 include_cookie_header: true,
392 encryption_salt: ENCRYPTION_SALT.to_string(),
393 cookie_algorithm: CookieAlgorithm::AesGcm,
394 });
395 let response = app.oneshot(request).await.unwrap();
396
397 assert_eq!(response.status(), 200);
398
399 let body = get_response_body(response).await;
400
401 assert_eq!(body, "hello world");
402 }
403
404 #[tokio::test]
405 async fn test_invalid_cookie_extraction() {
406 let (app, request) = generate_app_and_request(AppRequestConfig {
407 cookie_name: "does_not_exist".into(),
408 include_config: true,
409 include_cookie_header: true,
410 encryption_salt: ENCRYPTION_SALT.to_string(),
411 cookie_algorithm: CookieAlgorithm::AesGcm,
412 });
413 let response = app.oneshot(request).await.unwrap();
414
415 assert_eq!(response.status(), 200);
416
417 let body = get_response_body(response).await;
418
419 assert_eq!(body, "Failed to get cookie");
420 }
421
422 #[tokio::test]
423 async fn test_missing_app_config_extension_extraction() {
424 let (app, request) = generate_app_and_request(AppRequestConfig {
425 cookie_name: "test_cookie".into(),
426 include_config: false,
427 include_cookie_header: true,
428 encryption_salt: ENCRYPTION_SALT.to_string(),
429 cookie_algorithm: CookieAlgorithm::AesGcm,
430 });
431 let response = app.oneshot(request).await.unwrap();
432
433 assert_eq!(response.status(), 200);
434
435 let body = get_response_body(response).await;
436
437 assert_eq!(body, "Failed to extract Config");
438 }
439
440 #[tokio::test]
441 async fn test_missing_cookie_header_extraction() {
442 let (app, request) = generate_app_and_request(AppRequestConfig {
443 cookie_name: "test_cookie".into(),
444 include_config: true,
445 include_cookie_header: false,
446 encryption_salt: ENCRYPTION_SALT.to_string(),
447 cookie_algorithm: CookieAlgorithm::AesGcm,
448 });
449 let response = app.oneshot(request).await.unwrap();
450
451 assert_eq!(response.status(), 200);
452
453 let body = get_response_body(response).await;
454
455 assert_eq!(body, "Failed to get cookie");
456 }
457
458 #[tokio::test]
459 async fn test_incorrect_encryption_salt_extraction() {
460 let (app, request) = generate_app_and_request(AppRequestConfig {
461 cookie_name: "test_cookie".into(),
462 include_config: true,
463 include_cookie_header: true,
464 encryption_salt: "".to_string(),
465 cookie_algorithm: CookieAlgorithm::AesGcm,
466 });
467 let response = app.oneshot(request).await.unwrap();
468
469 assert_eq!(response.status(), 200);
470
471 let body = get_response_body(response).await;
472
473 assert_eq!(body, "Failed to decrypt cookie data");
474 }
475 }
476}