1use bytes::Bytes;
7
8use crate::error::AuthError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[non_exhaustive]
15pub enum AuthMethod {
16 SqlServer,
18 AzureAd,
20 Integrated,
22 Certificate,
24}
25
26impl AuthMethod {
27 #[must_use]
29 pub fn is_federated(&self) -> bool {
30 matches!(self, Self::AzureAd)
31 }
32
33 #[must_use]
35 pub fn is_sspi(&self) -> bool {
36 matches!(self, Self::Integrated)
37 }
38
39 #[must_use]
41 pub fn uses_login7_credentials(&self) -> bool {
42 matches!(self, Self::SqlServer)
43 }
44}
45
46#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub enum AuthData {
56 SqlServer {
58 username: String,
60 password_bytes: Vec<u8>,
62 },
63 FedAuth {
65 token: String,
67 nonce: Option<Bytes>,
69 },
70 Sspi {
72 blob: Vec<u8>,
74 },
75 None,
77}
78
79pub trait AuthProvider: Send + Sync {
94 fn method(&self) -> AuthMethod;
96
97 fn authenticate(&self) -> Result<AuthData, AuthError>;
102
103 fn feature_extension_data(&self) -> Option<Bytes> {
108 None
109 }
110
111 fn needs_refresh(&self) -> bool {
116 false
117 }
118}
119
120#[allow(async_fn_in_trait)]
125pub trait AsyncAuthProvider: Send + Sync {
126 fn method(&self) -> AuthMethod;
128
129 async fn authenticate_async(&self) -> Result<AuthData, AuthError>;
131
132 fn feature_extension_data(&self) -> Option<Bytes> {
134 None
135 }
136
137 fn needs_refresh(&self) -> bool {
139 false
140 }
141}
142
143impl<T: AsyncAuthProvider> AuthProvider for T {
146 fn method(&self) -> AuthMethod {
147 <T as AsyncAuthProvider>::method(self)
148 }
149
150 fn authenticate(&self) -> Result<AuthData, AuthError> {
151 Err(AuthError::Configuration(
154 "Async auth provider must use authenticate_async()".into(),
155 ))
156 }
157
158 fn feature_extension_data(&self) -> Option<Bytes> {
159 <T as AsyncAuthProvider>::feature_extension_data(self)
160 }
161
162 fn needs_refresh(&self) -> bool {
163 <T as AsyncAuthProvider>::needs_refresh(self)
164 }
165}
166
167#[cfg(feature = "zeroize")]
169impl Drop for AuthData {
170 fn drop(&mut self) {
171 use zeroize::Zeroize;
172
173 match self {
174 AuthData::SqlServer { password_bytes, .. } => {
175 password_bytes.zeroize();
176 }
177 AuthData::FedAuth { token, .. } => {
178 token.zeroize();
179 }
180 AuthData::Sspi { blob } => {
181 blob.zeroize();
182 }
183 AuthData::None => {}
184 }
185 }
186}
187
188#[cfg(test)]
189#[allow(clippy::unwrap_used, clippy::panic)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_auth_method_properties() {
195 assert!(AuthMethod::AzureAd.is_federated());
196 assert!(!AuthMethod::SqlServer.is_federated());
197
198 assert!(AuthMethod::Integrated.is_sspi());
199 assert!(!AuthMethod::SqlServer.is_sspi());
200
201 assert!(AuthMethod::SqlServer.uses_login7_credentials());
202 assert!(!AuthMethod::AzureAd.uses_login7_credentials());
203 }
204
205 #[test]
206 fn test_auth_method_all_variants_classified() {
207 let methods = [
209 AuthMethod::SqlServer,
210 AuthMethod::AzureAd,
211 AuthMethod::Integrated,
212 AuthMethod::Certificate,
213 ];
214
215 for method in &methods {
216 let categories = [
217 method.uses_login7_credentials(),
218 method.is_federated(),
219 method.is_sspi(),
220 ];
221 let count = categories.iter().filter(|&&b| b).count();
223 assert!(
224 count <= 1,
225 "{method:?} has {count} categories, expected 0 or 1"
226 );
227 }
228 }
229
230 #[test]
231 fn test_auth_method_certificate() {
232 let cert = AuthMethod::Certificate;
233 assert!(!cert.is_federated());
234 assert!(!cert.is_sspi());
235 assert!(!cert.uses_login7_credentials());
236 }
237
238 #[test]
239 fn test_auth_data_sql_server() {
240 let data = AuthData::SqlServer {
241 username: "sa".to_string(),
242 password_bytes: vec![0xA5, 0xB6],
243 };
244 match &data {
246 AuthData::SqlServer {
247 username,
248 password_bytes,
249 } => {
250 assert_eq!(username, "sa");
251 assert_eq!(password_bytes.len(), 2);
252 }
253 _ => panic!("Expected SqlServer variant"),
254 }
255 }
256
257 #[test]
258 fn test_auth_data_fed_auth() {
259 let data = AuthData::FedAuth {
260 token: "eyJhbGciOiJSUzI1NiJ9.test".to_string(),
261 nonce: None,
262 };
263 match &data {
264 AuthData::FedAuth { token, nonce } => {
265 assert!(token.starts_with("eyJ"));
266 assert!(nonce.is_none());
267 }
268 _ => panic!("Expected FedAuth variant"),
269 }
270 }
271
272 #[test]
273 fn test_auth_data_sspi() {
274 let data = AuthData::Sspi {
275 blob: vec![0x4E, 0x54, 0x4C, 0x4D], };
277 match &data {
278 AuthData::Sspi { blob } => {
279 assert_eq!(blob.len(), 4);
280 }
281 _ => panic!("Expected Sspi variant"),
282 }
283 }
284
285 #[test]
286 fn test_auth_data_none() {
287 let data = AuthData::None;
288 assert!(matches!(data, AuthData::None));
289 }
290
291 #[test]
292 fn test_auth_data_debug_output() {
293 let variants: Vec<AuthData> = vec![
295 AuthData::SqlServer {
296 username: "test".into(),
297 password_bytes: vec![1, 2, 3],
298 },
299 AuthData::FedAuth {
300 token: "tok".into(),
301 nonce: Some(Bytes::from_static(b"nonce")),
302 },
303 AuthData::Sspi {
304 blob: vec![0x01, 0x02],
305 },
306 AuthData::None,
307 ];
308
309 for v in &variants {
310 let _ = format!("{v:?}");
311 }
312 }
313
314 struct MockProvider {
316 method: AuthMethod,
317 }
318
319 impl AuthProvider for MockProvider {
320 fn method(&self) -> AuthMethod {
321 self.method
322 }
323
324 fn authenticate(&self) -> Result<AuthData, crate::error::AuthError> {
325 Ok(AuthData::None)
326 }
327 }
328
329 #[test]
330 fn test_auth_provider_trait_defaults() {
331 let provider = MockProvider {
332 method: AuthMethod::SqlServer,
333 };
334
335 assert_eq!(provider.method(), AuthMethod::SqlServer);
336 assert!(provider.feature_extension_data().is_none());
337 assert!(!provider.needs_refresh());
338
339 let data = provider.authenticate().unwrap();
340 assert!(matches!(data, AuthData::None));
341 }
342}