clerk_rs/validators/
authorizer.rs

1use crate::{apis::jwks_api::JwksKey, validators::jwks::JwksProvider};
2use jsonwebtoken::{decode, decode_header, errors::Error as jwtError, Algorithm, DecodingKey, Header, Validation};
3use serde_json::{Map, Value};
4use std::{error::Error, fmt, sync::Arc};
5
6#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
7pub struct ActiveOrganization {
8	#[serde(rename = "org_id")]
9	pub id: String,
10	#[serde(rename = "org_slug")]
11	pub slug: String,
12	#[serde(rename = "org_role")]
13	pub role: String,
14	#[serde(rename = "org_permissions")]
15	pub permissions: Vec<String>,
16}
17
18impl ActiveOrganization {
19	/// Checks if the user has the specific permission in their session claims.
20	pub fn has_permission(&self, permission: &str) -> bool {
21		self.permissions.contains(&permission.to_string())
22	}
23
24	/// Checks if the user has the specific role in their session claims.
25	/// Performing role checks is not considered a best-practice and developers
26	/// should avoid it as much as possible. Usually, complex role checks can be
27	/// refactored with a single permission check.
28	pub fn has_role(&self, role: &str) -> bool {
29		self.role == role
30	}
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
34pub struct Actor {
35	pub iss: Option<String>,
36	pub sid: Option<String>,
37	pub sub: String,
38}
39
40#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
41pub struct ClerkJwt {
42	pub azp: Option<String>,
43	pub exp: i32,
44	pub iat: i32,
45	pub iss: String,
46	pub nbf: i32,
47	pub sid: Option<String>,
48	pub sub: String,
49	pub act: Option<Actor>,
50	#[serde(flatten)]
51	pub org: Option<ActiveOrganization>,
52	/// Catch-all for any other attributes that may be present in the JWT. This
53	/// is useful for custom templates that may have additional fields
54	#[serde(flatten)]
55	pub other: Map<String, Value>,
56}
57
58pub trait ClerkRequest {
59	fn get_header(&self, key: &str) -> Option<String>;
60	fn get_cookie(&self, key: &str) -> Option<String>;
61}
62
63#[derive(Clone, Debug)]
64pub enum ClerkError {
65	Unauthorized(String),
66	InternalServerError(String),
67}
68
69impl fmt::Display for ClerkError {
70	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71		match self {
72			ClerkError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
73			ClerkError::InternalServerError(msg) => write!(f, "Internal Server Error: {}", msg),
74		}
75	}
76}
77
78impl Error for ClerkError {}
79
80pub struct ClerkAuthorizer<J> {
81	jwks_provider: Arc<J>,
82	validate_session_cookie: bool,
83}
84
85impl<J: JwksProvider> ClerkAuthorizer<J> {
86	/// Creates a Clerk authorizer
87	pub fn new(jwks_provider: J, validate_session_cookie: bool) -> Self {
88		Self {
89			jwks_provider: Arc::new(jwks_provider),
90			validate_session_cookie,
91		}
92	}
93
94	/// Returns a reference to the underlying [`JwksProvider`].
95	pub fn jwks_provider(&self) -> &Arc<J> {
96		&self.jwks_provider
97	}
98
99	/// Authorizes a service request against the Clerk auth provider
100	pub async fn authorize<T>(&self, request: &T) -> Result<ClerkJwt, ClerkError>
101	where
102		T: ClerkRequest,
103	{
104		// get the jwt from header or cookies
105		let access_token: String = match request.get_header("Authorization") {
106			Some(val) => val.to_string().replace("Bearer ", ""),
107			None => match self.validate_session_cookie {
108				true => match request.get_cookie("__session") {
109					Some(cookie) => cookie.to_string(),
110					None => {
111						return Err(ClerkError::Unauthorized(String::from(
112							"Error: No Authorization header or session cookie found on the request payload!",
113						)))
114					}
115				},
116				false => {
117					return Err(ClerkError::Unauthorized(String::from(
118						"Error: No Authorization header found on the request payload!",
119					)))
120				}
121			},
122		};
123
124		validate_jwt(&access_token, self.jwks_provider.clone()).await
125	}
126}
127
128impl<J> Clone for ClerkAuthorizer<J> {
129	fn clone(&self) -> Self {
130		Self {
131			jwks_provider: self.jwks_provider.clone(),
132			validate_session_cookie: self.validate_session_cookie,
133		}
134	}
135}
136
137/// Validates a jwt using the given [`JwksProvider`].
138///
139/// The jwt is required to have a `kid` which is used to request the matching key from the provider.
140pub async fn validate_jwt<J: JwksProvider>(token: &str, jwks: Arc<J>) -> Result<ClerkJwt, ClerkError> {
141	// parse the header to get the kid
142	let kid = match get_token_header(token).map(|h| h.kid) {
143		Ok(Some(kid)) => kid,
144		_ => {
145			// if the kid header was invalid or the kid field was unset, error
146			return Err(ClerkError::Unauthorized(String::from("Error: Invalid JWT!")));
147		}
148	};
149
150	// get the key from the provider
151	let Ok(key) = jwks.get_key(&kid).await else {
152		// In the event that a matching jwk was not found we want to output an error
153		return Err(ClerkError::Unauthorized(String::from("Error: Invalid JWT!")));
154	};
155
156	validate_jwt_with_key(token, &key)
157}
158
159/// Validates a jwt using the given jwk.
160///
161/// This function does not check that the token's kid matches the key's.
162pub fn validate_jwt_with_key(token: &str, key: &JwksKey) -> Result<ClerkJwt, ClerkError> {
163	match key.alg.as_str() {
164		// Currently, clerk only supports Rs256 by default
165		"RS256" => {
166			let decoding_key = DecodingKey::from_rsa_components(&key.n, &key.e)
167				.map_err(|_| ClerkError::InternalServerError(String::from("Error: Invalid decoding key")))?;
168
169			let mut validation = Validation::new(Algorithm::RS256);
170			validation.validate_exp = true;
171			validation.validate_nbf = true;
172
173			match decode::<ClerkJwt>(token, &decoding_key, &validation) {
174				Ok(token) => Ok(token.claims),
175				Err(err) => Err(ClerkError::Unauthorized(format!("Error: Invalid JWT! cause: {}", err))),
176			}
177		}
178		_ => Err(ClerkError::InternalServerError(String::from("Error: Unsupported key algorithm"))),
179	}
180}
181
182/// Extract the header from a jwt token
183fn get_token_header(token: &str) -> Result<Header, jwtError> {
184	let header = decode_header(&token);
185	header
186}
187
188#[cfg(test)]
189mod tests {
190	use super::*;
191	use crate::{apis::jwks_api::JwksKey, validators::jwks::tests::StaticJwksProvider};
192	use base64::engine::general_purpose::URL_SAFE_NO_PAD;
193	use base64::prelude::*;
194	use jsonwebtoken::{encode, errors::ErrorKind, Algorithm, EncodingKey, Header};
195	use rsa::{pkcs1::EncodeRsaPrivateKey, traits::PublicKeyParts, RsaPrivateKey};
196	use std::time::{SystemTime, UNIX_EPOCH};
197
198	#[derive(Debug, serde::Serialize)]
199	struct CustomFields {
200		custom_attribute: String,
201	}
202
203	#[derive(Debug, serde::Serialize)]
204	struct Claims {
205		sub: String,
206		iat: usize,
207		nbf: usize,
208		exp: usize,
209		azp: String,
210		iss: String,
211		sid: String,
212		act: Actor,
213		org_id: String,
214		org_slug: String,
215		org_role: String,
216		org_permissions: Vec<String>,
217		custom_key: String,
218		custom_map: CustomFields,
219	}
220
221	struct Helper {
222		private_key: RsaPrivateKey,
223	}
224
225	impl Helper {
226		pub fn new() -> Self {
227			let mut rng = rand::thread_rng();
228
229			Self {
230				private_key: RsaPrivateKey::new(&mut rng, 2048).unwrap(),
231			}
232		}
233
234		pub fn generate_jwt_token(&self, kid: Option<&str>, current_time: Option<usize>, expired: bool) -> String {
235			let pem = self.private_key.to_pkcs1_pem(rsa::pkcs8::LineEnding::LF).unwrap();
236			let encoding_key = EncodingKey::from_rsa_pem(pem.as_bytes()).expect("Failed to load encoding key");
237
238			let mut current_time = current_time.unwrap_or(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize);
239
240			if expired {
241				// issue the token some time in the past so that it's expired now
242				current_time -= 5000;
243			}
244
245			// expire after 1000 secs
246			let expiration = current_time + 1000;
247
248			let claims = Claims {
249				azp: "client_id".to_string(),
250				sub: "user".to_string(),
251				iat: current_time,
252				exp: expiration,
253				iss: "issuer".to_string(),
254				nbf: current_time,
255				sid: "session_id".to_string(),
256				org_id: "org_id".to_string(),
257				org_slug: "org_slug".to_string(),
258				org_role: "org_role".to_string(),
259				org_permissions: vec!["org_permission".to_string()],
260				act: Actor {
261					iss: Some("actor_iss".to_string()),
262					sid: Some("actor_sid".to_string()),
263					sub: "actor_sub".to_string(),
264				},
265				custom_key: "custom_value".to_string(),
266				custom_map: CustomFields {
267					custom_attribute: "custom_attribute".to_string(),
268				},
269			};
270
271			let mut header = Header::new(Algorithm::RS256);
272			if let Some(kid_value) = kid {
273				header.kid = Some(kid_value.to_string());
274			}
275
276			let token = encode(&header, &claims, &encoding_key).expect("Failed to create jwt token");
277
278			token
279		}
280
281		pub fn get_modulus_and_public_exponent(&self) -> (String, String) {
282			let encoded_modulus = URL_SAFE_NO_PAD.encode(self.private_key.n().to_bytes_be().as_slice());
283			let encoded_exponent = URL_SAFE_NO_PAD.encode(self.private_key.e().to_bytes_be().as_slice());
284			(encoded_modulus, encoded_exponent)
285		}
286	}
287
288	#[test]
289	fn test_validate_jwt_with_key_success() {
290		let helper = Helper::new();
291
292		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
293
294		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
295
296		let jwks_key = JwksKey {
297			use_key: String::new(),
298			kty: String::new(),
299			kid: kid.to_string(),
300			alg: String::from("RS256"),
301			n: modulus,
302			e: exponent,
303		};
304
305		let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize;
306		let token = helper.generate_jwt_token(Some(kid), Some(current_time), false);
307
308		let expected = ClerkJwt {
309			azp: Some("client_id".to_string()),
310			sub: "user".to_string(),
311			iat: current_time as i32,
312			exp: (current_time + 1000) as i32,
313			iss: "issuer".to_string(),
314			nbf: current_time as i32,
315			sid: Some("session_id".to_string()),
316			act: Some(Actor {
317				iss: Some("actor_iss".to_string()),
318				sid: Some("actor_sid".to_string()),
319				sub: "actor_sub".to_string(),
320			}),
321			org: Some(ActiveOrganization {
322				id: "org_id".to_string(),
323				slug: "org_slug".to_string(),
324				role: "org_role".to_string(),
325				permissions: vec!["org_permission".to_string()],
326			}),
327			other: {
328				let mut map = Map::new();
329				map.insert("custom_key".to_string(), Value::String("custom_value".to_string()));
330				map.insert(
331					"custom_map".to_string(),
332					Value::Object({
333						let mut map = Map::new();
334						map.insert("custom_attribute".to_string(), Value::String("custom_attribute".to_string()));
335						map
336					}),
337				);
338				map
339			},
340		};
341
342		assert_eq!(validate_jwt_with_key(token.as_str(), &jwks_key).expect("should be valid"), expected);
343	}
344
345	#[test]
346	fn test_validate_jwt_with_key_unexpected_key_algorithm() {
347		let helper = Helper::new();
348
349		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
350
351		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
352
353		let jwks_key = JwksKey {
354			use_key: String::new(),
355			kty: String::new(),
356			kid: kid.to_string(),
357			alg: String::from("INVALIDALGORITHM"),
358			n: modulus,
359			e: exponent,
360		};
361
362		let token = helper.generate_jwt_token(Some(kid), None, false);
363
364		assert!(matches!(
365			validate_jwt_with_key(&token, &jwks_key),
366			Err(ClerkError::InternalServerError(_))
367		))
368	}
369
370	#[test]
371	fn test_validate_jwt_with_key_invalid_decoding_key() {
372		let helper = Helper::new();
373
374		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
375
376		let jwks_key = JwksKey {
377			use_key: String::new(),
378			kty: String::new(),
379			kid: kid.to_string(),
380			alg: String::from("RS256"),
381			n: String::from("INVALIDMODULUS"),
382			e: String::from("INVALIDEXPONENT"),
383		};
384
385		let token = helper.generate_jwt_token(Some(kid), None, false);
386
387		assert!(matches!(
388			validate_jwt_with_key(&token, &jwks_key),
389			Err(ClerkError::InternalServerError(_))
390		))
391	}
392
393	#[test]
394	fn test_validate_jwt_with_key_invalid_sig() {
395		let helper1 = Helper::new();
396		let helper2 = Helper::new();
397
398		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
399
400		let (modulus, exponent) = helper1.get_modulus_and_public_exponent();
401
402		let jwks_key = JwksKey {
403			use_key: String::new(),
404			kty: String::new(),
405			kid: kid.to_string(),
406			alg: String::from("RS256"),
407			n: modulus,
408			e: exponent,
409		};
410
411		let token = helper2.generate_jwt_token(None, None, false);
412
413		let res = validate_jwt_with_key(&token, &jwks_key);
414		assert!(matches!(res, Err(ClerkError::Unauthorized(_))));
415	}
416
417	#[test]
418	fn test_validate_jwt_with_key_expired() {
419		let helper = Helper::new();
420
421		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
422
423		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
424
425		let jwks_key = JwksKey {
426			use_key: String::new(),
427			kty: String::new(),
428			kid: kid.to_string(),
429			alg: String::from("RS256"),
430			n: modulus,
431			e: exponent,
432		};
433
434		let token = helper.generate_jwt_token(Some(kid), None, true);
435
436		let res = validate_jwt_with_key(&token, &jwks_key);
437		assert!(matches!(res, Err(ClerkError::Unauthorized(_))))
438	}
439
440	#[tokio::test]
441	async fn test_validate_jwt_success() {
442		let helper = Helper::new();
443
444		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
445
446		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
447
448		let jwks_key = JwksKey {
449			use_key: String::new(),
450			kty: String::new(),
451			kid: kid.to_string(),
452			alg: String::from("RS256"),
453			n: modulus,
454			e: exponent,
455		};
456		let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
457
458		let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize;
459		let token = helper.generate_jwt_token(Some(kid), Some(current_time), false);
460
461		let expected = ClerkJwt {
462			azp: Some("client_id".to_string()),
463			sub: "user".to_string(),
464			iat: current_time as i32,
465			exp: (current_time + 1000) as i32,
466			iss: "issuer".to_string(),
467			nbf: current_time as i32,
468			sid: Some("session_id".to_string()),
469			act: Some(Actor {
470				iss: Some("actor_iss".to_string()),
471				sid: Some("actor_sid".to_string()),
472				sub: "actor_sub".to_string(),
473			}),
474			org: Some(ActiveOrganization {
475				id: "org_id".to_string(),
476				slug: "org_slug".to_string(),
477				role: "org_role".to_string(),
478				permissions: vec!["org_permission".to_string()],
479			}),
480			other: {
481				let mut map = Map::new();
482				map.insert("custom_key".to_string(), Value::String("custom_value".to_string()));
483				map.insert(
484					"custom_map".to_string(),
485					Value::Object({
486						let mut map = Map::new();
487						map.insert("custom_attribute".to_string(), Value::String("custom_attribute".to_string()));
488						map
489					}),
490				);
491				map
492			},
493		};
494
495		assert_eq!(validate_jwt(token.as_str(), jwks).await.expect("should be valid"), expected);
496	}
497
498	#[tokio::test]
499	async fn test_validate_jwt_invalid_token() {
500		let helper = Helper::new();
501
502		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
503
504		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
505
506		let jwks_key = JwksKey {
507			use_key: String::new(),
508			kty: String::new(),
509			kid: kid.to_string(),
510			alg: String::from("RS256"),
511			n: modulus,
512			e: exponent,
513		};
514		let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
515
516		assert!(matches!(validate_jwt("invalid_token", jwks).await, Err(ClerkError::Unauthorized(_))))
517	}
518
519	#[tokio::test]
520	async fn test_validate_jwt_missing_kid() {
521		let helper = Helper::new();
522
523		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
524
525		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
526
527		let jwks_key = JwksKey {
528			use_key: String::new(),
529			kty: String::new(),
530			kid: kid.to_string(),
531			alg: String::from("RS256"),
532			n: modulus,
533			e: exponent,
534		};
535		let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
536
537		let token = helper.generate_jwt_token(None, None, false);
538
539		assert!(matches!(validate_jwt(&token, jwks).await, Err(ClerkError::Unauthorized(_))))
540	}
541
542	#[tokio::test]
543	async fn test_validate_jwt_unknown_key() {
544		let helper = Helper::new();
545
546		let (modulus, exponent) = helper.get_modulus_and_public_exponent();
547
548		let jwks_key = JwksKey {
549			use_key: String::new(),
550			kty: String::new(),
551			kid: String::from("a288cbf5-fec1-41e3-ae83-5b0d122bf925"),
552			alg: String::from("RS256"),
553			n: modulus,
554			e: exponent,
555		};
556		let jwks = Arc::new(StaticJwksProvider::from_key(jwks_key));
557
558		let token = helper.generate_jwt_token(Some("bc63c2e9-5d1c-4e32-9b62-178f60409abd"), None, false);
559
560		assert!(matches!(validate_jwt(&token, jwks).await, Err(ClerkError::Unauthorized(_))))
561	}
562
563	#[test]
564	fn test_helper_generate_token_header() {
565		let helper = Helper::new();
566
567		let token = helper.generate_jwt_token(None, None, false);
568		let expected = Header::new(Algorithm::RS256);
569
570		assert_eq!(get_token_header(&token).expect("should be valid"), expected);
571	}
572
573	#[test]
574	fn test_helper_generate_token_header_with_kid() {
575		let helper = Helper::new();
576
577		let kid = "bc63c2e9-5d1c-4e32-9b62-178f60409abd".to_string();
578
579		let token = helper.generate_jwt_token(Some(&kid), None, false);
580		let mut expected = Header::new(Algorithm::RS256);
581		expected.kid = Some(kid);
582
583		assert_eq!(get_token_header(&token).expect("should be valid"), expected);
584	}
585
586	#[test]
587	fn test_helper_generate_token_header_error() {
588		let token = "invalid_jwt_token";
589
590		let err = get_token_header(token).expect_err("should be invalid");
591		assert_eq!(err.kind().to_owned(), ErrorKind::InvalidToken);
592	}
593}