use crate::error::TransportError;
use async_trait::async_trait;
use super::messages::{DataType, ResultData, ResultSetHandle, SessionInfo};
#[derive(Debug, Clone)]
pub struct ConnectionParams {
pub host: String,
pub port: u16,
pub use_tls: bool,
pub validate_server_certificate: bool,
pub certificate_fingerprint: Option<String>,
pub timeout_ms: u64,
}
impl ConnectionParams {
pub fn new(host: String, port: u16) -> Self {
Self {
host,
port,
use_tls: true,
validate_server_certificate: true,
certificate_fingerprint: None,
timeout_ms: 30_000, }
}
pub fn with_tls(mut self, use_tls: bool) -> Self {
self.use_tls = use_tls;
self
}
pub fn with_validate_server_certificate(mut self, validate: bool) -> Self {
self.validate_server_certificate = validate;
self
}
pub fn with_certificate_fingerprint(mut self, fingerprint: String) -> Self {
self.certificate_fingerprint = Some(fingerprint);
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn to_websocket_url(&self) -> String {
let scheme = if self.use_tls { "wss" } else { "ws" };
format!("{}://{}:{}", scheme, self.host, self.port)
}
}
#[derive(Debug, Clone)]
pub struct Credentials {
pub username: String,
pub password: String,
}
impl Credentials {
pub fn new(username: String, password: String) -> Self {
Self { username, password }
}
}
impl Drop for Credentials {
fn drop(&mut self) {
self.password.clear();
}
}
#[derive(Debug, Clone)]
pub struct PreparedStatementHandle {
pub handle: i32,
pub num_params: i32,
pub parameter_types: Vec<DataType>,
pub parameter_names: Vec<Option<String>>,
}
impl PreparedStatementHandle {
pub fn new(
handle: i32,
num_params: i32,
parameter_types: Vec<DataType>,
parameter_names: Vec<Option<String>>,
) -> Self {
Self {
handle,
num_params,
parameter_types,
parameter_names,
}
}
}
#[async_trait]
pub trait TransportProtocol: Send + Sync {
async fn connect(&mut self, params: &ConnectionParams) -> Result<(), TransportError>;
async fn authenticate(
&mut self,
credentials: &Credentials,
) -> Result<SessionInfo, TransportError>;
async fn execute_query(&mut self, sql: &str) -> Result<QueryResult, TransportError>;
async fn fetch_results(
&mut self,
handle: ResultSetHandle,
) -> Result<ResultData, TransportError>;
async fn close_result_set(&mut self, handle: ResultSetHandle) -> Result<(), TransportError>;
async fn create_prepared_statement(
&mut self,
sql: &str,
) -> Result<PreparedStatementHandle, TransportError>;
async fn execute_prepared_statement(
&mut self,
handle: &PreparedStatementHandle,
parameters: Option<Vec<Vec<serde_json::Value>>>,
) -> Result<QueryResult, TransportError>;
async fn close_prepared_statement(
&mut self,
handle: &PreparedStatementHandle,
) -> Result<(), TransportError>;
async fn close(&mut self) -> Result<(), TransportError>;
fn is_connected(&self) -> bool;
async fn set_autocommit(&mut self, enabled: bool) -> Result<(), TransportError>;
}
#[derive(Debug, Clone)]
pub enum QueryResult {
ResultSet {
handle: Option<ResultSetHandle>,
data: ResultData,
},
RowCount {
count: i64,
},
}
impl QueryResult {
pub fn result_set(handle: Option<ResultSetHandle>, data: ResultData) -> Self {
Self::ResultSet { handle, data }
}
pub fn row_count(count: i64) -> Self {
Self::RowCount { count }
}
pub fn is_result_set(&self) -> bool {
matches!(self, Self::ResultSet { .. })
}
pub fn is_row_count(&self) -> bool {
matches!(self, Self::RowCount { .. })
}
pub fn handle(&self) -> Option<ResultSetHandle> {
match self {
Self::ResultSet { handle, .. } => *handle,
_ => None,
}
}
pub fn get_row_count(&self) -> Option<i64> {
match self {
Self::RowCount { count } => Some(*count),
_ => None,
}
}
pub fn has_more_data(&self) -> bool {
match self {
Self::ResultSet { handle, data } => {
let num_rows = if data.data.is_empty() {
0
} else {
data.data[0].len() as i64
};
handle.is_some() && num_rows < data.total_rows
}
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_params_default() {
let params = ConnectionParams::new("localhost".to_string(), 8563);
assert_eq!(params.host, "localhost");
assert_eq!(params.port, 8563);
assert!(params.use_tls);
assert!(params.validate_server_certificate);
assert_eq!(params.timeout_ms, 30_000);
}
#[test]
fn test_connection_params_builder() {
let params = ConnectionParams::new("db.example.com".to_string(), 9000)
.with_tls(false)
.with_timeout(60_000);
assert_eq!(params.host, "db.example.com");
assert_eq!(params.port, 9000);
assert!(!params.use_tls);
assert!(params.validate_server_certificate);
assert_eq!(params.timeout_ms, 60_000);
}
#[test]
fn test_connection_params_validate_certificate_disabled() {
let params = ConnectionParams::new("localhost".to_string(), 8563)
.with_tls(true)
.with_validate_server_certificate(false);
assert!(params.use_tls);
assert!(!params.validate_server_certificate);
}
#[test]
fn test_websocket_url_with_tls() {
let params = ConnectionParams::new("localhost".to_string(), 8563).with_tls(true);
assert_eq!(params.to_websocket_url(), "wss://localhost:8563");
}
#[test]
fn test_websocket_url_without_tls() {
let params = ConnectionParams::new("localhost".to_string(), 8563).with_tls(false);
assert_eq!(params.to_websocket_url(), "ws://localhost:8563");
}
#[test]
fn test_credentials_creation() {
let creds = Credentials::new("user".to_string(), "pass".to_string());
assert_eq!(creds.username, "user");
assert_eq!(creds.password, "pass");
}
#[test]
fn test_credentials_drop_clears_password() {
let creds = Credentials::new("user".to_string(), "secret".to_string());
assert_eq!(creds.password, "secret");
drop(creds);
}
#[test]
fn test_prepared_statement_handle_creation() {
let param_types = vec![
DataType {
type_name: "DECIMAL".to_string(),
precision: Some(18),
scale: Some(0),
size: None,
character_set: None,
with_local_time_zone: None,
fraction: None,
},
DataType {
type_name: "VARCHAR".to_string(),
precision: None,
scale: None,
size: Some(100),
character_set: Some("UTF8".to_string()),
with_local_time_zone: None,
fraction: None,
},
];
let handle = PreparedStatementHandle::new(42, 2, param_types, vec![]);
assert_eq!(handle.handle, 42);
assert_eq!(handle.num_params, 2);
assert_eq!(handle.parameter_types.len(), 2);
assert_eq!(handle.parameter_types[0].type_name, "DECIMAL");
assert_eq!(handle.parameter_types[1].type_name, "VARCHAR");
}
#[test]
fn test_prepared_statement_handle_no_params() {
let handle = PreparedStatementHandle::new(1, 0, vec![], vec![]);
assert_eq!(handle.handle, 1);
assert_eq!(handle.num_params, 0);
assert!(handle.parameter_types.is_empty());
}
#[test]
fn test_query_result_result_set() {
use super::super::messages::{ColumnInfo, DataType, ResultData};
let data = ResultData {
columns: vec![ColumnInfo {
name: "id".to_string(),
data_type: DataType {
type_name: "DECIMAL".to_string(),
precision: Some(18),
scale: Some(0),
size: None,
character_set: None,
with_local_time_zone: None,
fraction: None,
},
}],
data: vec![], total_rows: 0,
};
let result = QueryResult::result_set(Some(ResultSetHandle::new(1)), data);
assert!(result.is_result_set());
assert!(!result.is_row_count());
assert_eq!(result.handle().unwrap().as_i32(), 1);
assert!(result.get_row_count().is_none());
}
#[test]
fn test_query_result_row_count() {
let result = QueryResult::row_count(42);
assert!(!result.is_result_set());
assert!(result.is_row_count());
assert_eq!(result.get_row_count().unwrap(), 42);
assert!(result.handle().is_none());
}
}