Skip to main content

dns_update/
jwt.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12//! Generic JWT utility for providers that need JWT authentication.
13
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[cfg(feature = "ring")]
19use ring::{
20    rand::SystemRandom,
21    signature::{RSA_PKCS1_SHA256, RsaKeyPair},
22};
23
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use aws_lc_rs::{
26    rand::SystemRandom,
27    signature::{RSA_PKCS1_SHA256, RsaKeyPair},
28};
29
30/// Service account JSON fields needed for JWT creation.
31#[derive(Debug, Deserialize)]
32pub struct ServiceAccount {
33    pub client_email: String,
34    pub private_key: String,
35    pub token_uri: String,
36    // other fields are ignored
37}
38
39/// Claims for Google OAuth2 JWT.
40#[derive(Debug, Serialize)]
41struct JwtClaims {
42    iss: String,
43    scope: String,
44    aud: String,
45    exp: u64,
46    iat: u64,
47}
48
49/// Encode a byte slice as base64url without padding.
50fn base64_url_encode(input: &[u8]) -> String {
51    URL_SAFE_NO_PAD.encode(input)
52}
53
54/// Create a signed JWT using the service account private key.
55/// Returns the JWT as a compact string.
56pub fn create_jwt(sa: &ServiceAccount, scopes: &str) -> Result<String, Box<dyn std::error::Error>> {
57    // Header
58    let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
59    let header_b64 = base64_url_encode(serde_json::to_string(&header)?.as_bytes());
60
61    // Timestamps
62    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
63    let exp = now + 3600; // 1 hour validity
64
65    // Claims
66    let claims = JwtClaims {
67        iss: sa.client_email.clone(),
68        scope: scopes.to_string(),
69        aud: sa.token_uri.clone(),
70        exp,
71        iat: now,
72    };
73    let claims_b64 = base64_url_encode(serde_json::to_string(&claims)?.as_bytes());
74
75    let signing_input = format!("{}.{}", header_b64, claims_b64);
76
77    // Sign using RSA SHA256
78    let pem_content = sa
79        .private_key
80        .replace("-----BEGIN PRIVATE KEY-----", "")
81        .replace("-----END PRIVATE KEY-----", "")
82        .replace("\n", "")
83        .replace("\r", "");
84    let der_bytes = base64::engine::general_purpose::STANDARD
85        .decode(pem_content.trim())
86        .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
87    let key_pair = RsaKeyPair::from_pkcs8(&der_bytes)?;
88    let mut signature = vec![0u8; signature_len(&key_pair)];
89    let rng = SystemRandom::new();
90    key_pair.sign(
91        &RSA_PKCS1_SHA256,
92        &rng,
93        signing_input.as_bytes(),
94        &mut signature,
95    )?;
96    let signature_b64 = base64_url_encode(&signature);
97
98    Ok(format!("{}.{}", signing_input, signature_b64))
99}
100
101/// Exchange a JWT for an OAuth2 access token.
102pub async fn exchange_jwt_for_token(
103    token_uri: &str,
104    jwt: &str,
105) -> Result<String, Box<dyn std::error::Error>> {
106    let client = reqwest::Client::new();
107    let params = [
108        ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
109        ("assertion", jwt),
110    ];
111    let body = serde_urlencoded::to_string(params).map_err(|e| e.to_string())?;
112    let resp: serde_json::Value = client
113        .post(token_uri)
114        .header("Content-Type", "application/x-www-form-urlencoded")
115        .body(body)
116        .send()
117        .await?
118        .json()
119        .await?;
120    if let Some(token) = resp.get("access_token") {
121        Ok(token.as_str().unwrap_or_default().to_string())
122    } else {
123        Err("Failed to obtain access token".into())
124    }
125}
126
127#[cfg(feature = "ring")]
128fn signature_len(key_pair: &RsaKeyPair) -> usize {
129    key_pair.public().modulus_len()
130}
131
132#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
133fn signature_len(key_pair: &RsaKeyPair) -> usize {
134    key_pair.public_modulus_len()
135}