use core::{ops::Deref, sync::atomic::Ordering};
use std::sync::Arc;
use super::{
client::ClientBorrow,
column::Column,
driver::codec::AsParams,
types::{ToSql, Type},
};
pub struct StatementGuarded<'a, C>
where
C: ClientBorrow,
{
stmt: Option<Statement>,
cli: &'a C,
}
impl<C> AsRef<Statement> for StatementGuarded<'_, C>
where
C: ClientBorrow,
{
#[inline]
fn as_ref(&self) -> &Statement {
self
}
}
impl<C> Deref for StatementGuarded<'_, C>
where
C: ClientBorrow,
{
type Target = Statement;
fn deref(&self) -> &Self::Target {
self.stmt.as_ref().unwrap()
}
}
impl<C> Drop for StatementGuarded<'_, C>
where
C: ClientBorrow,
{
fn drop(&mut self) {
if let Some(stmt) = self.stmt.take() {
let _ = self.cli.borrow_cli_ref().query_raw(stmt.cancel());
}
}
}
impl<C> StatementGuarded<'_, C>
where
C: ClientBorrow,
{
pub fn leak(mut self) -> Statement {
self.stmt.take().unwrap()
}
}
#[derive(Default)]
pub struct Statement {
name: Arc<str>,
params: Arc<[Type]>,
columns: Arc<[Column]>,
}
impl Statement {
pub(crate) fn new(name: String, params: Vec<Type>, columns: Vec<Column>) -> Self {
Self {
name: name.into(),
params: params.into(),
columns: columns.into(),
}
}
pub(crate) fn duplicate(&self) -> Self {
Self {
name: self.name.clone(),
params: self.params.clone(),
columns: self.columns.clone(),
}
}
pub(crate) fn name(&self) -> &str {
&self.name
}
pub(crate) fn columns_owned(&self) -> Arc<[Column]> {
self.columns.clone()
}
fn cancel(&self) -> StatementPreparedCancel<'_> {
StatementPreparedCancel { name: self.name() }
}
#[inline]
pub const fn named<'a>(stmt: &'a str, types: &'a [Type]) -> StatementNamed<'a> {
StatementNamed { stmt, types }
}
#[inline]
pub fn bind<P>(&self, params: P) -> StatementPreparedQuery<'_, P>
where
P: AsParams,
{
StatementPreparedQuery { stmt: self, params }
}
#[inline]
pub fn bind_dyn<'p, 't>(
&self,
params: &'p [&'t (dyn ToSql + Sync)],
) -> StatementPreparedQuery<'_, impl ExactSizeIterator<Item = &'t (dyn ToSql + Sync)> + Clone + 'p> {
self.bind(params.iter().cloned())
}
#[inline]
pub fn bind_none(&self) -> StatementPreparedQuery<'_, [bool; 0]> {
self.bind([])
}
#[inline]
pub fn params(&self) -> &[Type] {
&self.params
}
#[inline]
pub fn columns(&self) -> &[Column] {
&self.columns
}
#[inline]
pub fn into_guarded<C>(self, cli: &C) -> StatementGuarded<'_, C>
where
C: ClientBorrow,
{
StatementGuarded { stmt: Some(self), cli }
}
}
pub struct StatementNamed<'a> {
pub(crate) stmt: &'a str,
pub(crate) types: &'a [Type],
}
impl<'a> StatementNamed<'a> {
fn name() -> String {
let id = crate::NEXT_ID.fetch_add(1, Ordering::Relaxed);
format!("s{id}")
}
#[inline]
pub fn bind<P>(self, params: P) -> StatementQuery<'a, P> {
StatementQuery {
stmt: self.stmt,
types: self.types,
params,
}
}
#[inline]
pub fn bind_dyn<'p, 't>(
self,
params: &'p [&'t (dyn ToSql + Sync)],
) -> StatementQuery<'a, impl ExactSizeIterator<Item = &'t (dyn ToSql + Sync)> + Clone + 'p> {
self.bind(params.iter().cloned())
}
#[inline]
pub fn bind_none(self) -> StatementQuery<'a, [bool; 0]> {
StatementQuery {
stmt: self.stmt,
types: self.types,
params: [],
}
}
}
pub(crate) struct StatementCreate<'a, 'c, C> {
pub(crate) name: String,
pub(crate) stmt: &'a str,
pub(crate) types: &'a [Type],
pub(crate) cli: &'c C,
}
impl<'a, 'c, C> From<(StatementNamed<'a>, &'c C)> for StatementCreate<'a, 'c, C> {
fn from((stmt, cli): (StatementNamed<'a>, &'c C)) -> Self {
Self {
name: StatementNamed::name(),
stmt: stmt.stmt,
types: stmt.types,
cli,
}
}
}
pub(crate) struct StatementCreateBlocking<'a, 'c, C> {
pub(crate) name: String,
pub(crate) stmt: &'a str,
pub(crate) types: &'a [Type],
pub(crate) cli: &'c C,
}
impl<'a, 'c, C> From<(StatementNamed<'a>, &'c C)> for StatementCreateBlocking<'a, 'c, C> {
fn from((stmt, cli): (StatementNamed<'a>, &'c C)) -> Self {
Self {
name: StatementNamed::name(),
stmt: stmt.stmt,
types: stmt.types,
cli,
}
}
}
pub(crate) struct StatementPreparedCancel<'a> {
pub(crate) name: &'a str,
}
pub struct StatementPreparedQuery<'a, P> {
pub(crate) stmt: &'a Statement,
pub(crate) params: P,
}
impl<'a, P> StatementPreparedQuery<'a, P> {
#[inline]
pub fn into_owned(self) -> StatementPreparedQueryOwned<'a, P> {
StatementPreparedQueryOwned {
stmt: self.stmt,
params: self.params,
}
}
}
pub struct StatementPreparedQueryOwned<'a, P> {
pub(crate) stmt: &'a Statement,
pub(crate) params: P,
}
pub struct StatementQuery<'a, P> {
pub(crate) stmt: &'a str,
pub(crate) types: &'a [Type],
pub(crate) params: P,
}
impl<'a, P> StatementQuery<'a, P> {
pub fn into_single_rtt(self) -> StatementSingleRTTQuery<'a, P> {
StatementSingleRTTQuery { query: self }
}
}
pub struct StatementSingleRTTQuery<'a, P> {
query: StatementQuery<'a, P>,
}
impl<'a, P> StatementSingleRTTQuery<'a, P> {
pub(crate) fn into_with_cli<'c, C>(self, cli: &'c C) -> StatementSingleRTTQueryWithCli<'a, 'c, P, C> {
StatementSingleRTTQueryWithCli { query: self.query, cli }
}
}
pub(crate) struct StatementSingleRTTQueryWithCli<'a, 'c, P, C> {
pub(crate) query: StatementQuery<'a, P>,
pub(crate) cli: &'c C,
}
pub struct StatementGuardedOwned<C>
where
C: ClientBorrow,
{
stmt: Statement,
cli: C,
}
impl<C> Clone for StatementGuardedOwned<C>
where
C: ClientBorrow + Clone,
{
fn clone(&self) -> Self {
Self {
stmt: self.stmt.duplicate(),
cli: self.cli.clone(),
}
}
}
impl<C> Drop for StatementGuardedOwned<C>
where
C: ClientBorrow,
{
fn drop(&mut self) {
if Arc::strong_count(&self.stmt.name) == 1 {
debug_assert_eq!(Arc::strong_count(&self.stmt.params), 1);
debug_assert_eq!(Arc::strong_count(&self.stmt.columns), 1);
let _ = self.cli.borrow_cli_ref().query_raw(self.stmt.cancel());
}
}
}
impl<C> Deref for StatementGuardedOwned<C>
where
C: ClientBorrow,
{
type Target = Statement;
fn deref(&self) -> &Self::Target {
&self.stmt
}
}
impl<C> AsRef<Statement> for StatementGuardedOwned<C>
where
C: ClientBorrow,
{
fn as_ref(&self) -> &Statement {
&self.stmt
}
}
impl<C> StatementGuardedOwned<C>
where
C: ClientBorrow,
{
pub fn new(stmt: Statement, cli: C) -> Self {
Self { stmt, cli }
}
pub fn client(&self) -> &C {
&self.cli
}
}
#[cfg(test)]
mod test {
use core::future::IntoFuture;
use crate::{
Postgres,
error::{DbError, SqlState},
execute::Execute,
iter::AsyncLendingIterator,
statement::Statement,
};
#[tokio::test]
async fn cancel_statement() {
let (cli, drv) = Postgres::new("postgres://postgres:postgres@localhost:5432")
.connect()
.await
.unwrap();
tokio::task::spawn(drv.into_future());
std::path::Path::new("./samples/test.sql").execute(&cli).await.unwrap();
let stmt = Statement::named("SELECT id, name FROM foo ORDER BY id", &[])
.execute(&cli)
.await
.unwrap();
let stmt_raw = stmt.duplicate();
drop(stmt);
let mut stream = stmt_raw.query(&cli).await.unwrap();
let e = stream.try_next().await.err().unwrap();
let e = e.downcast_ref::<DbError>().unwrap();
assert_eq!(e.code(), &SqlState::INVALID_SQL_STATEMENT_NAME);
}
}