use std::borrow::Cow;
use std::time::{Duration, Instant};
use bytes::Bytes;
use crate::credentials::Credentials;
use crate::error::AuthError;
use crate::provider::{AuthData, AuthMethod, AuthProvider};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[non_exhaustive]
pub enum FedAuthLibrary {
Adal = 0x01,
SecurityToken = 0x02,
Msal = 0x03,
}
impl FedAuthLibrary {
#[must_use]
pub fn to_byte(self) -> u8 {
self as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum FedAuthWorkflow {
Interactive,
NonInteractive,
ManagedIdentity,
ServicePrincipal,
}
#[derive(Clone)]
pub struct AzureAdAuth {
token: Cow<'static, str>,
expires_at: Option<Instant>,
library: FedAuthLibrary,
}
impl AzureAdAuth {
pub fn with_token(token: impl Into<Cow<'static, str>>) -> Self {
Self {
token: token.into(),
expires_at: None,
library: FedAuthLibrary::SecurityToken,
}
}
pub fn with_token_expiring(token: impl Into<Cow<'static, str>>, expires_in: Duration) -> Self {
Self {
token: token.into(),
expires_at: Some(Instant::now() + expires_in),
library: FedAuthLibrary::SecurityToken,
}
}
pub fn from_credentials(credentials: &Credentials) -> Result<Self, AuthError> {
match credentials {
Credentials::AzureAccessToken { token } => Ok(Self::with_token(token.to_string())),
_ => Err(AuthError::UnsupportedMethod(
"AzureAdAuth requires Azure AD credentials".into(),
)),
}
}
#[must_use]
pub fn with_library(mut self, library: FedAuthLibrary) -> Self {
self.library = library;
self
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.expires_at
.map(|exp| Instant::now() >= exp)
.unwrap_or(false)
}
#[must_use]
pub fn is_expiring_soon(&self, within: Duration) -> bool {
self.expires_at
.map(|exp| Instant::now() + within >= exp)
.unwrap_or(false)
}
#[must_use]
pub fn build_feature_data(&self) -> Bytes {
let mut data = Vec::with_capacity(6);
data.push(self.library.to_byte());
data.push(0x00);
Bytes::from(data)
}
#[must_use]
pub fn build_token_data(&self) -> Bytes {
let token_utf16: Vec<u8> = self
.token
.encode_utf16()
.flat_map(|c| c.to_le_bytes())
.collect();
let mut data = Vec::with_capacity(4 + token_utf16.len());
data.extend_from_slice(&(token_utf16.len() as u32).to_le_bytes());
data.extend_from_slice(&token_utf16);
Bytes::from(data)
}
}
impl AuthProvider for AzureAdAuth {
fn method(&self) -> AuthMethod {
AuthMethod::AzureAd
}
fn authenticate(&self) -> Result<AuthData, AuthError> {
if self.is_expired() {
return Err(AuthError::TokenExpired);
}
tracing::debug!("authenticating with Azure AD token");
Ok(AuthData::FedAuth {
token: self.token.to_string(),
nonce: None,
})
}
fn feature_extension_data(&self) -> Option<Bytes> {
Some(self.build_feature_data())
}
fn needs_refresh(&self) -> bool {
self.is_expiring_soon(Duration::from_secs(300))
}
}
impl std::fmt::Debug for AzureAdAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureAdAuth")
.field("token", &"[REDACTED]")
.field("expires_at", &self.expires_at)
.field("library", &self.library)
.finish()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_azure_ad_with_token() {
let auth = AzureAdAuth::with_token("test_token");
assert_eq!(auth.method(), AuthMethod::AzureAd);
assert!(!auth.is_expired());
}
#[test]
fn test_azure_ad_with_expiring_token() {
let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(3600));
assert!(!auth.is_expired());
assert!(!auth.is_expiring_soon(Duration::from_secs(60)));
}
#[test]
fn test_azure_ad_expired_token() {
let auth = AzureAdAuth::with_token_expiring("test_token", Duration::from_secs(0));
std::thread::sleep(Duration::from_millis(10));
assert!(auth.is_expired());
let result = auth.authenticate();
assert!(matches!(result, Err(AuthError::TokenExpired)));
}
#[test]
fn test_azure_ad_feature_data() {
let auth = AzureAdAuth::with_token("test_token");
let data = auth.build_feature_data();
assert!(!data.is_empty());
assert_eq!(data[0], FedAuthLibrary::SecurityToken.to_byte());
}
#[test]
fn test_azure_ad_token_data() {
let auth = AzureAdAuth::with_token("AB");
let data = auth.build_token_data();
assert_eq!(data.len(), 8);
assert_eq!(&data[0..4], &[4, 0, 0, 0]);
}
#[test]
fn test_from_credentials() {
let creds = Credentials::azure_token("my_token");
let auth = AzureAdAuth::from_credentials(&creds).unwrap();
let data = auth.authenticate().unwrap();
match &data {
AuthData::FedAuth { token, .. } => {
assert_eq!(token, "my_token");
}
_ => panic!("Expected FedAuth data"),
}
}
#[test]
fn test_from_credentials_wrong_type() {
let creds = Credentials::sql_server("user", "pass");
let result = AzureAdAuth::from_credentials(&creds);
assert!(result.is_err());
}
#[test]
fn test_debug_redacts_token() {
let auth = AzureAdAuth::with_token("secret_token");
let debug = format!("{auth:?}");
assert!(!debug.contains("secret_token"));
assert!(debug.contains("[REDACTED]"));
}
}