1use jsonwebtoken::{Algorithm, Header, Validation};
2use serde::{Deserialize, Serialize};
3use std::{collections::HashMap, fmt::Debug, sync::Arc};
4
5use crate::{
6 bare_key::{BareKeyInner, SerializedKey},
7 registry::IsKey,
8 Any, BareKey, BarePrivateKey, BarePublicKey, KeyError, OpaqueValidationFailureReason,
9 SignatureError, SigningContext, ValidationContext, ValidationError, ValidationType,
10};
11
12#[derive(Clone, Copy, Debug, derive_more::Display, PartialEq, Eq)]
13pub enum KeyType {
14 RS256,
15 ES256,
16 HS256,
17}
18
19#[derive(Serialize, Deserialize)]
20struct Token {
21 #[serde(rename = "exp", default, skip_serializing_if = "Option::is_none")]
22 pub expiry: Option<usize>,
23 #[serde(rename = "iss", default, skip_serializing_if = "Option::is_none")]
24 pub issuer: Option<String>,
25 #[serde(rename = "aud", default, skip_serializing_if = "Option::is_none")]
26 pub audience: Option<String>,
27 #[serde(rename = "iat", default, skip_serializing_if = "Option::is_none")]
28 pub issued_at: Option<usize>,
29 #[serde(rename = "nbf", default, skip_serializing_if = "Option::is_none")]
30 pub not_before: Option<usize>,
31 #[serde(flatten)]
32 claims: HashMap<String, Any>,
33}
34
35pub struct PrivateKey {
37 pub(crate) kid: Option<String>,
38 pub(crate) inner: Arc<PrivateKeyInner>,
39}
40
41impl PrivateKey {
42 pub fn key_type(&self) -> KeyType {
43 self.inner.bare_key.key_type()
44 }
45
46 pub fn set_kid(&mut self, kid: Option<String>) {
47 self.kid = kid;
48 }
49
50 pub fn from_bare_private_key(
51 kid: Option<String>,
52 key: BarePrivateKey,
53 ) -> Result<Self, KeyError> {
54 let encoding_key = (&key).try_into()?;
55 let decoding_key = (&key.to_public()?).try_into()?;
56 let inner = PrivateKeyInner {
57 bare_key: key,
58 encoding_key,
59 decoding_key,
60 }
61 .into();
62 Ok(Self { kid, inner })
63 }
64
65 #[cfg(feature = "keygen")]
66 pub fn generate(kid: Option<String>, kty: KeyType) -> Result<Self, KeyError> {
67 let key = BarePrivateKey::generate(kty)?;
68 Self::from_bare_private_key(kid, key)
69 }
70
71 pub fn clone_key(&self) -> Self {
72 Self {
73 kid: self.kid.clone(),
74 inner: self.inner.clone(),
75 }
76 }
77
78 pub fn sign(
79 &self,
80 claims: HashMap<String, Any>,
81 ctx: &SigningContext,
82 ) -> Result<String, SignatureError> {
83 sign_token(
84 self.key_type(),
85 &self.inner.encoding_key,
86 self.kid.as_deref(),
87 claims,
88 ctx,
89 )
90 }
91
92 pub fn validate(
93 &self,
94 token: &str,
95 ctx: &ValidationContext,
96 ) -> Result<HashMap<String, Any>, ValidationError> {
97 validate_token(
98 self.key_type(),
99 &self.inner.decoding_key,
100 self.kid.as_deref(),
101 token,
102 ctx,
103 )
104 }
105}
106
107impl IsKey for PrivateKey {
108 type Inner = Arc<PrivateKeyInner>;
109
110 fn key_type(inner: &Self::Inner) -> KeyType {
111 inner.bare_key.key_type()
112 }
113
114 fn inner(&self) -> &Self::Inner {
115 &self.inner
116 }
117
118 fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self {
119 PrivateKey { kid, inner }
120 }
121
122 fn into_inner(self) -> (Option<String>, Self::Inner) {
123 (self.kid, self.inner)
124 }
125
126 fn get_serialized_key(key: SerializedKey) -> Option<Self> {
127 match key {
128 SerializedKey::Private(kid, key) => {
129 Some(PrivateKey::from_bare_private_key(kid, key).ok()?)
130 }
131 _ => None,
132 }
133 }
134
135 fn to_serialized_key(kid: Option<&str>, key: &Self::Inner) -> SerializedKey {
136 SerializedKey::Private(kid.map(String::from), key.bare_key.clone_key())
137 }
138
139 fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
140 BarePrivateKey::from_pem_multiple(pem).map(|keys| {
141 keys.into_iter()
142 .map(|k| k.and_then(|bare_key| PrivateKey::from_bare_private_key(None, bare_key)))
143 .collect()
144 })
145 }
146
147 fn to_pem(inner: &Self::Inner) -> String {
148 inner.bare_key.to_pem()
149 }
150
151 fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey {
152 &inner.decoding_key
153 }
154
155 fn encoding_key(inner: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey> {
156 Some(&inner.encoding_key)
157 }
158}
159
160pub(crate) fn sign_token(
161 key_type: KeyType,
162 encoding_key: &jsonwebtoken::EncodingKey,
163 kid: Option<&str>,
164 claims: HashMap<String, Any>,
165 ctx: &SigningContext,
166) -> Result<String, SignatureError> {
167 let mut header = Header {
168 kid: kid.map(String::from),
169 ..Default::default()
170 };
171 match key_type {
172 KeyType::HS256 => {}
173 KeyType::ES256 => header.alg = jsonwebtoken::Algorithm::ES256,
174 KeyType::RS256 => header.alg = jsonwebtoken::Algorithm::RS256,
175 }
176
177 let now = std::time::SystemTime::now()
178 .duration_since(std::time::UNIX_EPOCH)
179 .unwrap_or_default()
180 .as_secs() as usize;
181
182 let (issued_at, not_before) = if let Some(not_before) = ctx.not_before {
183 (
184 Some(now),
185 Some(now.saturating_sub(not_before.as_secs() as usize)),
186 )
187 } else {
188 (None, None)
189 };
190
191 let expiry = ctx.expiry.map(|d| d.as_secs() as isize);
192 let expiry = if expiry == Some(0) {
193 Some(now.saturating_sub(120))
197 } else {
198 expiry.map(|d| now.saturating_add_signed(d))
199 };
200
201 let token = Token {
202 expiry,
203 issuer: ctx.issuer.clone(),
204 audience: ctx.audience.clone(),
205 issued_at,
206 not_before,
207 claims,
208 };
209
210 jsonwebtoken::encode(&header, &token, encoding_key)
211 .map_err(|e| SignatureError::SignatureError(e.to_string()))
212}
213
214pub(crate) fn validate_token(
217 key_type: KeyType,
218 decoding_key: &jsonwebtoken::DecodingKey,
219 kid: Option<&str>,
220 token: &str,
221 ctx: &ValidationContext,
222) -> Result<HashMap<String, Any>, ValidationError> {
223 let mut validation = Validation::new(match key_type {
224 KeyType::ES256 => Algorithm::ES256,
225 KeyType::HS256 => Algorithm::HS256,
226 KeyType::RS256 => Algorithm::RS256,
227 });
228
229 validation.validate_aud = false;
230
231 match ctx.expiry {
232 ValidationType::Ignore => {
233 validation.required_spec_claims.remove("exp");
234 validation.validate_exp = false;
235 }
236 ValidationType::Allow => {
237 validation.required_spec_claims.remove("exp");
238 validation.validate_exp = true;
239 }
240 ValidationType::Reject => {
241 validation.required_spec_claims.remove("exp");
242 validation.validate_exp = false;
243 }
244 ValidationType::Require => {
245 }
247 }
248
249 match ctx.not_before {
250 ValidationType::Ignore => {
251 validation.validate_nbf = false;
252 }
253 ValidationType::Allow => {
254 validation.validate_nbf = true;
255 }
256 ValidationType::Reject => {
257 validation.validate_nbf = false;
258 }
259 ValidationType::Require => {
260 validation.required_spec_claims.insert("nbf".to_string());
261 validation.validate_nbf = true;
262 }
263 }
264
265 let token = jsonwebtoken::decode::<HashMap<String, Any>>(token, decoding_key, &validation)
266 .map_err(|e| match e.kind() {
267 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
268 OpaqueValidationFailureReason::InvalidSignature
269 }
270 _ => OpaqueValidationFailureReason::Failure(format!("{:?}", e.kind())),
271 })?;
272
273 if let (Some(token_kid), Some(expected_kid)) = (token.header.kid, kid) {
274 if token_kid != expected_kid {
275 return Err(OpaqueValidationFailureReason::InvalidHeader(
276 "kid".to_string(),
277 token_kid,
278 Some(expected_kid.to_string()),
279 )
280 .into());
281 }
282 }
283
284 for (claim, values) in &ctx.allow_list {
285 let value = token.claims.get(claim);
286 match value {
287 Some(Any::String(value)) => {
288 if !values.contains(value.as_ref()) {
289 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
290 claim.to_string(),
291 Some(value.to_string()),
292 )
293 .into());
294 }
295 }
296 Some(Any::Array(array_values)) => {
297 for v in array_values.iter() {
298 if let Any::String(v) = v {
299 if !values.contains(v.as_ref()) {
300 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
301 claim.to_string(),
302 Some(v.to_string()),
303 )
304 .into());
305 }
306 } else {
307 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
308 claim.to_string(),
309 None,
310 )
311 .into());
312 }
313 }
314 }
315 _ => {
316 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
317 claim.to_string(),
318 None,
319 )
320 .into());
321 }
322 }
323 }
324
325 for (claim, values) in &ctx.deny_list {
326 let value = token.claims.get(claim);
327 match value {
328 Some(Any::String(value)) => {
329 if values.contains(value.as_ref()) {
330 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
331 claim.to_string(),
332 Some(value.to_string()),
333 )
334 .into());
335 }
336 }
337 Some(Any::Array(array_values)) => {
338 for v in array_values.iter() {
339 if let Any::String(v) = v {
340 if values.contains(v.as_ref()) {
341 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
342 claim.to_string(),
343 Some(v.to_string()),
344 )
345 .into());
346 }
347 } else {
348 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
349 claim.to_string(),
350 None,
351 )
352 .into());
353 }
354 }
355 }
356 _ => {
357 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
358 claim.to_string(),
359 None,
360 )
361 .into());
362 }
363 }
364 }
365
366 let mut claims = token.claims;
369 claims.remove("exp");
370 for claim in ctx.claims.iter() {
371 claims.remove(claim.0);
372 }
373
374 if ctx.expiry == ValidationType::Reject {
375 if let Some(exp) = claims.remove("exp") {
376 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
377 "exp".to_string(),
378 Some(format!("{exp:?}")),
379 )
380 .into());
381 }
382 }
383 if ctx.not_before == ValidationType::Reject {
384 if let Some(nbf) = claims.remove("nbf") {
385 return Err(OpaqueValidationFailureReason::InvalidClaimValue(
386 "nbf".to_string(),
387 Some(format!("{nbf:?}")),
388 )
389 .into());
390 }
391 }
392
393 Ok(claims)
394}
395
396pub(crate) struct PrivateKeyInner {
397 pub(crate) bare_key: BarePrivateKey,
398 pub(crate) encoding_key: jsonwebtoken::EncodingKey,
399 pub(crate) decoding_key: jsonwebtoken::DecodingKey,
400}
401
402impl std::hash::Hash for PrivateKeyInner {
403 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
404 self.bare_key.hash(state);
405 }
406}
407
408impl PartialEq for PrivateKeyInner {
409 fn eq(&self, other: &Self) -> bool {
410 self.bare_key == other.bare_key
411 }
412}
413
414impl std::fmt::Debug for PrivateKeyInner {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 self.bare_key.fmt(f)
417 }
418}
419
420impl Eq for PrivateKeyInner {}
421
422#[derive(Clone, Debug, PartialEq, Eq)]
424pub struct PublicKey {
425 kid: Option<String>,
426 inner: Arc<PublicKeyInner>,
427}
428
429impl PublicKey {
430 pub fn key_type(&self) -> KeyType {
431 self.inner.bare_key.key_type()
432 }
433
434 pub fn set_kid(&mut self, kid: Option<String>) {
435 self.kid = kid;
436 }
437
438 pub fn from_bare_public_key(kid: Option<String>, key: BarePublicKey) -> Result<Self, KeyError> {
439 let decoding_key: jsonwebtoken::DecodingKey = (&key).try_into()?;
440 let inner = PublicKeyInner {
441 decoding_key,
442 bare_key: key,
443 }
444 .into();
445 Ok(Self { kid, inner })
446 }
447
448 pub fn validate(
449 &self,
450 token: &str,
451 ctx: &ValidationContext,
452 ) -> Result<HashMap<String, Any>, ValidationError> {
453 validate_token(
454 self.key_type(),
455 &self.inner.decoding_key,
456 self.kid.as_deref(),
457 token,
458 ctx,
459 )
460 }
461}
462
463impl IsKey for PublicKey {
464 type Inner = Arc<PublicKeyInner>;
465
466 fn key_type(inner: &Self::Inner) -> KeyType {
467 inner.bare_key.key_type()
468 }
469
470 fn inner(&self) -> &Self::Inner {
471 &self.inner
472 }
473
474 fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self {
475 PublicKey { kid, inner }
476 }
477
478 fn into_inner(self) -> (Option<String>, Self::Inner) {
479 (self.kid, self.inner)
480 }
481
482 fn get_serialized_key(key: SerializedKey) -> Option<Self> {
483 match key {
484 SerializedKey::Private(kid, key) => {
485 Some(PublicKey::from_bare_public_key(kid, key.to_public().ok()?).ok()?)
486 }
487 SerializedKey::Public(kid, key) => {
488 Some(PublicKey::from_bare_public_key(kid, key).ok()?)
489 }
490 _ => None,
491 }
492 }
493
494 fn to_serialized_key(kid: Option<&str>, key: &Self::Inner) -> SerializedKey {
495 SerializedKey::Public(kid.map(String::from), key.bare_key.clone_key())
496 }
497
498 fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
499 BarePublicKey::from_pem_multiple(pem).map(|keys| {
500 keys.into_iter()
501 .map(|k| k.and_then(|bare_key| PublicKey::from_bare_public_key(None, bare_key)))
502 .collect()
503 })
504 }
505
506 fn to_pem(inner: &Self::Inner) -> String {
507 inner.bare_key.to_pem()
508 }
509
510 fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey {
511 &inner.decoding_key
512 }
513
514 fn encoding_key(_: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey> {
515 None
516 }
517}
518
519pub struct PublicKeyInner {
520 pub(crate) bare_key: BarePublicKey,
521 pub(crate) decoding_key: jsonwebtoken::DecodingKey,
522}
523
524impl std::fmt::Debug for PublicKeyInner {
525 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 self.bare_key.fmt(f)
527 }
528}
529
530impl std::hash::Hash for PublicKeyInner {
531 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
532 self.bare_key.hash(state);
533 }
534}
535
536impl PartialEq for PublicKeyInner {
537 fn eq(&self, other: &Self) -> bool {
538 self.bare_key == other.bare_key
539 }
540}
541
542impl Eq for PublicKeyInner {}
543
544pub struct Key {
546 kid: Option<String>,
547 inner: KeyInner,
548}
549
550#[derive(Clone, Debug, PartialEq, Eq, Hash)]
551pub(crate) enum KeyInner {
552 Private(Arc<PrivateKeyInner>),
553 Public(Arc<PublicKeyInner>),
554}
555
556impl KeyInner {}
557
558impl Key {
559 pub fn key_type(&self) -> KeyType {
560 match &self.inner {
561 KeyInner::Private(inner) => inner.bare_key.key_type(),
562 KeyInner::Public(inner) => inner.bare_key.key_type(),
563 }
564 }
565
566 pub fn from_bare_key(kid: Option<String>, key: BareKey) -> Result<Self, KeyError> {
567 Ok(match key.inner {
568 BareKeyInner::Private(inner) => {
569 PrivateKey::from_bare_private_key(kid, BarePrivateKey { inner })?.into()
570 }
571 BareKeyInner::Public(inner) => {
572 PublicKey::from_bare_public_key(kid, BarePublicKey { inner })?.into()
573 }
574 })
575 }
576
577 pub fn from_bare_private_key(
578 kid: Option<String>,
579 key: BarePrivateKey,
580 ) -> Result<Self, KeyError> {
581 Ok(PrivateKey::from_bare_private_key(kid, key)?.into())
582 }
583
584 pub fn from_bare_public_key(kid: Option<String>, key: BarePublicKey) -> Result<Self, KeyError> {
585 Ok(PublicKey::from_bare_public_key(kid, key)?.into())
586 }
587}
588
589impl From<PrivateKey> for Key {
590 fn from(key: PrivateKey) -> Self {
591 Key {
592 kid: key.kid,
593 inner: KeyInner::Private(key.inner),
594 }
595 }
596}
597
598impl From<PublicKey> for Key {
599 fn from(key: PublicKey) -> Self {
600 Key {
601 kid: key.kid,
602 inner: KeyInner::Public(key.inner),
603 }
604 }
605}
606
607impl IsKey for Key {
608 type Inner = KeyInner;
609
610 fn key_type(inner: &Self::Inner) -> KeyType {
611 match inner {
612 KeyInner::Private(inner) => inner.bare_key.key_type(),
613 KeyInner::Public(inner) => inner.bare_key.key_type(),
614 }
615 }
616
617 fn inner(&self) -> &Self::Inner {
618 &self.inner
619 }
620
621 fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self {
622 Key { kid, inner }
623 }
624
625 fn into_inner(self) -> (Option<String>, Self::Inner) {
626 (self.kid, self.inner)
627 }
628
629 fn get_serialized_key(key: SerializedKey) -> Option<Self> {
630 match key {
631 SerializedKey::Private(kid, key) => {
632 Some(PrivateKey::from_bare_private_key(kid, key).ok()?.into())
633 }
634 SerializedKey::Public(kid, key) => {
635 Some(PublicKey::from_bare_public_key(kid, key).ok()?.into())
636 }
637 _ => None,
638 }
639 }
640
641 fn to_serialized_key(kid: Option<&str>, key: &Self::Inner) -> SerializedKey {
642 match key {
643 KeyInner::Private(inner) => {
644 SerializedKey::Private(kid.map(String::from), inner.bare_key.clone_key())
645 }
646 KeyInner::Public(inner) => {
647 SerializedKey::Public(kid.map(String::from), inner.bare_key.clone_key())
648 }
649 }
650 }
651
652 fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
653 let keys = BareKey::from_pem_multiple(pem)?;
654 let mut results = Vec::new();
655 for key in keys {
656 results.push(key.and_then(|key| Self::from_bare_key(None, key)));
657 }
658 Ok(results)
659 }
660
661 fn to_pem(inner: &Self::Inner) -> String {
662 match inner {
663 KeyInner::Private(inner) => inner.bare_key.to_pem(),
664 KeyInner::Public(inner) => inner.bare_key.to_pem(),
665 }
666 }
667
668 fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey {
669 match inner {
670 KeyInner::Private(inner) => &inner.decoding_key,
671 KeyInner::Public(inner) => &inner.decoding_key,
672 }
673 }
674
675 fn encoding_key(inner: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey> {
676 match inner {
677 KeyInner::Private(inner) => Some(&inner.encoding_key),
678 KeyInner::Public(_) => None,
679 }
680 }
681}