use crate::{
extensions::FromRequest,
pools::{Connection, Parameter, Pool},
};
use futures_util::{pin_mut, StreamExt, TryStreamExt};
use proto::postgres_server::{Postgres as GrpcService, PostgresServer};
pub use proto::QueryRequest;
use std::sync::Arc;
use tokio::sync::mpsc::error::SendError;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{codegen::InterceptedService, service::Interceptor, Request, Response, Status};
#[allow(unreachable_pub, missing_docs)]
mod proto {
tonic::include_proto!("postgres.v1");
}
#[derive(Clone)]
pub struct Postgres<P> {
pool: Arc<P>,
}
impl<P> Postgres<P>
where
P: Pool,
{
fn new(pool: Arc<P>) -> Self {
Self { pool }
}
#[tracing::instrument(skip(self, parameters), err)]
async fn query(
&self,
key: P::Key,
statement: &str,
parameters: &[Parameter],
) -> Result<<P::Connection as Connection>::RowStream, P::Error> {
tracing::info!("Querying postgres");
let rows = self
.pool
.get_connection(key)
.await?
.query(statement, parameters)
.await?;
Ok(rows)
}
}
#[tonic::async_trait]
impl<P> GrpcService for Postgres<P>
where
P: Pool + 'static,
P::Key: FromRequest,
{
type QueryStream = ReceiverStream<Result<pbjson_types::Struct, Status>>;
#[tracing::instrument(skip(self, request), err)]
async fn query(
&self,
mut request: Request<QueryRequest>,
) -> Result<Response<Self::QueryStream>, Status> {
let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
let QueryRequest { statement, values } = request.into_inner();
let value_count = values.len();
let parameters: Vec<_> = values.into_iter().map(Parameter::from).collect();
if parameters.len() < value_count {
return Err(
Status::invalid_argument(
"Invalid parameter values found. Only numbers, strings, boolean, and null values permitted"
)
);
}
let rows = Postgres::query(self, key, &statement, ¶meters)
.await
.map_err(Into::<Status>::into)?
.map_ok(Into::into)
.map_err(Into::<Status>::into);
let (transmitter, receiver) = tokio::sync::mpsc::channel(100);
tokio::spawn(async move {
pin_mut!(rows);
while let Some(row) = rows.next().await {
transmitter.send(row).await?;
}
Ok::<_, SendError<_>>(())
});
Ok(Response::new(ReceiverStream::new(receiver)))
}
}
pub fn new<P>(pool: Arc<P>) -> PostgresServer<Postgres<P>>
where
P: Pool + 'static,
P::Key: FromRequest,
{
PostgresServer::new(Postgres::new(pool))
}
pub fn with_interceptor<P, I>(
pool: Arc<P>,
interceptor: I,
) -> InterceptedService<PostgresServer<Postgres<P>>, I>
where
P: Pool + 'static,
P::Key: FromRequest,
I: Interceptor,
{
PostgresServer::with_interceptor(Postgres::new(pool), interceptor)
}