mssql_auth/azure_ad.rs
1//! Azure AD / Entra ID authentication implementation.
2//!
3//! This module provides token handling for Azure AD federated authentication,
4//! supporting both pre-acquired tokens and (with feature flags) token acquisition.
5//!
6//! ## Authentication Flow
7//!
8//! Azure AD authentication uses the TDS FEDAUTH feature extension
9//! (MS-TDS §2.2.6.4). Two workflows exist:
10//!
11//! - **SecurityToken** (implemented, #155 Phase 1): the client acquires a
12//! token *before* login and sends it inside the LOGIN7 FEDAUTH feature
13//! extension ([`build_security_token_feature_data`]). No FEDAUTHINFO
14//! round-trip occurs.
15//! - **ADAL/MSAL** (pending, #155 Phase 2): the client declares intent in
16//! LOGIN7, the server responds with FEDAUTHINFO (STS URL + SPN), and the
17//! client acquires a token and sends it in a separate FEDAUTH message.
18//!
19//! ## Token Sources (Tier 1 - Core)
20//!
21//! - Pre-acquired access token (user provides token directly)
22//!
23//! ## Token Sources (Tier 2 - azure-identity feature)
24//!
25//! These require the `azure-identity` feature flag:
26//!
27//! - `ManagedIdentityAuth` - Azure VM/Container identity
28//! - `ServicePrincipalAuth` - Client ID + Secret
29//!
30//! ## Token Sources (Tier 3 - cert-auth feature)
31//!
32//! - `CertificateAuth` - X.509 client certificate
33
34use std::borrow::Cow;
35use std::time::{Duration, Instant};
36
37use bytes::{BufMut, Bytes, BytesMut};
38
39use crate::credentials::Credentials;
40use crate::error::AuthError;
41use crate::provider::{AuthData, AuthMethod, AuthProvider};
42
43/// FEDAUTH library identifiers for the LOGIN7 FEDAUTH feature extension.
44///
45/// Per MS-TDS §2.2.6.4, `bFedAuthLibrary` is a 7-bit value that occupies the
46/// high 7 bits of the feature data's Options byte (the low bit is
47/// `fFedAuthEcho`). Live ID Compact Token (0x00) is legacy and not supported;
48/// 0x7F is reserved.
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50#[repr(u8)]
51#[non_exhaustive]
52pub enum FedAuthLibrary {
53 /// Security token: a token acquired by the client before login and sent
54 /// inside LOGIN7 (the workflow used for pre-acquired Azure AD tokens).
55 SecurityToken = 0x01,
56 /// ADAL: the client declares intent and acquires the token after the
57 /// server's FEDAUTHINFO response. MSAL (ADAL's successor) uses this same
58 /// wire identifier.
59 Adal = 0x02,
60}
61
62impl FedAuthLibrary {
63 /// Get the 7-bit library identifier (unshifted).
64 ///
65 /// In the Options byte this value is shifted left by one bit to make room
66 /// for `fFedAuthEcho`: `options = (library << 1) | echo`.
67 #[must_use]
68 pub fn to_byte(self) -> u8 {
69 self as u8
70 }
71}
72
73/// Build the LOGIN7 FEDAUTH feature extension data for the SecurityToken
74/// workflow (MS-TDS §2.2.6.4, `bFedAuthLibrary` = 0x01).
75///
76/// Layout:
77///
78/// ```text
79/// Options = 1 byte: (0x01 << 1) | fFedAuthEcho
80/// FedAuthToken = DWORD (LE) byte length + token as UTF-16LE
81/// ```
82///
83/// No trailing nonce is emitted: per spec the nonce MUST be present if and
84/// only if the server's PRELOGIN response carried a NONCE option, and this
85/// driver does not send NONCEOPT in PRELOGIN.
86///
87/// `fed_auth_echo` MUST be set if and only if the server's PRELOGIN response
88/// contained FEDAUTHREQUIRED with value 0x01 — the server validates this echo
89/// to detect tampering.
90///
91/// The token must be non-empty (the spec forbids a zero-length FedAuthToken);
92/// callers are expected to validate this before login.
93#[must_use]
94pub fn build_security_token_feature_data(token: &str, fed_auth_echo: bool) -> Bytes {
95 debug_assert!(!token.is_empty(), "FedAuthToken length MUST NOT be 0");
96
97 let token_utf16: Vec<u8> = token.encode_utf16().flat_map(|c| c.to_le_bytes()).collect();
98
99 let mut data = BytesMut::with_capacity(1 + 4 + token_utf16.len());
100 data.put_u8((FedAuthLibrary::SecurityToken.to_byte() << 1) | u8::from(fed_auth_echo));
101 data.put_u32_le(token_utf16.len() as u32);
102 data.put_slice(&token_utf16);
103 data.freeze()
104}
105
106/// FEDAUTH workflow types.
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[non_exhaustive]
109pub enum FedAuthWorkflow {
110 /// Interactive authentication with user sign-in.
111 Interactive,
112 /// Non-interactive with pre-acquired token.
113 NonInteractive,
114 /// Managed identity (system or user-assigned).
115 ManagedIdentity,
116 /// Service principal with secret.
117 ServicePrincipal,
118}
119
120/// Azure AD authentication provider.
121///
122/// This provider supports Azure AD / Entra ID federated authentication
123/// using pre-acquired access tokens (Tier 1) or token acquisition
124/// with the `azure-identity` feature (Tier 2).
125///
126/// # Example
127///
128/// ```rust
129/// use mssql_auth::AzureAdAuth;
130///
131/// // Using a pre-acquired token
132/// let auth = AzureAdAuth::with_token("eyJ0eXAi...");
133/// ```
134#[derive(Clone)]
135pub struct AzureAdAuth {
136 /// The access token.
137 token: Cow<'static, str>,
138 /// When the token expires (if known).
139 expires_at: Option<Instant>,
140}
141
142impl AzureAdAuth {
143 /// Create an Azure AD authenticator with a pre-acquired token.
144 ///
145 /// This is the simplest form - provide a valid access token obtained
146 /// from Azure AD / Entra ID via your preferred method.
147 ///
148 /// # Arguments
149 ///
150 /// * `token` - A valid JWT access token for Azure SQL Database
151 pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
152 Self {
153 token: token.into(),
154 expires_at: None,
155 }
156 }
157
158 /// Create an Azure AD authenticator with a token and expiration.
159 ///
160 /// Providing the expiration time allows the driver to proactively
161 /// refresh tokens before they expire.
162 ///
163 /// # Arguments
164 ///
165 /// * `token` - A valid JWT access token
166 /// * `expires_in` - Duration until the token expires
167 pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
168 Self {
169 token: token.into(),
170 expires_at: Some(Instant::now() + expires_in),
171 }
172 }
173
174 /// Create from existing credentials.
175 ///
176 /// Returns an error if the credentials are not Azure AD credentials.
177 pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
178 match credentials {
179 Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
180 _ => Err(AuthError::UnsupportedMethod(
181 "AzureAdAuth requires Azure AD credentials".into(),
182 )),
183 }
184 }
185
186 /// Check if the token is expired.
187 #[must_use]
188 pub fn is_expired(&self) -> bool {
189 self.expires_at
190 .map(|exp| Instant::now() >= exp)
191 .unwrap_or(false)
192 }
193
194 /// Check if the token is expiring soon (within the given duration).
195 #[must_use]
196 pub fn is_expiring_soon(&self, within: Duration) -> bool {
197 self.expires_at
198 .map(|exp| Instant::now() + within >= exp)
199 .unwrap_or(false)
200 }
201
202 /// Build the FEDAUTH feature extension data for Login7 (SecurityToken
203 /// workflow).
204 ///
205 /// See [`build_security_token_feature_data`] for the wire layout and the
206 /// `fed_auth_echo` contract.
207 #[must_use]
208 pub fn build_feature_data(&self, fed_auth_echo: bool) -> Bytes {
209 build_security_token_feature_data(&self.token, fed_auth_echo)
210 }
211
212 /// Build the FEDAUTH token packet data.
213 ///
214 /// This is the token data sent in response to FEDAUTHINFO from the server.
215 #[must_use]
216 pub fn build_token_data(&self) -> Bytes {
217 // Token is sent as UTF-16LE
218 let token_utf16: Vec<u8> = self
219 .token
220 .encode_utf16()
221 .flat_map(|c| c.to_le_bytes())
222 .collect();
223
224 let mut data = Vec::with_capacity(4 + token_utf16.len());
225
226 // Token length (4 bytes, little-endian)
227 data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
228
229 // Token data (UTF-16LE)
230 data.extend_from_slice(&token_utf16);
231
232 Bytes::from(data)
233 }
234}
235
236impl AuthProvider for AzureAdAuth {
237 fn method(&self) -> AuthMethod {
238 AuthMethod::AzureAd
239 }
240
241 fn authenticate(&self) -> Result<AuthData, AuthError> {
242 if self.is_expired() {
243 return Err(AuthError::TokenExpired);
244 }
245
246 tracing::debug!("authenticating with Azure AD token");
247
248 Ok(AuthData::FedAuth {
249 token: self.token.to_string(),
250 nonce: None,
251 })
252 }
253
254 // Note: `feature_extension_data` deliberately uses the trait default
255 // (`None`). The FEDAUTH feature data depends on the server's PRELOGIN
256 // FEDAUTHREQUIRED response (the echo bit), which is unknowable here;
257 // the login path builds it via `build_feature_data(echo)` instead.
258
259 fn needs_refresh(&self) -> bool {
260 // Refresh if token expires within 5 minutes
261 self.is_expiring_soon(Duration::from_secs(300))
262 }
263}
264
265impl std::fmt::Debug for AzureAdAuth {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("AzureAdAuth")
268 .field("token", &"[REDACTED]")
269 .field("expires_at", &self.expires_at)
270 .finish()
271 }
272}
273
274#[cfg(test)]
275#[allow(clippy::unwrap_used, clippy::panic)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_azure_ad_with_token() {
281 let auth = AzureAdAuth::with_token("test_token");
282 assert_eq!(auth.method(), AuthMethod::AzureAd);
283 assert!(!auth.is_expired());
284 }
285
286 #[test]
287 fn test_azure_ad_with_expiring_token() {
288 let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
289 assert!(!auth.is_expired());
290 assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
291 }
292
293 #[test]
294 fn test_azure_ad_expired_token() {
295 let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
296 // Token with 0 duration should be expired immediately (or very soon)
297 std::thread::sleep(Duration::from_millis(10));
298 assert!(auth.is_expired());
299
300 let result = auth.authenticate();
301 assert!(matches!(result, Err(AuthError::TokenExpired)));
302 }
303
304 /// Wire-exact encoding of the SecurityToken FEDAUTH feature data per
305 /// MS-TDS §2.2.6.4: Options byte = (0x01 << 1) | echo, then a
306 /// little-endian DWORD byte length, then the token as UTF-16LE. No nonce.
307 #[test]
308 fn test_security_token_feature_data_wire_exact() {
309 // "AB" -> UTF-16LE 41 00 42 00, length 4.
310 let no_echo = build_security_token_feature_data("AB", false);
311 assert_eq!(
312 no_echo.as_ref(),
313 &[0x02, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
314 "echo clear: options must be 0x02 (SecurityToken << 1)"
315 );
316
317 let echo = build_security_token_feature_data("AB", true);
318 assert_eq!(
319 echo.as_ref(),
320 &[0x03, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
321 "echo set: fFedAuthEcho is the low bit of the options byte"
322 );
323 }
324
325 /// Non-BMP characters must encode as UTF-16 surrogate pairs and the DWORD
326 /// length must count bytes (not code units or chars).
327 #[test]
328 fn test_security_token_feature_data_surrogate_pair() {
329 // U+1F600 -> surrogates D83D DE00 -> LE bytes 3D D8 00 DE.
330 let data = build_security_token_feature_data("\u{1F600}", false);
331 assert_eq!(
332 data.as_ref(),
333 &[0x02, 0x04, 0x00, 0x00, 0x00, 0x3D, 0xD8, 0x00, 0xDE]
334 );
335 }
336
337 /// The method form delegates to the free function with the same token.
338 #[test]
339 fn test_azure_ad_feature_data() {
340 let auth = AzureAdAuth::with_token("test_token");
341 let data = auth.build_feature_data(true);
342
343 assert_eq!(data, build_security_token_feature_data("test_token", true));
344 // Library bits: options >> 1 must be the SecurityToken identifier.
345 assert_eq!(data[0] >> 1, FedAuthLibrary::SecurityToken.to_byte());
346 assert_eq!(data[0] & 1, 1);
347 }
348
349 /// The wire identifiers come from MS-TDS §2.2.6.4 and must never drift:
350 /// SecurityToken = 0x01, ADAL (also used by MSAL) = 0x02.
351 #[test]
352 fn test_fed_auth_library_wire_values() {
353 assert_eq!(FedAuthLibrary::SecurityToken.to_byte(), 0x01);
354 assert_eq!(FedAuthLibrary::Adal.to_byte(), 0x02);
355 }
356
357 #[test]
358 fn test_azure_ad_token_data() {
359 let auth = AzureAdAuth::with_token("AB");
360 let data = auth.build_token_data();
361
362 // Length (4 bytes) + "AB" in UTF-16LE (4 bytes)
363 assert_eq!(data.len(), 8);
364 // Length is 4 (2 UTF-16 code units * 2 bytes each)
365 assert_eq!(&data[0..4], &[4, 0, 0, 0]);
366 }
367
368 #[test]
369 fn test_from_credentials() {
370 let creds = Credentials::azure_token("my_token");
371 let auth = AzureAdAuth::from_credentials(&creds).unwrap();
372
373 let data = auth.authenticate().unwrap();
374 match &data {
375 AuthData::FedAuth { token, .. } => {
376 assert_eq!(token, "my_token");
377 }
378 _ => panic!("Expected FedAuth data"),
379 }
380 }
381
382 #[test]
383 fn test_from_credentials_wrong_type() {
384 let creds = Credentials::sql_server("user", "pass");
385 let result = AzureAdAuth::from_credentials(&creds);
386 assert!(result.is_err());
387 }
388
389 #[test]
390 fn test_debug_redacts_token() {
391 let auth = AzureAdAuth::with_token("secret_token");
392 let debug = format!("{auth:?}");
393 assert!(!debug.contains("secret_token"));
394 assert!(debug.contains("[REDACTED]"));
395 }
396}