xitca-postgres 0.4.0

an async postgres client
Documentation
use std::sync::Arc;

use xitca_io::bytes::BytesMut;

use crate::{
    client::ClientBorrow,
    column::Column,
    error::{Error, InvalidParamCount},
    protocol::{self, message::frontend},
    statement::{
        Statement, StatementCreate, StatementCreateBlocking, StatementPreparedCancel, StatementPreparedQuery,
        StatementPreparedQueryOwned, StatementQuery, StatementSingleRTTQueryWithCli,
    },
    types::{BorrowToSql, IsNull, Type},
};

use super::{
    AsParams,
    response::{
        IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
    },
    sealed,
};

/// trait for generic over how to encode a query.
/// currently this trait can not be implement by library user.
#[diagnostic::on_unimplemented(
    message = "`{Self}` does not impl Encode trait",
    label = "query statement argument must be types implement Encode trait",
    note = "consider using the types listed below that implementing Encode trait"
)]
pub trait Encode: sealed::Sealed + Sized {
    /// output type defines how a potential async row streaming type should be constructed.
    /// certain state from the encode type may need to be passed for constructing the stream
    type Output: IntoResponse;

    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
}

impl sealed::Sealed for &str {}

impl Encode for &str {
    type Output = Vec<Column>;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        frontend::query(self, buf)?;
        Ok(Vec::new())
    }
}

impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}

impl<'c, C> Encode for StatementCreate<'_, 'c, C>
where
    C: ClientBorrow + Sync,
{
    type Output = StatementCreateResponse<'c, C>;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self { name, stmt, types, cli } = self;
        encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
    }
}

impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}

impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
where
    C: ClientBorrow,
{
    type Output = StatementCreateResponseBlocking<'c, C>;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self { name, stmt, types, cli } = self;
        encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
    }
}

fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
    frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
    frontend::describe(b'S', name, buf)?;
    frontend::sync(buf);
    Ok(())
}

impl sealed::Sealed for StatementPreparedCancel<'_> {}

impl Encode for StatementPreparedCancel<'_> {
    type Output = NoOpIntoRowStream;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self { name } = self;
        frontend::close(b'S', name, buf)?;
        frontend::sync(buf);
        Ok(NoOpIntoRowStream)
    }
}

impl<P> sealed::Sealed for StatementPreparedQuery<'_, P> {}

impl<'s, P> Encode for StatementPreparedQuery<'s, P>
where
    P: AsParams,
{
    type Output = &'s [Column];

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self { stmt, params } = self;
        encode_stmt_query(stmt, params, buf).map(|_| stmt.columns())
    }
}

impl<P> sealed::Sealed for StatementPreparedQueryOwned<'_, P> {}

impl<'s, P> Encode for StatementPreparedQueryOwned<'s, P>
where
    P: AsParams,
{
    type Output = Arc<[Column]>;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self { stmt, params } = self;
        encode_stmt_query(stmt, params, buf).map(|_| stmt.columns_owned())
    }
}

fn encode_stmt_query<P>(stmt: &Statement, params: P, buf: &mut BytesMut) -> Result<(), Error>
where
    P: AsParams,
{
    encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
    frontend::execute("", 0, buf)?;
    frontend::sync(buf);
    Ok(())
}

impl<C, P> sealed::Sealed for StatementSingleRTTQueryWithCli<'_, '_, P, C> {}

impl<'c, C, P> Encode for StatementSingleRTTQueryWithCli<'_, 'c, P, C>
where
    C: ClientBorrow,
    P: AsParams,
{
    type Output = IntoRowStreamGuard<'c, C>;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self { query, cli } = self;
        let StatementQuery { stmt, params, types } = query;
        frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
        encode_bind("", types, params, "", buf)?;
        frontend::describe(b'S', "", buf)?;
        frontend::execute("", 0, buf)?;
        frontend::sync(buf);
        Ok(IntoRowStreamGuard(cli))
    }
}

pub(crate) struct PortalCreate<'a, P> {
    pub(crate) name: &'a str,
    pub(crate) stmt: &'a str,
    pub(crate) types: &'a [Type],
    pub(crate) params: P,
}

impl<P> sealed::Sealed for PortalCreate<'_, P> {}

impl<P> Encode for PortalCreate<'_, P>
where
    P: AsParams,
{
    type Output = NoOpIntoRowStream;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let PortalCreate {
            name,
            stmt,
            types,
            params,
        } = self;
        encode_bind(stmt, types, params, name, buf)?;
        frontend::sync(buf);
        Ok(NoOpIntoRowStream)
    }
}

pub(crate) struct PortalCancel<'a> {
    pub(crate) name: &'a str,
}

impl sealed::Sealed for PortalCancel<'_> {}

impl Encode for PortalCancel<'_> {
    type Output = NoOpIntoRowStream;

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        frontend::close(b'P', self.name, buf)?;
        frontend::sync(buf);
        Ok(NoOpIntoRowStream)
    }
}

pub struct PortalQuery<'a> {
    pub(crate) name: &'a str,
    pub(crate) columns: &'a [Column],
    pub(crate) max_rows: i32,
}

impl sealed::Sealed for PortalQuery<'_> {}

impl<'s> Encode for PortalQuery<'s> {
    type Output = &'s [Column];

    #[inline]
    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
        let Self {
            name,
            max_rows,
            columns,
        } = self;
        frontend::execute(name, max_rows, buf)?;
        frontend::sync(buf);
        Ok(columns)
    }
}

fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
where
    P: AsParams,
{
    let params = params.into_iter();
    if params.len() != types.len() {
        return Err(Error::from(InvalidParamCount {
            expected: types.len(),
            params: params.len(),
        }));
    }

    let params = params.zip(types);

    frontend::bind(
        portal,
        stmt,
        params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
        params,
        |(p, ty), buf| {
            p.borrow_to_sql().to_sql_checked(ty, buf).map(|is_null| match is_null {
                IsNull::No => protocol::IsNull::No,
                IsNull::Yes => protocol::IsNull::Yes,
            })
        },
        Some(1),
        buf,
    )
    .map_err(|e| match e {
        frontend::BindError::Conversion(e) => Error::from(e),
        frontend::BindError::Serialization(e) => Error::from(e),
    })
}