use crate::adapters::params::convert_params;
use crate::adapters::result_set::{column_count, init_result_set};
use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
use crate::query_utils::extract_column_names;
use crate::types::ConversionMode;
use chrono::NaiveDateTime;
use serde_json::Value;
use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
use super::params::Params as PgParams;
pub async fn build_result_set(
stmt: &Statement,
params: &[&(dyn ToSql + Sync)],
transaction: &Transaction<'_>,
) -> Result<ResultSet, SqlMiddlewareDbError> {
let rows = transaction.query(stmt, params).await?;
let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
let capacity = rows.len();
let mut result_set = init_result_set(column_names, capacity);
for row in rows {
let mut row_values = Vec::new();
let col_count = column_count(&result_set)?;
for i in 0..col_count {
let value = postgres_extract_value(&row, i)?;
row_values.push(value);
}
result_set.add_row_values(row_values);
}
Ok(result_set)
}
pub fn postgres_extract_value(
row: &tokio_postgres::Row,
idx: usize,
) -> Result<RowValues, SqlMiddlewareDbError> {
let type_info = row.columns()[idx].type_();
if type_info.name() == "int2" {
let val: Option<i16> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
} else if type_info.name() == "int4" {
let val: Option<i32> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
} else if type_info.name() == "int8" {
let val: Option<i64> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Int))
} else if type_info.name() == "float4" || type_info.name() == "float8" {
let val: Option<f64> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Float))
} else if type_info.name() == "bool" {
let val: Option<bool> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Bool))
} else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
let val: Option<NaiveDateTime> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
} else if type_info.name() == "json" || type_info.name() == "jsonb" {
let val: Option<Value> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::JSON))
} else if type_info.name() == "bytea" {
let val: Option<Vec<u8>> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Blob))
} else if type_info.name() == "text"
|| type_info.name() == "varchar"
|| type_info.name() == "char"
{
let val: Option<String> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Text))
} else {
let val: Option<String> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Text))
}
}
pub(crate) fn build_result_set_from_rows(
rows: &[tokio_postgres::Row],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let mut result_set = ResultSet::with_capacity(rows.len());
if let Some(row) = rows.first() {
let cols = extract_column_names(row.columns().iter(), |col| col.name());
result_set.set_column_names(std::sync::Arc::new(cols));
}
for row in rows {
let col_count = row.columns().len();
let mut row_values = Vec::with_capacity(col_count);
for idx in 0..col_count {
row_values.push(postgres_extract_value(row, idx)?);
}
result_set.add_row_values(row_values);
}
Ok(result_set)
}
pub(crate) fn build_result_set_from_statement(
stmt: &Statement,
rows: &[tokio_postgres::Row],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
let column_count = column_names.len();
let mut result_set = init_result_set(column_names, rows.len());
for row in rows {
let mut row_values = Vec::with_capacity(column_count);
for idx in 0..column_count {
row_values.push(postgres_extract_value(row, idx)?);
}
result_set.add_row_values(row_values);
}
Ok(result_set)
}
pub async fn execute_query_on_client(
client: &Client,
query: &str,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let rows =
query_rows_on_client(client, query, None, params, "postgres select error").await?;
build_result_set_from_rows(&rows)
}
pub(crate) async fn execute_query_prepared_on_client(
client: &Client,
query: &str,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let stmt = client.prepare(query).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
})?;
let rows =
query_rows_on_client(client, query, Some(&stmt), params, "postgres select error").await?;
build_result_set_from_statement(&stmt, &rows)
}
pub async fn execute_dml_on_client(
client: &Client,
query: &str,
params: &[RowValues],
err_label: &str,
) -> Result<usize, SqlMiddlewareDbError> {
let rows = execute_rows_on_client(client, query, None, params, err_label).await?;
convert_affected_rows(rows, "postgres affected rows conversion error")
}
pub(crate) async fn execute_dml_prepared_on_client(
client: &Client,
query: &str,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let stmt = client.prepare(query).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
})?;
let rows =
execute_rows_on_client(client, query, Some(&stmt), params, "postgres execute error")
.await?;
convert_affected_rows(rows, "postgres affected rows conversion error")
}
pub(crate) fn convert_affected_rows(
rows: u64,
label: &str,
) -> Result<usize, SqlMiddlewareDbError> {
usize::try_from(rows).map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("{label}: {e}"))
})
}
async fn query_rows_on_client(
client: &Client,
query: &str,
stmt: Option<&Statement>,
params: &[RowValues],
err_label: &str,
) -> Result<Vec<tokio_postgres::Row>, SqlMiddlewareDbError>
{
let converted = convert_params::<PgParams>(params, ConversionMode::Query)?;
let refs = converted.as_refs();
let rows = match stmt {
Some(stmt) => client.query(stmt, refs).await,
None => client.query(query, refs).await,
};
rows.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))
}
async fn execute_rows_on_client(
client: &Client,
query: &str,
stmt: Option<&Statement>,
params: &[RowValues],
err_label: &str,
) -> Result<u64, SqlMiddlewareDbError>
{
let converted = convert_params::<PgParams>(params, ConversionMode::Execute)?;
let refs = converted.as_refs();
let rows = match stmt {
Some(stmt) => client.execute(stmt, refs).await,
None => client.execute(query, refs).await,
};
rows.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))
}