mssql_auth/provider.rs
1//! Authentication provider traits.
2//!
3//! This module defines the `AuthProvider` trait for implementing
4//! authentication strategies, as specified in ARCHITECTURE.md.
5
6use bytes::Bytes;
7
8use crate::error::AuthError;
9
10/// Authentication method enumeration.
11///
12/// This indicates which authentication flow to use during connection.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[non_exhaustive]
15pub enum AuthMethod {
16 /// SQL Server authentication (username/password in Login7).
17 SqlServer,
18 /// Azure AD / Entra ID federated authentication.
19 AzureAd,
20 /// Integrated Windows authentication (SSPI/Kerberos).
21 Integrated,
22 /// Certificate-based authentication.
23 Certificate,
24}
25
26impl AuthMethod {
27 /// Check if this method uses federated authentication.
28 #[must_use]
29 pub fn is_federated(&self) -> bool {
30 matches!(self, Self::AzureAd)
31 }
32
33 /// Check if this method uses SSPI.
34 #[must_use]
35 pub fn is_sspi(&self) -> bool {
36 matches!(self, Self::Integrated)
37 }
38
39 /// Check if this method uses Login7 credentials.
40 #[must_use]
41 pub fn uses_login7_credentials(&self) -> bool {
42 matches!(self, Self::SqlServer)
43 }
44}
45
46/// Authentication data produced by an auth provider.
47///
48/// This contains the data needed to authenticate with SQL Server,
49/// depending on the authentication method being used.
50///
51/// Sensitive fields (password bytes, tokens, SSPI blobs) are securely zeroized
52/// on drop when the `zeroize` feature is enabled.
53#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub enum AuthData {
56 /// SQL Server credentials for Login7 packet.
57 SqlServer {
58 /// Username.
59 username: String,
60 /// Obfuscated password bytes (XOR + bit rotation).
61 password_bytes: Vec<u8>,
62 },
63 /// Federated authentication token for FEDAUTH feature.
64 FedAuth {
65 /// The access token.
66 token: String,
67 /// Token nonce (optional, for certain flows).
68 nonce: Option<Bytes>,
69 },
70 /// SSPI blob for integrated authentication.
71 Sspi {
72 /// The SSPI authentication blob.
73 blob: Vec<u8>,
74 },
75 /// No additional authentication data needed.
76 None,
77}
78
79/// Trait for authentication providers.
80///
81/// Authentication providers are responsible for producing the authentication
82/// data needed for the TDS connection. Different providers support different
83/// authentication methods (SQL auth, Azure AD, integrated, etc.).
84///
85/// # Example
86///
87/// ```rust,ignore
88/// use mssql_auth::{AuthProvider, SqlServerAuth};
89///
90/// let provider = SqlServerAuth::new("username", "password");
91/// let auth_data = provider.authenticate().await?;
92/// ```
93pub trait AuthProvider: Send + Sync {
94 /// Get the authentication method this provider uses.
95 fn method(&self) -> AuthMethod;
96
97 /// Authenticate and produce authentication data.
98 ///
99 /// This may involve network calls (e.g., for Azure AD token acquisition)
100 /// so it returns a future in async implementations.
101 fn authenticate(&self) -> Result<AuthData, AuthError>;
102
103 /// Get additional feature extension data for Login7.
104 ///
105 /// Some authentication methods (like Azure AD) require feature extensions
106 /// in the Login7 packet. This returns the raw feature data if needed.
107 fn feature_extension_data(&self) -> Option<Bytes> {
108 None
109 }
110
111 /// Check if this provider needs to refresh its authentication.
112 ///
113 /// For token-based authentication, this can check if the token is expired
114 /// or about to expire.
115 fn needs_refresh(&self) -> bool {
116 false
117 }
118}
119
120/// Async authentication provider trait.
121///
122/// This is for authentication methods that require async operations,
123/// such as acquiring tokens from Azure AD endpoints.
124#[allow(async_fn_in_trait)]
125pub trait AsyncAuthProvider: Send + Sync {
126 /// Get the authentication method this provider uses.
127 fn method(&self) -> AuthMethod;
128
129 /// Authenticate asynchronously and produce authentication data.
130 async fn authenticate_async(&self) -> Result<AuthData, AuthError>;
131
132 /// Get additional feature extension data for Login7.
133 fn feature_extension_data(&self) -> Option<Bytes> {
134 None
135 }
136
137 /// Check if this provider needs to refresh its authentication.
138 fn needs_refresh(&self) -> bool {
139 false
140 }
141}
142
143// Implement AuthProvider for any AsyncAuthProvider by blocking
144// (for use in synchronous contexts when needed)
145impl<T: AsyncAuthProvider> AuthProvider for T {
146 fn method(&self) -> AuthMethod {
147 <T as AsyncAuthProvider>::method(self)
148 }
149
150 fn authenticate(&self) -> Result<AuthData, AuthError> {
151 // This is a fallback - in practice, async providers should be used
152 // with authenticate_async(). This implementation is for compatibility.
153 Err(AuthError::Configuration(
154 "Async auth provider must use authenticate_async()".into(),
155 ))
156 }
157
158 fn feature_extension_data(&self) -> Option<Bytes> {
159 <T as AsyncAuthProvider>::feature_extension_data(self)
160 }
161
162 fn needs_refresh(&self) -> bool {
163 <T as AsyncAuthProvider>::needs_refresh(self)
164 }
165}
166
167// Secure zeroization of sensitive authentication data when `zeroize` feature is enabled.
168#[cfg(feature = "zeroize")]
169impl Drop for AuthData {
170 fn drop(&mut self) {
171 use zeroize::Zeroize;
172
173 match self {
174 AuthData::SqlServer { password_bytes, .. } => {
175 password_bytes.zeroize();
176 }
177 AuthData::FedAuth { token, .. } => {
178 token.zeroize();
179 }
180 AuthData::Sspi { blob } => {
181 blob.zeroize();
182 }
183 AuthData::None => {}
184 }
185 }
186}
187
188#[cfg(test)]
189#[allow(clippy::unwrap_used)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_auth_method_properties() {
195 assert!(AuthMethod::AzureAd.is_federated());
196 assert!(!AuthMethod::SqlServer.is_federated());
197
198 assert!(AuthMethod::Integrated.is_sspi());
199 assert!(!AuthMethod::SqlServer.is_sspi());
200
201 assert!(AuthMethod::SqlServer.uses_login7_credentials());
202 assert!(!AuthMethod::AzureAd.uses_login7_credentials());
203 }
204}