mssql_auth/
azure_ad.rs

1//! Azure AD / Entra ID authentication implementation.
2//!
3//! This module provides Azure AD federated authentication for SQL Server,
4//! supporting both pre-acquired tokens and (with feature flags) token acquisition.
5//!
6//! ## Authentication Flow
7//!
8//! Azure AD authentication uses the TDS FEDAUTH feature extension:
9//!
10//! 1. Client includes FEDAUTH feature in Login7 packet
11//! 2. Server responds with FEDAUTHINFO containing STS URL and SPN
12//! 3. Client acquires token (or uses pre-acquired token)
13//! 4. Client sends FEDAUTH token packet
14//! 5. Server validates token and completes authentication
15//!
16//! ## Token Sources (Tier 1 - Core) ✅ Implemented
17//!
18//! - Pre-acquired access token (user provides token directly)
19//!
20//! ## Token Sources (Tier 2 - azure-identity feature) ✅ Implemented
21//!
22//! These require the `azure-identity` feature flag:
23//!
24//! - `ManagedIdentityAuth` - Azure VM/Container identity
25//! - `ServicePrincipalAuth` - Client ID + Secret
26//!
27//! ## Token Sources (Tier 3 - cert-auth feature) ✅ Implemented
28//!
29//! - `CertificateAuth` - X.509 client certificate
30
31use std::borrow::Cow;
32use std::time::{Duration, Instant};
33
34use bytes::Bytes;
35
36use crate::credentials::Credentials;
37use crate::error::AuthError;
38use crate::provider::{AuthData, AuthMethod, AuthProvider};
39
40/// FEDAUTH library options for Login7.
41///
42/// These values indicate to the server which token library the client uses.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44#[repr(u8)]
45pub enum FedAuthLibrary {
46    /// ADAL (Azure Active Directory Authentication Library) - legacy.
47    Adal = 0x01,
48    /// Security token (raw JWT).
49    SecurityToken = 0x02,
50    /// MSAL (Microsoft Authentication Library) - current.
51    Msal = 0x03,
52}
53
54impl FedAuthLibrary {
55    /// Get the byte value for the FEDAUTH feature extension.
56    #[must_use]
57    pub fn to_byte(self) -> u8 {
58        self as u8
59    }
60}
61
62/// FEDAUTH workflow types.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum FedAuthWorkflow {
65    /// Interactive authentication with user sign-in.
66    Interactive,
67    /// Non-interactive with pre-acquired token.
68    NonInteractive,
69    /// Managed identity (system or user-assigned).
70    ManagedIdentity,
71    /// Service principal with secret.
72    ServicePrincipal,
73}
74
75/// Azure AD authentication provider.
76///
77/// This provider supports Azure AD / Entra ID federated authentication
78/// using pre-acquired access tokens (Tier 1) or token acquisition
79/// with the `azure-identity` feature (Tier 2).
80///
81/// # Example
82///
83/// ```rust
84/// use mssql_auth::AzureAdAuth;
85///
86/// // Using a pre-acquired token
87/// let auth = AzureAdAuth::with_token("eyJ0eXAi...");
88/// ```
89#[derive(Clone)]
90pub struct AzureAdAuth {
91    /// The access token.
92    token: Cow<'static, str>,
93    /// When the token expires (if known).
94    expires_at: Option<Instant>,
95    /// The library type to report to the server.
96    library: FedAuthLibrary,
97}
98
99impl AzureAdAuth {
100    /// Create an Azure AD authenticator with a pre-acquired token.
101    ///
102    /// This is the simplest form - provide a valid access token obtained
103    /// from Azure AD / Entra ID via your preferred method.
104    ///
105    /// # Arguments
106    ///
107    /// * `token` - A valid JWT access token for Azure SQL Database
108    pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
109        Self {
110            token: token.into(),
111            expires_at: None,
112            library: FedAuthLibrary::SecurityToken,
113        }
114    }
115
116    /// Create an Azure AD authenticator with a token and expiration.
117    ///
118    /// Providing the expiration time allows the driver to proactively
119    /// refresh tokens before they expire.
120    ///
121    /// # Arguments
122    ///
123    /// * `token` - A valid JWT access token
124    /// * `expires_in` - Duration until the token expires
125    pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
126        Self {
127            token: token.into(),
128            expires_at: Some(Instant::now() + expires_in),
129            library: FedAuthLibrary::SecurityToken,
130        }
131    }
132
133    /// Create from existing credentials.
134    ///
135    /// Returns an error if the credentials are not Azure AD credentials.
136    pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
137        match credentials {
138            Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
139            _ => Err(AuthError::UnsupportedMethod(
140                "AzureAdAuth requires Azure AD credentials".into(),
141            )),
142        }
143    }
144
145    /// Set the library type to report to the server.
146    #[must_use]
147    pub fn with_library(mut self, library: FedAuthLibrary) -> Self {
148        self.library = library;
149        self
150    }
151
152    /// Check if the token is expired.
153    #[must_use]
154    pub fn is_expired(&self) -> bool {
155        self.expires_at
156            .map(|exp| Instant::now() >= exp)
157            .unwrap_or(false)
158    }
159
160    /// Check if the token is expiring soon (within the given duration).
161    #[must_use]
162    pub fn is_expiring_soon(&self, within: Duration) -> bool {
163        self.expires_at
164            .map(|exp| Instant::now() + within >= exp)
165            .unwrap_or(false)
166    }
167
168    /// Build the FEDAUTH feature extension data for Login7.
169    ///
170    /// Format:
171    /// - 1 byte: Library type (ADAL=1, SecurityToken=2, MSAL=3)
172    /// - 1 byte: Workflow (0x00 for pre-acquired token)
173    /// - 4 bytes: FedAuth token length (big-endian)
174    /// - N bytes: FedAuth token (UTF-16LE encoded)
175    #[must_use]
176    pub fn build_feature_data(&self) -> Bytes {
177        let mut data = Vec::with_capacity(6);
178
179        // Library type (1 byte)
180        data.push(self.library.to_byte());
181
182        // Workflow - 0x00 for non-interactive/pre-acquired token
183        data.push(0x00);
184
185        // For FEDAUTH, the actual token is sent in a separate FEDAUTH token packet,
186        // not in the Login7 feature extension. The feature extension just indicates
187        // that we want to use FEDAUTH.
188
189        Bytes::from(data)
190    }
191
192    /// Build the FEDAUTH token packet data.
193    ///
194    /// This is the token data sent in response to FEDAUTHINFO from the server.
195    #[must_use]
196    pub fn build_token_data(&self) -> Bytes {
197        // Token is sent as UTF-16LE
198        let token_utf16: Vec<u8> = self
199            .token
200            .encode_utf16()
201            .flat_map(|c| c.to_le_bytes())
202            .collect();
203
204        let mut data = Vec::with_capacity(4 + token_utf16.len());
205
206        // Token length (4 bytes, little-endian)
207        data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
208
209        // Token data (UTF-16LE)
210        data.extend_from_slice(&token_utf16);
211
212        Bytes::from(data)
213    }
214}
215
216impl AuthProvider for AzureAdAuth {
217    fn method(&self) -> AuthMethod {
218        AuthMethod::AzureAd
219    }
220
221    fn authenticate(&self) -> Result<AuthData, AuthError> {
222        if self.is_expired() {
223            return Err(AuthError::TokenExpired);
224        }
225
226        tracing::debug!("authenticating with Azure AD token");
227
228        Ok(AuthData::FedAuth {
229            token: self.token.to_string(),
230            nonce: None,
231        })
232    }
233
234    fn feature_extension_data(&self) -> Option<Bytes> {
235        Some(self.build_feature_data())
236    }
237
238    fn needs_refresh(&self) -> bool {
239        // Refresh if token expires within 5 minutes
240        self.is_expiring_soon(Duration::from_secs(300))
241    }
242}
243
244impl std::fmt::Debug for AzureAdAuth {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("AzureAdAuth")
247            .field("token", &"[REDACTED]")
248            .field("expires_at", &self.expires_at)
249            .field("library", &self.library)
250            .finish()
251    }
252}
253
254#[cfg(test)]
255#[allow(clippy::unwrap_used, clippy::panic)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_azure_ad_with_token() {
261        let auth = AzureAdAuth::with_token("test_token");
262        assert_eq!(auth.method(), AuthMethod::AzureAd);
263        assert!(!auth.is_expired());
264    }
265
266    #[test]
267    fn test_azure_ad_with_expiring_token() {
268        let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
269        assert!(!auth.is_expired());
270        assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
271    }
272
273    #[test]
274    fn test_azure_ad_expired_token() {
275        let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
276        // Token with 0 duration should be expired immediately (or very soon)
277        std::thread::sleep(Duration::from_millis(10));
278        assert!(auth.is_expired());
279
280        let result = auth.authenticate();
281        assert!(matches!(result, Err(AuthError::TokenExpired)));
282    }
283
284    #[test]
285    fn test_azure_ad_feature_data() {
286        let auth = AzureAdAuth::with_token("test_token");
287        let data = auth.build_feature_data();
288
289        assert!(!data.is_empty());
290        assert_eq!(data[0], FedAuthLibrary::SecurityToken.to_byte());
291    }
292
293    #[test]
294    fn test_azure_ad_token_data() {
295        let auth = AzureAdAuth::with_token("AB");
296        let data = auth.build_token_data();
297
298        // Length (4 bytes) + "AB" in UTF-16LE (4 bytes)
299        assert_eq!(data.len(), 8);
300        // Length is 4 (2 UTF-16 code units * 2 bytes each)
301        assert_eq!(&data[0..4], &[4, 0, 0, 0]);
302    }
303
304    #[test]
305    fn test_from_credentials() {
306        let creds = Credentials::azure_token("my_token");
307        let auth = AzureAdAuth::from_credentials(&creds).unwrap();
308
309        let data = auth.authenticate().unwrap();
310        match &data {
311            AuthData::FedAuth { token, .. } => {
312                assert_eq!(token, "my_token");
313            }
314            _ => panic!("Expected FedAuth data"),
315        }
316    }
317
318    #[test]
319    fn test_from_credentials_wrong_type() {
320        let creds = Credentials::sql_server("user", "pass");
321        let result = AzureAdAuth::from_credentials(&creds);
322        assert!(result.is_err());
323    }
324
325    #[test]
326    fn test_debug_redacts_token() {
327        let auth = AzureAdAuth::with_token("secret_token");
328        let debug = format!("{:?}", auth);
329        assert!(!debug.contains("secret_token"));
330        assert!(debug.contains("[REDACTED]"));
331    }
332}