Skip to main content

exarrow_rs/transport/
protocol.rs

1//! Transport protocol abstraction trait.
2//!
3//! This module defines the `TransportProtocol` trait that abstracts the underlying
4//! communication mechanism for connecting to Exasol. This allows for different
5//! protocol implementations (WebSocket in Phase 1, gRPC in Phase 2).
6
7use crate::error::TransportError;
8use async_trait::async_trait;
9
10use super::messages::{DataType, ResultData, ResultSetHandle, SessionInfo};
11
12/// Connection parameters for establishing a transport connection.
13#[derive(Debug, Clone)]
14pub struct ConnectionParams {
15    /// Database host
16    pub host: String,
17    /// Database port
18    pub port: u16,
19    /// Use TLS/SSL
20    pub use_tls: bool,
21    /// Validate server certificate (default: true)
22    pub validate_server_certificate: bool,
23    /// Connection timeout in milliseconds
24    pub timeout_ms: u64,
25}
26
27impl ConnectionParams {
28    /// Create new connection parameters.
29    pub fn new(host: String, port: u16) -> Self {
30        Self {
31            host,
32            port,
33            use_tls: true,
34            validate_server_certificate: true,
35            timeout_ms: 30_000, // 30 seconds default
36        }
37    }
38
39    /// Set whether to use TLS.
40    pub fn with_tls(mut self, use_tls: bool) -> Self {
41        self.use_tls = use_tls;
42        self
43    }
44
45    /// Set whether to validate the server certificate.
46    ///
47    /// # Security Warning
48    ///
49    /// Disabling certificate validation makes the connection vulnerable to
50    /// man-in-the-middle attacks. Only disable in development environments
51    /// with self-signed certificates.
52    pub fn with_validate_server_certificate(mut self, validate: bool) -> Self {
53        self.validate_server_certificate = validate;
54        self
55    }
56
57    /// Set connection timeout.
58    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
59        self.timeout_ms = timeout_ms;
60        self
61    }
62
63    /// Build the WebSocket URL from parameters.
64    pub fn to_websocket_url(&self) -> String {
65        let scheme = if self.use_tls { "wss" } else { "ws" };
66        format!("{}://{}:{}", scheme, self.host, self.port)
67    }
68}
69
70/// User credentials for authentication.
71#[derive(Debug, Clone)]
72pub struct Credentials {
73    /// Username
74    pub username: String,
75    /// Password (will be cleared after use)
76    pub password: String,
77}
78
79impl Credentials {
80    /// Create new credentials.
81    pub fn new(username: String, password: String) -> Self {
82        Self { username, password }
83    }
84}
85
86// Security: Implement Drop to clear password from memory
87impl Drop for Credentials {
88    fn drop(&mut self) {
89        // Clear password bytes (basic security measure)
90        // For production, consider using the `zeroize` crate
91        self.password.clear();
92    }
93}
94
95/// Handle to a prepared statement on the server.
96#[derive(Debug, Clone)]
97pub struct PreparedStatementHandle {
98    /// Server-side statement identifier
99    pub handle: i32,
100    /// Number of parameters expected
101    pub num_params: i32,
102    /// Parameter type information (if available)
103    pub parameter_types: Vec<DataType>,
104}
105
106impl PreparedStatementHandle {
107    /// Create a new prepared statement handle.
108    pub fn new(handle: i32, num_params: i32, parameter_types: Vec<DataType>) -> Self {
109        Self {
110            handle,
111            num_params,
112            parameter_types,
113        }
114    }
115}
116
117/// Transport protocol trait for database communication.
118///
119/// This trait abstracts the underlying transport mechanism, allowing for
120/// different implementations (WebSocket, gRPC, etc.).
121#[async_trait]
122pub trait TransportProtocol: Send + Sync {
123    /// Connect to the database server.
124    ///
125    /// # Arguments
126    ///
127    /// * `params` - Connection parameters
128    ///
129    /// # Errors
130    ///
131    /// Returns `TransportError` if connection fails.
132    async fn connect(&mut self, params: &ConnectionParams) -> Result<(), TransportError>;
133
134    /// Authenticate with the database.
135    ///
136    /// # Arguments
137    ///
138    /// * `credentials` - User credentials
139    ///
140    /// # Returns
141    ///
142    /// Session information on successful authentication.
143    ///
144    /// # Errors
145    ///
146    /// Returns `TransportError` if authentication fails.
147    async fn authenticate(
148        &mut self,
149        credentials: &Credentials,
150    ) -> Result<SessionInfo, TransportError>;
151
152    /// Execute a SQL query.
153    ///
154    /// # Arguments
155    ///
156    /// * `sql` - SQL statement to execute
157    ///
158    /// # Returns
159    ///
160    /// Result set handle for SELECT queries, or result data for other statements.
161    ///
162    /// # Errors
163    ///
164    /// Returns `TransportError` if execution fails.
165    async fn execute_query(&mut self, sql: &str) -> Result<QueryResult, TransportError>;
166
167    /// Fetch result data from a result set.
168    ///
169    /// # Arguments
170    ///
171    /// * `handle` - Result set handle from execute_query
172    ///
173    /// # Returns
174    ///
175    /// Result data containing rows and metadata.
176    ///
177    /// # Errors
178    ///
179    /// Returns `TransportError` if fetch fails.
180    async fn fetch_results(
181        &mut self,
182        handle: ResultSetHandle,
183    ) -> Result<ResultData, TransportError>;
184
185    /// Close a result set.
186    ///
187    /// # Arguments
188    ///
189    /// * `handle` - Result set handle to close
190    ///
191    /// # Errors
192    ///
193    /// Returns `TransportError` if close fails.
194    async fn close_result_set(&mut self, handle: ResultSetHandle) -> Result<(), TransportError>;
195
196    /// Create a prepared statement from SQL text.
197    ///
198    /// # Arguments
199    ///
200    /// * `sql` - SQL statement with parameter placeholders (?)
201    ///
202    /// # Returns
203    ///
204    /// A statement handle for later execution.
205    ///
206    /// # Errors
207    ///
208    /// Returns `TransportError` if statement creation fails.
209    async fn create_prepared_statement(
210        &mut self,
211        sql: &str,
212    ) -> Result<PreparedStatementHandle, TransportError>;
213
214    /// Execute a prepared statement with parameters.
215    ///
216    /// # Arguments
217    ///
218    /// * `handle` - Prepared statement handle from create_prepared_statement
219    /// * `parameters` - Parameters in column-major format (each inner Vec is a column)
220    ///
221    /// # Returns
222    ///
223    /// Query result (result set or row count).
224    ///
225    /// # Errors
226    ///
227    /// Returns `TransportError` if execution fails.
228    async fn execute_prepared_statement(
229        &mut self,
230        handle: &PreparedStatementHandle,
231        parameters: Option<Vec<Vec<serde_json::Value>>>,
232    ) -> Result<QueryResult, TransportError>;
233
234    /// Close a prepared statement and release server-side resources.
235    ///
236    /// # Arguments
237    ///
238    /// * `handle` - Prepared statement handle to close
239    ///
240    /// # Errors
241    ///
242    /// Returns `TransportError` if close fails.
243    async fn close_prepared_statement(
244        &mut self,
245        handle: &PreparedStatementHandle,
246    ) -> Result<(), TransportError>;
247
248    /// Close the connection.
249    ///
250    /// # Errors
251    ///
252    /// Returns `TransportError` if disconnect fails.
253    async fn close(&mut self) -> Result<(), TransportError>;
254
255    /// Check if the connection is still active.
256    fn is_connected(&self) -> bool;
257}
258
259/// Result of a query execution.
260#[derive(Debug, Clone)]
261pub enum QueryResult {
262    /// Result set from a SELECT query
263    ResultSet {
264        /// Handle for fetching more data (None if all data is in the response)
265        handle: Option<ResultSetHandle>,
266        /// Result data (may include first batch of rows)
267        data: ResultData,
268    },
269    /// Row count from an INSERT/UPDATE/DELETE query
270    RowCount {
271        /// Number of affected rows
272        count: i64,
273    },
274}
275
276impl QueryResult {
277    /// Create a result set query result.
278    pub fn result_set(handle: Option<ResultSetHandle>, data: ResultData) -> Self {
279        Self::ResultSet { handle, data }
280    }
281
282    /// Create a row count query result.
283    pub fn row_count(count: i64) -> Self {
284        Self::RowCount { count }
285    }
286
287    /// Check if this is a result set.
288    pub fn is_result_set(&self) -> bool {
289        matches!(self, Self::ResultSet { .. })
290    }
291
292    /// Check if this is a row count.
293    pub fn is_row_count(&self) -> bool {
294        matches!(self, Self::RowCount { .. })
295    }
296
297    /// Get the result set handle if this is a result set.
298    /// Returns None if all data was included in the initial response.
299    pub fn handle(&self) -> Option<ResultSetHandle> {
300        match self {
301            Self::ResultSet { handle, .. } => *handle,
302            _ => None,
303        }
304    }
305
306    /// Get the row count if this is a row count result.
307    pub fn get_row_count(&self) -> Option<i64> {
308        match self {
309            Self::RowCount { count } => Some(*count),
310            _ => None,
311        }
312    }
313
314    /// Check if this result has more data to fetch.
315    ///
316    /// For column-major data, the number of rows is determined by the length
317    /// of the first column (or 0 if no columns).
318    pub fn has_more_data(&self) -> bool {
319        match self {
320            Self::ResultSet { handle, data } => {
321                // Has more if there's a handle AND we have fewer rows than total
322                // For column-major data: rows = data[0].len() if data is not empty
323                let num_rows = if data.data.is_empty() {
324                    0
325                } else {
326                    data.data[0].len() as i64
327                };
328                handle.is_some() && num_rows < data.total_rows
329            }
330            _ => false,
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_connection_params_default() {
341        let params = ConnectionParams::new("localhost".to_string(), 8563);
342        assert_eq!(params.host, "localhost");
343        assert_eq!(params.port, 8563);
344        assert!(params.use_tls);
345        assert!(params.validate_server_certificate);
346        assert_eq!(params.timeout_ms, 30_000);
347    }
348
349    #[test]
350    fn test_connection_params_builder() {
351        let params = ConnectionParams::new("db.example.com".to_string(), 9000)
352            .with_tls(false)
353            .with_timeout(60_000);
354
355        assert_eq!(params.host, "db.example.com");
356        assert_eq!(params.port, 9000);
357        assert!(!params.use_tls);
358        assert!(params.validate_server_certificate);
359        assert_eq!(params.timeout_ms, 60_000);
360    }
361
362    #[test]
363    fn test_connection_params_validate_certificate_disabled() {
364        let params = ConnectionParams::new("localhost".to_string(), 8563)
365            .with_tls(true)
366            .with_validate_server_certificate(false);
367
368        assert!(params.use_tls);
369        assert!(!params.validate_server_certificate);
370    }
371
372    #[test]
373    fn test_websocket_url_with_tls() {
374        let params = ConnectionParams::new("localhost".to_string(), 8563).with_tls(true);
375        assert_eq!(params.to_websocket_url(), "wss://localhost:8563");
376    }
377
378    #[test]
379    fn test_websocket_url_without_tls() {
380        let params = ConnectionParams::new("localhost".to_string(), 8563).with_tls(false);
381        assert_eq!(params.to_websocket_url(), "ws://localhost:8563");
382    }
383
384    #[test]
385    fn test_credentials_creation() {
386        let creds = Credentials::new("user".to_string(), "pass".to_string());
387        assert_eq!(creds.username, "user");
388        assert_eq!(creds.password, "pass");
389    }
390
391    #[test]
392    fn test_credentials_drop_clears_password() {
393        let creds = Credentials::new("user".to_string(), "secret".to_string());
394        assert_eq!(creds.password, "secret");
395        drop(creds);
396        // Password should be cleared (can't test directly after drop)
397    }
398
399    #[test]
400    fn test_prepared_statement_handle_creation() {
401        let param_types = vec![
402            DataType {
403                type_name: "DECIMAL".to_string(),
404                precision: Some(18),
405                scale: Some(0),
406                size: None,
407                character_set: None,
408                with_local_time_zone: None,
409                fraction: None,
410            },
411            DataType {
412                type_name: "VARCHAR".to_string(),
413                precision: None,
414                scale: None,
415                size: Some(100),
416                character_set: Some("UTF8".to_string()),
417                with_local_time_zone: None,
418                fraction: None,
419            },
420        ];
421
422        let handle = PreparedStatementHandle::new(42, 2, param_types);
423        assert_eq!(handle.handle, 42);
424        assert_eq!(handle.num_params, 2);
425        assert_eq!(handle.parameter_types.len(), 2);
426        assert_eq!(handle.parameter_types[0].type_name, "DECIMAL");
427        assert_eq!(handle.parameter_types[1].type_name, "VARCHAR");
428    }
429
430    #[test]
431    fn test_prepared_statement_handle_no_params() {
432        let handle = PreparedStatementHandle::new(1, 0, vec![]);
433        assert_eq!(handle.handle, 1);
434        assert_eq!(handle.num_params, 0);
435        assert!(handle.parameter_types.is_empty());
436    }
437
438    #[test]
439    fn test_query_result_result_set() {
440        use super::super::messages::{ColumnInfo, DataType, ResultData};
441
442        let data = ResultData {
443            columns: vec![ColumnInfo {
444                name: "id".to_string(),
445                data_type: DataType {
446                    type_name: "DECIMAL".to_string(),
447                    precision: Some(18),
448                    scale: Some(0),
449                    size: None,
450                    character_set: None,
451                    with_local_time_zone: None,
452                    fraction: None,
453                },
454            }],
455            data: vec![], // Column-major: empty means no rows
456            total_rows: 0,
457        };
458
459        let result = QueryResult::result_set(Some(ResultSetHandle::new(1)), data);
460        assert!(result.is_result_set());
461        assert!(!result.is_row_count());
462        assert_eq!(result.handle().unwrap().as_i32(), 1);
463        assert!(result.get_row_count().is_none());
464    }
465
466    #[test]
467    fn test_query_result_row_count() {
468        let result = QueryResult::row_count(42);
469        assert!(!result.is_result_set());
470        assert!(result.is_row_count());
471        assert_eq!(result.get_row_count().unwrap(), 42);
472        assert!(result.handle().is_none());
473    }
474}