1use crate::adapters::params::convert_params;
2use crate::adapters::result_set::{column_count, init_result_set};
3use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
4use crate::query_utils::extract_column_names;
5use crate::types::ConversionMode;
6use chrono::NaiveDateTime;
7use serde_json::Value;
8use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
9
10use super::params::Params as PgParams;
11
12pub async fn build_result_set(
17 stmt: &Statement,
18 params: &[&(dyn ToSql + Sync)],
19 transaction: &Transaction<'_>,
20) -> Result<ResultSet, SqlMiddlewareDbError> {
21 let rows = transaction.query(stmt, params).await?;
23
24 let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
25
26 let capacity = rows.len();
28 let mut result_set = init_result_set(column_names, capacity);
29
30 for row in rows {
31 let mut row_values = Vec::new();
32
33 let col_count = column_count(&result_set)?;
34
35 for i in 0..col_count {
36 let value = postgres_extract_value(&row, i)?;
37 row_values.push(value);
38 }
39
40 result_set.add_row_values(row_values);
41 }
42
43 Ok(result_set)
44}
45
46pub fn postgres_extract_value(
51 row: &tokio_postgres::Row,
52 idx: usize,
53) -> Result<RowValues, SqlMiddlewareDbError> {
54 let type_info = row.columns()[idx].type_();
56
57 if type_info.name() == "int2" {
60 let val: Option<i16> = row.try_get(idx)?;
61 Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
62 } else if type_info.name() == "int4" {
63 let val: Option<i32> = row.try_get(idx)?;
64 Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
65 } else if type_info.name() == "int8" {
66 let val: Option<i64> = row.try_get(idx)?;
67 Ok(val.map_or(RowValues::Null, RowValues::Int))
68 } else if type_info.name() == "float4" || type_info.name() == "float8" {
69 let val: Option<f64> = row.try_get(idx)?;
70 Ok(val.map_or(RowValues::Null, RowValues::Float))
71 } else if type_info.name() == "bool" {
72 let val: Option<bool> = row.try_get(idx)?;
73 Ok(val.map_or(RowValues::Null, RowValues::Bool))
74 } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
75 let val: Option<NaiveDateTime> = row.try_get(idx)?;
76 Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
77 } else if type_info.name() == "json" || type_info.name() == "jsonb" {
78 let val: Option<Value> = row.try_get(idx)?;
79 Ok(val.map_or(RowValues::Null, RowValues::JSON))
80 } else if type_info.name() == "bytea" {
81 let val: Option<Vec<u8>> = row.try_get(idx)?;
82 Ok(val.map_or(RowValues::Null, RowValues::Blob))
83 } else if type_info.name() == "text"
84 || type_info.name() == "varchar"
85 || type_info.name() == "char"
86 {
87 let val: Option<String> = row.try_get(idx)?;
88 Ok(val.map_or(RowValues::Null, RowValues::Text))
89 } else {
90 let val: Option<String> = row.try_get(idx)?;
92 Ok(val.map_or(RowValues::Null, RowValues::Text))
93 }
94}
95
96pub(crate) fn build_result_set_from_rows(
101 rows: &[tokio_postgres::Row],
102) -> Result<ResultSet, SqlMiddlewareDbError> {
103 let mut result_set = ResultSet::with_capacity(rows.len());
104 if let Some(row) = rows.first() {
105 let cols = extract_column_names(row.columns().iter(), |col| col.name());
106 result_set.set_column_names(std::sync::Arc::new(cols));
107 }
108
109 for row in rows {
110 let col_count = row.columns().len();
111 let mut row_values = Vec::with_capacity(col_count);
112 for idx in 0..col_count {
113 row_values.push(postgres_extract_value(row, idx)?);
114 }
115 result_set.add_row_values(row_values);
116 }
117
118 Ok(result_set)
119}
120
121pub(crate) fn build_result_set_from_statement(
126 stmt: &Statement,
127 rows: &[tokio_postgres::Row],
128) -> Result<ResultSet, SqlMiddlewareDbError> {
129 let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
130 let column_count = column_names.len();
131
132 let mut result_set = init_result_set(column_names, rows.len());
133
134 for row in rows {
135 let mut row_values = Vec::with_capacity(column_count);
136 for idx in 0..column_count {
137 row_values.push(postgres_extract_value(row, idx)?);
138 }
139 result_set.add_row_values(row_values);
140 }
141
142 Ok(result_set)
143}
144
145pub async fn execute_query_on_client(
150 client: &Client,
151 query: &str,
152 params: &[RowValues],
153) -> Result<ResultSet, SqlMiddlewareDbError> {
154 let rows =
155 query_rows_on_client(client, query, None, params, "postgres select error").await?;
156 build_result_set_from_rows(&rows)
157}
158
159pub(crate) async fn execute_query_prepared_on_client(
164 client: &Client,
165 query: &str,
166 params: &[RowValues],
167) -> Result<ResultSet, SqlMiddlewareDbError> {
168 let stmt = client.prepare(query).await.map_err(|e| {
169 SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
170 })?;
171 let rows =
172 query_rows_on_client(client, query, Some(&stmt), params, "postgres select error").await?;
173 build_result_set_from_statement(&stmt, &rows)
174}
175
176pub async fn execute_dml_on_client(
181 client: &Client,
182 query: &str,
183 params: &[RowValues],
184 err_label: &str,
185) -> Result<usize, SqlMiddlewareDbError> {
186 let rows = execute_rows_on_client(client, query, None, params, err_label).await?;
187 convert_affected_rows(rows, "postgres affected rows conversion error")
188}
189
190pub(crate) async fn execute_dml_prepared_on_client(
195 client: &Client,
196 query: &str,
197 params: &[RowValues],
198) -> Result<usize, SqlMiddlewareDbError> {
199 let stmt = client.prepare(query).await.map_err(|e| {
200 SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
201 })?;
202 let rows =
203 execute_rows_on_client(client, query, Some(&stmt), params, "postgres execute error")
204 .await?;
205 convert_affected_rows(rows, "postgres affected rows conversion error")
206}
207
208pub(crate) fn convert_affected_rows(
209 rows: u64,
210 label: &str,
211) -> Result<usize, SqlMiddlewareDbError> {
212 usize::try_from(rows).map_err(|e| {
213 SqlMiddlewareDbError::ExecutionError(format!("{label}: {e}"))
214 })
215}
216
217async fn query_rows_on_client(
218 client: &Client,
219 query: &str,
220 stmt: Option<&Statement>,
221 params: &[RowValues],
222 err_label: &str,
223) -> Result<Vec<tokio_postgres::Row>, SqlMiddlewareDbError>
224{
225 let converted = convert_params::<PgParams>(params, ConversionMode::Query)?;
226 let refs = converted.as_refs();
227 let rows = match stmt {
228 Some(stmt) => client.query(stmt, refs).await,
229 None => client.query(query, refs).await,
230 };
231 rows.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))
232}
233
234async fn execute_rows_on_client(
235 client: &Client,
236 query: &str,
237 stmt: Option<&Statement>,
238 params: &[RowValues],
239 err_label: &str,
240) -> Result<u64, SqlMiddlewareDbError>
241{
242 let converted = convert_params::<PgParams>(params, ConversionMode::Execute)?;
243 let refs = converted.as_refs();
244 let rows = match stmt {
245 Some(stmt) => client.execute(stmt, refs).await,
246 None => client.execute(query, refs).await,
247 };
248 rows.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))
249}