1use std::borrow::Cow;
32use std::time::{Duration, Instant};
33
34use bytes::Bytes;
35
36use crate::credentials::Credentials;
37use crate::error::AuthError;
38use crate::provider::{AuthData, AuthMethod, AuthProvider};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44#[repr(u8)]
45pub enum FedAuthLibrary {
46 Adal = 0x01,
48 SecurityToken = 0x02,
50 Msal = 0x03,
52}
53
54impl FedAuthLibrary {
55 #[must_use]
57 pub fn to_byte(self) -> u8 {
58 self as u8
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum FedAuthWorkflow {
65 Interactive,
67 NonInteractive,
69 ManagedIdentity,
71 ServicePrincipal,
73}
74
75#[derive(Clone)]
90pub struct AzureAdAuth {
91 token: Cow<'static, str>,
93 expires_at: Option<Instant>,
95 library: FedAuthLibrary,
97}
98
99impl AzureAdAuth {
100 pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
109 Self {
110 token: token.into(),
111 expires_at: None,
112 library: FedAuthLibrary::SecurityToken,
113 }
114 }
115
116 pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
126 Self {
127 token: token.into(),
128 expires_at: Some(Instant::now() + expires_in),
129 library: FedAuthLibrary::SecurityToken,
130 }
131 }
132
133 pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
137 match credentials {
138 Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
139 _ => Err(AuthError::UnsupportedMethod(
140 "AzureAdAuth requires Azure AD credentials".into(),
141 )),
142 }
143 }
144
145 #[must_use]
147 pub fn with_library(mut self, library: FedAuthLibrary) -> Self {
148 self.library = library;
149 self
150 }
151
152 #[must_use]
154 pub fn is_expired(&self) -> bool {
155 self.expires_at
156 .map(|exp| Instant::now() >= exp)
157 .unwrap_or(false)
158 }
159
160 #[must_use]
162 pub fn is_expiring_soon(&self, within: Duration) -> bool {
163 self.expires_at
164 .map(|exp| Instant::now() + within >= exp)
165 .unwrap_or(false)
166 }
167
168 #[must_use]
176 pub fn build_feature_data(&self) -> Bytes {
177 let mut data = Vec::with_capacity(6);
178
179 data.push(self.library.to_byte());
181
182 data.push(0x00);
184
185 Bytes::from(data)
190 }
191
192 #[must_use]
196 pub fn build_token_data(&self) -> Bytes {
197 let token_utf16: Vec<u8> = self
199 .token
200 .encode_utf16()
201 .flat_map(|c| c.to_le_bytes())
202 .collect();
203
204 let mut data = Vec::with_capacity(4 + token_utf16.len());
205
206 data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
208
209 data.extend_from_slice(&token_utf16);
211
212 Bytes::from(data)
213 }
214}
215
216impl AuthProvider for AzureAdAuth {
217 fn method(&self) -> AuthMethod {
218 AuthMethod::AzureAd
219 }
220
221 fn authenticate(&self) -> Result<AuthData, AuthError> {
222 if self.is_expired() {
223 return Err(AuthError::TokenExpired);
224 }
225
226 tracing::debug!("authenticating with Azure AD token");
227
228 Ok(AuthData::FedAuth {
229 token: self.token.to_string(),
230 nonce: None,
231 })
232 }
233
234 fn feature_extension_data(&self) -> Option<Bytes> {
235 Some(self.build_feature_data())
236 }
237
238 fn needs_refresh(&self) -> bool {
239 self.is_expiring_soon(Duration::from_secs(300))
241 }
242}
243
244impl std::fmt::Debug for AzureAdAuth {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 f.debug_struct("AzureAdAuth")
247 .field("token", &"[REDACTED]")
248 .field("expires_at", &self.expires_at)
249 .field("library", &self.library)
250 .finish()
251 }
252}
253
254#[cfg(test)]
255#[allow(clippy::unwrap_used, clippy::panic)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_azure_ad_with_token() {
261 let auth = AzureAdAuth::with_token("test_token");
262 assert_eq!(auth.method(), AuthMethod::AzureAd);
263 assert!(!auth.is_expired());
264 }
265
266 #[test]
267 fn test_azure_ad_with_expiring_token() {
268 let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
269 assert!(!auth.is_expired());
270 assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
271 }
272
273 #[test]
274 fn test_azure_ad_expired_token() {
275 let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
276 std::thread::sleep(Duration::from_millis(10));
278 assert!(auth.is_expired());
279
280 let result = auth.authenticate();
281 assert!(matches!(result, Err(AuthError::TokenExpired)));
282 }
283
284 #[test]
285 fn test_azure_ad_feature_data() {
286 let auth = AzureAdAuth::with_token("test_token");
287 let data = auth.build_feature_data();
288
289 assert!(!data.is_empty());
290 assert_eq!(data[0], FedAuthLibrary::SecurityToken.to_byte());
291 }
292
293 #[test]
294 fn test_azure_ad_token_data() {
295 let auth = AzureAdAuth::with_token("AB");
296 let data = auth.build_token_data();
297
298 assert_eq!(data.len(), 8);
300 assert_eq!(&data[0..4], &[4, 0, 0, 0]);
302 }
303
304 #[test]
305 fn test_from_credentials() {
306 let creds = Credentials::azure_token("my_token");
307 let auth = AzureAdAuth::from_credentials(&creds).unwrap();
308
309 let data = auth.authenticate().unwrap();
310 match &data {
311 AuthData::FedAuth { token, .. } => {
312 assert_eq!(token, "my_token");
313 }
314 _ => panic!("Expected FedAuth data"),
315 }
316 }
317
318 #[test]
319 fn test_from_credentials_wrong_type() {
320 let creds = Credentials::sql_server("user", "pass");
321 let result = AzureAdAuth::from_credentials(&creds);
322 assert!(result.is_err());
323 }
324
325 #[test]
326 fn test_debug_redacts_token() {
327 let auth = AzureAdAuth::with_token("secret_token");
328 let debug = format!("{:?}", auth);
329 assert!(!debug.contains("secret_token"));
330 assert!(debug.contains("[REDACTED]"));
331 }
332}