1use std::marker::PhantomData;
7use std::sync::Arc;
8use std::time::Duration;
9
10use jsonwebtoken::jwk::{Jwk, JwkSet};
11use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation};
12use parking_lot::RwLock;
13
14use crate::errors::AuthError;
15use crate::file_watcher::FileWatcher;
16use crate::jwt::{Key, KeyData, KeyFormat, SignerJwt, StaticTokenProvider, VerifierJwt};
17use crate::metadata::MetadataMap;
18use crate::resolver::KeyResolver;
19use crate::traits::StandardClaims;
20
21pub mod state {
34 pub struct Initial;
38
39 pub struct WithPrivateKey;
43
44 pub struct WithPublicKey;
46
47 pub struct WithToken;
49}
50
51pub struct JwtBuilder<S = state::Initial> {
63 issuer: Option<String>,
65 audience: Option<Vec<String>>,
66 subject: Option<String>,
67
68 private_key: Option<Key>,
70 public_key: Option<Key>,
71 algorithm: Algorithm,
72
73 token_duration: Duration,
75
76 auto_resolve_keys: bool,
78
79 required_claims: Vec<String>,
81
82 custom_claims: MetadataMap,
84
85 token_file: Option<String>,
87
88 _state: PhantomData<S>,
90}
91
92fn resolve_key(key: &Key) -> String {
93 match &key.key {
94 KeyData::Data(key) => key.clone(),
95 KeyData::File(path) => std::fs::read_to_string(path).expect("error reading key file"),
96 }
97}
98
99impl Default for JwtBuilder<state::Initial> {
100 fn default() -> Self {
101 Self {
102 issuer: None,
103 audience: None,
104 subject: None,
105 private_key: None,
106 public_key: None,
107 algorithm: Algorithm::HS256, token_duration: Duration::from_secs(3600), auto_resolve_keys: false,
110 required_claims: Vec::new(),
111 custom_claims: MetadataMap::new(),
112 token_file: None,
113 _state: PhantomData,
114 }
115 }
116}
117
118impl<S> JwtBuilder<S> {
120 fn build_validation(&self) -> Validation {
121 let mut validation = Validation::new(self.algorithm);
122 if let Some(audience) = &self.audience {
123 tracing::info!(?audience, "Setting audience");
124 validation.set_audience(audience);
125 }
126 if let Some(issuer) = &self.issuer {
127 tracing::info!(%issuer, "Setting issuer");
128 validation.set_issuer(&[issuer]);
129 }
130
131 if !self.required_claims.is_empty() {
132 tracing::info!(claims = ?self.required_claims, "Setting required claims");
133 validation.set_required_spec_claims(self.required_claims.as_ref());
134 }
135
136 validation
137 }
138
139 fn build_claims(&self) -> StandardClaims {
140 StandardClaims {
141 iss: self.issuer.clone(),
142 aud: self.audience.clone(),
143 sub: self.subject.clone(),
144 exp: 0, iat: None, nbf: None, jti: None, custom_claims: self.custom_claims.clone(),
149 }
150 }
151}
152
153impl JwtBuilder<state::Initial> {
155 pub fn new() -> Self {
157 Self::default()
158 }
159
160 pub fn issuer(self, issuer: impl Into<String>) -> Self {
162 Self {
163 issuer: Some(issuer.into()),
164 ..self
165 }
166 }
167
168 pub fn audience(self, audience: &[impl Into<String> + Clone]) -> Self {
170 Self {
171 audience: Some(audience.iter().map(|a| a.clone().into()).collect()),
172 ..self
173 }
174 }
175
176 pub fn subject(self, subject: impl Into<String>) -> Self {
178 Self {
179 subject: Some(subject.into()),
180 ..self
181 }
182 }
183
184 pub fn require_exp(self) -> Self {
186 let mut required_claims = self.required_claims.clone();
187 required_claims.push("exp".to_string());
188 Self {
189 required_claims,
190 ..self
191 }
192 }
193
194 pub fn require_nbf(self) -> Self {
196 let mut required_claims = self.required_claims.clone();
197 required_claims.push("nbf".to_string());
198 Self {
199 required_claims,
200 ..self
201 }
202 }
203
204 pub fn require_aud(self) -> Self {
206 let mut required_claims = self.required_claims.clone();
207 required_claims.push("aud".to_string());
208 Self {
209 required_claims,
210 ..self
211 }
212 }
213
214 pub fn require_iss(self) -> Self {
216 let mut required_claims = self.required_claims.clone();
217 required_claims.push("iss".to_string());
218 Self {
219 required_claims,
220 ..self
221 }
222 }
223
224 pub fn require_sub(self) -> Self {
226 let mut required_claims = self.required_claims.clone();
227 required_claims.push("sub".to_string());
228 Self {
229 required_claims,
230 ..self
231 }
232 }
233
234 pub fn private_key(self, key: &Key) -> JwtBuilder<state::WithPrivateKey> {
236 JwtBuilder::<state::WithPrivateKey> {
237 issuer: self.issuer,
238 audience: self.audience,
239 subject: self.subject,
240 private_key: Some(key.clone()),
241 public_key: None,
242 algorithm: key.algorithm,
243 token_duration: self.token_duration,
244 auto_resolve_keys: self.auto_resolve_keys,
245 required_claims: self.required_claims,
246 custom_claims: self.custom_claims,
247 token_file: None,
248 _state: PhantomData,
249 }
250 }
251
252 pub fn public_key(self, key: &Key) -> JwtBuilder<state::WithPublicKey> {
254 JwtBuilder::<state::WithPublicKey> {
255 issuer: self.issuer,
256 audience: self.audience,
257 subject: self.subject,
258 private_key: None,
259 public_key: Some(key.clone()),
260 algorithm: key.algorithm,
261 token_duration: self.token_duration,
262 auto_resolve_keys: self.auto_resolve_keys,
263 required_claims: self.required_claims,
264 custom_claims: self.custom_claims,
265 token_file: None,
266 _state: PhantomData,
267 }
268 }
269
270 pub fn auto_resolve_keys(self, enable: bool) -> JwtBuilder<state::WithPublicKey> {
272 JwtBuilder::<state::WithPublicKey> {
273 issuer: self.issuer,
274 audience: self.audience,
275 subject: self.subject,
276 private_key: None,
277 public_key: None,
278 algorithm: self.algorithm,
279 token_duration: self.token_duration,
280 auto_resolve_keys: enable,
281 required_claims: self.required_claims,
282 custom_claims: self.custom_claims,
283 token_file: None,
284 _state: PhantomData,
285 }
286 }
287
288 pub fn token_file(self, token_file: impl Into<String>) -> JwtBuilder<state::WithToken> {
289 JwtBuilder::<state::WithToken> {
290 issuer: self.issuer,
291 audience: self.audience,
292 subject: self.subject,
293 private_key: self.private_key,
294 public_key: self.public_key,
295 algorithm: self.algorithm,
296 token_duration: self.token_duration,
297 auto_resolve_keys: self.auto_resolve_keys,
298 required_claims: self.required_claims,
299 custom_claims: self.custom_claims,
300 token_file: Some(token_file.into()),
301 _state: PhantomData,
302 }
303 }
304}
305
306impl JwtBuilder<state::WithPrivateKey> {
308 pub fn token_duration(self, duration: Duration) -> Self {
310 Self {
311 token_duration: duration,
312 ..self
313 }
314 }
315
316 pub fn custom_claims(self, claims: MetadataMap) -> Self {
318 Self {
319 custom_claims: claims,
320 ..self
321 }
322 }
323
324 fn build_internal(key: &Key) -> Result<EncodingKey, AuthError> {
325 if key.format == KeyFormat::Jwk {
327 return Err(AuthError::JwtJwkFormatNotSupportedForEncoding);
328 }
329
330 let key_str = resolve_key(key);
331 match key.algorithm {
332 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
333 Ok(EncodingKey::from_secret(key_str.as_bytes()))
334 }
335 Algorithm::RS256
336 | Algorithm::RS384
337 | Algorithm::RS512
338 | Algorithm::PS256
339 | Algorithm::PS384
340 | Algorithm::PS512 => {
341 {
343 let ek = EncodingKey::from_rsa_pem(key_str.as_bytes())?;
344 Ok(ek)
345 }
346 }
347 Algorithm::ES256 | Algorithm::ES384 => {
348 let ek = EncodingKey::from_ec_pem(key_str.as_bytes())?;
350 Ok(ek)
351 }
352 Algorithm::EdDSA => {
353 let ek = EncodingKey::from_ed_pem(key_str.as_bytes())?;
355 Ok(ek)
356 }
357 }
358 }
359
360 pub fn build(self) -> Result<SignerJwt, AuthError> {
362 let validation = self.build_validation();
364
365 let key = Self::build_internal(self.private_key.as_ref().unwrap())?;
367
368 let encoding_key = Arc::new(RwLock::new(key));
370
371 let signer = SignerJwt::new(self.build_claims(), self.token_duration, validation)?
373 .with_encoding_key(encoding_key.clone());
374
375 let signer = match &self.private_key.as_ref().unwrap().key {
377 KeyData::File(path) => {
378 let encoding_key_clone = encoding_key.clone();
380 let key_clone = self.private_key.clone().unwrap();
381 let mut w = FileWatcher::create_watcher(move |_file: &str| {
382 let new_key =
383 Self::build_internal(&key_clone).expect("error processing new key");
384 *encoding_key_clone.as_ref().write() = new_key;
385 });
386 w.add_file(path).expect("error adding file to the watcher");
387
388 signer.with_watcher(w)
389 }
390 _ => signer,
391 };
392
393 Ok(signer)
395 }
396}
397
398enum DecodingKeyInternal {
399 DecKey(DecodingKey),
400 Jwks(JwkSet),
401}
402
403impl JwtBuilder<state::WithPublicKey> {
405 fn build_internal(key: &Key) -> Result<DecodingKeyInternal, AuthError> {
406 let key_str = resolve_key(key);
407
408 match &key.format {
409 KeyFormat::Pem => {
410 match key.algorithm {
412 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => Ok(
413 DecodingKeyInternal::DecKey(DecodingKey::from_secret(key_str.as_bytes())),
414 ),
415 Algorithm::RS256
416 | Algorithm::RS384
417 | Algorithm::RS512
418 | Algorithm::PS256
419 | Algorithm::PS384
420 | Algorithm::PS512 => {
421 let ret = DecodingKey::from_rsa_pem(key_str.as_bytes())?;
423 Ok(DecodingKeyInternal::DecKey(ret))
424 }
425 Algorithm::ES256 | Algorithm::ES384 => {
426 let ret = DecodingKey::from_ec_pem(key_str.as_bytes())?;
428 Ok(DecodingKeyInternal::DecKey(ret))
429 }
430 Algorithm::EdDSA => {
431 let ret = DecodingKey::from_ed_pem(key_str.as_bytes())?;
433 Ok(DecodingKeyInternal::DecKey(ret))
434 }
435 }
436 }
437 KeyFormat::Jwk => {
438 let jwk: Jwk = serde_json::from_str(&resolve_key(key))?;
440 let ret = DecodingKey::from_jwk(&jwk)?;
441 Ok(DecodingKeyInternal::DecKey(ret))
442 }
443 KeyFormat::Jwks => {
444 let jwk_set: JwkSet = serde_json::from_str(&resolve_key(key))?;
446
447 Ok(DecodingKeyInternal::Jwks(jwk_set))
448 }
449 }
450 }
451
452 pub fn build(self) -> Result<VerifierJwt, AuthError> {
454 let validation = self.build_validation();
456
457 let verifier = VerifierJwt::new(self.build_claims(), self.token_duration, validation)?;
458
459 if self.auto_resolve_keys {
461 return Ok(verifier.with_key_resolver(KeyResolver::new()));
462 }
463
464 let key = Self::build_internal(self.public_key.as_ref().unwrap())?;
466
467 match key {
468 DecodingKeyInternal::DecKey(key) => {
469 let decoding_key = Arc::new(RwLock::new(key));
471
472 let verifier = verifier.with_decoding_key(decoding_key.clone());
474
475 let verifier = match &self.public_key.as_ref().unwrap().key {
477 KeyData::File(path) => {
478 let decoding_key_clone = decoding_key.clone();
480 let key_clone = self.public_key.clone().unwrap();
481 let mut w = FileWatcher::create_watcher(move |_file: &str| {
482 let new_key =
483 Self::build_internal(&key_clone).expect("error processing new key");
484 *decoding_key_clone.as_ref().write() = match new_key {
485 DecodingKeyInternal::DecKey(key) => key,
486 _ => panic!("Expected DecodingKey, got Jwks"),
487 };
488 });
489 w.add_file(path).expect("error adding file to the watcher");
490
491 verifier.with_watcher(w)
492 }
493 _ => verifier,
494 };
495
496 Ok(verifier)
497 }
498 DecodingKeyInternal::Jwks(jwk_set) => {
499 let resolver = KeyResolver::with_jwks(jwk_set);
501
502 Ok(verifier.with_key_resolver(resolver))
504 }
505 }
506 }
507}
508
509impl JwtBuilder<state::WithToken> {
511 pub fn build(self) -> Result<StaticTokenProvider, AuthError> {
513 let static_token = std::fs::read_to_string(self.token_file.as_ref().unwrap())
515 .expect("error reading token file");
516 let static_token = Arc::new(RwLock::new(static_token));
517
518 let token_clone = static_token.clone();
519 let mut w = FileWatcher::create_watcher(move |file: &str| {
520 let token = std::fs::read_to_string(file).expect("error reading token file");
521 *token_clone.as_ref().write() = token;
522 });
523 w.add_file(self.token_file.as_ref().unwrap())?;
524
525 Ok(SignerJwt::new(
527 self.build_claims(), std::time::Duration::from_secs(0), self.build_validation(), )?
531 .with_static_token(static_token)
532 .with_watcher(w))
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use crate::traits::TokenProvider;
540 use crate::traits::{Signer, Verifier};
541 use serde::{Deserialize, Serialize};
542 use std::env;
543 use std::fs;
544 use std::fs::File;
545 use std::io::Write;
546 use std::path::Path;
547 use std::path::PathBuf;
548 use std::process;
549 use std::time::SystemTime;
550 use std::time::UNIX_EPOCH;
551
552 use slim_config::tls::provider::initialize_crypto_provider;
553
554 fn temp_file_path(prefix: &str) -> PathBuf {
555 let mut path = env::temp_dir();
556 let unique = SystemTime::now()
557 .duration_since(UNIX_EPOCH)
558 .expect("time went backwards")
559 .as_nanos();
560 path.push(format!("{prefix}_{}_{}.txt", process::id(), unique));
561 path
562 }
563
564 fn create_file(file_path: &Path, content: &str) -> std::io::Result<()> {
565 let mut file = File::create(file_path)?;
566 file.write_all(content.as_bytes())?;
567 Ok(())
568 }
569
570 fn delete_file(file_path: &Path) {
571 let _ = fs::remove_file(file_path);
572 }
573
574 #[test]
575 fn test_jwt_builder_basic() {
576 let jwt = JwtBuilder::new()
577 .issuer("test-issuer")
578 .audience(&["test-audience"])
579 .subject("test-subject")
580 .private_key(&Key {
581 algorithm: Algorithm::HS512,
582 format: KeyFormat::Pem,
583 key: KeyData::Data("test-key".to_string()),
584 })
585 .build()
586 .unwrap();
587
588 let claims = jwt.create_claims();
589
590 assert_eq!(claims.iss.unwrap(), "test-issuer");
591 assert_eq!(claims.aud.unwrap(), &["test-audience"]);
592 assert_eq!(claims.sub.unwrap(), "test-subject");
593 }
594
595 #[tokio::test]
596 async fn test_jwt_builder_basic_key_from_file() {
597 let file_path = temp_file_path("key_file_builder");
599 create_file(&file_path, "tesk-key").expect("failed to create file");
600 let file_name = file_path.to_string_lossy().to_string();
601
602 let jwt = JwtBuilder::new()
604 .issuer("test-issuer")
605 .audience(&["test-audience"])
606 .subject("test-subject")
607 .private_key(&Key {
608 algorithm: Algorithm::HS512,
609 format: KeyFormat::Pem,
610 key: KeyData::File(file_name.to_string()),
611 })
612 .build()
613 .unwrap();
614
615 let claims = jwt.create_claims();
616
617 assert_eq!(claims.iss.unwrap(), "test-issuer");
618 assert_eq!(claims.aud.unwrap(), &["test-audience"]);
619 assert_eq!(claims.sub.unwrap(), "test-subject");
620
621 delete_file(&file_path);
622 }
623
624 #[tokio::test]
625 async fn test_jwt_builder_sign_verify() {
626 let signer = JwtBuilder::new()
628 .issuer("test-issuer")
629 .audience(&["test-audience"])
630 .subject("test-subject")
631 .private_key(&Key {
632 algorithm: Algorithm::HS512,
633 format: KeyFormat::Pem,
634 key: KeyData::Data("test-key".to_string()),
635 })
636 .build()
637 .unwrap();
638
639 let verifier = JwtBuilder::new()
640 .issuer("test-issuer")
641 .audience(&["test-audience"])
642 .subject("test-subject")
643 .public_key(&Key {
644 algorithm: Algorithm::HS512,
645 format: KeyFormat::Pem,
646 key: KeyData::Data("test-key".to_string()),
647 })
648 .build()
649 .unwrap();
650
651 let claims = signer.create_claims();
652 let token = signer.sign(&claims).unwrap();
653 let verified: crate::traits::StandardClaims = verifier.get_claims(&token).await.unwrap();
654
655 assert_eq!(verified.iss.unwrap(), "test-issuer");
656 assert_eq!(verified.aud.unwrap(), &["test-audience"]);
657 assert_eq!(verified.sub.unwrap(), "test-subject");
658 }
659
660 #[tokio::test]
661 async fn test_jwt_builder_custom_claims() {
662 #[derive(Debug, Serialize, Deserialize, PartialEq)]
663 struct CustomClaims {
664 iss: String,
665 aud: Vec<String>,
666 sub: String,
667 exp: u64,
668 role: String,
669 }
670
671 let signer = JwtBuilder::new()
672 .issuer("test-issuer")
673 .audience(&["test-audience"])
674 .subject("test-subject")
675 .private_key(&Key {
676 algorithm: Algorithm::HS512,
677 format: KeyFormat::Pem,
678 key: KeyData::Data("test-key".to_string()),
679 })
680 .build()
681 .unwrap();
682
683 let verifier = JwtBuilder::new()
684 .issuer("test-issuer")
685 .audience(&["test-audience"])
686 .subject("test-subject")
687 .public_key(&Key {
688 algorithm: Algorithm::HS512,
689 format: KeyFormat::Pem,
690 key: KeyData::Data("test-key".to_string()),
691 })
692 .build()
693 .unwrap();
694
695 let now = SystemTime::now()
696 .duration_since(UNIX_EPOCH)
697 .unwrap()
698 .as_secs();
699
700 let custom_claims = CustomClaims {
701 iss: "test-issuer".to_string(),
702 aud: vec!["test-audience".to_string()],
703 sub: "test-subject".to_string(),
704 exp: now + 3600,
705 role: "admin".to_string(),
706 };
707
708 let token = signer.sign(&custom_claims).unwrap();
709 let verified: CustomClaims = verifier.get_claims(&token).await.unwrap();
710
711 assert_eq!(verified, custom_claims);
712 }
713
714 #[test]
715 fn test_jwt_builder_auto_resolve_keys() {
716 initialize_crypto_provider();
718
719 let jwt = JwtBuilder::new()
721 .issuer("https://example.com")
722 .audience(&["test-audience"])
723 .subject("test-subject")
724 .auto_resolve_keys(true)
725 .build();
726 assert!(jwt.is_ok());
727 }
728
729 #[tokio::test]
730 async fn test_static_token_provider() {
731 initialize_crypto_provider();
733
734 let tokenvalue = "thecontent";
735
736 let file_path = temp_file_path("token");
737 create_file(&file_path, tokenvalue).unwrap();
738 let file_name = file_path.to_string_lossy().to_string();
739
740 let provider = JwtBuilder::new()
741 .issuer("https://example.com")
742 .audience(&["test-audience"])
743 .subject("test-subject")
744 .token_file(&file_name)
745 .build()
746 .unwrap();
747
748 let token = provider.get_token().unwrap();
750 assert_eq!(token, tokenvalue);
751
752 let new_token_value = "thenewcontent";
754 create_file(&file_path, new_token_value).unwrap();
755
756 tokio::time::sleep(Duration::from_millis(100)).await;
759
760 let token = provider.get_token().unwrap();
762 assert_eq!(token, new_token_value);
763
764 delete_file(&file_path);
765 }
766
767 #[tokio::test]
768 async fn initialize_static_token_provider() {
769 let jwt = JwtBuilder::new()
770 .issuer("test-issuer")
771 .audience(&["aud"])
772 .subject("sub")
773 .private_key(&Key {
774 algorithm: Algorithm::HS256,
775 format: KeyFormat::Pem,
776 key: KeyData::Data("secret-key".into()),
777 })
778 .build()
779 .unwrap();
780 let token = Arc::new(RwLock::new("header.payload.sig".to_string()));
781 let mut static_provider: StaticTokenProvider = jwt.with_static_token(token);
782 let _ = static_provider.initialize().await; let token = static_provider.get_token().unwrap(); assert_eq!(token, "header.payload.sig");
785 }
786}