Skip to main content

mssql_auth/
azure_ad.rs

1//! Azure AD / Entra ID authentication implementation.
2//!
3//! This module provides token handling for Azure AD federated authentication,
4//! supporting both pre-acquired tokens and (with feature flags) token acquisition.
5//!
6//! ## Authentication Flow
7//!
8//! Azure AD authentication uses the TDS FEDAUTH feature extension
9//! (MS-TDS §2.2.6.4). Two workflows exist:
10//!
11//! - **SecurityToken** (implemented, #155 Phase 1): the client acquires a
12//!   token *before* login and sends it inside the LOGIN7 FEDAUTH feature
13//!   extension ([`build_security_token_feature_data`]). No FEDAUTHINFO
14//!   round-trip occurs.
15//! - **ADAL/MSAL** (pending, #155 Phase 2): the client declares intent in
16//!   LOGIN7, the server responds with FEDAUTHINFO (STS URL + SPN), and the
17//!   client acquires a token and sends it in a separate FEDAUTH message.
18//!
19//! ## Token Sources (Tier 1 - Core)
20//!
21//! - Pre-acquired access token (user provides token directly)
22//!
23//! ## Token Sources (Tier 2 - azure-identity feature)
24//!
25//! These require the `azure-identity` feature flag:
26//!
27//! - `ManagedIdentityAuth` - Azure VM/Container identity
28//! - `ServicePrincipalAuth` - Client ID + Secret
29//!
30//! ## Token Sources (Tier 3 - cert-auth feature)
31//!
32//! - `CertificateAuth` - X.509 client certificate
33
34use std::borrow::Cow;
35use std::time::{Duration, Instant};
36
37use bytes::{BufMut, Bytes, BytesMut};
38
39use crate::credentials::Credentials;
40use crate::error::AuthError;
41use crate::provider::{AuthData, AuthMethod, AuthProvider};
42
43/// FEDAUTH library identifiers for the LOGIN7 FEDAUTH feature extension.
44///
45/// Per MS-TDS §2.2.6.4, `bFedAuthLibrary` is a 7-bit value that occupies the
46/// high 7 bits of the feature data's Options byte (the low bit is
47/// `fFedAuthEcho`). Live ID Compact Token (0x00) is legacy and not supported;
48/// 0x7F is reserved.
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50#[repr(u8)]
51#[non_exhaustive]
52pub enum FedAuthLibrary {
53    /// Security token: a token acquired by the client before login and sent
54    /// inside LOGIN7 (the workflow used for pre-acquired Azure AD tokens).
55    SecurityToken = 0x01,
56    /// ADAL: the client declares intent and acquires the token after the
57    /// server's FEDAUTHINFO response. MSAL (ADAL's successor) uses this same
58    /// wire identifier.
59    Adal = 0x02,
60}
61
62impl FedAuthLibrary {
63    /// Get the 7-bit library identifier (unshifted).
64    ///
65    /// In the Options byte this value is shifted left by one bit to make room
66    /// for `fFedAuthEcho`: `options = (library << 1) | echo`.
67    #[must_use]
68    pub fn to_byte(self) -> u8 {
69        self as u8
70    }
71}
72
73/// Build the LOGIN7 FEDAUTH feature extension data for the SecurityToken
74/// workflow (MS-TDS §2.2.6.4, `bFedAuthLibrary` = 0x01).
75///
76/// Layout:
77///
78/// ```text
79/// Options       = 1 byte: (0x01 << 1) | fFedAuthEcho
80/// FedAuthToken  = DWORD (LE) byte length + token as UTF-16LE
81/// ```
82///
83/// No trailing nonce is emitted: per spec the nonce MUST be present if and
84/// only if the server's PRELOGIN response carried a NONCE option, and this
85/// driver does not send NONCEOPT in PRELOGIN.
86///
87/// `fed_auth_echo` MUST be set if and only if the server's PRELOGIN response
88/// contained FEDAUTHREQUIRED with value 0x01 — the server validates this echo
89/// to detect tampering.
90///
91/// The token must be non-empty (the spec forbids a zero-length FedAuthToken);
92/// callers are expected to validate this before login.
93#[must_use]
94pub fn build_security_token_feature_data(token: &str, fed_auth_echo: bool) -> Bytes {
95    debug_assert!(!token.is_empty(), "FedAuthToken length MUST NOT be 0");
96
97    let token_utf16: Vec<u8> = token.encode_utf16().flat_map(|c| c.to_le_bytes()).collect();
98
99    let mut data = BytesMut::with_capacity(1 + 4 + token_utf16.len());
100    data.put_u8((FedAuthLibrary::SecurityToken.to_byte() << 1) | u8::from(fed_auth_echo));
101    data.put_u32_le(token_utf16.len() as u32);
102    data.put_slice(&token_utf16);
103    data.freeze()
104}
105
106/// FEDAUTH workflow types.
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[non_exhaustive]
109pub enum FedAuthWorkflow {
110    /// Interactive authentication with user sign-in.
111    Interactive,
112    /// Non-interactive with pre-acquired token.
113    NonInteractive,
114    /// Managed identity (system or user-assigned).
115    ManagedIdentity,
116    /// Service principal with secret.
117    ServicePrincipal,
118}
119
120/// Azure AD authentication provider.
121///
122/// This provider supports Azure AD / Entra ID federated authentication
123/// using pre-acquired access tokens (Tier 1) or token acquisition
124/// with the `azure-identity` feature (Tier 2).
125///
126/// # Example
127///
128/// ```rust
129/// use mssql_auth::AzureAdAuth;
130///
131/// // Using a pre-acquired token
132/// let auth = AzureAdAuth::with_token("eyJ0eXAi...");
133/// ```
134#[derive(Clone)]
135pub struct AzureAdAuth {
136    /// The access token.
137    token: Cow<'static, str>,
138    /// When the token expires (if known).
139    expires_at: Option<Instant>,
140}
141
142impl AzureAdAuth {
143    /// Create an Azure AD authenticator with a pre-acquired token.
144    ///
145    /// This is the simplest form - provide a valid access token obtained
146    /// from Azure AD / Entra ID via your preferred method.
147    ///
148    /// # Arguments
149    ///
150    /// * `token` - A valid JWT access token for Azure SQL Database
151    pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
152        Self {
153            token: token.into(),
154            expires_at: None,
155        }
156    }
157
158    /// Create an Azure AD authenticator with a token and expiration.
159    ///
160    /// Providing the expiration time allows the driver to proactively
161    /// refresh tokens before they expire.
162    ///
163    /// # Arguments
164    ///
165    /// * `token` - A valid JWT access token
166    /// * `expires_in` - Duration until the token expires
167    pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
168        Self {
169            token: token.into(),
170            expires_at: Some(Instant::now() + expires_in),
171        }
172    }
173
174    /// Create from existing credentials.
175    ///
176    /// Returns an error if the credentials are not Azure AD credentials.
177    pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
178        match credentials {
179            Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
180            _ => Err(AuthError::UnsupportedMethod(
181                "AzureAdAuth requires Azure AD credentials".into(),
182            )),
183        }
184    }
185
186    /// Check if the token is expired.
187    #[must_use]
188    pub fn is_expired(&self) -> bool {
189        self.expires_at
190            .map(|exp| Instant::now() >= exp)
191            .unwrap_or(false)
192    }
193
194    /// Check if the token is expiring soon (within the given duration).
195    #[must_use]
196    pub fn is_expiring_soon(&self, within: Duration) -> bool {
197        self.expires_at
198            .map(|exp| Instant::now() + within >= exp)
199            .unwrap_or(false)
200    }
201
202    /// Build the FEDAUTH feature extension data for Login7 (SecurityToken
203    /// workflow).
204    ///
205    /// See [`build_security_token_feature_data`] for the wire layout and the
206    /// `fed_auth_echo` contract.
207    #[must_use]
208    pub fn build_feature_data(&self, fed_auth_echo: bool) -> Bytes {
209        build_security_token_feature_data(&self.token, fed_auth_echo)
210    }
211
212    /// Build the FEDAUTH token packet data.
213    ///
214    /// This is the token data sent in response to FEDAUTHINFO from the server.
215    #[must_use]
216    pub fn build_token_data(&self) -> Bytes {
217        // Token is sent as UTF-16LE
218        let token_utf16: Vec<u8> = self
219            .token
220            .encode_utf16()
221            .flat_map(|c| c.to_le_bytes())
222            .collect();
223
224        let mut data = Vec::with_capacity(4 + token_utf16.len());
225
226        // Token length (4 bytes, little-endian)
227        data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
228
229        // Token data (UTF-16LE)
230        data.extend_from_slice(&token_utf16);
231
232        Bytes::from(data)
233    }
234}
235
236impl AuthProvider for AzureAdAuth {
237    fn method(&self) -> AuthMethod {
238        AuthMethod::AzureAd
239    }
240
241    fn authenticate(&self) -> Result<AuthData, AuthError> {
242        if self.is_expired() {
243            return Err(AuthError::TokenExpired);
244        }
245
246        tracing::debug!("authenticating with Azure AD token");
247
248        Ok(AuthData::FedAuth {
249            token: self.token.to_string(),
250            nonce: None,
251        })
252    }
253
254    // Note: `feature_extension_data` deliberately uses the trait default
255    // (`None`). The FEDAUTH feature data depends on the server's PRELOGIN
256    // FEDAUTHREQUIRED response (the echo bit), which is unknowable here;
257    // the login path builds it via `build_feature_data(echo)` instead.
258
259    fn needs_refresh(&self) -> bool {
260        // Refresh if token expires within 5 minutes
261        self.is_expiring_soon(Duration::from_secs(300))
262    }
263}
264
265impl std::fmt::Debug for AzureAdAuth {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("AzureAdAuth")
268            .field("token", &"[REDACTED]")
269            .field("expires_at", &self.expires_at)
270            .finish()
271    }
272}
273
274#[cfg(test)]
275#[allow(clippy::unwrap_used, clippy::panic)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_azure_ad_with_token() {
281        let auth = AzureAdAuth::with_token("test_token");
282        assert_eq!(auth.method(), AuthMethod::AzureAd);
283        assert!(!auth.is_expired());
284    }
285
286    #[test]
287    fn test_azure_ad_with_expiring_token() {
288        let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
289        assert!(!auth.is_expired());
290        assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
291    }
292
293    #[test]
294    fn test_azure_ad_expired_token() {
295        let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
296        // Token with 0 duration should be expired immediately (or very soon)
297        std::thread::sleep(Duration::from_millis(10));
298        assert!(auth.is_expired());
299
300        let result = auth.authenticate();
301        assert!(matches!(result, Err(AuthError::TokenExpired)));
302    }
303
304    /// Wire-exact encoding of the SecurityToken FEDAUTH feature data per
305    /// MS-TDS §2.2.6.4: Options byte = (0x01 << 1) | echo, then a
306    /// little-endian DWORD byte length, then the token as UTF-16LE. No nonce.
307    #[test]
308    fn test_security_token_feature_data_wire_exact() {
309        // "AB" -> UTF-16LE 41 00 42 00, length 4.
310        let no_echo = build_security_token_feature_data("AB", false);
311        assert_eq!(
312            no_echo.as_ref(),
313            &[0x02, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
314            "echo clear: options must be 0x02 (SecurityToken << 1)"
315        );
316
317        let echo = build_security_token_feature_data("AB", true);
318        assert_eq!(
319            echo.as_ref(),
320            &[0x03, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
321            "echo set: fFedAuthEcho is the low bit of the options byte"
322        );
323    }
324
325    /// Non-BMP characters must encode as UTF-16 surrogate pairs and the DWORD
326    /// length must count bytes (not code units or chars).
327    #[test]
328    fn test_security_token_feature_data_surrogate_pair() {
329        // U+1F600 -> surrogates D83D DE00 -> LE bytes 3D D8 00 DE.
330        let data = build_security_token_feature_data("\u{1F600}", false);
331        assert_eq!(
332            data.as_ref(),
333            &[0x02, 0x04, 0x00, 0x00, 0x00, 0x3D, 0xD8, 0x00, 0xDE]
334        );
335    }
336
337    /// The method form delegates to the free function with the same token.
338    #[test]
339    fn test_azure_ad_feature_data() {
340        let auth = AzureAdAuth::with_token("test_token");
341        let data = auth.build_feature_data(true);
342
343        assert_eq!(data, build_security_token_feature_data("test_token", true));
344        // Library bits: options >> 1 must be the SecurityToken identifier.
345        assert_eq!(data[0] >> 1, FedAuthLibrary::SecurityToken.to_byte());
346        assert_eq!(data[0] & 1, 1);
347    }
348
349    /// The wire identifiers come from MS-TDS §2.2.6.4 and must never drift:
350    /// SecurityToken = 0x01, ADAL (also used by MSAL) = 0x02.
351    #[test]
352    fn test_fed_auth_library_wire_values() {
353        assert_eq!(FedAuthLibrary::SecurityToken.to_byte(), 0x01);
354        assert_eq!(FedAuthLibrary::Adal.to_byte(), 0x02);
355    }
356
357    #[test]
358    fn test_azure_ad_token_data() {
359        let auth = AzureAdAuth::with_token("AB");
360        let data = auth.build_token_data();
361
362        // Length (4 bytes) + "AB" in UTF-16LE (4 bytes)
363        assert_eq!(data.len(), 8);
364        // Length is 4 (2 UTF-16 code units * 2 bytes each)
365        assert_eq!(&data[0..4], &[4, 0, 0, 0]);
366    }
367
368    #[test]
369    fn test_from_credentials() {
370        let creds = Credentials::azure_token("my_token");
371        let auth = AzureAdAuth::from_credentials(&creds).unwrap();
372
373        let data = auth.authenticate().unwrap();
374        match &data {
375            AuthData::FedAuth { token, .. } => {
376                assert_eq!(token, "my_token");
377            }
378            _ => panic!("Expected FedAuth data"),
379        }
380    }
381
382    #[test]
383    fn test_from_credentials_wrong_type() {
384        let creds = Credentials::sql_server("user", "pass");
385        let result = AzureAdAuth::from_credentials(&creds);
386        assert!(result.is_err());
387    }
388
389    #[test]
390    fn test_debug_redacts_token() {
391        let auth = AzureAdAuth::with_token("secret_token");
392        let debug = format!("{auth:?}");
393        assert!(!debug.contains("secret_token"));
394        assert!(debug.contains("[REDACTED]"));
395    }
396}