1use std::sync::Arc;
79use std::time::Duration;
80
81use azure_core::credentials::TokenCredential;
82use azure_identity::ClientCertificateCredential;
83
84use crate::AzureAdAuth;
85use crate::error::AuthError;
86
87const AZURE_SQL_SCOPE: &str = "https://database.windows.net/.default";
89
90pub struct CertificateAuth {
103 credential: Arc<ClientCertificateCredential>,
104}
105
106impl CertificateAuth {
107 pub fn new(
136 tenant_id: impl AsRef<str>,
137 client_id: impl Into<String>,
138 certificate: impl AsRef<[u8]>,
139 password: Option<&str>,
140 ) -> Result<Self, AuthError> {
141 use azure_core::credentials::Secret;
142 use azure_identity::ClientCertificateCredentialOptions;
143 use base64::Engine;
144
145 let cert_bytes = certificate.as_ref();
149 let der: Vec<u8> = if is_base64(cert_bytes) {
150 base64::engine::general_purpose::STANDARD
151 .decode(cert_bytes)
152 .map_err(|e| {
153 AuthError::Certificate(format!("base64-decoding the certificate failed: {e}"))
154 })?
155 } else {
156 cert_bytes.to_vec()
157 };
158
159 let cert_secret = azure_core::credentials::SecretBytes::new(der);
160
161 let options = if let Some(pwd) = password {
164 ClientCertificateCredentialOptions {
165 password: Some(Secret::new(pwd.to_string())),
166 ..Default::default()
167 }
168 } else {
169 ClientCertificateCredentialOptions::default()
170 };
171
172 let credential = ClientCertificateCredential::new(
173 tenant_id.as_ref().to_string(),
174 client_id.into(),
175 cert_secret,
176 Some(options),
177 )
178 .map_err(|e| {
179 AuthError::Certificate(format!("Failed to create certificate credential: {e}"))
180 })?;
181
182 Ok(Self { credential })
183 }
184
185 pub fn from_pem(
224 tenant_id: impl AsRef<str>,
225 client_id: impl Into<String>,
226 cert_pem: impl AsRef<[u8]>,
227 key_pem: impl AsRef<[u8]>,
228 password: Option<&str>,
229 ) -> Result<Self, AuthError> {
230 use std::io::BufReader;
231
232 let cert_pem_bytes = cert_pem.as_ref();
234 let mut cert_reader = BufReader::new(cert_pem_bytes);
235 let certs: Vec<_> = rustls_pemfile::certs(&mut cert_reader)
236 .collect::<Result<Vec<_>, _>>()
237 .map_err(|e| AuthError::Certificate(format!("Failed to parse certificate PEM: {e}")))?;
238
239 let cert_der = certs
240 .first()
241 .ok_or_else(|| AuthError::Certificate("No certificate found in PEM data".into()))?;
242
243 let key_pem_bytes = key_pem.as_ref();
245 let mut key_reader = BufReader::new(key_pem_bytes);
246 let key_der = rustls_pemfile::private_key(&mut key_reader)
247 .map_err(|e| AuthError::Certificate(format!("Failed to parse private key PEM: {e}")))?
248 .ok_or_else(|| AuthError::Certificate("No private key found in PEM data".into()))?;
249
250 let pkcs12_password = password.unwrap_or("");
252 let pfx = p12::PFX::new(
253 cert_der.as_ref(),
254 key_der.secret_der(),
255 None, pkcs12_password,
257 "cert",
258 )
259 .ok_or_else(|| AuthError::Certificate("Failed to create PKCS#12 from PEM data".into()))?;
260
261 let pkcs12_bytes = pfx.to_der();
262
263 Self::new(tenant_id, client_id, pkcs12_bytes, password)
265 }
266
267 pub async fn get_token(&self) -> Result<String, AuthError> {
274 let token = self
275 .credential
276 .get_token(&[AZURE_SQL_SCOPE], None)
277 .await
278 .map_err(|e| AuthError::Certificate(format!("Failed to acquire token: {e}")))?;
279 Ok(token.token.secret().to_string())
280 }
281
282 pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
288 let token = self
289 .credential
290 .get_token(&[AZURE_SQL_SCOPE], None)
291 .await
292 .map_err(|e| AuthError::Certificate(format!("Failed to acquire token: {e}")))?;
293
294 let now = time::OffsetDateTime::now_utc();
296 let expires_in = if token.expires_on > now {
297 let diff = token.expires_on - now;
298 Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
299 } else {
300 None
301 };
302
303 Ok((token.token.secret().to_string(), expires_in))
304 }
305
306 pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
315 let (token, expires_in) = self.get_token_with_expiry().await?;
316 match expires_in {
317 Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
318 None => Ok(AzureAdAuth::with_token(token)),
319 }
320 }
321}
322
323fn is_base64(data: &[u8]) -> bool {
325 data.iter().all(|&b| {
328 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=' || b == b'\n' || b == b'\r'
329 })
330}
331
332impl Clone for CertificateAuth {
333 fn clone(&self) -> Self {
334 Self {
335 credential: Arc::clone(&self.credential),
336 }
337 }
338}
339
340impl std::fmt::Debug for CertificateAuth {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 f.debug_struct("CertificateAuth")
343 .field("credential", &"[REDACTED]")
344 .finish()
345 }
346}
347
348impl crate::provider::AsyncAuthProvider for CertificateAuth {
349 fn method(&self) -> crate::provider::AuthMethod {
350 crate::provider::AuthMethod::AzureAd
351 }
352
353 async fn authenticate_async(&self) -> Result<crate::provider::AuthData, AuthError> {
354 let token = self.get_token().await?;
355 Ok(crate::provider::AuthData::FedAuth { token, nonce: None })
356 }
357
358 fn needs_refresh(&self) -> bool {
359 false
361 }
362}
363
364#[cfg(test)]
365#[allow(clippy::unwrap_used, clippy::expect_used)]
366mod tests {
367 use super::*;
368
369 #[test]
374 fn test_is_base64() {
375 assert!(is_base64(b"SGVsbG8gV29ybGQ="));
376 assert!(is_base64(b"MIIC+jCCAeKgAwIBAgIJAL"));
377 assert!(!is_base64(&[0x00, 0x01, 0x02, 0x03])); }
379
380 #[tokio::test]
381 #[ignore = "Requires Azure Service Principal with certificate"]
382 async fn test_certificate_auth() {
383 let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
384 let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
385 let cert_path = std::env::var("AZURE_CLIENT_CERTIFICATE_PATH")
386 .expect("AZURE_CLIENT_CERTIFICATE_PATH not set");
387 let cert_password = std::env::var("AZURE_CLIENT_CERTIFICATE_PASSWORD").ok();
388
389 let cert_bytes = std::fs::read(&cert_path).expect("Failed to read certificate");
390 let auth = CertificateAuth::new(tenant_id, client_id, cert_bytes, cert_password.as_deref())
391 .expect("Failed to create CertificateAuth");
392
393 let token = auth.get_token().await.expect("Failed to get token");
394 assert!(!token.is_empty());
395 }
396
397 #[test]
398 fn test_from_pem_invalid_certificate() {
399 let invalid_cert = b"not a valid PEM certificate";
400 let valid_key_format = b"-----BEGIN PRIVATE KEY-----\nMIIE=\n-----END PRIVATE KEY-----";
401
402 let result = CertificateAuth::from_pem(
403 "tenant-id",
404 "client-id",
405 invalid_cert,
406 valid_key_format,
407 None,
408 );
409
410 assert!(result.is_err());
411 let err = result.unwrap_err();
412 assert!(
413 err.to_string().contains("No certificate found"),
414 "Expected 'No certificate found' error, got: {err}"
415 );
416 }
417
418 #[test]
419 fn test_from_pem_invalid_private_key() {
420 let cert_pem =
422 b"-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpE=\n-----END CERTIFICATE-----";
423 let invalid_key = b"not a valid PEM private key";
424
425 let result =
426 CertificateAuth::from_pem("tenant-id", "client-id", cert_pem, invalid_key, None);
427
428 assert!(result.is_err());
429 let err = result.unwrap_err();
430 assert!(
431 err.to_string().contains("No private key found"),
432 "Expected 'No private key found' error, got: {err}"
433 );
434 }
435
436 #[test]
437 fn test_from_pem_empty_certificate() {
438 let empty_cert = b"";
439 let key_pem = b"-----BEGIN PRIVATE KEY-----\nMIIE=\n-----END PRIVATE KEY-----";
440
441 let result = CertificateAuth::from_pem("tenant-id", "client-id", empty_cert, key_pem, None);
442
443 assert!(result.is_err());
444 }
445
446 #[test]
447 fn test_from_pem_empty_private_key() {
448 let cert_pem =
449 b"-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpE=\n-----END CERTIFICATE-----";
450 let empty_key = b"";
451
452 let result = CertificateAuth::from_pem("tenant-id", "client-id", cert_pem, empty_key, None);
453
454 assert!(result.is_err());
455 }
456
457 #[tokio::test]
458 #[ignore = "Requires Azure Service Principal with PEM certificate"]
459 async fn test_certificate_auth_from_pem() {
460 let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
461 let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
462 let cert_path = std::env::var("AZURE_CLIENT_CERTIFICATE_PEM")
463 .expect("AZURE_CLIENT_CERTIFICATE_PEM not set");
464 let key_path = std::env::var("AZURE_CLIENT_PRIVATE_KEY_PEM")
465 .expect("AZURE_CLIENT_PRIVATE_KEY_PEM not set");
466
467 let cert_pem = std::fs::read(&cert_path).expect("Failed to read certificate PEM");
468 let key_pem = std::fs::read(&key_path).expect("Failed to read private key PEM");
469
470 let auth = CertificateAuth::from_pem(tenant_id, client_id, &cert_pem, &key_pem, None)
471 .expect("Failed to create CertificateAuth from PEM");
472
473 let token = auth.get_token().await.expect("Failed to get token");
474 assert!(!token.is_empty());
475 }
476}