1use std::sync::Arc;
79use std::time::Duration;
80
81use azure_core::credentials::TokenCredential;
82use azure_identity::ClientCertificateCredential;
83
84use crate::AzureAdAuth;
85use crate::error::AuthError;
86
87const AZURE_SQL_SCOPE: &str = "https://database.windows.net/.default";
89
90pub struct CertificateAuth {
103 credential: Arc<ClientCertificateCredential>,
104}
105
106impl CertificateAuth {
107 pub fn new(
136 tenant_id: impl AsRef<str>,
137 client_id: impl Into<String>,
138 certificate: impl AsRef<[u8]>,
139 password: Option<&str>,
140 ) -> Result<Self, AuthError> {
141 use azure_core::credentials::Secret;
142 use azure_identity::ClientCertificateCredentialOptions;
143 use base64::Engine;
144
145 let cert_bytes = certificate.as_ref();
148 let cert_b64 = if cert_bytes.starts_with(b"MII") || is_base64(cert_bytes) {
149 String::from_utf8_lossy(cert_bytes).to_string()
151 } else {
152 base64::engine::general_purpose::STANDARD.encode(cert_bytes)
154 };
155
156 let cert_secret = azure_core::credentials::SecretBytes::new(cert_b64.into_bytes());
157
158 let options = if let Some(pwd) = password {
161 ClientCertificateCredentialOptions {
162 password: Some(Secret::new(pwd.to_string())),
163 ..Default::default()
164 }
165 } else {
166 ClientCertificateCredentialOptions::default()
167 };
168
169 let credential = ClientCertificateCredential::new(
170 tenant_id.as_ref().to_string(),
171 client_id.into(),
172 cert_secret,
173 Some(options),
174 )
175 .map_err(|e| {
176 AuthError::Certificate(format!("Failed to create certificate credential: {e}"))
177 })?;
178
179 Ok(Self { credential })
180 }
181
182 pub fn from_pem(
221 tenant_id: impl AsRef<str>,
222 client_id: impl Into<String>,
223 cert_pem: impl AsRef<[u8]>,
224 key_pem: impl AsRef<[u8]>,
225 password: Option<&str>,
226 ) -> Result<Self, AuthError> {
227 use std::io::BufReader;
228
229 let cert_pem_bytes = cert_pem.as_ref();
231 let mut cert_reader = BufReader::new(cert_pem_bytes);
232 let certs: Vec<_> = rustls_pemfile::certs(&mut cert_reader)
233 .collect::<Result<Vec<_>, _>>()
234 .map_err(|e| AuthError::Certificate(format!("Failed to parse certificate PEM: {e}")))?;
235
236 let cert_der = certs
237 .first()
238 .ok_or_else(|| AuthError::Certificate("No certificate found in PEM data".into()))?;
239
240 let key_pem_bytes = key_pem.as_ref();
242 let mut key_reader = BufReader::new(key_pem_bytes);
243 let key_der = rustls_pemfile::private_key(&mut key_reader)
244 .map_err(|e| AuthError::Certificate(format!("Failed to parse private key PEM: {e}")))?
245 .ok_or_else(|| AuthError::Certificate("No private key found in PEM data".into()))?;
246
247 let pkcs12_password = password.unwrap_or("");
249 let pfx = p12::PFX::new(
250 cert_der.as_ref(),
251 key_der.secret_der(),
252 None, pkcs12_password,
254 "cert",
255 )
256 .ok_or_else(|| AuthError::Certificate("Failed to create PKCS#12 from PEM data".into()))?;
257
258 let pkcs12_bytes = pfx.to_der();
259
260 Self::new(tenant_id, client_id, pkcs12_bytes, password)
262 }
263
264 pub async fn get_token(&self) -> Result<String, AuthError> {
271 let token = self
272 .credential
273 .get_token(&[AZURE_SQL_SCOPE], None)
274 .await
275 .map_err(|e| AuthError::Certificate(format!("Failed to acquire token: {e}")))?;
276 Ok(token.token.secret().to_string())
277 }
278
279 pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
285 let token = self
286 .credential
287 .get_token(&[AZURE_SQL_SCOPE], None)
288 .await
289 .map_err(|e| AuthError::Certificate(format!("Failed to acquire token: {e}")))?;
290
291 let now = time::OffsetDateTime::now_utc();
293 let expires_in = if token.expires_on > now {
294 let diff = token.expires_on - now;
295 Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
296 } else {
297 None
298 };
299
300 Ok((token.token.secret().to_string(), expires_in))
301 }
302
303 pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
312 let (token, expires_in) = self.get_token_with_expiry().await?;
313 match expires_in {
314 Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
315 None => Ok(AzureAdAuth::with_token(token)),
316 }
317 }
318}
319
320fn is_base64(data: &[u8]) -> bool {
322 data.iter().all(|&b| {
325 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=' || b == b'\n' || b == b'\r'
326 })
327}
328
329impl Clone for CertificateAuth {
330 fn clone(&self) -> Self {
331 Self {
332 credential: Arc::clone(&self.credential),
333 }
334 }
335}
336
337impl std::fmt::Debug for CertificateAuth {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 f.debug_struct("CertificateAuth")
340 .field("credential", &"[REDACTED]")
341 .finish()
342 }
343}
344
345impl crate::provider::AsyncAuthProvider for CertificateAuth {
346 fn method(&self) -> crate::provider::AuthMethod {
347 crate::provider::AuthMethod::AzureAd
348 }
349
350 async fn authenticate_async(&self) -> Result<crate::provider::AuthData, AuthError> {
351 let token = self.get_token().await?;
352 Ok(crate::provider::AuthData::FedAuth { token, nonce: None })
353 }
354
355 fn needs_refresh(&self) -> bool {
356 false
358 }
359}
360
361#[cfg(test)]
362#[allow(clippy::unwrap_used, clippy::expect_used)]
363mod tests {
364 use super::*;
365
366 #[test]
371 fn test_is_base64() {
372 assert!(is_base64(b"SGVsbG8gV29ybGQ="));
373 assert!(is_base64(b"MIIC+jCCAeKgAwIBAgIJAL"));
374 assert!(!is_base64(&[0x00, 0x01, 0x02, 0x03])); }
376
377 #[tokio::test]
378 #[ignore = "Requires Azure Service Principal with certificate"]
379 async fn test_certificate_auth() {
380 let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
381 let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
382 let cert_path = std::env::var("AZURE_CLIENT_CERTIFICATE_PATH")
383 .expect("AZURE_CLIENT_CERTIFICATE_PATH not set");
384 let cert_password = std::env::var("AZURE_CLIENT_CERTIFICATE_PASSWORD").ok();
385
386 let cert_bytes = std::fs::read(&cert_path).expect("Failed to read certificate");
387 let auth = CertificateAuth::new(tenant_id, client_id, cert_bytes, cert_password.as_deref())
388 .expect("Failed to create CertificateAuth");
389
390 let token = auth.get_token().await.expect("Failed to get token");
391 assert!(!token.is_empty());
392 }
393
394 #[test]
395 fn test_from_pem_invalid_certificate() {
396 let invalid_cert = b"not a valid PEM certificate";
397 let valid_key_format = b"-----BEGIN PRIVATE KEY-----\nMIIE=\n-----END PRIVATE KEY-----";
398
399 let result = CertificateAuth::from_pem(
400 "tenant-id",
401 "client-id",
402 invalid_cert,
403 valid_key_format,
404 None,
405 );
406
407 assert!(result.is_err());
408 let err = result.unwrap_err();
409 assert!(
410 err.to_string().contains("No certificate found"),
411 "Expected 'No certificate found' error, got: {err}"
412 );
413 }
414
415 #[test]
416 fn test_from_pem_invalid_private_key() {
417 let cert_pem =
419 b"-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpE=\n-----END CERTIFICATE-----";
420 let invalid_key = b"not a valid PEM private key";
421
422 let result =
423 CertificateAuth::from_pem("tenant-id", "client-id", cert_pem, invalid_key, None);
424
425 assert!(result.is_err());
426 let err = result.unwrap_err();
427 assert!(
428 err.to_string().contains("No private key found"),
429 "Expected 'No private key found' error, got: {err}"
430 );
431 }
432
433 #[test]
434 fn test_from_pem_empty_certificate() {
435 let empty_cert = b"";
436 let key_pem = b"-----BEGIN PRIVATE KEY-----\nMIIE=\n-----END PRIVATE KEY-----";
437
438 let result = CertificateAuth::from_pem("tenant-id", "client-id", empty_cert, key_pem, None);
439
440 assert!(result.is_err());
441 }
442
443 #[test]
444 fn test_from_pem_empty_private_key() {
445 let cert_pem =
446 b"-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpE=\n-----END CERTIFICATE-----";
447 let empty_key = b"";
448
449 let result = CertificateAuth::from_pem("tenant-id", "client-id", cert_pem, empty_key, None);
450
451 assert!(result.is_err());
452 }
453
454 #[tokio::test]
455 #[ignore = "Requires Azure Service Principal with PEM certificate"]
456 async fn test_certificate_auth_from_pem() {
457 let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
458 let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
459 let cert_path = std::env::var("AZURE_CLIENT_CERTIFICATE_PEM")
460 .expect("AZURE_CLIENT_CERTIFICATE_PEM not set");
461 let key_path = std::env::var("AZURE_CLIENT_PRIVATE_KEY_PEM")
462 .expect("AZURE_CLIENT_PRIVATE_KEY_PEM not set");
463
464 let cert_pem = std::fs::read(&cert_path).expect("Failed to read certificate PEM");
465 let key_pem = std::fs::read(&key_path).expect("Failed to read private key PEM");
466
467 let auth = CertificateAuth::from_pem(tenant_id, client_id, &cert_pem, &key_pem, None)
468 .expect("Failed to create CertificateAuth from PEM");
469
470 let token = auth.get_token().await.expect("Failed to get token");
471 assert!(!token.is_empty());
472 }
473}