Skip to main content

tokio_postgres/
transaction.rs

1#[cfg(feature = "runtime")]
2use crate::Socket;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9use crate::{
10    CancelToken, Client, CopyInSink, Error, Portal, Row, SimpleQueryMessage, Statement,
11    ToStatement, bind, query, slice_iter,
12};
13use bytes::Buf;
14use futures_util::TryStreamExt;
15use tokio::io::{AsyncRead, AsyncWrite};
16
17/// A representation of a PostgreSQL database transaction.
18///
19/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
20/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
21pub struct Transaction<'a> {
22    client: &'a mut Client,
23    savepoint: Option<Savepoint>,
24    done: bool,
25}
26
27/// A representation of a PostgreSQL database savepoint.
28struct Savepoint {
29    name: String,
30    depth: u32,
31}
32
33impl Drop for Transaction<'_> {
34    fn drop(&mut self) {
35        if self.done {
36            return;
37        }
38
39        let name = self.savepoint.as_ref().map(|sp| sp.name.as_str());
40        self.client.__private_api_rollback(name);
41    }
42}
43
44impl<'a> Transaction<'a> {
45    pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
46        Transaction {
47            client,
48            savepoint: None,
49            done: false,
50        }
51    }
52
53    /// Consumes the transaction, committing all changes made within it.
54    pub async fn commit(mut self) -> Result<(), Error> {
55        self.done = true;
56        let query = if let Some(sp) = self.savepoint.as_ref() {
57            format!("RELEASE {}", sp.name)
58        } else {
59            "COMMIT".to_string()
60        };
61        self.client.batch_execute(&query).await
62    }
63
64    /// Rolls the transaction back, discarding all changes made within it.
65    ///
66    /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
67    pub async fn rollback(mut self) -> Result<(), Error> {
68        self.done = true;
69        let query = if let Some(sp) = self.savepoint.as_ref() {
70            format!("ROLLBACK TO {}", sp.name)
71        } else {
72            "ROLLBACK".to_string()
73        };
74        self.client.batch_execute(&query).await
75    }
76
77    /// Like `Client::prepare`.
78    pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
79        self.client.prepare(query).await
80    }
81
82    /// Like `Client::prepare_typed`.
83    pub async fn prepare_typed(
84        &self,
85        query: &str,
86        parameter_types: &[Type],
87    ) -> Result<Statement, Error> {
88        self.client.prepare_typed(query, parameter_types).await
89    }
90
91    /// Like `Client::query`.
92    pub async fn query<T>(
93        &self,
94        statement: &T,
95        params: &[&(dyn ToSql + Sync)],
96    ) -> Result<Vec<Row>, Error>
97    where
98        T: ?Sized + ToStatement,
99    {
100        self.client.query(statement, params).await
101    }
102
103    /// Like `Client::query_one`.
104    pub async fn query_one<T>(
105        &self,
106        statement: &T,
107        params: &[&(dyn ToSql + Sync)],
108    ) -> Result<Row, Error>
109    where
110        T: ?Sized + ToStatement,
111    {
112        self.client.query_one(statement, params).await
113    }
114
115    /// Like `Client::query_opt`.
116    pub async fn query_opt<T>(
117        &self,
118        statement: &T,
119        params: &[&(dyn ToSql + Sync)],
120    ) -> Result<Option<Row>, Error>
121    where
122        T: ?Sized + ToStatement,
123    {
124        self.client.query_opt(statement, params).await
125    }
126
127    /// Like `Client::query_raw`.
128    pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
129    where
130        T: ?Sized + ToStatement,
131        P: BorrowToSql,
132        I: IntoIterator<Item = P>,
133        I::IntoIter: ExactSizeIterator,
134    {
135        self.client.query_raw(statement, params).await
136    }
137
138    /// Like `Client::query_typed`.
139    pub async fn query_typed(
140        &self,
141        statement: &str,
142        params: &[(&(dyn ToSql + Sync), Type)],
143    ) -> Result<Vec<Row>, Error> {
144        self.client.query_typed(statement, params).await
145    }
146
147    /// Like `Client::query_typed_one`.
148    pub async fn query_typed_one(
149        &self,
150        statement: &str,
151        params: &[(&(dyn ToSql + Sync), Type)],
152    ) -> Result<Row, Error> {
153        self.client.query_typed_one(statement, params).await
154    }
155
156    /// Like `Client::query_typed_opt`.
157    pub async fn query_typed_opt(
158        &self,
159        statement: &str,
160        params: &[(&(dyn ToSql + Sync), Type)],
161    ) -> Result<Option<Row>, Error> {
162        self.client.query_typed_opt(statement, params).await
163    }
164
165    /// Like `Client::query_typed_raw`.
166    pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
167    where
168        P: BorrowToSql,
169        I: IntoIterator<Item = (P, Type)>,
170    {
171        self.client.query_typed_raw(query, params).await
172    }
173
174    /// Like `Client::execute`.
175    pub async fn execute<T>(
176        &self,
177        statement: &T,
178        params: &[&(dyn ToSql + Sync)],
179    ) -> Result<u64, Error>
180    where
181        T: ?Sized + ToStatement,
182    {
183        self.client.execute(statement, params).await
184    }
185
186    /// Like `Client::execute_typed`.
187    pub async fn execute_typed(
188        &self,
189        statement: &str,
190        params: &[(&(dyn ToSql + Sync), Type)],
191    ) -> Result<u64, Error> {
192        self.client.execute_typed(statement, params).await
193    }
194
195    /// Like `Client::execute_iter`.
196    pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
197    where
198        T: ?Sized + ToStatement,
199        P: BorrowToSql,
200        I: IntoIterator<Item = P>,
201        I::IntoIter: ExactSizeIterator,
202    {
203        self.client.execute_raw(statement, params).await
204    }
205
206    /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
207    ///
208    /// Portals only last for the duration of the transaction in which they are created, and can only be used on the
209    /// connection that created them.
210    ///
211    /// # Panics
212    ///
213    /// Panics if the number of parameters provided does not match the number expected.
214    pub async fn bind<T>(
215        &self,
216        statement: &T,
217        params: &[&(dyn ToSql + Sync)],
218    ) -> Result<Portal, Error>
219    where
220        T: ?Sized + ToStatement,
221    {
222        self.bind_raw(statement, slice_iter(params)).await
223    }
224
225    /// A maximally flexible version of [`bind`].
226    ///
227    /// [`bind`]: #method.bind
228    pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
229    where
230        T: ?Sized + ToStatement,
231        P: BorrowToSql,
232        I: IntoIterator<Item = P>,
233        I::IntoIter: ExactSizeIterator,
234    {
235        let statement = statement
236            .__convert()
237            .into_statement(self.client.inner())
238            .await?;
239        bind::bind(self.client.inner(), statement, params).await
240    }
241
242    /// Continues execution of a portal, returning a stream of the resulting rows.
243    ///
244    /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
245    /// `query_portal`. If the requested number is negative or 0, all rows will be returned.
246    pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
247        self.query_portal_raw(portal, max_rows)
248            .await?
249            .try_collect()
250            .await
251    }
252
253    /// The maximally flexible version of [`query_portal`].
254    ///
255    /// [`query_portal`]: #method.query_portal
256    pub async fn query_portal_raw(
257        &self,
258        portal: &Portal,
259        max_rows: i32,
260    ) -> Result<RowStream, Error> {
261        query::query_portal(self.client.inner(), portal, max_rows).await
262    }
263
264    /// Like `Client::copy_in`.
265    pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
266    where
267        T: ?Sized + ToStatement,
268        U: Buf + 'static + Send,
269    {
270        self.client.copy_in(statement).await
271    }
272
273    /// Like `Client::copy_out`.
274    pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
275    where
276        T: ?Sized + ToStatement,
277    {
278        self.client.copy_out(statement).await
279    }
280
281    /// Like `Client::simple_query`.
282    pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
283        self.client.simple_query(query).await
284    }
285
286    /// Like `Client::batch_execute`.
287    pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
288        self.client.batch_execute(query).await
289    }
290
291    /// Like `Client::cancel_token`.
292    pub fn cancel_token(&self) -> CancelToken {
293        self.client.cancel_token()
294    }
295
296    /// Like `Client::cancel_query`.
297    #[cfg(feature = "runtime")]
298    #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
299    pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
300    where
301        T: MakeTlsConnect<Socket>,
302    {
303        #[allow(deprecated)]
304        self.client.cancel_query(tls).await
305    }
306
307    /// Like `Client::cancel_query_raw`.
308    #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
309    pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
310    where
311        S: AsyncRead + AsyncWrite + Unpin,
312        T: TlsConnect<S>,
313    {
314        #[allow(deprecated)]
315        self.client.cancel_query_raw(stream, tls).await
316    }
317
318    /// Like `Client::transaction`, but creates a nested transaction via a savepoint.
319    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
320        self._savepoint(None).await
321    }
322
323    /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
324    pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
325    where
326        I: Into<String>,
327    {
328        self._savepoint(Some(name.into())).await
329    }
330
331    async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
332        let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
333        let name = name.unwrap_or_else(|| format!("sp_{depth}"));
334        let query = format!("SAVEPOINT {name}");
335        self.batch_execute(&query).await?;
336
337        Ok(Transaction {
338            client: self.client,
339            savepoint: Some(Savepoint { name, depth }),
340            done: false,
341        })
342    }
343
344    /// Returns a reference to the underlying `Client`.
345    pub fn client(&self) -> &Client {
346        self.client
347    }
348}