1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_in::CopyInReceiver;
3use crate::error::DbError;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::{AsyncMessage, Error, Notification};
6use bytes::BytesMut;
7use fallible_iterator::FallibleIterator;
8use futures::channel::mpsc;
9use futures::stream::FusedStream;
10use futures::{ready, Sink, Stream, StreamExt};
11use log::{info, trace};
12use postgres_protocol::message::backend::Message;
13use postgres_protocol::message::frontend;
14use std::collections::{HashMap, VecDeque};
15use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::Framed;
20
21pub enum RequestMessages {
22 Single(FrontendMessage),
23 CopyIn(CopyInReceiver),
24}
25
26pub struct Request {
27 pub messages: RequestMessages,
28 pub sender: mpsc::Sender<BackendMessages>,
29}
30
31pub struct Response {
32 sender: mpsc::Sender<BackendMessages>,
33}
34
35#[derive(PartialEq, Debug)]
36enum State {
37 Active,
38 Terminating,
39 Closing,
40}
41
42#[must_use = "futures do nothing unless polled"]
50pub struct Connection<S, T> {
51 stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
52 parameters: HashMap<String, String>,
53 receiver: mpsc::UnboundedReceiver<Request>,
54 pending_request: Option<RequestMessages>,
55 pending_responses: VecDeque<BackendMessage>,
56 responses: VecDeque<Response>,
57 state: State,
58}
59
60impl<S, T> Connection<S, T>
61where
62 S: AsyncRead + AsyncWrite + Unpin,
63 T: AsyncRead + AsyncWrite + Unpin,
64{
65 pub(crate) fn new(
66 stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
67 pending_responses: VecDeque<BackendMessage>,
68 parameters: HashMap<String, String>,
69 receiver: mpsc::UnboundedReceiver<Request>,
70 ) -> Connection<S, T> {
71 Connection {
72 stream,
73 parameters,
74 receiver,
75 pending_request: None,
76 pending_responses,
77 responses: VecDeque::new(),
78 state: State::Active,
79 }
80 }
81
82 fn poll_response(
83 &mut self,
84 cx: &mut Context<'_>,
85 ) -> Poll<Option<Result<BackendMessage, Error>>> {
86 if let Some(message) = self.pending_responses.pop_front() {
87 trace!("retrying pending response");
88 return Poll::Ready(Some(Ok(message)));
89 }
90
91 Pin::new(&mut self.stream)
92 .poll_next(cx)
93 .map(|o| o.map(|r| r.map_err(Error::io)))
94 }
95
96 fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
97 if self.state != State::Active {
98 trace!("poll_read: done");
99 return Ok(None);
100 }
101
102 loop {
103 let message = match self.poll_response(cx)? {
104 Poll::Ready(Some(message)) => message,
105 Poll::Ready(None) => return Err(Error::closed()),
106 Poll::Pending => {
107 trace!("poll_read: waiting on response");
108 return Ok(None);
109 }
110 };
111
112 let (mut messages, request_complete) = match message {
113 BackendMessage::Async(Message::NoticeResponse(body)) => {
114 let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
115 return Ok(Some(AsyncMessage::Notice(error)));
116 }
117 BackendMessage::Async(Message::NotificationResponse(body)) => {
118 let notification = Notification {
119 process_id: body.process_id(),
120 channel: body.channel().map_err(Error::parse)?.to_string(),
121 payload: body.message().map_err(Error::parse)?.to_string(),
122 };
123 return Ok(Some(AsyncMessage::Notification(notification)));
124 }
125 BackendMessage::Async(Message::ParameterStatus(body)) => {
126 self.parameters.insert(
127 body.name().map_err(Error::parse)?.to_string(),
128 body.value().map_err(Error::parse)?.to_string(),
129 );
130 continue;
131 }
132 BackendMessage::Async(_) => unreachable!(),
133 BackendMessage::Normal {
134 messages,
135 request_complete,
136 } => (messages, request_complete),
137 };
138
139 let mut response = match self.responses.pop_front() {
140 Some(response) => response,
141 None => match messages.next().map_err(Error::parse)? {
142 Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
143 _ => return Err(Error::unexpected_message()),
144 },
145 };
146
147 match response.sender.poll_ready(cx) {
148 Poll::Ready(Ok(())) => {
149 let _ = response.sender.start_send(messages);
150 if !request_complete {
151 self.responses.push_front(response);
152 }
153 }
154 Poll::Ready(Err(_)) => {
155 if !request_complete {
157 self.responses.push_front(response);
158 }
159 }
160 Poll::Pending => {
161 self.responses.push_front(response);
162 self.pending_responses.push_back(BackendMessage::Normal {
163 messages,
164 request_complete,
165 });
166 trace!("poll_read: waiting on sender");
167 return Ok(None);
168 }
169 }
170 }
171 }
172
173 fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
174 if let Some(messages) = self.pending_request.take() {
175 trace!("retrying pending request");
176 return Poll::Ready(Some(messages));
177 }
178
179 if self.receiver.is_terminated() {
180 return Poll::Ready(None);
181 }
182
183 match self.receiver.poll_next_unpin(cx) {
184 Poll::Ready(Some(request)) => {
185 trace!("polled new request");
186 self.responses.push_back(Response {
187 sender: request.sender,
188 });
189 Poll::Ready(Some(request.messages))
190 }
191 Poll::Ready(None) => Poll::Ready(None),
192 Poll::Pending => Poll::Pending,
193 }
194 }
195
196 fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
197 loop {
198 if self.state == State::Closing {
199 trace!("poll_write: done");
200 return Ok(false);
201 }
202
203 if Pin::new(&mut self.stream)
204 .poll_ready(cx)
205 .map_err(Error::io)?
206 .is_pending()
207 {
208 trace!("poll_write: waiting on socket");
209 return Ok(false);
210 }
211
212 let request = match self.poll_request(cx) {
213 Poll::Ready(Some(request)) => request,
214 Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
215 trace!("poll_write: at eof, terminating");
216 self.state = State::Terminating;
217 let mut request = BytesMut::new();
218 frontend::terminate(&mut request);
219 RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
220 }
221 Poll::Ready(None) => {
222 trace!(
223 "poll_write: at eof, pending responses {}",
224 self.responses.len()
225 );
226 return Ok(true);
227 }
228 Poll::Pending => {
229 trace!("poll_write: waiting on request");
230 return Ok(true);
231 }
232 };
233
234 match request {
235 RequestMessages::Single(request) => {
236 Pin::new(&mut self.stream)
237 .start_send(request)
238 .map_err(Error::io)?;
239 if self.state == State::Terminating {
240 trace!("poll_write: sent eof, closing");
241 self.state = State::Closing;
242 }
243 }
244 RequestMessages::CopyIn(mut receiver) => {
245 let message = match receiver.poll_next_unpin(cx) {
246 Poll::Ready(Some(message)) => message,
247 Poll::Ready(None) => {
248 trace!("poll_write: finished copy_in request");
249 continue;
250 }
251 Poll::Pending => {
252 trace!("poll_write: waiting on copy_in stream");
253 self.pending_request = Some(RequestMessages::CopyIn(receiver));
254 return Ok(true);
255 }
256 };
257 Pin::new(&mut self.stream)
258 .start_send(message)
259 .map_err(Error::io)?;
260 self.pending_request = Some(RequestMessages::CopyIn(receiver));
261 }
262 }
263 }
264 }
265
266 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
267 match Pin::new(&mut self.stream)
268 .poll_flush(cx)
269 .map_err(Error::io)?
270 {
271 Poll::Ready(()) => trace!("poll_flush: flushed"),
272 Poll::Pending => trace!("poll_flush: waiting on socket"),
273 }
274 Ok(())
275 }
276
277 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
278 if self.state != State::Closing {
279 return Poll::Pending;
280 }
281
282 match Pin::new(&mut self.stream)
283 .poll_close(cx)
284 .map_err(Error::io)?
285 {
286 Poll::Ready(()) => {
287 trace!("poll_shutdown: complete");
288 Poll::Ready(Ok(()))
289 }
290 Poll::Pending => {
291 trace!("poll_shutdown: waiting on socket");
292 Poll::Pending
293 }
294 }
295 }
296
297 pub fn parameter(&self, name: &str) -> Option<&str> {
299 self.parameters.get(name).map(|s| &**s)
300 }
301
302 pub fn poll_message(
307 &mut self,
308 cx: &mut Context<'_>,
309 ) -> Poll<Option<Result<AsyncMessage, Error>>> {
310 let message = self.poll_read(cx)?;
311 let want_flush = self.poll_write(cx)?;
312 if want_flush {
313 self.poll_flush(cx)?;
314 }
315 match message {
316 Some(message) => Poll::Ready(Some(Ok(message))),
317 None => match self.poll_shutdown(cx) {
318 Poll::Ready(Ok(())) => Poll::Ready(None),
319 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
320 Poll::Pending => Poll::Pending,
321 },
322 }
323 }
324}
325
326impl<S, T> Future for Connection<S, T>
327where
328 S: AsyncRead + AsyncWrite + Unpin,
329 T: AsyncRead + AsyncWrite + Unpin,
330{
331 type Output = Result<(), Error>;
332
333 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
334 while let Some(message) = ready!(self.poll_message(cx)?) {
335 if let AsyncMessage::Notice(notice) = message {
336 info!("{}: {}", notice.severity(), notice.message());
337 }
338 }
339 Poll::Ready(Ok(()))
340 }
341}