Skip to main content

dns_update/
jwt.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12//! Generic JWT utility for providers that need JWT authentication.
13
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[cfg(feature = "ring")]
19use ring::{
20    rand::SystemRandom,
21    signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
22};
23
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use aws_lc_rs::{
26    rand::SystemRandom,
27    signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
28};
29
30#[derive(Debug, Clone, Copy)]
31pub enum JwtSignAlgorithm {
32    Rs256,
33    Ps256,
34}
35
36/// Service account JSON fields needed for JWT creation.
37#[derive(Debug, Deserialize)]
38pub struct ServiceAccount {
39    pub client_email: String,
40    pub private_key: String,
41    pub token_uri: String,
42    // other fields are ignored
43}
44
45/// Claims for Google OAuth2 JWT.
46#[derive(Debug, Serialize)]
47struct JwtClaims {
48    iss: String,
49    scope: String,
50    aud: String,
51    exp: u64,
52    iat: u64,
53}
54
55/// Encode a byte slice as base64url without padding.
56fn base64_url_encode(input: &[u8]) -> String {
57    URL_SAFE_NO_PAD.encode(input)
58}
59
60/// Parse a PKCS#8 PEM-encoded RSA private key into an `RsaKeyPair`.
61/// Supports both `BEGIN PRIVATE KEY` (PKCS#8) and `BEGIN RSA PRIVATE KEY` (PKCS#1).
62/// Encrypted PEMs (`BEGIN ENCRYPTED PRIVATE KEY`) are not supported.
63pub fn parse_rsa_pkcs8_pem(pem: &str) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
64    if pem.contains("ENCRYPTED PRIVATE KEY") {
65        return Err("encrypted PEM private keys are not supported".into());
66    }
67    if pem.contains("BEGIN RSA PRIVATE KEY") {
68        return Err(
69            "PKCS#1 (BEGIN RSA PRIVATE KEY) format is not supported, please convert to PKCS#8"
70                .into(),
71        );
72    }
73    let pem_content = pem
74        .replace("-----BEGIN PRIVATE KEY-----", "")
75        .replace("-----END PRIVATE KEY-----", "")
76        .replace("\n", "")
77        .replace("\r", "")
78        .replace(" ", "");
79    let der_bytes = base64::engine::general_purpose::STANDARD
80        .decode(pem_content.trim())
81        .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
82    RsaKeyPair::from_pkcs8(&der_bytes).map_err(|e| format!("Invalid PKCS#8 RSA key: {}", e).into())
83}
84
85/// Sign `data` with RSA-SHA256 using a key pair, returning the raw signature bytes.
86pub fn rsa_sha256_sign(
87    key_pair: &RsaKeyPair,
88    data: &[u8],
89) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
90    let mut signature = vec![0u8; signature_len(key_pair)];
91    let rng = SystemRandom::new();
92    key_pair.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature)?;
93    Ok(signature)
94}
95
96/// Create a signed JWT using the service account private key.
97/// Returns the JWT as a compact string.
98pub fn create_jwt(sa: &ServiceAccount, scopes: &str) -> Result<String, Box<dyn std::error::Error>> {
99    let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
100    let header_b64 = base64_url_encode(serde_json::to_string(&header)?.as_bytes());
101
102    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
103    let exp = now + 3600;
104
105    let claims = JwtClaims {
106        iss: sa.client_email.clone(),
107        scope: scopes.to_string(),
108        aud: sa.token_uri.clone(),
109        exp,
110        iat: now,
111    };
112    let claims_b64 = base64_url_encode(serde_json::to_string(&claims)?.as_bytes());
113
114    let signing_input = format!("{}.{}", header_b64, claims_b64);
115
116    let key_pair = parse_rsa_pkcs8_pem(&sa.private_key)?;
117    let signature = rsa_sha256_sign(&key_pair, signing_input.as_bytes())?;
118    let signature_b64 = base64_url_encode(&signature);
119
120    Ok(format!("{}.{}", signing_input, signature_b64))
121}
122
123/// Exchange a JWT for an OAuth2 access token.
124pub async fn exchange_jwt_for_token(
125    token_uri: &str,
126    jwt: &str,
127) -> Result<String, Box<dyn std::error::Error>> {
128    let client = reqwest::Client::new();
129    let params = [
130        ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
131        ("assertion", jwt),
132    ];
133    let body = serde_urlencoded::to_string(params).map_err(|e| e.to_string())?;
134    let resp: serde_json::Value = client
135        .post(token_uri)
136        .header("Content-Type", "application/x-www-form-urlencoded")
137        .body(body)
138        .send()
139        .await?
140        .json()
141        .await?;
142    if let Some(token) = resp.get("access_token") {
143        Ok(token.as_str().unwrap_or_default().to_string())
144    } else {
145        Err("Failed to obtain access token".into())
146    }
147}
148
149#[cfg(feature = "ring")]
150fn signature_len(key_pair: &RsaKeyPair) -> usize {
151    key_pair.public().modulus_len()
152}
153
154#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
155fn signature_len(key_pair: &RsaKeyPair) -> usize {
156    key_pair.public_modulus_len()
157}
158
159pub fn parse_pkcs8_pem(pem: &str) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
160    let stripped = pem
161        .replace("-----BEGIN PRIVATE KEY-----", "")
162        .replace("-----END PRIVATE KEY-----", "")
163        .replace("\n", "")
164        .replace("\r", "");
165    let der_bytes = base64::engine::general_purpose::STANDARD
166        .decode(stripped.trim())
167        .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
168    Ok(der_bytes)
169}
170
171pub fn rsa_sha512_sign(
172    private_key_pem: &str,
173    data: &[u8],
174) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
175    let der_bytes = parse_pkcs8_pem(private_key_pem)?;
176    let key_pair = RsaKeyPair::from_pkcs8(&der_bytes)?;
177    let mut signature = vec![0u8; signature_len(&key_pair)];
178    let rng = SystemRandom::new();
179    key_pair.sign(&RSA_PKCS1_SHA512, &rng, data, &mut signature)?;
180    Ok(signature)
181}
182
183pub fn sign_jwt(
184    header: &serde_json::Value,
185    claims: &serde_json::Value,
186    private_key_pem: &str,
187    algorithm: JwtSignAlgorithm,
188) -> Result<String, Box<dyn std::error::Error>> {
189    let header_b64 = base64_url_encode(serde_json::to_string(header)?.as_bytes());
190    let claims_b64 = base64_url_encode(serde_json::to_string(claims)?.as_bytes());
191    let signing_input = format!("{}.{}", header_b64, claims_b64);
192
193    let key_pair = parse_rsa_pkcs8_pem(private_key_pem)?;
194    let mut signature = vec![0u8; signature_len(&key_pair)];
195    let rng = SystemRandom::new();
196    match algorithm {
197        JwtSignAlgorithm::Rs256 => {
198            key_pair.sign(
199                &RSA_PKCS1_SHA256,
200                &rng,
201                signing_input.as_bytes(),
202                &mut signature,
203            )?;
204        }
205        JwtSignAlgorithm::Ps256 => {
206            key_pair.sign(
207                &RSA_PSS_SHA256,
208                &rng,
209                signing_input.as_bytes(),
210                &mut signature,
211            )?;
212        }
213    }
214    let signature_b64 = base64_url_encode(&signature);
215    Ok(format!("{}.{}", signing_input, signature_b64))
216}
217