Skip to main content

mssql_auth/
azure_ad.rs

1//! Azure AD / Entra ID authentication implementation.
2//!
3//! This module provides Azure AD federated authentication for SQL Server,
4//! supporting both pre-acquired tokens and (with feature flags) token acquisition.
5//!
6//! ## Authentication Flow
7//!
8//! Azure AD authentication uses the TDS FEDAUTH feature extension:
9//!
10//! 1. Client includes FEDAUTH feature in Login7 packet
11//! 2. Server responds with FEDAUTHINFO containing STS URL and SPN
12//! 3. Client acquires token (or uses pre-acquired token)
13//! 4. Client sends FEDAUTH token packet
14//! 5. Server validates token and completes authentication
15//!
16//! ## Token Sources (Tier 1 - Core) ✅ Implemented
17//!
18//! - Pre-acquired access token (user provides token directly)
19//!
20//! ## Token Sources (Tier 2 - azure-identity feature) ✅ Implemented
21//!
22//! These require the `azure-identity` feature flag:
23//!
24//! - `ManagedIdentityAuth` - Azure VM/Container identity
25//! - `ServicePrincipalAuth` - Client ID + Secret
26//!
27//! ## Token Sources (Tier 3 - cert-auth feature) ✅ Implemented
28//!
29//! - `CertificateAuth` - X.509 client certificate
30
31use std::borrow::Cow;
32use std::time::{Duration, Instant};
33
34use bytes::Bytes;
35
36use crate::credentials::Credentials;
37use crate::error::AuthError;
38use crate::provider::{AuthData, AuthMethod, AuthProvider};
39
40/// FEDAUTH library options for Login7.
41///
42/// These values indicate to the server which token library the client uses.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44#[repr(u8)]
45#[non_exhaustive]
46pub enum FedAuthLibrary {
47    /// ADAL (Azure Active Directory Authentication Library) - legacy.
48    Adal = 0x01,
49    /// Security token (raw JWT).
50    SecurityToken = 0x02,
51    /// MSAL (Microsoft Authentication Library) - current.
52    Msal = 0x03,
53}
54
55impl FedAuthLibrary {
56    /// Get the byte value for the FEDAUTH feature extension.
57    #[must_use]
58    pub fn to_byte(self) -> u8 {
59        self as u8
60    }
61}
62
63/// FEDAUTH workflow types.
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65#[non_exhaustive]
66pub enum FedAuthWorkflow {
67    /// Interactive authentication with user sign-in.
68    Interactive,
69    /// Non-interactive with pre-acquired token.
70    NonInteractive,
71    /// Managed identity (system or user-assigned).
72    ManagedIdentity,
73    /// Service principal with secret.
74    ServicePrincipal,
75}
76
77/// Azure AD authentication provider.
78///
79/// This provider supports Azure AD / Entra ID federated authentication
80/// using pre-acquired access tokens (Tier 1) or token acquisition
81/// with the `azure-identity` feature (Tier 2).
82///
83/// # Example
84///
85/// ```rust
86/// use mssql_auth::AzureAdAuth;
87///
88/// // Using a pre-acquired token
89/// let auth = AzureAdAuth::with_token("eyJ0eXAi...");
90/// ```
91#[derive(Clone)]
92pub struct AzureAdAuth {
93    /// The access token.
94    token: Cow<'static, str>,
95    /// When the token expires (if known).
96    expires_at: Option<Instant>,
97    /// The library type to report to the server.
98    library: FedAuthLibrary,
99}
100
101impl AzureAdAuth {
102    /// Create an Azure AD authenticator with a pre-acquired token.
103    ///
104    /// This is the simplest form - provide a valid access token obtained
105    /// from Azure AD / Entra ID via your preferred method.
106    ///
107    /// # Arguments
108    ///
109    /// * `token` - A valid JWT access token for Azure SQL Database
110    pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
111        Self {
112            token: token.into(),
113            expires_at: None,
114            library: FedAuthLibrary::SecurityToken,
115        }
116    }
117
118    /// Create an Azure AD authenticator with a token and expiration.
119    ///
120    /// Providing the expiration time allows the driver to proactively
121    /// refresh tokens before they expire.
122    ///
123    /// # Arguments
124    ///
125    /// * `token` - A valid JWT access token
126    /// * `expires_in` - Duration until the token expires
127    pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
128        Self {
129            token: token.into(),
130            expires_at: Some(Instant::now() + expires_in),
131            library: FedAuthLibrary::SecurityToken,
132        }
133    }
134
135    /// Create from existing credentials.
136    ///
137    /// Returns an error if the credentials are not Azure AD credentials.
138    pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
139        match credentials {
140            Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
141            _ => Err(AuthError::UnsupportedMethod(
142                "AzureAdAuth requires Azure AD credentials".into(),
143            )),
144        }
145    }
146
147    /// Set the library type to report to the server.
148    #[must_use]
149    pub fn with_library(mut self, library: FedAuthLibrary) -> Self {
150        self.library = library;
151        self
152    }
153
154    /// Check if the token is expired.
155    #[must_use]
156    pub fn is_expired(&self) -> bool {
157        self.expires_at
158            .map(|exp| Instant::now() >= exp)
159            .unwrap_or(false)
160    }
161
162    /// Check if the token is expiring soon (within the given duration).
163    #[must_use]
164    pub fn is_expiring_soon(&self, within: Duration) -> bool {
165        self.expires_at
166            .map(|exp| Instant::now() + within >= exp)
167            .unwrap_or(false)
168    }
169
170    /// Build the FEDAUTH feature extension data for Login7.
171    ///
172    /// Format:
173    /// - 1 byte: Library type (ADAL=1, SecurityToken=2, MSAL=3)
174    /// - 1 byte: Workflow (0x00 for pre-acquired token)
175    /// - 4 bytes: FedAuth token length (big-endian)
176    /// - N bytes: FedAuth token (UTF-16LE encoded)
177    #[must_use]
178    pub fn build_feature_data(&self) -> Bytes {
179        let mut data = Vec::with_capacity(6);
180
181        // Library type (1 byte)
182        data.push(self.library.to_byte());
183
184        // Workflow - 0x00 for non-interactive/pre-acquired token
185        data.push(0x00);
186
187        // For FEDAUTH, the actual token is sent in a separate FEDAUTH token packet,
188        // not in the Login7 feature extension. The feature extension just indicates
189        // that we want to use FEDAUTH.
190
191        Bytes::from(data)
192    }
193
194    /// Build the FEDAUTH token packet data.
195    ///
196    /// This is the token data sent in response to FEDAUTHINFO from the server.
197    #[must_use]
198    pub fn build_token_data(&self) -> Bytes {
199        // Token is sent as UTF-16LE
200        let token_utf16: Vec<u8> = self
201            .token
202            .encode_utf16()
203            .flat_map(|c| c.to_le_bytes())
204            .collect();
205
206        let mut data = Vec::with_capacity(4 + token_utf16.len());
207
208        // Token length (4 bytes, little-endian)
209        data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
210
211        // Token data (UTF-16LE)
212        data.extend_from_slice(&token_utf16);
213
214        Bytes::from(data)
215    }
216}
217
218impl AuthProvider for AzureAdAuth {
219    fn method(&self) -> AuthMethod {
220        AuthMethod::AzureAd
221    }
222
223    fn authenticate(&self) -> Result<AuthData, AuthError> {
224        if self.is_expired() {
225            return Err(AuthError::TokenExpired);
226        }
227
228        tracing::debug!("authenticating with Azure AD token");
229
230        Ok(AuthData::FedAuth {
231            token: self.token.to_string(),
232            nonce: None,
233        })
234    }
235
236    fn feature_extension_data(&self) -> Option<Bytes> {
237        Some(self.build_feature_data())
238    }
239
240    fn needs_refresh(&self) -> bool {
241        // Refresh if token expires within 5 minutes
242        self.is_expiring_soon(Duration::from_secs(300))
243    }
244}
245
246impl std::fmt::Debug for AzureAdAuth {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        f.debug_struct("AzureAdAuth")
249            .field("token", &"[REDACTED]")
250            .field("expires_at", &self.expires_at)
251            .field("library", &self.library)
252            .finish()
253    }
254}
255
256#[cfg(test)]
257#[allow(clippy::unwrap_used, clippy::panic)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_azure_ad_with_token() {
263        let auth = AzureAdAuth::with_token("test_token");
264        assert_eq!(auth.method(), AuthMethod::AzureAd);
265        assert!(!auth.is_expired());
266    }
267
268    #[test]
269    fn test_azure_ad_with_expiring_token() {
270        let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
271        assert!(!auth.is_expired());
272        assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
273    }
274
275    #[test]
276    fn test_azure_ad_expired_token() {
277        let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
278        // Token with 0 duration should be expired immediately (or very soon)
279        std::thread::sleep(Duration::from_millis(10));
280        assert!(auth.is_expired());
281
282        let result = auth.authenticate();
283        assert!(matches!(result, Err(AuthError::TokenExpired)));
284    }
285
286    #[test]
287    fn test_azure_ad_feature_data() {
288        let auth = AzureAdAuth::with_token("test_token");
289        let data = auth.build_feature_data();
290
291        assert!(!data.is_empty());
292        assert_eq!(data[0], FedAuthLibrary::SecurityToken.to_byte());
293    }
294
295    #[test]
296    fn test_azure_ad_token_data() {
297        let auth = AzureAdAuth::with_token("AB");
298        let data = auth.build_token_data();
299
300        // Length (4 bytes) + "AB" in UTF-16LE (4 bytes)
301        assert_eq!(data.len(), 8);
302        // Length is 4 (2 UTF-16 code units * 2 bytes each)
303        assert_eq!(&data[0..4], &[4, 0, 0, 0]);
304    }
305
306    #[test]
307    fn test_from_credentials() {
308        let creds = Credentials::azure_token("my_token");
309        let auth = AzureAdAuth::from_credentials(&creds).unwrap();
310
311        let data = auth.authenticate().unwrap();
312        match &data {
313            AuthData::FedAuth { token, .. } => {
314                assert_eq!(token, "my_token");
315            }
316            _ => panic!("Expected FedAuth data"),
317        }
318    }
319
320    #[test]
321    fn test_from_credentials_wrong_type() {
322        let creds = Credentials::sql_server("user", "pass");
323        let result = AzureAdAuth::from_credentials(&creds);
324        assert!(result.is_err());
325    }
326
327    #[test]
328    fn test_debug_redacts_token() {
329        let auth = AzureAdAuth::with_token("secret_token");
330        let debug = format!("{auth:?}");
331        assert!(!debug.contains("secret_token"));
332        assert!(debug.contains("[REDACTED]"));
333    }
334}