1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[cfg(feature = "ring")]
19use ring::{
20 rand::SystemRandom,
21 signature::{RSA_PKCS1_SHA256, RsaKeyPair},
22};
23
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use aws_lc_rs::{
26 rand::SystemRandom,
27 signature::{RSA_PKCS1_SHA256, RsaKeyPair},
28};
29
30#[derive(Debug, Deserialize)]
32pub struct ServiceAccount {
33 pub client_email: String,
34 pub private_key: String,
35 pub token_uri: String,
36 }
38
39#[derive(Debug, Serialize)]
41struct JwtClaims {
42 iss: String,
43 scope: String,
44 aud: String,
45 exp: u64,
46 iat: u64,
47}
48
49fn base64_url_encode(input: &[u8]) -> String {
51 URL_SAFE_NO_PAD.encode(input)
52}
53
54pub fn create_jwt(sa: &ServiceAccount, scopes: &str) -> Result<String, Box<dyn std::error::Error>> {
57 let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
59 let header_b64 = base64_url_encode(serde_json::to_string(&header)?.as_bytes());
60
61 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
63 let exp = now + 3600; let claims = JwtClaims {
67 iss: sa.client_email.clone(),
68 scope: scopes.to_string(),
69 aud: sa.token_uri.clone(),
70 exp,
71 iat: now,
72 };
73 let claims_b64 = base64_url_encode(serde_json::to_string(&claims)?.as_bytes());
74
75 let signing_input = format!("{}.{}", header_b64, claims_b64);
76
77 let pem_content = sa
79 .private_key
80 .replace("-----BEGIN PRIVATE KEY-----", "")
81 .replace("-----END PRIVATE KEY-----", "")
82 .replace("\n", "")
83 .replace("\r", "");
84 let der_bytes = base64::engine::general_purpose::STANDARD
85 .decode(pem_content.trim())
86 .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
87 let key_pair = RsaKeyPair::from_pkcs8(&der_bytes)?;
88 let mut signature = vec![0u8; signature_len(&key_pair)];
89 let rng = SystemRandom::new();
90 key_pair.sign(
91 &RSA_PKCS1_SHA256,
92 &rng,
93 signing_input.as_bytes(),
94 &mut signature,
95 )?;
96 let signature_b64 = base64_url_encode(&signature);
97
98 Ok(format!("{}.{}", signing_input, signature_b64))
99}
100
101pub async fn exchange_jwt_for_token(
103 token_uri: &str,
104 jwt: &str,
105) -> Result<String, Box<dyn std::error::Error>> {
106 let client = reqwest::Client::new();
107 let params = [
108 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
109 ("assertion", jwt),
110 ];
111 let body = serde_urlencoded::to_string(params).map_err(|e| e.to_string())?;
112 let resp: serde_json::Value = client
113 .post(token_uri)
114 .header("Content-Type", "application/x-www-form-urlencoded")
115 .body(body)
116 .send()
117 .await?
118 .json()
119 .await?;
120 if let Some(token) = resp.get("access_token") {
121 Ok(token.as_str().unwrap_or_default().to_string())
122 } else {
123 Err("Failed to obtain access token".into())
124 }
125}
126
127#[cfg(feature = "ring")]
128fn signature_len(key_pair: &RsaKeyPair) -> usize {
129 key_pair.public().modulus_len()
130}
131
132#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
133fn signature_len(key_pair: &RsaKeyPair) -> usize {
134 key_pair.public_modulus_len()
135}