Skip to main content

mssql_auth/
provider.rs

1//! Authentication provider traits.
2//!
3//! This module defines the `AuthProvider` trait for implementing
4//! authentication strategies, as specified in ARCHITECTURE.md.
5
6use bytes::Bytes;
7
8use crate::error::AuthError;
9
10/// Authentication method enumeration.
11///
12/// This indicates which authentication flow to use during connection.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[non_exhaustive]
15pub enum AuthMethod {
16    /// SQL Server authentication (username/password in Login7).
17    SqlServer,
18    /// Azure AD / Entra ID federated authentication.
19    AzureAd,
20    /// Integrated Windows authentication (SSPI/Kerberos).
21    Integrated,
22    /// Certificate-based authentication.
23    Certificate,
24}
25
26impl AuthMethod {
27    /// Check if this method uses federated authentication.
28    #[must_use]
29    pub fn is_federated(&self) -> bool {
30        matches!(self, Self::AzureAd)
31    }
32
33    /// Check if this method uses SSPI.
34    #[must_use]
35    pub fn is_sspi(&self) -> bool {
36        matches!(self, Self::Integrated)
37    }
38
39    /// Check if this method uses Login7 credentials.
40    #[must_use]
41    pub fn uses_login7_credentials(&self) -> bool {
42        matches!(self, Self::SqlServer)
43    }
44}
45
46/// Authentication data produced by an auth provider.
47///
48/// This contains the data needed to authenticate with SQL Server,
49/// depending on the authentication method being used.
50///
51/// Sensitive fields (password bytes, tokens, SSPI blobs) are securely zeroized
52/// on drop when the `zeroize` feature is enabled.
53#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub enum AuthData {
56    /// SQL Server credentials for Login7 packet.
57    SqlServer {
58        /// Username.
59        username: String,
60        /// Obfuscated password bytes (XOR + bit rotation).
61        password_bytes: Vec<u8>,
62    },
63    /// Federated authentication token for FEDAUTH feature.
64    FedAuth {
65        /// The access token.
66        token: String,
67        /// Token nonce (optional, for certain flows).
68        nonce: Option<Bytes>,
69    },
70    /// SSPI blob for integrated authentication.
71    Sspi {
72        /// The SSPI authentication blob.
73        blob: Vec<u8>,
74    },
75    /// No additional authentication data needed.
76    None,
77}
78
79/// Trait for authentication providers.
80///
81/// Authentication providers are responsible for producing the authentication
82/// data needed for the TDS connection. Different providers support different
83/// authentication methods (SQL auth, Azure AD, integrated, etc.).
84///
85/// # Example
86///
87/// ```rust,ignore
88/// use mssql_auth::{AuthProvider, SqlServerAuth};
89///
90/// let provider = SqlServerAuth::new("username", "password");
91/// let auth_data = provider.authenticate().await?;
92/// ```
93pub trait AuthProvider: Send + Sync {
94    /// Get the authentication method this provider uses.
95    fn method(&self) -> AuthMethod;
96
97    /// Authenticate and produce authentication data.
98    ///
99    /// This may involve network calls (e.g., for Azure AD token acquisition)
100    /// so it returns a future in async implementations.
101    fn authenticate(&self) -> Result<AuthData, AuthError>;
102
103    /// Get additional feature extension data for Login7.
104    ///
105    /// Some authentication methods (like Azure AD) require feature extensions
106    /// in the Login7 packet. This returns the raw feature data if needed.
107    fn feature_extension_data(&self) -> Option<Bytes> {
108        None
109    }
110
111    /// Check if this provider needs to refresh its authentication.
112    ///
113    /// For token-based authentication, this can check if the token is expired
114    /// or about to expire.
115    fn needs_refresh(&self) -> bool {
116        false
117    }
118}
119
120/// Async authentication provider trait.
121///
122/// This is for authentication methods that require async operations,
123/// such as acquiring tokens from Azure AD endpoints.
124#[allow(async_fn_in_trait)]
125pub trait AsyncAuthProvider: Send + Sync {
126    /// Get the authentication method this provider uses.
127    fn method(&self) -> AuthMethod;
128
129    /// Authenticate asynchronously and produce authentication data.
130    async fn authenticate_async(&self) -> Result<AuthData, AuthError>;
131
132    /// Get additional feature extension data for Login7.
133    fn feature_extension_data(&self) -> Option<Bytes> {
134        None
135    }
136
137    /// Check if this provider needs to refresh its authentication.
138    fn needs_refresh(&self) -> bool {
139        false
140    }
141}
142
143// Implement AuthProvider for any AsyncAuthProvider by blocking
144// (for use in synchronous contexts when needed)
145impl<T: AsyncAuthProvider> AuthProvider for T {
146    fn method(&self) -> AuthMethod {
147        <T as AsyncAuthProvider>::method(self)
148    }
149
150    fn authenticate(&self) -> Result<AuthData, AuthError> {
151        // This is a fallback - in practice, async providers should be used
152        // with authenticate_async(). This implementation is for compatibility.
153        Err(AuthError::Configuration(
154            "Async auth provider must use authenticate_async()".into(),
155        ))
156    }
157
158    fn feature_extension_data(&self) -> Option<Bytes> {
159        <T as AsyncAuthProvider>::feature_extension_data(self)
160    }
161
162    fn needs_refresh(&self) -> bool {
163        <T as AsyncAuthProvider>::needs_refresh(self)
164    }
165}
166
167// Secure zeroization of sensitive authentication data when `zeroize` feature is enabled.
168#[cfg(feature = "zeroize")]
169impl Drop for AuthData {
170    fn drop(&mut self) {
171        use zeroize::Zeroize;
172
173        match self {
174            AuthData::SqlServer { password_bytes, .. } => {
175                password_bytes.zeroize();
176            }
177            AuthData::FedAuth { token, .. } => {
178                token.zeroize();
179            }
180            AuthData::Sspi { blob } => {
181                blob.zeroize();
182            }
183            AuthData::None => {}
184        }
185    }
186}
187
188#[cfg(test)]
189#[allow(clippy::unwrap_used)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_auth_method_properties() {
195        assert!(AuthMethod::AzureAd.is_federated());
196        assert!(!AuthMethod::SqlServer.is_federated());
197
198        assert!(AuthMethod::Integrated.is_sspi());
199        assert!(!AuthMethod::SqlServer.is_sspi());
200
201        assert!(AuthMethod::SqlServer.uses_login7_credentials());
202        assert!(!AuthMethod::AzureAd.uses_login7_credentials());
203    }
204}