gel_jwt/
key.rs

1use jsonwebtoken::{Algorithm, Header, Validation};
2use serde::{Deserialize, Serialize};
3use std::{collections::HashMap, fmt::Debug, sync::Arc};
4
5use crate::{
6    bare_key::{BareKeyInner, SerializedKey},
7    registry::IsKey,
8    Any, BareKey, BarePrivateKey, BarePublicKey, KeyError, OpaqueValidationFailureReason,
9    SignatureError, SigningContext, ValidationContext, ValidationError, ValidationType,
10};
11
12#[derive(Clone, Copy, Debug, derive_more::Display, PartialEq, Eq)]
13pub enum KeyType {
14    RS256,
15    ES256,
16    HS256,
17}
18
19#[derive(Serialize, Deserialize)]
20struct Token {
21    #[serde(rename = "exp", default, skip_serializing_if = "Option::is_none")]
22    pub expiry: Option<usize>,
23    #[serde(rename = "iss", default, skip_serializing_if = "Option::is_none")]
24    pub issuer: Option<String>,
25    #[serde(rename = "aud", default, skip_serializing_if = "Option::is_none")]
26    pub audience: Option<String>,
27    #[serde(rename = "iat", default, skip_serializing_if = "Option::is_none")]
28    pub issued_at: Option<usize>,
29    #[serde(rename = "nbf", default, skip_serializing_if = "Option::is_none")]
30    pub not_before: Option<usize>,
31    #[serde(flatten)]
32    claims: HashMap<String, Any>,
33}
34
35/// A private key with key-signing capabilities.
36pub struct PrivateKey {
37    pub(crate) kid: Option<String>,
38    pub(crate) inner: Arc<PrivateKeyInner>,
39}
40
41impl PrivateKey {
42    pub fn key_type(&self) -> KeyType {
43        self.inner.bare_key.key_type()
44    }
45
46    pub fn set_kid(&mut self, kid: Option<String>) {
47        self.kid = kid;
48    }
49
50    pub fn from_bare_private_key(
51        kid: Option<String>,
52        key: BarePrivateKey,
53    ) -> Result<Self, KeyError> {
54        let encoding_key = (&key).try_into()?;
55        let decoding_key = (&key.to_public()?).try_into()?;
56        let inner = PrivateKeyInner {
57            bare_key: key,
58            encoding_key,
59            decoding_key,
60        }
61        .into();
62        Ok(Self { kid, inner })
63    }
64
65    #[cfg(feature = "keygen")]
66    pub fn generate(kid: Option<String>, kty: KeyType) -> Result<Self, KeyError> {
67        let key = BarePrivateKey::generate(kty)?;
68        Self::from_bare_private_key(kid, key)
69    }
70
71    pub fn clone_key(&self) -> Self {
72        Self {
73            kid: self.kid.clone(),
74            inner: self.inner.clone(),
75        }
76    }
77
78    pub fn sign(
79        &self,
80        claims: HashMap<String, Any>,
81        ctx: &SigningContext,
82    ) -> Result<String, SignatureError> {
83        sign_token(
84            self.key_type(),
85            &self.inner.encoding_key,
86            self.kid.as_deref(),
87            claims,
88            ctx,
89        )
90    }
91
92    pub fn validate(
93        &self,
94        token: &str,
95        ctx: &ValidationContext,
96    ) -> Result<HashMap<String, Any>, ValidationError> {
97        validate_token(
98            self.key_type(),
99            &self.inner.decoding_key,
100            self.kid.as_deref(),
101            token,
102            ctx,
103        )
104    }
105}
106
107impl IsKey for PrivateKey {
108    type Inner = Arc<PrivateKeyInner>;
109
110    fn key_type(inner: &Self::Inner) -> KeyType {
111        inner.bare_key.key_type()
112    }
113
114    fn inner(&self) -> &Self::Inner {
115        &self.inner
116    }
117
118    fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self {
119        PrivateKey { kid, inner }
120    }
121
122    fn into_inner(self) -> (Option<String>, Self::Inner) {
123        (self.kid, self.inner)
124    }
125
126    fn get_serialized_key(key: SerializedKey) -> Option<Self> {
127        match key {
128            SerializedKey::Private(kid, key) => {
129                Some(PrivateKey::from_bare_private_key(kid, key).ok()?)
130            }
131            _ => None,
132        }
133    }
134
135    fn to_serialized_key(kid: Option<&str>, key: &Self::Inner) -> SerializedKey {
136        SerializedKey::Private(kid.map(String::from), key.bare_key.clone_key())
137    }
138
139    fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
140        BarePrivateKey::from_pem_multiple(pem).map(|keys| {
141            keys.into_iter()
142                .map(|k| k.and_then(|bare_key| PrivateKey::from_bare_private_key(None, bare_key)))
143                .collect()
144        })
145    }
146
147    fn to_pem(inner: &Self::Inner) -> String {
148        inner.bare_key.to_pem()
149    }
150
151    fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey {
152        &inner.decoding_key
153    }
154
155    fn encoding_key(inner: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey> {
156        Some(&inner.encoding_key)
157    }
158}
159
160pub(crate) fn sign_token(
161    key_type: KeyType,
162    encoding_key: &jsonwebtoken::EncodingKey,
163    kid: Option<&str>,
164    claims: HashMap<String, Any>,
165    ctx: &SigningContext,
166) -> Result<String, SignatureError> {
167    let mut header = Header {
168        kid: kid.map(String::from),
169        ..Default::default()
170    };
171    match key_type {
172        KeyType::HS256 => {}
173        KeyType::ES256 => header.alg = jsonwebtoken::Algorithm::ES256,
174        KeyType::RS256 => header.alg = jsonwebtoken::Algorithm::RS256,
175    }
176
177    let now = std::time::SystemTime::now()
178        .duration_since(std::time::UNIX_EPOCH)
179        .unwrap_or_default()
180        .as_secs() as usize;
181
182    let (issued_at, not_before) = if let Some(not_before) = ctx.not_before {
183        (
184            Some(now),
185            Some(now.saturating_sub(not_before.as_secs() as usize)),
186        )
187    } else {
188        (None, None)
189    };
190
191    let expiry = ctx.expiry.map(|d| d.as_secs() as isize);
192    let expiry = if expiry == Some(0) {
193        // Ensure that a token that expires now expires with enough notice for
194        // the leeway option to be ignored. This isn't a great solution, but
195        // it's challenging to test expiring tokens otherwise.
196        Some(now.saturating_sub(120))
197    } else {
198        expiry.map(|d| now.saturating_add_signed(d))
199    };
200
201    let token = Token {
202        expiry,
203        issuer: ctx.issuer.clone(),
204        audience: ctx.audience.clone(),
205        issued_at,
206        not_before,
207        claims,
208    };
209
210    jsonwebtoken::encode(&header, &token, encoding_key)
211        .map_err(|e| SignatureError::SignatureError(e.to_string()))
212}
213
214/// Returns the raw claims from the token, including those we may have added
215/// as part of the signature process.
216pub(crate) fn validate_token(
217    key_type: KeyType,
218    decoding_key: &jsonwebtoken::DecodingKey,
219    kid: Option<&str>,
220    token: &str,
221    ctx: &ValidationContext,
222) -> Result<HashMap<String, Any>, ValidationError> {
223    let mut validation = Validation::new(match key_type {
224        KeyType::ES256 => Algorithm::ES256,
225        KeyType::HS256 => Algorithm::HS256,
226        KeyType::RS256 => Algorithm::RS256,
227    });
228
229    validation.validate_aud = false;
230
231    match ctx.expiry {
232        ValidationType::Ignore => {
233            validation.required_spec_claims.remove("exp");
234            validation.validate_exp = false;
235        }
236        ValidationType::Allow => {
237            validation.required_spec_claims.remove("exp");
238            validation.validate_exp = true;
239        }
240        ValidationType::Reject => {
241            validation.required_spec_claims.remove("exp");
242            validation.validate_exp = false;
243        }
244        ValidationType::Require => {
245            // The default
246        }
247    }
248
249    match ctx.not_before {
250        ValidationType::Ignore => {
251            validation.validate_nbf = false;
252        }
253        ValidationType::Allow => {
254            validation.validate_nbf = true;
255        }
256        ValidationType::Reject => {
257            validation.validate_nbf = false;
258        }
259        ValidationType::Require => {
260            validation.required_spec_claims.insert("nbf".to_string());
261            validation.validate_nbf = true;
262        }
263    }
264
265    let token = jsonwebtoken::decode::<HashMap<String, Any>>(token, decoding_key, &validation)
266        .map_err(|e| match e.kind() {
267            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
268                OpaqueValidationFailureReason::InvalidSignature
269            }
270            _ => OpaqueValidationFailureReason::Failure(format!("{:?}", e.kind())),
271        })?;
272
273    if let (Some(token_kid), Some(expected_kid)) = (token.header.kid, kid) {
274        if token_kid != expected_kid {
275            return Err(OpaqueValidationFailureReason::InvalidHeader(
276                "kid".to_string(),
277                token_kid,
278                Some(expected_kid.to_string()),
279            )
280            .into());
281        }
282    }
283
284    for (claim, values) in &ctx.allow_list {
285        let value = token.claims.get(claim);
286        match value {
287            Some(Any::String(value)) => {
288                if !values.contains(value.as_ref()) {
289                    return Err(OpaqueValidationFailureReason::InvalidClaimValue(
290                        claim.to_string(),
291                        Some(value.to_string()),
292                    )
293                    .into());
294                }
295            }
296            Some(Any::Array(array_values)) => {
297                for v in array_values.iter() {
298                    if let Any::String(v) = v {
299                        if !values.contains(v.as_ref()) {
300                            return Err(OpaqueValidationFailureReason::InvalidClaimValue(
301                                claim.to_string(),
302                                Some(v.to_string()),
303                            )
304                            .into());
305                        }
306                    } else {
307                        return Err(OpaqueValidationFailureReason::InvalidClaimValue(
308                            claim.to_string(),
309                            None,
310                        )
311                        .into());
312                    }
313                }
314            }
315            _ => {
316                return Err(OpaqueValidationFailureReason::InvalidClaimValue(
317                    claim.to_string(),
318                    None,
319                )
320                .into());
321            }
322        }
323    }
324
325    for (claim, values) in &ctx.deny_list {
326        let value = token.claims.get(claim);
327        match value {
328            Some(Any::String(value)) => {
329                if values.contains(value.as_ref()) {
330                    return Err(OpaqueValidationFailureReason::InvalidClaimValue(
331                        claim.to_string(),
332                        Some(value.to_string()),
333                    )
334                    .into());
335                }
336            }
337            Some(Any::Array(array_values)) => {
338                for v in array_values.iter() {
339                    if let Any::String(v) = v {
340                        if values.contains(v.as_ref()) {
341                            return Err(OpaqueValidationFailureReason::InvalidClaimValue(
342                                claim.to_string(),
343                                Some(v.to_string()),
344                            )
345                            .into());
346                        }
347                    } else {
348                        return Err(OpaqueValidationFailureReason::InvalidClaimValue(
349                            claim.to_string(),
350                            None,
351                        )
352                        .into());
353                    }
354                }
355            }
356            _ => {
357                return Err(OpaqueValidationFailureReason::InvalidClaimValue(
358                    claim.to_string(),
359                    None,
360                )
361                .into());
362            }
363        }
364    }
365
366    // Remove any claims that were validated automatically and reject any that should not
367    // be present.
368    let mut claims = token.claims;
369    claims.remove("exp");
370    for claim in ctx.claims.iter() {
371        claims.remove(claim.0);
372    }
373
374    if ctx.expiry == ValidationType::Reject {
375        if let Some(exp) = claims.remove("exp") {
376            return Err(OpaqueValidationFailureReason::InvalidClaimValue(
377                "exp".to_string(),
378                Some(format!("{exp:?}")),
379            )
380            .into());
381        }
382    }
383    if ctx.not_before == ValidationType::Reject {
384        if let Some(nbf) = claims.remove("nbf") {
385            return Err(OpaqueValidationFailureReason::InvalidClaimValue(
386                "nbf".to_string(),
387                Some(format!("{nbf:?}")),
388            )
389            .into());
390        }
391    }
392
393    Ok(claims)
394}
395
396pub(crate) struct PrivateKeyInner {
397    pub(crate) bare_key: BarePrivateKey,
398    pub(crate) encoding_key: jsonwebtoken::EncodingKey,
399    pub(crate) decoding_key: jsonwebtoken::DecodingKey,
400}
401
402impl std::hash::Hash for PrivateKeyInner {
403    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
404        self.bare_key.hash(state);
405    }
406}
407
408impl PartialEq for PrivateKeyInner {
409    fn eq(&self, other: &Self) -> bool {
410        self.bare_key == other.bare_key
411    }
412}
413
414impl std::fmt::Debug for PrivateKeyInner {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        self.bare_key.fmt(f)
417    }
418}
419
420impl Eq for PrivateKeyInner {}
421
422/// A public key with key-validation capabilities.
423#[derive(Clone, Debug, PartialEq, Eq)]
424pub struct PublicKey {
425    kid: Option<String>,
426    inner: Arc<PublicKeyInner>,
427}
428
429impl PublicKey {
430    pub fn key_type(&self) -> KeyType {
431        self.inner.bare_key.key_type()
432    }
433
434    pub fn set_kid(&mut self, kid: Option<String>) {
435        self.kid = kid;
436    }
437
438    pub fn from_bare_public_key(kid: Option<String>, key: BarePublicKey) -> Result<Self, KeyError> {
439        let decoding_key: jsonwebtoken::DecodingKey = (&key).try_into()?;
440        let inner = PublicKeyInner {
441            decoding_key,
442            bare_key: key,
443        }
444        .into();
445        Ok(Self { kid, inner })
446    }
447
448    pub fn validate(
449        &self,
450        token: &str,
451        ctx: &ValidationContext,
452    ) -> Result<HashMap<String, Any>, ValidationError> {
453        validate_token(
454            self.key_type(),
455            &self.inner.decoding_key,
456            self.kid.as_deref(),
457            token,
458            ctx,
459        )
460    }
461}
462
463impl IsKey for PublicKey {
464    type Inner = Arc<PublicKeyInner>;
465
466    fn key_type(inner: &Self::Inner) -> KeyType {
467        inner.bare_key.key_type()
468    }
469
470    fn inner(&self) -> &Self::Inner {
471        &self.inner
472    }
473
474    fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self {
475        PublicKey { kid, inner }
476    }
477
478    fn into_inner(self) -> (Option<String>, Self::Inner) {
479        (self.kid, self.inner)
480    }
481
482    fn get_serialized_key(key: SerializedKey) -> Option<Self> {
483        match key {
484            SerializedKey::Private(kid, key) => {
485                Some(PublicKey::from_bare_public_key(kid, key.to_public().ok()?).ok()?)
486            }
487            SerializedKey::Public(kid, key) => {
488                Some(PublicKey::from_bare_public_key(kid, key).ok()?)
489            }
490            _ => None,
491        }
492    }
493
494    fn to_serialized_key(kid: Option<&str>, key: &Self::Inner) -> SerializedKey {
495        SerializedKey::Public(kid.map(String::from), key.bare_key.clone_key())
496    }
497
498    fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
499        BarePublicKey::from_pem_multiple(pem).map(|keys| {
500            keys.into_iter()
501                .map(|k| k.and_then(|bare_key| PublicKey::from_bare_public_key(None, bare_key)))
502                .collect()
503        })
504    }
505
506    fn to_pem(inner: &Self::Inner) -> String {
507        inner.bare_key.to_pem()
508    }
509
510    fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey {
511        &inner.decoding_key
512    }
513
514    fn encoding_key(_: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey> {
515        None
516    }
517}
518
519pub struct PublicKeyInner {
520    pub(crate) bare_key: BarePublicKey,
521    pub(crate) decoding_key: jsonwebtoken::DecodingKey,
522}
523
524impl std::fmt::Debug for PublicKeyInner {
525    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        self.bare_key.fmt(f)
527    }
528}
529
530impl std::hash::Hash for PublicKeyInner {
531    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
532        self.bare_key.hash(state);
533    }
534}
535
536impl PartialEq for PublicKeyInner {
537    fn eq(&self, other: &Self) -> bool {
538        self.bare_key == other.bare_key
539    }
540}
541
542impl Eq for PublicKeyInner {}
543
544/// A key which is either a private or public key.
545pub struct Key {
546    kid: Option<String>,
547    inner: KeyInner,
548}
549
550#[derive(Clone, Debug, PartialEq, Eq, Hash)]
551pub(crate) enum KeyInner {
552    Private(Arc<PrivateKeyInner>),
553    Public(Arc<PublicKeyInner>),
554}
555
556impl KeyInner {}
557
558impl Key {
559    pub fn key_type(&self) -> KeyType {
560        match &self.inner {
561            KeyInner::Private(inner) => inner.bare_key.key_type(),
562            KeyInner::Public(inner) => inner.bare_key.key_type(),
563        }
564    }
565
566    pub fn from_bare_key(kid: Option<String>, key: BareKey) -> Result<Self, KeyError> {
567        Ok(match key.inner {
568            BareKeyInner::Private(inner) => {
569                PrivateKey::from_bare_private_key(kid, BarePrivateKey { inner })?.into()
570            }
571            BareKeyInner::Public(inner) => {
572                PublicKey::from_bare_public_key(kid, BarePublicKey { inner })?.into()
573            }
574        })
575    }
576
577    pub fn from_bare_private_key(
578        kid: Option<String>,
579        key: BarePrivateKey,
580    ) -> Result<Self, KeyError> {
581        Ok(PrivateKey::from_bare_private_key(kid, key)?.into())
582    }
583
584    pub fn from_bare_public_key(kid: Option<String>, key: BarePublicKey) -> Result<Self, KeyError> {
585        Ok(PublicKey::from_bare_public_key(kid, key)?.into())
586    }
587}
588
589impl From<PrivateKey> for Key {
590    fn from(key: PrivateKey) -> Self {
591        Key {
592            kid: key.kid,
593            inner: KeyInner::Private(key.inner),
594        }
595    }
596}
597
598impl From<PublicKey> for Key {
599    fn from(key: PublicKey) -> Self {
600        Key {
601            kid: key.kid,
602            inner: KeyInner::Public(key.inner),
603        }
604    }
605}
606
607impl IsKey for Key {
608    type Inner = KeyInner;
609
610    fn key_type(inner: &Self::Inner) -> KeyType {
611        match inner {
612            KeyInner::Private(inner) => inner.bare_key.key_type(),
613            KeyInner::Public(inner) => inner.bare_key.key_type(),
614        }
615    }
616
617    fn inner(&self) -> &Self::Inner {
618        &self.inner
619    }
620
621    fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self {
622        Key { kid, inner }
623    }
624
625    fn into_inner(self) -> (Option<String>, Self::Inner) {
626        (self.kid, self.inner)
627    }
628
629    fn get_serialized_key(key: SerializedKey) -> Option<Self> {
630        match key {
631            SerializedKey::Private(kid, key) => {
632                Some(PrivateKey::from_bare_private_key(kid, key).ok()?.into())
633            }
634            SerializedKey::Public(kid, key) => {
635                Some(PublicKey::from_bare_public_key(kid, key).ok()?.into())
636            }
637            _ => None,
638        }
639    }
640
641    fn to_serialized_key(kid: Option<&str>, key: &Self::Inner) -> SerializedKey {
642        match key {
643            KeyInner::Private(inner) => {
644                SerializedKey::Private(kid.map(String::from), inner.bare_key.clone_key())
645            }
646            KeyInner::Public(inner) => {
647                SerializedKey::Public(kid.map(String::from), inner.bare_key.clone_key())
648            }
649        }
650    }
651
652    fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
653        let keys = BareKey::from_pem_multiple(pem)?;
654        let mut results = Vec::new();
655        for key in keys {
656            results.push(key.and_then(|key| Self::from_bare_key(None, key)));
657        }
658        Ok(results)
659    }
660
661    fn to_pem(inner: &Self::Inner) -> String {
662        match inner {
663            KeyInner::Private(inner) => inner.bare_key.to_pem(),
664            KeyInner::Public(inner) => inner.bare_key.to_pem(),
665        }
666    }
667
668    fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey {
669        match inner {
670            KeyInner::Private(inner) => &inner.decoding_key,
671            KeyInner::Public(inner) => &inner.decoding_key,
672        }
673    }
674
675    fn encoding_key(inner: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey> {
676        match inner {
677            KeyInner::Private(inner) => Some(&inner.encoding_key),
678            KeyInner::Public(_) => None,
679        }
680    }
681}