use super::{Connection, Parameter};
use deadpool_postgres::{
tokio_postgres::{error::SqlState, RowStream, Statement},
ManagerConfig, PoolConfig,
};
use futures_util::{ready, Stream};
use pin_project_lite::pin_project;
use serde::{Deserialize, Deserializer};
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use thiserror::Error;
use tonic::{async_trait, Status};
#[cfg(feature = "ssl-native-tls")]
use {native_tls::TlsConnector, postgres_native_tls::MakeTlsConnector};
#[derive(Debug, Error)]
pub enum Error {
#[error("Expected {expected} parameters but found {actual} instead")]
Params {
expected: usize,
actual: usize,
},
#[error("Error fetching connection from the pool: {0}")]
Pool(#[from] deadpool_postgres::PoolError),
#[error("SQL Query error: {0}")]
Query(#[from] deadpool_postgres::tokio_postgres::Error),
#[error("Unable to set the ROLE of the connection before use: {0}")]
Role(deadpool_postgres::tokio_postgres::Error),
#[error("Unable to aggregate rows from query into valid JSON")]
InvalidJson,
#[error("Error creating the connection pool: {0}")]
Create(#[from] deadpool_postgres::CreatePoolError),
#[cfg(feature = "ssl-native-tls")]
#[error("Error setting up TLS connection: {0}")]
Tls(#[from] native_tls::Error),
}
impl From<Error> for Status {
fn from(error: Error) -> Self {
let message = error.to_string();
match error {
Error::Params { .. } | Error::Role(..) | Error::Query(..) | Error::InvalidJson => {
Status::invalid_argument(message)
}
Error::Create(..) | Error::Pool(..) => Status::resource_exhausted(message),
#[cfg(feature = "ssl-native-tls")]
Error::Tls(..) => Status::internal(message),
}
}
}
pub struct Pool {
pool: deadpool_postgres::Pool,
statement_timeout: Option<Duration>,
}
#[async_trait]
impl super::Pool for Pool {
#[cfg(feature = "role-header")]
type Key = crate::extensions::role_header::Role;
#[cfg(not(feature = "role-header"))]
type Key = ();
type Connection = Client;
type Error = <Self::Connection as Connection>::Error;
#[tracing::instrument(skip(self))]
async fn get_connection(&self, key: Self::Key) -> Result<Self::Connection, Self::Error> {
tracing::trace!("Fetching connection from the pool");
let client = self.pool.get().await?;
#[cfg(feature = "role-header")]
{
let local_role_statement = match key {
Some(role) => format!(r#"SET ROLE "{}""#, role),
None => "RESET ROLE".to_string(),
};
client
.batch_execute(&local_role_statement)
.await
.map_err(Error::Role)?;
}
if let Some(statement_timeout) = self.statement_timeout {
client
.batch_execute(&format!(
"SET statement_timeout={}",
statement_timeout.as_millis()
))
.await?;
}
Ok(Client { client })
}
}
pin_project! {
pub struct StructStream {
#[pin]
rows: RowStream,
}
}
impl Stream for StructStream {
type Item = Result<pbjson_types::Struct, Error>;
fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.rows.poll_next(context)?) {
Some(row) => {
if let serde_json::Value::Object(map) = row.try_get("json")? {
Poll::Ready(Some(Ok(to_proto_struct(map))))
} else {
Poll::Ready(Some(Err(Error::InvalidJson)))
}
}
None => Poll::Ready(None),
}
}
}
impl From<RowStream> for StructStream {
fn from(rows: RowStream) -> Self {
Self { rows }
}
}
pub struct Client {
client: deadpool_postgres::Client,
}
#[async_trait]
impl Connection for Client {
type Error = Error;
type RowStream = StructStream;
#[tracing::instrument(skip(self, parameters))]
async fn query(
&self,
statement: &str,
parameters: &[Parameter],
) -> Result<Self::RowStream, Self::Error> {
tracing::trace!("Querying Connection");
let prepared_statement = self.client.prepare_cached(statement).await?;
let inferred_types = prepared_statement.params();
if inferred_types.len() != parameters.len() {
return Err(Error::Params {
expected: inferred_types.len(),
actual: parameters.len(),
});
}
let rows = match query_raw(self, statement, &prepared_statement, parameters).await {
Err(Error::Query(error)) if error.code() == Some(&SqlState::FEATURE_NOT_SUPPORTED) => {
tracing::warn!("Schema poisoned underneath statement cache. Retrying query");
self.client
.statement_cache
.remove(statement, inferred_types);
query_raw(self, statement, &prepared_statement, parameters).await
}
result => result,
}?;
Ok(StructStream::from(rows))
}
#[tracing::instrument(skip(self))]
async fn batch(&self, query: &str) -> Result<(), Self::Error> {
tracing::trace!("Executing batch query on Connection");
self.client.batch_execute(query).await?;
Ok(())
}
}
async fn query_raw(
client: &Client,
statement: &str,
prepared_statement: &Statement,
parameters: &[Parameter],
) -> Result<RowStream, Error> {
let rows = if prepared_statement.columns().is_empty() {
client
.client
.query_raw(prepared_statement, parameters)
.await?
} else {
let json_statement = format!(
"WITH cte AS ({})
SELECT TO_JSON(__result) AS json
FROM (SELECT * FROM cte) AS __result",
&statement
);
let prepared_statement = client.client.prepare_cached(&json_statement).await?;
client
.client
.query_raw(&prepared_statement, parameters)
.await?
};
Ok(rows)
}
#[derive(Deserialize, Debug)]
pub struct Configuration {
#[serde(default = "get_max_connection_pool_size")]
max_connection_pool_size: usize,
#[serde(default, deserialize_with = "from_milliseconds_string")]
statement_timeout: Option<Duration>,
#[serde(default)]
recycling_method: RecyclingMethod,
pgdbname: String,
#[serde(default = "get_localhost")]
pghost: String,
pgpassword: String,
#[serde(default = "get_postgres_port")]
pgport: u16,
pguser: String,
#[serde(default = "get_application_name")]
pgappname: String,
#[serde(default)]
pgsslmode: Option<SslMode>,
}
impl Configuration {
#[tracing::instrument(
skip(self),
fields(
max_connection_pool_size = self.max_connection_pool_size,
?statement_timeout = self.statement_timeout,
?recycling_method = self.recycling_method,
pgdbname = self.pgdbname,
pghost = self.pghost,
pgpassword = "******",
pgport = self.pgport,
pgappname = self.pgappname,
?pgsslmode = self.pgsslmode,
)
)]
pub fn create_pool(self) -> Result<Pool, Error> {
tracing::debug!("Creating deadpool-based connection pool from configuration");
#[cfg(feature = "ssl-native-tls")]
let tls_connector = {
let connector = TlsConnector::builder().build()?;
MakeTlsConnector::new(connector)
};
#[cfg(not(feature = "ssl-native-tls"))]
let tls_connector = tokio_postgres::NoTls;
let manager = ManagerConfig {
recycling_method: self.recycling_method.into(),
};
let pool = PoolConfig {
max_size: self.max_connection_pool_size,
..PoolConfig::default()
};
let config = deadpool_postgres::Config {
dbname: Some(self.pgdbname),
host: Some(self.pghost.to_string()),
password: Some(self.pgpassword),
port: Some(self.pgport),
user: Some(self.pguser),
application_name: Some(self.pgappname),
ssl_mode: self.pgsslmode.map(Into::into),
manager: Some(manager),
pool: Some(pool),
..deadpool_postgres::Config::default()
};
let pool = config.create_pool(None, tls_connector)?;
Ok(Pool {
pool,
statement_timeout: self.statement_timeout,
})
}
}
fn get_localhost() -> String {
"localhost".to_string()
}
fn get_postgres_port() -> u16 {
5432
}
fn get_application_name() -> String {
"postgrpc".to_string()
}
fn get_max_connection_pool_size() -> usize {
num_cpus::get_physical() * 4
}
fn from_milliseconds_string<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let base_string = String::deserialize(deserializer)?;
if base_string.is_empty() {
Ok(None)
} else {
let parsed_millis: u64 = base_string.parse().map_err(serde::de::Error::custom)?;
let duration = Duration::from_millis(parsed_millis);
Ok(Some(duration))
}
}
#[derive(Deserialize, Debug, Default)]
#[serde(rename_all = "lowercase")]
enum RecyclingMethod {
Fast,
Verified,
#[default]
Clean,
}
impl From<RecyclingMethod> for deadpool_postgres::RecyclingMethod {
fn from(method: RecyclingMethod) -> Self {
match method {
RecyclingMethod::Fast => Self::Fast,
RecyclingMethod::Verified => Self::Verified,
RecyclingMethod::Clean => Self::Clean,
}
}
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
enum SslMode {
Disable,
Prefer,
Require,
}
impl From<SslMode> for deadpool_postgres::SslMode {
fn from(mode: SslMode) -> Self {
match mode {
SslMode::Disable => Self::Disable,
SslMode::Prefer => Self::Prefer,
SslMode::Require => Self::Require,
}
}
}
fn to_proto_value(json: serde_json::Value) -> pbjson_types::Value {
let kind = match json {
serde_json::Value::Null => pbjson_types::value::Kind::NullValue(0),
serde_json::Value::Bool(boolean) => pbjson_types::value::Kind::BoolValue(boolean),
serde_json::Value::Number(number) => match number.as_f64() {
Some(number) => pbjson_types::value::Kind::NumberValue(number),
None => pbjson_types::value::Kind::StringValue(number.to_string()),
},
serde_json::Value::String(string) => pbjson_types::value::Kind::StringValue(string),
serde_json::Value::Array(array) => {
pbjson_types::value::Kind::ListValue(pbjson_types::ListValue {
values: array.into_iter().map(to_proto_value).collect(),
})
}
serde_json::Value::Object(map) => {
pbjson_types::value::Kind::StructValue(to_proto_struct(map))
}
};
pbjson_types::Value { kind: Some(kind) }
}
fn to_proto_struct(map: serde_json::Map<String, serde_json::Value>) -> pbjson_types::Struct {
pbjson_types::Struct {
fields: map
.into_iter()
.map(|(key, value)| (key, to_proto_value(value)))
.collect(),
}
}