1use crate::{
2 extensions::FromRequest,
3 pools::{transaction, Connection, Parameter, Pool},
4};
5use futures_util::{pin_mut, StreamExt, TryStreamExt};
6use proto::transaction_server::{Transaction as GrpcService, TransactionServer};
7pub use proto::{BeginResponse, CommitRequest, RollbackRequest, TransactionQueryRequest};
8use std::{hash::Hash, sync::Arc};
9use tokio::sync::mpsc::error::SendError;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{codegen::InterceptedService, service::Interceptor, Request, Response, Status};
12use uuid::Uuid;
13
14#[allow(unreachable_pub, missing_docs)]
16mod proto {
17 tonic::include_proto!("transaction.v1");
18}
19
20type Error<P> = transaction::Error<<<P as Pool>::Connection as Connection>::Error>;
22
23#[derive(Clone)]
25pub struct Transaction<P>
26where
27 P: Pool,
28 P::Key: Hash + Eq + Clone,
29{
30 pool: transaction::Pool<P>,
31}
32
33impl<P> Transaction<P>
34where
35 P: Pool + 'static,
36 P::Key: Hash + Eq + Send + Sync + Clone + 'static,
37 P::Connection: 'static,
38 <P::Connection as Connection>::Error: Send + Sync + 'static,
39{
40 pub fn new(pool: Arc<P>) -> Self {
42 Self {
43 pool: transaction::Pool::new(pool),
44 }
45 }
46
47 #[tracing::instrument(skip(self), err)]
49 pub async fn begin(&self, key: P::Key) -> Result<Uuid, Error<P>> {
50 tracing::debug!("Beginning transaction");
51
52 let transaction_id = self.pool.begin(key).await?;
53
54 Ok(transaction_id)
55 }
56
57 #[tracing::instrument(skip(self, parameters), err)]
59 pub async fn query(
60 &self,
61 id: Uuid,
62 key: P::Key,
63 statement: &str,
64 parameters: &[Parameter],
65 ) -> Result<<P::Connection as Connection>::RowStream, Error<P>> {
66 tracing::info!("Querying transaction");
67
68 let transaction_key = transaction::Key::new(key, id);
69
70 let rows = self
71 .pool
72 .get_connection(transaction_key)
73 .await?
74 .query(statement, parameters)
75 .await
76 .map_err(transaction::Error::Connection)?;
77
78 Ok(rows)
79 }
80
81 #[tracing::instrument(skip(self), err)]
83 pub async fn commit(&self, id: Uuid, key: P::Key) -> Result<(), Error<P>> {
84 tracing::debug!("Committing transaction");
85
86 self.pool.commit(id, key).await?;
87
88 Ok(())
89 }
90
91 #[tracing::instrument(skip(self), err)]
93 pub async fn rollback(&self, id: Uuid, key: P::Key) -> Result<(), Error<P>> {
94 tracing::debug!("Rolling back transaction");
95
96 self.pool.rollback(id, key).await?;
97
98 Ok(())
99 }
100}
101
102#[tonic::async_trait]
104impl<P> GrpcService for Transaction<P>
105where
106 P: Pool + 'static,
107 P::Key: FromRequest + Hash + Eq + Clone,
108{
109 type QueryStream = ReceiverStream<Result<pbjson_types::Struct, Status>>;
110
111 #[tracing::instrument(skip(self, request), err)]
112 async fn query(
113 &self,
114 mut request: Request<TransactionQueryRequest>,
115 ) -> Result<Response<Self::QueryStream>, Status> {
116 let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
118
119 let TransactionQueryRequest {
121 id,
122 statement,
123 values,
124 } = request.into_inner();
125
126 let id = Uuid::parse_str(&id).map_err(|_| {
127 Status::invalid_argument("Transaction ID in request had unrecognized format")
128 })?;
129
130 let value_count = values.len();
132
133 let parameters: Vec<_> = values.into_iter().map(Parameter::from).collect();
134
135 if parameters.len() < value_count {
136 return Err(
137 Status::invalid_argument(
138 "Invalid parameter values found. Only numbers, strings, boolean, and null values permitted"
139 )
140 );
141 }
142
143 let rows = Transaction::query(self, id, key, &statement, ¶meters)
145 .await
146 .map_err(Into::<Status>::into)?
147 .map_ok(Into::into)
148 .map_err(Into::<Status>::into);
149
150 let (transmitter, receiver) = tokio::sync::mpsc::channel(100);
152
153 tokio::spawn(async move {
155 pin_mut!(rows);
156
157 while let Some(row) = rows.next().await {
158 transmitter.send(row).await?;
159 }
160
161 Ok::<_, SendError<_>>(())
162 });
163
164 Ok(Response::new(ReceiverStream::new(receiver)))
165 }
166
167 #[tracing::instrument(skip(self, request), err)]
168 async fn begin(
169 &self,
170 mut request: Request<pbjson_types::Empty>,
171 ) -> Result<Response<BeginResponse>, Status> {
172 let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
174 let id = Transaction::begin(self, key).await?.to_string();
175
176 Ok(Response::new(BeginResponse { id }))
177 }
178
179 #[tracing::instrument(skip(self, request), err)]
180 async fn commit(
181 &self,
182 mut request: Request<CommitRequest>,
183 ) -> Result<Response<pbjson_types::Empty>, Status> {
184 let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
186
187 let CommitRequest { id } = request.get_ref();
188
189 let id = Uuid::parse_str(id).map_err(|_| {
190 Status::invalid_argument("Transaction ID in request had unrecognized format")
191 })?;
192
193 Transaction::commit(self, id, key).await?;
194
195 Ok(Response::new(pbjson_types::Empty::default()))
196 }
197
198 #[tracing::instrument(skip(self, request), err)]
199 async fn rollback(
200 &self,
201 mut request: Request<RollbackRequest>,
202 ) -> Result<Response<pbjson_types::Empty>, Status> {
203 let key = P::Key::from_request(&mut request).map_err(Into::<Status>::into)?;
205
206 let RollbackRequest { id } = request.get_ref();
207
208 let id = Uuid::parse_str(id).map_err(|_| {
209 Status::invalid_argument("Transaction ID in request had unrecognized format")
210 })?;
211
212 Transaction::rollback(self, id, key).await?;
213
214 Ok(Response::new(pbjson_types::Empty::default()))
215 }
216}
217
218pub fn new<P>(pool: Arc<P>) -> TransactionServer<Transaction<P>>
220where
221 P: Pool + 'static,
222 P::Key: FromRequest + Hash + Eq + Clone,
223{
224 TransactionServer::new(Transaction::new(pool))
225}
226
227pub fn with_interceptor<P, I>(
229 pool: Arc<P>,
230 interceptor: I,
231) -> InterceptedService<TransactionServer<Transaction<P>>, I>
232where
233 P: Pool + 'static,
234 P::Key: FromRequest + Hash + Eq + Clone,
235 I: Interceptor,
236{
237 TransactionServer::with_interceptor(Transaction::new(pool), interceptor)
238}