use std::sync::Arc;
use async_trait::async_trait;
use aws_lc_rs::rand::{SecureRandom, SystemRandom};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use futures_util::{SinkExt, StreamExt};
use num_bigint::BigUint;
use rustls::pki_types::CertificateDer;
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async_tls_with_config,
tungstenite::{protocol::WebSocketConfig, Message},
Connector, MaybeTlsStream, WebSocketStream,
};
use crate::error::TransportError;
use super::messages::{
AuthRequest, ClosePreparedStatementRequest, ClosePreparedStatementResponse,
CloseResultSetRequest, CloseResultSetResponse, CreatePreparedStatementRequest,
CreatePreparedStatementResponse, DisconnectRequest, DisconnectResponse,
ExecutePreparedStatementRequest, ExecuteRequest, ExecuteResponse, FetchRequest, FetchResponse,
LoginInitRequest, LoginResponse, PublicKeyResponse, ResultData, ResultPayload, ResultSetHandle,
SessionInfo, SetAttributesRequest, SetAttributesResponse,
};
use super::protocol::{
ConnectionParams, Credentials, PreparedStatementHandle, QueryResult, TransportProtocol,
};
pub struct WebSocketTransport {
ws_stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
session_info: Option<SessionInfo>,
state: ConnectionState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ConnectionState {
Disconnected,
Connected,
Authenticated,
Closed,
}
impl WebSocketTransport {
pub fn new() -> Self {
Self {
ws_stream: None,
session_info: None,
state: ConnectionState::Disconnected,
}
}
async fn send_receive<T, R>(&mut self, request: &T) -> Result<R, TransportError>
where
T: serde::Serialize,
R: serde::de::DeserializeOwned,
{
let request_json = serde_json::to_string(request)?;
let ws_stream = self
.ws_stream
.as_mut()
.ok_or_else(|| TransportError::ProtocolError("Not connected".to_string()))?;
ws_stream
.send(Message::Text(request_json.into()))
.await
.map_err(|e| TransportError::SendError(e.to_string()))?;
loop {
let response_msg = ws_stream
.next()
.await
.ok_or_else(|| TransportError::ReceiveError("Connection closed".to_string()))?
.map_err(|e| TransportError::ReceiveError(e.to_string()))?;
let response_text = response_msg.to_text().map_err(|e| {
TransportError::ProtocolError(format!("Invalid message format: {}", e))
})?;
if !response_text.starts_with('{') {
continue;
}
let response: R = serde_json::from_str(response_text).map_err(|e| {
let preview: String = response_text.chars().take(200).collect();
TransportError::InvalidResponse(format!(
"JSON parse error: {}. Response preview: '{}'",
e, preview
))
})?;
return Ok(response);
}
}
fn check_status(
&self,
status: &str,
exception: &Option<super::messages::ExceptionInfo>,
) -> Result<(), TransportError> {
if status != "ok" {
let error_msg = exception
.as_ref()
.map(|e| {
format!(
"{} (SQL code: {})",
e.text,
e.sql_code.as_deref().unwrap_or("unknown")
)
})
.unwrap_or_else(|| "Unknown error".to_string());
return Err(TransportError::ProtocolError(error_msg));
}
Ok(())
}
fn encrypt_password(password: &str, public_key_pem: &str) -> Result<String, TransportError> {
let pkcs1_der = Self::pem_to_pkcs1_der(public_key_pem).map_err(|e| {
TransportError::ProtocolError(format!("Failed to parse RSA public key PEM: {}", e))
})?;
let (n, e) = Self::parse_rsa_public_key(&pkcs1_der).map_err(|err| {
TransportError::ProtocolError(format!("Failed to parse RSA public key: {}", err))
})?;
let bit_len = n.bits() as usize;
if !(1024..=8192).contains(&bit_len) {
return Err(TransportError::ProtocolError(format!(
"Unsupported RSA key size: {} bits (expected 1024..=8192)",
bit_len
)));
}
let k = bit_len.div_ceil(8);
let em = Self::pkcs1_v15_pad(password.as_bytes(), k).map_err(|err| {
TransportError::ProtocolError(format!("Failed to pad password: {}", err))
})?;
let m = BigUint::from_bytes_be(&em);
let c = m.modpow(&e, &n);
let mut ciphertext = vec![0u8; k];
let c_bytes = c.to_bytes_be();
ciphertext[k - c_bytes.len()..].copy_from_slice(&c_bytes);
Ok(STANDARD.encode(&ciphertext))
}
fn pem_to_pkcs1_der(pem: &str) -> Result<Vec<u8>, &'static str> {
let start_marker = "-----BEGIN RSA PUBLIC KEY-----";
let end_marker = "-----END RSA PUBLIC KEY-----";
let start = pem.find(start_marker).ok_or("Missing PEM start marker")? + start_marker.len();
let end = pem.find(end_marker).ok_or("Missing PEM end marker")?;
let base64_content: String = pem[start..end]
.chars()
.filter(|c| !c.is_whitespace())
.collect();
STANDARD
.decode(&base64_content)
.map_err(|_| "Invalid base64 in PEM")
}
fn parse_rsa_public_key(pkcs1_der: &[u8]) -> Result<(BigUint, BigUint), &'static str> {
let mut pos = 0;
if pkcs1_der.get(pos) != Some(&0x30) {
return Err("Invalid DER: expected SEQUENCE");
}
pos += 1;
let (seq_len, len_bytes) = Self::read_der_length(&pkcs1_der[pos..])?;
pos += len_bytes;
let seq_end = pos + seq_len;
if seq_end > pkcs1_der.len() {
return Err("Invalid DER: truncated");
}
let (n_bytes, consumed) = Self::read_der_integer(&pkcs1_der[pos..seq_end])?;
pos += consumed;
let (e_bytes, consumed) = Self::read_der_integer(&pkcs1_der[pos..seq_end])?;
pos += consumed;
if pos != seq_end {
return Err("Invalid DER: trailing bytes after SEQUENCE");
}
Ok((
BigUint::from_bytes_be(n_bytes),
BigUint::from_bytes_be(e_bytes),
))
}
fn read_der_length(data: &[u8]) -> Result<(usize, usize), &'static str> {
if data.is_empty() {
return Err("Invalid DER: truncated");
}
let first = data[0];
if first < 0x80 {
Ok((first as usize, 1))
} else {
let num_bytes = (first & 0x7F) as usize;
if num_bytes == 0 || num_bytes > 3 {
return Err("Invalid DER: truncated");
}
if data.len() < 1 + num_bytes {
return Err("Invalid DER: truncated");
}
let mut len: usize = 0;
for &b in &data[1..1 + num_bytes] {
len = (len << 8) | (b as usize);
}
Ok((len, 1 + num_bytes))
}
}
fn read_der_integer(data: &[u8]) -> Result<(&[u8], usize), &'static str> {
if data.is_empty() || data[0] != 0x02 {
return Err("Invalid DER: expected INTEGER");
}
let (int_len, len_bytes) = Self::read_der_length(&data[1..])?;
let header_len = 1 + len_bytes;
if data.len() < header_len + int_len {
return Err("Invalid DER: truncated");
}
let mut value = &data[header_len..header_len + int_len];
if value.len() > 1 && value[0] == 0x00 {
value = &value[1..];
}
Ok((value, header_len + int_len))
}
fn pkcs1_v15_pad(message: &[u8], k: usize) -> Result<Vec<u8>, &'static str> {
if message.len() > k.saturating_sub(11) {
return Err("Message too long for RSA modulus");
}
let ps_len = k - message.len() - 3;
let rng = SystemRandom::new();
let mut em = Vec::with_capacity(k);
em.push(0x00);
em.push(0x02);
for _ in 0..ps_len {
let mut byte = [0u8; 1];
loop {
rng.fill(&mut byte)
.map_err(|_| "Failed to generate random bytes")?;
if byte[0] != 0 {
break;
}
}
em.push(byte[0]);
}
em.push(0x00);
em.extend_from_slice(message);
Ok(em)
}
}
impl Default for WebSocketTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TransportProtocol for WebSocketTransport {
async fn connect(&mut self, params: &ConnectionParams) -> Result<(), TransportError> {
if self.state != ConnectionState::Disconnected {
return Err(TransportError::ProtocolError(
"Already connected".to_string(),
));
}
let url = params.to_websocket_url();
let connector = if params.use_tls {
let tls_connector = if let Some(ref fingerprint) = params.certificate_fingerprint {
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(FingerprintVerifier {
expected_fingerprint: fingerprint.clone(),
}))
.with_no_client_auth();
Connector::Rustls(Arc::new(config))
} else if params.validate_server_certificate {
let mut root_store = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs();
for cert in certs.certs {
let _ = root_store.add(cert);
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Connector::Rustls(Arc::new(config))
} else {
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
Connector::Rustls(Arc::new(config))
};
Some(tls_connector)
} else {
None
};
let mut ws_config = WebSocketConfig::default();
ws_config.max_frame_size = None;
ws_config.max_message_size = None;
let connect_future = connect_async_tls_with_config(
&url,
Some(ws_config),
false, connector, );
let (ws_stream, _) = tokio::time::timeout(
tokio::time::Duration::from_millis(params.timeout_ms),
connect_future,
)
.await
.map_err(|_| {
TransportError::IoError(format!("Connection timeout after {}ms", params.timeout_ms))
})?
.map_err(|e| TransportError::WebSocketError(e.to_string()))?;
self.ws_stream = Some(ws_stream);
self.state = ConnectionState::Connected;
Ok(())
}
async fn authenticate(
&mut self,
credentials: &Credentials,
) -> Result<SessionInfo, TransportError> {
if self.state != ConnectionState::Connected {
return Err(TransportError::ProtocolError(
"Must connect before authenticating".to_string(),
));
}
let init_request = LoginInitRequest::new();
let key_response: PublicKeyResponse = self.send_receive(&init_request).await?;
self.check_status(&key_response.status, &key_response.exception)?;
let key_data = key_response.response_data.ok_or_else(|| {
TransportError::InvalidResponse("Missing public key data in response".to_string())
})?;
let encrypted_password =
Self::encrypt_password(&credentials.password, &key_data.public_key_pem)?;
let auth_request = AuthRequest::new(
credentials.username.clone(),
encrypted_password,
"exarrow-rs".to_string(),
);
let login_response: LoginResponse = self.send_receive(&auth_request).await?;
self.check_status(&login_response.status, &login_response.exception)?;
let session_data = login_response
.response_data
.ok_or_else(|| TransportError::InvalidResponse("Missing response data".to_string()))?;
let session_info: SessionInfo = session_data.into();
self.session_info = Some(session_info.clone());
self.state = ConnectionState::Authenticated;
Ok(session_info)
}
async fn execute_query(&mut self, sql: &str) -> Result<QueryResult, TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before executing queries".to_string(),
));
}
let request = ExecuteRequest::new(sql.to_string());
let response: ExecuteResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
let response_data = response
.response_data
.ok_or_else(|| TransportError::InvalidResponse("Missing response data".to_string()))?;
if response_data.results.is_empty() {
return Err(TransportError::InvalidResponse(
"No results returned".to_string(),
));
}
let result = &response_data.results[0];
match result.result_type.as_str() {
"resultSet" => {
let result_set = result.result_set.as_ref().ok_or_else(|| {
TransportError::InvalidResponse(format!(
"Missing result_set data. Result: {:?}",
result
))
})?;
let columns = result_set.columns.clone().ok_or_else(|| {
TransportError::InvalidResponse("Missing columns".to_string())
})?;
let data_values = result_set.data.clone().unwrap_or_default();
let total_rows = result_set.num_rows.unwrap_or(0);
let data = ResultData {
columns,
data: ResultPayload::Json(data_values),
total_rows,
};
let handle = result_set.result_set_handle.map(ResultSetHandle::new);
Ok(QueryResult::result_set(handle, data))
}
"rowCount" => {
let count = result.row_count.unwrap_or(0);
Ok(QueryResult::row_count(count))
}
other => Err(TransportError::InvalidResponse(format!(
"Unknown result type: {}",
other
))),
}
}
async fn fetch_results(
&mut self,
handle: ResultSetHandle,
) -> Result<ResultData, TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before fetching results".to_string(),
));
}
let max_bytes = self
.session_info
.as_ref()
.map(|s| s.max_data_message_size)
.unwrap_or(1024 * 1024);
let request = FetchRequest::new(handle.as_i32(), 0, max_bytes);
let response: FetchResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
let fetch_data = response
.response_data
.ok_or_else(|| TransportError::InvalidResponse("Missing response data".to_string()))?;
Ok(ResultData {
columns: vec![],
data: ResultPayload::Json(fetch_data.data),
total_rows: fetch_data.num_rows,
})
}
async fn close_result_set(&mut self, handle: ResultSetHandle) -> Result<(), TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before closing result sets".to_string(),
));
}
let request = CloseResultSetRequest::new(vec![handle.as_i32()]);
let response: CloseResultSetResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
Ok(())
}
async fn create_prepared_statement(
&mut self,
sql: &str,
) -> Result<PreparedStatementHandle, TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before creating prepared statements".to_string(),
));
}
let request = CreatePreparedStatementRequest::new(sql);
let response: CreatePreparedStatementResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
let response_data = response
.response_data
.ok_or_else(|| TransportError::InvalidResponse("Missing response data".to_string()))?;
let (num_params, parameter_types, parameter_names) =
if let Some(param_data) = response_data.parameter_data {
let mut types = Vec::with_capacity(param_data.columns.len());
let mut names = Vec::with_capacity(param_data.columns.len());
for p in param_data.columns {
types.push(p.data_type);
names.push(p.name);
}
(param_data.num_columns, types, names)
} else {
(0, vec![], vec![])
};
Ok(PreparedStatementHandle::new(
response_data.statement_handle,
num_params,
parameter_types,
parameter_names,
))
}
async fn execute_prepared_statement(
&mut self,
handle: &PreparedStatementHandle,
parameters: Option<Vec<Vec<serde_json::Value>>>,
) -> Result<QueryResult, TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before executing prepared statements".to_string(),
));
}
let mut request = ExecutePreparedStatementRequest::new(handle.handle);
if let Some(ref data) = parameters {
let columns: Vec<_> = if handle.parameter_types.is_empty() {
data.iter()
.enumerate()
.map(|(i, col_values)| {
let data_type = col_values
.first()
.map(super::messages::DataType::infer_from_json)
.unwrap_or_else(|| super::messages::DataType::varchar(2_000_000));
super::messages::ColumnInfo {
name: format!("param{}", i),
data_type,
}
})
.collect()
} else {
handle
.parameter_types
.iter()
.enumerate()
.map(|(i, dt)| super::messages::ColumnInfo {
name: format!("param{}", i),
data_type: dt.clone(),
})
.collect()
};
request = request.with_data(columns, data.clone());
}
let response: ExecuteResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
let response_data = response
.response_data
.ok_or_else(|| TransportError::InvalidResponse("Missing response data".to_string()))?;
if response_data.results.is_empty() {
return Err(TransportError::InvalidResponse(
"No results returned".to_string(),
));
}
let result = &response_data.results[0];
match result.result_type.as_str() {
"resultSet" => {
let result_set = result.result_set.as_ref().ok_or_else(|| {
TransportError::InvalidResponse(format!(
"Missing result_set data. Result: {:?}",
result
))
})?;
let columns = result_set.columns.clone().ok_or_else(|| {
TransportError::InvalidResponse("Missing columns".to_string())
})?;
let data_values = result_set.data.clone().unwrap_or_default();
let total_rows = result_set.num_rows.unwrap_or(0);
let data = ResultData {
columns,
data: ResultPayload::Json(data_values),
total_rows,
};
let handle = result_set.result_set_handle.map(ResultSetHandle::new);
Ok(QueryResult::result_set(handle, data))
}
"rowCount" => {
let count = result.row_count.unwrap_or(0);
Ok(QueryResult::row_count(count))
}
other => Err(TransportError::InvalidResponse(format!(
"Unknown result type: {}",
other
))),
}
}
async fn close_prepared_statement(
&mut self,
handle: &PreparedStatementHandle,
) -> Result<(), TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before closing prepared statements".to_string(),
));
}
let request = ClosePreparedStatementRequest::new(handle.handle);
let response: ClosePreparedStatementResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
Ok(())
}
async fn close(&mut self) -> Result<(), TransportError> {
if self.state == ConnectionState::Disconnected || self.state == ConnectionState::Closed {
return Ok(());
}
if self.state == ConnectionState::Authenticated {
let request = DisconnectRequest::new();
let _ = self.send_receive::<_, DisconnectResponse>(&request).await;
}
if let Some(mut ws_stream) = self.ws_stream.take() {
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
let close_frame = CloseFrame {
code: CloseCode::Normal,
reason: "Client closing connection".into(),
};
let _ = ws_stream.close(Some(close_frame)).await;
use futures_util::StreamExt;
while let Ok(Some(_)) =
tokio::time::timeout(std::time::Duration::from_millis(100), ws_stream.next()).await
{
}
}
self.state = ConnectionState::Closed;
self.session_info = None;
Ok(())
}
fn is_connected(&self) -> bool {
matches!(
self.state,
ConnectionState::Connected | ConnectionState::Authenticated
)
}
async fn set_autocommit(&mut self, enabled: bool) -> Result<(), TransportError> {
if self.state != ConnectionState::Authenticated {
return Err(TransportError::ProtocolError(
"Must authenticate before setting attributes".to_string(),
));
}
let request = SetAttributesRequest::autocommit(enabled);
let response: SetAttributesResponse = self.send_receive(&request).await?;
self.check_status(&response.status, &response.exception)?;
Ok(())
}
}
fn all_supported_verify_schemes() -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
all_supported_verify_schemes()
}
}
#[derive(Debug)]
struct FingerprintVerifier {
expected_fingerprint: String,
}
impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
use aws_lc_rs::digest;
let fingerprint = digest::digest(&digest::SHA256, end_entity.as_ref());
let actual: String = fingerprint
.as_ref()
.iter()
.map(|b| format!("{:02x}", b))
.collect();
if actual == self.expected_fingerprint {
Ok(rustls::client::danger::ServerCertVerified::assertion())
} else {
Err(rustls::Error::General(format!(
"Certificate fingerprint mismatch: expected {}, got {}",
self.expected_fingerprint, actual
)))
}
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
all_supported_verify_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_websocket_transport_new() {
let transport = WebSocketTransport::new();
assert!(!transport.is_connected());
assert_eq!(transport.state, ConnectionState::Disconnected);
}
#[test]
fn test_websocket_transport_default() {
let transport = WebSocketTransport::default();
assert!(!transport.is_connected());
}
#[test]
fn test_connection_state_transitions() {
let mut transport = WebSocketTransport::new();
assert_eq!(transport.state, ConnectionState::Disconnected);
transport.state = ConnectionState::Connected;
assert!(transport.is_connected());
transport.state = ConnectionState::Authenticated;
assert!(transport.is_connected());
transport.state = ConnectionState::Closed;
assert!(!transport.is_connected());
}
#[tokio::test]
async fn test_connect_requires_disconnected_state() {
let mut transport = WebSocketTransport::new();
transport.state = ConnectionState::Connected;
let params = ConnectionParams::new("localhost".to_string(), 8563);
let result = transport.connect(¶ms).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Already connected"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_authenticate_requires_connected_state() {
let mut transport = WebSocketTransport::new();
assert_eq!(transport.state, ConnectionState::Disconnected);
let credentials = Credentials::new("user".to_string(), "pass".to_string());
let result = transport.authenticate(&credentials).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must connect before authenticating"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_execute_query_requires_authenticated_state() {
let mut transport = WebSocketTransport::new();
let result = transport.execute_query("SELECT 1").await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before executing queries"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_create_prepared_statement_requires_authenticated_state() {
let mut transport = WebSocketTransport::new();
let result = transport
.create_prepared_statement("SELECT * FROM test WHERE id = ?")
.await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before creating prepared statements"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_execute_prepared_statement_requires_authenticated_state() {
let mut transport = WebSocketTransport::new();
let handle = PreparedStatementHandle::new(1, 0, vec![], vec![]);
let result = transport.execute_prepared_statement(&handle, None).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before executing prepared statements"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_close_prepared_statement_requires_authenticated_state() {
let mut transport = WebSocketTransport::new();
let handle = PreparedStatementHandle::new(1, 0, vec![], vec![]);
let result = transport.close_prepared_statement(&handle).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before closing prepared statements"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_close_idempotent() {
let mut transport = WebSocketTransport::new();
let result = transport.close().await;
assert!(result.is_ok());
assert_eq!(transport.state, ConnectionState::Disconnected);
transport.state = ConnectionState::Closed;
let result = transport.close().await;
assert!(result.is_ok());
}
#[test]
fn test_check_status_ok() {
let transport = WebSocketTransport::new();
let result = transport.check_status("ok", &None);
assert!(result.is_ok());
}
#[test]
fn test_check_status_error() {
use super::super::messages::ExceptionInfo;
let transport = WebSocketTransport::new();
let exception = Some(ExceptionInfo {
sql_code: Some("42000".to_string()),
text: "Syntax error".to_string(),
});
let result = transport.check_status("error", &exception);
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Syntax error"));
assert!(msg.contains("42000"));
} else {
panic!("Expected ProtocolError");
}
}
#[test]
fn test_encrypt_password_with_valid_key() {
let test_public_key_pem = r#"-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAulMKxKfPd02qNEVCU1M6hG/Vc9xz+0u+N47Qqa1Y0E2A5bDiz3XA
aCg2d65C7DyuTL38zwmOtjagvvIAgRj9yDf0v1/v9e1X4l5XE6UiaKKqdcXNy6lJ
QspqkOBUptlz+2h/G8Z12++xUo/4AGAGz9ZkrRRvcTGW1GJhCROizeJhTpGMpc/v
o1G53uy2eTHwnz5S3YgJF7nfX60wjJ99ifQuQ9BhDIYLNqzwHTzExMN63v0UOBIL
vJ+yVUqh0/T2f5e9E1lDNuIqLyXe8VwwUsS72A1EGtg0s77+xUQ7KiGRbHD4bsBo
A74EI7MHQ7163wVPT0VWFRvUmmv+UO7W8wIDAQAB
-----END RSA PUBLIC KEY-----"#;
let result = WebSocketTransport::encrypt_password("test_password", test_public_key_pem);
assert!(
result.is_ok(),
"encrypt_password failed: {:?}",
result.err()
);
let encrypted = result.unwrap();
assert!(!encrypted.is_empty());
let decoded = STANDARD.decode(&encrypted);
assert!(decoded.is_ok());
assert_eq!(decoded.unwrap().len(), 256);
let der = WebSocketTransport::pem_to_pkcs1_der(test_public_key_pem).unwrap();
let (n, e) = WebSocketTransport::parse_rsa_public_key(&der).unwrap();
assert!(
n.bits() > 2040 && n.bits() <= 2048,
"Expected ~2048-bit modulus, got {} bits",
n.bits()
);
assert_eq!(
e,
BigUint::from(65537u32),
"Expected standard RSA exponent 65537"
);
}
#[test]
fn test_encrypt_password_with_invalid_key() {
let invalid_pem = "not a valid PEM key";
let result = WebSocketTransport::encrypt_password("password", invalid_pem);
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Failed to parse RSA public key"));
} else {
panic!("Expected ProtocolError");
}
}
#[test]
fn test_encrypt_password_with_1024_bit_key() {
let pem_1024 = r#"-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAMh0gXUltxbJYmwQUIvXPLl9Y8bGaGFN/urgclF3Czd7viHaAMuanebQ
62s1mLV0vXaYSZk8Zsrat2T/i7jbPE0XKpVUgmnlT/CHXv6gPdTpOr3JTpo/lop0
t6J/6xJBNQDp6OrFMtTTq2M3zxSfcomlT4Q759uuGkEdM9crb8A9AgMBAAE=
-----END RSA PUBLIC KEY-----"#;
let result = WebSocketTransport::encrypt_password("hunter2", pem_1024);
assert!(
result.is_ok(),
"encrypt_password with 1024-bit key failed: {:?}",
result.err()
);
let encrypted = result.unwrap();
let decoded = STANDARD.decode(&encrypted).expect("valid base64");
assert_eq!(
decoded.len(),
128,
"1024-bit RSA ciphertext must be 128 bytes, got {}",
decoded.len()
);
let result2 = WebSocketTransport::encrypt_password("different_password", pem_1024);
assert!(result2.is_ok());
let encrypted2 = result2.unwrap();
assert_ne!(
encrypted, encrypted2,
"Different passwords must produce different ciphertexts"
);
let result3 = WebSocketTransport::encrypt_password("hunter2", pem_1024);
assert!(result3.is_ok());
let encrypted3 = result3.unwrap();
assert_ne!(
encrypted, encrypted3,
"Same password with random padding should produce different ciphertexts"
);
}
#[test]
fn test_encrypt_password_rejects_too_small_key() {
let pem_512 = r#"-----BEGIN RSA PUBLIC KEY-----
MEgCQQCrHBlqjk3p2boRBFTAZdqcWNU+g5LjKicoOX5UIyIanBCV5fgbtoRCCvBr
++vdlAaIAcJx5iKBMp1obShMOPwVAgMBAAE=
-----END RSA PUBLIC KEY-----"#;
let result = WebSocketTransport::encrypt_password("pw", pem_512);
assert!(result.is_err(), "512-bit key should be rejected");
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(
msg.contains("Unsupported RSA key size"),
"Error should mention unsupported key size, got: {}",
msg
);
assert!(
msg.contains("512"),
"Error should mention 512 bits, got: {}",
msg
);
} else {
panic!("Expected ProtocolError");
}
}
#[test]
fn test_parse_rsa_public_key_roundtrip() {
let test_public_key_pem = r#"-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAulMKxKfPd02qNEVCU1M6hG/Vc9xz+0u+N47Qqa1Y0E2A5bDiz3XA
aCg2d65C7DyuTL38zwmOtjagvvIAgRj9yDf0v1/v9e1X4l5XE6UiaKKqdcXNy6lJ
QspqkOBUptlz+2h/G8Z12++xUo/4AGAGz9ZkrRRvcTGW1GJhCROizeJhTpGMpc/v
o1G53uy2eTHwnz5S3YgJF7nfX60wjJ99ifQuQ9BhDIYLNqzwHTzExMN63v0UOBIL
vJ+yVUqh0/T2f5e9E1lDNuIqLyXe8VwwUsS72A1EGtg0s77+xUQ7KiGRbHD4bsBo
A74EI7MHQ7163wVPT0VWFRvUmmv+UO7W8wIDAQAB
-----END RSA PUBLIC KEY-----"#;
let der = WebSocketTransport::pem_to_pkcs1_der(test_public_key_pem).unwrap();
let (n, e) = WebSocketTransport::parse_rsa_public_key(&der).unwrap();
assert_eq!(
e,
BigUint::from(65537u32),
"Expected standard RSA exponent 65537"
);
let n_bytes = n.to_bytes_be();
assert_eq!(
n_bytes.len(),
256,
"2048-bit modulus should be 256 bytes in big-endian, got {}",
n_bytes.len()
);
assert!(
n_bytes[0] & 0x80 != 0,
"High byte of 2048-bit modulus should have top bit set, got 0x{:02x}",
n_bytes[0]
);
assert!(
n.bits() > 2040 && n.bits() <= 2048,
"Expected ~2048-bit modulus, got {} bits",
n.bits()
);
}
#[test]
fn test_encrypt_password_with_1024_bit_key_decrypts_correctly() {
let pub_pem = r#"-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAM5D4w7VjlVI8hUmJj8MXx4glxP3dqroATH1hito7CzfGGD/Ss73lp/n
0JLwI4oT6g0wBiMlLkPY6C+hfTI2x/UhJE8gKhz6URld6uQ29d0PAq4OvsZW5Rhz
nuIWmHv9WKvS/5DVcNbBtkiTJ9BDnsSQW/YqA4DQGmzph4ThqEkJAgMBAAE=
-----END RSA PUBLIC KEY-----"#;
let priv_pem = r#"-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDOQ+MO1Y5VSPIVJiY/DF8eIJcT93aq6AEx9YYraOws3xhg/0rO
95af59CS8COKE+oNMAYjJS5D2OgvoX0yNsf1ISRPICoc+lEZXerkNvXdDwKuDr7G
VuUYc57iFph7/Vir0v+Q1XDWwbZIkyfQQ57EkFv2KgOA0Bps6YeE4ahJCQIDAQAB
AoGAVYDAu+J86Q+fAnNZAWPAfj2mQumfMIOSE0KjBpWs6YDlmzfYq+jocIro5DBV
myRcLnFM6f68qfVdcnkv68PXqSA7acyTtAKSIJAgN1xiYELuRVWMk/+UVgGhRpcH
rY4sTIwPM5b9r6JA++6PX13b8qqybPijf/Lz5urEbU3oPu0CQQDmZrbz623uijm4
ifEhk6f+Gq4spF5tUHVwY/GlfdtaDPr/wSTBkeKweZIAd9LJh8iv7Il2tYNKzzot
BTMut3VnAkEA5S6sKNaeLQevYy7N2zKm2SdSKKo1Sh48+oLYd2CZQGG5jDRom6p/
e+4hol1jGDBwvx0sMrFCFsoQXRiP2etYDwJBANSDj2LjF837Xwww5+IhkMVXlKoG
njZUDU6yUPRlZwrjiCyY2S9WQXKnX5zg6OMMRHbIRW7iM4ywIafe8Pu5KicCQQCE
rB8nyQ5qfP9wSGENWuYx4cxzFA2jaZvdXa/Yc8hj9+7FFnXUX8BLSxCXgL5j+27Z
hBbZBbp/nNwaOKTV/6LLAkAjNOwM7VST6Medyd2lNpsCjmYAoF0eUf8Z8ZJPaZeo
El6NrMeFybqeqwjPHPG1oCwg4YIeaT8ZB2qUW143brUB
-----END RSA PRIVATE KEY-----"#;
let password = "exasol_secret_123";
let encrypted = WebSocketTransport::encrypt_password(password, pub_pem)
.expect("encryption should succeed");
let ciphertext = STANDARD.decode(&encrypted).expect("valid base64");
assert_eq!(
ciphertext.len(),
128,
"1024-bit RSA ciphertext must be 128 bytes"
);
let (n, d) = parse_pkcs1_private_key_n_d(priv_pem);
let c = BigUint::from_bytes_be(&ciphertext);
let m = c.modpow(&d, &n);
let k = 128; let mut em = vec![0u8; k];
let m_bytes = m.to_bytes_be();
em[k - m_bytes.len()..].copy_from_slice(&m_bytes);
assert_eq!(em[0], 0x00, "EM must start with 0x00");
assert_eq!(em[1], 0x02, "EM block type must be 0x02");
let separator_pos = em[2..]
.iter()
.position(|&b| b == 0x00)
.expect("Must find 0x00 separator after PS padding");
let message_start = 2 + separator_pos + 1;
for (i, &b) in em[2..2 + separator_pos].iter().enumerate() {
assert_ne!(b, 0x00, "PS byte at position {} must be non-zero", i);
}
assert!(
separator_pos >= 8,
"PS must be at least 8 bytes, got {}",
separator_pos
);
let decrypted = &em[message_start..];
assert_eq!(
decrypted,
password.as_bytes(),
"Decrypted message must match original password"
);
}
fn parse_pkcs1_private_key_n_d(pem: &str) -> (BigUint, BigUint) {
let start_marker = "-----BEGIN RSA PRIVATE KEY-----";
let end_marker = "-----END RSA PRIVATE KEY-----";
let start = pem.find(start_marker).unwrap() + start_marker.len();
let end = pem.find(end_marker).unwrap();
let base64_content: String = pem[start..end]
.chars()
.filter(|c| !c.is_whitespace())
.collect();
let der = STANDARD.decode(&base64_content).unwrap();
let mut pos = 0;
assert_eq!(der[pos], 0x30);
pos += 1;
let (_, len_bytes) = read_test_der_length(&der[pos..]);
pos += len_bytes;
let (_, consumed) = read_test_der_integer(&der[pos..]);
pos += consumed;
let (n_bytes, consumed) = read_test_der_integer(&der[pos..]);
pos += consumed;
let n = BigUint::from_bytes_be(n_bytes);
let (_, consumed) = read_test_der_integer(&der[pos..]);
pos += consumed;
let (d_bytes, _) = read_test_der_integer(&der[pos..]);
let d = BigUint::from_bytes_be(d_bytes);
(n, d)
}
fn read_test_der_length(data: &[u8]) -> (usize, usize) {
let first = data[0];
if first < 0x80 {
(first as usize, 1)
} else {
let num_bytes = (first & 0x7F) as usize;
let mut len: usize = 0;
for &b in &data[1..1 + num_bytes] {
len = (len << 8) | (b as usize);
}
(len, 1 + num_bytes)
}
}
fn read_test_der_integer(data: &[u8]) -> (&[u8], usize) {
assert_eq!(data[0], 0x02, "Expected INTEGER tag");
let (int_len, len_bytes) = read_test_der_length(&data[1..]);
let header_len = 1 + len_bytes;
let mut value = &data[header_len..header_len + int_len];
if value.len() > 1 && value[0] == 0x00 {
value = &value[1..];
}
(value, header_len + int_len)
}
#[test]
fn test_check_status_error_with_none_exception() {
let transport = WebSocketTransport::new();
let result = transport.check_status("error", &None);
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert_eq!(msg, "Unknown error");
} else {
panic!("Expected ProtocolError with 'Unknown error' message");
}
}
#[test]
fn test_check_status_error_with_exception_missing_sql_code() {
use super::super::messages::ExceptionInfo;
let transport = WebSocketTransport::new();
let exception = Some(ExceptionInfo {
sql_code: None,
text: "Some database error".to_string(),
});
let result = transport.check_status("error", &exception);
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Some database error"));
assert!(msg.contains("unknown")); } else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_fetch_results_requires_authenticated_state() {
let mut transport = WebSocketTransport::new();
let handle = super::super::messages::ResultSetHandle::new(1);
let result = transport.fetch_results(handle).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before fetching results"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_fetch_results_requires_authenticated_state_from_connected() {
let mut transport = WebSocketTransport::new();
transport.state = ConnectionState::Connected;
let handle = super::super::messages::ResultSetHandle::new(1);
let result = transport.fetch_results(handle).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before fetching results"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_close_result_set_requires_authenticated_state() {
let mut transport = WebSocketTransport::new();
let handle = super::super::messages::ResultSetHandle::new(1);
let result = transport.close_result_set(handle).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before closing result sets"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_close_result_set_requires_authenticated_state_from_connected() {
let mut transport = WebSocketTransport::new();
transport.state = ConnectionState::Connected;
let handle = super::super::messages::ResultSetHandle::new(1);
let result = transport.close_result_set(handle).await;
assert!(result.is_err());
if let Err(TransportError::ProtocolError(msg)) = result {
assert!(msg.contains("Must authenticate before closing result sets"));
} else {
panic!("Expected ProtocolError");
}
}
#[tokio::test]
async fn test_close_from_connected_state_succeeds() {
let mut transport = WebSocketTransport::new();
transport.state = ConnectionState::Connected;
let result = transport.close().await;
assert!(result.is_ok());
assert_eq!(transport.state, ConnectionState::Closed);
}
#[tokio::test]
async fn test_close_clears_session_info() {
use super::super::messages::SessionInfo;
let mut transport = WebSocketTransport::new();
transport.state = ConnectionState::Connected;
transport.session_info = Some(SessionInfo {
session_id: "12345".to_string(),
protocol_version: 3,
release_version: "7.1.0".to_string(),
database_name: "test_db".to_string(),
product_name: "EXASolution".to_string(),
max_data_message_size: 1024 * 1024,
time_zone: Some("UTC".to_string()),
});
let result = transport.close().await;
assert!(result.is_ok());
assert!(transport.session_info.is_none());
}
}