Skip to main content

reinhardt_db/backends/dialect/
postgres.rs

1//! PostgreSQL dialect implementation
2
3use async_trait::async_trait;
4use sqlx::{Column, PgPool, Postgres, Transaction, postgres::PgRow};
5use std::sync::Arc;
6use uuid::Uuid;
7
8use crate::backends::{
9	backend::DatabaseBackend,
10	error::{DatabaseError, Result},
11	types::{
12		DatabaseType, IsolationLevel, QueryResult, QueryValue, Row, Savepoint, TransactionExecutor,
13	},
14};
15
16/// PostgreSQL database backend
17pub struct PostgresBackend {
18	pool: Arc<PgPool>,
19}
20
21impl PostgresBackend {
22	/// Creates a new instance.
23	pub fn new(pool: PgPool) -> Self {
24		Self {
25			pool: Arc::new(pool),
26		}
27	}
28
29	/// Performs the pool operation.
30	pub fn pool(&self) -> &PgPool {
31		&self.pool
32	}
33
34	fn bind_value<'q>(
35		query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
36		value: &'q QueryValue,
37	) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
38		match value {
39			QueryValue::Null => query.bind(None::<i32>),
40			QueryValue::Bool(b) => query.bind(b),
41			QueryValue::Int(i) => query.bind(i),
42			QueryValue::Float(f) => query.bind(f),
43			QueryValue::String(s) => query.bind(s),
44			QueryValue::Bytes(b) => query.bind(b),
45			QueryValue::Timestamp(dt) => query.bind(dt),
46			QueryValue::Uuid(u) => query.bind(u),
47			QueryValue::Now => {
48				// PostgreSQL uses NOW() function, which should be part of SQL string
49				// For binding, we use current UTC time
50				query.bind(chrono::Utc::now())
51			}
52		}
53	}
54
55	fn convert_row(pg_row: PgRow) -> Result<Row> {
56		Self::convert_row_internal(pg_row)
57	}
58}
59
60#[async_trait]
61impl DatabaseBackend for PostgresBackend {
62	fn database_type(&self) -> DatabaseType {
63		DatabaseType::Postgres
64	}
65
66	fn placeholder(&self, index: usize) -> String {
67		format!("${}", index)
68	}
69
70	fn supports_returning(&self) -> bool {
71		true
72	}
73
74	fn supports_on_conflict(&self) -> bool {
75		true
76	}
77
78	async fn execute(&self, sql: &str, params: Vec<QueryValue>) -> Result<QueryResult> {
79		let mut query = sqlx::query(sql);
80		for param in &params {
81			query = Self::bind_value(query, param);
82		}
83		let result = query.execute(self.pool.as_ref()).await?;
84		Ok(QueryResult {
85			rows_affected: result.rows_affected(),
86		})
87	}
88
89	async fn fetch_one(&self, sql: &str, params: Vec<QueryValue>) -> Result<Row> {
90		let mut query = sqlx::query(sql);
91		for param in &params {
92			query = Self::bind_value(query, param);
93		}
94		let row = query.fetch_one(self.pool.as_ref()).await?;
95		Self::convert_row(row)
96	}
97
98	async fn fetch_all(&self, sql: &str, params: Vec<QueryValue>) -> Result<Vec<Row>> {
99		let mut query = sqlx::query(sql);
100		for param in &params {
101			query = Self::bind_value(query, param);
102		}
103		let rows = query.fetch_all(self.pool.as_ref()).await?;
104		rows.into_iter().map(Self::convert_row).collect()
105	}
106
107	async fn fetch_optional(&self, sql: &str, params: Vec<QueryValue>) -> Result<Option<Row>> {
108		let mut query = sqlx::query(sql);
109		for param in &params {
110			query = Self::bind_value(query, param);
111		}
112		let row = query.fetch_optional(self.pool.as_ref()).await?;
113		row.map(Self::convert_row).transpose()
114	}
115
116	async fn begin(&self) -> Result<Box<dyn TransactionExecutor>> {
117		let tx = self.pool.begin().await?;
118		Ok(Box::new(PgTransactionExecutor::new(tx)))
119	}
120
121	async fn begin_with_isolation(
122		&self,
123		isolation_level: IsolationLevel,
124	) -> Result<Box<dyn TransactionExecutor>> {
125		// PostgreSQL supports setting isolation level at transaction start
126		let mut tx = self.pool.begin().await?;
127
128		// Set the isolation level using PostgreSQL's SET TRANSACTION command
129		let sql = format!(
130			"SET TRANSACTION ISOLATION LEVEL {}",
131			isolation_level.to_sql(DatabaseType::Postgres)
132		);
133		sqlx::query(&sql).execute(&mut *tx).await?;
134
135		Ok(Box::new(PgTransactionExecutor::new(tx)))
136	}
137
138	fn as_any(&self) -> &dyn std::any::Any {
139		self
140	}
141}
142
143/// PostgreSQL transaction executor
144///
145/// This struct wraps a SQLx `Transaction` to ensure all queries
146/// within a transaction run on the same physical database connection.
147pub struct PgTransactionExecutor {
148	tx: Option<Transaction<'static, Postgres>>,
149}
150
151impl PgTransactionExecutor {
152	/// Creates a new instance.
153	pub fn new(tx: Transaction<'static, Postgres>) -> Self {
154		Self { tx: Some(tx) }
155	}
156
157	fn bind_value<'q>(
158		query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
159		value: &'q QueryValue,
160	) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
161		match value {
162			QueryValue::Null => query.bind(None::<i32>),
163			QueryValue::Bool(b) => query.bind(b),
164			QueryValue::Int(i) => query.bind(i),
165			QueryValue::Float(f) => query.bind(f),
166			QueryValue::String(s) => query.bind(s),
167			QueryValue::Bytes(b) => query.bind(b),
168			QueryValue::Timestamp(dt) => query.bind(dt),
169			QueryValue::Uuid(u) => query.bind(u),
170			QueryValue::Now => query.bind(chrono::Utc::now()),
171		}
172	}
173
174	fn convert_row(pg_row: PgRow) -> Result<Row> {
175		PostgresBackend::convert_row_internal(pg_row)
176	}
177}
178
179impl PostgresBackend {
180	/// Internal row conversion method shared between backend and transaction executor
181	pub(crate) fn convert_row_internal(pg_row: PgRow) -> Result<Row> {
182		use rust_decimal::prelude::ToPrimitive;
183		use sqlx::Row as SqlxRow;
184
185		let mut row = Row::new();
186		for column in pg_row.columns() {
187			let column_name = column.name();
188
189			if let Ok(value) = pg_row.try_get::<Uuid, _>(column_name) {
190				row.insert(column_name.to_string(), QueryValue::Uuid(value));
191			} else if let Ok(value) = pg_row.try_get::<bool, _>(column_name) {
192				row.insert(column_name.to_string(), QueryValue::Bool(value));
193			} else if let Ok(value) = pg_row.try_get::<i64, _>(column_name) {
194				row.insert(column_name.to_string(), QueryValue::Int(value));
195			} else if let Ok(value) = pg_row.try_get::<i32, _>(column_name) {
196				row.insert(column_name.to_string(), QueryValue::Int(value as i64));
197			} else if let Ok(value) = pg_row.try_get::<rust_decimal::Decimal, _>(column_name) {
198				// Convert DECIMAL/NUMERIC to f64 for Float storage
199				if let Some(f) = value.to_f64() {
200					row.insert(column_name.to_string(), QueryValue::Float(f));
201				} else {
202					return Err(Self::decimal_conversion_error(&value, column_name));
203				}
204			} else if let Ok(value) = pg_row.try_get::<f64, _>(column_name) {
205				row.insert(column_name.to_string(), QueryValue::Float(value));
206			} else if let Ok(value) = pg_row.try_get::<String, _>(column_name) {
207				row.insert(column_name.to_string(), QueryValue::String(value));
208			} else if let Ok(value) = pg_row.try_get::<Vec<u8>, _>(column_name) {
209				row.insert(column_name.to_string(), QueryValue::Bytes(value));
210			} else if let Ok(value) = pg_row.try_get::<chrono::NaiveDateTime, _>(column_name) {
211				row.insert(
212					column_name.to_string(),
213					QueryValue::Timestamp(chrono::DateTime::from_naive_utc_and_offset(
214						value,
215						chrono::Utc,
216					)),
217				);
218			} else if let Ok(value) =
219				pg_row.try_get::<chrono::DateTime<chrono::Utc>, _>(column_name)
220			{
221				row.insert(column_name.to_string(), QueryValue::Timestamp(value));
222			} else if pg_row.try_get::<Option<i32>, _>(column_name).is_ok() {
223				row.insert(column_name.to_string(), QueryValue::Null);
224			}
225		}
226		Ok(row)
227	}
228
229	/// Build a TypeError for failed Decimal-to-f64 conversion
230	fn decimal_conversion_error(value: &rust_decimal::Decimal, column_name: &str) -> DatabaseError {
231		DatabaseError::TypeError(format!(
232			"Failed to convert Decimal value '{}' to f64 for column '{}'",
233			value, column_name
234		))
235	}
236}
237
238#[async_trait]
239impl TransactionExecutor for PgTransactionExecutor {
240	async fn execute(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<QueryResult> {
241		let tx = self.tx.as_mut().ok_or_else(|| {
242			crate::backends::error::DatabaseError::TransactionError(
243				"Transaction already consumed".to_string(),
244			)
245		})?;
246
247		let mut query = sqlx::query(sql);
248		for param in &params {
249			query = Self::bind_value(query, param);
250		}
251		let result = query.execute(&mut **tx).await?;
252		Ok(QueryResult {
253			rows_affected: result.rows_affected(),
254		})
255	}
256
257	async fn fetch_one(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Row> {
258		let tx = self.tx.as_mut().ok_or_else(|| {
259			crate::backends::error::DatabaseError::TransactionError(
260				"Transaction already consumed".to_string(),
261			)
262		})?;
263
264		let mut query = sqlx::query(sql);
265		for param in &params {
266			query = Self::bind_value(query, param);
267		}
268		let row = query.fetch_one(&mut **tx).await?;
269		Self::convert_row(row)
270	}
271
272	async fn fetch_all(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Vec<Row>> {
273		let tx = self.tx.as_mut().ok_or_else(|| {
274			crate::backends::error::DatabaseError::TransactionError(
275				"Transaction already consumed".to_string(),
276			)
277		})?;
278
279		let mut query = sqlx::query(sql);
280		for param in &params {
281			query = Self::bind_value(query, param);
282		}
283		let rows = query.fetch_all(&mut **tx).await?;
284		rows.into_iter().map(Self::convert_row).collect()
285	}
286
287	async fn fetch_optional(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Option<Row>> {
288		let tx = self.tx.as_mut().ok_or_else(|| {
289			crate::backends::error::DatabaseError::TransactionError(
290				"Transaction already consumed".to_string(),
291			)
292		})?;
293
294		let mut query = sqlx::query(sql);
295		for param in &params {
296			query = Self::bind_value(query, param);
297		}
298		let row = query.fetch_optional(&mut **tx).await?;
299		row.map(Self::convert_row).transpose()
300	}
301
302	async fn commit(mut self: Box<Self>) -> Result<()> {
303		let tx = self.tx.take().ok_or_else(|| {
304			crate::backends::error::DatabaseError::TransactionError(
305				"Transaction already consumed".to_string(),
306			)
307		})?;
308		tx.commit().await?;
309		Ok(())
310	}
311
312	async fn rollback(mut self: Box<Self>) -> Result<()> {
313		let tx = self.tx.take().ok_or_else(|| {
314			crate::backends::error::DatabaseError::TransactionError(
315				"Transaction already consumed".to_string(),
316			)
317		})?;
318		tx.rollback().await?;
319		Ok(())
320	}
321
322	async fn savepoint(&mut self, name: &str) -> Result<()> {
323		let tx = self.tx.as_mut().ok_or_else(|| {
324			crate::backends::error::DatabaseError::TransactionError(
325				"Transaction already consumed".to_string(),
326			)
327		})?;
328
329		let sp = Savepoint::new(name);
330		sqlx::query(&sp.to_sql()).execute(&mut **tx).await?;
331		Ok(())
332	}
333
334	async fn release_savepoint(&mut self, name: &str) -> Result<()> {
335		let tx = self.tx.as_mut().ok_or_else(|| {
336			crate::backends::error::DatabaseError::TransactionError(
337				"Transaction already consumed".to_string(),
338			)
339		})?;
340
341		let sp = Savepoint::new(name);
342		sqlx::query(&sp.release_sql()).execute(&mut **tx).await?;
343		Ok(())
344	}
345
346	async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
347		let tx = self.tx.as_mut().ok_or_else(|| {
348			crate::backends::error::DatabaseError::TransactionError(
349				"Transaction already consumed".to_string(),
350			)
351		})?;
352
353		let sp = Savepoint::new(name);
354		sqlx::query(&sp.rollback_sql()).execute(&mut **tx).await?;
355		Ok(())
356	}
357}
358
359#[cfg(test)]
360mod tests {
361	use rstest::rstest;
362	use rust_decimal::prelude::ToPrimitive;
363
364	/// Verify that normal Decimal values succeed to_f64() conversion
365	#[rstest]
366	#[case::positive(rust_decimal::Decimal::new(12345, 2), 123.45)]
367	#[case::zero(rust_decimal::Decimal::ZERO, 0.0)]
368	#[case::negative(rust_decimal::Decimal::new(-999, 1), -99.9)]
369	#[case::max(rust_decimal::Decimal::MAX, 7.922816251426434e28)]
370	fn test_decimal_to_f64_conversion_succeeds(
371		#[case] decimal: rust_decimal::Decimal,
372		#[case] expected: f64,
373	) {
374		// Act
375		let result = decimal.to_f64();
376
377		// Assert
378		assert!(
379			result.is_some(),
380			"Decimal '{}' should convert to f64",
381			decimal
382		);
383		let f = result.unwrap();
384
385		// Use combined relative and absolute tolerance for float comparison
386		let diff = (f - expected).abs();
387		let rel_tol = 1e-12;
388		let abs_tol = 1e-12;
389		let tol = expected.abs() * rel_tol + abs_tol;
390
391		assert!(
392			diff <= tol,
393			"Expected approximately {} (tolerance {}, diff {}), got {}",
394			expected,
395			tol,
396			diff,
397			f
398		);
399	}
400
401	/// Verify the TypeError is constructed correctly for conversion failures
402	#[rstest]
403	fn test_decimal_conversion_error_message_format() {
404		use crate::backends::error::DatabaseError;
405
406		// Arrange
407		let value = rust_decimal::Decimal::new(12345, 2);
408		let column_name = "price_column";
409
410		// Act
411		let error = super::PostgresBackend::decimal_conversion_error(&value, column_name);
412
413		// Assert
414		assert!(matches!(error, DatabaseError::TypeError(_)));
415		let error_msg = error.to_string();
416		assert!(
417			error_msg.contains("price_column"),
418			"Error message should contain the column name"
419		);
420		assert!(
421			error_msg.contains("123.45"),
422			"Error message should contain the decimal value"
423		);
424	}
425
426	/// Verify TypeError is the correct variant for type conversion failures
427	#[rstest]
428	fn test_type_error_variant_distinction() {
429		use crate::backends::error::DatabaseError;
430
431		// Arrange & Act
432		let type_error = DatabaseError::TypeError("conversion failed".to_string());
433		let query_error = DatabaseError::QueryError("query failed".to_string());
434
435		// Assert
436		assert!(matches!(type_error, DatabaseError::TypeError(_)));
437		assert!(!matches!(type_error, DatabaseError::QueryError(_)));
438		assert!(matches!(query_error, DatabaseError::QueryError(_)));
439		assert!(!matches!(query_error, DatabaseError::TypeError(_)));
440	}
441}