1use crate::{apis::jwks_api::JwksKey, validators::jwks::JwksProvider};
2use jsonwebtoken::{decode, decode_header, errors::Error as jwtError, Algorithm, DecodingKey, Header, Validation};
3use serde_json::{Map, Value};
4use std::{error::Error, fmt, sync::Arc};
5
6#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
7pub struct ActiveOrganization {
8 #[serde(rename = "org_id")]
9 pub id: String,
10 #[serde(rename = "org_slug")]
11 pub slug: String,
12 #[serde(rename = "org_role")]
13 pub role: String,
14 #[serde(rename = "org_permissions")]
15 pub permissions: Vec<String>,
16}
17
18impl ActiveOrganization {
19 pub fn has_permission(&self, permission: &str) -> bool {
21 self.permissions.contains(&permission.to_string())
22 }
23
24 pub fn has_role(&self, role: &str) -> bool {
29 self.role == role
30 }
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
34pub struct Actor {
35 pub iss: Option<String>,
36 pub sid: Option<String>,
37 pub sub: String,
38}
39
40#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
41pub struct ClerkJwt {
42 pub azp: Option<String>,
43 pub exp: i32,
44 pub iat: i32,
45 pub iss: String,
46 pub nbf: i32,
47 pub sid: Option<String>,
48 pub sub: String,
49 pub act: Option<Actor>,
50 #[serde(flatten)]
51 pub org: Option<ActiveOrganization>,
52 #[serde(flatten)]
55 pub other: Map<String, Value>,
56}
57
58pub trait ClerkRequest {
59 fn get_header(&self, key: &str) -> Option<String>;
60 fn get_cookie(&self, key: &str) -> Option<String>;
61}
62
63#[derive(Clone, Debug)]
64pub enum ClerkError {
65 Unauthorized(String),
66 InternalServerError(String),
67}
68
69impl fmt::Display for ClerkError {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 match self {
72 ClerkError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
73 ClerkError::InternalServerError(msg) => write!(f, "Internal Server Error: {}", msg),
74 }
75 }
76}
77
78impl Error for ClerkError {}
79
80pub struct ClerkAuthorizer<J> {
81 jwks_provider: Arc<J>,
82 validate_session_cookie: bool,
83}
84
85impl<J: JwksProvider> ClerkAuthorizer<J> {
86 pub fn new(jwks_provider: J, validate_session_cookie: bool) -> Self {
88 Self {
89 jwks_provider: Arc::new(jwks_provider),
90 validate_session_cookie,
91 }
92 }
93
94 pub fn jwks_provider(&self) -> &Arc<J> {
96 &self.jwks_provider
97 }
98
99 pub async fn authorize<T>(&self, request: &T) -> Result<ClerkJwt, ClerkError>
101 where
102 T: ClerkRequest,
103 {
104 let access_token: String = match request.get_header("Authorization") {
106 Some(val) => val.to_string().replace("Bearer ", ""),
107 None => match self.validate_session_cookie {
108 true => match request.get_cookie("__session") {
109 Some(cookie) => cookie.to_string(),
110 None => {
111 return Err(ClerkError::Unauthorized(String::from(
112 "Error: No Authorization header or session cookie found on the request payload!",
113 )))
114 }
115 },
116 false => {
117 return Err(ClerkError::Unauthorized(String::from(
118 "Error: No Authorization header found on the request payload!",
119 )))
120 }
121 },
122 };
123
124 validate_jwt(&access_token, self.jwks_provider.clone()).await
125 }
126}
127
128impl<J> Clone for ClerkAuthorizer<J> {
129 fn clone(&self) -> Self {
130 Self {
131 jwks_provider: self.jwks_provider.clone(),
132 validate_session_cookie: self.validate_session_cookie,
133 }
134 }
135}
136
137pub async fn validate_jwt<J: JwksProvider>(token: &str, jwks: Arc<J>) -> Result<ClerkJwt, ClerkError> {
141 let kid = match get_token_header(token).map(|h| h.kid) {
143 Ok(Some(kid)) => kid,
144 _ => {
145 return Err(ClerkError::Unauthorized(String::from("Error: Invalid JWT!")));
147 }
148 };
149
150 let Ok(key) = jwks.get_key(&kid).await else {
152 return Err(ClerkError::Unauthorized(String::from("Error: Invalid JWT!")));
154 };
155
156 validate_jwt_with_key(token, &key)
157}
158
159pub fn validate_jwt_with_key(token: &str, key: &JwksKey) -> Result<ClerkJwt, ClerkError> {
163 match key.alg.as_str() {
164 "RS256" => {
166 let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e)
167 .map_err(|_| ClerkError::InternalServerError(String::from("Error: Invalid decoding key")))?;
168
169 let mut validation = Validation::new(Algorithm::RS256);
170 validation.validate_exp = true;
171 validation.validate_nbf = true;
172
173 match decode::<ClerkJwt>(token, &decoding_key, &validation) {
174 Ok(token) => Ok(token.claims),
175 Err(err) => Err(ClerkError::Unauthorized(format!("Error: Invalid JWT! cause: {}", err))),
176 }
177 }
178 _ => Err(ClerkError::InternalServerError(String::from("Error: Unsupported key algorithm"))),
179 }
180}
181
182fn get_token_header(token: &str) -> Result<Header, jwtError> {
184 let header = decode_header(&token);
185 header
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::{apis::jwks_api::JwksKey, validators::jwks::tests::StaticJwksProvider};
192 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
193 use base64::prelude::*;
194 use jsonwebtoken::{encode, errors::ErrorKind, Algorithm, EncodingKey, Header};
195 use rsa::{pkcs1::EncodeRsaPrivateKey, traits::PublicKeyParts, RsaPrivateKey};
196 use std::time::{SystemTime, UNIX_EPOCH};
197
198 #[derive(Debug, serde::Serialize)]
199 struct CustomFields {
200 custom_attribute: String,
201 }
202
203 #[derive(Debug, serde::Serialize)]
204 struct Claims {
205 sub: String,
206 iat: usize,
207 nbf: usize,
208 exp: usize,
209 azp: String,
210 iss: String,
211 sid: String,
212 act: Actor,
213 org_id: String,
214 org_slug: String,
215 org_role: String,
216 org_permissions: Vec<String>,
217 custom_key: String,
218 custom_map: CustomFields,
219 }
220
221 struct Helper {
222 private_key: RsaPrivateKey,
223 }
224
225 impl Helper {
226 pub fn new() -> Self {
227 let mut rng = rand::thread_rng();
228
229 Self {
230 private_key: RsaPrivateKey::new(&mut rng, 2048).unwrap(),
231 }
232 }
233
234 pub fn generate_jwt_token(&self, kid: Option<&str>, current_time: Option<usize>, expired: bool) -> String {
235 let pem = self.private_key.to_pkcs1_pem(rsa::pkcs8::LineEnding::LF).unwrap();
236 let encoding_key = EncodingKey::from_rsa_pem(pem.as_bytes()).expect("Failed to load encoding key");
237
238 let mut current_time = current_time.unwrap_or(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize);
239
240 if expired {
241 current_time -= 5000;
243 }
244
245 let expiration = current_time + 1000;
247
248 let claims = Claims {
249 azp: "client_id".to_string(),
250 sub: "user".to_string(),
251 iat: current_time,
252 exp: expiration,
253 iss: "issuer".to_string(),
254 nbf: current_time,
255 sid: "session_id".to_string(),
256 org_id: "org_id".to_string(),
257 org_slug: "org_slug".to_string(),
258 org_role: "org_role".to_string(),
259 org_permissions: vec!["org_permission".to_string()],
260 act: Actor {
261 iss: Some("actor_iss".to_string()),
262 sid: Some("actor_sid".to_string()),
263 sub: "actor_sub".to_string(),
264 },
265 custom_key: "custom_value".to_string(),
266 custom_map: CustomFields {
267 custom_attribute: "custom_attribute".to_string(),
268 },
269 };
270
271 let mut header = Header::new(Algorithm::RS256);
272 if let Some(kid_value) = kid {
273 header.kid = Some(kid_value.to_string());
274 }
275
276 let token = encode(&header, &claims, &encoding_key).expect("Failed to create jwt token");
277
278 token
279 }
280
281 pub fn get_modulus_and_public_exponent(&self) -> (String, String) {
282 let encoded_modulus = URL_SAFE_NO_PAD.encode(self.private_key.n().to_bytes_be().as_slice());
283 let encoded_exponent = URL_SAFE_NO_PAD.encode(self.private_key.e().to_bytes_be().as_slice());
284 (encoded_modulus, encoded_exponent)
285 }
286 }
287
288 #[test]
289 fn test_validate_jwt_with_key_success() {
290 let helper = Helper::new();
291
292 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
293
294 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
295
296 let jwks_key = JwksKey {
297 use_key: String::new(),
298 kty: String::new(),
299 kid: kid.to_string(),
300 alg: String::from("RS256"),
301 n: modulus,
302 e: exponent,
303 };
304
305 let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize;
306 let token = helper.generate_jwt_token(Some(kid), Some(current_time), false);
307
308 let expected = ClerkJwt {
309 azp: Some("client_id".to_string()),
310 sub: "user".to_string(),
311 iat: current_time as i32,
312 exp: (current_time + 1000) as i32,
313 iss: "issuer".to_string(),
314 nbf: current_time as i32,
315 sid: Some("session_id".to_string()),
316 act: Some(Actor {
317 iss: Some("actor_iss".to_string()),
318 sid: Some("actor_sid".to_string()),
319 sub: "actor_sub".to_string(),
320 }),
321 org: Some(ActiveOrganization {
322 id: "org_id".to_string(),
323 slug: "org_slug".to_string(),
324 role: "org_role".to_string(),
325 permissions: vec!["org_permission".to_string()],
326 }),
327 other: {
328 let mut map = Map::new();
329 map.insert("custom_key".to_string(), Value::String("custom_value".to_string()));
330 map.insert(
331 "custom_map".to_string(),
332 Value::Object({
333 let mut map = Map::new();
334 map.insert("custom_attribute".to_string(), Value::String("custom_attribute".to_string()));
335 map
336 }),
337 );
338 map
339 },
340 };
341
342 assert_eq!(validate_jwt_with_key(token.as_str(), &jwks_key).expect("should be valid"), expected);
343 }
344
345 #[test]
346 fn test_validate_jwt_with_key_unexpected_key_algorithm() {
347 let helper = Helper::new();
348
349 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
350
351 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
352
353 let jwks_key = JwksKey {
354 use_key: String::new(),
355 kty: String::new(),
356 kid: kid.to_string(),
357 alg: String::from("INVALIDALGORITHM"),
358 n: modulus,
359 e: exponent,
360 };
361
362 let token = helper.generate_jwt_token(Some(kid), None, false);
363
364 assert!(matches!(
365 validate_jwt_with_key(&token, &jwks_key),
366 Err(ClerkError::InternalServerError(_))
367 ))
368 }
369
370 #[test]
371 fn test_validate_jwt_with_key_invalid_decoding_key() {
372 let helper = Helper::new();
373
374 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
375
376 let jwks_key = JwksKey {
377 use_key: String::new(),
378 kty: String::new(),
379 kid: kid.to_string(),
380 alg: String::from("RS256"),
381 n: String::from("INVALIDMODULUS"),
382 e: String::from("INVALIDEXPONENT"),
383 };
384
385 let token = helper.generate_jwt_token(Some(kid), None, false);
386
387 assert!(matches!(
388 validate_jwt_with_key(&token, &jwks_key),
389 Err(ClerkError::InternalServerError(_))
390 ))
391 }
392
393 #[test]
394 fn test_validate_jwt_with_key_invalid_sig() {
395 let helper1 = Helper::new();
396 let helper2 = Helper::new();
397
398 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
399
400 let (modulus, exponent) = helper1.get_modulus_and_public_exponent();
401
402 let jwks_key = JwksKey {
403 use_key: String::new(),
404 kty: String::new(),
405 kid: kid.to_string(),
406 alg: String::from("RS256"),
407 n: modulus,
408 e: exponent,
409 };
410
411 let token = helper2.generate_jwt_token(None, None, false);
412
413 let res = validate_jwt_with_key(&token, &jwks_key);
414 assert!(matches!(res, Err(ClerkError::Unauthorized(_))));
415 }
416
417 #[test]
418 fn test_validate_jwt_with_key_expired() {
419 let helper = Helper::new();
420
421 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
422
423 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
424
425 let jwks_key = JwksKey {
426 use_key: String::new(),
427 kty: String::new(),
428 kid: kid.to_string(),
429 alg: String::from("RS256"),
430 n: modulus,
431 e: exponent,
432 };
433
434 let token = helper.generate_jwt_token(Some(kid), None, true);
435
436 let res = validate_jwt_with_key(&token, &jwks_key);
437 assert!(matches!(res, Err(ClerkError::Unauthorized(_))))
438 }
439
440 #[tokio::test]
441 async fn test_validate_jwt_success() {
442 let helper = Helper::new();
443
444 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
445
446 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
447
448 let jwks_key = JwksKey {
449 use_key: String::new(),
450 kty: String::new(),
451 kid: kid.to_string(),
452 alg: String::from("RS256"),
453 n: modulus,
454 e: exponent,
455 };
456 let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
457
458 let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize;
459 let token = helper.generate_jwt_token(Some(kid), Some(current_time), false);
460
461 let expected = ClerkJwt {
462 azp: Some("client_id".to_string()),
463 sub: "user".to_string(),
464 iat: current_time as i32,
465 exp: (current_time + 1000) as i32,
466 iss: "issuer".to_string(),
467 nbf: current_time as i32,
468 sid: Some("session_id".to_string()),
469 act: Some(Actor {
470 iss: Some("actor_iss".to_string()),
471 sid: Some("actor_sid".to_string()),
472 sub: "actor_sub".to_string(),
473 }),
474 org: Some(ActiveOrganization {
475 id: "org_id".to_string(),
476 slug: "org_slug".to_string(),
477 role: "org_role".to_string(),
478 permissions: vec!["org_permission".to_string()],
479 }),
480 other: {
481 let mut map = Map::new();
482 map.insert("custom_key".to_string(), Value::String("custom_value".to_string()));
483 map.insert(
484 "custom_map".to_string(),
485 Value::Object({
486 let mut map = Map::new();
487 map.insert("custom_attribute".to_string(), Value::String("custom_attribute".to_string()));
488 map
489 }),
490 );
491 map
492 },
493 };
494
495 assert_eq!(validate_jwt(token.as_str(), jwks).await.expect("should be valid"), expected);
496 }
497
498 #[tokio::test]
499 async fn test_validate_jwt_invalid_token() {
500 let helper = Helper::new();
501
502 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
503
504 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
505
506 let jwks_key = JwksKey {
507 use_key: String::new(),
508 kty: String::new(),
509 kid: kid.to_string(),
510 alg: String::from("RS256"),
511 n: modulus,
512 e: exponent,
513 };
514 let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
515
516 assert!(matches!(validate_jwt("invalid_token", jwks).await, Err(ClerkError::Unauthorized(_))))
517 }
518
519 #[tokio::test]
520 async fn test_validate_jwt_missing_kid() {
521 let helper = Helper::new();
522
523 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
524
525 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
526
527 let jwks_key = JwksKey {
528 use_key: String::new(),
529 kty: String::new(),
530 kid: kid.to_string(),
531 alg: String::from("RS256"),
532 n: modulus,
533 e: exponent,
534 };
535 let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
536
537 let token = helper.generate_jwt_token(None, None, false);
538
539 assert!(matches!(validate_jwt(&token, jwks).await, Err(ClerkError::Unauthorized(_))))
540 }
541
542 #[tokio::test]
543 async fn test_validate_jwt_unknown_key() {
544 let helper = Helper::new();
545
546 let (modulus, exponent) = helper.get_modulus_and_public_exponent();
547
548 let jwks_key = JwksKey {
549 use_key: String::new(),
550 kty: String::new(),
551 kid: String::from("a288cbf5-fec1-41e3-ae83-5b0d122bf925"),
552 alg: String::from("RS256"),
553 n: modulus,
554 e: exponent,
555 };
556 let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
557
558 let token = helper.generate_jwt_token(Some("bc63c2e9-5d1c-4e32-9b62-178f60409abd"), None, false);
559
560 assert!(matches!(validate_jwt(&token, jwks).await, Err(ClerkError::Unauthorized(_))))
561 }
562
563 #[test]
564 fn test_helper_generate_token_header() {
565 let helper = Helper::new();
566
567 let token = helper.generate_jwt_token(None, None, false);
568 let expected = Header::new(Algorithm::RS256);
569
570 assert_eq!(get_token_header(&token).expect("should be valid"), expected);
571 }
572
573 #[test]
574 fn test_helper_generate_token_header_with_kid() {
575 let helper = Helper::new();
576
577 let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd".to_string();
578
579 let token = helper.generate_jwt_token(Some(&kid), None, false);
580 let mut expected = Header::new(Algorithm::RS256);
581 expected.kid = Some(kid);
582
583 assert_eq!(get_token_header(&token).expect("should be valid"), expected);
584 }
585
586 #[test]
587 fn test_helper_generate_token_header_error() {
588 let token = "invalid_jwt_token";
589
590 let err = get_token_header(token).expect_err("should be invalid");
591 assert_eq!(err.kind().to_owned(), ErrorKind::InvalidToken);
592 }
593}