1use std::borrow::Cow;
32use std::time::{Duration, Instant};
33
34use bytes::Bytes;
35
36use crate::credentials::Credentials;
37use crate::error::AuthError;
38use crate::provider::{AuthData, AuthMethod, AuthProvider};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44#[repr(u8)]
45#[non_exhaustive]
46pub enum FedAuthLibrary {
47 Adal = 0x01,
49 SecurityToken = 0x02,
51 Msal = 0x03,
53}
54
55impl FedAuthLibrary {
56 #[must_use]
58 pub fn to_byte(self) -> u8 {
59 self as u8
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65#[non_exhaustive]
66pub enum FedAuthWorkflow {
67 Interactive,
69 NonInteractive,
71 ManagedIdentity,
73 ServicePrincipal,
75}
76
77#[derive(Clone)]
92pub struct AzureAdAuth {
93 token: Cow<'static, str>,
95 expires_at: Option<Instant>,
97 library: FedAuthLibrary,
99}
100
101impl AzureAdAuth {
102 pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
111 Self {
112 token: token.into(),
113 expires_at: None,
114 library: FedAuthLibrary::SecurityToken,
115 }
116 }
117
118 pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
128 Self {
129 token: token.into(),
130 expires_at: Some(Instant::now() + expires_in),
131 library: FedAuthLibrary::SecurityToken,
132 }
133 }
134
135 pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
139 match credentials {
140 Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
141 _ => Err(AuthError::UnsupportedMethod(
142 "AzureAdAuth requires Azure AD credentials".into(),
143 )),
144 }
145 }
146
147 #[must_use]
149 pub fn with_library(mut self, library: FedAuthLibrary) -> Self {
150 self.library = library;
151 self
152 }
153
154 #[must_use]
156 pub fn is_expired(&self) -> bool {
157 self.expires_at
158 .map(|exp| Instant::now() >= exp)
159 .unwrap_or(false)
160 }
161
162 #[must_use]
164 pub fn is_expiring_soon(&self, within: Duration) -> bool {
165 self.expires_at
166 .map(|exp| Instant::now() + within >= exp)
167 .unwrap_or(false)
168 }
169
170 #[must_use]
178 pub fn build_feature_data(&self) -> Bytes {
179 let mut data = Vec::with_capacity(6);
180
181 data.push(self.library.to_byte());
183
184 data.push(0x00);
186
187 Bytes::from(data)
192 }
193
194 #[must_use]
198 pub fn build_token_data(&self) -> Bytes {
199 let token_utf16: Vec<u8> = self
201 .token
202 .encode_utf16()
203 .flat_map(|c| c.to_le_bytes())
204 .collect();
205
206 let mut data = Vec::with_capacity(4 + token_utf16.len());
207
208 data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
210
211 data.extend_from_slice(&token_utf16);
213
214 Bytes::from(data)
215 }
216}
217
218impl AuthProvider for AzureAdAuth {
219 fn method(&self) -> AuthMethod {
220 AuthMethod::AzureAd
221 }
222
223 fn authenticate(&self) -> Result<AuthData, AuthError> {
224 if self.is_expired() {
225 return Err(AuthError::TokenExpired);
226 }
227
228 tracing::debug!("authenticating with Azure AD token");
229
230 Ok(AuthData::FedAuth {
231 token: self.token.to_string(),
232 nonce: None,
233 })
234 }
235
236 fn feature_extension_data(&self) -> Option<Bytes> {
237 Some(self.build_feature_data())
238 }
239
240 fn needs_refresh(&self) -> bool {
241 self.is_expiring_soon(Duration::from_secs(300))
243 }
244}
245
246impl std::fmt::Debug for AzureAdAuth {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 f.debug_struct("AzureAdAuth")
249 .field("token", &"[REDACTED]")
250 .field("expires_at", &self.expires_at)
251 .field("library", &self.library)
252 .finish()
253 }
254}
255
256#[cfg(test)]
257#[allow(clippy::unwrap_used, clippy::panic)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_azure_ad_with_token() {
263 let auth = AzureAdAuth::with_token("test_token");
264 assert_eq!(auth.method(), AuthMethod::AzureAd);
265 assert!(!auth.is_expired());
266 }
267
268 #[test]
269 fn test_azure_ad_with_expiring_token() {
270 let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
271 assert!(!auth.is_expired());
272 assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
273 }
274
275 #[test]
276 fn test_azure_ad_expired_token() {
277 let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
278 std::thread::sleep(Duration::from_millis(10));
280 assert!(auth.is_expired());
281
282 let result = auth.authenticate();
283 assert!(matches!(result, Err(AuthError::TokenExpired)));
284 }
285
286 #[test]
287 fn test_azure_ad_feature_data() {
288 let auth = AzureAdAuth::with_token("test_token");
289 let data = auth.build_feature_data();
290
291 assert!(!data.is_empty());
292 assert_eq!(data[0], FedAuthLibrary::SecurityToken.to_byte());
293 }
294
295 #[test]
296 fn test_azure_ad_token_data() {
297 let auth = AzureAdAuth::with_token("AB");
298 let data = auth.build_token_data();
299
300 assert_eq!(data.len(), 8);
302 assert_eq!(&data[0..4], &[4, 0, 0, 0]);
304 }
305
306 #[test]
307 fn test_from_credentials() {
308 let creds = Credentials::azure_token("my_token");
309 let auth = AzureAdAuth::from_credentials(&creds).unwrap();
310
311 let data = auth.authenticate().unwrap();
312 match &data {
313 AuthData::FedAuth { token, .. } => {
314 assert_eq!(token, "my_token");
315 }
316 _ => panic!("Expected FedAuth data"),
317 }
318 }
319
320 #[test]
321 fn test_from_credentials_wrong_type() {
322 let creds = Credentials::sql_server("user", "pass");
323 let result = AzureAdAuth::from_credentials(&creds);
324 assert!(result.is_err());
325 }
326
327 #[test]
328 fn test_debug_redacts_token() {
329 let auth = AzureAdAuth::with_token("secret_token");
330 let debug = format!("{auth:?}");
331 assert!(!debug.contains("secret_token"));
332 assert!(debug.contains("[REDACTED]"));
333 }
334}