mod extend;
pub mod socket;
mod tls;
#[cfg(feature = "etl")]
use std::net::IpAddr;
use std::{borrow::Cow, fmt::Debug, net::SocketAddr};
use extend::WebSocketExt;
use futures_util::{io::BufReader, Future};
use lru::LruCache;
use serde::de::{DeserializeOwned, IgnoredAny};
use socket::ExaSocket;
use sqlx_core::Error as SqlxError;
pub use tls::WithMaybeTlsExaSocket;
use self::extend::PlainWebSocket;
use super::stream::QueryResultStream;
#[cfg(feature = "etl")]
use crate::responses::Hosts;
use crate::{
command::{Command, ExaCommand},
error::{ExaProtocolError, ExaResultExt},
options::ExaConnectOptionsRef,
responses::{
DataChunk, DescribeStatement, ExaAttributes, PreparedStatement, QueryResult, Results,
SessionInfo,
},
};
#[derive(Debug)]
pub struct ExaWebSocket {
pub ws: WebSocketExt,
pub attributes: ExaAttributes,
pub pending_rollback: bool,
}
impl ExaWebSocket {
const WS_SCHEME: &'static str = "ws";
const WSS_SCHEME: &'static str = "wss";
pub(crate) async fn new(
host: &str,
port: u16,
socket: ExaSocket,
options: ExaConnectOptionsRef<'_>,
with_tls: bool,
) -> Result<(Self, SessionInfo), SqlxError> {
let scheme = match with_tls {
true => Self::WSS_SCHEME,
false => Self::WS_SCHEME,
};
let host = format!("{scheme}://{host}:{port}");
let (ws, _) = async_tungstenite::client_async(host, BufReader::new(socket))
.await
.to_sqlx_err()?;
let attributes = ExaAttributes {
compression_enabled: options.compression,
fetch_size: options.fetch_size,
encryption_enabled: with_tls,
statement_cache_capacity: options.statement_cache_capacity,
..Default::default()
};
let mut plain_ws = PlainWebSocket(ws);
let session_info = plain_ws.login(options).await?;
let ws = WebSocketExt::new(plain_ws.0, attributes.compression_enabled);
let mut this = Self {
ws,
attributes,
pending_rollback: false,
};
this.get_attributes().await?;
Ok((this, session_info))
}
pub async fn get_result_stream<'a, C, F>(
&'a mut self,
cmd: Command,
rs_handle: &mut Option<u16>,
future_maker: C,
) -> Result<QueryResultStream<'_, C, F>, SqlxError>
where
C: Fn(&'a mut ExaWebSocket, u16, usize) -> Result<F, SqlxError>,
F: Future<Output = Result<(DataChunk, &'a mut ExaWebSocket), SqlxError>> + 'a,
{
if let Some(handle) = rs_handle.take() {
self.close_result_set(handle).await?;
}
let query_result = self.get_query_result(cmd).await?;
*rs_handle = query_result.handle();
QueryResultStream::new(self, query_result, future_maker)
}
pub async fn get_query_result(&mut self, cmd: Command) -> Result<QueryResult, SqlxError> {
self.send_and_recv::<Results>(cmd).await.map(From::from)
}
pub async fn close_result_set(&mut self, handle: u16) -> Result<(), SqlxError> {
let cmd = ExaCommand::new_close_result(handle).try_into()?;
self.send_cmd_ignore_response(cmd).await?;
Ok(())
}
pub async fn create_prepared(&mut self, cmd: Command) -> Result<PreparedStatement, SqlxError> {
self.send_and_recv(cmd).await
}
pub async fn describe(&mut self, cmd: Command) -> Result<DescribeStatement, SqlxError> {
self.send_and_recv(cmd).await
}
pub async fn close_prepared(&mut self, handle: u16) -> Result<(), SqlxError> {
let cmd = ExaCommand::new_close_prepared(handle).try_into()?;
self.send_cmd_ignore_response(cmd).await
}
pub async fn fetch_chunk(&mut self, cmd: Command) -> Result<DataChunk, SqlxError> {
self.send_and_recv(cmd).await
}
pub async fn set_attributes(&mut self) -> Result<(), SqlxError> {
let cmd = ExaCommand::new_set_attributes(&self.attributes).try_into()?;
self.send_cmd_ignore_response(cmd).await
}
#[cfg(feature = "etl")]
pub async fn get_hosts(&mut self) -> Result<Vec<IpAddr>, SqlxError> {
let host_ip = self.socket_addr().ip();
let cmd = ExaCommand::new_get_hosts(host_ip).try_into()?;
self.send_and_recv::<Hosts>(cmd).await.map(From::from)
}
pub async fn get_attributes(&mut self) -> Result<(), SqlxError> {
let cmd = ExaCommand::GetAttributes.try_into()?;
self.send_cmd_ignore_response(cmd).await
}
pub fn begin(&mut self) -> Result<(), SqlxError> {
if self.attributes.open_transaction {
return Err(ExaProtocolError::TransactionAlreadyOpen)?;
}
self.attributes.autocommit = false;
self.attributes.open_transaction = true;
Ok(())
}
pub async fn commit(&mut self) -> Result<(), SqlxError> {
self.attributes.autocommit = true;
let cmd = ExaCommand::new_execute("COMMIT;", &self.attributes).try_into()?;
self.send_cmd_ignore_response(cmd).await?;
self.attributes.open_transaction = false;
Ok(())
}
pub async fn rollback(&mut self) -> Result<(), SqlxError> {
let cmd = ExaCommand::new_execute("ROLLBACK;", &self.attributes).try_into()?;
self.raw_send(cmd).await?;
self.recv::<Option<IgnoredAny>>().await?;
self.attributes.autocommit = true;
self.attributes.open_transaction = false;
self.pending_rollback = false;
Ok(())
}
pub async fn ping(&mut self) -> Result<(), SqlxError> {
self.ws.ping().await
}
pub async fn disconnect(&mut self) -> Result<(), SqlxError> {
let cmd = ExaCommand::Disconnect.try_into()?;
self.send_cmd_ignore_response(cmd).await
}
pub async fn close(&mut self) -> Result<(), SqlxError> {
self.ws.close().await
}
pub async fn get_or_prepare<'a>(
&mut self,
cache: &'a mut LruCache<String, PreparedStatement>,
sql: &str,
persist: bool,
) -> Result<Cow<'a, PreparedStatement>, SqlxError> {
if cache.contains(sql) {
return Ok(Cow::Borrowed(cache.get(sql).unwrap()));
}
let cmd = ExaCommand::new_create_prepared(sql).try_into()?;
let prepared = self.create_prepared(cmd).await?;
if persist {
if let Some((_, old)) = cache.push(sql.to_owned(), prepared) {
self.close_prepared(old.statement_handle).await?;
}
return Ok(Cow::Borrowed(cache.get(sql).unwrap()));
}
Ok(Cow::Owned(prepared))
}
#[cfg(feature = "migrate")]
pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlxError> {
let sql = sql.trim_end();
let sql = sql.strip_suffix(';').unwrap_or(sql);
let sql_batch = sql.split(';').collect();
let cmd = ExaCommand::new_execute_batch(sql_batch, &self.attributes).try_into()?;
match self.send_cmd_ignore_response(cmd).await {
Ok(()) => return Ok(()),
Err(e) => tracing::warn!(
"failed to execute batch SQL: {e}; will attempt sequential execution"
),
};
let mut result = Ok(());
let mut position = 0;
let mut sql_start = 0;
let handle_err_fn = |err: SqlxError, result: &mut Result<(), SqlxError>| {
let db_err = match &err {
SqlxError::Database(_) => err,
_ => return Err(err),
};
tracing::warn!("error running statement: {db_err}; perhaps it's incomplete?");
if result.is_ok() {
*result = Err(db_err);
}
Ok(())
};
while let Some(sql_end) = sql[position..].find(';') {
let sql = sql[sql_start..position + sql_end].trim();
let cmd = ExaCommand::new_execute(sql, &self.attributes).try_into()?;
position += sql_end + 1;
if let Err(err) = self.send_cmd_ignore_response(cmd).await {
handle_err_fn(err, &mut result)?;
} else {
sql_start = position;
result = Ok(());
}
}
let sql = sql[sql_start..].trim();
if !sql.is_empty() {
let cmd = ExaCommand::new_execute(sql, &self.attributes).try_into()?;
if let Err(err) = self.send_cmd_ignore_response(cmd).await {
handle_err_fn(err, &mut result)?;
} else {
result = Ok(());
}
}
result
}
pub fn socket_addr(&self) -> SocketAddr {
self.ws.socket_addr()
}
async fn send_cmd_ignore_response(&mut self, cmd: Command) -> Result<(), SqlxError> {
self.send_and_recv::<Option<IgnoredAny>>(cmd).await?;
Ok(())
}
async fn send_and_recv<T>(&mut self, cmd: Command) -> Result<T, SqlxError>
where
T: DeserializeOwned + Debug,
{
self.send(cmd).await?;
self.recv().await
}
pub(crate) async fn send(&mut self, cmd: Command) -> Result<(), SqlxError> {
if self.pending_rollback {
self.rollback().await?;
self.pending_rollback = false;
}
self.raw_send(cmd).await
}
async fn raw_send(&mut self, cmd: Command) -> Result<(), SqlxError> {
let cmd = cmd.into_inner();
tracing::debug!("sending command to database: {cmd}");
self.ws.send(cmd).await
}
pub(crate) async fn recv<T>(&mut self) -> Result<T, SqlxError>
where
T: DeserializeOwned + Debug,
{
let (response_data, attributes) = Result::from(self.ws.recv().await?)?;
if let Some(attributes) = attributes {
tracing::debug!("updating connection attributes using:\n{attributes:#?}");
self.attributes.update(attributes);
}
tracing::trace!("database response:\n{response_data:#?}");
Ok(response_data)
}
}