use super::client_exec::{execute_rows_on_client, query_rows_on_client};
use crate::adapters::result_set::{column_count, init_result_set};
use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
use crate::query_utils::extract_column_names;
use chrono::NaiveDateTime;
use serde_json::Value;
use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
pub async fn build_result_set(
stmt: &Statement,
params: &[&(dyn ToSql + Sync)],
transaction: &Transaction<'_>,
) -> Result<ResultSet, SqlMiddlewareDbError> {
let rows = transaction.query(stmt, params).await?;
let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
let capacity = rows.len();
let mut result_set = init_result_set(column_names, capacity);
for row in rows {
let mut row_values = Vec::new();
let col_count = column_count(&result_set)?;
for i in 0..col_count {
let value = postgres_extract_value(&row, i)?;
row_values.push(value);
}
result_set.add_row_values(row_values);
}
Ok(result_set)
}
pub fn postgres_extract_value(
row: &tokio_postgres::Row,
idx: usize,
) -> Result<RowValues, SqlMiddlewareDbError> {
let type_info = row.columns()[idx].type_();
if type_info.name() == "int2" {
let val: Option<i16> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
} else if type_info.name() == "int4" {
let val: Option<i32> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
} else if type_info.name() == "int8" {
let val: Option<i64> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Int))
} else if type_info.name() == "float4" || type_info.name() == "float8" {
let val: Option<f64> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Float))
} else if type_info.name() == "bool" {
let val: Option<bool> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Bool))
} else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
let val: Option<NaiveDateTime> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
} else if type_info.name() == "json" || type_info.name() == "jsonb" {
let val: Option<Value> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::JSON))
} else if type_info.name() == "bytea" {
let val: Option<Vec<u8>> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Blob))
} else if type_info.name() == "text"
|| type_info.name() == "varchar"
|| type_info.name() == "char"
{
let val: Option<String> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Text))
} else {
let val: Option<String> = row.try_get(idx)?;
Ok(val.map_or(RowValues::Null, RowValues::Text))
}
}
pub(crate) fn build_result_set_from_rows(
rows: &[tokio_postgres::Row],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let mut result_set = ResultSet::with_capacity(rows.len());
if let Some(row) = rows.first() {
let cols = extract_column_names(row.columns().iter(), |col| col.name());
result_set.set_column_names(std::sync::Arc::new(cols));
}
for row in rows {
let col_count = row.columns().len();
let mut row_values = Vec::with_capacity(col_count);
for idx in 0..col_count {
row_values.push(postgres_extract_value(row, idx)?);
}
result_set.add_row_values(row_values);
}
Ok(result_set)
}
pub(crate) fn build_result_set_from_statement(
stmt: &Statement,
rows: &[tokio_postgres::Row],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
let column_count = column_names.len();
let mut result_set = init_result_set(column_names, rows.len());
for row in rows {
let mut row_values = Vec::with_capacity(column_count);
for idx in 0..column_count {
row_values.push(postgres_extract_value(row, idx)?);
}
result_set.add_row_values(row_values);
}
Ok(result_set)
}
pub async fn execute_query_on_client(
client: &Client,
query: &str,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let rows = query_rows_on_client(client, query, None, params, "postgres select error").await?;
build_result_set_from_rows(&rows)
}
pub(crate) async fn execute_query_prepared_on_client(
client: &Client,
query: &str,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let stmt = client.prepare(query).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
})?;
let rows =
query_rows_on_client(client, query, Some(&stmt), params, "postgres select error").await?;
build_result_set_from_statement(&stmt, &rows)
}
pub(crate) async fn execute_query_prepared_statement_on_client(
client: &Client,
stmt: &Statement,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let rows =
query_rows_on_client(client, "", Some(stmt), params, "postgres select error").await?;
build_result_set_from_statement(stmt, &rows)
}
pub async fn execute_dml_on_client(
client: &Client,
query: &str,
params: &[RowValues],
err_label: &str,
) -> Result<usize, SqlMiddlewareDbError> {
let rows = execute_rows_on_client(client, query, None, params, err_label).await?;
convert_affected_rows(rows, "postgres affected rows conversion error")
}
pub(crate) async fn execute_dml_prepared_on_client(
client: &Client,
query: &str,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let stmt = client.prepare(query).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
})?;
let rows = execute_rows_on_client(client, query, Some(&stmt), params, "postgres execute error")
.await?;
convert_affected_rows(rows, "postgres affected rows conversion error")
}
pub(crate) async fn execute_dml_prepared_statement_on_client(
client: &Client,
stmt: &Statement,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let rows =
execute_rows_on_client(client, "", Some(stmt), params, "postgres execute error").await?;
convert_affected_rows(rows, "postgres affected rows conversion error")
}
pub(crate) fn convert_affected_rows(rows: u64, label: &str) -> Result<usize, SqlMiddlewareDbError> {
usize::try_from(rows).map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{label}: {e}")))
}