madsim_tokio_postgres/
client.rs

1use crate::codec::{BackendMessages, FrontendMessage};
2#[cfg(feature = "runtime")]
3use crate::config::Host;
4use crate::config::SslMode;
5use crate::connection::{Request, RequestMessages};
6use crate::copy_out::CopyOutStream;
7use crate::query::RowStream;
8use crate::simple_query::SimpleQueryStream;
9#[cfg(feature = "runtime")]
10use crate::tls::MakeTlsConnect;
11use crate::tls::TlsConnect;
12use crate::types::{Oid, ToSql, Type};
13#[cfg(feature = "runtime")]
14use crate::Socket;
15use crate::{
16    copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
17    Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
18};
19use bytes::{Buf, BytesMut};
20use fallible_iterator::FallibleIterator;
21use futures::channel::mpsc;
22use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
23use parking_lot::Mutex;
24use postgres_protocol::message::{backend::Message, frontend};
25use postgres_types::BorrowToSql;
26use std::collections::HashMap;
27use std::fmt;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30#[cfg(feature = "runtime")]
31use std::time::Duration;
32use tokio::io::{AsyncRead, AsyncWrite};
33
34pub struct Responses {
35    receiver: mpsc::Receiver<BackendMessages>,
36    cur: BackendMessages,
37}
38
39impl Responses {
40    pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
41        loop {
42            match self.cur.next().map_err(Error::parse)? {
43                Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
44                Some(message) => return Poll::Ready(Ok(message)),
45                None => {}
46            }
47
48            match ready!(self.receiver.poll_next_unpin(cx)) {
49                Some(messages) => self.cur = messages,
50                None => return Poll::Ready(Err(Error::closed())),
51            }
52        }
53    }
54
55    pub async fn next(&mut self) -> Result<Message, Error> {
56        future::poll_fn(|cx| self.poll_next(cx)).await
57    }
58}
59
60/// A cache of type info and prepared statements for fetching type info
61/// (corresponding to the queries in the [prepare](prepare) module).
62#[derive(Default)]
63struct CachedTypeInfo {
64    /// A statement for basic information for a type from its
65    /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
66    /// fallback).
67    typeinfo: Option<Statement>,
68    /// A statement for getting information for a composite type from its OID.
69    /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
70    typeinfo_composite: Option<Statement>,
71    /// A statement for getting information for a composite type from its OID.
72    /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
73    /// its fallback).
74    typeinfo_enum: Option<Statement>,
75
76    /// Cache of types already looked up.
77    types: HashMap<Oid, Type>,
78}
79
80pub struct InnerClient {
81    sender: mpsc::UnboundedSender<Request>,
82    cached_typeinfo: Mutex<CachedTypeInfo>,
83
84    /// A buffer to use when writing out postgres commands.
85    buffer: Mutex<BytesMut>,
86}
87
88impl InnerClient {
89    pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
90        let (sender, receiver) = mpsc::channel(1);
91        let request = Request { messages, sender };
92        self.sender
93            .unbounded_send(request)
94            .map_err(|_| Error::closed())?;
95
96        Ok(Responses {
97            receiver,
98            cur: BackendMessages::empty(),
99        })
100    }
101
102    pub fn typeinfo(&self) -> Option<Statement> {
103        self.cached_typeinfo.lock().typeinfo.clone()
104    }
105
106    pub fn set_typeinfo(&self, statement: &Statement) {
107        self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
108    }
109
110    pub fn typeinfo_composite(&self) -> Option<Statement> {
111        self.cached_typeinfo.lock().typeinfo_composite.clone()
112    }
113
114    pub fn set_typeinfo_composite(&self, statement: &Statement) {
115        self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
116    }
117
118    pub fn typeinfo_enum(&self) -> Option<Statement> {
119        self.cached_typeinfo.lock().typeinfo_enum.clone()
120    }
121
122    pub fn set_typeinfo_enum(&self, statement: &Statement) {
123        self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
124    }
125
126    pub fn type_(&self, oid: Oid) -> Option<Type> {
127        self.cached_typeinfo.lock().types.get(&oid).cloned()
128    }
129
130    pub fn set_type(&self, oid: Oid, type_: &Type) {
131        self.cached_typeinfo.lock().types.insert(oid, type_.clone());
132    }
133
134    pub fn clear_type_cache(&self) {
135        self.cached_typeinfo.lock().types.clear();
136    }
137
138    /// Call the given function with a buffer to be used when writing out
139    /// postgres commands.
140    pub fn with_buf<F, R>(&self, f: F) -> R
141    where
142        F: FnOnce(&mut BytesMut) -> R,
143    {
144        let mut buffer = self.buffer.lock();
145        let r = f(&mut buffer);
146        buffer.clear();
147        r
148    }
149}
150
151#[cfg(feature = "runtime")]
152#[derive(Clone)]
153pub(crate) struct SocketConfig {
154    pub host: Host,
155    pub port: u16,
156    pub connect_timeout: Option<Duration>,
157    pub keepalives: bool,
158    pub keepalives_idle: Duration,
159}
160
161/// An asynchronous PostgreSQL client.
162///
163/// The client is one half of what is returned when a connection is established. Users interact with the database
164/// through this client object.
165pub struct Client {
166    inner: Arc<InnerClient>,
167    #[cfg(feature = "runtime")]
168    socket_config: Option<SocketConfig>,
169    ssl_mode: SslMode,
170    process_id: i32,
171    secret_key: i32,
172}
173
174impl Client {
175    pub(crate) fn new(
176        sender: mpsc::UnboundedSender<Request>,
177        ssl_mode: SslMode,
178        process_id: i32,
179        secret_key: i32,
180    ) -> Client {
181        Client {
182            inner: Arc::new(InnerClient {
183                sender,
184                cached_typeinfo: Default::default(),
185                buffer: Default::default(),
186            }),
187            #[cfg(feature = "runtime")]
188            socket_config: None,
189            ssl_mode,
190            process_id,
191            secret_key,
192        }
193    }
194
195    pub(crate) fn inner(&self) -> &Arc<InnerClient> {
196        &self.inner
197    }
198
199    #[cfg(feature = "runtime")]
200    pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
201        self.socket_config = Some(socket_config);
202    }
203
204    /// Creates a new prepared statement.
205    ///
206    /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
207    /// which are set when executed. Prepared statements can only be used with the connection that created them.
208    pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
209        self.prepare_typed(query, &[]).await
210    }
211
212    /// Like `prepare`, but allows the types of query parameters to be explicitly specified.
213    ///
214    /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
215    /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
216    pub async fn prepare_typed(
217        &self,
218        query: &str,
219        parameter_types: &[Type],
220    ) -> Result<Statement, Error> {
221        prepare::prepare(&self.inner, query, parameter_types).await
222    }
223
224    /// Executes a statement, returning a vector of the resulting rows.
225    ///
226    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
227    /// provided, 1-indexed.
228    ///
229    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
230    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
231    /// with the `prepare` method.
232    ///
233    /// # Panics
234    ///
235    /// Panics if the number of parameters provided does not match the number expected.
236    pub async fn query<T>(
237        &self,
238        statement: &T,
239        params: &[&(dyn ToSql + Sync)],
240    ) -> Result<Vec<Row>, Error>
241    where
242        T: ?Sized + ToStatement,
243    {
244        self.query_raw(statement, slice_iter(params))
245            .await?
246            .try_collect()
247            .await
248    }
249
250    /// Executes a statement which returns a single row, returning it.
251    ///
252    /// Returns an error if the query does not return exactly one row.
253    ///
254    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
255    /// provided, 1-indexed.
256    ///
257    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
258    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
259    /// with the `prepare` method.
260    ///
261    /// # Panics
262    ///
263    /// Panics if the number of parameters provided does not match the number expected.
264    pub async fn query_one<T>(
265        &self,
266        statement: &T,
267        params: &[&(dyn ToSql + Sync)],
268    ) -> Result<Row, Error>
269    where
270        T: ?Sized + ToStatement,
271    {
272        let stream = self.query_raw(statement, slice_iter(params)).await?;
273        pin_mut!(stream);
274
275        let row = match stream.try_next().await? {
276            Some(row) => row,
277            None => return Err(Error::row_count()),
278        };
279
280        if stream.try_next().await?.is_some() {
281            return Err(Error::row_count());
282        }
283
284        Ok(row)
285    }
286
287    /// Executes a statements which returns zero or one rows, returning it.
288    ///
289    /// Returns an error if the query returns more than one row.
290    ///
291    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
292    /// provided, 1-indexed.
293    ///
294    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
295    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
296    /// with the `prepare` method.
297    ///
298    /// # Panics
299    ///
300    /// Panics if the number of parameters provided does not match the number expected.
301    pub async fn query_opt<T>(
302        &self,
303        statement: &T,
304        params: &[&(dyn ToSql + Sync)],
305    ) -> Result<Option<Row>, Error>
306    where
307        T: ?Sized + ToStatement,
308    {
309        let stream = self.query_raw(statement, slice_iter(params)).await?;
310        pin_mut!(stream);
311
312        let row = match stream.try_next().await? {
313            Some(row) => row,
314            None => return Ok(None),
315        };
316
317        if stream.try_next().await?.is_some() {
318            return Err(Error::row_count());
319        }
320
321        Ok(Some(row))
322    }
323
324    /// The maximally flexible version of [`query`].
325    ///
326    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
327    /// provided, 1-indexed.
328    ///
329    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
330    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
331    /// with the `prepare` method.
332    ///
333    /// # Panics
334    ///
335    /// Panics if the number of parameters provided does not match the number expected.
336    ///
337    /// [`query`]: #method.query
338    ///
339    /// # Examples
340    ///
341    /// ```no_run
342    /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
343    /// use tokio_postgres::types::ToSql;
344    /// use futures::{pin_mut, TryStreamExt};
345    ///
346    /// let params: Vec<String> = vec![
347    ///     "first param".into(),
348    ///     "second param".into(),
349    /// ];
350    /// let mut it = client.query_raw(
351    ///     "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
352    ///     params,
353    /// ).await?;
354    ///
355    /// pin_mut!(it);
356    /// while let Some(row) = it.try_next().await? {
357    ///     let foo: i32 = row.get("foo");
358    ///     println!("foo: {}", foo);
359    /// }
360    /// # Ok(())
361    /// # }
362    /// ```
363    pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
364    where
365        T: ?Sized + ToStatement,
366        P: BorrowToSql,
367        I: IntoIterator<Item = P>,
368        I::IntoIter: ExactSizeIterator,
369    {
370        let statement = statement.__convert().into_statement(self).await?;
371        query::query(&self.inner, statement, params).await
372    }
373
374    /// Executes a statement, returning the number of rows modified.
375    ///
376    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
377    /// provided, 1-indexed.
378    ///
379    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
380    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
381    /// with the `prepare` method.
382    ///
383    /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
384    ///
385    /// # Panics
386    ///
387    /// Panics if the number of parameters provided does not match the number expected.
388    pub async fn execute<T>(
389        &self,
390        statement: &T,
391        params: &[&(dyn ToSql + Sync)],
392    ) -> Result<u64, Error>
393    where
394        T: ?Sized + ToStatement,
395    {
396        self.execute_raw(statement, slice_iter(params)).await
397    }
398
399    /// The maximally flexible version of [`execute`].
400    ///
401    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
402    /// provided, 1-indexed.
403    ///
404    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
405    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
406    /// with the `prepare` method.
407    ///
408    /// # Panics
409    ///
410    /// Panics if the number of parameters provided does not match the number expected.
411    ///
412    /// [`execute`]: #method.execute
413    pub async fn execute_raw<T, P, I>(&self, statement: &T, params: I) -> Result<u64, Error>
414    where
415        T: ?Sized + ToStatement,
416        P: BorrowToSql,
417        I: IntoIterator<Item = P>,
418        I::IntoIter: ExactSizeIterator,
419    {
420        let statement = statement.__convert().into_statement(self).await?;
421        query::execute(self.inner(), statement, params).await
422    }
423
424    /// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data.
425    ///
426    /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. The copy *must*
427    /// be explicitly completed via the `Sink::close` or `finish` methods. If it is not, the copy will be aborted.
428    ///
429    /// # Panics
430    ///
431    /// Panics if the statement contains parameters.
432    pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
433    where
434        T: ?Sized + ToStatement,
435        U: Buf + 'static + Send,
436    {
437        let statement = statement.__convert().into_statement(self).await?;
438        copy_in::copy_in(self.inner(), statement).await
439    }
440
441    /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.
442    ///
443    /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any.
444    ///
445    /// # Panics
446    ///
447    /// Panics if the statement contains parameters.
448    pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
449    where
450        T: ?Sized + ToStatement,
451    {
452        let statement = statement.__convert().into_statement(self).await?;
453        copy_out::copy_out(self.inner(), statement).await
454    }
455
456    /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
457    ///
458    /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
459    /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
460    /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
461    /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
462    /// or a row of data. This preserves the framing between the separate statements in the request.
463    ///
464    /// # Warning
465    ///
466    /// Prepared statements should be use for any query which contains user-specified data, as they provided the
467    /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
468    /// them to this method!
469    pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
470        self.simple_query_raw(query).await?.try_collect().await
471    }
472
473    pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
474        simple_query::simple_query(self.inner(), query).await
475    }
476
477    /// Executes a sequence of SQL statements using the simple query protocol.
478    ///
479    /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
480    /// point. This is intended for use when, for example, initializing a database schema.
481    ///
482    /// # Warning
483    ///
484    /// Prepared statements should be use for any query which contains user-specified data, as they provided the
485    /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
486    /// them to this method!
487    pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
488        simple_query::batch_execute(self.inner(), query).await
489    }
490
491    /// Begins a new database transaction.
492    ///
493    /// The transaction will roll back by default - use the `commit` method to commit it.
494    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
495        struct RollbackIfNotDone<'me> {
496            client: &'me Client,
497            done: bool,
498        }
499
500        impl<'a> Drop for RollbackIfNotDone<'a> {
501            fn drop(&mut self) {
502                if self.done {
503                    return;
504                }
505
506                let buf = self.client.inner().with_buf(|buf| {
507                    frontend::query("ROLLBACK", buf).unwrap();
508                    buf.split().freeze()
509                });
510                let _ = self
511                    .client
512                    .inner()
513                    .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
514            }
515        }
516
517        // This is done, as `Future` created by this method can be dropped after
518        // `RequestMessages` is synchronously send to the `Connection` by
519        // `batch_execute()`, but before `Responses` is asynchronously polled to
520        // completion. In that case `Transaction` won't be created and thus
521        // won't be rolled back.
522        {
523            let mut cleaner = RollbackIfNotDone {
524                client: self,
525                done: false,
526            };
527            self.batch_execute("BEGIN").await?;
528            cleaner.done = true;
529        }
530
531        Ok(Transaction::new(self))
532    }
533
534    /// Returns a builder for a transaction with custom settings.
535    ///
536    /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
537    /// attributes.
538    pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
539        TransactionBuilder::new(self)
540    }
541
542    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
543    /// connection associated with this client.
544    pub fn cancel_token(&self) -> CancelToken {
545        CancelToken {
546            #[cfg(feature = "runtime")]
547            socket_config: self.socket_config.clone(),
548            ssl_mode: self.ssl_mode,
549            process_id: self.process_id,
550            secret_key: self.secret_key,
551        }
552    }
553
554    /// Attempts to cancel an in-progress query.
555    ///
556    /// The server provides no information about whether a cancellation attempt was successful or not. An error will
557    /// only be returned if the client was unable to connect to the database.
558    ///
559    /// Requires the `runtime` Cargo feature (enabled by default).
560    #[cfg(feature = "runtime")]
561    #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")]
562    pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
563    where
564        T: MakeTlsConnect<Socket>,
565    {
566        self.cancel_token().cancel_query(tls).await
567    }
568
569    /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
570    /// connection itself.
571    #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")]
572    pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
573    where
574        S: AsyncRead + AsyncWrite + Unpin,
575        T: TlsConnect<S>,
576    {
577        self.cancel_token().cancel_query_raw(stream, tls).await
578    }
579
580    /// Clears the client's type information cache.
581    ///
582    /// When user-defined types are used in a query, the client loads their definitions from the database and caches
583    /// them for the lifetime of the client. If those definitions are changed in the database, this method can be used
584    /// to flush the local cache and allow the new, updated definitions to be loaded.
585    pub fn clear_type_cache(&self) {
586        self.inner().clear_type_cache();
587    }
588
589    /// Determines if the connection to the server has already closed.
590    ///
591    /// In that case, all future queries will fail.
592    pub fn is_closed(&self) -> bool {
593        self.inner.sender.is_closed()
594    }
595
596    #[doc(hidden)]
597    pub fn __private_api_close(&mut self) {
598        self.inner.sender.close_channel()
599    }
600}
601
602impl fmt::Debug for Client {
603    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604        f.debug_struct("Client").finish()
605    }
606}