Skip to main content

mssql_auth/
cert_auth.rs

1//! Client certificate authentication provider.
2//!
3//! This module provides Azure AD authentication using a client certificate
4//! (X.509) instead of a client secret. This is more secure than using secrets
5//! because certificates can be stored in secure hardware (HSM) and have
6//! built-in expiration.
7//!
8//! ## How It Works
9//!
10//! Certificate authentication uses an Azure AD Service Principal with an
11//! X.509 certificate. The certificate's private key is used to sign a JWT
12//! assertion, which Azure AD validates using the certificate's public key
13//! registered with the application.
14//!
15//! **Important**: This is NOT TDS-level mTLS. SQL Server/Azure SQL do not
16//! support client certificates at the TDS protocol level. Instead, the
17//! certificate authenticates to Azure AD, which issues an access token
18//! used for SQL authentication.
19//!
20//! ## Prerequisites
21//!
22//! 1. Create an Azure AD App Registration
23//! 2. Generate or upload a certificate to the app registration
24//! 3. Export the certificate (PKCS#12 or PEM format)
25//! 4. Grant the service principal access to your Azure SQL database
26//!
27//! ## Example (PKCS#12)
28//!
29//! ```rust,ignore
30//! use mssql_auth::CertificateAuth;
31//! use std::fs;
32//!
33//! // Load PKCS#12 certificate from file
34//! let cert_bytes = fs::read("service-principal.pfx")?;
35//!
36//! let auth = CertificateAuth::new(
37//!     "your-tenant-id",
38//!     "your-client-id",
39//!     cert_bytes,
40//!     Some("certificate-password"),
41//! )?;
42//!
43//! // Get access token for Azure SQL
44//! let token = auth.get_token().await?;
45//! ```
46//!
47//! ## Example (PEM)
48//!
49//! PEM certificates are common in Linux/Kubernetes environments:
50//!
51//! ```rust,ignore
52//! use mssql_auth::CertificateAuth;
53//! use std::fs;
54//!
55//! // Load PEM certificate and private key
56//! let cert_pem = fs::read("cert.pem")?;
57//! let key_pem = fs::read("key.pem")?;
58//!
59//! let auth = CertificateAuth::from_pem(
60//!     "your-tenant-id",
61//!     "your-client-id",
62//!     &cert_pem,
63//!     &key_pem,
64//!     None, // optional password
65//! )?;
66//!
67//! let token = auth.get_token().await?;
68//! ```
69//!
70//! ## Security Considerations
71//!
72//! - Store certificates in Azure Key Vault or secure hardware when possible
73//! - Use certificates with appropriate key sizes (RSA 2048+ or ECDSA P-256+)
74//! - Set reasonable certificate expiration (1-2 years)
75//! - Rotate certificates before expiration
76//! - Never commit certificates to source control
77
78use std::sync::Arc;
79use std::time::Duration;
80
81use azure_core::credentials::TokenCredential;
82use azure_identity::ClientCertificateCredential;
83
84use crate::AzureAdAuth;
85use crate::error::AuthError;
86
87/// The Azure SQL Database scope for token requests.
88const AZURE_SQL_SCOPE: &str = "https://database.windows.net/.default";
89
90/// Client certificate authentication provider.
91///
92/// Uses an X.509 certificate to authenticate as an Azure AD Service Principal,
93/// then acquires an access token for Azure SQL Database.
94///
95/// # Security
96///
97/// Certificate authentication is more secure than client secrets because:
98/// - Certificates have built-in expiration
99/// - Private keys can be stored in secure hardware (HSM/TPM)
100/// - Certificates support hardware-based attestation
101/// - Certificate rotation doesn't require application restarts
102pub struct CertificateAuth {
103    credential: Arc<ClientCertificateCredential>,
104}
105
106impl CertificateAuth {
107    /// Create a new certificate authentication provider.
108    ///
109    /// # Arguments
110    ///
111    /// * `tenant_id` - The Azure AD tenant ID
112    /// * `client_id` - The application (client) ID of the service principal
113    /// * `certificate` - The PKCS#12 (.pfx) certificate bytes (base64-encoded or raw)
114    /// * `password` - Optional password for the certificate's private key
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the certificate cannot be parsed or the credential
119    /// cannot be created.
120    ///
121    /// # Example
122    ///
123    /// ```rust,ignore
124    /// use mssql_auth::CertificateAuth;
125    /// use std::fs;
126    ///
127    /// let cert = fs::read("app.pfx")?;
128    /// let auth = CertificateAuth::new(
129    ///     "tenant-id",
130    ///     "client-id",
131    ///     cert,
132    ///     Some("cert-password"),
133    /// )?;
134    /// ```
135    pub fn new(
136        tenant_id: impl AsRef<str>,
137        client_id: impl Into<String>,
138        certificate: impl AsRef<[u8]>,
139        password: Option<&str>,
140    ) -> Result<Self, AuthError> {
141        use azure_core::credentials::Secret;
142        use azure_identity::ClientCertificateCredentialOptions;
143        use base64::Engine;
144
145        // The certificate should be base64-encoded PKCS#12
146        // If it's raw bytes, encode it first
147        let cert_bytes = certificate.as_ref();
148        let cert_b64 = if cert_bytes.starts_with(b"MII") || is_base64(cert_bytes) {
149            // Already looks like base64
150            String::from_utf8_lossy(cert_bytes).to_string()
151        } else {
152            // Raw PKCS#12 bytes - encode to base64
153            base64::engine::general_purpose::STANDARD.encode(cert_bytes)
154        };
155
156        let cert_secret = azure_core::credentials::SecretBytes::new(cert_b64.into_bytes());
157
158        // Create options with password if provided
159        // Note: send_certificate_chain is now controlled by AZURE_CLIENT_SEND_CERTIFICATE_CHAIN env var
160        let options = if let Some(pwd) = password {
161            ClientCertificateCredentialOptions {
162                password: Some(Secret::new(pwd.to_string())),
163                ..Default::default()
164            }
165        } else {
166            ClientCertificateCredentialOptions::default()
167        };
168
169        let credential = ClientCertificateCredential::new(
170            tenant_id.as_ref().to_string(),
171            client_id.into(),
172            cert_secret,
173            Some(options),
174        )
175        .map_err(|e| {
176            AuthError::Certificate(format!("Failed to create certificate credential: {e}"))
177        })?;
178
179        Ok(Self { credential })
180    }
181
182    /// Create a new certificate authentication provider from PEM-encoded files.
183    ///
184    /// This is a convenience method for users who have PEM-formatted certificates
185    /// (common in Linux/Kubernetes environments) rather than PKCS#12 format.
186    ///
187    /// # Arguments
188    ///
189    /// * `tenant_id` - The Azure AD tenant ID
190    /// * `client_id` - The application (client) ID of the service principal
191    /// * `cert_pem` - The PEM-encoded certificate (typically from a `.pem` or `.crt` file)
192    /// * `key_pem` - The PEM-encoded private key (typically from a `.key` or `.pem` file)
193    /// * `password` - Optional password for the PKCS#12 bundle (used during conversion)
194    ///
195    /// # Errors
196    ///
197    /// Returns an error if:
198    /// - The certificate PEM cannot be parsed
199    /// - The private key PEM cannot be parsed
200    /// - The PEM-to-PKCS#12 conversion fails
201    /// - The credential cannot be created
202    ///
203    /// # Example
204    ///
205    /// ```rust,ignore
206    /// use mssql_auth::CertificateAuth;
207    /// use std::fs;
208    ///
209    /// let cert_pem = fs::read("cert.pem")?;
210    /// let key_pem = fs::read("key.pem")?;
211    ///
212    /// let auth = CertificateAuth::from_pem(
213    ///     "tenant-id",
214    ///     "client-id",
215    ///     &cert_pem,
216    ///     &key_pem,
217    ///     None, // or Some("pkcs12-password")
218    /// )?;
219    /// ```
220    pub fn from_pem(
221        tenant_id: impl AsRef<str>,
222        client_id: impl Into<String>,
223        cert_pem: impl AsRef<[u8]>,
224        key_pem: impl AsRef<[u8]>,
225        password: Option<&str>,
226    ) -> Result<Self, AuthError> {
227        use std::io::BufReader;
228
229        // Parse certificate from PEM
230        let cert_pem_bytes = cert_pem.as_ref();
231        let mut cert_reader = BufReader::new(cert_pem_bytes);
232        let certs: Vec<_> = rustls_pemfile::certs(&mut cert_reader)
233            .collect::<Result<Vec<_>, _>>()
234            .map_err(|e| AuthError::Certificate(format!("Failed to parse certificate PEM: {e}")))?;
235
236        let cert_der = certs
237            .first()
238            .ok_or_else(|| AuthError::Certificate("No certificate found in PEM data".into()))?;
239
240        // Parse private key from PEM
241        let key_pem_bytes = key_pem.as_ref();
242        let mut key_reader = BufReader::new(key_pem_bytes);
243        let key_der = rustls_pemfile::private_key(&mut key_reader)
244            .map_err(|e| AuthError::Certificate(format!("Failed to parse private key PEM: {e}")))?
245            .ok_or_else(|| AuthError::Certificate("No private key found in PEM data".into()))?;
246
247        // Convert to PKCS#12 format
248        let pkcs12_password = password.unwrap_or("");
249        let pfx = p12::PFX::new(
250            cert_der.as_ref(),
251            key_der.secret_der(),
252            None, // No CA certificate
253            pkcs12_password,
254            "cert",
255        )
256        .ok_or_else(|| AuthError::Certificate("Failed to create PKCS#12 from PEM data".into()))?;
257
258        let pkcs12_bytes = pfx.to_der();
259
260        // Use existing constructor with the converted PKCS#12
261        Self::new(tenant_id, client_id, pkcs12_bytes, password)
262    }
263
264    /// Get an access token for Azure SQL Database.
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if token acquisition fails (e.g., certificate invalid,
269    /// network error, insufficient permissions).
270    pub async fn get_token(&self) -> Result<String, AuthError> {
271        let token = self
272            .credential
273            .get_token(&[AZURE_SQL_SCOPE], None)
274            .await
275            .map_err(|e| AuthError::Certificate(format!("Failed to acquire token: {e}")))?;
276        Ok(token.token.secret().to_string())
277    }
278
279    /// Get an access token with expiration information.
280    ///
281    /// # Errors
282    ///
283    /// Returns an error if token acquisition fails.
284    pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
285        let token = self
286            .credential
287            .get_token(&[AZURE_SQL_SCOPE], None)
288            .await
289            .map_err(|e| AuthError::Certificate(format!("Failed to acquire token: {e}")))?;
290
291        // Calculate time until expiration
292        let now = time::OffsetDateTime::now_utc();
293        let expires_in = if token.expires_on > now {
294            let diff = token.expires_on - now;
295            Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
296        } else {
297            None
298        };
299
300        Ok((token.token.secret().to_string(), expires_in))
301    }
302
303    /// Convert to an `AzureAdAuth` provider with an acquired token.
304    ///
305    /// This is useful when you need to use the token with APIs that
306    /// expect `AzureAdAuth`.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if token acquisition fails.
311    pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
312        let (token, expires_in) = self.get_token_with_expiry().await?;
313        match expires_in {
314            Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
315            None => Ok(AzureAdAuth::with_token(token)),
316        }
317    }
318}
319
320/// Check if bytes look like base64-encoded data.
321fn is_base64(data: &[u8]) -> bool {
322    // Simple heuristic: base64 contains only alphanumeric, +, /, =
323    // and PKCS#12 raw data would have binary bytes
324    data.iter().all(|&b| {
325        b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=' || b == b'\n' || b == b'\r'
326    })
327}
328
329impl Clone for CertificateAuth {
330    fn clone(&self) -> Self {
331        Self {
332            credential: Arc::clone(&self.credential),
333        }
334    }
335}
336
337impl std::fmt::Debug for CertificateAuth {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.debug_struct("CertificateAuth")
340            .field("credential", &"[REDACTED]")
341            .finish()
342    }
343}
344
345impl crate::provider::AsyncAuthProvider for CertificateAuth {
346    fn method(&self) -> crate::provider::AuthMethod {
347        crate::provider::AuthMethod::AzureAd
348    }
349
350    async fn authenticate_async(&self) -> Result<crate::provider::AuthData, AuthError> {
351        let token = self.get_token().await?;
352        Ok(crate::provider::AuthData::FedAuth { token, nonce: None })
353    }
354
355    fn needs_refresh(&self) -> bool {
356        // Certificate-based tokens are acquired fresh each time
357        false
358    }
359}
360
361#[cfg(test)]
362#[allow(clippy::unwrap_used, clippy::expect_used)]
363mod tests {
364    use super::*;
365
366    // Note: These tests require Azure credentials and a valid certificate.
367    // They are marked as ignored and can be run manually with:
368    // cargo test --features cert-auth -- --ignored
369
370    #[test]
371    fn test_is_base64() {
372        assert!(is_base64(b"SGVsbG8gV29ybGQ="));
373        assert!(is_base64(b"MIIC+jCCAeKgAwIBAgIJAL"));
374        assert!(!is_base64(&[0x00, 0x01, 0x02, 0x03])); // Binary data
375    }
376
377    #[tokio::test]
378    #[ignore = "Requires Azure Service Principal with certificate"]
379    async fn test_certificate_auth() {
380        let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
381        let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
382        let cert_path = std::env::var("AZURE_CLIENT_CERTIFICATE_PATH")
383            .expect("AZURE_CLIENT_CERTIFICATE_PATH not set");
384        let cert_password = std::env::var("AZURE_CLIENT_CERTIFICATE_PASSWORD").ok();
385
386        let cert_bytes = std::fs::read(&cert_path).expect("Failed to read certificate");
387        let auth = CertificateAuth::new(tenant_id, client_id, cert_bytes, cert_password.as_deref())
388            .expect("Failed to create CertificateAuth");
389
390        let token = auth.get_token().await.expect("Failed to get token");
391        assert!(!token.is_empty());
392    }
393
394    #[test]
395    fn test_from_pem_invalid_certificate() {
396        let invalid_cert = b"not a valid PEM certificate";
397        let valid_key_format = b"-----BEGIN PRIVATE KEY-----\nMIIE=\n-----END PRIVATE KEY-----";
398
399        let result = CertificateAuth::from_pem(
400            "tenant-id",
401            "client-id",
402            invalid_cert,
403            valid_key_format,
404            None,
405        );
406
407        assert!(result.is_err());
408        let err = result.unwrap_err();
409        assert!(
410            err.to_string().contains("No certificate found"),
411            "Expected 'No certificate found' error, got: {err}"
412        );
413    }
414
415    #[test]
416    fn test_from_pem_invalid_private_key() {
417        // Valid PEM structure but not actually a valid cert (will fail at PKCS#12 conversion)
418        let cert_pem =
419            b"-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpE=\n-----END CERTIFICATE-----";
420        let invalid_key = b"not a valid PEM private key";
421
422        let result =
423            CertificateAuth::from_pem("tenant-id", "client-id", cert_pem, invalid_key, None);
424
425        assert!(result.is_err());
426        let err = result.unwrap_err();
427        assert!(
428            err.to_string().contains("No private key found"),
429            "Expected 'No private key found' error, got: {err}"
430        );
431    }
432
433    #[test]
434    fn test_from_pem_empty_certificate() {
435        let empty_cert = b"";
436        let key_pem = b"-----BEGIN PRIVATE KEY-----\nMIIE=\n-----END PRIVATE KEY-----";
437
438        let result = CertificateAuth::from_pem("tenant-id", "client-id", empty_cert, key_pem, None);
439
440        assert!(result.is_err());
441    }
442
443    #[test]
444    fn test_from_pem_empty_private_key() {
445        let cert_pem =
446            b"-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpE=\n-----END CERTIFICATE-----";
447        let empty_key = b"";
448
449        let result = CertificateAuth::from_pem("tenant-id", "client-id", cert_pem, empty_key, None);
450
451        assert!(result.is_err());
452    }
453
454    #[tokio::test]
455    #[ignore = "Requires Azure Service Principal with PEM certificate"]
456    async fn test_certificate_auth_from_pem() {
457        let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
458        let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
459        let cert_path = std::env::var("AZURE_CLIENT_CERTIFICATE_PEM")
460            .expect("AZURE_CLIENT_CERTIFICATE_PEM not set");
461        let key_path = std::env::var("AZURE_CLIENT_PRIVATE_KEY_PEM")
462            .expect("AZURE_CLIENT_PRIVATE_KEY_PEM not set");
463
464        let cert_pem = std::fs::read(&cert_path).expect("Failed to read certificate PEM");
465        let key_pem = std::fs::read(&key_path).expect("Failed to read private key PEM");
466
467        let auth = CertificateAuth::from_pem(tenant_id, client_id, &cert_pem, &key_pem, None)
468            .expect("Failed to create CertificateAuth from PEM");
469
470        let token = auth.get_token().await.expect("Failed to get token");
471        assert!(!token.is_empty());
472    }
473}