1use crate::error::TransportError;
8use async_trait::async_trait;
9
10use super::messages::{DataType, ResultData, ResultSetHandle, SessionInfo};
11
12#[derive(Debug, Clone)]
14pub struct ConnectionParams {
15 pub host: String,
17 pub port: u16,
19 pub use_tls: bool,
21 pub validate_server_certificate: bool,
23 pub timeout_ms: u64,
25}
26
27impl ConnectionParams {
28 pub fn new(host: String, port: u16) -> Self {
30 Self {
31 host,
32 port,
33 use_tls: true,
34 validate_server_certificate: true,
35 timeout_ms: 30_000, }
37 }
38
39 pub fn with_tls(mut self, use_tls: bool) -> Self {
41 self.use_tls = use_tls;
42 self
43 }
44
45 pub fn with_validate_server_certificate(mut self, validate: bool) -> Self {
53 self.validate_server_certificate = validate;
54 self
55 }
56
57 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
59 self.timeout_ms = timeout_ms;
60 self
61 }
62
63 pub fn to_websocket_url(&self) -> String {
65 let scheme = if self.use_tls { "wss" } else { "ws" };
66 format!("{}://{}:{}", scheme, self.host, self.port)
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct Credentials {
73 pub username: String,
75 pub password: String,
77}
78
79impl Credentials {
80 pub fn new(username: String, password: String) -> Self {
82 Self { username, password }
83 }
84}
85
86impl Drop for Credentials {
88 fn drop(&mut self) {
89 self.password.clear();
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct PreparedStatementHandle {
98 pub handle: i32,
100 pub num_params: i32,
102 pub parameter_types: Vec<DataType>,
104}
105
106impl PreparedStatementHandle {
107 pub fn new(handle: i32, num_params: i32, parameter_types: Vec<DataType>) -> Self {
109 Self {
110 handle,
111 num_params,
112 parameter_types,
113 }
114 }
115}
116
117#[async_trait]
122pub trait TransportProtocol: Send + Sync {
123 async fn connect(&mut self, params: &ConnectionParams) -> Result<(), TransportError>;
133
134 async fn authenticate(
148 &mut self,
149 credentials: &Credentials,
150 ) -> Result<SessionInfo, TransportError>;
151
152 async fn execute_query(&mut self, sql: &str) -> Result<QueryResult, TransportError>;
166
167 async fn fetch_results(
181 &mut self,
182 handle: ResultSetHandle,
183 ) -> Result<ResultData, TransportError>;
184
185 async fn close_result_set(&mut self, handle: ResultSetHandle) -> Result<(), TransportError>;
195
196 async fn create_prepared_statement(
210 &mut self,
211 sql: &str,
212 ) -> Result<PreparedStatementHandle, TransportError>;
213
214 async fn execute_prepared_statement(
229 &mut self,
230 handle: &PreparedStatementHandle,
231 parameters: Option<Vec<Vec<serde_json::Value>>>,
232 ) -> Result<QueryResult, TransportError>;
233
234 async fn close_prepared_statement(
244 &mut self,
245 handle: &PreparedStatementHandle,
246 ) -> Result<(), TransportError>;
247
248 async fn close(&mut self) -> Result<(), TransportError>;
254
255 fn is_connected(&self) -> bool;
257}
258
259#[derive(Debug, Clone)]
261pub enum QueryResult {
262 ResultSet {
264 handle: Option<ResultSetHandle>,
266 data: ResultData,
268 },
269 RowCount {
271 count: i64,
273 },
274}
275
276impl QueryResult {
277 pub fn result_set(handle: Option<ResultSetHandle>, data: ResultData) -> Self {
279 Self::ResultSet { handle, data }
280 }
281
282 pub fn row_count(count: i64) -> Self {
284 Self::RowCount { count }
285 }
286
287 pub fn is_result_set(&self) -> bool {
289 matches!(self, Self::ResultSet { .. })
290 }
291
292 pub fn is_row_count(&self) -> bool {
294 matches!(self, Self::RowCount { .. })
295 }
296
297 pub fn handle(&self) -> Option<ResultSetHandle> {
300 match self {
301 Self::ResultSet { handle, .. } => *handle,
302 _ => None,
303 }
304 }
305
306 pub fn get_row_count(&self) -> Option<i64> {
308 match self {
309 Self::RowCount { count } => Some(*count),
310 _ => None,
311 }
312 }
313
314 pub fn has_more_data(&self) -> bool {
319 match self {
320 Self::ResultSet { handle, data } => {
321 let num_rows = if data.data.is_empty() {
324 0
325 } else {
326 data.data[0].len() as i64
327 };
328 handle.is_some() && num_rows < data.total_rows
329 }
330 _ => false,
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_connection_params_default() {
341 let params = ConnectionParams::new("localhost".to_string(), 8563);
342 assert_eq!(params.host, "localhost");
343 assert_eq!(params.port, 8563);
344 assert!(params.use_tls);
345 assert!(params.validate_server_certificate);
346 assert_eq!(params.timeout_ms, 30_000);
347 }
348
349 #[test]
350 fn test_connection_params_builder() {
351 let params = ConnectionParams::new("db.example.com".to_string(), 9000)
352 .with_tls(false)
353 .with_timeout(60_000);
354
355 assert_eq!(params.host, "db.example.com");
356 assert_eq!(params.port, 9000);
357 assert!(!params.use_tls);
358 assert!(params.validate_server_certificate);
359 assert_eq!(params.timeout_ms, 60_000);
360 }
361
362 #[test]
363 fn test_connection_params_validate_certificate_disabled() {
364 let params = ConnectionParams::new("localhost".to_string(), 8563)
365 .with_tls(true)
366 .with_validate_server_certificate(false);
367
368 assert!(params.use_tls);
369 assert!(!params.validate_server_certificate);
370 }
371
372 #[test]
373 fn test_websocket_url_with_tls() {
374 let params = ConnectionParams::new("localhost".to_string(), 8563).with_tls(true);
375 assert_eq!(params.to_websocket_url(), "wss://localhost:8563");
376 }
377
378 #[test]
379 fn test_websocket_url_without_tls() {
380 let params = ConnectionParams::new("localhost".to_string(), 8563).with_tls(false);
381 assert_eq!(params.to_websocket_url(), "ws://localhost:8563");
382 }
383
384 #[test]
385 fn test_credentials_creation() {
386 let creds = Credentials::new("user".to_string(), "pass".to_string());
387 assert_eq!(creds.username, "user");
388 assert_eq!(creds.password, "pass");
389 }
390
391 #[test]
392 fn test_credentials_drop_clears_password() {
393 let creds = Credentials::new("user".to_string(), "secret".to_string());
394 assert_eq!(creds.password, "secret");
395 drop(creds);
396 }
398
399 #[test]
400 fn test_prepared_statement_handle_creation() {
401 let param_types = vec![
402 DataType {
403 type_name: "DECIMAL".to_string(),
404 precision: Some(18),
405 scale: Some(0),
406 size: None,
407 character_set: None,
408 with_local_time_zone: None,
409 fraction: None,
410 },
411 DataType {
412 type_name: "VARCHAR".to_string(),
413 precision: None,
414 scale: None,
415 size: Some(100),
416 character_set: Some("UTF8".to_string()),
417 with_local_time_zone: None,
418 fraction: None,
419 },
420 ];
421
422 let handle = PreparedStatementHandle::new(42, 2, param_types);
423 assert_eq!(handle.handle, 42);
424 assert_eq!(handle.num_params, 2);
425 assert_eq!(handle.parameter_types.len(), 2);
426 assert_eq!(handle.parameter_types[0].type_name, "DECIMAL");
427 assert_eq!(handle.parameter_types[1].type_name, "VARCHAR");
428 }
429
430 #[test]
431 fn test_prepared_statement_handle_no_params() {
432 let handle = PreparedStatementHandle::new(1, 0, vec![]);
433 assert_eq!(handle.handle, 1);
434 assert_eq!(handle.num_params, 0);
435 assert!(handle.parameter_types.is_empty());
436 }
437
438 #[test]
439 fn test_query_result_result_set() {
440 use super::super::messages::{ColumnInfo, DataType, ResultData};
441
442 let data = ResultData {
443 columns: vec![ColumnInfo {
444 name: "id".to_string(),
445 data_type: DataType {
446 type_name: "DECIMAL".to_string(),
447 precision: Some(18),
448 scale: Some(0),
449 size: None,
450 character_set: None,
451 with_local_time_zone: None,
452 fraction: None,
453 },
454 }],
455 data: vec![], total_rows: 0,
457 };
458
459 let result = QueryResult::result_set(Some(ResultSetHandle::new(1)), data);
460 assert!(result.is_result_set());
461 assert!(!result.is_row_count());
462 assert_eq!(result.handle().unwrap().as_i32(), 1);
463 assert!(result.get_row_count().is_none());
464 }
465
466 #[test]
467 fn test_query_result_row_count() {
468 let result = QueryResult::row_count(42);
469 assert!(!result.is_result_set());
470 assert!(result.is_row_count());
471 assert_eq!(result.get_row_count().unwrap(), 42);
472 assert!(result.handle().is_none());
473 }
474}