use bytes::Bytes;
use crate::error::AuthError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum AuthMethod {
SqlServer,
AzureAd,
Integrated,
Certificate,
}
impl AuthMethod {
#[must_use]
pub fn is_federated(&self) -> bool {
matches!(self, Self::AzureAd)
}
#[must_use]
pub fn is_sspi(&self) -> bool {
matches!(self, Self::Integrated)
}
#[must_use]
pub fn uses_login7_credentials(&self) -> bool {
matches!(self, Self::SqlServer)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AuthData {
SqlServer {
username: String,
password_bytes: Vec<u8>,
},
FedAuth {
token: String,
nonce: Option<Bytes>,
},
Sspi {
blob: Vec<u8>,
},
None,
}
pub trait AuthProvider: Send + Sync {
fn method(&self) -> AuthMethod;
fn authenticate(&self) -> Result<AuthData, AuthError>;
fn feature_extension_data(&self) -> Option<Bytes> {
None
}
fn needs_refresh(&self) -> bool {
false
}
}
#[allow(async_fn_in_trait)]
pub trait AsyncAuthProvider: Send + Sync {
fn method(&self) -> AuthMethod;
async fn authenticate_async(&self) -> Result<AuthData, AuthError>;
fn feature_extension_data(&self) -> Option<Bytes> {
None
}
fn needs_refresh(&self) -> bool {
false
}
}
impl<T: AsyncAuthProvider> AuthProvider for T {
fn method(&self) -> AuthMethod {
<T as AsyncAuthProvider>::method(self)
}
fn authenticate(&self) -> Result<AuthData, AuthError> {
Err(AuthError::Configuration(
"Async auth provider must use authenticate_async()".into(),
))
}
fn feature_extension_data(&self) -> Option<Bytes> {
<T as AsyncAuthProvider>::feature_extension_data(self)
}
fn needs_refresh(&self) -> bool {
<T as AsyncAuthProvider>::needs_refresh(self)
}
}
#[cfg(feature = "zeroize")]
impl Drop for AuthData {
fn drop(&mut self) {
use zeroize::Zeroize;
match self {
AuthData::SqlServer { password_bytes, .. } => {
password_bytes.zeroize();
}
AuthData::FedAuth { token, .. } => {
token.zeroize();
}
AuthData::Sspi { blob } => {
blob.zeroize();
}
AuthData::None => {}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_auth_method_properties() {
assert!(AuthMethod::AzureAd.is_federated());
assert!(!AuthMethod::SqlServer.is_federated());
assert!(AuthMethod::Integrated.is_sspi());
assert!(!AuthMethod::SqlServer.is_sspi());
assert!(AuthMethod::SqlServer.uses_login7_credentials());
assert!(!AuthMethod::AzureAd.uses_login7_credentials());
}
#[test]
fn test_auth_method_all_variants_classified() {
let methods = [
AuthMethod::SqlServer,
AuthMethod::AzureAd,
AuthMethod::Integrated,
AuthMethod::Certificate,
];
for method in &methods {
let categories = [
method.uses_login7_credentials(),
method.is_federated(),
method.is_sspi(),
];
let count = categories.iter().filter(|&&b| b).count();
assert!(
count <= 1,
"{method:?} has {count} categories, expected 0 or 1"
);
}
}
#[test]
fn test_auth_method_certificate() {
let cert = AuthMethod::Certificate;
assert!(!cert.is_federated());
assert!(!cert.is_sspi());
assert!(!cert.uses_login7_credentials());
}
#[test]
fn test_auth_data_sql_server() {
let data = AuthData::SqlServer {
username: "sa".to_string(),
password_bytes: vec![0xA5, 0xB6],
};
match &data {
AuthData::SqlServer {
username,
password_bytes,
} => {
assert_eq!(username, "sa");
assert_eq!(password_bytes.len(), 2);
}
_ => panic!("Expected SqlServer variant"),
}
}
#[test]
fn test_auth_data_fed_auth() {
let data = AuthData::FedAuth {
token: "eyJhbGciOiJSUzI1NiJ9.test".to_string(),
nonce: None,
};
match &data {
AuthData::FedAuth { token, nonce } => {
assert!(token.starts_with("eyJ"));
assert!(nonce.is_none());
}
_ => panic!("Expected FedAuth variant"),
}
}
#[test]
fn test_auth_data_sspi() {
let data = AuthData::Sspi {
blob: vec![0x4E, 0x54, 0x4C, 0x4D], };
match &data {
AuthData::Sspi { blob } => {
assert_eq!(blob.len(), 4);
}
_ => panic!("Expected Sspi variant"),
}
}
#[test]
fn test_auth_data_none() {
let data = AuthData::None;
assert!(matches!(data, AuthData::None));
}
#[test]
fn test_auth_data_debug_output() {
let variants: Vec<AuthData> = vec![
AuthData::SqlServer {
username: "test".into(),
password_bytes: vec![1, 2, 3],
},
AuthData::FedAuth {
token: "tok".into(),
nonce: Some(Bytes::from_static(b"nonce")),
},
AuthData::Sspi {
blob: vec![0x01, 0x02],
},
AuthData::None,
];
for v in &variants {
let _ = format!("{v:?}");
}
}
struct MockProvider {
method: AuthMethod,
}
impl AuthProvider for MockProvider {
fn method(&self) -> AuthMethod {
self.method
}
fn authenticate(&self) -> Result<AuthData, crate::error::AuthError> {
Ok(AuthData::None)
}
}
#[test]
fn test_auth_provider_trait_defaults() {
let provider = MockProvider {
method: AuthMethod::SqlServer,
};
assert_eq!(provider.method(), AuthMethod::SqlServer);
assert!(provider.feature_extension_data().is_none());
assert!(!provider.needs_refresh());
let data = provider.authenticate().unwrap();
assert!(matches!(data, AuthData::None));
}
}