Skip to main content

exarrow_rs/connection/
auth.rs

1//! Authentication handling for Exasol connections.
2//!
3//! This module provides secure credential management and authentication
4//! protocol implementation.
5
6use crate::error::ConnectionError;
7use serde::{Deserialize, Serialize};
8use std::fmt;
9use std::sync::Arc;
10
11/// Secure credentials container.
12///
13/// This struct ensures credentials are never accidentally logged or displayed.
14#[derive(Clone)]
15pub struct Credentials {
16    username: String,
17    password: Arc<SecureString>,
18}
19
20impl Credentials {
21    /// Create new credentials.
22    pub fn new(username: String, password: String) -> Self {
23        Self {
24            username,
25            password: Arc::new(SecureString::new(password)),
26        }
27    }
28
29    /// Get the username.
30    pub fn username(&self) -> &str {
31        &self.username
32    }
33
34    /// Get the password (for internal use only).
35    pub(crate) fn password(&self) -> &str {
36        self.password.as_str()
37    }
38}
39
40impl fmt::Debug for Credentials {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("Credentials")
43            .field("username", &self.username)
44            .field("password", &"<redacted>")
45            .finish()
46    }
47}
48
49impl fmt::Display for Credentials {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(f, "Credentials(username: {})", self.username)
52    }
53}
54
55/// Secure string that zeros memory on drop and never displays its contents.
56struct SecureString {
57    data: Vec<u8>,
58}
59
60impl SecureString {
61    fn new(s: String) -> Self {
62        Self {
63            data: s.into_bytes(),
64        }
65    }
66
67    fn as_str(&self) -> &str {
68        // Safe because we only construct from valid UTF-8 strings
69        unsafe { std::str::from_utf8_unchecked(&self.data) }
70    }
71}
72
73impl Drop for SecureString {
74    fn drop(&mut self) {
75        // Zero out the password bytes before dropping
76        for byte in &mut self.data {
77            *byte = 0;
78        }
79    }
80}
81
82impl fmt::Debug for SecureString {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        write!(f, "SecureString(<redacted>)")
85    }
86}
87
88/// Authentication request message for Exasol protocol.
89#[derive(Debug, Serialize, Deserialize)]
90pub struct AuthRequest {
91    #[serde(rename = "command")]
92    pub command: String,
93
94    #[serde(rename = "username")]
95    pub username: String,
96
97    #[serde(rename = "password")]
98    pub password: String,
99
100    #[serde(rename = "useCompression", skip_serializing_if = "Option::is_none")]
101    pub use_compression: Option<bool>,
102
103    #[serde(rename = "sessionId", skip_serializing_if = "Option::is_none")]
104    pub session_id: Option<String>,
105
106    #[serde(rename = "clientName", skip_serializing_if = "Option::is_none")]
107    pub client_name: Option<String>,
108
109    #[serde(rename = "clientVersion", skip_serializing_if = "Option::is_none")]
110    pub client_version: Option<String>,
111
112    #[serde(rename = "driverName", skip_serializing_if = "Option::is_none")]
113    pub driver_name: Option<String>,
114
115    #[serde(rename = "attributes", skip_serializing_if = "Option::is_none")]
116    pub attributes: Option<serde_json::Value>,
117}
118
119impl AuthRequest {
120    /// Create a new authentication request.
121    pub fn new(username: String, password: String) -> Self {
122        Self {
123            command: "login".to_string(),
124            username,
125            password,
126            use_compression: Some(false),
127            session_id: None,
128            client_name: None,
129            client_version: None,
130            driver_name: Some("exarrow-rs".to_string()),
131            attributes: None,
132        }
133    }
134
135    /// Set client information.
136    pub fn with_client_info(mut self, name: String, version: String) -> Self {
137        self.client_name = Some(name);
138        self.client_version = Some(version);
139        self
140    }
141
142    /// Set session ID for reconnection.
143    pub fn with_session_id(mut self, session_id: String) -> Self {
144        self.session_id = Some(session_id);
145        self
146    }
147
148    /// Enable compression.
149    pub fn with_compression(mut self, enabled: bool) -> Self {
150        self.use_compression = Some(enabled);
151        self
152    }
153
154    /// Add custom attributes.
155    pub fn with_attributes(mut self, attributes: serde_json::Value) -> Self {
156        self.attributes = Some(attributes);
157        self
158    }
159}
160
161// Custom Debug that doesn't leak password
162impl fmt::Display for AuthRequest {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        write!(
165            f,
166            "AuthRequest {{ username: {}, password: <redacted> }}",
167            self.username
168        )
169    }
170}
171
172/// Authentication response from Exasol server.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct AuthResponse {
175    #[serde(rename = "status")]
176    pub status: String,
177
178    #[serde(rename = "responseData", skip_serializing_if = "Option::is_none")]
179    pub response_data: Option<AuthResponseData>,
180
181    #[serde(rename = "exception", skip_serializing_if = "Option::is_none")]
182    pub exception: Option<ExceptionInfo>,
183}
184
185/// Authentication response data.
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct AuthResponseData {
188    #[serde(rename = "sessionId")]
189    pub session_id: String,
190
191    #[serde(rename = "protocolVersion")]
192    pub protocol_version: i32,
193
194    #[serde(rename = "releaseVersion")]
195    pub release_version: String,
196
197    #[serde(rename = "databaseName")]
198    pub database_name: String,
199
200    #[serde(rename = "productName")]
201    pub product_name: String,
202
203    #[serde(rename = "maxDataMessageSize")]
204    pub max_data_message_size: i64,
205
206    #[serde(rename = "maxIdentifierLength")]
207    pub max_identifier_length: i32,
208
209    #[serde(rename = "maxVarcharLength")]
210    pub max_varchar_length: i64,
211
212    #[serde(rename = "identifierQuoteString")]
213    pub identifier_quote_string: String,
214
215    #[serde(rename = "timeZone")]
216    pub time_zone: String,
217
218    #[serde(rename = "timeZoneBehavior")]
219    pub time_zone_behavior: String,
220}
221
222/// Exception information from server.
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ExceptionInfo {
225    #[serde(rename = "text")]
226    pub text: String,
227
228    #[serde(rename = "sqlCode")]
229    pub sql_code: String,
230}
231
232impl AuthResponse {
233    /// Check if authentication was successful.
234    pub fn is_success(&self) -> bool {
235        self.status == "ok" && self.response_data.is_some()
236    }
237
238    /// Get the session ID if authentication was successful.
239    pub fn session_id(&self) -> Option<&str> {
240        self.response_data
241            .as_ref()
242            .map(|data| data.session_id.as_str())
243    }
244
245    /// Get error message if authentication failed.
246    pub fn error_message(&self) -> Option<String> {
247        self.exception
248            .as_ref()
249            .map(|exc| format!("{} ({})", exc.text, exc.sql_code))
250    }
251}
252
253/// Handler for authentication protocol.
254pub struct AuthenticationHandler {
255    credentials: Credentials,
256    client_name: String,
257    client_version: String,
258}
259
260impl AuthenticationHandler {
261    /// Create a new authentication handler.
262    pub fn new(credentials: Credentials, client_name: String, client_version: String) -> Self {
263        Self {
264            credentials,
265            client_name,
266            client_version,
267        }
268    }
269
270    /// Build an authentication request message.
271    pub fn build_auth_request(&self) -> AuthRequest {
272        AuthRequest::new(
273            self.credentials.username().to_string(),
274            self.credentials.password().to_string(),
275        )
276        .with_client_info(self.client_name.clone(), self.client_version.clone())
277    }
278
279    /// Build a reconnection request with session ID.
280    pub fn build_reconnect_request(&self, session_id: String) -> AuthRequest {
281        self.build_auth_request().with_session_id(session_id)
282    }
283
284    /// Process authentication response.
285    pub fn process_auth_response(
286        &self,
287        response: AuthResponse,
288    ) -> Result<AuthResponseData, ConnectionError> {
289        if response.is_success() {
290            response.response_data.ok_or_else(|| {
291                ConnectionError::AuthenticationFailed(
292                    "Server returned success but no response data".to_string(),
293                )
294            })
295        } else {
296            let error_msg = response
297                .error_message()
298                .unwrap_or_else(|| "Unknown authentication error".to_string());
299            Err(ConnectionError::AuthenticationFailed(error_msg))
300        }
301    }
302
303    /// Get the credentials username.
304    pub fn username(&self) -> &str {
305        self.credentials.username()
306    }
307}
308
309impl fmt::Debug for AuthenticationHandler {
310    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311        f.debug_struct("AuthenticationHandler")
312            .field("credentials", &self.credentials)
313            .field("client_name", &self.client_name)
314            .field("client_version", &self.client_version)
315            .finish()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_credentials_no_password_leak() {
325        let creds = Credentials::new("admin".to_string(), "secret123".to_string());
326
327        let debug = format!("{:?}", creds);
328        assert!(!debug.contains("secret123"));
329        assert!(debug.contains("admin"));
330        assert!(debug.contains("redacted"));
331
332        let display = format!("{}", creds);
333        assert!(!display.contains("secret123"));
334        assert!(display.contains("admin"));
335    }
336
337    #[test]
338    fn test_credentials_access() {
339        let creds = Credentials::new("user".to_string(), "pass".to_string());
340
341        assert_eq!(creds.username(), "user");
342        assert_eq!(creds.password(), "pass");
343    }
344
345    #[test]
346    fn test_secure_string_zeros_on_drop() {
347        let _data = {
348            let secure = SecureString::new("password".to_string());
349            // Clone the underlying data to test it gets zeroed
350            secure.data.clone()
351        };
352
353        // After SecureString is dropped, original data should remain
354        // but a new SecureString drop should zero its own data
355        let secure = SecureString::new("test1234".to_string());
356        let _ptr = secure.data.as_ptr();
357        let _len = secure.data.len();
358
359        drop(secure);
360
361        // We can't easily verify the memory was zeroed without unsafe code,
362        // but we test the drop implementation runs without panic
363    }
364
365    #[test]
366    fn test_auth_request_creation() {
367        let req = AuthRequest::new("admin".to_string(), "secret".to_string());
368
369        assert_eq!(req.command, "login");
370        assert_eq!(req.username, "admin");
371        assert_eq!(req.password, "secret");
372        assert_eq!(req.driver_name, Some("exarrow-rs".to_string()));
373    }
374
375    #[test]
376    fn test_auth_request_with_client_info() {
377        let req = AuthRequest::new("admin".to_string(), "secret".to_string())
378            .with_client_info("test-client".to_string(), "1.0.0".to_string());
379
380        assert_eq!(req.client_name, Some("test-client".to_string()));
381        assert_eq!(req.client_version, Some("1.0.0".to_string()));
382    }
383
384    #[test]
385    fn test_auth_request_no_password_leak() {
386        let req = AuthRequest::new("admin".to_string(), "secret123".to_string());
387
388        let display = format!("{}", req);
389        assert!(!display.contains("secret123"));
390        assert!(display.contains("admin"));
391        assert!(display.contains("redacted"));
392    }
393
394    #[test]
395    fn test_auth_response_success() {
396        let response = AuthResponse {
397            status: "ok".to_string(),
398            response_data: Some(AuthResponseData {
399                session_id: "sess123".to_string(),
400                protocol_version: 3,
401                release_version: "7.1.0".to_string(),
402                database_name: "EXA".to_string(),
403                product_name: "Exasol".to_string(),
404                max_data_message_size: 4_194_304,
405                max_identifier_length: 128,
406                max_varchar_length: 2_000_000,
407                identifier_quote_string: "\"".to_string(),
408                time_zone: "UTC".to_string(),
409                time_zone_behavior: "INVALID TIMESTAMP TO DOUBLE".to_string(),
410            }),
411            exception: None,
412        };
413
414        assert!(response.is_success());
415        assert_eq!(response.session_id(), Some("sess123"));
416        assert!(response.error_message().is_none());
417    }
418
419    #[test]
420    fn test_auth_response_failure() {
421        let response = AuthResponse {
422            status: "error".to_string(),
423            response_data: None,
424            exception: Some(ExceptionInfo {
425                text: "Invalid credentials".to_string(),
426                sql_code: "08004".to_string(),
427            }),
428        };
429
430        assert!(!response.is_success());
431        assert!(response.session_id().is_none());
432        assert_eq!(
433            response.error_message(),
434            Some("Invalid credentials (08004)".to_string())
435        );
436    }
437
438    #[test]
439    fn test_auth_handler_build_request() {
440        let creds = Credentials::new("admin".to_string(), "secret".to_string());
441        let handler =
442            AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
443
444        let req = handler.build_auth_request();
445
446        assert_eq!(req.username, "admin");
447        assert_eq!(req.password, "secret");
448        assert_eq!(req.client_name, Some("test-client".to_string()));
449        assert_eq!(req.client_version, Some("1.0.0".to_string()));
450    }
451
452    #[test]
453    fn test_auth_handler_process_success() {
454        let creds = Credentials::new("admin".to_string(), "secret".to_string());
455        let handler =
456            AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
457
458        let response = AuthResponse {
459            status: "ok".to_string(),
460            response_data: Some(AuthResponseData {
461                session_id: "sess123".to_string(),
462                protocol_version: 3,
463                release_version: "7.1.0".to_string(),
464                database_name: "EXA".to_string(),
465                product_name: "Exasol".to_string(),
466                max_data_message_size: 4_194_304,
467                max_identifier_length: 128,
468                max_varchar_length: 2_000_000,
469                identifier_quote_string: "\"".to_string(),
470                time_zone: "UTC".to_string(),
471                time_zone_behavior: "INVALID TIMESTAMP TO DOUBLE".to_string(),
472            }),
473            exception: None,
474        };
475
476        let result = handler.process_auth_response(response);
477        assert!(result.is_ok());
478
479        let data = result.unwrap();
480        assert_eq!(data.session_id, "sess123");
481        assert_eq!(data.protocol_version, 3);
482    }
483
484    #[test]
485    fn test_auth_handler_process_failure() {
486        let creds = Credentials::new("admin".to_string(), "secret".to_string());
487        let handler =
488            AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
489
490        let response = AuthResponse {
491            status: "error".to_string(),
492            response_data: None,
493            exception: Some(ExceptionInfo {
494                text: "Invalid credentials".to_string(),
495                sql_code: "08004".to_string(),
496            }),
497        };
498
499        let result = handler.process_auth_response(response);
500        assert!(result.is_err());
501
502        match result.unwrap_err() {
503            ConnectionError::AuthenticationFailed(msg) => {
504                assert!(msg.contains("Invalid credentials"));
505            }
506            _ => panic!("Expected AuthenticationFailed error"),
507        }
508    }
509
510    #[test]
511    fn test_auth_handler_no_password_leak() {
512        let creds = Credentials::new("admin".to_string(), "super_secret".to_string());
513        let handler =
514            AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
515
516        let debug = format!("{:?}", handler);
517        assert!(!debug.contains("super_secret"));
518        assert!(debug.contains("admin"));
519        assert!(debug.contains("redacted"));
520    }
521
522    #[test]
523    fn test_reconnect_request() {
524        let creds = Credentials::new("admin".to_string(), "secret".to_string());
525        let handler =
526            AuthenticationHandler::new(creds, "test-client".to_string(), "1.0.0".to_string());
527
528        let req = handler.build_reconnect_request("old_session_123".to_string());
529
530        assert_eq!(req.session_id, Some("old_session_123".to_string()));
531        assert_eq!(req.username, "admin");
532    }
533
534    #[test]
535    fn test_credentials_clone() {
536        let creds = Credentials::new("user".to_string(), "pass".to_string());
537        let creds2 = creds.clone();
538
539        assert_eq!(creds.username(), creds2.username());
540        assert_eq!(creds.password(), creds2.password());
541    }
542}