1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::config::{self, Config};
3use crate::connect_tls::connect_tls;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::tls::{TlsConnect, TlsStream};
6use crate::{Client, Connection, Error};
7use bytes::BytesMut;
8use fallible_iterator::FallibleIterator;
9use futures::channel::mpsc;
10use futures::{ready, Sink, SinkExt, Stream, TryStreamExt};
11use postgres_protocol::authentication;
12use postgres_protocol::authentication::sasl;
13use postgres_protocol::authentication::sasl::ScramSha256;
14use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
15use postgres_protocol::message::frontend;
16use std::collections::{HashMap, VecDeque};
17use std::io;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tokio::io::{AsyncRead, AsyncWrite};
21use tokio_util::codec::Framed;
22
23pub struct StartupStream<S, T> {
24 inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
25 buf: BackendMessages,
26 delayed: VecDeque<BackendMessage>,
27}
28
29impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
30where
31 S: AsyncRead + AsyncWrite + Unpin,
32 T: AsyncRead + AsyncWrite + Unpin,
33{
34 type Error = io::Error;
35
36 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
37 Pin::new(&mut self.inner).poll_ready(cx)
38 }
39
40 fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
41 Pin::new(&mut self.inner).start_send(item)
42 }
43
44 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
45 Pin::new(&mut self.inner).poll_flush(cx)
46 }
47
48 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
49 Pin::new(&mut self.inner).poll_close(cx)
50 }
51}
52
53impl<S, T> Stream for StartupStream<S, T>
54where
55 S: AsyncRead + AsyncWrite + Unpin,
56 T: AsyncRead + AsyncWrite + Unpin,
57{
58 type Item = io::Result<Message>;
59
60 fn poll_next(
61 mut self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 ) -> Poll<Option<io::Result<Message>>> {
64 loop {
65 match self.buf.next() {
66 Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
67 Ok(None) => {}
68 Err(e) => return Poll::Ready(Some(Err(e))),
69 }
70
71 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
72 Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
73 Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
74 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
75 None => return Poll::Ready(None),
76 }
77 }
78 }
79}
80
81pub async fn connect_raw<S, T>(
82 stream: S,
83 tls: T,
84 config: &Config,
85) -> Result<(Client, Connection<S, T::Stream>), Error>
86where
87 S: AsyncRead + AsyncWrite + Unpin,
88 T: TlsConnect<S>,
89{
90 let stream = connect_tls(stream, config.ssl_mode, tls).await?;
91
92 let mut stream = StartupStream {
93 inner: Framed::new(stream, PostgresCodec),
94 buf: BackendMessages::empty(),
95 delayed: VecDeque::new(),
96 };
97
98 startup(&mut stream, config).await?;
99 authenticate(&mut stream, config).await?;
100 let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
101
102 let (sender, receiver) = mpsc::unbounded();
103 let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
104 let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
105
106 Ok((client, connection))
107}
108
109async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
110where
111 S: AsyncRead + AsyncWrite + Unpin,
112 T: AsyncRead + AsyncWrite + Unpin,
113{
114 let mut params = vec![("client_encoding", "UTF8")];
115 if let Some(user) = &config.user {
116 params.push(("user", &**user));
117 }
118 if let Some(dbname) = &config.dbname {
119 params.push(("database", &**dbname));
120 }
121 if let Some(options) = &config.options {
122 params.push(("options", &**options));
123 }
124 if let Some(application_name) = &config.application_name {
125 params.push(("application_name", &**application_name));
126 }
127
128 let mut buf = BytesMut::new();
129 frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
130
131 stream
132 .send(FrontendMessage::Raw(buf.freeze()))
133 .await
134 .map_err(Error::io)
135}
136
137async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
138where
139 S: AsyncRead + AsyncWrite + Unpin,
140 T: TlsStream + Unpin,
141{
142 match stream.try_next().await.map_err(Error::io)? {
143 Some(Message::AuthenticationOk) => {
144 can_skip_channel_binding(config)?;
145 return Ok(());
146 }
147 Some(Message::AuthenticationCleartextPassword) => {
148 can_skip_channel_binding(config)?;
149
150 let pass = config
151 .password
152 .as_ref()
153 .ok_or_else(|| Error::config("password missing".into()))?;
154
155 authenticate_password(stream, pass).await?;
156 }
157 Some(Message::AuthenticationMd5Password(body)) => {
158 can_skip_channel_binding(config)?;
159
160 let user = config
161 .user
162 .as_ref()
163 .ok_or_else(|| Error::config("user missing".into()))?;
164 let pass = config
165 .password
166 .as_ref()
167 .ok_or_else(|| Error::config("password missing".into()))?;
168
169 let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
170 authenticate_password(stream, output.as_bytes()).await?;
171 }
172 Some(Message::AuthenticationSasl(body)) => {
173 authenticate_sasl(stream, body, config).await?;
174 }
175 Some(Message::AuthenticationKerberosV5)
176 | Some(Message::AuthenticationScmCredential)
177 | Some(Message::AuthenticationGss)
178 | Some(Message::AuthenticationSspi) => {
179 return Err(Error::authentication(
180 "unsupported authentication method".into(),
181 ))
182 }
183 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
184 Some(_) => return Err(Error::unexpected_message()),
185 None => return Err(Error::closed()),
186 }
187
188 match stream.try_next().await.map_err(Error::io)? {
189 Some(Message::AuthenticationOk) => Ok(()),
190 Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
191 Some(_) => Err(Error::unexpected_message()),
192 None => Err(Error::closed()),
193 }
194}
195
196fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
197 match config.channel_binding {
198 config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
199 config::ChannelBinding::Require => Err(Error::authentication(
200 "server did not use channel binding".into(),
201 )),
202 }
203}
204
205async fn authenticate_password<S, T>(
206 stream: &mut StartupStream<S, T>,
207 password: &[u8],
208) -> Result<(), Error>
209where
210 S: AsyncRead + AsyncWrite + Unpin,
211 T: AsyncRead + AsyncWrite + Unpin,
212{
213 let mut buf = BytesMut::new();
214 frontend::password_message(password, &mut buf).map_err(Error::encode)?;
215
216 stream
217 .send(FrontendMessage::Raw(buf.freeze()))
218 .await
219 .map_err(Error::io)
220}
221
222async fn authenticate_sasl<S, T>(
223 stream: &mut StartupStream<S, T>,
224 body: AuthenticationSaslBody,
225 config: &Config,
226) -> Result<(), Error>
227where
228 S: AsyncRead + AsyncWrite + Unpin,
229 T: TlsStream + Unpin,
230{
231 let password = config
232 .password
233 .as_ref()
234 .ok_or_else(|| Error::config("password missing".into()))?;
235
236 let mut has_scram = false;
237 let mut has_scram_plus = false;
238 let mut mechanisms = body.mechanisms();
239 while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
240 match mechanism {
241 sasl::SCRAM_SHA_256 => has_scram = true,
242 sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
243 _ => {}
244 }
245 }
246
247 let channel_binding = stream
248 .inner
249 .get_ref()
250 .channel_binding()
251 .tls_server_end_point
252 .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
253 .map(sasl::ChannelBinding::tls_server_end_point);
254
255 let (channel_binding, mechanism) = if has_scram_plus {
256 match channel_binding {
257 Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
258 None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
259 }
260 } else if has_scram {
261 match channel_binding {
262 Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
263 None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
264 }
265 } else {
266 return Err(Error::authentication("unsupported SASL mechanism".into()));
267 };
268
269 if mechanism != sasl::SCRAM_SHA_256_PLUS {
270 can_skip_channel_binding(config)?;
271 }
272
273 let mut scram = ScramSha256::new(password, channel_binding);
274
275 let mut buf = BytesMut::new();
276 frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
277 stream
278 .send(FrontendMessage::Raw(buf.freeze()))
279 .await
280 .map_err(Error::io)?;
281
282 let body = match stream.try_next().await.map_err(Error::io)? {
283 Some(Message::AuthenticationSaslContinue(body)) => body,
284 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
285 Some(_) => return Err(Error::unexpected_message()),
286 None => return Err(Error::closed()),
287 };
288
289 scram
290 .update(body.data())
291 .map_err(|e| Error::authentication(e.into()))?;
292
293 let mut buf = BytesMut::new();
294 frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
295 stream
296 .send(FrontendMessage::Raw(buf.freeze()))
297 .await
298 .map_err(Error::io)?;
299
300 let body = match stream.try_next().await.map_err(Error::io)? {
301 Some(Message::AuthenticationSaslFinal(body)) => body,
302 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
303 Some(_) => return Err(Error::unexpected_message()),
304 None => return Err(Error::closed()),
305 };
306
307 scram
308 .finish(body.data())
309 .map_err(|e| Error::authentication(e.into()))?;
310
311 Ok(())
312}
313
314async fn read_info<S, T>(
315 stream: &mut StartupStream<S, T>,
316) -> Result<(i32, i32, HashMap<String, String>), Error>
317where
318 S: AsyncRead + AsyncWrite + Unpin,
319 T: AsyncRead + AsyncWrite + Unpin,
320{
321 let mut process_id = 0;
322 let mut secret_key = 0;
323 let mut parameters = HashMap::new();
324
325 loop {
326 match stream.try_next().await.map_err(Error::io)? {
327 Some(Message::BackendKeyData(body)) => {
328 process_id = body.process_id();
329 secret_key = body.secret_key();
330 }
331 Some(Message::ParameterStatus(body)) => {
332 parameters.insert(
333 body.name().map_err(Error::parse)?.to_string(),
334 body.value().map_err(Error::parse)?.to_string(),
335 );
336 }
337 Some(msg @ Message::NoticeResponse(_)) => {
338 stream.delayed.push_back(BackendMessage::Async(msg))
339 }
340 Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
341 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
342 Some(_) => return Err(Error::unexpected_message()),
343 None => return Err(Error::closed()),
344 }
345 }
346}