openauth_plugins/jwt/
keys.rs1use openauth_core::error::OpenAuthError;
2use rand::RngCore;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use time::OffsetDateTime;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum JwkAlgorithm {
9 #[serde(rename = "EdDSA")]
10 EdDsa,
11 #[serde(rename = "ES256")]
12 Es256,
13 #[serde(rename = "ES512")]
14 Es512,
15 #[serde(rename = "RS256")]
16 Rs256,
17 #[serde(rename = "PS256")]
18 Ps256,
19}
20
21impl JwkAlgorithm {
22 pub fn as_str(self) -> &'static str {
23 match self {
24 Self::EdDsa => "EdDSA",
25 Self::Es256 => "ES256",
26 Self::Es512 => "ES512",
27 Self::Rs256 => "RS256",
28 Self::Ps256 => "PS256",
29 }
30 }
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub struct Jwk {
35 pub id: String,
36 pub public_key: String,
37 pub private_key: String,
38 pub created_at: OffsetDateTime,
39 pub expires_at: Option<OffsetDateTime>,
40 pub alg: Option<JwkAlgorithm>,
41 pub crv: Option<String>,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45pub struct Jwks {
46 pub keys: Vec<Value>,
47}
48
49pub(crate) fn generate_jwk(options: &super::JwtOptions) -> Result<Jwk, OpenAuthError> {
50 let algorithm = options.algorithm();
51 let mut private = match algorithm {
52 JwkAlgorithm::EdDsa => {
53 use josekit::jwk::alg::ed::EdCurve::Ed25519;
54 josekit::jwk::Jwk::generate_ed_key(Ed25519)
55 }
56 JwkAlgorithm::Es256 => {
57 use josekit::jwk::alg::ec::EcCurve::P256;
58 josekit::jwk::Jwk::generate_ec_key(P256)
59 }
60 JwkAlgorithm::Es512 => {
61 use josekit::jwk::alg::ec::EcCurve::P521;
62 josekit::jwk::Jwk::generate_ec_key(P521)
63 }
64 JwkAlgorithm::Rs256 | JwkAlgorithm::Ps256 => {
65 josekit::jwk::Jwk::generate_rsa_key(options.jwks.rsa_modulus_length.unwrap_or(2048))
66 }
67 }
68 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
69 private.set_algorithm(algorithm.as_str());
70 private.set_key_use("sig");
71 private.set_key_operations(vec!["sign"]);
72
73 let id = random_id();
74 private.set_key_id(&id);
75 let mut public = private
76 .to_public_key()
77 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
78 public.set_algorithm(algorithm.as_str());
79 public.set_key_use("sig");
80 public.set_key_operations(vec!["verify"]);
81 public.set_key_id(&id);
82
83 let now = OffsetDateTime::now_utc();
84 let expires_at = options
85 .jwks
86 .rotation_interval
87 .map(|seconds| now + time::Duration::seconds(seconds));
88
89 Ok(Jwk {
90 id,
91 public_key: serde_json::to_string(&public)
92 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?,
93 private_key: serde_json::to_string(&private)
94 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?,
95 created_at: now,
96 expires_at,
97 alg: Some(algorithm),
98 crv: public.curve().map(str::to_owned),
99 })
100}
101
102pub(crate) fn public_jwk_value(
103 key: &Jwk,
104 options: &super::JwtOptions,
105) -> Result<Value, OpenAuthError> {
106 let mut value: Value = serde_json::from_str(&key.public_key)
107 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
108 let Value::Object(map) = &mut value else {
109 return Err(OpenAuthError::Crypto(
110 "public JWK must be an object".to_owned(),
111 ));
112 };
113 map.insert("kid".to_owned(), Value::String(key.id.clone()));
114 map.insert(
115 "alg".to_owned(),
116 Value::String(
117 key.alg
118 .unwrap_or_else(|| options.algorithm())
119 .as_str()
120 .to_owned(),
121 ),
122 );
123 if let Some(crv) = &key.crv {
124 map.entry("crv".to_owned())
125 .or_insert_with(|| Value::String(crv.clone()));
126 }
127 map.remove("d");
128 Ok(value)
129}
130
131fn random_id() -> String {
132 let mut bytes = [0_u8; 16];
133 rand::rngs::OsRng.fill_bytes(&mut bytes);
134 bytes[6] = (bytes[6] & 0x0f) | 0x40;
135 bytes[8] = (bytes[8] & 0x3f) | 0x80;
136 format!(
137 "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
138 u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
139 u16::from_be_bytes([bytes[4], bytes[5]]),
140 u16::from_be_bytes([bytes[6], bytes[7]]),
141 u16::from_be_bytes([bytes[8], bytes[9]]),
142 u64::from_be_bytes([
143 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
144 ])
145 )
146}