1use crate::copy_out::CopyOutStream;
2use crate::query::RowStream;
3#[cfg(feature = "runtime")]
4use crate::tls::MakeTlsConnect;
5use crate::tls::TlsConnect;
6use crate::types::{BorrowToSql, ToSql, Type};
7#[cfg(feature = "runtime")]
8use crate::Socket;
9use crate::{
10 bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
11 SimpleQueryMessage, Statement, ToStatement,
12};
13use bytes::Buf;
14use futures_util::TryStreamExt;
15use tokio::io::{AsyncRead, AsyncWrite};
16
17pub struct Transaction<'a> {
22 client: &'a mut Client,
23 savepoint: Option<Savepoint>,
24 done: bool,
25}
26
27struct Savepoint {
29 name: String,
30 depth: u32,
31}
32
33impl Drop for Transaction<'_> {
34 fn drop(&mut self) {
35 if self.done {
36 return;
37 }
38
39 let name = self.savepoint.as_ref().map(|sp| sp.name.as_str());
40 self.client.__private_api_rollback(name);
41 }
42}
43
44impl<'a> Transaction<'a> {
45 pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
46 Transaction {
47 client,
48 savepoint: None,
49 done: false,
50 }
51 }
52
53 pub async fn commit(mut self) -> Result<(), Error> {
55 self.done = true;
56 let query = if let Some(sp) = self.savepoint.as_ref() {
57 format!("RELEASE {}", sp.name)
58 } else {
59 "COMMIT".to_string()
60 };
61 self.client.batch_execute(&query).await
62 }
63
64 pub async fn rollback(mut self) -> Result<(), Error> {
68 self.done = true;
69 let query = if let Some(sp) = self.savepoint.as_ref() {
70 format!("ROLLBACK TO {}", sp.name)
71 } else {
72 "ROLLBACK".to_string()
73 };
74 self.client.batch_execute(&query).await
75 }
76
77 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
79 self.client.prepare(query).await
80 }
81
82 pub async fn prepare_typed(
84 &self,
85 query: &str,
86 parameter_types: &[Type],
87 ) -> Result<Statement, Error> {
88 self.client.prepare_typed(query, parameter_types).await
89 }
90
91 pub async fn query<T>(
93 &self,
94 statement: &T,
95 params: &[&(dyn ToSql + Sync)],
96 ) -> Result<Vec<Row>, Error>
97 where
98 T: ?Sized + ToStatement,
99 {
100 self.client.query(statement, params).await
101 }
102
103 pub async fn query_one<T>(
105 &self,
106 statement: &T,
107 params: &[&(dyn ToSql + Sync)],
108 ) -> Result<Row, Error>
109 where
110 T: ?Sized + ToStatement,
111 {
112 self.client.query_one(statement, params).await
113 }
114
115 pub async fn query_opt<T>(
117 &self,
118 statement: &T,
119 params: &[&(dyn ToSql + Sync)],
120 ) -> Result<Option<Row>, Error>
121 where
122 T: ?Sized + ToStatement,
123 {
124 self.client.query_opt(statement, params).await
125 }
126
127 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
129 where
130 T: ?Sized + ToStatement,
131 P: BorrowToSql,
132 I: IntoIterator<Item = P>,
133 I::IntoIter: ExactSizeIterator,
134 {
135 self.client.query_raw(statement, params).await
136 }
137
138 pub async fn query_typed(
140 &self,
141 statement: &str,
142 params: &[(&(dyn ToSql + Sync), Type)],
143 ) -> Result<Vec<Row>, Error> {
144 self.client.query_typed(statement, params).await
145 }
146
147 pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
149 where
150 P: BorrowToSql,
151 I: IntoIterator<Item = (P, Type)>,
152 {
153 self.client.query_typed_raw(query, params).await
154 }
155
156 pub async fn execute<T>(
158 &self,
159 statement: &T,
160 params: &[&(dyn ToSql + Sync)],
161 ) -> Result<u64, Error>
162 where
163 T: ?Sized + ToStatement,
164 {
165 self.client.execute(statement, params).await
166 }
167
168 pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
170 where
171 T: ?Sized + ToStatement,
172 P: BorrowToSql,
173 I: IntoIterator<Item = P>,
174 I::IntoIter: ExactSizeIterator,
175 {
176 self.client.execute_raw(statement, params).await
177 }
178
179 pub async fn bind<T>(
188 &self,
189 statement: &T,
190 params: &[&(dyn ToSql + Sync)],
191 ) -> Result<Portal, Error>
192 where
193 T: ?Sized + ToStatement,
194 {
195 self.bind_raw(statement, slice_iter(params)).await
196 }
197
198 pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
202 where
203 T: ?Sized + ToStatement,
204 P: BorrowToSql,
205 I: IntoIterator<Item = P>,
206 I::IntoIter: ExactSizeIterator,
207 {
208 let statement = statement
209 .__convert()
210 .into_statement(self.client.inner())
211 .await?;
212 bind::bind(self.client.inner(), statement, params).await
213 }
214
215 pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
220 self.query_portal_raw(portal, max_rows)
221 .await?
222 .try_collect()
223 .await
224 }
225
226 pub async fn query_portal_raw(
230 &self,
231 portal: &Portal,
232 max_rows: i32,
233 ) -> Result<RowStream, Error> {
234 query::query_portal(self.client.inner(), portal, max_rows).await
235 }
236
237 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
239 where
240 T: ?Sized + ToStatement,
241 U: Buf + 'static + Send,
242 {
243 self.client.copy_in(statement).await
244 }
245
246 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
248 where
249 T: ?Sized + ToStatement,
250 {
251 self.client.copy_out(statement).await
252 }
253
254 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
256 self.client.simple_query(query).await
257 }
258
259 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
261 self.client.batch_execute(query).await
262 }
263
264 pub fn cancel_token(&self) -> CancelToken {
266 self.client.cancel_token()
267 }
268
269 #[cfg(feature = "runtime")]
271 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
272 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
273 where
274 T: MakeTlsConnect<Socket>,
275 {
276 #[allow(deprecated)]
277 self.client.cancel_query(tls).await
278 }
279
280 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
282 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
283 where
284 S: AsyncRead + AsyncWrite + Unpin,
285 T: TlsConnect<S>,
286 {
287 #[allow(deprecated)]
288 self.client.cancel_query_raw(stream, tls).await
289 }
290
291 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
293 self._savepoint(None).await
294 }
295
296 pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
298 where
299 I: Into<String>,
300 {
301 self._savepoint(Some(name.into())).await
302 }
303
304 async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
305 let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
306 let name = name.unwrap_or_else(|| format!("sp_{depth}"));
307 let query = format!("SAVEPOINT {name}");
308 self.batch_execute(&query).await?;
309
310 Ok(Transaction {
311 client: self.client,
312 savepoint: Some(Savepoint { name, depth }),
313 done: false,
314 })
315 }
316
317 pub fn client(&self) -> &Client {
319 self.client
320 }
321}