saola-quaint 0.2.0-alpha.14

An abstraction layer for SQL databases (PostgreSQL, MySQL, SQLite, MSSQL)
use std::{
    io::{Error as IoError, ErrorKind as IoErrorKind},
    pin::Pin,
    str::FromStr,
    task::{Context, Poll, ready},
};

use bytes::Bytes;
use futures::{FutureExt, Sink, SinkExt, Stream};
use pin_project::pin_project;
use postgres_native_tls::TlsConnector;
use tokio::{
    io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
    net::TcpStream,
    task::JoinHandle,
};
use tokio_postgres::{Client, Config};
use tokio_tungstenite::{
    MaybeTlsStream, WebSocketStream, connect_async,
    tungstenite::{
        self, Error as TungsteniteError, Message,
        client::IntoClientRequest,
        http::{HeaderMap, HeaderValue, StatusCode},
    },
};
use tokio_util::io::StreamReader;

use crate::{
    connector::PostgresWebSocketUrl,
    error::{self, Error, ErrorKind, Name, NativeErrorKind},
};

const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters";
const HOST_HEADER: &str = "Prisma-Db-Host";

pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result<(Client, JoinHandle<()>)> {
    let db_name = url.overriden_db_name().map(ToOwned::to_owned);
    let (ws_stream, response) = connect_async(url).await?;

    let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?;
    let db_host = require_header_value(response.headers(), HOST_HEADER)?;

    let mut config = Config::from_str(connection_params)?;
    if let Some(db_name) = db_name {
        config.dbname(&db_name);
    }
    let ws_byte_stream = WsTunnel::new(ws_stream);

    let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host);
    let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?;

    let task = connection.map(move |result| {
        if let Err(err) = result {
            tracing::error!("Error in PostgreSQL WebSocket connection: {err:?}");
        }
    });

    #[cfg(feature = "telemetry")]
    let task = tracing_futures::WithSubscriber::with_current_subscriber(task);
    #[cfg(not(feature = "telemetry"))]
    let task = crate::WithSubscriber::with_current_subscriber(task);

    #[cfg(feature = "metrics")]
    let task = prisma_metrics::WithMetricsInstrumentation::with_current_recorder(task);
    #[cfg(not(feature = "metrics"))]
    let task = crate::WithMetricsInstrumentation::with_current_recorder(task);

    let handle = tokio::spawn(task);

    Ok((client, handle))
}

fn require_header_value<'a>(headers: &'a HeaderMap, name: &str) -> crate::Result<&'a str> {
    let Some(header) = headers.get(name) else {
        let message = format!("Missing response header {name}");
        let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build();
        return Err(error);
    };

    let value = header.to_str().map_err(|inner| {
        Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build()
    })?;

    Ok(value)
}

impl IntoClientRequest for PostgresWebSocketUrl {
    fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
        let mut request = self.url.to_string().into_client_request()?;
        let bearer = format!("Bearer {}", self.api_key());
        let auth_header = HeaderValue::from_str(&bearer)?;
        request.headers_mut().insert("Authorization", auth_header);
        Ok(request)
    }
}

impl From<TungsteniteError> for error::Error {
    fn from(value: TungsteniteError) -> Self {
        let builder = match value {
            TungsteniteError::Tls(tls_error) => Error::builder(ErrorKind::Native(NativeErrorKind::TlsError {
                message: tls_error.to_string(),
            })),

            TungsteniteError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
                Error::builder(ErrorKind::DatabaseAccessDenied {
                    db_name: Name::Unavailable,
                })
            }

            _ => Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(value)))),
        };

        builder.build()
    }
}

#[pin_project]
struct WsTunnel {
    #[pin]
    inner: StreamReader<WsBytesStream, Bytes>,
    write_state: WriteState,
}

enum WriteState {
    Free,
    Writing(usize, usize),
}

#[pin_project]
struct WsBytesStream(#[pin] WebSocketStream<MaybeTlsStream<TcpStream>>);

impl WsTunnel {
    fn new(stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
        WsTunnel {
            inner: StreamReader::new(WsBytesStream(stream)),
            write_state: WriteState::Free,
        }
    }
}

impl WsBytesStream {
    fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
        self.project().0
    }
}

impl AsyncRead for WsTunnel {
    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
        self.project().inner.poll_read(cx, buf)
    }
}

impl AsyncBufRead for WsTunnel {
    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
        self.project().inner.poll_fill_buf(cx)
    }

    fn consume(self: Pin<&mut Self>, amt: usize) {
        self.project().inner.consume(amt)
    }
}

impl AsyncWrite for WsTunnel {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
        let this = self.get_mut();
        let sink = &mut this.inner.get_mut().0;
        let to_io_err = |err| IoError::other(err);

        match this.write_state {
            WriteState::Free => {
                ready!(sink.poll_ready_unpin(cx)).map_err(to_io_err)?;
                sink.start_send_unpin(Message::Binary(Bytes::copy_from_slice(buf)))
                    .map_err(to_io_err)?;
                this.write_state = WriteState::Writing(buf.as_ptr() as usize, buf.len());
                cx.waker().wake_by_ref();
                Poll::Pending
            }

            WriteState::Writing(addr, len) => {
                if (buf.as_ptr() as usize, buf.len()) != (addr, len) {
                    return Poll::Ready(Err(IoError::new(
                        IoErrorKind::ResourceBusy,
                        "concurrent writes to the WebSocket tunnel are not allowed",
                    )));
                }
                ready!(sink.poll_flush_unpin(cx)).map_err(to_io_err)?;
                this.write_state = WriteState::Free;
                Poll::Ready(Ok(len))
            }
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.project()
            .inner
            .get_pin_mut()
            .get_pin_mut()
            .poll_flush(cx)
            .map_err(IoError::other)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.project()
            .inner
            .get_pin_mut()
            .get_pin_mut()
            .poll_close(cx)
            .map_err(IoError::other)
    }
}

impl Stream for WsBytesStream {
    type Item = Result<Bytes, IoError>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match self.get_pin_mut().poll_next(cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Ready(Some(Ok(msg))) => match msg {
                Message::Binary(data) => Poll::Ready(Some(Ok(data))),
                Message::Close(_) => Poll::Ready(None),
                Message::Text(data) => {
                    tracing::warn!(%data, "unexpected text frame in a WebSocket tunnel");
                    cx.waker().wake_by_ref();
                    Poll::Pending
                }
                Message::Ping(_) | Message::Pong(_) => {
                    cx.waker().wake_by_ref();
                    Poll::Pending
                }
                Message::Frame(_) => Poll::Ready(Some(Err(IoError::other("unexpected raw frame")))),
            },
            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(IoError::other(err)))),
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.0.size_hint()
    }
}