1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[cfg(feature = "ring")]
19use ring::{
20 rand::SystemRandom,
21 signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
22};
23
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use aws_lc_rs::{
26 rand::SystemRandom,
27 signature::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RsaKeyPair},
28};
29
30#[derive(Debug, Clone, Copy)]
31pub enum JwtSignAlgorithm {
32 Rs256,
33 Ps256,
34}
35
36#[derive(Debug, Deserialize)]
38pub struct ServiceAccount {
39 pub client_email: String,
40 pub private_key: String,
41 pub token_uri: String,
42 }
44
45#[derive(Debug, Serialize)]
47struct JwtClaims {
48 iss: String,
49 scope: String,
50 aud: String,
51 exp: u64,
52 iat: u64,
53}
54
55fn base64_url_encode(input: &[u8]) -> String {
57 URL_SAFE_NO_PAD.encode(input)
58}
59
60pub fn rsa_private_key_pem_to_pkcs8_der(pem: &str) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
63 if pem.contains("ENCRYPTED PRIVATE KEY") {
64 return Err("encrypted PEM private keys are not supported".into());
65 }
66 let is_pkcs1 = pem.contains("BEGIN RSA PRIVATE KEY");
67 let body = strip_pem_armor(pem);
68 let der_bytes = base64::engine::general_purpose::STANDARD
69 .decode(&body)
70 .map_err(|e| format!("Invalid base64 in private key: {}", e))?;
71 if is_pkcs1 {
72 Ok(wrap_pkcs1_in_pkcs8(&der_bytes))
73 } else {
74 Ok(der_bytes)
75 }
76}
77
78fn strip_pem_armor(pem: &str) -> String {
79 let mut out = String::with_capacity(pem.len());
80 for line in pem.lines() {
81 let trimmed = line.trim();
82 if trimmed.starts_with("-----") {
83 continue;
84 }
85 for ch in trimmed.chars() {
86 if !ch.is_ascii_whitespace() {
87 out.push(ch);
88 }
89 }
90 }
91 out
92}
93
94fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
96 const RSA_ENCRYPTION_ALG_ID: &[u8] = &[
97 0x30, 0x0D, 0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01, 0x05, 0x00,
98 ];
99 let mut octet_string = Vec::with_capacity(pkcs1_der.len() + 4);
100 octet_string.push(0x04);
101 write_der_length(&mut octet_string, pkcs1_der.len());
102 octet_string.extend_from_slice(pkcs1_der);
103
104 let mut inner = Vec::with_capacity(3 + RSA_ENCRYPTION_ALG_ID.len() + octet_string.len());
105 inner.extend_from_slice(&[0x02, 0x01, 0x00]);
106 inner.extend_from_slice(RSA_ENCRYPTION_ALG_ID);
107 inner.extend_from_slice(&octet_string);
108
109 let mut out = Vec::with_capacity(inner.len() + 4);
110 out.push(0x30);
111 write_der_length(&mut out, inner.len());
112 out.extend_from_slice(&inner);
113 out
114}
115
116fn write_der_length(out: &mut Vec<u8>, len: usize) {
117 if len < 0x80 {
118 out.push(len as u8);
119 } else if len < 0x100 {
120 out.extend_from_slice(&[0x81, len as u8]);
121 } else if len < 0x10000 {
122 out.extend_from_slice(&[0x82, (len >> 8) as u8, len as u8]);
123 } else if len < 0x1000000 {
124 out.extend_from_slice(&[0x83, (len >> 16) as u8, (len >> 8) as u8, len as u8]);
125 } else {
126 out.extend_from_slice(&[
127 0x84,
128 (len >> 24) as u8,
129 (len >> 16) as u8,
130 (len >> 8) as u8,
131 len as u8,
132 ]);
133 }
134}
135
136pub fn parse_rsa_pkcs8_pem(pem: &str) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
138 let der_bytes = rsa_private_key_pem_to_pkcs8_der(pem)?;
139 RsaKeyPair::from_pkcs8(&der_bytes).map_err(|e| format!("Invalid PKCS#8 RSA key: {}", e).into())
140}
141
142pub fn rsa_sha256_sign(
144 key_pair: &RsaKeyPair,
145 data: &[u8],
146) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
147 let mut signature = vec![0u8; signature_len(key_pair)];
148 let rng = SystemRandom::new();
149 key_pair.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature)?;
150 Ok(signature)
151}
152
153pub fn create_jwt(sa: &ServiceAccount, scopes: &str) -> Result<String, Box<dyn std::error::Error>> {
156 let header = serde_json::json!({"alg": "RS256", "typ": "JWT"});
157 let header_b64 = base64_url_encode(serde_json::to_string(&header)?.as_bytes());
158
159 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
160 let exp = now + 3600;
161
162 let claims = JwtClaims {
163 iss: sa.client_email.clone(),
164 scope: scopes.to_string(),
165 aud: sa.token_uri.clone(),
166 exp,
167 iat: now,
168 };
169 let claims_b64 = base64_url_encode(serde_json::to_string(&claims)?.as_bytes());
170
171 let signing_input = format!("{}.{}", header_b64, claims_b64);
172
173 let key_pair = parse_rsa_pkcs8_pem(&sa.private_key)?;
174 let signature = rsa_sha256_sign(&key_pair, signing_input.as_bytes())?;
175 let signature_b64 = base64_url_encode(&signature);
176
177 Ok(format!("{}.{}", signing_input, signature_b64))
178}
179
180pub async fn exchange_jwt_for_token(
182 token_uri: &str,
183 jwt: &str,
184) -> Result<String, Box<dyn std::error::Error>> {
185 let client = reqwest::Client::new();
186 let params = [
187 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
188 ("assertion", jwt),
189 ];
190 let body = serde_urlencoded::to_string(params).map_err(|e| e.to_string())?;
191 let resp: serde_json::Value = client
192 .post(token_uri)
193 .header("Content-Type", "application/x-www-form-urlencoded")
194 .body(body)
195 .send()
196 .await?
197 .json()
198 .await?;
199 if let Some(token) = resp.get("access_token") {
200 Ok(token.as_str().unwrap_or_default().to_string())
201 } else {
202 Err("Failed to obtain access token".into())
203 }
204}
205
206#[cfg(feature = "ring")]
207fn signature_len(key_pair: &RsaKeyPair) -> usize {
208 key_pair.public().modulus_len()
209}
210
211#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
212fn signature_len(key_pair: &RsaKeyPair) -> usize {
213 key_pair.public_modulus_len()
214}
215
216pub fn rsa_sha512_sign(
217 private_key_pem: &str,
218 data: &[u8],
219) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
220 let key_pair = parse_rsa_pkcs8_pem(private_key_pem)?;
221 let mut signature = vec![0u8; signature_len(&key_pair)];
222 let rng = SystemRandom::new();
223 key_pair.sign(&RSA_PKCS1_SHA512, &rng, data, &mut signature)?;
224 Ok(signature)
225}
226
227pub fn sign_jwt(
228 header: &serde_json::Value,
229 claims: &serde_json::Value,
230 private_key_pem: &str,
231 algorithm: JwtSignAlgorithm,
232) -> Result<String, Box<dyn std::error::Error>> {
233 let header_b64 = base64_url_encode(serde_json::to_string(header)?.as_bytes());
234 let claims_b64 = base64_url_encode(serde_json::to_string(claims)?.as_bytes());
235 let signing_input = format!("{}.{}", header_b64, claims_b64);
236
237 let key_pair = parse_rsa_pkcs8_pem(private_key_pem)?;
238 let mut signature = vec![0u8; signature_len(&key_pair)];
239 let rng = SystemRandom::new();
240 match algorithm {
241 JwtSignAlgorithm::Rs256 => {
242 key_pair.sign(
243 &RSA_PKCS1_SHA256,
244 &rng,
245 signing_input.as_bytes(),
246 &mut signature,
247 )?;
248 }
249 JwtSignAlgorithm::Ps256 => {
250 key_pair.sign(
251 &RSA_PSS_SHA256,
252 &rng,
253 signing_input.as_bytes(),
254 &mut signature,
255 )?;
256 }
257 }
258 let signature_b64 = base64_url_encode(&signature);
259 Ok(format!("{}.{}", signing_input, signature_b64))
260}