mod database;
mod query_stats;
mod relay;
#[cfg(test)]
mod tests;
#[cfg(all(test, feature = "test-postgres"))]
mod integration_tests;
use std::{fmt::Write, time::Duration};
use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
use fraiseql_error::{FraiseQLError, Result};
use tokio_postgres::{NoTls, Row};
use super::where_generator::PostgresWhereGenerator;
use crate::{
dialect::PostgresDialect,
identifier::quote_postgres_identifier,
order_by::append_order_by,
traits::DatabaseAdapter,
types::{
DatabaseType, JsonbValue, QueryParam,
sql_hints::{OrderByClause, SqlProjectionHint},
},
where_clause::WhereClause,
};
const DEFAULT_POOL_SIZE: usize = 25;
const MAX_CONNECTION_RETRIES: u32 = 3;
const CONNECTION_RETRY_DELAY_MS: u64 = 50;
#[derive(Debug, Clone)]
pub struct PoolPrewarmConfig {
pub min_size: usize,
pub max_size: usize,
pub timeout_secs: Option<u64>,
}
fn build_pool(connection_string: &str, max_size: usize, timeout_secs: Option<u64>) -> Result<Pool> {
let mut cfg = Config::new();
cfg.url = Some(connection_string.to_string());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
let mut pool_cfg = deadpool_postgres::PoolConfig::new(max_size);
if let Some(secs) = timeout_secs {
let t = Duration::from_secs(secs);
pool_cfg.timeouts.wait = Some(t);
pool_cfg.timeouts.create = Some(t);
}
cfg.pool = Some(pool_cfg);
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| FraiseQLError::ConnectionPool {
message: format!("Failed to create connection pool: {e}"),
})
}
pub(super) fn escape_jsonb_key(key: &str) -> String {
key.replace('\'', "''")
}
#[derive(Clone)]
pub struct PostgresAdapter {
pub(super) pool: Pool,
mutation_timing_enabled: bool,
timing_variable_name: String,
}
impl std::fmt::Debug for PostgresAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgresAdapter")
.field("mutation_timing_enabled", &self.mutation_timing_enabled)
.field("timing_variable_name", &self.timing_variable_name)
.field("pool", &"<Pool>")
.finish()
}
}
impl PostgresAdapter {
pub async fn new(connection_string: &str) -> Result<Self> {
Self::with_pool_config(
connection_string,
PoolPrewarmConfig {
min_size: 0,
max_size: DEFAULT_POOL_SIZE,
timeout_secs: None,
},
)
.await
}
pub async fn with_pool_config(connection_string: &str, cfg: PoolPrewarmConfig) -> Result<Self> {
let pool = build_pool(connection_string, cfg.max_size, cfg.timeout_secs)?;
let client = pool.get().await.map_err(|e| FraiseQLError::ConnectionPool {
message: format!("Failed to acquire connection: {e}"),
})?;
client.query("SELECT 1", &[]).await.map_err(|e| FraiseQLError::Database {
message: format!("Failed to connect to database: {e}"),
sql_state: e.code().map(|c| c.code().to_string()),
})?;
drop(client);
let adapter = Self {
pool,
mutation_timing_enabled: false,
timing_variable_name: "fraiseql.started_at".to_string(),
};
let warm_target = cfg.min_size.min(cfg.max_size).saturating_sub(1);
if warm_target > 0 {
adapter.prewarm(warm_target).await;
}
Ok(adapter)
}
pub async fn with_pool_size(connection_string: &str, max_size: usize) -> Result<Self> {
Self::with_pool_config(
connection_string,
PoolPrewarmConfig {
min_size: 0,
max_size,
timeout_secs: None,
},
)
.await
}
async fn prewarm(&self, count: usize) {
use futures::future::join_all;
use tokio::time::timeout;
let handles: Vec<_> = (0..count)
.map(|_| {
let pool = self.pool.clone();
tokio::spawn(async move { pool.get().await })
})
.collect();
let result = timeout(Duration::from_secs(10), join_all(handles)).await;
let (succeeded, failed) = match result {
Ok(outcomes) => {
let s = outcomes
.iter()
.filter(|r| r.as_ref().map(|inner| inner.is_ok()).unwrap_or(false))
.count();
(s, count - s)
},
Err(_elapsed) => {
tracing::warn!(
target_connections = count,
"Pool pre-warm timed out after 10s; server will continue with partial pre-warm"
);
(0, count)
},
};
if failed > 0 {
tracing::warn!(
succeeded,
failed,
"Pool pre-warm: some connections could not be established"
);
} else {
tracing::info!(
idle_connections = succeeded + 1,
"PostgreSQL pool pre-warmed successfully"
);
}
}
#[must_use]
pub const fn pool(&self) -> &Pool {
&self.pool
}
#[must_use]
pub fn with_mutation_timing(mut self, variable_name: &str) -> Self {
self.mutation_timing_enabled = true;
self.timing_variable_name = variable_name.to_string();
self
}
#[must_use]
pub const fn mutation_timing_enabled(&self) -> bool {
self.mutation_timing_enabled
}
pub(super) async fn execute_raw(
&self,
sql: &str,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
) -> Result<Vec<JsonbValue>> {
let client = self.acquire_connection_with_retry().await?;
let rows: Vec<Row> =
client.query(sql, params).await.map_err(|e| FraiseQLError::Database {
message: format!("Query execution failed: {e}"),
sql_state: e.code().map(|c| c.code().to_string()),
})?;
let results = rows
.into_iter()
.map(|row| {
let data: serde_json::Value = row.get(0);
JsonbValue::new(data)
})
.collect();
Ok(results)
}
pub(super) async fn acquire_connection_with_retry(&self) -> Result<deadpool_postgres::Client> {
use deadpool_postgres::PoolError;
let mut last_error = None;
for attempt in 0..MAX_CONNECTION_RETRIES {
match self.pool.get().await {
Ok(client) => {
if attempt > 0 {
tracing::info!(attempt, "Successfully acquired connection after retries");
}
return Ok(client);
},
Err(PoolError::Timeout(_)) => {
let metrics = self.pool_metrics();
tracing::error!(
available = metrics.idle_connections,
active = metrics.active_connections,
max = metrics.total_connections,
"Connection pool timeout: all connections busy"
);
return Err(FraiseQLError::ConnectionPool {
message: format!(
"Connection pool timeout: {}/{} connections busy. \
Increase pool_max_size or reduce concurrent load.",
metrics.active_connections, metrics.total_connections,
),
});
},
Err(e) => {
last_error = Some(e);
if attempt < MAX_CONNECTION_RETRIES - 1 {
let delay = CONNECTION_RETRY_DELAY_MS * (u64::from(attempt) + 1);
tracing::warn!(
attempt = attempt + 1,
total = MAX_CONNECTION_RETRIES,
delay_ms = delay,
"Transient connection error, retrying"
);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
},
}
}
let pool_metrics = self.pool_metrics();
tracing::error!(
retries = MAX_CONNECTION_RETRIES,
available = pool_metrics.idle_connections,
active = pool_metrics.active_connections,
max = pool_metrics.total_connections,
"Failed to acquire connection after all retries"
);
Err(FraiseQLError::ConnectionPool {
message: format!(
"Failed to acquire connection after {} retries: {}. \
Pool state: idle={}, active={}, max={}",
MAX_CONNECTION_RETRIES,
last_error.expect("last_error is set on every retry iteration"),
pool_metrics.idle_connections,
pool_metrics.active_connections,
pool_metrics.total_connections,
),
})
}
pub(super) async fn execute_with_projection_impl(
&self,
view: &str,
projection: Option<&SqlProjectionHint>,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Vec<JsonbValue>> {
if projection.is_none() {
return self.execute_where_query(view, where_clause, limit, offset, order_by).await;
}
let projection = projection.expect("projection is Some; None was returned above");
let mut sql = format!(
"SELECT {} FROM {}",
projection.projection_template,
quote_postgres_identifier(view)
);
let mut typed_params: Vec<QueryParam> = if let Some(clause) = where_clause {
let generator = PostgresWhereGenerator::new(PostgresDialect);
let (where_sql, where_params) = generator.generate(clause)?;
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
where_params.into_iter().map(QueryParam::from).collect()
} else {
Vec::new()
};
let mut param_count = typed_params.len();
append_order_by(&mut sql, order_by, DatabaseType::PostgreSQL)?;
if let Some(lim) = limit {
param_count += 1;
write!(sql, " LIMIT ${param_count}").expect("write to String");
typed_params.push(QueryParam::BigInt(i64::from(lim)));
}
if let Some(off) = offset {
param_count += 1;
write!(sql, " OFFSET ${param_count}").expect("write to String");
typed_params.push(QueryParam::BigInt(i64::from(off)));
}
tracing::debug!("SQL with projection = {}", sql);
tracing::debug!("typed_params = {:?}", typed_params);
let param_refs = crate::types::as_sql_param_refs(&typed_params);
self.execute_raw(&sql, ¶m_refs).await
}
pub async fn execute_with_projection(
&self,
view: &str,
projection: Option<&SqlProjectionHint>,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<Vec<JsonbValue>> {
self.execute_with_projection_impl(view, projection, where_clause, limit, offset, None)
.await
}
}
pub(super) fn build_where_select_sql(
view: &str,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(String, Vec<QueryParam>)> {
build_where_select_sql_ordered(view, where_clause, limit, offset, None)
}
pub(super) fn build_where_select_sql_ordered(
view: &str,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<(String, Vec<QueryParam>)> {
let mut sql = format!("SELECT data FROM {}", quote_postgres_identifier(view));
let mut typed_params: Vec<QueryParam> = if let Some(clause) = where_clause {
let generator = PostgresWhereGenerator::new(PostgresDialect);
let (where_sql, where_params) = generator.generate(clause)?;
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
where_params.into_iter().map(QueryParam::from).collect()
} else {
Vec::new()
};
let mut param_count = typed_params.len();
append_order_by(&mut sql, order_by, DatabaseType::PostgreSQL)?;
if let Some(lim) = limit {
param_count += 1;
write!(sql, " LIMIT ${param_count}").expect("write to String");
typed_params.push(QueryParam::BigInt(i64::from(lim)));
}
if let Some(off) = offset {
param_count += 1;
write!(sql, " OFFSET ${param_count}").expect("write to String");
typed_params.push(QueryParam::BigInt(i64::from(off)));
}
Ok((sql, typed_params))
}