1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[cfg(feature = "ring")]
19use ring::{
20 rand::SystemRandom,
21 signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
22};
23
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use aws_lc_rs::{
26 rand::SystemRandom,
27 signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
28};
29
30#[derive(Debug, Clone, Copy)]
31pub enum JwtSignAlgorithm {
32 Rs256,
33 Ps256,
34}
35
36#[derive(Debug, Deserialize)]
38pub struct ServiceAccount {
39 pub client_email: String,
40 pub private_key: String,
41 pub token_uri: String,
42 }
44
45#[derive(Debug, Serialize)]
47struct JwtClaims {
48 iss: String,
49 scope: String,
50 aud: String,
51 exp: u64,
52 iat: u64,
53}
54
55fn base64_url_encode(input: &[u8]) -> String {
57 URL_SAFE_NO_PAD.encode(input)
58}
59
60pub fn parse_rsa_pkcs8_pem(pem: &str) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
64 if pem.contains("ENCRYPTED PRIVATE KEY") {
65 return Err("encrypted PEM private keys are not supported".into());
66 }
67 if pem.contains("BEGIN RSA PRIVATE KEY") {
68 return Err(
69 "PKCS#1 (BEGIN RSA PRIVATE KEY) format is not supported, please convert to PKCS#8"
70 .into(),
71 );
72 }
73 let pem_content = pem
74 .replace("-----BEGIN PRIVATE KEY-----", "")
75 .replace("-----END PRIVATE KEY-----", "")
76 .replace("\n", "")
77 .replace("\r", "")
78 .replace(" ", "");
79 let der_bytes = base64::engine::general_purpose::STANDARD
80 .decode(pem_content.trim())
81 .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
82 RsaKeyPair::from_pkcs8(&der_bytes).map_err(|e| format!("Invalid PKCS#8 RSA key: {}", e).into())
83}
84
85pub fn rsa_sha256_sign(
87 key_pair: &RsaKeyPair,
88 data: &[u8],
89) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
90 let mut signature = vec![0u8; signature_len(key_pair)];
91 let rng = SystemRandom::new();
92 key_pair.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature)?;
93 Ok(signature)
94}
95
96pub fn create_jwt(sa: &ServiceAccount, scopes: &str) -> Result<String, Box<dyn std::error::Error>> {
99 let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
100 let header_b64 = base64_url_encode(serde_json::to_string(&header)?.as_bytes());
101
102 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
103 let exp = now + 3600;
104
105 let claims = JwtClaims {
106 iss: sa.client_email.clone(),
107 scope: scopes.to_string(),
108 aud: sa.token_uri.clone(),
109 exp,
110 iat: now,
111 };
112 let claims_b64 = base64_url_encode(serde_json::to_string(&claims)?.as_bytes());
113
114 let signing_input = format!("{}.{}", header_b64, claims_b64);
115
116 let key_pair = parse_rsa_pkcs8_pem(&sa.private_key)?;
117 let signature = rsa_sha256_sign(&key_pair, signing_input.as_bytes())?;
118 let signature_b64 = base64_url_encode(&signature);
119
120 Ok(format!("{}.{}", signing_input, signature_b64))
121}
122
123pub async fn exchange_jwt_for_token(
125 token_uri: &str,
126 jwt: &str,
127) -> Result<String, Box<dyn std::error::Error>> {
128 let client = reqwest::Client::new();
129 let params = [
130 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
131 ("assertion", jwt),
132 ];
133 let body = serde_urlencoded::to_string(params).map_err(|e| e.to_string())?;
134 let resp: serde_json::Value = client
135 .post(token_uri)
136 .header("Content-Type", "application/x-www-form-urlencoded")
137 .body(body)
138 .send()
139 .await?
140 .json()
141 .await?;
142 if let Some(token) = resp.get("access_token") {
143 Ok(token.as_str().unwrap_or_default().to_string())
144 } else {
145 Err("Failed to obtain access token".into())
146 }
147}
148
149#[cfg(feature = "ring")]
150fn signature_len(key_pair: &RsaKeyPair) -> usize {
151 key_pair.public().modulus_len()
152}
153
154#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
155fn signature_len(key_pair: &RsaKeyPair) -> usize {
156 key_pair.public_modulus_len()
157}
158
159pub fn parse_pkcs8_pem(pem: &str) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
160 let stripped = pem
161 .replace("-----BEGIN PRIVATE KEY-----", "")
162 .replace("-----END PRIVATE KEY-----", "")
163 .replace("\n", "")
164 .replace("\r", "");
165 let der_bytes = base64::engine::general_purpose::STANDARD
166 .decode(stripped.trim())
167 .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
168 Ok(der_bytes)
169}
170
171pub fn rsa_sha512_sign(
172 private_key_pem: &str,
173 data: &[u8],
174) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
175 let der_bytes = parse_pkcs8_pem(private_key_pem)?;
176 let key_pair = RsaKeyPair::from_pkcs8(&der_bytes)?;
177 let mut signature = vec![0u8; signature_len(&key_pair)];
178 let rng = SystemRandom::new();
179 key_pair.sign(&RSA_PKCS1_SHA512, &rng, data, &mut signature)?;
180 Ok(signature)
181}
182
183pub fn sign_jwt(
184 header: &serde_json::Value,
185 claims: &serde_json::Value,
186 private_key_pem: &str,
187 algorithm: JwtSignAlgorithm,
188) -> Result<String, Box<dyn std::error::Error>> {
189 let header_b64 = base64_url_encode(serde_json::to_string(header)?.as_bytes());
190 let claims_b64 = base64_url_encode(serde_json::to_string(claims)?.as_bytes());
191 let signing_input = format!("{}.{}", header_b64, claims_b64);
192
193 let key_pair = parse_rsa_pkcs8_pem(private_key_pem)?;
194 let mut signature = vec![0u8; signature_len(&key_pair)];
195 let rng = SystemRandom::new();
196 match algorithm {
197 JwtSignAlgorithm::Rs256 => {
198 key_pair.sign(
199 &RSA_PKCS1_SHA256,
200 &rng,
201 signing_input.as_bytes(),
202 &mut signature,
203 )?;
204 }
205 JwtSignAlgorithm::Ps256 => {
206 key_pair.sign(
207 &RSA_PSS_SHA256,
208 &rng,
209 signing_input.as_bytes(),
210 &mut signature,
211 )?;
212 }
213 }
214 let signature_b64 = base64_url_encode(&signature);
215 Ok(format!("{}.{}", signing_input, signature_b64))
216}
217