use super::stream::FromRowStream;
use super::{starts_with_keyword, strip_sql_prefix};
use crate::client::{GenericClient, RowStream, StreamingClient};
use crate::error::{OrmError, OrmResult};
use crate::row::FromRow;
use std::sync::Arc;
use tokio_postgres::Row;
use tokio_postgres::types::{FromSql, ToSql};
#[must_use]
pub struct Query {
sql: String,
params: Vec<Arc<dyn ToSql + Sync + Send>>,
tag: Option<String>,
}
impl Query {
pub fn new(sql: impl Into<String>) -> Self {
Self {
sql: sql.into(),
params: Vec::new(),
tag: None,
}
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tag = Some(tag.into());
self
}
pub fn bind<T>(mut self, value: T) -> Self
where
T: ToSql + Sync + Send + 'static,
{
self.params.push(Arc::new(value));
self
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn params_ref(&self) -> Vec<&(dyn ToSql + Sync)> {
self.params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect()
}
pub async fn fetch_all(&self, conn: &impl GenericClient) -> OrmResult<Vec<Row>> {
let params = self.params_ref();
match self.tag.as_deref() {
Some(tag) => conn.query_tagged(tag, &self.sql, ¶ms).await,
None => conn.query(&self.sql, ¶ms).await,
}
}
pub async fn stream(&self, conn: &impl StreamingClient) -> OrmResult<RowStream> {
let params = self.params_ref();
match self.tag.as_deref() {
Some(tag) => conn.query_stream_tagged(tag, &self.sql, ¶ms).await,
None => conn.query_stream(&self.sql, ¶ms).await,
}
}
pub async fn stream_as<T: FromRow>(
&self,
conn: &impl StreamingClient,
) -> OrmResult<FromRowStream<T>> {
let stream = self.stream(conn).await?;
Ok(FromRowStream::new(stream))
}
pub async fn fetch_all_as<T: FromRow>(&self, conn: &impl GenericClient) -> OrmResult<Vec<T>> {
let rows = self.fetch_all(conn).await?;
rows.iter().map(T::from_row).collect()
}
pub async fn fetch_one(&self, conn: &impl GenericClient) -> OrmResult<Row> {
let params = self.params_ref();
match self.tag.as_deref() {
Some(tag) => conn.query_one_tagged(tag, &self.sql, ¶ms).await,
None => conn.query_one(&self.sql, ¶ms).await,
}
}
pub async fn fetch_one_as<T: FromRow>(&self, conn: &impl GenericClient) -> OrmResult<T> {
let row = self.fetch_one(conn).await?;
T::from_row(&row)
}
pub async fn fetch_opt(&self, conn: &impl GenericClient) -> OrmResult<Option<Row>> {
let params = self.params_ref();
match self.tag.as_deref() {
Some(tag) => conn.query_opt_tagged(tag, &self.sql, ¶ms).await,
None => conn.query_opt(&self.sql, ¶ms).await,
}
}
pub async fn fetch_opt_as<T: FromRow>(
&self,
conn: &impl GenericClient,
) -> OrmResult<Option<T>> {
let row = self.fetch_opt(conn).await?;
row.as_ref().map(T::from_row).transpose()
}
pub async fn execute(&self, conn: &impl GenericClient) -> OrmResult<u64> {
let params = self.params_ref();
match self.tag.as_deref() {
Some(tag) => conn.execute_tagged(tag, &self.sql, ¶ms).await,
None => conn.execute(&self.sql, ¶ms).await,
}
}
pub async fn fetch_all_tagged(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<Vec<Row>> {
let params = self.params_ref();
conn.query_tagged(tag, &self.sql, ¶ms).await
}
pub async fn fetch_all_tagged_as<T: FromRow>(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<Vec<T>> {
let rows = self.fetch_all_tagged(conn, tag).await?;
rows.iter().map(T::from_row).collect()
}
pub async fn fetch_one_tagged(&self, conn: &impl GenericClient, tag: &str) -> OrmResult<Row> {
let params = self.params_ref();
conn.query_one_tagged(tag, &self.sql, ¶ms).await
}
pub async fn fetch_one_tagged_as<T: FromRow>(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<T> {
let row = self.fetch_one_tagged(conn, tag).await?;
T::from_row(&row)
}
pub async fn fetch_one_strict(&self, conn: &impl GenericClient) -> OrmResult<Row> {
let params = self.params_ref();
match self.tag.as_deref() {
Some(tag) => conn.query_one_strict_tagged(tag, &self.sql, ¶ms).await,
None => conn.query_one_strict(&self.sql, ¶ms).await,
}
}
pub async fn fetch_one_strict_as<T: FromRow>(&self, conn: &impl GenericClient) -> OrmResult<T> {
let row = self.fetch_one_strict(conn).await?;
T::from_row(&row)
}
pub async fn fetch_one_strict_tagged(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<Row> {
let params = self.params_ref();
conn.query_one_strict_tagged(tag, &self.sql, ¶ms).await
}
pub async fn fetch_one_strict_tagged_as<T: FromRow>(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<T> {
let row = self.fetch_one_strict_tagged(conn, tag).await?;
T::from_row(&row)
}
pub async fn fetch_opt_tagged(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<Option<Row>> {
let params = self.params_ref();
conn.query_opt_tagged(tag, &self.sql, ¶ms).await
}
pub async fn fetch_opt_tagged_as<T: FromRow>(
&self,
conn: &impl GenericClient,
tag: &str,
) -> OrmResult<Option<T>> {
let row = self.fetch_opt_tagged(conn, tag).await?;
row.as_ref().map(T::from_row).transpose()
}
pub async fn execute_tagged(&self, conn: &impl GenericClient, tag: &str) -> OrmResult<u64> {
let params = self.params_ref();
conn.execute_tagged(tag, &self.sql, ¶ms).await
}
pub async fn fetch_scalar_one<'a, T>(&self, conn: &impl GenericClient) -> OrmResult<T>
where
T: for<'b> FromSql<'b> + Send + Sync,
{
let row = self.fetch_one(conn).await?;
row.try_get(0)
.map_err(|e| OrmError::decode("0", e.to_string()))
}
pub async fn fetch_scalar_opt<'a, T>(&self, conn: &impl GenericClient) -> OrmResult<Option<T>>
where
T: for<'b> FromSql<'b> + Send + Sync,
{
let row = self.fetch_opt(conn).await?;
match row {
Some(r) => r
.try_get(0)
.map(Some)
.map_err(|e| OrmError::decode("0", e.to_string())),
None => Ok(None),
}
}
pub async fn fetch_scalar_all<'a, T>(&self, conn: &impl GenericClient) -> OrmResult<Vec<T>>
where
T: for<'b> FromSql<'b> + Send + Sync,
{
let rows = self.fetch_all(conn).await?;
rows.iter()
.map(|r| {
r.try_get(0)
.map_err(|e| OrmError::decode("0", e.to_string()))
})
.collect()
}
pub async fn exists(&self, conn: &impl GenericClient) -> OrmResult<bool> {
let inner_sql = self.sql.trim_end();
let inner_sql = inner_sql.strip_suffix(';').unwrap_or(inner_sql).trim_end();
let trimmed = strip_sql_prefix(inner_sql);
if !starts_with_keyword(trimmed, "SELECT") && !starts_with_keyword(trimmed, "WITH") {
return Err(OrmError::Validation(
"exists() only works with SELECT statements (including WITH ... SELECT)"
.to_string(),
));
}
let wrapped_sql = format!("SELECT EXISTS({inner_sql})");
let params = self.params_ref();
let row = match self.tag.as_deref() {
Some(tag) => conn.query_one_tagged(tag, &wrapped_sql, ¶ms).await?,
None => conn.query_one(&wrapped_sql, ¶ms).await?,
};
row.try_get(0)
.map_err(|e| OrmError::decode("0", e.to_string()))
}
}