1use super::stream::PgStream;
16use super::{PgError, PgResult};
17use super::notification::Notification;
18use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
19use bytes::BytesMut;
20use lru::LruCache;
21use std::collections::{HashMap, VecDeque};
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26
27const STMT_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(100).unwrap();
29
30pub(crate) const BUFFER_CAPACITY: usize = 65536;
32
33const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
35
36pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
38
39pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
42
43#[derive(Clone)]
45pub struct TlsConfig {
46 pub client_cert_pem: Vec<u8>,
48 pub client_key_pem: Vec<u8>,
50 pub ca_cert_pem: Option<Vec<u8>>,
52}
53
54impl TlsConfig {
55 pub fn from_files(
57 cert_path: impl AsRef<std::path::Path>,
58 key_path: impl AsRef<std::path::Path>,
59 ca_path: Option<impl AsRef<std::path::Path>>,
60 ) -> std::io::Result<Self> {
61 Ok(Self {
62 client_cert_pem: std::fs::read(cert_path)?,
63 client_key_pem: std::fs::read(key_path)?,
64 ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
65 })
66 }
67}
68
69pub struct PgConnection {
71 pub(crate) stream: PgStream,
72 pub(crate) buffer: BytesMut,
73 pub(crate) write_buf: BytesMut,
74 pub(crate) sql_buf: BytesMut,
75 pub(crate) params_buf: Vec<Option<Vec<u8>>>,
76 pub(crate) prepared_statements: HashMap<String, String>,
77 pub(crate) stmt_cache: LruCache<u64, String>,
78 pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
82 pub(crate) process_id: i32,
83 pub(crate) secret_key: i32,
84 pub(crate) notifications: VecDeque<Notification>,
87}
88
89impl PgConnection {
90 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
99 Self::connect_with_password(host, port, user, database, None).await
100 }
101
102 pub async fn connect_with_password(
105 host: &str,
106 port: u16,
107 user: &str,
108 database: &str,
109 password: Option<&str>,
110 ) -> PgResult<Self> {
111 tokio::time::timeout(
112 DEFAULT_CONNECT_TIMEOUT,
113 Self::connect_with_password_inner(host, port, user, database, password),
114 )
115 .await
116 .map_err(|_| PgError::Connection(format!(
117 "Connection timeout after {:?} (TCP connect + handshake)",
118 DEFAULT_CONNECT_TIMEOUT
119 )))?
120 }
121
122 async fn connect_with_password_inner(
124 host: &str,
125 port: u16,
126 user: &str,
127 database: &str,
128 password: Option<&str>,
129 ) -> PgResult<Self> {
130 let addr = format!("{}:{}", host, port);
131 let tcp_stream = TcpStream::connect(&addr).await?;
132
133 tcp_stream.set_nodelay(true)?;
135
136 let mut conn = Self {
137 stream: PgStream::Tcp(tcp_stream),
138 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
139 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
141 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
143 stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
144 column_info_cache: HashMap::new(),
145 process_id: 0,
146 secret_key: 0,
147 notifications: VecDeque::new(),
148 };
149
150 conn.send(FrontendMessage::Startup {
151 user: user.to_string(),
152 database: database.to_string(),
153 })
154 .await?;
155
156 conn.handle_startup(user, password).await?;
157
158 Ok(conn)
159 }
160
161 pub async fn connect_tls(
164 host: &str,
165 port: u16,
166 user: &str,
167 database: &str,
168 password: Option<&str>,
169 ) -> PgResult<Self> {
170 tokio::time::timeout(
171 DEFAULT_CONNECT_TIMEOUT,
172 Self::connect_tls_inner(host, port, user, database, password),
173 )
174 .await
175 .map_err(|_| PgError::Connection(format!(
176 "TLS connection timeout after {:?}",
177 DEFAULT_CONNECT_TIMEOUT
178 )))?
179 }
180
181 async fn connect_tls_inner(
183 host: &str,
184 port: u16,
185 user: &str,
186 database: &str,
187 password: Option<&str>,
188 ) -> PgResult<Self> {
189 use tokio::io::AsyncReadExt;
190 use tokio_rustls::TlsConnector;
191 use tokio_rustls::rustls::ClientConfig;
192 use tokio_rustls::rustls::pki_types::ServerName;
193
194 let addr = format!("{}:{}", host, port);
195 let mut tcp_stream = TcpStream::connect(&addr).await?;
196
197 tcp_stream.write_all(&SSL_REQUEST).await?;
199
200 let mut response = [0u8; 1];
202 tcp_stream.read_exact(&mut response).await?;
203
204 if response[0] != b'S' {
205 return Err(PgError::Connection(
206 "Server does not support TLS".to_string(),
207 ));
208 }
209
210 let certs = rustls_native_certs::load_native_certs();
212 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
213 for cert in certs.certs {
214 let _ = root_cert_store.add(cert);
215 }
216
217 let config = ClientConfig::builder()
218 .with_root_certificates(root_cert_store)
219 .with_no_client_auth();
220
221 let connector = TlsConnector::from(Arc::new(config));
222 let server_name = ServerName::try_from(host.to_string())
223 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
224
225 let tls_stream = connector
226 .connect(server_name, tcp_stream)
227 .await
228 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
229
230 let mut conn = Self {
231 stream: PgStream::Tls(tls_stream),
232 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
233 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
234 sql_buf: BytesMut::with_capacity(512),
235 params_buf: Vec::with_capacity(16),
236 prepared_statements: HashMap::new(),
237 stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
238 column_info_cache: HashMap::new(),
239 process_id: 0,
240 secret_key: 0,
241 notifications: VecDeque::new(),
242 };
243
244 conn.send(FrontendMessage::Startup {
245 user: user.to_string(),
246 database: database.to_string(),
247 })
248 .await?;
249
250 conn.handle_startup(user, password).await?;
251
252 Ok(conn)
253 }
254
255 pub async fn connect_mtls(
272 host: &str,
273 port: u16,
274 user: &str,
275 database: &str,
276 config: TlsConfig,
277 ) -> PgResult<Self> {
278 use tokio::io::AsyncReadExt;
279 use tokio_rustls::TlsConnector;
280 use tokio_rustls::rustls::{
281 ClientConfig,
282 pki_types::{CertificateDer, ServerName},
283 };
284
285 let addr = format!("{}:{}", host, port);
286 let mut tcp_stream = TcpStream::connect(&addr).await?;
287
288 tcp_stream.write_all(&SSL_REQUEST).await?;
290
291 let mut response = [0u8; 1];
293 tcp_stream.read_exact(&mut response).await?;
294
295 if response[0] != b'S' {
296 return Err(PgError::Connection(
297 "Server does not support TLS".to_string(),
298 ));
299 }
300
301 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
302
303 if let Some(ca_pem) = &config.ca_cert_pem {
304 let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
305 .filter_map(|r| r.ok())
306 .collect::<Vec<_>>();
307 for cert in certs {
308 let _ = root_cert_store.add(cert);
309 }
310 } else {
311 let certs = rustls_native_certs::load_native_certs();
313 for cert in certs.certs {
314 let _ = root_cert_store.add(cert);
315 }
316 }
317
318 let client_certs: Vec<CertificateDer<'static>> =
319 rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
320 .filter_map(|r| r.ok())
321 .collect();
322
323 let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
324 .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
325 .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
326
327 let tls_config = ClientConfig::builder()
328 .with_root_certificates(root_cert_store)
329 .with_client_auth_cert(client_certs, client_key)
330 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
331
332 let connector = TlsConnector::from(Arc::new(tls_config));
333 let server_name = ServerName::try_from(host.to_string())
334 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
335
336 let tls_stream = connector
337 .connect(server_name, tcp_stream)
338 .await
339 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
340
341 let mut conn = Self {
342 stream: PgStream::Tls(tls_stream),
343 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
344 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
345 sql_buf: BytesMut::with_capacity(512),
346 params_buf: Vec::with_capacity(16),
347 prepared_statements: HashMap::new(),
348 stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
349 column_info_cache: HashMap::new(),
350 process_id: 0,
351 secret_key: 0,
352 notifications: VecDeque::new(),
353 };
354
355 conn.send(FrontendMessage::Startup {
356 user: user.to_string(),
357 database: database.to_string(),
358 })
359 .await?;
360
361 conn.handle_startup(user, None).await?;
363
364 Ok(conn)
365 }
366
367 #[cfg(unix)]
369 pub async fn connect_unix(
370 socket_path: &str,
371 user: &str,
372 database: &str,
373 password: Option<&str>,
374 ) -> PgResult<Self> {
375 use tokio::net::UnixStream;
376
377 let unix_stream = UnixStream::connect(socket_path).await?;
378
379 let mut conn = Self {
380 stream: PgStream::Unix(unix_stream),
381 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
382 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
383 sql_buf: BytesMut::with_capacity(512),
384 params_buf: Vec::with_capacity(16),
385 prepared_statements: HashMap::new(),
386 stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
387 column_info_cache: HashMap::new(),
388 process_id: 0,
389 secret_key: 0,
390 notifications: VecDeque::new(),
391 };
392
393 conn.send(FrontendMessage::Startup {
394 user: user.to_string(),
395 database: database.to_string(),
396 })
397 .await?;
398
399 conn.handle_startup(user, password).await?;
400
401 Ok(conn)
402 }
403
404 async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
406 let mut scram_client: Option<ScramClient> = None;
407
408 loop {
409 let msg = self.recv().await?;
410 match msg {
411 BackendMessage::AuthenticationOk => {}
412 BackendMessage::AuthenticationMD5Password(_salt) => {
413 return Err(PgError::Auth(
414 "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
415 ));
416 }
417 BackendMessage::AuthenticationSASL(mechanisms) => {
418 let password = password.ok_or_else(|| {
419 PgError::Auth("Password required for SCRAM authentication".to_string())
420 })?;
421
422 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
423 return Err(PgError::Auth(format!(
424 "Server doesn't support SCRAM-SHA-256. Available: {:?}",
425 mechanisms
426 )));
427 }
428
429 let client = ScramClient::new(user, password);
430 let first_message = client.client_first_message();
431
432 self.send(FrontendMessage::SASLInitialResponse {
433 mechanism: "SCRAM-SHA-256".to_string(),
434 data: first_message,
435 })
436 .await?;
437
438 scram_client = Some(client);
439 }
440 BackendMessage::AuthenticationSASLContinue(server_data) => {
441 let client = scram_client.as_mut().ok_or_else(|| {
442 PgError::Auth("Received SASL Continue without SASL init".to_string())
443 })?;
444
445 let final_message = client
446 .process_server_first(&server_data)
447 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
448
449 self.send(FrontendMessage::SASLResponse(final_message))
450 .await?;
451 }
452 BackendMessage::AuthenticationSASLFinal(server_signature) => {
453 if let Some(client) = scram_client.as_ref() {
454 client.verify_server_final(&server_signature).map_err(|e| {
455 PgError::Auth(format!("Server verification failed: {}", e))
456 })?;
457 }
458 }
459 BackendMessage::ParameterStatus { .. } => {}
460 BackendMessage::BackendKeyData {
461 process_id,
462 secret_key,
463 } => {
464 self.process_id = process_id;
465 self.secret_key = secret_key;
466 }
467 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
468 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
469 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
470 return Ok(());
471 }
472 BackendMessage::ErrorResponse(err) => {
473 return Err(PgError::Connection(err.message));
474 }
475 _ => {}
476 }
477 }
478 }
479
480 pub async fn close(mut self) -> PgResult<()> {
483 use crate::protocol::PgEncoder;
484
485 let terminate = PgEncoder::encode_terminate();
487 self.stream.write_all(&terminate).await?;
488 self.stream.flush().await?;
489
490 Ok(())
491 }
492
493 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
499
500 pub(crate) fn evict_prepared_if_full(&mut self) {
506 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
507 if let Some((_hash, evicted_name)) = self.stmt_cache.pop_lru() {
509 self.prepared_statements.remove(&evicted_name);
510 } else {
511 if let Some(key) = self.prepared_statements.keys().next().cloned() {
515 self.prepared_statements.remove(&key);
516 }
517 }
518 }
519 }
520}
521
522impl Drop for PgConnection {
525 fn drop(&mut self) {
526 let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
529
530 match &mut self.stream {
531 PgStream::Tcp(tcp) => {
532 let _ = tcp.try_write(&terminate);
534 }
535 PgStream::Tls(_) => {
536 }
540 #[cfg(unix)]
541 PgStream::Unix(unix) => {
542 let _ = unix.try_write(&terminate);
543 }
544 }
545 }
546}
547
548pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
549 tag.split_whitespace()
550 .last()
551 .and_then(|s| s.parse().ok())
552 .unwrap_or(0)
553}