use std::sync::Arc;
use xitca_io::bytes::BytesMut;
use crate::{
client::ClientBorrow,
column::Column,
error::{Error, InvalidParamCount},
protocol::{self, message::frontend},
statement::{
Statement, StatementCreate, StatementCreateBlocking, StatementPreparedCancel, StatementPreparedQuery,
StatementPreparedQueryOwned, StatementQuery, StatementSingleRTTQueryWithCli,
},
types::{BorrowToSql, IsNull, Type},
};
use super::{
AsParams,
response::{
IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
},
sealed,
};
#[diagnostic::on_unimplemented(
message = "`{Self}` does not impl Encode trait",
label = "query statement argument must be types implement Encode trait",
note = "consider using the types listed below that implementing Encode trait"
)]
pub trait Encode: sealed::Sealed + Sized {
type Output: IntoResponse;
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
}
impl sealed::Sealed for &str {}
impl Encode for &str {
type Output = Vec<Column>;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
frontend::query(self, buf)?;
Ok(Vec::new())
}
}
impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}
impl<'c, C> Encode for StatementCreate<'_, 'c, C>
where
C: ClientBorrow + Sync,
{
type Output = StatementCreateResponse<'c, C>;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { name, stmt, types, cli } = self;
encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
}
}
impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}
impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
where
C: ClientBorrow,
{
type Output = StatementCreateResponseBlocking<'c, C>;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { name, stmt, types, cli } = self;
encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
}
}
fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
frontend::describe(b'S', name, buf)?;
frontend::sync(buf);
Ok(())
}
impl sealed::Sealed for StatementPreparedCancel<'_> {}
impl Encode for StatementPreparedCancel<'_> {
type Output = NoOpIntoRowStream;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { name } = self;
frontend::close(b'S', name, buf)?;
frontend::sync(buf);
Ok(NoOpIntoRowStream)
}
}
impl<P> sealed::Sealed for StatementPreparedQuery<'_, P> {}
impl<'s, P> Encode for StatementPreparedQuery<'s, P>
where
P: AsParams,
{
type Output = &'s [Column];
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { stmt, params } = self;
encode_stmt_query(stmt, params, buf).map(|_| stmt.columns())
}
}
impl<P> sealed::Sealed for StatementPreparedQueryOwned<'_, P> {}
impl<'s, P> Encode for StatementPreparedQueryOwned<'s, P>
where
P: AsParams,
{
type Output = Arc<[Column]>;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { stmt, params } = self;
encode_stmt_query(stmt, params, buf).map(|_| stmt.columns_owned())
}
}
fn encode_stmt_query<P>(stmt: &Statement, params: P, buf: &mut BytesMut) -> Result<(), Error>
where
P: AsParams,
{
encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
frontend::execute("", 0, buf)?;
frontend::sync(buf);
Ok(())
}
impl<C, P> sealed::Sealed for StatementSingleRTTQueryWithCli<'_, '_, P, C> {}
impl<'c, C, P> Encode for StatementSingleRTTQueryWithCli<'_, 'c, P, C>
where
C: ClientBorrow,
P: AsParams,
{
type Output = IntoRowStreamGuard<'c, C>;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { query, cli } = self;
let StatementQuery { stmt, params, types } = query;
frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
encode_bind("", types, params, "", buf)?;
frontend::describe(b'S', "", buf)?;
frontend::execute("", 0, buf)?;
frontend::sync(buf);
Ok(IntoRowStreamGuard(cli))
}
}
pub(crate) struct PortalCreate<'a, P> {
pub(crate) name: &'a str,
pub(crate) stmt: &'a str,
pub(crate) types: &'a [Type],
pub(crate) params: P,
}
impl<P> sealed::Sealed for PortalCreate<'_, P> {}
impl<P> Encode for PortalCreate<'_, P>
where
P: AsParams,
{
type Output = NoOpIntoRowStream;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let PortalCreate {
name,
stmt,
types,
params,
} = self;
encode_bind(stmt, types, params, name, buf)?;
frontend::sync(buf);
Ok(NoOpIntoRowStream)
}
}
pub(crate) struct PortalCancel<'a> {
pub(crate) name: &'a str,
}
impl sealed::Sealed for PortalCancel<'_> {}
impl Encode for PortalCancel<'_> {
type Output = NoOpIntoRowStream;
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
frontend::close(b'P', self.name, buf)?;
frontend::sync(buf);
Ok(NoOpIntoRowStream)
}
}
pub struct PortalQuery<'a> {
pub(crate) name: &'a str,
pub(crate) columns: &'a [Column],
pub(crate) max_rows: i32,
}
impl sealed::Sealed for PortalQuery<'_> {}
impl<'s> Encode for PortalQuery<'s> {
type Output = &'s [Column];
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self {
name,
max_rows,
columns,
} = self;
frontend::execute(name, max_rows, buf)?;
frontend::sync(buf);
Ok(columns)
}
}
fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
where
P: AsParams,
{
let params = params.into_iter();
if params.len() != types.len() {
return Err(Error::from(InvalidParamCount {
expected: types.len(),
params: params.len(),
}));
}
let params = params.zip(types);
frontend::bind(
portal,
stmt,
params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
params,
|(p, ty), buf| {
p.borrow_to_sql().to_sql_checked(ty, buf).map(|is_null| match is_null {
IsNull::No => protocol::IsNull::No,
IsNull::Yes => protocol::IsNull::Yes,
})
},
Some(1),
buf,
)
.map_err(|e| match e {
frontend::BindError::Conversion(e) => Error::from(e),
frontend::BindError::Serialization(e) => Error::from(e),
})
}