faucet_common_snowflake/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use faucet_core::FaucetError;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23
24#[derive(Clone, Serialize, Deserialize, JsonSchema)]
32#[serde(tag = "type", content = "config", rename_all = "snake_case")]
33pub enum SnowflakeAuth {
34 KeyPair {
39 user: String,
41 private_key_pem: String,
43 },
44 #[serde(rename = "oauth")]
46 OAuth { token: String },
47}
48
49impl std::fmt::Debug for SnowflakeAuth {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Self::KeyPair { user, .. } => f
53 .debug_struct("KeyPair")
54 .field("user", user)
55 .field("private_key_pem", &"***")
56 .finish(),
57 Self::OAuth { .. } => f.debug_struct("OAuth").field("token", &"***").finish(),
58 }
59 }
60}
61
62pub fn authorization_header(auth: &SnowflakeAuth, account: &str) -> Result<String, FaucetError> {
73 match auth {
74 SnowflakeAuth::KeyPair {
75 user,
76 private_key_pem,
77 } => {
78 let account_upper = account.to_uppercase();
79 let user_upper = user.to_uppercase();
80 let qualified_user = format!("{account_upper}.{user_upper}");
81
82 let fingerprint = public_key_fingerprint(private_key_pem)?;
89 let issuer = format!("{qualified_user}.{fingerprint}");
90
91 let now = jsonwebtoken::get_current_timestamp();
92 let claims = serde_json::json!({
93 "iss": issuer,
94 "sub": qualified_user,
95 "iat": now,
96 "exp": now + 3600,
97 });
98
99 let key = jsonwebtoken::EncodingKey::from_rsa_pem(private_key_pem.as_bytes())
100 .map_err(|e| FaucetError::Auth(format!("invalid RSA key: {e}")))?;
101
102 let token = jsonwebtoken::encode(
103 &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
104 &claims,
105 &key,
106 )
107 .map_err(|e| FaucetError::Auth(format!("JWT generation failed: {e}")))?;
108
109 Ok(format!("Bearer {token}"))
110 }
111 SnowflakeAuth::OAuth { token } => Ok(format!("Snowflake Token=\"{token}\"")),
112 }
113}
114
115fn public_key_fingerprint(private_key_pem: &str) -> Result<String, FaucetError> {
123 use base64::Engine as _;
124 use rsa::pkcs1::DecodeRsaPrivateKey;
125 use rsa::pkcs8::{DecodePrivateKey, EncodePublicKey};
126 use rsa::{RsaPrivateKey, RsaPublicKey};
127 use sha2::{Digest, Sha256};
128
129 let private = RsaPrivateKey::from_pkcs8_pem(private_key_pem)
132 .or_else(|_| RsaPrivateKey::from_pkcs1_pem(private_key_pem))
133 .map_err(|e| FaucetError::Auth(format!("invalid RSA private key: {e}")))?;
134
135 let public = RsaPublicKey::from(&private);
136 let der = public
137 .to_public_key_der()
138 .map_err(|e| FaucetError::Auth(format!("failed to DER-encode public key: {e}")))?;
139
140 let digest = Sha256::digest(der.as_bytes());
141 let b64 = base64::engine::general_purpose::STANDARD.encode(digest);
142 Ok(format!("SHA256:{b64}"))
143}
144
145pub fn snowflake_token_type(auth: &SnowflakeAuth) -> &'static str {
148 match auth {
149 SnowflakeAuth::KeyPair { .. } => "KEYPAIR_JWT",
150 SnowflakeAuth::OAuth { .. } => "OAUTH",
151 }
152}
153
154pub fn credential_to_auth(cred: faucet_core::Credential) -> Result<SnowflakeAuth, FaucetError> {
168 match cred {
169 faucet_core::Credential::Bearer(token) | faucet_core::Credential::Token(token) => {
170 Ok(SnowflakeAuth::OAuth { token })
171 }
172 other => Err(FaucetError::Auth(format!(
173 "Snowflake auth provider must yield a bearer/token credential, got {other:?}"
174 ))),
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn debug_masks_key_pair_private_key() {
184 let auth = SnowflakeAuth::KeyPair {
185 user: "alice".into(),
186 private_key_pem: "PRIVATE-KEY-DATA".into(),
187 };
188 let debug = format!("{auth:?}");
189 assert!(debug.contains("alice"));
190 assert!(debug.contains("***"));
191 assert!(!debug.contains("PRIVATE-KEY-DATA"));
192 }
193
194 #[test]
195 fn debug_masks_oauth_token() {
196 let auth = SnowflakeAuth::OAuth {
197 token: "my-token".into(),
198 };
199 let debug = format!("{auth:?}");
200 assert!(debug.contains("***"));
201 assert!(!debug.contains("my-token"));
202 }
203
204 #[test]
205 fn serde_round_trip_oauth() {
206 let auth = SnowflakeAuth::OAuth { token: "t".into() };
207 let json = serde_json::to_string(&auth).unwrap();
208 assert_eq!(json, r#"{"type":"oauth","config":{"token":"t"}}"#);
209 let parsed: SnowflakeAuth = serde_json::from_str(&json).unwrap();
210 assert!(matches!(parsed, SnowflakeAuth::OAuth { .. }));
211 }
212
213 #[test]
214 fn serde_round_trip_key_pair() {
215 let json = r#"{"type":"key_pair","config":{"user":"u","private_key_pem":"k"}}"#;
216 let parsed: SnowflakeAuth = serde_json::from_str(json).unwrap();
217 match parsed {
218 SnowflakeAuth::KeyPair {
219 user,
220 private_key_pem,
221 } => {
222 assert_eq!(user, "u");
223 assert_eq!(private_key_pem, "k");
224 }
225 _ => panic!("expected KeyPair"),
226 }
227 }
228
229 #[test]
230 fn oauth_authorization_header_uses_snowflake_token_scheme() {
231 let auth = SnowflakeAuth::OAuth {
232 token: "my-token".into(),
233 };
234 let header = authorization_header(&auth, "acct").unwrap();
235 assert_eq!(header, "Snowflake Token=\"my-token\"");
236 }
237
238 #[test]
239 fn key_pair_with_invalid_pem_surfaces_auth_error() {
240 let auth = SnowflakeAuth::KeyPair {
241 user: "u".into(),
242 private_key_pem: "not-a-pem".into(),
243 };
244 let err = authorization_header(&auth, "acct").unwrap_err();
245 match err {
246 FaucetError::Auth(msg) => assert!(msg.contains("invalid RSA"), "{msg}"),
249 other => panic!("expected Auth error, got {other:?}"),
250 }
251 }
252
253 const TEST_RSA_PKCS8_PEM: &str = "-----BEGIN PRIVATE KEY-----
255MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDDmeSF5jD5LMGw
256INB1hExU2Ux9qEQ9DXNUeWxrDv7K3QHA+UkCbdUpHDZdFSbIr/bvwlNn16Hqhqi9
257b8WywAzjagZNg0cReXuQ7nKIr5c9zYl2EJe+RZTo2z2LE21HrSKRhTAmlOk3XJ1N
258xc7ahYcKyw8lchuTcZaYWaNTYvronOpHUAGS0XpT0y8Oggzp1DvZNYOeZbJCPZwf
259mpGGCSilnODNYnwT02Pc4aXXBzJP7TP57+ve/ZzqvsKCBiNJUMLsjUZcGWnqQHnR
260A+8B87ug7CyhhEiYnskp0d1ZlWT/kU7rIZv58KMbMJidAdizA47jRjelsWeoedRf
261JmiA99ZhAgMBAAECggEAAOrybwxm82xZ1k05HSwLPaStXrOQ6mZrQZy2PQRbfrEt
262xm2FAa1pQCGhQauNPIjS1EopoQWafWK3XPguyclr5g9Dy05P4Y2b3lC4GdsVDxWt
263TPAD/kEOU09gCQyEyT7PODaTRMMTGw7ksA47C7xvp0XPouHXrkfsqHdXNFd1DO1Z
264dBCzkX4dg4Y4ffh5tt/ILeSsNlmqqpUQmHQZ/X3JHkP9/+NpAe6i4k9QKsqmLDGD
2657+br/snVYbECBgmN1QIofTSvnlmmRiKgoG9wbZLmGvCiW9xVjbY+ryJs/lsLoM7w
266W1TUuOlk3apoIzQ7OIGznyZzE5RumdQq11rNKB7aaQKBgQDowsceEQz2kLb93f8J
267QaBDcebqbbGTJE6+hq2k8D/GzvZAdBHGuEt7NiDAFKy/GItwzJSGGdjK24iRtZ7G
2682gIloZShu+7mmxX6Ojuxun8EMRZKZzTedMJWQJMwA1Hk1fwzsEM0+9+yZdTcylP9
269wYDMFKbvw+av7sDcySENNEhshQKBgQDXIVX+Zvlf2PoLkRx11mk1CBtPfjqPTMcs
270QVjISwvkgGSi8ihq+mwsIWLXhOZX38+L4iGfdIgqSSnwqB/fgTbjwQsa0Dqkygt6
271IBfb3QmWr7196c+xss5h8eUTFiCMWw/EAa9R+jkWH0cVpJVbyTK7cBJlaXxPcXx3
272xprI10qnLQKBgBl/NKajgYME6Ta3+bb+3FpnAL+PUpNmt8WBJUZbFvFlPG5lCIl3
273KLWPgVjpKt8oBiZOErr529ik4bnsZj8sJG4Q3CI3Xv0d4fNuK5nVbxJ7ehCea5ku
274uxcNrdHlmzPxCNZ0qXgFW0TEiOPCuh6i8sPoQz0ifYOqKLBGy/sRThmtAoGAGTd9
275Hv7vCD8kwCpYTa++UUsL+HtxXc7AIf3e7Etvr28lXLxJ5JBKEbowHdckMPS5HUp6
276anh8ZYiB9AWhBs/coUHFjXUPCrXsNnqAkXMNZq5e5d18TPYKnwx9r4kOc6VQ6cbQ
277yCkue9tat7y9DS8+VR5D6cM9oQpKbrfG+PfTdlkCgYBf/pUWO94VgZvpV5Ui7MHb
2786ZoH11q0gIhmT72FQ+2Erw977qghzs1+C7HO4Q7kNfC8sA9uVS4WiA1EzE6QeJWt
279+FklEinW+AR2azgC/+gEUBvZSWU1v4meYdAQcNEek8L4VtBuGc4ZwbVbho3hiLmx
28068Y3qeoKxOyBKo6j2NiZzg==
281-----END PRIVATE KEY-----
282";
283
284 #[test]
285 fn key_pair_jwt_iss_includes_public_key_fingerprint() {
286 use base64::Engine as _;
289 let auth = SnowflakeAuth::KeyPair {
290 user: "u".into(),
291 private_key_pem: TEST_RSA_PKCS8_PEM.into(),
292 };
293 let header = authorization_header(&auth, "acct").unwrap();
294 let jwt = header.strip_prefix("Bearer ").expect("Bearer token");
295 let payload_b64 = jwt.split('.').nth(1).expect("jwt payload segment");
296 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
297 .decode(payload_b64)
298 .expect("base64url payload");
299 let claims: serde_json::Value = serde_json::from_slice(&payload).unwrap();
300 assert_eq!(claims["sub"], "ACCT.U");
301 assert_eq!(
302 claims["iss"],
303 "ACCT.U.SHA256:NiQ5G+9Hr4ZBmdBscIoTOgx2SM6aWPG0/Q9Y6NuFtpI="
304 );
305 }
306
307 #[test]
308 fn public_key_fingerprint_matches_openssl() {
309 let fp = public_key_fingerprint(TEST_RSA_PKCS8_PEM).unwrap();
310 assert_eq!(fp, "SHA256:NiQ5G+9Hr4ZBmdBscIoTOgx2SM6aWPG0/Q9Y6NuFtpI=");
311 }
312
313 #[test]
314 fn token_type_matches_variant() {
315 assert_eq!(
316 snowflake_token_type(&SnowflakeAuth::OAuth { token: "t".into() }),
317 "OAUTH"
318 );
319 assert_eq!(
320 snowflake_token_type(&SnowflakeAuth::KeyPair {
321 user: "u".into(),
322 private_key_pem: "k".into()
323 }),
324 "KEYPAIR_JWT"
325 );
326 }
327}