Skip to main content

openauth_plugins/jwt/
keys.rs

1use openauth_core::error::OpenAuthError;
2use rand::RngCore;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use time::OffsetDateTime;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum JwkAlgorithm {
9    #[serde(rename = "EdDSA")]
10    EdDsa,
11    #[serde(rename = "ES256")]
12    Es256,
13    #[serde(rename = "ES512")]
14    Es512,
15    #[serde(rename = "RS256")]
16    Rs256,
17    #[serde(rename = "PS256")]
18    Ps256,
19}
20
21impl JwkAlgorithm {
22    pub fn as_str(self) -> &'static str {
23        match self {
24            Self::EdDsa => "EdDSA",
25            Self::Es256 => "ES256",
26            Self::Es512 => "ES512",
27            Self::Rs256 => "RS256",
28            Self::Ps256 => "PS256",
29        }
30    }
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub struct Jwk {
35    pub id: String,
36    pub public_key: String,
37    pub private_key: String,
38    pub created_at: OffsetDateTime,
39    pub expires_at: Option<OffsetDateTime>,
40    pub alg: Option<JwkAlgorithm>,
41    pub crv: Option<String>,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45pub struct Jwks {
46    pub keys: Vec<Value>,
47}
48
49pub(crate) fn generate_jwk(options: &super::JwtOptions) -> Result<Jwk, OpenAuthError> {
50    let algorithm = options.algorithm();
51    let mut private = match algorithm {
52        JwkAlgorithm::EdDsa => {
53            use josekit::jwk::alg::ed::EdCurve::Ed25519;
54            josekit::jwk::Jwk::generate_ed_key(Ed25519)
55        }
56        JwkAlgorithm::Es256 => {
57            use josekit::jwk::alg::ec::EcCurve::P256;
58            josekit::jwk::Jwk::generate_ec_key(P256)
59        }
60        JwkAlgorithm::Es512 => {
61            use josekit::jwk::alg::ec::EcCurve::P521;
62            josekit::jwk::Jwk::generate_ec_key(P521)
63        }
64        JwkAlgorithm::Rs256 | JwkAlgorithm::Ps256 => {
65            josekit::jwk::Jwk::generate_rsa_key(options.jwks.rsa_modulus_length.unwrap_or(2048))
66        }
67    }
68    .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
69    private.set_algorithm(algorithm.as_str());
70    private.set_key_use("sig");
71    private.set_key_operations(vec!["sign"]);
72
73    let id = random_id();
74    private.set_key_id(&id);
75    let mut public = private
76        .to_public_key()
77        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
78    public.set_algorithm(algorithm.as_str());
79    public.set_key_use("sig");
80    public.set_key_operations(vec!["verify"]);
81    public.set_key_id(&id);
82
83    let now = OffsetDateTime::now_utc();
84    let expires_at = options
85        .jwks
86        .rotation_interval
87        .map(|seconds| now + time::Duration::seconds(seconds));
88
89    Ok(Jwk {
90        id,
91        public_key: serde_json::to_string(&public)
92            .map_err(|error| OpenAuthError::Crypto(error.to_string()))?,
93        private_key: serde_json::to_string(&private)
94            .map_err(|error| OpenAuthError::Crypto(error.to_string()))?,
95        created_at: now,
96        expires_at,
97        alg: Some(algorithm),
98        crv: public.curve().map(str::to_owned),
99    })
100}
101
102pub(crate) fn public_jwk_value(
103    key: &Jwk,
104    options: &super::JwtOptions,
105) -> Result<Value, OpenAuthError> {
106    let mut value: Value = serde_json::from_str(&key.public_key)
107        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
108    let Value::Object(map) = &mut value else {
109        return Err(OpenAuthError::Crypto(
110            "public JWK must be an object".to_owned(),
111        ));
112    };
113    map.insert("kid".to_owned(), Value::String(key.id.clone()));
114    map.insert(
115        "alg".to_owned(),
116        Value::String(
117            key.alg
118                .unwrap_or_else(|| options.algorithm())
119                .as_str()
120                .to_owned(),
121        ),
122    );
123    if let Some(crv) = &key.crv {
124        map.entry("crv".to_owned())
125            .or_insert_with(|| Value::String(crv.clone()));
126    }
127    map.remove("d");
128    Ok(value)
129}
130
131fn random_id() -> String {
132    let mut bytes = [0_u8; 16];
133    rand::rngs::OsRng.fill_bytes(&mut bytes);
134    bytes[6] = (bytes[6] & 0x0f) | 0x40;
135    bytes[8] = (bytes[8] & 0x3f) | 0x80;
136    format!(
137        "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
138        u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
139        u16::from_be_bytes([bytes[4], bytes[5]]),
140        u16::from_be_bytes([bytes[6], bytes[7]]),
141        u16::from_be_bytes([bytes[8], bytes[9]]),
142        u64::from_be_bytes([
143            0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
144        ])
145    )
146}