Skip to main content

reqsign_core/
jwt.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! JWT encoding helpers.
19
20use base64::Engine;
21use base64::engine::general_purpose::URL_SAFE_NO_PAD;
22use rsa::RsaPrivateKey;
23use rsa::pkcs1v15::SigningKey;
24use rsa::pkcs8::DecodePrivateKey;
25use rsa::rand_core::OsRng;
26use rsa::sha2::Sha256;
27use rsa::signature::{RandomizedSigner, SignatureEncoding};
28use serde::Serialize;
29
30use crate::{Error, Result};
31
32/// Encode a JWS compact JWT using the RS256 algorithm.
33///
34/// RS256 is RSA PKCS#1 v1.5 with SHA-256. The caller owns the JWT header and
35/// claims shape so service-specific fields such as `x5t` stay service-local.
36pub fn encode_rs256<H, C>(header: &H, claims: &C, private_key: &RsaPrivateKey) -> Result<String>
37where
38    H: Serialize,
39    C: Serialize,
40{
41    let encoded_header = encode_json(header)?;
42    let encoded_claims = encode_json(claims)?;
43    let signing_input = format!("{encoded_header}.{encoded_claims}");
44
45    let mut rng = OsRng;
46    let signing_key = SigningKey::<Sha256>::new(private_key.clone());
47    let signature = signing_key.sign_with_rng(&mut rng, signing_input.as_bytes());
48    let encoded_signature = URL_SAFE_NO_PAD.encode(signature.to_bytes());
49
50    Ok(format!("{signing_input}.{encoded_signature}"))
51}
52
53/// Encode a JWS compact JWT using an RSA private key in PKCS#8 PEM format.
54pub fn encode_rs256_pem<H, C>(header: &H, claims: &C, private_key_pem: &[u8]) -> Result<String>
55where
56    H: Serialize,
57    C: Serialize,
58{
59    let private_key_pem = std::str::from_utf8(private_key_pem).map_err(|e| {
60        Error::credential_invalid("RSA private key PEM is not valid UTF-8").with_source(e)
61    })?;
62    let private_key = RsaPrivateKey::from_pkcs8_pem(private_key_pem).map_err(|e| {
63        Error::credential_invalid("failed to parse PKCS#8 RSA private key PEM").with_source(e)
64    })?;
65
66    encode_rs256(header, claims, &private_key)
67}
68
69fn encode_json<T>(value: &T) -> Result<String>
70where
71    T: Serialize,
72{
73    let json = serde_json::to_vec(value)
74        .map_err(|e| Error::unexpected("failed to serialize JWT JSON").with_source(e))?;
75    Ok(URL_SAFE_NO_PAD.encode(json))
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
82    use rsa::pkcs1v15::{Signature, VerifyingKey};
83    use rsa::rand_core::OsRng;
84    use rsa::signature::Verifier;
85    use serde_json::json;
86
87    #[derive(Serialize)]
88    struct Header<'a> {
89        alg: &'a str,
90        typ: &'a str,
91        x5t: &'a str,
92    }
93
94    #[derive(Serialize)]
95    struct Claims<'a> {
96        iss: &'a str,
97        sub: &'a str,
98        aud: &'a str,
99        exp: u64,
100    }
101
102    #[test]
103    fn encode_rs256_keeps_jws_shape_and_verifiable_signature() -> Result<()> {
104        let mut rng = OsRng;
105        let private_key = RsaPrivateKey::new(&mut rng, 2048)
106            .map_err(|e| Error::unexpected("failed to generate test RSA key").with_source(e))?;
107
108        let jwt = encode_rs256(
109            &Header {
110                alg: "RS256",
111                typ: "JWT",
112                x5t: "thumbprint",
113            },
114            &Claims {
115                iss: "client",
116                sub: "client",
117                aud: "https://example.com/token",
118                exp: 1,
119            },
120            &private_key,
121        )?;
122
123        let parts = jwt.split('.').collect::<Vec<_>>();
124        assert_eq!(parts.len(), 3);
125
126        let header: serde_json::Value = serde_json::from_slice(
127            &URL_SAFE_NO_PAD
128                .decode(parts[0])
129                .map_err(|e| Error::unexpected("failed to decode JWT header").with_source(e))?,
130        )
131        .map_err(|e| Error::unexpected("failed to parse JWT header").with_source(e))?;
132        assert_eq!(
133            header,
134            json!({
135                "alg": "RS256",
136                "typ": "JWT",
137                "x5t": "thumbprint"
138            })
139        );
140
141        let signature = URL_SAFE_NO_PAD
142            .decode(parts[2])
143            .map_err(|e| Error::unexpected("failed to decode JWT signature").with_source(e))?;
144        let signature = Signature::try_from(signature.as_slice())
145            .map_err(|e| Error::unexpected("failed to parse JWT signature").with_source(e))?;
146        let verifying_key = VerifyingKey::<Sha256>::new(private_key.to_public_key());
147        verifying_key
148            .verify(format!("{}.{}", parts[0], parts[1]).as_bytes(), &signature)
149            .map_err(|e| Error::unexpected("failed to verify JWT signature").with_source(e))?;
150
151        Ok(())
152    }
153}