use bytes::Bytes;
use crate::error::AuthError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum AuthMethod {
SqlServer,
AzureAd,
Integrated,
Certificate,
}
impl AuthMethod {
#[must_use]
pub fn is_federated(&self) -> bool {
matches!(self, Self::AzureAd)
}
#[must_use]
pub fn is_sspi(&self) -> bool {
matches!(self, Self::Integrated)
}
#[must_use]
pub fn uses_login7_credentials(&self) -> bool {
matches!(self, Self::SqlServer)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AuthData {
SqlServer {
username: String,
password_bytes: Vec<u8>,
},
FedAuth {
token: String,
nonce: Option<Bytes>,
},
Sspi {
blob: Vec<u8>,
},
None,
}
pub trait AuthProvider: Send + Sync {
fn method(&self) -> AuthMethod;
fn authenticate(&self) -> Result<AuthData, AuthError>;
fn feature_extension_data(&self) -> Option<Bytes> {
None
}
fn needs_refresh(&self) -> bool {
false
}
}
#[allow(async_fn_in_trait)]
pub trait AsyncAuthProvider: Send + Sync {
fn method(&self) -> AuthMethod;
async fn authenticate_async(&self) -> Result<AuthData, AuthError>;
fn feature_extension_data(&self) -> Option<Bytes> {
None
}
fn needs_refresh(&self) -> bool {
false
}
}
impl<T: AsyncAuthProvider> AuthProvider for T {
fn method(&self) -> AuthMethod {
<T as AsyncAuthProvider>::method(self)
}
fn authenticate(&self) -> Result<AuthData, AuthError> {
Err(AuthError::Configuration(
"Async auth provider must use authenticate_async()".into(),
))
}
fn feature_extension_data(&self) -> Option<Bytes> {
<T as AsyncAuthProvider>::feature_extension_data(self)
}
fn needs_refresh(&self) -> bool {
<T as AsyncAuthProvider>::needs_refresh(self)
}
}
#[cfg(feature = "zeroize")]
impl Drop for AuthData {
fn drop(&mut self) {
use zeroize::Zeroize;
match self {
AuthData::SqlServer { password_bytes, .. } => {
password_bytes.zeroize();
}
AuthData::FedAuth { token, .. } => {
token.zeroize();
}
AuthData::Sspi { blob } => {
blob.zeroize();
}
AuthData::None => {}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_auth_method_properties() {
assert!(AuthMethod::AzureAd.is_federated());
assert!(!AuthMethod::SqlServer.is_federated());
assert!(AuthMethod::Integrated.is_sspi());
assert!(!AuthMethod::SqlServer.is_sspi());
assert!(AuthMethod::SqlServer.uses_login7_credentials());
assert!(!AuthMethod::AzureAd.uses_login7_credentials());
}
}