madsim_tokio_postgres/client.rs
1use crate::codec::{BackendMessages, FrontendMessage};
2#[cfg(feature = "runtime")]
3use crate::config::Host;
4use crate::config::SslMode;
5use crate::connection::{Request, RequestMessages};
6use crate::copy_out::CopyOutStream;
7use crate::query::RowStream;
8use crate::simple_query::SimpleQueryStream;
9#[cfg(feature = "runtime")]
10use crate::tls::MakeTlsConnect;
11use crate::tls::TlsConnect;
12use crate::types::{Oid, ToSql, Type};
13#[cfg(feature = "runtime")]
14use crate::Socket;
15use crate::{
16 copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
17 Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
18};
19use bytes::{Buf, BytesMut};
20use fallible_iterator::FallibleIterator;
21use futures::channel::mpsc;
22use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
23use parking_lot::Mutex;
24use postgres_protocol::message::{backend::Message, frontend};
25use postgres_types::BorrowToSql;
26use std::collections::HashMap;
27use std::fmt;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30#[cfg(feature = "runtime")]
31use std::time::Duration;
32use tokio::io::{AsyncRead, AsyncWrite};
33
34pub struct Responses {
35 receiver: mpsc::Receiver<BackendMessages>,
36 cur: BackendMessages,
37}
38
39impl Responses {
40 pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
41 loop {
42 match self.cur.next().map_err(Error::parse)? {
43 Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
44 Some(message) => return Poll::Ready(Ok(message)),
45 None => {}
46 }
47
48 match ready!(self.receiver.poll_next_unpin(cx)) {
49 Some(messages) => self.cur = messages,
50 None => return Poll::Ready(Err(Error::closed())),
51 }
52 }
53 }
54
55 pub async fn next(&mut self) -> Result<Message, Error> {
56 future::poll_fn(|cx| self.poll_next(cx)).await
57 }
58}
59
60/// A cache of type info and prepared statements for fetching type info
61/// (corresponding to the queries in the [prepare](prepare) module).
62#[derive(Default)]
63struct CachedTypeInfo {
64 /// A statement for basic information for a type from its
65 /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
66 /// fallback).
67 typeinfo: Option<Statement>,
68 /// A statement for getting information for a composite type from its OID.
69 /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
70 typeinfo_composite: Option<Statement>,
71 /// A statement for getting information for a composite type from its OID.
72 /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
73 /// its fallback).
74 typeinfo_enum: Option<Statement>,
75
76 /// Cache of types already looked up.
77 types: HashMap<Oid, Type>,
78}
79
80pub struct InnerClient {
81 sender: mpsc::UnboundedSender<Request>,
82 cached_typeinfo: Mutex<CachedTypeInfo>,
83
84 /// A buffer to use when writing out postgres commands.
85 buffer: Mutex<BytesMut>,
86}
87
88impl InnerClient {
89 pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
90 let (sender, receiver) = mpsc::channel(1);
91 let request = Request { messages, sender };
92 self.sender
93 .unbounded_send(request)
94 .map_err(|_| Error::closed())?;
95
96 Ok(Responses {
97 receiver,
98 cur: BackendMessages::empty(),
99 })
100 }
101
102 pub fn typeinfo(&self) -> Option<Statement> {
103 self.cached_typeinfo.lock().typeinfo.clone()
104 }
105
106 pub fn set_typeinfo(&self, statement: &Statement) {
107 self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
108 }
109
110 pub fn typeinfo_composite(&self) -> Option<Statement> {
111 self.cached_typeinfo.lock().typeinfo_composite.clone()
112 }
113
114 pub fn set_typeinfo_composite(&self, statement: &Statement) {
115 self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
116 }
117
118 pub fn typeinfo_enum(&self) -> Option<Statement> {
119 self.cached_typeinfo.lock().typeinfo_enum.clone()
120 }
121
122 pub fn set_typeinfo_enum(&self, statement: &Statement) {
123 self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
124 }
125
126 pub fn type_(&self, oid: Oid) -> Option<Type> {
127 self.cached_typeinfo.lock().types.get(&oid).cloned()
128 }
129
130 pub fn set_type(&self, oid: Oid, type_: &Type) {
131 self.cached_typeinfo.lock().types.insert(oid, type_.clone());
132 }
133
134 pub fn clear_type_cache(&self) {
135 self.cached_typeinfo.lock().types.clear();
136 }
137
138 /// Call the given function with a buffer to be used when writing out
139 /// postgres commands.
140 pub fn with_buf<F, R>(&self, f: F) -> R
141 where
142 F: FnOnce(&mut BytesMut) -> R,
143 {
144 let mut buffer = self.buffer.lock();
145 let r = f(&mut buffer);
146 buffer.clear();
147 r
148 }
149}
150
151#[cfg(feature = "runtime")]
152#[derive(Clone)]
153pub(crate) struct SocketConfig {
154 pub host: Host,
155 pub port: u16,
156 pub connect_timeout: Option<Duration>,
157 pub keepalives: bool,
158 pub keepalives_idle: Duration,
159}
160
161/// An asynchronous PostgreSQL client.
162///
163/// The client is one half of what is returned when a connection is established. Users interact with the database
164/// through this client object.
165pub struct Client {
166 inner: Arc<InnerClient>,
167 #[cfg(feature = "runtime")]
168 socket_config: Option<SocketConfig>,
169 ssl_mode: SslMode,
170 process_id: i32,
171 secret_key: i32,
172}
173
174impl Client {
175 pub(crate) fn new(
176 sender: mpsc::UnboundedSender<Request>,
177 ssl_mode: SslMode,
178 process_id: i32,
179 secret_key: i32,
180 ) -> Client {
181 Client {
182 inner: Arc::new(InnerClient {
183 sender,
184 cached_typeinfo: Default::default(),
185 buffer: Default::default(),
186 }),
187 #[cfg(feature = "runtime")]
188 socket_config: None,
189 ssl_mode,
190 process_id,
191 secret_key,
192 }
193 }
194
195 pub(crate) fn inner(&self) -> &Arc<InnerClient> {
196 &self.inner
197 }
198
199 #[cfg(feature = "runtime")]
200 pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
201 self.socket_config = Some(socket_config);
202 }
203
204 /// Creates a new prepared statement.
205 ///
206 /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
207 /// which are set when executed. Prepared statements can only be used with the connection that created them.
208 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
209 self.prepare_typed(query, &[]).await
210 }
211
212 /// Like `prepare`, but allows the types of query parameters to be explicitly specified.
213 ///
214 /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
215 /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
216 pub async fn prepare_typed(
217 &self,
218 query: &str,
219 parameter_types: &[Type],
220 ) -> Result<Statement, Error> {
221 prepare::prepare(&self.inner, query, parameter_types).await
222 }
223
224 /// Executes a statement, returning a vector of the resulting rows.
225 ///
226 /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
227 /// provided, 1-indexed.
228 ///
229 /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
230 /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
231 /// with the `prepare` method.
232 ///
233 /// # Panics
234 ///
235 /// Panics if the number of parameters provided does not match the number expected.
236 pub async fn query<T>(
237 &self,
238 statement: &T,
239 params: &[&(dyn ToSql + Sync)],
240 ) -> Result<Vec<Row>, Error>
241 where
242 T: ?Sized + ToStatement,
243 {
244 self.query_raw(statement, slice_iter(params))
245 .await?
246 .try_collect()
247 .await
248 }
249
250 /// Executes a statement which returns a single row, returning it.
251 ///
252 /// Returns an error if the query does not return exactly one row.
253 ///
254 /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
255 /// provided, 1-indexed.
256 ///
257 /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
258 /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
259 /// with the `prepare` method.
260 ///
261 /// # Panics
262 ///
263 /// Panics if the number of parameters provided does not match the number expected.
264 pub async fn query_one<T>(
265 &self,
266 statement: &T,
267 params: &[&(dyn ToSql + Sync)],
268 ) -> Result<Row, Error>
269 where
270 T: ?Sized + ToStatement,
271 {
272 let stream = self.query_raw(statement, slice_iter(params)).await?;
273 pin_mut!(stream);
274
275 let row = match stream.try_next().await? {
276 Some(row) => row,
277 None => return Err(Error::row_count()),
278 };
279
280 if stream.try_next().await?.is_some() {
281 return Err(Error::row_count());
282 }
283
284 Ok(row)
285 }
286
287 /// Executes a statements which returns zero or one rows, returning it.
288 ///
289 /// Returns an error if the query returns more than one row.
290 ///
291 /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
292 /// provided, 1-indexed.
293 ///
294 /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
295 /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
296 /// with the `prepare` method.
297 ///
298 /// # Panics
299 ///
300 /// Panics if the number of parameters provided does not match the number expected.
301 pub async fn query_opt<T>(
302 &self,
303 statement: &T,
304 params: &[&(dyn ToSql + Sync)],
305 ) -> Result<Option<Row>, Error>
306 where
307 T: ?Sized + ToStatement,
308 {
309 let stream = self.query_raw(statement, slice_iter(params)).await?;
310 pin_mut!(stream);
311
312 let row = match stream.try_next().await? {
313 Some(row) => row,
314 None => return Ok(None),
315 };
316
317 if stream.try_next().await?.is_some() {
318 return Err(Error::row_count());
319 }
320
321 Ok(Some(row))
322 }
323
324 /// The maximally flexible version of [`query`].
325 ///
326 /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
327 /// provided, 1-indexed.
328 ///
329 /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
330 /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
331 /// with the `prepare` method.
332 ///
333 /// # Panics
334 ///
335 /// Panics if the number of parameters provided does not match the number expected.
336 ///
337 /// [`query`]: #method.query
338 ///
339 /// # Examples
340 ///
341 /// ```no_run
342 /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
343 /// use tokio_postgres::types::ToSql;
344 /// use futures::{pin_mut, TryStreamExt};
345 ///
346 /// let params: Vec<String> = vec![
347 /// "first param".into(),
348 /// "second param".into(),
349 /// ];
350 /// let mut it = client.query_raw(
351 /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
352 /// params,
353 /// ).await?;
354 ///
355 /// pin_mut!(it);
356 /// while let Some(row) = it.try_next().await? {
357 /// let foo: i32 = row.get("foo");
358 /// println!("foo: {}", foo);
359 /// }
360 /// # Ok(())
361 /// # }
362 /// ```
363 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
364 where
365 T: ?Sized + ToStatement,
366 P: BorrowToSql,
367 I: IntoIterator<Item = P>,
368 I::IntoIter: ExactSizeIterator,
369 {
370 let statement = statement.__convert().into_statement(self).await?;
371 query::query(&self.inner, statement, params).await
372 }
373
374 /// Executes a statement, returning the number of rows modified.
375 ///
376 /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
377 /// provided, 1-indexed.
378 ///
379 /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
380 /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
381 /// with the `prepare` method.
382 ///
383 /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
384 ///
385 /// # Panics
386 ///
387 /// Panics if the number of parameters provided does not match the number expected.
388 pub async fn execute<T>(
389 &self,
390 statement: &T,
391 params: &[&(dyn ToSql + Sync)],
392 ) -> Result<u64, Error>
393 where
394 T: ?Sized + ToStatement,
395 {
396 self.execute_raw(statement, slice_iter(params)).await
397 }
398
399 /// The maximally flexible version of [`execute`].
400 ///
401 /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
402 /// provided, 1-indexed.
403 ///
404 /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
405 /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
406 /// with the `prepare` method.
407 ///
408 /// # Panics
409 ///
410 /// Panics if the number of parameters provided does not match the number expected.
411 ///
412 /// [`execute`]: #method.execute
413 pub async fn execute_raw<T, P, I>(&self, statement: &T, params: I) -> Result<u64, Error>
414 where
415 T: ?Sized + ToStatement,
416 P: BorrowToSql,
417 I: IntoIterator<Item = P>,
418 I::IntoIter: ExactSizeIterator,
419 {
420 let statement = statement.__convert().into_statement(self).await?;
421 query::execute(self.inner(), statement, params).await
422 }
423
424 /// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data.
425 ///
426 /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. The copy *must*
427 /// be explicitly completed via the `Sink::close` or `finish` methods. If it is not, the copy will be aborted.
428 ///
429 /// # Panics
430 ///
431 /// Panics if the statement contains parameters.
432 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
433 where
434 T: ?Sized + ToStatement,
435 U: Buf + 'static + Send,
436 {
437 let statement = statement.__convert().into_statement(self).await?;
438 copy_in::copy_in(self.inner(), statement).await
439 }
440
441 /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.
442 ///
443 /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any.
444 ///
445 /// # Panics
446 ///
447 /// Panics if the statement contains parameters.
448 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
449 where
450 T: ?Sized + ToStatement,
451 {
452 let statement = statement.__convert().into_statement(self).await?;
453 copy_out::copy_out(self.inner(), statement).await
454 }
455
456 /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
457 ///
458 /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
459 /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
460 /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
461 /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
462 /// or a row of data. This preserves the framing between the separate statements in the request.
463 ///
464 /// # Warning
465 ///
466 /// Prepared statements should be use for any query which contains user-specified data, as they provided the
467 /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
468 /// them to this method!
469 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
470 self.simple_query_raw(query).await?.try_collect().await
471 }
472
473 pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
474 simple_query::simple_query(self.inner(), query).await
475 }
476
477 /// Executes a sequence of SQL statements using the simple query protocol.
478 ///
479 /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
480 /// point. This is intended for use when, for example, initializing a database schema.
481 ///
482 /// # Warning
483 ///
484 /// Prepared statements should be use for any query which contains user-specified data, as they provided the
485 /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
486 /// them to this method!
487 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
488 simple_query::batch_execute(self.inner(), query).await
489 }
490
491 /// Begins a new database transaction.
492 ///
493 /// The transaction will roll back by default - use the `commit` method to commit it.
494 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
495 struct RollbackIfNotDone<'me> {
496 client: &'me Client,
497 done: bool,
498 }
499
500 impl<'a> Drop for RollbackIfNotDone<'a> {
501 fn drop(&mut self) {
502 if self.done {
503 return;
504 }
505
506 let buf = self.client.inner().with_buf(|buf| {
507 frontend::query("ROLLBACK", buf).unwrap();
508 buf.split().freeze()
509 });
510 let _ = self
511 .client
512 .inner()
513 .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
514 }
515 }
516
517 // This is done, as `Future` created by this method can be dropped after
518 // `RequestMessages` is synchronously send to the `Connection` by
519 // `batch_execute()`, but before `Responses` is asynchronously polled to
520 // completion. In that case `Transaction` won't be created and thus
521 // won't be rolled back.
522 {
523 let mut cleaner = RollbackIfNotDone {
524 client: self,
525 done: false,
526 };
527 self.batch_execute("BEGIN").await?;
528 cleaner.done = true;
529 }
530
531 Ok(Transaction::new(self))
532 }
533
534 /// Returns a builder for a transaction with custom settings.
535 ///
536 /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
537 /// attributes.
538 pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
539 TransactionBuilder::new(self)
540 }
541
542 /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
543 /// connection associated with this client.
544 pub fn cancel_token(&self) -> CancelToken {
545 CancelToken {
546 #[cfg(feature = "runtime")]
547 socket_config: self.socket_config.clone(),
548 ssl_mode: self.ssl_mode,
549 process_id: self.process_id,
550 secret_key: self.secret_key,
551 }
552 }
553
554 /// Attempts to cancel an in-progress query.
555 ///
556 /// The server provides no information about whether a cancellation attempt was successful or not. An error will
557 /// only be returned if the client was unable to connect to the database.
558 ///
559 /// Requires the `runtime` Cargo feature (enabled by default).
560 #[cfg(feature = "runtime")]
561 #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")]
562 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
563 where
564 T: MakeTlsConnect<Socket>,
565 {
566 self.cancel_token().cancel_query(tls).await
567 }
568
569 /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
570 /// connection itself.
571 #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")]
572 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
573 where
574 S: AsyncRead + AsyncWrite + Unpin,
575 T: TlsConnect<S>,
576 {
577 self.cancel_token().cancel_query_raw(stream, tls).await
578 }
579
580 /// Clears the client's type information cache.
581 ///
582 /// When user-defined types are used in a query, the client loads their definitions from the database and caches
583 /// them for the lifetime of the client. If those definitions are changed in the database, this method can be used
584 /// to flush the local cache and allow the new, updated definitions to be loaded.
585 pub fn clear_type_cache(&self) {
586 self.inner().clear_type_cache();
587 }
588
589 /// Determines if the connection to the server has already closed.
590 ///
591 /// In that case, all future queries will fail.
592 pub fn is_closed(&self) -> bool {
593 self.inner.sender.is_closed()
594 }
595
596 #[doc(hidden)]
597 pub fn __private_api_close(&mut self) {
598 self.inner.sender.close_channel()
599 }
600}
601
602impl fmt::Debug for Client {
603 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604 f.debug_struct("Client").finish()
605 }
606}