1use super::stream::PgStream;
16use super::{PgError, PgResult};
17use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
18use bytes::BytesMut;
19use lru::LruCache;
20use std::collections::HashMap;
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26pub(crate) const BUFFER_CAPACITY: usize = 65536;
28
29const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
31
32pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
34
35#[derive(Clone)]
37pub struct TlsConfig {
38 pub client_cert_pem: Vec<u8>,
40 pub client_key_pem: Vec<u8>,
42 pub ca_cert_pem: Option<Vec<u8>>,
44}
45
46impl TlsConfig {
47 pub fn from_files(
49 cert_path: impl AsRef<std::path::Path>,
50 key_path: impl AsRef<std::path::Path>,
51 ca_path: Option<impl AsRef<std::path::Path>>,
52 ) -> std::io::Result<Self> {
53 Ok(Self {
54 client_cert_pem: std::fs::read(cert_path)?,
55 client_key_pem: std::fs::read(key_path)?,
56 ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
57 })
58 }
59}
60
61pub struct PgConnection {
63 pub(crate) stream: PgStream,
64 pub(crate) buffer: BytesMut,
65 pub(crate) write_buf: BytesMut,
66 pub(crate) sql_buf: BytesMut,
67 pub(crate) params_buf: Vec<Option<Vec<u8>>>,
68 pub(crate) prepared_statements: HashMap<String, String>,
69 pub(crate) stmt_cache: LruCache<u64, String>,
70 pub(crate) process_id: i32,
71 pub(crate) secret_key: i32,
72}
73
74impl PgConnection {
75 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
77 Self::connect_with_password(host, port, user, database, None).await
78 }
79
80 pub async fn connect_with_password(
82 host: &str,
83 port: u16,
84 user: &str,
85 database: &str,
86 password: Option<&str>,
87 ) -> PgResult<Self> {
88 let addr = format!("{}:{}", host, port);
89 let tcp_stream = TcpStream::connect(&addr).await?;
90
91 tcp_stream.set_nodelay(true)?;
93
94 let mut conn = Self {
95 stream: PgStream::Tcp(tcp_stream),
96 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
97 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
99 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
101 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
102 process_id: 0,
103 secret_key: 0,
104 };
105
106 conn.send(FrontendMessage::Startup {
107 user: user.to_string(),
108 database: database.to_string(),
109 })
110 .await?;
111
112 conn.handle_startup(user, password).await?;
113
114 Ok(conn)
115 }
116
117 pub async fn connect_tls(
119 host: &str,
120 port: u16,
121 user: &str,
122 database: &str,
123 password: Option<&str>,
124 ) -> PgResult<Self> {
125 use tokio::io::AsyncReadExt;
126 use tokio_rustls::TlsConnector;
127 use tokio_rustls::rustls::ClientConfig;
128 use tokio_rustls::rustls::pki_types::ServerName;
129
130 let addr = format!("{}:{}", host, port);
131 let mut tcp_stream = TcpStream::connect(&addr).await?;
132
133 tcp_stream.write_all(&SSL_REQUEST).await?;
135
136 let mut response = [0u8; 1];
138 tcp_stream.read_exact(&mut response).await?;
139
140 if response[0] != b'S' {
141 return Err(PgError::Connection(
142 "Server does not support TLS".to_string(),
143 ));
144 }
145
146 let certs = rustls_native_certs::load_native_certs();
148 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
149 for cert in certs.certs {
150 let _ = root_cert_store.add(cert);
151 }
152
153 let config = ClientConfig::builder()
154 .with_root_certificates(root_cert_store)
155 .with_no_client_auth();
156
157 let connector = TlsConnector::from(Arc::new(config));
158 let server_name = ServerName::try_from(host.to_string())
159 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
160
161 let tls_stream = connector
162 .connect(server_name, tcp_stream)
163 .await
164 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
165
166 let mut conn = Self {
167 stream: PgStream::Tls(tls_stream),
168 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
169 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
170 sql_buf: BytesMut::with_capacity(512),
171 params_buf: Vec::with_capacity(16),
172 prepared_statements: HashMap::new(),
173 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
174 process_id: 0,
175 secret_key: 0,
176 };
177
178 conn.send(FrontendMessage::Startup {
179 user: user.to_string(),
180 database: database.to_string(),
181 })
182 .await?;
183
184 conn.handle_startup(user, password).await?;
185
186 Ok(conn)
187 }
188
189 pub async fn connect_mtls(
206 host: &str,
207 port: u16,
208 user: &str,
209 database: &str,
210 config: TlsConfig,
211 ) -> PgResult<Self> {
212 use tokio::io::AsyncReadExt;
213 use tokio_rustls::TlsConnector;
214 use tokio_rustls::rustls::{
215 ClientConfig,
216 pki_types::{CertificateDer, ServerName},
217 };
218
219 let addr = format!("{}:{}", host, port);
220 let mut tcp_stream = TcpStream::connect(&addr).await?;
221
222 tcp_stream.write_all(&SSL_REQUEST).await?;
224
225 let mut response = [0u8; 1];
227 tcp_stream.read_exact(&mut response).await?;
228
229 if response[0] != b'S' {
230 return Err(PgError::Connection(
231 "Server does not support TLS".to_string(),
232 ));
233 }
234
235 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
236
237 if let Some(ca_pem) = &config.ca_cert_pem {
238 let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
239 .filter_map(|r| r.ok())
240 .collect::<Vec<_>>();
241 for cert in certs {
242 let _ = root_cert_store.add(cert);
243 }
244 } else {
245 let certs = rustls_native_certs::load_native_certs();
247 for cert in certs.certs {
248 let _ = root_cert_store.add(cert);
249 }
250 }
251
252 let client_certs: Vec<CertificateDer<'static>> =
253 rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
254 .filter_map(|r| r.ok())
255 .collect();
256
257 let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
258 .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
259 .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
260
261 let tls_config = ClientConfig::builder()
262 .with_root_certificates(root_cert_store)
263 .with_client_auth_cert(client_certs, client_key)
264 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
265
266 let connector = TlsConnector::from(Arc::new(tls_config));
267 let server_name = ServerName::try_from(host.to_string())
268 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
269
270 let tls_stream = connector
271 .connect(server_name, tcp_stream)
272 .await
273 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
274
275 let mut conn = Self {
276 stream: PgStream::Tls(tls_stream),
277 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
278 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
279 sql_buf: BytesMut::with_capacity(512),
280 params_buf: Vec::with_capacity(16),
281 prepared_statements: HashMap::new(),
282 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
283 process_id: 0,
284 secret_key: 0,
285 };
286
287 conn.send(FrontendMessage::Startup {
288 user: user.to_string(),
289 database: database.to_string(),
290 })
291 .await?;
292
293 conn.handle_startup(user, None).await?;
295
296 Ok(conn)
297 }
298
299 #[cfg(unix)]
301 pub async fn connect_unix(
302 socket_path: &str,
303 user: &str,
304 database: &str,
305 password: Option<&str>,
306 ) -> PgResult<Self> {
307 use tokio::net::UnixStream;
308
309 let unix_stream = UnixStream::connect(socket_path).await?;
310
311 let mut conn = Self {
312 stream: PgStream::Unix(unix_stream),
313 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
314 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
315 sql_buf: BytesMut::with_capacity(512),
316 params_buf: Vec::with_capacity(16),
317 prepared_statements: HashMap::new(),
318 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
319 process_id: 0,
320 secret_key: 0,
321 };
322
323 conn.send(FrontendMessage::Startup {
324 user: user.to_string(),
325 database: database.to_string(),
326 })
327 .await?;
328
329 conn.handle_startup(user, password).await?;
330
331 Ok(conn)
332 }
333
334 async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
336 let mut scram_client: Option<ScramClient> = None;
337
338 loop {
339 let msg = self.recv().await?;
340 match msg {
341 BackendMessage::AuthenticationOk => {}
342 BackendMessage::AuthenticationMD5Password(_salt) => {
343 return Err(PgError::Auth(
344 "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
345 ));
346 }
347 BackendMessage::AuthenticationSASL(mechanisms) => {
348 let password = password.ok_or_else(|| {
349 PgError::Auth("Password required for SCRAM authentication".to_string())
350 })?;
351
352 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
353 return Err(PgError::Auth(format!(
354 "Server doesn't support SCRAM-SHA-256. Available: {:?}",
355 mechanisms
356 )));
357 }
358
359 let client = ScramClient::new(user, password);
360 let first_message = client.client_first_message();
361
362 self.send(FrontendMessage::SASLInitialResponse {
363 mechanism: "SCRAM-SHA-256".to_string(),
364 data: first_message,
365 })
366 .await?;
367
368 scram_client = Some(client);
369 }
370 BackendMessage::AuthenticationSASLContinue(server_data) => {
371 let client = scram_client.as_mut().ok_or_else(|| {
372 PgError::Auth("Received SASL Continue without SASL init".to_string())
373 })?;
374
375 let final_message = client
376 .process_server_first(&server_data)
377 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
378
379 self.send(FrontendMessage::SASLResponse(final_message))
380 .await?;
381 }
382 BackendMessage::AuthenticationSASLFinal(server_signature) => {
383 if let Some(client) = scram_client.as_ref() {
384 client.verify_server_final(&server_signature).map_err(|e| {
385 PgError::Auth(format!("Server verification failed: {}", e))
386 })?;
387 }
388 }
389 BackendMessage::ParameterStatus { .. } => {}
390 BackendMessage::BackendKeyData {
391 process_id,
392 secret_key,
393 } => {
394 self.process_id = process_id;
395 self.secret_key = secret_key;
396 }
397 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
398 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
399 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
400 return Ok(());
401 }
402 BackendMessage::ErrorResponse(err) => {
403 return Err(PgError::Connection(err.message));
404 }
405 _ => {}
406 }
407 }
408 }
409
410 pub async fn close(mut self) -> PgResult<()> {
413 use crate::protocol::PgEncoder;
414
415 let terminate = PgEncoder::encode_terminate();
417 self.stream.write_all(&terminate).await?;
418 self.stream.flush().await?;
419
420 Ok(())
421 }
422}
423
424impl Drop for PgConnection {
427 fn drop(&mut self) {
428 let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
431
432 match &mut self.stream {
433 PgStream::Tcp(tcp) => {
434 let _ = tcp.try_write(&terminate);
436 }
437 PgStream::Tls(_) => {
438 }
442 #[cfg(unix)]
443 PgStream::Unix(unix) => {
444 let _ = unix.try_write(&terminate);
445 }
446 }
447 }
448}
449
450pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
451 tag.split_whitespace()
452 .last()
453 .and_then(|s| s.parse().ok())
454 .unwrap_or(0)
455}