Skip to main content

dns_update/
jwt.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12//! Generic JWT utility for providers that need JWT authentication.
13
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[cfg(feature = "ring")]
19use ring::{
20    rand::SystemRandom,
21    signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
22};
23
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use aws_lc_rs::{
26    rand::SystemRandom,
27    signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
28};
29
30#[derive(Debug, Clone, Copy)]
31pub enum JwtSignAlgorithm {
32    Rs256,
33    Ps256,
34}
35
36/// Service account JSON fields needed for JWT creation.
37#[derive(Debug, Deserialize)]
38pub struct ServiceAccount {
39    pub client_email: String,
40    pub private_key: String,
41    pub token_uri: String,
42    // other fields are ignored
43}
44
45/// Claims for Google OAuth2 JWT.
46#[derive(Debug, Serialize)]
47struct JwtClaims {
48    iss: String,
49    scope: String,
50    aud: String,
51    exp: u64,
52    iat: u64,
53}
54
55/// Encode a byte slice as base64url without padding.
56fn base64_url_encode(input: &[u8]) -> String {
57    URL_SAFE_NO_PAD.encode(input)
58}
59
60/// Decode a PEM-encoded RSA private key (PKCS#8 `BEGIN PRIVATE KEY` or PKCS#1
61/// `BEGIN RSA PRIVATE KEY`) and return PKCS#8 DER bytes.
62pub fn rsa_private_key_pem_to_pkcs8_der(pem: &str) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
63    if pem.contains("ENCRYPTED PRIVATE KEY") {
64        return Err("encrypted PEM private keys are not supported".into());
65    }
66    let is_pkcs1 = pem.contains("BEGIN RSA PRIVATE KEY");
67    let body = strip_pem_armor(pem);
68    let der_bytes = base64::engine::general_purpose::STANDARD
69        .decode(&body)
70        .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
71    if is_pkcs1 {
72        Ok(wrap_pkcs1_in_pkcs8(&der_bytes))
73    } else {
74        Ok(der_bytes)
75    }
76}
77
78fn strip_pem_armor(pem: &str) -> String {
79    let mut out = String::with_capacity(pem.len());
80    for line in pem.lines() {
81        let trimmed = line.trim();
82        if trimmed.starts_with("-----") {
83            continue;
84        }
85        for ch in trimmed.chars() {
86            if !ch.is_ascii_whitespace() {
87                out.push(ch);
88            }
89        }
90    }
91    out
92}
93
94/// Wrap a PKCS#1 `RSAPrivateKey` DER blob in a PKCS#8 `PrivateKeyInfo` envelope.
95fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
96    const RSA_ENCRYPTION_ALG_ID: &[u8] = &[
97        0x30, 0x0D, 0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01, 0x05, 0x00,
98    ];
99    let mut octet_string = Vec::with_capacity(pkcs1_der.len() + 4);
100    octet_string.push(0x04);
101    write_der_length(&mut octet_string, pkcs1_der.len());
102    octet_string.extend_from_slice(pkcs1_der);
103
104    let mut inner = Vec::with_capacity(3 + RSA_ENCRYPTION_ALG_ID.len() + octet_string.len());
105    inner.extend_from_slice(&[0x02, 0x01, 0x00]);
106    inner.extend_from_slice(RSA_ENCRYPTION_ALG_ID);
107    inner.extend_from_slice(&octet_string);
108
109    let mut out = Vec::with_capacity(inner.len() + 4);
110    out.push(0x30);
111    write_der_length(&mut out, inner.len());
112    out.extend_from_slice(&inner);
113    out
114}
115
116fn write_der_length(out: &mut Vec<u8>, len: usize) {
117    if len < 0x80 {
118        out.push(len as u8);
119    } else if len < 0x100 {
120        out.extend_from_slice(&[0x81, len as u8]);
121    } else if len < 0x10000 {
122        out.extend_from_slice(&[0x82, (len >> 8) as u8, len as u8]);
123    } else if len < 0x1000000 {
124        out.extend_from_slice(&[0x83, (len >> 16) as u8, (len >> 8) as u8, len as u8]);
125    } else {
126        out.extend_from_slice(&[
127            0x84,
128            (len >> 24) as u8,
129            (len >> 16) as u8,
130            (len >> 8) as u8,
131            len as u8,
132        ]);
133    }
134}
135
136/// Parse an RSA private key PEM (PKCS#1 or PKCS#8) into an `RsaKeyPair`.
137pub fn parse_rsa_pkcs8_pem(pem: &str) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
138    let der_bytes = rsa_private_key_pem_to_pkcs8_der(pem)?;
139    RsaKeyPair::from_pkcs8(&der_bytes).map_err(|e| format!("Invalid PKCS#8 RSA key: {}", e).into())
140}
141
142/// Sign `data` with RSA-SHA256 using a key pair, returning the raw signature bytes.
143pub fn rsa_sha256_sign(
144    key_pair: &RsaKeyPair,
145    data: &[u8],
146) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
147    let mut signature = vec![0u8; signature_len(key_pair)];
148    let rng = SystemRandom::new();
149    key_pair.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature)?;
150    Ok(signature)
151}
152
153/// Create a signed JWT using the service account private key.
154/// Returns the JWT as a compact string.
155pub fn create_jwt(sa: &ServiceAccount, scopes: &str) -> Result<String, Box<dyn std::error::Error>> {
156    let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
157    let header_b64 = base64_url_encode(serde_json::to_string(&header)?.as_bytes());
158
159    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
160    let exp = now + 3600;
161
162    let claims = JwtClaims {
163        iss: sa.client_email.clone(),
164        scope: scopes.to_string(),
165        aud: sa.token_uri.clone(),
166        exp,
167        iat: now,
168    };
169    let claims_b64 = base64_url_encode(serde_json::to_string(&claims)?.as_bytes());
170
171    let signing_input = format!("{}.{}", header_b64, claims_b64);
172
173    let key_pair = parse_rsa_pkcs8_pem(&sa.private_key)?;
174    let signature = rsa_sha256_sign(&key_pair, signing_input.as_bytes())?;
175    let signature_b64 = base64_url_encode(&signature);
176
177    Ok(format!("{}.{}", signing_input, signature_b64))
178}
179
180/// Exchange a JWT for an OAuth2 access token.
181pub async fn exchange_jwt_for_token(
182    token_uri: &str,
183    jwt: &str,
184) -> Result<String, Box<dyn std::error::Error>> {
185    let client = reqwest::Client::new();
186    let params = [
187        ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
188        ("assertion", jwt),
189    ];
190    let body = serde_urlencoded::to_string(params).map_err(|e| e.to_string())?;
191    let resp: serde_json::Value = client
192        .post(token_uri)
193        .header("Content-Type", "application/x-www-form-urlencoded")
194        .body(body)
195        .send()
196        .await?
197        .json()
198        .await?;
199    if let Some(token) = resp.get("access_token") {
200        Ok(token.as_str().unwrap_or_default().to_string())
201    } else {
202        Err("Failed to obtain access token".into())
203    }
204}
205
206#[cfg(feature = "ring")]
207fn signature_len(key_pair: &RsaKeyPair) -> usize {
208    key_pair.public().modulus_len()
209}
210
211#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
212fn signature_len(key_pair: &RsaKeyPair) -> usize {
213    key_pair.public_modulus_len()
214}
215
216pub fn rsa_sha512_sign(
217    private_key_pem: &str,
218    data: &[u8],
219) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
220    let key_pair = parse_rsa_pkcs8_pem(private_key_pem)?;
221    let mut signature = vec![0u8; signature_len(&key_pair)];
222    let rng = SystemRandom::new();
223    key_pair.sign(&RSA_PKCS1_SHA512, &rng, data, &mut signature)?;
224    Ok(signature)
225}
226
227pub fn sign_jwt(
228    header: &serde_json::Value,
229    claims: &serde_json::Value,
230    private_key_pem: &str,
231    algorithm: JwtSignAlgorithm,
232) -> Result<String, Box<dyn std::error::Error>> {
233    let header_b64 = base64_url_encode(serde_json::to_string(header)?.as_bytes());
234    let claims_b64 = base64_url_encode(serde_json::to_string(claims)?.as_bytes());
235    let signing_input = format!("{}.{}", header_b64, claims_b64);
236
237    let key_pair = parse_rsa_pkcs8_pem(private_key_pem)?;
238    let mut signature = vec![0u8; signature_len(&key_pair)];
239    let rng = SystemRandom::new();
240    match algorithm {
241        JwtSignAlgorithm::Rs256 => {
242            key_pair.sign(
243                &RSA_PKCS1_SHA256,
244                &rng,
245                signing_input.as_bytes(),
246                &mut signature,
247            )?;
248        }
249        JwtSignAlgorithm::Ps256 => {
250            key_pair.sign(
251                &RSA_PSS_SHA256,
252                &rng,
253                signing_input.as_bytes(),
254                &mut signature,
255            )?;
256        }
257    }
258    let signature_b64 = base64_url_encode(&signature);
259    Ok(format!("{}.{}", signing_input, signature_b64))
260}