reinhardt_db/backends/dialect/
postgres.rs1use async_trait::async_trait;
4use sqlx::{Column, PgPool, Postgres, Transaction, postgres::PgRow};
5use std::sync::Arc;
6use uuid::Uuid;
7
8use crate::backends::{
9 backend::DatabaseBackend,
10 error::{DatabaseError, Result},
11 types::{
12 DatabaseType, IsolationLevel, QueryResult, QueryValue, Row, Savepoint, TransactionExecutor,
13 },
14};
15
16pub struct PostgresBackend {
18 pool: Arc<PgPool>,
19}
20
21impl PostgresBackend {
22 pub fn new(pool: PgPool) -> Self {
24 Self {
25 pool: Arc::new(pool),
26 }
27 }
28
29 pub fn pool(&self) -> &PgPool {
31 &self.pool
32 }
33
34 fn bind_value<'q>(
35 query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
36 value: &'q QueryValue,
37 ) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
38 match value {
39 QueryValue::Null => query.bind(None::<i32>),
40 QueryValue::Bool(b) => query.bind(b),
41 QueryValue::Int(i) => query.bind(i),
42 QueryValue::Float(f) => query.bind(f),
43 QueryValue::String(s) => query.bind(s),
44 QueryValue::Bytes(b) => query.bind(b),
45 QueryValue::Timestamp(dt) => query.bind(dt),
46 QueryValue::Uuid(u) => query.bind(u),
47 QueryValue::Now => {
48 query.bind(chrono::Utc::now())
51 }
52 }
53 }
54
55 fn convert_row(pg_row: PgRow) -> Result<Row> {
56 Self::convert_row_internal(pg_row)
57 }
58}
59
60#[async_trait]
61impl DatabaseBackend for PostgresBackend {
62 fn database_type(&self) -> DatabaseType {
63 DatabaseType::Postgres
64 }
65
66 fn placeholder(&self, index: usize) -> String {
67 format!("${}", index)
68 }
69
70 fn supports_returning(&self) -> bool {
71 true
72 }
73
74 fn supports_on_conflict(&self) -> bool {
75 true
76 }
77
78 async fn execute(&self, sql: &str, params: Vec<QueryValue>) -> Result<QueryResult> {
79 let mut query = sqlx::query(sql);
80 for param in ¶ms {
81 query = Self::bind_value(query, param);
82 }
83 let result = query.execute(self.pool.as_ref()).await?;
84 Ok(QueryResult {
85 rows_affected: result.rows_affected(),
86 })
87 }
88
89 async fn fetch_one(&self, sql: &str, params: Vec<QueryValue>) -> Result<Row> {
90 let mut query = sqlx::query(sql);
91 for param in ¶ms {
92 query = Self::bind_value(query, param);
93 }
94 let row = query.fetch_one(self.pool.as_ref()).await?;
95 Self::convert_row(row)
96 }
97
98 async fn fetch_all(&self, sql: &str, params: Vec<QueryValue>) -> Result<Vec<Row>> {
99 let mut query = sqlx::query(sql);
100 for param in ¶ms {
101 query = Self::bind_value(query, param);
102 }
103 let rows = query.fetch_all(self.pool.as_ref()).await?;
104 rows.into_iter().map(Self::convert_row).collect()
105 }
106
107 async fn fetch_optional(&self, sql: &str, params: Vec<QueryValue>) -> Result<Option<Row>> {
108 let mut query = sqlx::query(sql);
109 for param in ¶ms {
110 query = Self::bind_value(query, param);
111 }
112 let row = query.fetch_optional(self.pool.as_ref()).await?;
113 row.map(Self::convert_row).transpose()
114 }
115
116 async fn begin(&self) -> Result<Box<dyn TransactionExecutor>> {
117 let tx = self.pool.begin().await?;
118 Ok(Box::new(PgTransactionExecutor::new(tx)))
119 }
120
121 async fn begin_with_isolation(
122 &self,
123 isolation_level: IsolationLevel,
124 ) -> Result<Box<dyn TransactionExecutor>> {
125 let mut tx = self.pool.begin().await?;
127
128 let sql = format!(
130 "SET TRANSACTION ISOLATION LEVEL {}",
131 isolation_level.to_sql(DatabaseType::Postgres)
132 );
133 sqlx::query(&sql).execute(&mut *tx).await?;
134
135 Ok(Box::new(PgTransactionExecutor::new(tx)))
136 }
137
138 fn as_any(&self) -> &dyn std::any::Any {
139 self
140 }
141}
142
143pub struct PgTransactionExecutor {
148 tx: Option<Transaction<'static, Postgres>>,
149}
150
151impl PgTransactionExecutor {
152 pub fn new(tx: Transaction<'static, Postgres>) -> Self {
154 Self { tx: Some(tx) }
155 }
156
157 fn bind_value<'q>(
158 query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
159 value: &'q QueryValue,
160 ) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
161 match value {
162 QueryValue::Null => query.bind(None::<i32>),
163 QueryValue::Bool(b) => query.bind(b),
164 QueryValue::Int(i) => query.bind(i),
165 QueryValue::Float(f) => query.bind(f),
166 QueryValue::String(s) => query.bind(s),
167 QueryValue::Bytes(b) => query.bind(b),
168 QueryValue::Timestamp(dt) => query.bind(dt),
169 QueryValue::Uuid(u) => query.bind(u),
170 QueryValue::Now => query.bind(chrono::Utc::now()),
171 }
172 }
173
174 fn convert_row(pg_row: PgRow) -> Result<Row> {
175 PostgresBackend::convert_row_internal(pg_row)
176 }
177}
178
179impl PostgresBackend {
180 pub(crate) fn convert_row_internal(pg_row: PgRow) -> Result<Row> {
182 use rust_decimal::prelude::ToPrimitive;
183 use sqlx::Row as SqlxRow;
184
185 let mut row = Row::new();
186 for column in pg_row.columns() {
187 let column_name = column.name();
188
189 if let Ok(value) = pg_row.try_get::<Uuid, _>(column_name) {
190 row.insert(column_name.to_string(), QueryValue::Uuid(value));
191 } else if let Ok(value) = pg_row.try_get::<bool, _>(column_name) {
192 row.insert(column_name.to_string(), QueryValue::Bool(value));
193 } else if let Ok(value) = pg_row.try_get::<i64, _>(column_name) {
194 row.insert(column_name.to_string(), QueryValue::Int(value));
195 } else if let Ok(value) = pg_row.try_get::<i32, _>(column_name) {
196 row.insert(column_name.to_string(), QueryValue::Int(value as i64));
197 } else if let Ok(value) = pg_row.try_get::<rust_decimal::Decimal, _>(column_name) {
198 if let Some(f) = value.to_f64() {
200 row.insert(column_name.to_string(), QueryValue::Float(f));
201 } else {
202 return Err(Self::decimal_conversion_error(&value, column_name));
203 }
204 } else if let Ok(value) = pg_row.try_get::<f64, _>(column_name) {
205 row.insert(column_name.to_string(), QueryValue::Float(value));
206 } else if let Ok(value) = pg_row.try_get::<String, _>(column_name) {
207 row.insert(column_name.to_string(), QueryValue::String(value));
208 } else if let Ok(value) = pg_row.try_get::<Vec<u8>, _>(column_name) {
209 row.insert(column_name.to_string(), QueryValue::Bytes(value));
210 } else if let Ok(value) = pg_row.try_get::<chrono::NaiveDateTime, _>(column_name) {
211 row.insert(
212 column_name.to_string(),
213 QueryValue::Timestamp(chrono::DateTime::from_naive_utc_and_offset(
214 value,
215 chrono::Utc,
216 )),
217 );
218 } else if let Ok(value) =
219 pg_row.try_get::<chrono::DateTime<chrono::Utc>, _>(column_name)
220 {
221 row.insert(column_name.to_string(), QueryValue::Timestamp(value));
222 } else if pg_row.try_get::<Option<i32>, _>(column_name).is_ok() {
223 row.insert(column_name.to_string(), QueryValue::Null);
224 }
225 }
226 Ok(row)
227 }
228
229 fn decimal_conversion_error(value: &rust_decimal::Decimal, column_name: &str) -> DatabaseError {
231 DatabaseError::TypeError(format!(
232 "Failed to convert Decimal value '{}' to f64 for column '{}'",
233 value, column_name
234 ))
235 }
236}
237
238#[async_trait]
239impl TransactionExecutor for PgTransactionExecutor {
240 async fn execute(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<QueryResult> {
241 let tx = self.tx.as_mut().ok_or_else(|| {
242 crate::backends::error::DatabaseError::TransactionError(
243 "Transaction already consumed".to_string(),
244 )
245 })?;
246
247 let mut query = sqlx::query(sql);
248 for param in ¶ms {
249 query = Self::bind_value(query, param);
250 }
251 let result = query.execute(&mut **tx).await?;
252 Ok(QueryResult {
253 rows_affected: result.rows_affected(),
254 })
255 }
256
257 async fn fetch_one(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Row> {
258 let tx = self.tx.as_mut().ok_or_else(|| {
259 crate::backends::error::DatabaseError::TransactionError(
260 "Transaction already consumed".to_string(),
261 )
262 })?;
263
264 let mut query = sqlx::query(sql);
265 for param in ¶ms {
266 query = Self::bind_value(query, param);
267 }
268 let row = query.fetch_one(&mut **tx).await?;
269 Self::convert_row(row)
270 }
271
272 async fn fetch_all(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Vec<Row>> {
273 let tx = self.tx.as_mut().ok_or_else(|| {
274 crate::backends::error::DatabaseError::TransactionError(
275 "Transaction already consumed".to_string(),
276 )
277 })?;
278
279 let mut query = sqlx::query(sql);
280 for param in ¶ms {
281 query = Self::bind_value(query, param);
282 }
283 let rows = query.fetch_all(&mut **tx).await?;
284 rows.into_iter().map(Self::convert_row).collect()
285 }
286
287 async fn fetch_optional(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Option<Row>> {
288 let tx = self.tx.as_mut().ok_or_else(|| {
289 crate::backends::error::DatabaseError::TransactionError(
290 "Transaction already consumed".to_string(),
291 )
292 })?;
293
294 let mut query = sqlx::query(sql);
295 for param in ¶ms {
296 query = Self::bind_value(query, param);
297 }
298 let row = query.fetch_optional(&mut **tx).await?;
299 row.map(Self::convert_row).transpose()
300 }
301
302 async fn commit(mut self: Box<Self>) -> Result<()> {
303 let tx = self.tx.take().ok_or_else(|| {
304 crate::backends::error::DatabaseError::TransactionError(
305 "Transaction already consumed".to_string(),
306 )
307 })?;
308 tx.commit().await?;
309 Ok(())
310 }
311
312 async fn rollback(mut self: Box<Self>) -> Result<()> {
313 let tx = self.tx.take().ok_or_else(|| {
314 crate::backends::error::DatabaseError::TransactionError(
315 "Transaction already consumed".to_string(),
316 )
317 })?;
318 tx.rollback().await?;
319 Ok(())
320 }
321
322 async fn savepoint(&mut self, name: &str) -> Result<()> {
323 let tx = self.tx.as_mut().ok_or_else(|| {
324 crate::backends::error::DatabaseError::TransactionError(
325 "Transaction already consumed".to_string(),
326 )
327 })?;
328
329 let sp = Savepoint::new(name);
330 sqlx::query(&sp.to_sql()).execute(&mut **tx).await?;
331 Ok(())
332 }
333
334 async fn release_savepoint(&mut self, name: &str) -> Result<()> {
335 let tx = self.tx.as_mut().ok_or_else(|| {
336 crate::backends::error::DatabaseError::TransactionError(
337 "Transaction already consumed".to_string(),
338 )
339 })?;
340
341 let sp = Savepoint::new(name);
342 sqlx::query(&sp.release_sql()).execute(&mut **tx).await?;
343 Ok(())
344 }
345
346 async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
347 let tx = self.tx.as_mut().ok_or_else(|| {
348 crate::backends::error::DatabaseError::TransactionError(
349 "Transaction already consumed".to_string(),
350 )
351 })?;
352
353 let sp = Savepoint::new(name);
354 sqlx::query(&sp.rollback_sql()).execute(&mut **tx).await?;
355 Ok(())
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use rstest::rstest;
362 use rust_decimal::prelude::ToPrimitive;
363
364 #[rstest]
366 #[case::positive(rust_decimal::Decimal::new(12345, 2), 123.45)]
367 #[case::zero(rust_decimal::Decimal::ZERO, 0.0)]
368 #[case::negative(rust_decimal::Decimal::new(-999, 1), -99.9)]
369 #[case::max(rust_decimal::Decimal::MAX, 7.922816251426434e28)]
370 fn test_decimal_to_f64_conversion_succeeds(
371 #[case] decimal: rust_decimal::Decimal,
372 #[case] expected: f64,
373 ) {
374 let result = decimal.to_f64();
376
377 assert!(
379 result.is_some(),
380 "Decimal '{}' should convert to f64",
381 decimal
382 );
383 let f = result.unwrap();
384
385 let diff = (f - expected).abs();
387 let rel_tol = 1e-12;
388 let abs_tol = 1e-12;
389 let tol = expected.abs() * rel_tol + abs_tol;
390
391 assert!(
392 diff <= tol,
393 "Expected approximately {} (tolerance {}, diff {}), got {}",
394 expected,
395 tol,
396 diff,
397 f
398 );
399 }
400
401 #[rstest]
403 fn test_decimal_conversion_error_message_format() {
404 use crate::backends::error::DatabaseError;
405
406 let value = rust_decimal::Decimal::new(12345, 2);
408 let column_name = "price_column";
409
410 let error = super::PostgresBackend::decimal_conversion_error(&value, column_name);
412
413 assert!(matches!(error, DatabaseError::TypeError(_)));
415 let error_msg = error.to_string();
416 assert!(
417 error_msg.contains("price_column"),
418 "Error message should contain the column name"
419 );
420 assert!(
421 error_msg.contains("123.45"),
422 "Error message should contain the decimal value"
423 );
424 }
425
426 #[rstest]
428 fn test_type_error_variant_distinction() {
429 use crate::backends::error::DatabaseError;
430
431 let type_error = DatabaseError::TypeError("conversion failed".to_string());
433 let query_error = DatabaseError::QueryError("query failed".to_string());
434
435 assert!(matches!(type_error, DatabaseError::TypeError(_)));
437 assert!(!matches!(type_error, DatabaseError::QueryError(_)));
438 assert!(matches!(query_error, DatabaseError::QueryError(_)));
439 assert!(!matches!(query_error, DatabaseError::TypeError(_)));
440 }
441}