use std::error::Error as StdError;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use bytes::BytesMut;
use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod};
use tokio_postgres::types::{to_sql_checked, IsNull, ToSql, Type};
use tokio_postgres::NoTls;
use crate::driver::ExecuteResult;
use crate::error::OrmError;
use crate::row::Row;
use crate::value::Value;
type BoxError = Box<dyn StdError + Sync + Send>;
#[derive(Clone)]
pub struct PostgresPool {
pool: Pool,
statements: Arc<AtomicU64>,
}
impl PostgresPool {
pub fn new(url: &str, max_connections: u32) -> crate::Result<Self> {
if max_connections == 0 {
return Err(OrmError::configuration("max_connections must be at least 1"));
}
let pg_config: tokio_postgres::Config = url
.parse()
.map_err(|e| OrmError::configuration("invalid PostgreSQL url").with_source(e))?;
let manager = Manager::from_config(
pg_config,
NoTls,
ManagerConfig { recycling_method: RecyclingMethod::Fast },
);
let pool = Pool::builder(manager)
.max_size(max_connections as usize)
.build()
.map_err(|e| OrmError::configuration("cannot build PostgreSQL pool").with_source(e))?;
Ok(Self { pool, statements: Arc::new(AtomicU64::new(0)) })
}
pub async fn fetch_all(&self, sql: String, params: Vec<Value>) -> crate::Result<Vec<Row>> {
self.statements.fetch_add(1, Ordering::Relaxed);
let client = self.get().await?;
query(&client, &sql, ¶ms).await
}
pub async fn execute(&self, sql: String, params: Vec<Value>) -> crate::Result<ExecuteResult> {
self.statements.fetch_add(1, Ordering::Relaxed);
let client = self.get().await?;
execute(&client, &sql, ¶ms).await
}
pub async fn execute_batch(&self, sql: String) -> crate::Result<()> {
self.statements.fetch_add(1, Ordering::Relaxed);
let client = self.get().await?;
client
.batch_execute(&sql)
.await
.map_err(|e| OrmError::query("statement batch failed").with_source(e))
}
pub fn statement_count(&self) -> u64 {
self.statements.load(Ordering::Relaxed)
}
pub async fn close(&self) {
self.pool.close();
}
pub(crate) async fn acquire_pinned(&self) -> crate::Result<PinnedPostgres> {
let object = self.get().await?;
Ok(PinnedPostgres {
object: Mutex::new(Some(object)),
statements: Arc::clone(&self.statements),
})
}
async fn get(&self) -> crate::Result<Object> {
self.pool
.get()
.await
.map_err(|e| OrmError::connection("cannot acquire a PostgreSQL connection").with_source(e))
}
}
async fn query(client: &Object, sql: &str, params: &[Value]) -> crate::Result<Vec<Row>> {
let bound: Vec<&(dyn ToSql + Sync)> = params.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
let rows = client
.query(sql, &bound)
.await
.map_err(|e| OrmError::query("query execution failed").with_source(e))?;
rows.iter().map(read_row).collect()
}
async fn execute(client: &Object, sql: &str, params: &[Value]) -> crate::Result<ExecuteResult> {
let bound: Vec<&(dyn ToSql + Sync)> = params.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
let affected = client
.execute(sql, &bound)
.await
.map_err(|e| OrmError::query("statement execution failed").with_source(e))?;
Ok(ExecuteResult {
rows_affected: affected,
last_insert_rowid: 0,
})
}
fn read_row(row: &tokio_postgres::Row) -> crate::Result<Row> {
let columns: Arc<[String]> = row
.columns()
.iter()
.map(|column| column.name().to_string())
.collect::<Vec<_>>()
.into();
let mut values = Vec::with_capacity(row.len());
for index in 0..row.len() {
values.push(read_value(row, index)?);
}
Ok(Row::with_columns(columns, values))
}
fn read_value(row: &tokio_postgres::Row, index: usize) -> crate::Result<Value> {
let ty = row.columns()[index].type_();
if let tokio_postgres::types::Kind::Array(element) = ty.kind() {
return read_array(row, index, element);
}
let value = if *ty == Type::BOOL {
get_opt::<bool>(row, index)?.map_or(Value::Null, Value::Bool)
} else if *ty == Type::INT2 {
get_opt::<i16>(row, index)?.map_or(Value::Null, |n| Value::Int(i64::from(n)))
} else if *ty == Type::INT4 {
get_opt::<i32>(row, index)?.map_or(Value::Null, |n| Value::Int(i64::from(n)))
} else if *ty == Type::INT8 {
get_opt::<i64>(row, index)?.map_or(Value::Null, Value::Int)
} else if *ty == Type::FLOAT4 {
get_opt::<f32>(row, index)?.map_or(Value::Null, |n| Value::Real(f64::from(n)))
} else if *ty == Type::FLOAT8 {
get_opt::<f64>(row, index)?.map_or(Value::Null, Value::Real)
} else if *ty == Type::BYTEA {
get_opt::<Vec<u8>>(row, index)?.map_or(Value::Null, Value::Blob)
} else if *ty == Type::TIMESTAMPTZ {
get_opt::<time::OffsetDateTime>(row, index)?.map_or(Value::Null, Value::Timestamp)
} else if *ty == Type::JSON || *ty == Type::JSONB {
get_opt::<serde_json::Value>(row, index)?.map_or(Value::Null, Value::Json)
} else if *ty == Type::UUID {
get_opt::<uuid::Uuid>(row, index)?.map_or(Value::Null, Value::Uuid)
} else {
get_opt::<String>(row, index)?.map_or(Value::Null, Value::Text)
};
Ok(value)
}
fn read_array(row: &tokio_postgres::Row, index: usize, element: &Type) -> crate::Result<Value> {
macro_rules! read_elements {
($t:ty, $wrap:expr) => {{
let column: Option<Vec<Option<$t>>> = row
.try_get(index)
.map_err(|e| OrmError::conversion(format!("cannot read array column {index}")).with_source(e))?;
match column {
None => Value::Null,
Some(items) => Value::Array(
items
.into_iter()
.map(|item| item.map_or(Value::Null, $wrap))
.collect(),
),
}
}};
}
let value = if *element == Type::BOOL {
read_elements!(bool, Value::Bool)
} else if *element == Type::INT2 {
read_elements!(i16, |n| Value::Int(i64::from(n)))
} else if *element == Type::INT4 {
read_elements!(i32, |n| Value::Int(i64::from(n)))
} else if *element == Type::INT8 {
read_elements!(i64, Value::Int)
} else if *element == Type::FLOAT4 {
read_elements!(f32, |n| Value::Real(f64::from(n)))
} else if *element == Type::FLOAT8 {
read_elements!(f64, Value::Real)
} else if *element == Type::UUID {
read_elements!(uuid::Uuid, Value::Uuid)
} else {
read_elements!(String, Value::Text)
};
Ok(value)
}
fn get_opt<'a, T>(row: &'a tokio_postgres::Row, index: usize) -> crate::Result<Option<T>>
where
T: tokio_postgres::types::FromSql<'a>,
{
row.try_get::<_, Option<T>>(index)
.map_err(|e| OrmError::conversion(format!("cannot read column {index}")).with_source(e))
}
impl ToSql for Value {
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxError> {
match self {
Value::Null => Ok(IsNull::Yes),
Value::Bool(b) => b.to_sql(ty, out),
Value::Int(i) => {
if *ty == Type::INT2 {
i16::try_from(*i)?.to_sql(ty, out)
} else if *ty == Type::INT4 {
i32::try_from(*i)?.to_sql(ty, out)
} else {
i.to_sql(ty, out)
}
}
Value::Real(r) => {
if *ty == Type::FLOAT4 {
(*r as f32).to_sql(ty, out)
} else {
r.to_sql(ty, out)
}
}
Value::Text(s) => s.to_sql(ty, out),
Value::Blob(b) => b.to_sql(ty, out),
Value::Timestamp(t) => t.to_sql(ty, out),
Value::Json(j) => j.to_sql(ty, out),
Value::Uuid(u) => u.to_sql(ty, out),
Value::Array(items) => items.to_sql(ty, out),
}
}
fn accepts(_ty: &Type) -> bool {
true
}
to_sql_checked!();
}
pub(crate) struct PinnedPostgres {
object: Mutex<Option<Object>>,
statements: Arc<AtomicU64>,
}
impl PinnedPostgres {
fn take(&self) -> crate::Result<Object> {
self.object
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.take()
.ok_or_else(|| OrmError::query("pinned connection is already in use"))
}
fn put(&self, object: Object) {
*self.object.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = Some(object);
}
pub(crate) async fn fetch_all(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<Vec<Row>> {
self.statements.fetch_add(1, Ordering::Relaxed);
let object = self.take()?;
let result = query(&object, &sql, ¶ms).await;
self.put(object);
result
}
pub(crate) async fn execute(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<ExecuteResult> {
self.statements.fetch_add(1, Ordering::Relaxed);
let object = self.take()?;
let result = execute(&object, &sql, ¶ms).await;
self.put(object);
result
}
pub(crate) async fn execute_batch(&self, sql: String) -> crate::Result<()> {
self.statements.fetch_add(1, Ordering::Relaxed);
let object = self.take()?;
let result = object
.batch_execute(&sql)
.await
.map_err(|e| OrmError::query("statement batch failed").with_source(e));
self.put(object);
result
}
pub(crate) fn rollback_now(&self) {
let Ok(object) = self.take() else { return };
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
handle.spawn(async move {
let _ = object.batch_execute("ROLLBACK").await;
});
}
Err(_) => drop(object),
}
}
}