use super::{
BuiltQuery, DeleteQuery, InsertQuery, Row, RowContext, SelectQuery, SqlParam, UpdateQuery,
Value, pg_row_to_row,
};
use crate::Error;
use crate::{PgType, Schema, Table};
use tokio_postgres::Client;
use tracing::Instrument;
pub struct Db<'a> {
client: &'a Client,
schema: Schema,
}
impl<'a> Db<'a> {
pub fn new(client: &'a Client) -> Self {
Self {
client,
schema: crate::schema::collect_schema(),
}
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn table(&self, name: &str) -> Option<&Table> {
self.schema.tables.get(name)
}
#[allow(clippy::result_large_err)]
pub fn select(&self, table: &str) -> Result<SelectBuilder<'_>, Error> {
let table_def = self
.table(table)
.ok_or_else(|| Error::UnknownTable(table.to_string()))?;
Ok(SelectBuilder {
db: self,
table: table_def,
query: SelectQuery::new(table),
})
}
#[allow(clippy::result_large_err)]
pub fn insert(&self, table: &str) -> Result<InsertBuilder<'_>, Error> {
let table_def = self
.table(table)
.ok_or_else(|| Error::UnknownTable(table.to_string()))?;
Ok(InsertBuilder {
db: self,
table: table_def,
query: InsertQuery::new(table),
})
}
#[allow(clippy::result_large_err)]
pub fn update(&self, table: &str) -> Result<UpdateBuilder<'_>, Error> {
let table_def = self
.table(table)
.ok_or_else(|| Error::UnknownTable(table.to_string()))?;
Ok(UpdateBuilder {
db: self,
table: table_def,
query: UpdateQuery::new(table),
})
}
#[allow(clippy::result_large_err)]
pub fn delete(&self, table: &str) -> Result<DeleteBuilder<'_>, Error> {
let table_def = self
.table(table)
.ok_or_else(|| Error::UnknownTable(table.to_string()))?;
Ok(DeleteBuilder {
db: self,
table: table_def,
query: DeleteQuery::new(table),
})
}
async fn execute_select(&self, query: BuiltQuery, table: &Table) -> Result<Vec<Row>, Error> {
let params: Vec<SqlParam> = query.params.iter().map(SqlParam).collect();
let params_ref: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
.iter()
.map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
.collect();
let span = tracing::debug_span!(
"db.query",
sql = %query.sql,
params = params.len(),
rows = tracing::field::Empty,
);
let rows = self
.client
.query(&query.sql, ¶ms_ref)
.instrument(span.clone())
.await?;
span.record("rows", rows.len());
let columns: Vec<_> = if rows.is_empty() {
table
.columns
.iter()
.map(|c| (c.name.clone(), c.pg_type))
.collect()
} else {
rows[0]
.columns()
.iter()
.map(|pg_col| {
let name = pg_col.name().to_string();
let pg_type = table
.columns
.iter()
.find(|c| c.name == name)
.map(|c| c.pg_type)
.unwrap_or(PgType::Text); (name, pg_type)
})
.collect()
};
let ctx = RowContext {
table_name: &table.name,
};
rows.iter()
.map(|row| pg_row_to_row(row, &columns, &ctx))
.collect()
}
async fn execute_mutation(&self, query: BuiltQuery) -> Result<u64, Error> {
let params: Vec<SqlParam> = query.params.iter().map(SqlParam).collect();
let params_ref: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
.iter()
.map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
.collect();
let span = tracing::debug_span!(
"db.execute",
sql = %query.sql,
params = params.len(),
affected = tracing::field::Empty,
);
let affected = self
.client
.execute(&query.sql, ¶ms_ref)
.instrument(span.clone())
.await?;
span.record("affected", affected);
Ok(affected)
}
async fn execute_returning(
&self,
query: BuiltQuery,
table: &Table,
) -> Result<Option<Row>, Error> {
let params: Vec<SqlParam> = query.params.iter().map(SqlParam).collect();
let params_ref: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
.iter()
.map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
.collect();
let span = tracing::debug_span!(
"db.query",
sql = %query.sql,
params = params.len(),
rows = tracing::field::Empty,
);
let rows = self
.client
.query(&query.sql, ¶ms_ref)
.instrument(span.clone())
.await?;
span.record("rows", rows.len());
if rows.is_empty() {
return Ok(None);
}
let columns: Vec<_> = rows[0]
.columns()
.iter()
.map(|pg_col| {
let name = pg_col.name().to_string();
let pg_type = table
.columns
.iter()
.find(|c| c.name == name)
.map(|c| c.pg_type)
.unwrap_or(PgType::Text);
(name, pg_type)
})
.collect();
let ctx = RowContext {
table_name: &table.name,
};
Ok(Some(pg_row_to_row(&rows[0], &columns, &ctx)?))
}
}
pub struct SelectBuilder<'a> {
db: &'a Db<'a>,
table: &'a Table,
query: SelectQuery,
}
impl<'a> SelectBuilder<'a> {
pub fn columns(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.query = self.query.columns(cols);
self
}
pub fn filter(mut self, expr: super::Expr) -> Self {
self.query = self.query.filter(expr);
self
}
pub fn order_by(mut self, column: impl Into<String>, dir: super::SortDir) -> Self {
self.query = self.query.order_by(column, dir);
self
}
pub fn limit(mut self, n: u32) -> Self {
self.query = self.query.limit(n);
self
}
pub fn offset(mut self, n: u32) -> Self {
self.query = self.query.offset(n);
self
}
pub async fn all(self) -> Result<Vec<Row>, Error> {
let built = self.query.build();
self.db.execute_select(built, self.table).await
}
pub async fn one(self) -> Result<Option<Row>, Error> {
let mut rows = self.limit(1).all().await?;
Ok(rows.pop())
}
pub async fn count(self) -> Result<u64, Error> {
let built = self.query.build_count();
let params: Vec<SqlParam> = built.params.iter().map(SqlParam).collect();
let params_ref: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
.iter()
.map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
.collect();
let span = tracing::debug_span!(
"db.query",
sql = %built.sql,
params = params.len(),
count = tracing::field::Empty,
);
let rows = self
.db
.client
.query(&built.sql, ¶ms_ref)
.instrument(span.clone())
.await?;
let count: i64 = rows[0].get(0);
span.record("count", count);
Ok(count as u64)
}
}
pub struct InsertBuilder<'a> {
db: &'a Db<'a>,
table: &'a Table,
query: InsertQuery,
}
impl<'a> InsertBuilder<'a> {
pub fn values(
mut self,
data: impl IntoIterator<Item = (impl Into<String>, impl Into<Value>)>,
) -> Self {
self.query = self.query.values(data);
self
}
pub async fn execute(self) -> Result<u64, Error> {
let built = self.query.build();
self.db.execute_mutation(built).await
}
pub async fn returning(mut self) -> Result<Option<Row>, Error> {
self.query = self.query.returning_all();
let built = self.query.build();
self.db.execute_returning(built, self.table).await
}
}
pub struct UpdateBuilder<'a> {
db: &'a Db<'a>,
table: &'a Table,
query: UpdateQuery,
}
impl<'a> UpdateBuilder<'a> {
pub fn set(
mut self,
data: impl IntoIterator<Item = (impl Into<String>, impl Into<Value>)>,
) -> Self {
self.query = self.query.set(data);
self
}
pub fn filter(mut self, expr: super::Expr) -> Self {
self.query = self.query.filter(expr);
self
}
pub async fn execute(self) -> Result<u64, Error> {
let built = self.query.build();
self.db.execute_mutation(built).await
}
pub async fn returning(mut self) -> Result<Option<Row>, Error> {
self.query = self.query.returning_all();
let built = self.query.build();
self.db.execute_returning(built, self.table).await
}
}
pub struct DeleteBuilder<'a> {
db: &'a Db<'a>,
table: &'a Table,
query: DeleteQuery,
}
impl<'a> DeleteBuilder<'a> {
pub fn filter(mut self, expr: super::Expr) -> Self {
self.query = self.query.filter(expr);
self
}
pub async fn execute(self) -> Result<u64, Error> {
let built = self.query.build();
self.db.execute_mutation(built).await
}
pub async fn returning(mut self) -> Result<Option<Row>, Error> {
self.query = self.query.returning_all();
let built = self.query.build();
self.db.execute_returning(built, self.table).await
}
}