pub mod future;
pub mod request;
pub mod socket;
mod tls;
mod transport;
use std::{
fmt::Debug,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use futures_core::Stream;
use futures_util::{io::BufReader, Sink, SinkExt, StreamExt};
use socket::ExaSocket;
use sqlx_core::{bytes::Bytes, common::StatementCache};
pub use tls::WithMaybeTlsExaSocket;
use transport::MaybeCompressedWebSocket;
use crate::{
connection::websocket::{
future::{CloseResultSets, ExaLogin, GetAttributes, Rollback, WebSocketFuture},
request::ExaLoginRequest,
transport::PlainWebSocket,
},
error::ToSqlxError,
responses::{ExaAttributes, PreparedStatement, SessionInfo},
SqlxError, SqlxResult,
};
#[derive(Debug)]
pub struct ExaWebSocket {
pub inner: MaybeCompressedWebSocket,
pub attributes: ExaAttributes,
pub pending_close: Option<CloseResultSets>,
pub pending_rollback: Option<Rollback>,
pub statement_cache: StatementCache<PreparedStatement>,
pub active_request: bool,
}
impl ExaWebSocket {
const WS_SCHEME: &'static str = "ws";
const WSS_SCHEME: &'static str = "wss";
pub async fn new(
host: &str,
port: u16,
socket: ExaSocket,
options: ExaLoginRequest<'_>,
with_tls: bool,
) -> SqlxResult<(Self, SessionInfo)> {
let scheme = if with_tls {
Self::WSS_SCHEME
} else {
Self::WS_SCHEME
};
let host = format!("{scheme}://{host}:{port}");
let (ws, _) = async_tungstenite::client_async(host, BufReader::new(socket))
.await
.map_err(ToSqlxError::to_sqlx_err)?;
let attributes = ExaAttributes::new(
options.use_compression,
options.fetch_size,
with_tls,
options.statement_cache_capacity,
);
let statement_cache = StatementCache::new(options.statement_cache_capacity);
let inner = MaybeCompressedWebSocket::Plain(PlainWebSocket(ws));
let use_compression = options.use_compression;
let mut this = Self {
inner,
attributes,
pending_close: None,
pending_rollback: None,
statement_cache,
active_request: false,
};
let session_info = ExaLogin::new(options).future(&mut this).await?;
this.inner = this.inner.maybe_compress(use_compression);
GetAttributes::default().future(&mut this).await?;
Ok((this, session_info))
}
pub async fn ping(&mut self) -> SqlxResult<()> {
self.inner.ping().await
}
pub fn server(&self) -> SocketAddr {
self.inner.server()
}
}
impl Stream for ExaWebSocket {
type Item = SqlxResult<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().inner.poll_next_unpin(cx)
}
}
impl Sink<String> for ExaWebSocket {
type Error = SqlxError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.get_mut().inner.start_send_unpin(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_flush_unpin(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_close_unpin(cx)
}
}