madsim_tokio_postgres/
connect_raw.rs

1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::config::{self, Config};
3use crate::connect_tls::connect_tls;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::tls::{TlsConnect, TlsStream};
6use crate::{Client, Connection, Error};
7use bytes::BytesMut;
8use fallible_iterator::FallibleIterator;
9use futures::channel::mpsc;
10use futures::{ready, Sink, SinkExt, Stream, TryStreamExt};
11use postgres_protocol::authentication;
12use postgres_protocol::authentication::sasl;
13use postgres_protocol::authentication::sasl::ScramSha256;
14use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
15use postgres_protocol::message::frontend;
16use std::collections::{HashMap, VecDeque};
17use std::io;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tokio::io::{AsyncRead, AsyncWrite};
21use tokio_util::codec::Framed;
22
23pub struct StartupStream<S, T> {
24    inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
25    buf: BackendMessages,
26    delayed: VecDeque<BackendMessage>,
27}
28
29impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
30where
31    S: AsyncRead + AsyncWrite + Unpin,
32    T: AsyncRead + AsyncWrite + Unpin,
33{
34    type Error = io::Error;
35
36    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
37        Pin::new(&mut self.inner).poll_ready(cx)
38    }
39
40    fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
41        Pin::new(&mut self.inner).start_send(item)
42    }
43
44    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
45        Pin::new(&mut self.inner).poll_flush(cx)
46    }
47
48    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
49        Pin::new(&mut self.inner).poll_close(cx)
50    }
51}
52
53impl<S, T> Stream for StartupStream<S, T>
54where
55    S: AsyncRead + AsyncWrite + Unpin,
56    T: AsyncRead + AsyncWrite + Unpin,
57{
58    type Item = io::Result<Message>;
59
60    fn poll_next(
61        mut self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63    ) -> Poll<Option<io::Result<Message>>> {
64        loop {
65            match self.buf.next() {
66                Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
67                Ok(None) => {}
68                Err(e) => return Poll::Ready(Some(Err(e))),
69            }
70
71            match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
72                Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
73                Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
74                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
75                None => return Poll::Ready(None),
76            }
77        }
78    }
79}
80
81pub async fn connect_raw<S, T>(
82    stream: S,
83    tls: T,
84    config: &Config,
85) -> Result<(Client, Connection<S, T::Stream>), Error>
86where
87    S: AsyncRead + AsyncWrite + Unpin,
88    T: TlsConnect<S>,
89{
90    let stream = connect_tls(stream, config.ssl_mode, tls).await?;
91
92    let mut stream = StartupStream {
93        inner: Framed::new(stream, PostgresCodec),
94        buf: BackendMessages::empty(),
95        delayed: VecDeque::new(),
96    };
97
98    startup(&mut stream, config).await?;
99    authenticate(&mut stream, config).await?;
100    let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
101
102    let (sender, receiver) = mpsc::unbounded();
103    let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
104    let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
105
106    Ok((client, connection))
107}
108
109async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
110where
111    S: AsyncRead + AsyncWrite + Unpin,
112    T: AsyncRead + AsyncWrite + Unpin,
113{
114    let mut params = vec![("client_encoding", "UTF8")];
115    if let Some(user) = &config.user {
116        params.push(("user", &**user));
117    }
118    if let Some(dbname) = &config.dbname {
119        params.push(("database", &**dbname));
120    }
121    if let Some(options) = &config.options {
122        params.push(("options", &**options));
123    }
124    if let Some(application_name) = &config.application_name {
125        params.push(("application_name", &**application_name));
126    }
127
128    let mut buf = BytesMut::new();
129    frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
130
131    stream
132        .send(FrontendMessage::Raw(buf.freeze()))
133        .await
134        .map_err(Error::io)
135}
136
137async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
138where
139    S: AsyncRead + AsyncWrite + Unpin,
140    T: TlsStream + Unpin,
141{
142    match stream.try_next().await.map_err(Error::io)? {
143        Some(Message::AuthenticationOk) => {
144            can_skip_channel_binding(config)?;
145            return Ok(());
146        }
147        Some(Message::AuthenticationCleartextPassword) => {
148            can_skip_channel_binding(config)?;
149
150            let pass = config
151                .password
152                .as_ref()
153                .ok_or_else(|| Error::config("password missing".into()))?;
154
155            authenticate_password(stream, pass).await?;
156        }
157        Some(Message::AuthenticationMd5Password(body)) => {
158            can_skip_channel_binding(config)?;
159
160            let user = config
161                .user
162                .as_ref()
163                .ok_or_else(|| Error::config("user missing".into()))?;
164            let pass = config
165                .password
166                .as_ref()
167                .ok_or_else(|| Error::config("password missing".into()))?;
168
169            let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
170            authenticate_password(stream, output.as_bytes()).await?;
171        }
172        Some(Message::AuthenticationSasl(body)) => {
173            authenticate_sasl(stream, body, config).await?;
174        }
175        Some(Message::AuthenticationKerberosV5)
176        | Some(Message::AuthenticationScmCredential)
177        | Some(Message::AuthenticationGss)
178        | Some(Message::AuthenticationSspi) => {
179            return Err(Error::authentication(
180                "unsupported authentication method".into(),
181            ))
182        }
183        Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
184        Some(_) => return Err(Error::unexpected_message()),
185        None => return Err(Error::closed()),
186    }
187
188    match stream.try_next().await.map_err(Error::io)? {
189        Some(Message::AuthenticationOk) => Ok(()),
190        Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
191        Some(_) => Err(Error::unexpected_message()),
192        None => Err(Error::closed()),
193    }
194}
195
196fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
197    match config.channel_binding {
198        config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
199        config::ChannelBinding::Require => Err(Error::authentication(
200            "server did not use channel binding".into(),
201        )),
202    }
203}
204
205async fn authenticate_password<S, T>(
206    stream: &mut StartupStream<S, T>,
207    password: &[u8],
208) -> Result<(), Error>
209where
210    S: AsyncRead + AsyncWrite + Unpin,
211    T: AsyncRead + AsyncWrite + Unpin,
212{
213    let mut buf = BytesMut::new();
214    frontend::password_message(password, &mut buf).map_err(Error::encode)?;
215
216    stream
217        .send(FrontendMessage::Raw(buf.freeze()))
218        .await
219        .map_err(Error::io)
220}
221
222async fn authenticate_sasl<S, T>(
223    stream: &mut StartupStream<S, T>,
224    body: AuthenticationSaslBody,
225    config: &Config,
226) -> Result<(), Error>
227where
228    S: AsyncRead + AsyncWrite + Unpin,
229    T: TlsStream + Unpin,
230{
231    let password = config
232        .password
233        .as_ref()
234        .ok_or_else(|| Error::config("password missing".into()))?;
235
236    let mut has_scram = false;
237    let mut has_scram_plus = false;
238    let mut mechanisms = body.mechanisms();
239    while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
240        match mechanism {
241            sasl::SCRAM_SHA_256 => has_scram = true,
242            sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
243            _ => {}
244        }
245    }
246
247    let channel_binding = stream
248        .inner
249        .get_ref()
250        .channel_binding()
251        .tls_server_end_point
252        .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
253        .map(sasl::ChannelBinding::tls_server_end_point);
254
255    let (channel_binding, mechanism) = if has_scram_plus {
256        match channel_binding {
257            Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
258            None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
259        }
260    } else if has_scram {
261        match channel_binding {
262            Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
263            None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
264        }
265    } else {
266        return Err(Error::authentication("unsupported SASL mechanism".into()));
267    };
268
269    if mechanism != sasl::SCRAM_SHA_256_PLUS {
270        can_skip_channel_binding(config)?;
271    }
272
273    let mut scram = ScramSha256::new(password, channel_binding);
274
275    let mut buf = BytesMut::new();
276    frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
277    stream
278        .send(FrontendMessage::Raw(buf.freeze()))
279        .await
280        .map_err(Error::io)?;
281
282    let body = match stream.try_next().await.map_err(Error::io)? {
283        Some(Message::AuthenticationSaslContinue(body)) => body,
284        Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
285        Some(_) => return Err(Error::unexpected_message()),
286        None => return Err(Error::closed()),
287    };
288
289    scram
290        .update(body.data())
291        .map_err(|e| Error::authentication(e.into()))?;
292
293    let mut buf = BytesMut::new();
294    frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
295    stream
296        .send(FrontendMessage::Raw(buf.freeze()))
297        .await
298        .map_err(Error::io)?;
299
300    let body = match stream.try_next().await.map_err(Error::io)? {
301        Some(Message::AuthenticationSaslFinal(body)) => body,
302        Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
303        Some(_) => return Err(Error::unexpected_message()),
304        None => return Err(Error::closed()),
305    };
306
307    scram
308        .finish(body.data())
309        .map_err(|e| Error::authentication(e.into()))?;
310
311    Ok(())
312}
313
314async fn read_info<S, T>(
315    stream: &mut StartupStream<S, T>,
316) -> Result<(i32, i32, HashMap<String, String>), Error>
317where
318    S: AsyncRead + AsyncWrite + Unpin,
319    T: AsyncRead + AsyncWrite + Unpin,
320{
321    let mut process_id = 0;
322    let mut secret_key = 0;
323    let mut parameters = HashMap::new();
324
325    loop {
326        match stream.try_next().await.map_err(Error::io)? {
327            Some(Message::BackendKeyData(body)) => {
328                process_id = body.process_id();
329                secret_key = body.secret_key();
330            }
331            Some(Message::ParameterStatus(body)) => {
332                parameters.insert(
333                    body.name().map_err(Error::parse)?.to_string(),
334                    body.value().map_err(Error::parse)?.to_string(),
335                );
336            }
337            Some(msg @ Message::NoticeResponse(_)) => {
338                stream.delayed.push_back(BackendMessage::Async(msg))
339            }
340            Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
341            Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
342            Some(_) => return Err(Error::unexpected_message()),
343            None => return Err(Error::closed()),
344        }
345    }
346}