1use bytes::BytesMut;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8
9use crate::error::PgWireError;
10use crate::protocol::backend;
11use crate::protocol::frontend;
12use crate::protocol::types::{BackendMsg, FrontendMsg, RawRow};
13use crate::scram::ScramClient;
14use crate::tls::{MaybeTlsStream, TlsMode};
15
16pub struct WireConn {
19 pub(crate) stream: MaybeTlsStream,
20 recv_buf: BytesMut,
21 pub(crate) pid: i32,
22 pub(crate) secret: i32,
23 pub params: std::collections::HashMap<String, String>,
37 pub(crate) auth_mechanism: &'static str,
43}
44
45impl WireConn {
46 pub fn pid(&self) -> i32 {
51 self.pid
52 }
53
54 pub fn auth_mechanism(&self) -> &'static str {
60 self.auth_mechanism
61 }
62}
63
64impl std::fmt::Debug for WireConn {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("WireConn")
67 .field("pid", &self.pid)
68 .field("params", &self.params)
69 .finish_non_exhaustive()
70 }
71}
72
73const RECV_BUF_SIZE: usize = 32 * 1024; impl WireConn {
76 #[allow(clippy::result_large_err)]
79 fn choose_scram_mechanism(
80 &self,
81 mechanisms: &[String],
82 ) -> Result<(crate::scram::ChannelBinding, &'static [u8], &'static str), PgWireError> {
83 #[cfg(feature = "tls")]
85 if let MaybeTlsStream::Tls(ref tls) = self.stream {
86 if mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS") {
87 if let Some(certs) = tls.get_ref().1.peer_certificates() {
88 if let Some(cert) = certs.first() {
89 let hash = crate::cert_hash::cert_signature_hash(cert.as_ref());
90 return Ok((
91 crate::scram::ChannelBinding::TlsServerEndPoint(hash),
92 b"SCRAM-SHA-256-PLUS",
93 "SCRAM-SHA-256-PLUS",
94 ));
95 }
96 }
97 }
98 }
99
100 if mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
102 Ok((
103 crate::scram::ChannelBinding::None,
104 b"SCRAM-SHA-256",
105 "SCRAM-SHA-256",
106 ))
107 } else {
108 Err(PgWireError::Protocol(format!(
109 "No supported SASL mechanism: {:?}",
110 mechanisms
111 )))
112 }
113 }
114
115 pub fn has_pending_data(&self) -> bool {
117 !self.recv_buf.is_empty()
118 }
119
120 pub async fn connect(
122 addr: &str,
123 user: &str,
124 password: &str,
125 database: &str,
126 ) -> Result<Self, PgWireError> {
127 Self::connect_with_options(addr, user, password, database, &[], TlsMode::default()).await
128 }
129
130 pub async fn connect_with_params(
145 addr: &str,
146 user: &str,
147 password: &str,
148 database: &str,
149 startup_params: &[(&str, &str)],
150 ) -> Result<Self, PgWireError> {
151 Self::connect_with_options(
152 addr,
153 user,
154 password,
155 database,
156 startup_params,
157 TlsMode::default(),
158 )
159 .await
160 }
161
162 pub async fn connect_with_options(
169 addr: &str,
170 user: &str,
171 password: &str,
172 database: &str,
173 startup_params: &[(&str, &str)],
174 tls_mode: TlsMode,
175 ) -> Result<Self, PgWireError> {
176 #[cfg(feature = "tls")]
177 {
178 Self::connect_with_tls_config(
179 addr,
180 user,
181 password,
182 database,
183 startup_params,
184 tls_mode,
185 &crate::tls::TlsConfig::default(),
186 )
187 .await
188 }
189 #[cfg(not(feature = "tls"))]
190 {
191 Self::connect_inner(addr, user, password, database, startup_params, tls_mode).await
192 }
193 }
194
195 #[cfg(feature = "tls")]
200 pub async fn connect_with_tls_config(
201 addr: &str,
202 user: &str,
203 password: &str,
204 database: &str,
205 startup_params: &[(&str, &str)],
206 tls_mode: TlsMode,
207 tls_config: &crate::tls::TlsConfig,
208 ) -> Result<Self, PgWireError> {
209 Self::connect_inner(
210 addr,
211 user,
212 password,
213 database,
214 startup_params,
215 tls_mode,
216 tls_config,
217 )
218 .await
219 }
220
221 #[cfg(feature = "tls")]
222 async fn connect_inner(
223 addr: &str,
224 user: &str,
225 password: &str,
226 database: &str,
227 startup_params: &[(&str, &str)],
228 tls_mode: TlsMode,
229 tls_config: &crate::tls::TlsConfig,
230 ) -> Result<Self, PgWireError> {
231 let stream = TcpStream::connect(addr).await?;
232 stream.set_nodelay(true)?;
233
234 let socket = socket2::SockRef::from(&stream);
235 let keepalive = socket2::TcpKeepalive::new()
236 .with_time(std::time::Duration::from_secs(60))
237 .with_interval(std::time::Duration::from_secs(15));
238 let _ = socket.set_tcp_keepalive(&keepalive);
239
240 let hostname = parse_hostname(addr);
241 let stream =
242 crate::tls::negotiate_tls_with_config(stream, &hostname, tls_config, tls_mode).await?;
243
244 Self::finish_startup(stream, user, password, database, startup_params).await
245 }
246
247 #[cfg(not(feature = "tls"))]
248 async fn connect_inner(
249 addr: &str,
250 user: &str,
251 password: &str,
252 database: &str,
253 startup_params: &[(&str, &str)],
254 tls_mode: TlsMode,
255 ) -> Result<Self, PgWireError> {
256 let stream = TcpStream::connect(addr).await?;
257 stream.set_nodelay(true)?;
258
259 let socket = socket2::SockRef::from(&stream);
260 let keepalive = socket2::TcpKeepalive::new()
261 .with_time(std::time::Duration::from_secs(60))
262 .with_interval(std::time::Duration::from_secs(15));
263 let _ = socket.set_tcp_keepalive(&keepalive);
264
265 if tls_mode == TlsMode::Require {
266 return Err(PgWireError::Protocol(
267 "sslmode=require but pg-wired was built without the `tls` feature".into(),
268 ));
269 }
270 let stream = MaybeTlsStream::Plain(stream);
271
272 Self::finish_startup(stream, user, password, database, startup_params).await
273 }
274
275 async fn finish_startup(
276 stream: MaybeTlsStream,
277 user: &str,
278 password: &str,
279 database: &str,
280 startup_params: &[(&str, &str)],
281 ) -> Result<Self, PgWireError> {
282 let mut conn = WireConn {
283 stream,
284 recv_buf: BytesMut::with_capacity(RECV_BUF_SIZE),
285 pid: 0,
286 secret: 0,
287 params: std::collections::HashMap::new(),
288 auth_mechanism: "trust",
291 };
292
293 let mut buf = BytesMut::new();
295 frontend::encode_startup_with_params(user, database, startup_params, &mut buf);
296 conn.send_raw(&buf).await?;
297
298 loop {
300 let msg = conn.recv_msg().await?;
301 match msg {
302 BackendMsg::AuthenticationOk => {}
303 BackendMsg::AuthenticationCleartextPassword => {
304 conn.auth_mechanism = "cleartext";
305 let mut buf = BytesMut::new();
306 frontend::encode_password(password.as_bytes(), &mut buf);
307 conn.send_raw(&buf).await?;
308 }
309 BackendMsg::AuthenticationMd5Password { salt } => {
310 conn.auth_mechanism = "md5";
311 let hash = frontend::md5_password(user, password, &salt);
312 let mut buf = BytesMut::new();
313 frontend::encode_password(&hash, &mut buf);
314 conn.send_raw(&buf).await?;
315 }
316 BackendMsg::AuthenticationSASL { mechanisms } => {
317 let (cb, mechanism, name) = conn.choose_scram_mechanism(&mechanisms)?;
319 conn.auth_mechanism = name;
320 let (scram, client_first) = ScramClient::new(password, cb);
321 let mut buf = BytesMut::new();
322 frontend::encode_message(
323 &FrontendMsg::SASLInitialResponse {
324 mechanism,
325 data: &client_first,
326 },
327 &mut buf,
328 );
329 conn.send_raw(&buf).await?;
330
331 let server_first = loop {
333 match conn.recv_msg().await? {
334 BackendMsg::AuthenticationSASLContinue { data } => break data,
335 BackendMsg::ErrorResponse { fields } => {
336 return Err(PgWireError::Pg(fields));
337 }
338 _ => {}
339 }
340 };
341
342 let client_final = scram
343 .process_server_first(&server_first)
344 .map_err(PgWireError::Protocol)?;
345 let mut buf = BytesMut::new();
346 frontend::encode_message(&FrontendMsg::SASLResponse(&client_final), &mut buf);
347 conn.send_raw(&buf).await?;
348
349 loop {
351 match conn.recv_msg().await? {
352 BackendMsg::AuthenticationSASLFinal { .. } => {}
353 BackendMsg::AuthenticationOk => break,
354 BackendMsg::ErrorResponse { fields } => {
355 return Err(PgWireError::Pg(fields));
356 }
357 _ => {}
358 }
359 }
360 }
361 BackendMsg::ParameterStatus { name, value } => {
362 tracing::debug!(name = %name, value = %value, "server parameter");
363 conn.params.insert(name, value);
364 }
365 BackendMsg::BackendKeyData { pid, secret } => {
366 conn.pid = pid;
367 conn.secret = secret;
368 }
369 BackendMsg::ReadyForQuery { .. } => break,
370 BackendMsg::ErrorResponse { fields } => {
371 return Err(PgWireError::Pg(fields));
372 }
373 BackendMsg::NoticeResponse { .. } => {}
374 other => {
375 tracing::debug!("Startup: ignoring {:?}", other);
376 }
377 }
378 }
379
380 Ok(conn)
381 }
382
383 pub async fn send_raw(&mut self, buf: &[u8]) -> Result<(), PgWireError> {
385 self.stream.write_all(buf).await?;
386 Ok(())
387 }
388
389 pub async fn recv_msg(&mut self) -> Result<BackendMsg, PgWireError> {
392 loop {
393 if let Some(msg) =
395 backend::parse_message(&mut self.recv_buf).map_err(PgWireError::Protocol)?
396 {
397 return Ok(msg);
398 }
399
400 let n = self.stream.read_buf(&mut self.recv_buf).await?;
402 if n == 0 {
403 if let Some(msg) =
405 backend::parse_message(&mut self.recv_buf).map_err(PgWireError::Protocol)?
406 {
407 return Ok(msg);
408 }
409 return Err(PgWireError::ConnectionClosed);
410 }
411 }
412 }
413
414 pub async fn collect_rows(&mut self) -> Result<(Vec<RawRow>, String), PgWireError> {
417 let mut rows = Vec::new();
418 let mut tag = String::new();
419
420 loop {
421 let msg = self.recv_msg().await?;
422 match msg {
423 BackendMsg::DataRow(row) => {
424 tracing::trace!("collect_rows: DataRow with {} cols", row.len());
425 rows.push(row);
426 }
427 BackendMsg::CommandComplete { tag: t } => tag = t,
428 BackendMsg::ReadyForQuery { .. } => return Ok((rows, tag)),
429 BackendMsg::ParseComplete | BackendMsg::BindComplete | BackendMsg::NoData => {}
430 BackendMsg::RowDescription { .. } => {}
431 BackendMsg::ErrorResponse { fields } => {
432 self.drain_until_ready().await?;
434 return Err(PgWireError::Pg(fields));
435 }
436 BackendMsg::NoticeResponse { .. } => {}
437 BackendMsg::EmptyQueryResponse => {}
438 _ => {}
439 }
440 }
441 }
442
443 pub async fn describe_statement(
447 &mut self,
448 sql: &str,
449 ) -> Result<(Vec<u32>, Vec<crate::protocol::types::FieldDescription>), PgWireError> {
450 use crate::protocol::frontend;
451 use crate::protocol::types::FrontendMsg;
452 let mut buf = bytes::BytesMut::with_capacity(256);
453
454 frontend::encode_message(
456 &FrontendMsg::Parse {
457 name: b"",
458 sql: sql.as_bytes(),
459 param_oids: &[],
460 },
461 &mut buf,
462 );
463 frontend::encode_message(
465 &FrontendMsg::Describe {
466 kind: b'S',
467 name: b"",
468 },
469 &mut buf,
470 );
471 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
473
474 self.send_raw(&buf).await?;
475
476 let mut param_oids = Vec::new();
477 let mut fields = Vec::new();
478
479 loop {
480 let msg = self.recv_msg().await?;
481 match msg {
482 BackendMsg::ParseComplete => {}
483 BackendMsg::ParameterDescription { type_oids } => {
484 param_oids = type_oids;
485 }
486 BackendMsg::RowDescription { fields: f } => {
487 fields = f;
488 }
489 BackendMsg::NoData => {} BackendMsg::ReadyForQuery { .. } => {
491 return Ok((param_oids, fields));
492 }
493 BackendMsg::ErrorResponse { fields } => {
494 self.drain_until_ready().await?;
495 return Err(PgWireError::Pg(fields));
496 }
497 _ => {}
498 }
499 }
500 }
501
502 pub async fn drain_until_ready(&mut self) -> Result<(), PgWireError> {
506 loop {
507 let msg = self.recv_msg().await?;
508 if matches!(msg, BackendMsg::ReadyForQuery { .. }) {
509 return Ok(());
510 }
511 if let BackendMsg::ErrorResponse { ref fields } = msg {
513 tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
514 }
515 }
516 }
517}
518
519#[cfg(any(feature = "tls", test))]
520fn parse_hostname(addr: &str) -> String {
523 if addr.starts_with('[') {
524 if let Some(end) = addr.find(']') {
526 return addr[1..end].to_string();
527 }
528 }
529 addr.split(':').next().unwrap_or(addr).to_string()
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_parse_hostname_ipv4() {
539 assert_eq!(parse_hostname("127.0.0.1:5432"), "127.0.0.1");
540 }
541
542 #[test]
543 fn test_parse_hostname_name() {
544 assert_eq!(parse_hostname("localhost:5432"), "localhost");
545 }
546
547 #[test]
548 fn test_parse_hostname_ipv6() {
549 assert_eq!(parse_hostname("[::1]:5432"), "::1");
550 }
551
552 #[test]
553 fn test_parse_hostname_ipv6_full() {
554 assert_eq!(parse_hostname("[2001:db8::1]:5432"), "2001:db8::1");
555 }
556
557 #[test]
558 fn test_parse_hostname_no_port() {
559 assert_eq!(parse_hostname("myhost"), "myhost");
560 }
561}