Skip to main content

postgrpc/services/
transaction.rs

1use crate::{
2    extensions::FromRequest,
3    pools::{transaction, Connection, Parameter, Pool},
4};
5use futures_util::{pin_mut, StreamExt, TryStreamExt};
6use proto::transaction_server::{Transaction as GrpcService, TransactionServer};
7pub use proto::{BeginResponse, CommitRequest, RollbackRequest, TransactionQueryRequest};
8use std::{hash::Hash, sync::Arc};
9use tokio::sync::mpsc::error::SendError;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{codegen::InterceptedService, service::Interceptor, Request, Response, Status};
12use uuid::Uuid;
13
14/// Compiled protocol buffers for the Transaction service
15#[allow(unreachable_pub, missing_docs)]
16mod proto {
17    tonic::include_proto!("transaction.v1");
18}
19
20/// Type alias representing a bubbled-up error from the transaction pool
21type Error<P> = transaction::Error<<<P as Pool>::Connection as Connection>::Error>;
22
23/// Protocol-agnostic Transaction handlers for any connection pool
24#[derive(Clone)]
25pub struct Transaction<P>
26where
27    P: Pool,
28    P::Key: Hash + Eq + Clone,
29{
30    pool: transaction::Pool<P>,
31}
32
33impl<P> Transaction<P>
34where
35    P: Pool + 'static,
36    P::Key: Hash + Eq + Send + Sync + Clone + 'static,
37    P::Connection: 'static,
38    <P::Connection as Connection>::Error: Send + Sync + 'static,
39{
40    /// Create a new Postgres transaction service from a reference-counted Pool
41    pub fn new(pool: Arc<P>) -> Self {
42        Self {
43            pool: transaction::Pool::new(pool),
44        }
45    }
46
47    /// Begin a Postgres transaction, returning a unique ID for the transaction
48    #[tracing::instrument(skip(self), err)]
49    pub async fn begin(&self, key: P::Key) -> Result<Uuid, Error<P>> {
50        tracing::debug!("Beginning transaction");
51
52        let transaction_id = self.pool.begin(key).await?;
53
54        Ok(transaction_id)
55    }
56
57    /// Query an active Postgres transaction by ID and connection pool key
58    #[tracing::instrument(skip(self, parameters), err)]
59    pub async fn query(
60        &self,
61        id: Uuid,
62        key: P::Key,
63        statement: &str,
64        parameters: &[Parameter],
65    ) -> Result<<P::Connection as Connection>::RowStream, Error<P>> {
66        tracing::info!("Querying transaction");
67
68        let transaction_key = transaction::Key::new(key, id);
69
70        let rows = self
71            .pool
72            .get_connection(transaction_key)
73            .await?
74            .query(statement, parameters)
75            .await
76            .map_err(transaction::Error::Connection)?;
77
78        Ok(rows)
79    }
80
81    /// Commit an active Postgres transaction by ID and connection pool key
82    #[tracing::instrument(skip(self), err)]
83    pub async fn commit(&self, id: Uuid, key: P::Key) -> Result<(), Error<P>> {
84        tracing::debug!("Committing transaction");
85
86        self.pool.commit(id, key).await?;
87
88        Ok(())
89    }
90
91    /// Roll back an active Postgres transaction by ID and connection pool key
92    #[tracing::instrument(skip(self), err)]
93    pub async fn rollback(&self, id: Uuid, key: P::Key) -> Result<(), Error<P>> {
94        tracing::debug!("Rolling back transaction");
95
96        self.pool.rollback(id, key).await?;
97
98        Ok(())
99    }
100}
101
102/// gRPC service implementation for Transaction service
103#[tonic::async_trait]
104impl<P> GrpcService for Transaction<P>
105where
106    P: Pool + 'static,
107    P::Key: FromRequest + Hash + Eq + Clone,
108{
109    type QueryStream = ReceiverStream<Result<pbjson_types::Struct, Status>>;
110
111    #[tracing::instrument(skip(self, request), err)]
112    async fn query(
113        &self,
114        mut request: Request<TransactionQueryRequest>,
115    ) -> Result<Response<Self::QueryStream>, Status> {
116        // derive a key from extensions to use as a connection pool key
117        let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
118
119        // get the request values
120        let TransactionQueryRequest {
121            id,
122            statement,
123            values,
124        } = request.into_inner();
125
126        let id = Uuid::parse_str(&id).map_err(|_| {
127            Status::invalid_argument("Transaction ID in request had unrecognized format")
128        })?;
129
130        // convert values to valid parameters
131        let value_count = values.len();
132
133        let parameters: Vec<_> = values.into_iter().map(Parameter::from).collect();
134
135        if parameters.len() < value_count {
136            return Err(
137                Status::invalid_argument(
138                    "Invalid parameter values found. Only numbers, strings, boolean, and null values permitted"
139                )
140            );
141        }
142
143        // get the rows, converting output to proto-compatible structs and statuses
144        let rows = Transaction::query(self, id, key, &statement, &parameters)
145            .await
146            .map_err(Into::<Status>::into)?
147            .map_ok(Into::into)
148            .map_err(Into::<Status>::into);
149
150        // create the row stream transmitter and receiver
151        let (transmitter, receiver) = tokio::sync::mpsc::channel(100);
152
153        // emit the rows as a Send stream
154        tokio::spawn(async move {
155            pin_mut!(rows);
156
157            while let Some(row) = rows.next().await {
158                transmitter.send(row).await?;
159            }
160
161            Ok::<_, SendError<_>>(())
162        });
163
164        Ok(Response::new(ReceiverStream::new(receiver)))
165    }
166
167    #[tracing::instrument(skip(self, request), err)]
168    async fn begin(
169        &self,
170        mut request: Request<pbjson_types::Empty>,
171    ) -> Result<Response<BeginResponse>, Status> {
172        // derive a key from extensions to use as a connection pool key
173        let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
174        let id = Transaction::begin(self, key).await?.to_string();
175
176        Ok(Response::new(BeginResponse { id }))
177    }
178
179    #[tracing::instrument(skip(self, request), err)]
180    async fn commit(
181        &self,
182        mut request: Request<CommitRequest>,
183    ) -> Result<Response<pbjson_types::Empty>, Status> {
184        // derive a key from extensions to use as a connection pool key
185        let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
186
187        let CommitRequest { id } = request.get_ref();
188
189        let id = Uuid::parse_str(id).map_err(|_| {
190            Status::invalid_argument("Transaction ID in request had unrecognized format")
191        })?;
192
193        Transaction::commit(self, id, key).await?;
194
195        Ok(Response::new(pbjson_types::Empty::default()))
196    }
197
198    #[tracing::instrument(skip(self, request), err)]
199    async fn rollback(
200        &self,
201        mut request: Request<RollbackRequest>,
202    ) -> Result<Response<pbjson_types::Empty>, Status> {
203        // derive a key from extensions to use as a connection pool key
204        let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
205
206        let RollbackRequest { id } = request.get_ref();
207
208        let id = Uuid::parse_str(id).map_err(|_| {
209            Status::invalid_argument("Transaction ID in request had unrecognized format")
210        })?;
211
212        Transaction::rollback(self, id, key).await?;
213
214        Ok(Response::new(pbjson_types::Empty::default()))
215    }
216}
217
218/// Create a new Transaction service from a connection pool
219pub fn new<P>(pool: Arc<P>) -> TransactionServer<Transaction<P>>
220where
221    P: Pool + 'static,
222    P::Key: FromRequest + Hash + Eq + Clone,
223{
224    TransactionServer::new(Transaction::new(pool))
225}
226
227/// Create a new Postgres service from a connection pool and an interceptor
228pub fn with_interceptor<P, I>(
229    pool: Arc<P>,
230    interceptor: I,
231) -> InterceptedService<TransactionServer<Transaction<P>>, I>
232where
233    P: Pool + 'static,
234    P::Key: FromRequest + Hash + Eq + Clone,
235    I: Interceptor,
236{
237    TransactionServer::with_interceptor(Transaction::new(pool), interceptor)
238}