madsim_tokio_postgres/
connection.rs

1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_in::CopyInReceiver;
3use crate::error::DbError;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::{AsyncMessage, Error, Notification};
6use bytes::BytesMut;
7use fallible_iterator::FallibleIterator;
8use futures::channel::mpsc;
9use futures::stream::FusedStream;
10use futures::{ready, Sink, Stream, StreamExt};
11use log::{info, trace};
12use postgres_protocol::message::backend::Message;
13use postgres_protocol::message::frontend;
14use std::collections::{HashMap, VecDeque};
15use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::Framed;
20
21pub enum RequestMessages {
22    Single(FrontendMessage),
23    CopyIn(CopyInReceiver),
24}
25
26pub struct Request {
27    pub messages: RequestMessages,
28    pub sender: mpsc::Sender<BackendMessages>,
29}
30
31pub struct Response {
32    sender: mpsc::Sender<BackendMessages>,
33}
34
35#[derive(PartialEq, Debug)]
36enum State {
37    Active,
38    Terminating,
39    Closing,
40}
41
42/// A connection to a PostgreSQL database.
43///
44/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
45/// server, and should generally be spawned off onto an executor to run in the background.
46///
47/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
48/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
49#[must_use = "futures do nothing unless polled"]
50pub struct Connection<S, T> {
51    stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
52    parameters: HashMap<String, String>,
53    receiver: mpsc::UnboundedReceiver<Request>,
54    pending_request: Option<RequestMessages>,
55    pending_responses: VecDeque<BackendMessage>,
56    responses: VecDeque<Response>,
57    state: State,
58}
59
60impl<S, T> Connection<S, T>
61where
62    S: AsyncRead + AsyncWrite + Unpin,
63    T: AsyncRead + AsyncWrite + Unpin,
64{
65    pub(crate) fn new(
66        stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
67        pending_responses: VecDeque<BackendMessage>,
68        parameters: HashMap<String, String>,
69        receiver: mpsc::UnboundedReceiver<Request>,
70    ) -> Connection<S, T> {
71        Connection {
72            stream,
73            parameters,
74            receiver,
75            pending_request: None,
76            pending_responses,
77            responses: VecDeque::new(),
78            state: State::Active,
79        }
80    }
81
82    fn poll_response(
83        &mut self,
84        cx: &mut Context<'_>,
85    ) -> Poll<Option<Result<BackendMessage, Error>>> {
86        if let Some(message) = self.pending_responses.pop_front() {
87            trace!("retrying pending response");
88            return Poll::Ready(Some(Ok(message)));
89        }
90
91        Pin::new(&mut self.stream)
92            .poll_next(cx)
93            .map(|o| o.map(|r| r.map_err(Error::io)))
94    }
95
96    fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
97        if self.state != State::Active {
98            trace!("poll_read: done");
99            return Ok(None);
100        }
101
102        loop {
103            let message = match self.poll_response(cx)? {
104                Poll::Ready(Some(message)) => message,
105                Poll::Ready(None) => return Err(Error::closed()),
106                Poll::Pending => {
107                    trace!("poll_read: waiting on response");
108                    return Ok(None);
109                }
110            };
111
112            let (mut messages, request_complete) = match message {
113                BackendMessage::Async(Message::NoticeResponse(body)) => {
114                    let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
115                    return Ok(Some(AsyncMessage::Notice(error)));
116                }
117                BackendMessage::Async(Message::NotificationResponse(body)) => {
118                    let notification = Notification {
119                        process_id: body.process_id(),
120                        channel: body.channel().map_err(Error::parse)?.to_string(),
121                        payload: body.message().map_err(Error::parse)?.to_string(),
122                    };
123                    return Ok(Some(AsyncMessage::Notification(notification)));
124                }
125                BackendMessage::Async(Message::ParameterStatus(body)) => {
126                    self.parameters.insert(
127                        body.name().map_err(Error::parse)?.to_string(),
128                        body.value().map_err(Error::parse)?.to_string(),
129                    );
130                    continue;
131                }
132                BackendMessage::Async(_) => unreachable!(),
133                BackendMessage::Normal {
134                    messages,
135                    request_complete,
136                } => (messages, request_complete),
137            };
138
139            let mut response = match self.responses.pop_front() {
140                Some(response) => response,
141                None => match messages.next().map_err(Error::parse)? {
142                    Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
143                    _ => return Err(Error::unexpected_message()),
144                },
145            };
146
147            match response.sender.poll_ready(cx) {
148                Poll::Ready(Ok(())) => {
149                    let _ = response.sender.start_send(messages);
150                    if !request_complete {
151                        self.responses.push_front(response);
152                    }
153                }
154                Poll::Ready(Err(_)) => {
155                    // we need to keep paging through the rest of the messages even if the receiver's hung up
156                    if !request_complete {
157                        self.responses.push_front(response);
158                    }
159                }
160                Poll::Pending => {
161                    self.responses.push_front(response);
162                    self.pending_responses.push_back(BackendMessage::Normal {
163                        messages,
164                        request_complete,
165                    });
166                    trace!("poll_read: waiting on sender");
167                    return Ok(None);
168                }
169            }
170        }
171    }
172
173    fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
174        if let Some(messages) = self.pending_request.take() {
175            trace!("retrying pending request");
176            return Poll::Ready(Some(messages));
177        }
178
179        if self.receiver.is_terminated() {
180            return Poll::Ready(None);
181        }
182
183        match self.receiver.poll_next_unpin(cx) {
184            Poll::Ready(Some(request)) => {
185                trace!("polled new request");
186                self.responses.push_back(Response {
187                    sender: request.sender,
188                });
189                Poll::Ready(Some(request.messages))
190            }
191            Poll::Ready(None) => Poll::Ready(None),
192            Poll::Pending => Poll::Pending,
193        }
194    }
195
196    fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
197        loop {
198            if self.state == State::Closing {
199                trace!("poll_write: done");
200                return Ok(false);
201            }
202
203            if Pin::new(&mut self.stream)
204                .poll_ready(cx)
205                .map_err(Error::io)?
206                .is_pending()
207            {
208                trace!("poll_write: waiting on socket");
209                return Ok(false);
210            }
211
212            let request = match self.poll_request(cx) {
213                Poll::Ready(Some(request)) => request,
214                Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
215                    trace!("poll_write: at eof, terminating");
216                    self.state = State::Terminating;
217                    let mut request = BytesMut::new();
218                    frontend::terminate(&mut request);
219                    RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
220                }
221                Poll::Ready(None) => {
222                    trace!(
223                        "poll_write: at eof, pending responses {}",
224                        self.responses.len()
225                    );
226                    return Ok(true);
227                }
228                Poll::Pending => {
229                    trace!("poll_write: waiting on request");
230                    return Ok(true);
231                }
232            };
233
234            match request {
235                RequestMessages::Single(request) => {
236                    Pin::new(&mut self.stream)
237                        .start_send(request)
238                        .map_err(Error::io)?;
239                    if self.state == State::Terminating {
240                        trace!("poll_write: sent eof, closing");
241                        self.state = State::Closing;
242                    }
243                }
244                RequestMessages::CopyIn(mut receiver) => {
245                    let message = match receiver.poll_next_unpin(cx) {
246                        Poll::Ready(Some(message)) => message,
247                        Poll::Ready(None) => {
248                            trace!("poll_write: finished copy_in request");
249                            continue;
250                        }
251                        Poll::Pending => {
252                            trace!("poll_write: waiting on copy_in stream");
253                            self.pending_request = Some(RequestMessages::CopyIn(receiver));
254                            return Ok(true);
255                        }
256                    };
257                    Pin::new(&mut self.stream)
258                        .start_send(message)
259                        .map_err(Error::io)?;
260                    self.pending_request = Some(RequestMessages::CopyIn(receiver));
261                }
262            }
263        }
264    }
265
266    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
267        match Pin::new(&mut self.stream)
268            .poll_flush(cx)
269            .map_err(Error::io)?
270        {
271            Poll::Ready(()) => trace!("poll_flush: flushed"),
272            Poll::Pending => trace!("poll_flush: waiting on socket"),
273        }
274        Ok(())
275    }
276
277    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
278        if self.state != State::Closing {
279            return Poll::Pending;
280        }
281
282        match Pin::new(&mut self.stream)
283            .poll_close(cx)
284            .map_err(Error::io)?
285        {
286            Poll::Ready(()) => {
287                trace!("poll_shutdown: complete");
288                Poll::Ready(Ok(()))
289            }
290            Poll::Pending => {
291                trace!("poll_shutdown: waiting on socket");
292                Poll::Pending
293            }
294        }
295    }
296
297    /// Returns the value of a runtime parameter for this connection.
298    pub fn parameter(&self, name: &str) -> Option<&str> {
299        self.parameters.get(name).map(|s| &**s)
300    }
301
302    /// Polls for asynchronous messages from the server.
303    ///
304    /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
305    /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
306    pub fn poll_message(
307        &mut self,
308        cx: &mut Context<'_>,
309    ) -> Poll<Option<Result<AsyncMessage, Error>>> {
310        let message = self.poll_read(cx)?;
311        let want_flush = self.poll_write(cx)?;
312        if want_flush {
313            self.poll_flush(cx)?;
314        }
315        match message {
316            Some(message) => Poll::Ready(Some(Ok(message))),
317            None => match self.poll_shutdown(cx) {
318                Poll::Ready(Ok(())) => Poll::Ready(None),
319                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
320                Poll::Pending => Poll::Pending,
321            },
322        }
323    }
324}
325
326impl<S, T> Future for Connection<S, T>
327where
328    S: AsyncRead + AsyncWrite + Unpin,
329    T: AsyncRead + AsyncWrite + Unpin,
330{
331    type Output = Result<(), Error>;
332
333    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
334        while let Some(message) = ready!(self.poll_message(cx)?) {
335            if let AsyncMessage::Notice(notice) = message {
336                info!("{}: {}", notice.severity(), notice.message());
337            }
338        }
339        Poll::Ready(Ok(()))
340    }
341}