1use super::stream::PgStream;
16use super::{PgError, PgResult};
17use super::notification::Notification;
18use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
19use bytes::BytesMut;
20use lru::LruCache;
21use std::collections::{HashMap, VecDeque};
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26
27pub(crate) const BUFFER_CAPACITY: usize = 65536;
29
30const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
32
33pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
35
36pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
39
40#[derive(Clone)]
42pub struct TlsConfig {
43 pub client_cert_pem: Vec<u8>,
45 pub client_key_pem: Vec<u8>,
47 pub ca_cert_pem: Option<Vec<u8>>,
49}
50
51impl TlsConfig {
52 pub fn from_files(
54 cert_path: impl AsRef<std::path::Path>,
55 key_path: impl AsRef<std::path::Path>,
56 ca_path: Option<impl AsRef<std::path::Path>>,
57 ) -> std::io::Result<Self> {
58 Ok(Self {
59 client_cert_pem: std::fs::read(cert_path)?,
60 client_key_pem: std::fs::read(key_path)?,
61 ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
62 })
63 }
64}
65
66pub struct PgConnection {
68 pub(crate) stream: PgStream,
69 pub(crate) buffer: BytesMut,
70 pub(crate) write_buf: BytesMut,
71 pub(crate) sql_buf: BytesMut,
72 pub(crate) params_buf: Vec<Option<Vec<u8>>>,
73 pub(crate) prepared_statements: HashMap<String, String>,
74 pub(crate) stmt_cache: LruCache<u64, String>,
75 pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
79 pub(crate) process_id: i32,
80 pub(crate) secret_key: i32,
81 pub(crate) notifications: VecDeque<Notification>,
84}
85
86impl PgConnection {
87 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
89 Self::connect_with_password(host, port, user, database, None).await
90 }
91
92 pub async fn connect_with_password(
95 host: &str,
96 port: u16,
97 user: &str,
98 database: &str,
99 password: Option<&str>,
100 ) -> PgResult<Self> {
101 tokio::time::timeout(
102 DEFAULT_CONNECT_TIMEOUT,
103 Self::connect_with_password_inner(host, port, user, database, password),
104 )
105 .await
106 .map_err(|_| PgError::Connection(format!(
107 "Connection timeout after {:?} (TCP connect + handshake)",
108 DEFAULT_CONNECT_TIMEOUT
109 )))?
110 }
111
112 async fn connect_with_password_inner(
114 host: &str,
115 port: u16,
116 user: &str,
117 database: &str,
118 password: Option<&str>,
119 ) -> PgResult<Self> {
120 let addr = format!("{}:{}", host, port);
121 let tcp_stream = TcpStream::connect(&addr).await?;
122
123 tcp_stream.set_nodelay(true)?;
125
126 let mut conn = Self {
127 stream: PgStream::Tcp(tcp_stream),
128 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
129 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
131 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
133 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
134 column_info_cache: HashMap::new(),
135 process_id: 0,
136 secret_key: 0,
137 notifications: VecDeque::new(),
138 };
139
140 conn.send(FrontendMessage::Startup {
141 user: user.to_string(),
142 database: database.to_string(),
143 })
144 .await?;
145
146 conn.handle_startup(user, password).await?;
147
148 Ok(conn)
149 }
150
151 pub async fn connect_tls(
154 host: &str,
155 port: u16,
156 user: &str,
157 database: &str,
158 password: Option<&str>,
159 ) -> PgResult<Self> {
160 tokio::time::timeout(
161 DEFAULT_CONNECT_TIMEOUT,
162 Self::connect_tls_inner(host, port, user, database, password),
163 )
164 .await
165 .map_err(|_| PgError::Connection(format!(
166 "TLS connection timeout after {:?}",
167 DEFAULT_CONNECT_TIMEOUT
168 )))?
169 }
170
171 async fn connect_tls_inner(
173 host: &str,
174 port: u16,
175 user: &str,
176 database: &str,
177 password: Option<&str>,
178 ) -> PgResult<Self> {
179 use tokio::io::AsyncReadExt;
180 use tokio_rustls::TlsConnector;
181 use tokio_rustls::rustls::ClientConfig;
182 use tokio_rustls::rustls::pki_types::ServerName;
183
184 let addr = format!("{}:{}", host, port);
185 let mut tcp_stream = TcpStream::connect(&addr).await?;
186
187 tcp_stream.write_all(&SSL_REQUEST).await?;
189
190 let mut response = [0u8; 1];
192 tcp_stream.read_exact(&mut response).await?;
193
194 if response[0] != b'S' {
195 return Err(PgError::Connection(
196 "Server does not support TLS".to_string(),
197 ));
198 }
199
200 let certs = rustls_native_certs::load_native_certs();
202 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
203 for cert in certs.certs {
204 let _ = root_cert_store.add(cert);
205 }
206
207 let config = ClientConfig::builder()
208 .with_root_certificates(root_cert_store)
209 .with_no_client_auth();
210
211 let connector = TlsConnector::from(Arc::new(config));
212 let server_name = ServerName::try_from(host.to_string())
213 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
214
215 let tls_stream = connector
216 .connect(server_name, tcp_stream)
217 .await
218 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
219
220 let mut conn = Self {
221 stream: PgStream::Tls(tls_stream),
222 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
223 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
224 sql_buf: BytesMut::with_capacity(512),
225 params_buf: Vec::with_capacity(16),
226 prepared_statements: HashMap::new(),
227 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
228 column_info_cache: HashMap::new(),
229 process_id: 0,
230 secret_key: 0,
231 notifications: VecDeque::new(),
232 };
233
234 conn.send(FrontendMessage::Startup {
235 user: user.to_string(),
236 database: database.to_string(),
237 })
238 .await?;
239
240 conn.handle_startup(user, password).await?;
241
242 Ok(conn)
243 }
244
245 pub async fn connect_mtls(
262 host: &str,
263 port: u16,
264 user: &str,
265 database: &str,
266 config: TlsConfig,
267 ) -> PgResult<Self> {
268 use tokio::io::AsyncReadExt;
269 use tokio_rustls::TlsConnector;
270 use tokio_rustls::rustls::{
271 ClientConfig,
272 pki_types::{CertificateDer, ServerName},
273 };
274
275 let addr = format!("{}:{}", host, port);
276 let mut tcp_stream = TcpStream::connect(&addr).await?;
277
278 tcp_stream.write_all(&SSL_REQUEST).await?;
280
281 let mut response = [0u8; 1];
283 tcp_stream.read_exact(&mut response).await?;
284
285 if response[0] != b'S' {
286 return Err(PgError::Connection(
287 "Server does not support TLS".to_string(),
288 ));
289 }
290
291 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
292
293 if let Some(ca_pem) = &config.ca_cert_pem {
294 let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
295 .filter_map(|r| r.ok())
296 .collect::<Vec<_>>();
297 for cert in certs {
298 let _ = root_cert_store.add(cert);
299 }
300 } else {
301 let certs = rustls_native_certs::load_native_certs();
303 for cert in certs.certs {
304 let _ = root_cert_store.add(cert);
305 }
306 }
307
308 let client_certs: Vec<CertificateDer<'static>> =
309 rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
310 .filter_map(|r| r.ok())
311 .collect();
312
313 let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
314 .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
315 .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
316
317 let tls_config = ClientConfig::builder()
318 .with_root_certificates(root_cert_store)
319 .with_client_auth_cert(client_certs, client_key)
320 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
321
322 let connector = TlsConnector::from(Arc::new(tls_config));
323 let server_name = ServerName::try_from(host.to_string())
324 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
325
326 let tls_stream = connector
327 .connect(server_name, tcp_stream)
328 .await
329 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
330
331 let mut conn = Self {
332 stream: PgStream::Tls(tls_stream),
333 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
334 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
335 sql_buf: BytesMut::with_capacity(512),
336 params_buf: Vec::with_capacity(16),
337 prepared_statements: HashMap::new(),
338 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
339 column_info_cache: HashMap::new(),
340 process_id: 0,
341 secret_key: 0,
342 notifications: VecDeque::new(),
343 };
344
345 conn.send(FrontendMessage::Startup {
346 user: user.to_string(),
347 database: database.to_string(),
348 })
349 .await?;
350
351 conn.handle_startup(user, None).await?;
353
354 Ok(conn)
355 }
356
357 #[cfg(unix)]
359 pub async fn connect_unix(
360 socket_path: &str,
361 user: &str,
362 database: &str,
363 password: Option<&str>,
364 ) -> PgResult<Self> {
365 use tokio::net::UnixStream;
366
367 let unix_stream = UnixStream::connect(socket_path).await?;
368
369 let mut conn = Self {
370 stream: PgStream::Unix(unix_stream),
371 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
372 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
373 sql_buf: BytesMut::with_capacity(512),
374 params_buf: Vec::with_capacity(16),
375 prepared_statements: HashMap::new(),
376 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
377 column_info_cache: HashMap::new(),
378 process_id: 0,
379 secret_key: 0,
380 notifications: VecDeque::new(),
381 };
382
383 conn.send(FrontendMessage::Startup {
384 user: user.to_string(),
385 database: database.to_string(),
386 })
387 .await?;
388
389 conn.handle_startup(user, password).await?;
390
391 Ok(conn)
392 }
393
394 async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
396 let mut scram_client: Option<ScramClient> = None;
397
398 loop {
399 let msg = self.recv().await?;
400 match msg {
401 BackendMessage::AuthenticationOk => {}
402 BackendMessage::AuthenticationMD5Password(_salt) => {
403 return Err(PgError::Auth(
404 "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
405 ));
406 }
407 BackendMessage::AuthenticationSASL(mechanisms) => {
408 let password = password.ok_or_else(|| {
409 PgError::Auth("Password required for SCRAM authentication".to_string())
410 })?;
411
412 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
413 return Err(PgError::Auth(format!(
414 "Server doesn't support SCRAM-SHA-256. Available: {:?}",
415 mechanisms
416 )));
417 }
418
419 let client = ScramClient::new(user, password);
420 let first_message = client.client_first_message();
421
422 self.send(FrontendMessage::SASLInitialResponse {
423 mechanism: "SCRAM-SHA-256".to_string(),
424 data: first_message,
425 })
426 .await?;
427
428 scram_client = Some(client);
429 }
430 BackendMessage::AuthenticationSASLContinue(server_data) => {
431 let client = scram_client.as_mut().ok_or_else(|| {
432 PgError::Auth("Received SASL Continue without SASL init".to_string())
433 })?;
434
435 let final_message = client
436 .process_server_first(&server_data)
437 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
438
439 self.send(FrontendMessage::SASLResponse(final_message))
440 .await?;
441 }
442 BackendMessage::AuthenticationSASLFinal(server_signature) => {
443 if let Some(client) = scram_client.as_ref() {
444 client.verify_server_final(&server_signature).map_err(|e| {
445 PgError::Auth(format!("Server verification failed: {}", e))
446 })?;
447 }
448 }
449 BackendMessage::ParameterStatus { .. } => {}
450 BackendMessage::BackendKeyData {
451 process_id,
452 secret_key,
453 } => {
454 self.process_id = process_id;
455 self.secret_key = secret_key;
456 }
457 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
458 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
459 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
460 return Ok(());
461 }
462 BackendMessage::ErrorResponse(err) => {
463 return Err(PgError::Connection(err.message));
464 }
465 _ => {}
466 }
467 }
468 }
469
470 pub async fn close(mut self) -> PgResult<()> {
473 use crate::protocol::PgEncoder;
474
475 let terminate = PgEncoder::encode_terminate();
477 self.stream.write_all(&terminate).await?;
478 self.stream.flush().await?;
479
480 Ok(())
481 }
482
483 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
489
490 pub(crate) fn evict_prepared_if_full(&mut self) {
496 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
497 if let Some((_hash, evicted_name)) = self.stmt_cache.pop_lru() {
499 self.prepared_statements.remove(&evicted_name);
500 } else {
501 if let Some(key) = self.prepared_statements.keys().next().cloned() {
505 self.prepared_statements.remove(&key);
506 }
507 }
508 }
509 }
510}
511
512impl Drop for PgConnection {
515 fn drop(&mut self) {
516 let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
519
520 match &mut self.stream {
521 PgStream::Tcp(tcp) => {
522 let _ = tcp.try_write(&terminate);
524 }
525 PgStream::Tls(_) => {
526 }
530 #[cfg(unix)]
531 PgStream::Unix(unix) => {
532 let _ = unix.try_write(&terminate);
533 }
534 }
535 }
536}
537
538pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
539 tag.split_whitespace()
540 .last()
541 .and_then(|s| s.parse().ok())
542 .unwrap_or(0)
543}