1use base64::Engine;
21use base64::engine::general_purpose::URL_SAFE_NO_PAD;
22use rsa::RsaPrivateKey;
23use rsa::pkcs1v15::SigningKey;
24use rsa::pkcs8::DecodePrivateKey;
25use rsa::rand_core::OsRng;
26use rsa::sha2::Sha256;
27use rsa::signature::{RandomizedSigner, SignatureEncoding};
28use serde::Serialize;
29
30use crate::{Error, Result};
31
32pub fn encode_rs256<H, C>(header: &H, claims: &C, private_key: &RsaPrivateKey) -> Result<String>
37where
38 H: Serialize,
39 C: Serialize,
40{
41 let encoded_header = encode_json(header)?;
42 let encoded_claims = encode_json(claims)?;
43 let signing_input = format!("{encoded_header}.{encoded_claims}");
44
45 let mut rng = OsRng;
46 let signing_key = SigningKey::<Sha256>::new(private_key.clone());
47 let signature = signing_key.sign_with_rng(&mut rng, signing_input.as_bytes());
48 let encoded_signature = URL_SAFE_NO_PAD.encode(signature.to_bytes());
49
50 Ok(format!("{signing_input}.{encoded_signature}"))
51}
52
53pub fn encode_rs256_pem<H, C>(header: &H, claims: &C, private_key_pem: &[u8]) -> Result<String>
55where
56 H: Serialize,
57 C: Serialize,
58{
59 let private_key_pem = std::str::from_utf8(private_key_pem).map_err(|e| {
60 Error::credential_invalid("RSA private key PEM is not valid UTF-8").with_source(e)
61 })?;
62 let private_key = RsaPrivateKey::from_pkcs8_pem(private_key_pem).map_err(|e| {
63 Error::credential_invalid("failed to parse PKCS#8 RSA private key PEM").with_source(e)
64 })?;
65
66 encode_rs256(header, claims, &private_key)
67}
68
69fn encode_json<T>(value: &T) -> Result<String>
70where
71 T: Serialize,
72{
73 let json = serde_json::to_vec(value)
74 .map_err(|e| Error::unexpected("failed to serialize JWT JSON").with_source(e))?;
75 Ok(URL_SAFE_NO_PAD.encode(json))
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
82 use rsa::pkcs1v15::{Signature, VerifyingKey};
83 use rsa::rand_core::OsRng;
84 use rsa::signature::Verifier;
85 use serde_json::json;
86
87 #[derive(Serialize)]
88 struct Header<'a> {
89 alg: &'a str,
90 typ: &'a str,
91 x5t: &'a str,
92 }
93
94 #[derive(Serialize)]
95 struct Claims<'a> {
96 iss: &'a str,
97 sub: &'a str,
98 aud: &'a str,
99 exp: u64,
100 }
101
102 #[test]
103 fn encode_rs256_keeps_jws_shape_and_verifiable_signature() -> Result<()> {
104 let mut rng = OsRng;
105 let private_key = RsaPrivateKey::new(&mut rng, 2048)
106 .map_err(|e| Error::unexpected("failed to generate test RSA key").with_source(e))?;
107
108 let jwt = encode_rs256(
109 &Header {
110 alg: "RS256",
111 typ: "JWT",
112 x5t: "thumbprint",
113 },
114 &Claims {
115 iss: "client",
116 sub: "client",
117 aud: "https://example.com/token",
118 exp: 1,
119 },
120 &private_key,
121 )?;
122
123 let parts = jwt.split('.').collect::<Vec<_>>();
124 assert_eq!(parts.len(), 3);
125
126 let header: serde_json::Value = serde_json::from_slice(
127 &URL_SAFE_NO_PAD
128 .decode(parts[0])
129 .map_err(|e| Error::unexpected("failed to decode JWT header").with_source(e))?,
130 )
131 .map_err(|e| Error::unexpected("failed to parse JWT header").with_source(e))?;
132 assert_eq!(
133 header,
134 json!({
135 "alg": "RS256",
136 "typ": "JWT",
137 "x5t": "thumbprint"
138 })
139 );
140
141 let signature = URL_SAFE_NO_PAD
142 .decode(parts[2])
143 .map_err(|e| Error::unexpected("failed to decode JWT signature").with_source(e))?;
144 let signature = Signature::try_from(signature.as_slice())
145 .map_err(|e| Error::unexpected("failed to parse JWT signature").with_source(e))?;
146 let verifying_key = VerifyingKey::<Sha256>::new(private_key.to_public_key());
147 verifying_key
148 .verify(format!("{}.{}", parts[0], parts[1]).as_bytes(), &signature)
149 .map_err(|e| Error::unexpected("failed to verify JWT signature").with_source(e))?;
150
151 Ok(())
152 }
153}