1use crate::error::ConnectionError;
7use serde::{Deserialize, Serialize};
8use std::fmt;
9use std::sync::Arc;
10
11#[derive(Clone)]
15pub struct Credentials {
16 username: String,
17 password: Arc<SecureString>,
18}
19
20impl Credentials {
21 pub fn new(username: String, password: String) -> Self {
23 Self {
24 username,
25 password: Arc::new(SecureString::new(password)),
26 }
27 }
28
29 pub fn username(&self) -> &str {
31 &self.username
32 }
33
34 pub(crate) fn password(&self) -> &str {
36 self.password.as_str()
37 }
38}
39
40impl fmt::Debug for Credentials {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("Credentials")
43 .field("username", &self.username)
44 .field("password", &"<redacted>")
45 .finish()
46 }
47}
48
49impl fmt::Display for Credentials {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "Credentials(username: {})", self.username)
52 }
53}
54
55struct SecureString {
57 data: Vec<u8>,
58}
59
60impl SecureString {
61 fn new(s: String) -> Self {
62 Self {
63 data: s.into_bytes(),
64 }
65 }
66
67 fn as_str(&self) -> &str {
68 unsafe { std::str::from_utf8_unchecked(&self.data) }
70 }
71}
72
73impl Drop for SecureString {
74 fn drop(&mut self) {
75 for byte in &mut self.data {
77 *byte = 0;
78 }
79 }
80}
81
82impl fmt::Debug for SecureString {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 write!(f, "SecureString(<redacted>)")
85 }
86}
87
88#[derive(Debug, Serialize, Deserialize)]
90pub struct AuthRequest {
91 #[serde(rename = "command")]
92 pub command: String,
93
94 #[serde(rename = "username")]
95 pub username: String,
96
97 #[serde(rename = "password")]
98 pub password: String,
99
100 #[serde(rename = "useCompression", skip_serializing_if = "Option::is_none")]
101 pub use_compression: Option<bool>,
102
103 #[serde(rename = "sessionId", skip_serializing_if = "Option::is_none")]
104 pub session_id: Option<String>,
105
106 #[serde(rename = "clientName", skip_serializing_if = "Option::is_none")]
107 pub client_name: Option<String>,
108
109 #[serde(rename = "clientVersion", skip_serializing_if = "Option::is_none")]
110 pub client_version: Option<String>,
111
112 #[serde(rename = "driverName", skip_serializing_if = "Option::is_none")]
113 pub driver_name: Option<String>,
114
115 #[serde(rename = "attributes", skip_serializing_if = "Option::is_none")]
116 pub attributes: Option<serde_json::Value>,
117}
118
119impl AuthRequest {
120 pub fn new(username: String, password: String) -> Self {
122 Self {
123 command: "login".to_string(),
124 username,
125 password,
126 use_compression: Some(false),
127 session_id: None,
128 client_name: None,
129 client_version: None,
130 driver_name: Some("exarrow-rs".to_string()),
131 attributes: None,
132 }
133 }
134
135 pub fn with_client_info(mut self, name: String, version: String) -> Self {
137 self.client_name = Some(name);
138 self.client_version = Some(version);
139 self
140 }
141
142 pub fn with_session_id(mut self, session_id: String) -> Self {
144 self.session_id = Some(session_id);
145 self
146 }
147
148 pub fn with_compression(mut self, enabled: bool) -> Self {
150 self.use_compression = Some(enabled);
151 self
152 }
153
154 pub fn with_attributes(mut self, attributes: serde_json::Value) -> Self {
156 self.attributes = Some(attributes);
157 self
158 }
159}
160
161impl fmt::Display for AuthRequest {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 write!(
165 f,
166 "AuthRequest {{ username: {}, password: <redacted> }}",
167 self.username
168 )
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct AuthResponse {
175 #[serde(rename = "status")]
176 pub status: String,
177
178 #[serde(rename = "responseData", skip_serializing_if = "Option::is_none")]
179 pub response_data: Option<AuthResponseData>,
180
181 #[serde(rename = "exception", skip_serializing_if = "Option::is_none")]
182 pub exception: Option<ExceptionInfo>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct AuthResponseData {
188 #[serde(rename = "sessionId")]
189 pub session_id: String,
190
191 #[serde(rename = "protocolVersion")]
192 pub protocol_version: i32,
193
194 #[serde(rename = "releaseVersion")]
195 pub release_version: String,
196
197 #[serde(rename = "databaseName")]
198 pub database_name: String,
199
200 #[serde(rename = "productName")]
201 pub product_name: String,
202
203 #[serde(rename = "maxDataMessageSize")]
204 pub max_data_message_size: i64,
205
206 #[serde(rename = "maxIdentifierLength")]
207 pub max_identifier_length: i32,
208
209 #[serde(rename = "maxVarcharLength")]
210 pub max_varchar_length: i64,
211
212 #[serde(rename = "identifierQuoteString")]
213 pub identifier_quote_string: String,
214
215 #[serde(rename = "timeZone")]
216 pub time_zone: String,
217
218 #[serde(rename = "timeZoneBehavior")]
219 pub time_zone_behavior: String,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ExceptionInfo {
225 #[serde(rename = "text")]
226 pub text: String,
227
228 #[serde(rename = "sqlCode")]
229 pub sql_code: String,
230}
231
232impl AuthResponse {
233 pub fn is_success(&self) -> bool {
235 self.status == "ok" && self.response_data.is_some()
236 }
237
238 pub fn session_id(&self) -> Option<&str> {
240 self.response_data
241 .as_ref()
242 .map(|data| data.session_id.as_str())
243 }
244
245 pub fn error_message(&self) -> Option<String> {
247 self.exception
248 .as_ref()
249 .map(|exc| format!("{} ({})", exc.text, exc.sql_code))
250 }
251}
252
253pub struct AuthenticationHandler {
255 credentials: Credentials,
256 client_name: String,
257 client_version: String,
258}
259
260impl AuthenticationHandler {
261 pub fn new(credentials: Credentials, client_name: String, client_version: String) -> Self {
263 Self {
264 credentials,
265 client_name,
266 client_version,
267 }
268 }
269
270 pub fn build_auth_request(&self) -> AuthRequest {
272 AuthRequest::new(
273 self.credentials.username().to_string(),
274 self.credentials.password().to_string(),
275 )
276 .with_client_info(self.client_name.clone(), self.client_version.clone())
277 }
278
279 pub fn build_reconnect_request(&self, session_id: String) -> AuthRequest {
281 self.build_auth_request().with_session_id(session_id)
282 }
283
284 pub fn process_auth_response(
286 &self,
287 response: AuthResponse,
288 ) -> Result<AuthResponseData, ConnectionError> {
289 if response.is_success() {
290 response.response_data.ok_or_else(|| {
291 ConnectionError::AuthenticationFailed(
292 "Server returned success but no response data".to_string(),
293 )
294 })
295 } else {
296 let error_msg = response
297 .error_message()
298 .unwrap_or_else(|| "Unknown authentication error".to_string());
299 Err(ConnectionError::AuthenticationFailed(error_msg))
300 }
301 }
302
303 pub fn username(&self) -> &str {
305 self.credentials.username()
306 }
307}
308
309impl fmt::Debug for AuthenticationHandler {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 f.debug_struct("AuthenticationHandler")
312 .field("credentials", &self.credentials)
313 .field("client_name", &self.client_name)
314 .field("client_version", &self.client_version)
315 .finish()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_credentials_no_password_leak() {
325 let creds = Credentials::new("admin".to_string(), "secret123".to_string());
326
327 let debug = format!("{:?}", creds);
328 assert!(!debug.contains("secret123"));
329 assert!(debug.contains("admin"));
330 assert!(debug.contains("redacted"));
331
332 let display = format!("{}", creds);
333 assert!(!display.contains("secret123"));
334 assert!(display.contains("admin"));
335 }
336
337 #[test]
338 fn test_credentials_access() {
339 let creds = Credentials::new("user".to_string(), "pass".to_string());
340
341 assert_eq!(creds.username(), "user");
342 assert_eq!(creds.password(), "pass");
343 }
344
345 #[test]
346 fn test_secure_string_zeros_on_drop() {
347 let _data = {
348 let secure = SecureString::new("password".to_string());
349 secure.data.clone()
351 };
352
353 let secure = SecureString::new("test1234".to_string());
356 let _ptr = secure.data.as_ptr();
357 let _len = secure.data.len();
358
359 drop(secure);
360
361 }
364
365 #[test]
366 fn test_auth_request_creation() {
367 let req = AuthRequest::new("admin".to_string(), "secret".to_string());
368
369 assert_eq!(req.command, "login");
370 assert_eq!(req.username, "admin");
371 assert_eq!(req.password, "secret");
372 assert_eq!(req.driver_name, Some("exarrow-rs".to_string()));
373 }
374
375 #[test]
376 fn test_auth_request_with_client_info() {
377 let req = AuthRequest::new("admin".to_string(), "secret".to_string())
378 .with_client_info("test-client".to_string(), "1.0.0".to_string());
379
380 assert_eq!(req.client_name, Some("test-client".to_string()));
381 assert_eq!(req.client_version, Some("1.0.0".to_string()));
382 }
383
384 #[test]
385 fn test_auth_request_no_password_leak() {
386 let req = AuthRequest::new("admin".to_string(), "secret123".to_string());
387
388 let display = format!("{}", req);
389 assert!(!display.contains("secret123"));
390 assert!(display.contains("admin"));
391 assert!(display.contains("redacted"));
392 }
393
394 #[test]
395 fn test_auth_response_success() {
396 let response = AuthResponse {
397 status: "ok".to_string(),
398 response_data: Some(AuthResponseData {
399 session_id: "sess123".to_string(),
400 protocol_version: 3,
401 release_version: "7.1.0".to_string(),
402 database_name: "EXA".to_string(),
403 product_name: "Exasol".to_string(),
404 max_data_message_size: 4_194_304,
405 max_identifier_length: 128,
406 max_varchar_length: 2_000_000,
407 identifier_quote_string: "\"".to_string(),
408 time_zone: "UTC".to_string(),
409 time_zone_behavior: "INVALID TIMESTAMP TO DOUBLE".to_string(),
410 }),
411 exception: None,
412 };
413
414 assert!(response.is_success());
415 assert_eq!(response.session_id(), Some("sess123"));
416 assert!(response.error_message().is_none());
417 }
418
419 #[test]
420 fn test_auth_response_failure() {
421 let response = AuthResponse {
422 status: "error".to_string(),
423 response_data: None,
424 exception: Some(ExceptionInfo {
425 text: "Invalid credentials".to_string(),
426 sql_code: "08004".to_string(),
427 }),
428 };
429
430 assert!(!response.is_success());
431 assert!(response.session_id().is_none());
432 assert_eq!(
433 response.error_message(),
434 Some("Invalid credentials (08004)".to_string())
435 );
436 }
437
438 #[test]
439 fn test_auth_handler_build_request() {
440 let creds = Credentials::new("admin".to_string(), "secret".to_string());
441 let handler =
442 AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
443
444 let req = handler.build_auth_request();
445
446 assert_eq!(req.username, "admin");
447 assert_eq!(req.password, "secret");
448 assert_eq!(req.client_name, Some("test-client".to_string()));
449 assert_eq!(req.client_version, Some("1.0.0".to_string()));
450 }
451
452 #[test]
453 fn test_auth_handler_process_success() {
454 let creds = Credentials::new("admin".to_string(), "secret".to_string());
455 let handler =
456 AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
457
458 let response = AuthResponse {
459 status: "ok".to_string(),
460 response_data: Some(AuthResponseData {
461 session_id: "sess123".to_string(),
462 protocol_version: 3,
463 release_version: "7.1.0".to_string(),
464 database_name: "EXA".to_string(),
465 product_name: "Exasol".to_string(),
466 max_data_message_size: 4_194_304,
467 max_identifier_length: 128,
468 max_varchar_length: 2_000_000,
469 identifier_quote_string: "\"".to_string(),
470 time_zone: "UTC".to_string(),
471 time_zone_behavior: "INVALID TIMESTAMP TO DOUBLE".to_string(),
472 }),
473 exception: None,
474 };
475
476 let result = handler.process_auth_response(response);
477 assert!(result.is_ok());
478
479 let data = result.unwrap();
480 assert_eq!(data.session_id, "sess123");
481 assert_eq!(data.protocol_version, 3);
482 }
483
484 #[test]
485 fn test_auth_handler_process_failure() {
486 let creds = Credentials::new("admin".to_string(), "secret".to_string());
487 let handler =
488 AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
489
490 let response = AuthResponse {
491 status: "error".to_string(),
492 response_data: None,
493 exception: Some(ExceptionInfo {
494 text: "Invalid credentials".to_string(),
495 sql_code: "08004".to_string(),
496 }),
497 };
498
499 let result = handler.process_auth_response(response);
500 assert!(result.is_err());
501
502 match result.unwrap_err() {
503 ConnectionError::AuthenticationFailed(msg) => {
504 assert!(msg.contains("Invalid credentials"));
505 }
506 _ => panic!("Expected AuthenticationFailed error"),
507 }
508 }
509
510 #[test]
511 fn test_auth_handler_no_password_leak() {
512 let creds = Credentials::new("admin".to_string(), "super_secret".to_string());
513 let handler =
514 AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
515
516 let debug = format!("{:?}", handler);
517 assert!(!debug.contains("super_secret"));
518 assert!(debug.contains("admin"));
519 assert!(debug.contains("redacted"));
520 }
521
522 #[test]
523 fn test_reconnect_request() {
524 let creds = Credentials::new("admin".to_string(), "secret".to_string());
525 let handler =
526 AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
527
528 let req = handler.build_reconnect_request("old_session_123".to_string());
529
530 assert_eq!(req.session_id, Some("old_session_123".to_string()));
531 assert_eq!(req.username, "admin");
532 }
533
534 #[test]
535 fn test_credentials_clone() {
536 let creds = Credentials::new("user".to_string(), "pass".to_string());
537 let creds2 = creds.clone();
538
539 assert_eq!(creds.username(), creds2.username());
540 assert_eq!(creds.password(), creds2.password());
541 }
542}