1use serde::{Deserialize, Deserializer, Serialize, Serializer};
37use sha2::digest::{Digest, Output};
38
39use core::fmt;
40
41use crate::{
42 alg::SecretBytes,
43 alloc::{Cow, String, ToString, Vec},
44};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48#[non_exhaustive]
49pub enum KeyType {
50 Rsa,
52 EllipticCurve,
55 Symmetric,
57 KeyPair,
59}
60
61impl fmt::Display for KeyType {
62 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63 formatter.write_str(match self {
64 Self::Rsa => "RSA",
65 Self::EllipticCurve => "EC",
66 Self::Symmetric => "oct",
67 Self::KeyPair => "OKP",
68 })
69 }
70}
71
72#[derive(Debug)]
75#[non_exhaustive]
76pub enum JwkError {
77 NoField(String),
79 UnexpectedKeyType {
81 expected: KeyType,
83 actual: KeyType,
85 },
86 UnexpectedValue {
88 field: String,
90 expected: String,
92 actual: String,
94 },
95 UnexpectedLen {
97 field: String,
99 expected: usize,
101 actual: usize,
103 },
104 MismatchedKeys,
106 Custom(anyhow::Error),
108}
109
110impl fmt::Display for JwkError {
111 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
112 match self {
113 Self::UnexpectedKeyType { expected, actual } => {
114 write!(formatter, "unexpected key type: {actual} (expected {expected})")
115 },
116 Self::NoField(field) => write!(formatter, "field `{field}` is absent from JWK"),
117 Self::UnexpectedValue { field, expected, actual } => {
118 write!(formatter, "field `{field}` has unexpected value (expected: {expected}, got: {actual})")
119 },
120 Self::UnexpectedLen { field, expected, actual } => {
121 write!(formatter, "field `{field}` has unexpected length (expected: {expected}, got: {actual})")
122 },
123 Self::MismatchedKeys => formatter.write_str("private and public keys encoded in JWK do not match"),
124 Self::Custom(err) => fmt::Display::fmt(err, formatter),
125 }
126 }
127}
128
129#[cfg(feature = "std")]
130impl std::error::Error for JwkError {
131 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
132 match self {
133 Self::Custom(err) => Some(err.as_ref()),
134 _ => None,
135 }
136 }
137}
138
139impl JwkError {
140 pub fn custom(err: impl Into<anyhow::Error>) -> Self {
142 Self::Custom(err.into())
143 }
144
145 pub(crate) fn key_type(jwk: &JsonWebKey<'_>, expected: KeyType) -> Self {
146 let actual = jwk.key_type();
147 debug_assert_ne!(actual, expected);
148 Self::UnexpectedKeyType { actual, expected }
149 }
150}
151
152impl Serialize for SecretBytes<'_> {
153 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
154 base64url::serialize(self.as_ref(), serializer)
155 }
156}
157
158impl<'de> Deserialize<'de> for SecretBytes<'_> {
159 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
160 base64url::deserialize(deserializer).map(SecretBytes::new)
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204#[serde(tag = "kty")]
205#[non_exhaustive]
206pub enum JsonWebKey<'a> {
207 #[serde(rename = "RSA")]
209 Rsa {
210 #[serde(rename = "n", with = "base64url")]
212 modulus: Cow<'a, [u8]>,
213 #[serde(rename = "e", with = "base64url")]
215 public_exponent: Cow<'a, [u8]>,
216 #[serde(flatten)]
218 private_parts: Option<RsaPrivateParts<'a>>,
219 },
220 #[serde(rename = "EC")]
222 EllipticCurve {
223 #[serde(rename = "crv")]
225 curve: Cow<'a, str>,
226 #[serde(with = "base64url")]
228 x: Cow<'a, [u8]>,
229 #[serde(with = "base64url")]
231 y: Cow<'a, [u8]>,
232 #[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
234 secret: Option<SecretBytes<'a>>,
235 },
236 #[serde(rename = "oct")]
238 Symmetric {
239 #[serde(rename = "k")]
241 secret: SecretBytes<'a>,
242 },
243 #[serde(rename = "OKP")]
245 KeyPair {
246 #[serde(rename = "crv")]
248 curve: Cow<'a, str>,
249 #[serde(with = "base64url")]
252 x: Cow<'a, [u8]>,
253 #[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
255 secret: Option<SecretBytes<'a>>,
256 },
257}
258
259impl JsonWebKey<'_> {
260 pub const fn key_type(&self) -> KeyType {
262 match self {
263 Self::Rsa { .. } => KeyType::Rsa,
264 Self::EllipticCurve { .. } => KeyType::EllipticCurve,
265 Self::Symmetric { .. } => KeyType::Symmetric,
266 Self::KeyPair { .. } => KeyType::KeyPair,
267 }
268 }
269
270 pub const fn is_signing_key(&self) -> bool {
272 match self {
273 Self::Rsa { private_parts, .. } => private_parts.is_some(),
274 Self::EllipticCurve { secret, .. } | Self::KeyPair { secret, .. } => secret.is_some(),
275 Self::Symmetric { .. } => true,
276 }
277 }
278
279 #[must_use]
281 pub fn to_verifying_key(&self) -> Self {
282 match self {
283 Self::Rsa { modulus, public_exponent, .. } => {
284 Self::Rsa { modulus: modulus.clone(), public_exponent: public_exponent.clone(), private_parts: None }
285 },
286
287 Self::EllipticCurve { curve, x, y, .. } => {
288 Self::EllipticCurve { curve: curve.clone(), x: x.clone(), y: y.clone(), secret: None }
289 },
290
291 Self::Symmetric { secret } => Self::Symmetric { secret: secret.clone() },
292
293 Self::KeyPair { curve, x, .. } => Self::KeyPair { curve: curve.clone(), x: x.clone(), secret: None },
294 }
295 }
296
297 pub fn thumbprint<D: Digest>(&self) -> Output<D> {
302 let hashed_key = if self.is_signing_key() { Cow::Owned(self.to_verifying_key()) } else { Cow::Borrowed(self) };
303 D::digest(hashed_key.to_string().as_bytes())
304 }
305}
306
307impl fmt::Display for JsonWebKey<'_> {
308 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
310 let json_value = serde_json::to_value(self).expect("Cannot convert JsonWebKey to JSON");
311 let json_value = json_value.as_object().unwrap();
312 let mut json_entries: Vec<_> = json_value.iter().collect();
315 json_entries.sort_unstable_by(|(x, _), (y, _)| x.cmp(y));
316
317 formatter.write_str("{")?;
318 let field_count = json_entries.len();
319 for (i, (name, value)) in json_entries.into_iter().enumerate() {
320 write!(formatter, "\"{name}\":{value}")?;
321 if i + 1 < field_count {
322 formatter.write_str(",")?;
323 }
324 }
325 formatter.write_str("}")
326 }
327}
328
329#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct RsaPrivateParts<'a> {
338 #[serde(rename = "d")]
340 pub private_exponent: SecretBytes<'a>,
341 #[serde(rename = "p")]
343 pub prime_factor_p: SecretBytes<'a>,
344 #[serde(rename = "q")]
346 pub prime_factor_q: SecretBytes<'a>,
347 #[serde(rename = "dp", default, skip_serializing_if = "Option::is_none")]
349 pub p_crt_exponent: Option<SecretBytes<'a>>,
350 #[serde(rename = "dq", default, skip_serializing_if = "Option::is_none")]
352 pub q_crt_exponent: Option<SecretBytes<'a>>,
353 #[serde(rename = "qi", default, skip_serializing_if = "Option::is_none")]
355 pub q_crt_coefficient: Option<SecretBytes<'a>>,
356 #[serde(rename = "oth", default, skip_serializing_if = "Vec::is_empty")]
358 pub other_prime_factors: Vec<RsaPrimeFactor<'a>>,
359}
360
361#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
369pub struct RsaPrimeFactor<'a> {
370 #[serde(rename = "r")]
372 pub factor: SecretBytes<'a>,
373 #[serde(rename = "d", default, skip_serializing_if = "Option::is_none")]
375 pub crt_exponent: Option<SecretBytes<'a>>,
376 #[serde(rename = "t", default, skip_serializing_if = "Option::is_none")]
378 pub crt_coefficient: Option<SecretBytes<'a>>,
379}
380
381#[cfg(any(
382 feature = "es256k",
383 feature = "k256",
384 feature = "exonum-crypto",
385 feature = "ed25519-dalek",
386 feature = "ed25519-compact"
387))]
388mod helpers {
389 use super::{JsonWebKey, JwkError};
390 use crate::{alg::SigningKey, alloc::ToOwned, Algorithm};
391
392 impl JsonWebKey<'_> {
393 pub(crate) fn ensure_curve(curve: &str, expected: &str) -> Result<(), JwkError> {
394 if curve == expected {
395 Ok(())
396 } else {
397 Err(JwkError::UnexpectedValue {
398 field: "crv".to_owned(),
399 expected: expected.to_owned(),
400 actual: curve.to_owned(),
401 })
402 }
403 }
404
405 pub(crate) fn ensure_len(field: &str, bytes: &[u8], expected_len: usize) -> Result<(), JwkError> {
406 if bytes.len() == expected_len {
407 Ok(())
408 } else {
409 Err(JwkError::UnexpectedLen { field: field.to_owned(), expected: expected_len, actual: bytes.len() })
410 }
411 }
412
413 pub(crate) fn ensure_key_match<Alg, K>(&self, signing_key: K) -> Result<K, JwkError>
416 where
417 Alg: Algorithm<SigningKey = K>,
418 K: SigningKey<Alg>,
419 Alg::VerifyingKey: for<'jwk> TryFrom<&'jwk Self, Error = JwkError> + PartialEq,
420 {
421 let verifying_key = <Alg::VerifyingKey>::try_from(self)?;
422 if verifying_key == signing_key.to_verifying_key() {
423 Ok(signing_key)
424 } else {
425 Err(JwkError::MismatchedKeys)
426 }
427 }
428 }
429}
430
431mod base64url {
432 use base64ct::{Base64UrlUnpadded, Encoding};
433 use serde::{
434 de::{Error as DeError, Unexpected, Visitor},
435 Deserializer, Serializer,
436 };
437
438 use core::fmt;
439
440 use crate::alloc::{Cow, Vec};
441
442 pub fn serialize<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
443 where
444 S: Serializer,
445 {
446 if serializer.is_human_readable() {
447 serializer.serialize_str(&Base64UrlUnpadded::encode_string(value))
448 } else {
449 serializer.serialize_bytes(value)
450 }
451 }
452
453 pub fn deserialize<'de, D>(deserializer: D) -> Result<Cow<'static, [u8]>, D::Error>
454 where
455 D: Deserializer<'de>,
456 {
457 struct Base64Visitor;
458
459 impl Visitor<'_> for Base64Visitor {
460 type Value = Vec<u8>;
461
462 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
463 formatter.write_str("base64url-encoded data")
464 }
465
466 fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
467 Base64UrlUnpadded::decode_vec(value).map_err(|_| E::invalid_value(Unexpected::Str(value), &self))
468 }
469
470 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
471 Ok(value.to_vec())
472 }
473
474 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
475 Ok(value)
476 }
477 }
478
479 struct BytesVisitor;
480
481 impl<'de> Visitor<'de> for BytesVisitor {
482 type Value = Vec<u8>;
483
484 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
485 formatter.write_str("byte buffer")
486 }
487
488 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
489 Ok(value.to_vec())
490 }
491
492 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
493 Ok(value)
494 }
495 }
496
497 let maybe_bytes = if deserializer.is_human_readable() {
498 deserializer.deserialize_str(Base64Visitor)
499 } else {
500 deserializer.deserialize_bytes(BytesVisitor)
501 };
502 maybe_bytes.map(Cow::Owned)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::alg::Hs256Key;
510
511 use assert_matches::assert_matches;
512
513 const fn create_jwk() -> JsonWebKey<'static> {
514 JsonWebKey::KeyPair { curve: Cow::Borrowed("Ed25519"), x: Cow::Borrowed(b"test"), secret: None }
515 }
516
517 #[test]
518 fn serializing_jwk() {
519 let jwk = create_jwk();
520
521 let json = serde_json::to_value(&jwk).unwrap();
522 assert_eq!(json, serde_json::json!({ "crv": "Ed25519", "kty": "OKP", "x": "dGVzdA" }));
523
524 let restored: JsonWebKey<'_> = serde_json::from_value(json).unwrap();
525 assert_eq!(restored, jwk);
526 }
527
528 #[test]
529 fn jwk_deserialization_errors() {
530 let missing_field_json = r#"{"crv":"Ed25519"}"#;
531 let missing_field_err = serde_json::from_str::<JsonWebKey<'_>>(missing_field_json).unwrap_err().to_string();
532 assert!(missing_field_err.contains("missing field `kty`"), "{missing_field_err}");
533
534 let base64_json = r#"{"crv":"Ed25519","kty":"OKP","x":"??"}"#;
535 let base64_err = serde_json::from_str::<JsonWebKey<'_>>(base64_json).unwrap_err().to_string();
536 assert!(base64_err.contains("invalid value: string \"??\""), "{base64_err}");
537 assert!(base64_err.contains("base64url-encoded data"), "{base64_err}");
538 }
539
540 #[test]
541 fn extra_jwk_fields() {
542 #[derive(Debug, Serialize, Deserialize)]
543 struct ExtendedJsonWebKey<'a, T> {
544 #[serde(flatten)]
545 base: JsonWebKey<'a>,
546 #[serde(flatten)]
547 extra: T,
548 }
549
550 #[derive(Debug, Deserialize)]
551 struct Extra {
552 #[serde(rename = "kid")]
553 key_id: String,
554 #[serde(rename = "use")]
555 key_use: KeyUse,
556 }
557
558 #[derive(Debug, Deserialize, PartialEq)]
559 enum KeyUse {
560 #[serde(rename = "sig")]
561 Signature,
562 #[serde(rename = "enc")]
563 Encryption,
564 }
565
566 let json_str = r#"
567 { "kty": "oct", "kid": "my-unique-key", "k": "dGVzdA", "use": "sig" }
568 "#;
569 let jwk: ExtendedJsonWebKey<'_, Extra> = serde_json::from_str(json_str).unwrap();
570
571 assert_matches!(&jwk.base, JsonWebKey::Symmetric { secret } if secret.as_ref() == b"test");
572 assert_eq!(jwk.extra.key_id, "my-unique-key");
573 assert_eq!(jwk.extra.key_use, KeyUse::Signature);
574
575 let key = Hs256Key::try_from(&jwk.base).unwrap();
576 let jwk_from_key = JsonWebKey::from(&key);
577
578 assert_matches!(
579 jwk_from_key,
580 JsonWebKey::Symmetric { secret } if secret.as_ref() == b"test"
581 );
582 }
583
584 #[test]
585 #[cfg(feature = "ciborium")]
586 fn jwk_with_cbor() {
587 let key = JsonWebKey::KeyPair {
588 curve: Cow::Borrowed("Ed25519"),
589 x: Cow::Borrowed(b"public"),
590 secret: Some(SecretBytes::borrowed(b"private")),
591 };
592 let mut bytes = vec![];
593 ciborium::into_writer(&key, &mut bytes).unwrap();
594 assert!(bytes.windows(6).any(|window| window == b"public"));
595 assert!(bytes.windows(7).any(|window| window == b"private"));
596
597 let restored: JsonWebKey<'_> = ciborium::from_reader(&bytes[..]).unwrap();
598 assert_eq!(restored, key);
599 }
600}