1use super::stream::PgStream;
16use super::{PgError, PgResult};
17use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
18use bytes::BytesMut;
19use lru::LruCache;
20use std::collections::HashMap;
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26pub(crate) const BUFFER_CAPACITY: usize = 65536;
28
29const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
31
32pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
34
35#[derive(Clone)]
37pub struct TlsConfig {
38 pub client_cert_pem: Vec<u8>,
40 pub client_key_pem: Vec<u8>,
42 pub ca_cert_pem: Option<Vec<u8>>,
44}
45
46impl TlsConfig {
47 pub fn from_files(
49 cert_path: impl AsRef<std::path::Path>,
50 key_path: impl AsRef<std::path::Path>,
51 ca_path: Option<impl AsRef<std::path::Path>>,
52 ) -> std::io::Result<Self> {
53 Ok(Self {
54 client_cert_pem: std::fs::read(cert_path)?,
55 client_key_pem: std::fs::read(key_path)?,
56 ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
57 })
58 }
59}
60
61pub struct PgConnection {
63 pub(crate) stream: PgStream,
64 pub(crate) buffer: BytesMut,
65 pub(crate) write_buf: BytesMut,
66 pub(crate) sql_buf: BytesMut,
67 pub(crate) params_buf: Vec<Option<Vec<u8>>>,
68 pub(crate) prepared_statements: HashMap<String, String>,
69 pub(crate) stmt_cache: LruCache<u64, String>,
70 pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
74 pub(crate) process_id: i32,
75 pub(crate) secret_key: i32,
76}
77
78impl PgConnection {
79 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
81 Self::connect_with_password(host, port, user, database, None).await
82 }
83
84 pub async fn connect_with_password(
86 host: &str,
87 port: u16,
88 user: &str,
89 database: &str,
90 password: Option<&str>,
91 ) -> PgResult<Self> {
92 let addr = format!("{}:{}", host, port);
93 let tcp_stream = TcpStream::connect(&addr).await?;
94
95 tcp_stream.set_nodelay(true)?;
97
98 let mut conn = Self {
99 stream: PgStream::Tcp(tcp_stream),
100 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
101 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), sql_buf: BytesMut::with_capacity(512),
103 params_buf: Vec::with_capacity(16), prepared_statements: HashMap::new(),
105 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
106 column_info_cache: HashMap::new(),
107 process_id: 0,
108 secret_key: 0,
109 };
110
111 conn.send(FrontendMessage::Startup {
112 user: user.to_string(),
113 database: database.to_string(),
114 })
115 .await?;
116
117 conn.handle_startup(user, password).await?;
118
119 Ok(conn)
120 }
121
122 pub async fn connect_tls(
124 host: &str,
125 port: u16,
126 user: &str,
127 database: &str,
128 password: Option<&str>,
129 ) -> PgResult<Self> {
130 use tokio::io::AsyncReadExt;
131 use tokio_rustls::TlsConnector;
132 use tokio_rustls::rustls::ClientConfig;
133 use tokio_rustls::rustls::pki_types::ServerName;
134
135 let addr = format!("{}:{}", host, port);
136 let mut tcp_stream = TcpStream::connect(&addr).await?;
137
138 tcp_stream.write_all(&SSL_REQUEST).await?;
140
141 let mut response = [0u8; 1];
143 tcp_stream.read_exact(&mut response).await?;
144
145 if response[0] != b'S' {
146 return Err(PgError::Connection(
147 "Server does not support TLS".to_string(),
148 ));
149 }
150
151 let certs = rustls_native_certs::load_native_certs();
153 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
154 for cert in certs.certs {
155 let _ = root_cert_store.add(cert);
156 }
157
158 let config = ClientConfig::builder()
159 .with_root_certificates(root_cert_store)
160 .with_no_client_auth();
161
162 let connector = TlsConnector::from(Arc::new(config));
163 let server_name = ServerName::try_from(host.to_string())
164 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
165
166 let tls_stream = connector
167 .connect(server_name, tcp_stream)
168 .await
169 .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
170
171 let mut conn = Self {
172 stream: PgStream::Tls(tls_stream),
173 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
174 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
175 sql_buf: BytesMut::with_capacity(512),
176 params_buf: Vec::with_capacity(16),
177 prepared_statements: HashMap::new(),
178 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
179 column_info_cache: HashMap::new(),
180 process_id: 0,
181 secret_key: 0,
182 };
183
184 conn.send(FrontendMessage::Startup {
185 user: user.to_string(),
186 database: database.to_string(),
187 })
188 .await?;
189
190 conn.handle_startup(user, password).await?;
191
192 Ok(conn)
193 }
194
195 pub async fn connect_mtls(
212 host: &str,
213 port: u16,
214 user: &str,
215 database: &str,
216 config: TlsConfig,
217 ) -> PgResult<Self> {
218 use tokio::io::AsyncReadExt;
219 use tokio_rustls::TlsConnector;
220 use tokio_rustls::rustls::{
221 ClientConfig,
222 pki_types::{CertificateDer, ServerName},
223 };
224
225 let addr = format!("{}:{}", host, port);
226 let mut tcp_stream = TcpStream::connect(&addr).await?;
227
228 tcp_stream.write_all(&SSL_REQUEST).await?;
230
231 let mut response = [0u8; 1];
233 tcp_stream.read_exact(&mut response).await?;
234
235 if response[0] != b'S' {
236 return Err(PgError::Connection(
237 "Server does not support TLS".to_string(),
238 ));
239 }
240
241 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
242
243 if let Some(ca_pem) = &config.ca_cert_pem {
244 let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
245 .filter_map(|r| r.ok())
246 .collect::<Vec<_>>();
247 for cert in certs {
248 let _ = root_cert_store.add(cert);
249 }
250 } else {
251 let certs = rustls_native_certs::load_native_certs();
253 for cert in certs.certs {
254 let _ = root_cert_store.add(cert);
255 }
256 }
257
258 let client_certs: Vec<CertificateDer<'static>> =
259 rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
260 .filter_map(|r| r.ok())
261 .collect();
262
263 let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
264 .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
265 .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
266
267 let tls_config = ClientConfig::builder()
268 .with_root_certificates(root_cert_store)
269 .with_client_auth_cert(client_certs, client_key)
270 .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
271
272 let connector = TlsConnector::from(Arc::new(tls_config));
273 let server_name = ServerName::try_from(host.to_string())
274 .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
275
276 let tls_stream = connector
277 .connect(server_name, tcp_stream)
278 .await
279 .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
280
281 let mut conn = Self {
282 stream: PgStream::Tls(tls_stream),
283 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
284 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
285 sql_buf: BytesMut::with_capacity(512),
286 params_buf: Vec::with_capacity(16),
287 prepared_statements: HashMap::new(),
288 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
289 column_info_cache: HashMap::new(),
290 process_id: 0,
291 secret_key: 0,
292 };
293
294 conn.send(FrontendMessage::Startup {
295 user: user.to_string(),
296 database: database.to_string(),
297 })
298 .await?;
299
300 conn.handle_startup(user, None).await?;
302
303 Ok(conn)
304 }
305
306 #[cfg(unix)]
308 pub async fn connect_unix(
309 socket_path: &str,
310 user: &str,
311 database: &str,
312 password: Option<&str>,
313 ) -> PgResult<Self> {
314 use tokio::net::UnixStream;
315
316 let unix_stream = UnixStream::connect(socket_path).await?;
317
318 let mut conn = Self {
319 stream: PgStream::Unix(unix_stream),
320 buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
321 write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
322 sql_buf: BytesMut::with_capacity(512),
323 params_buf: Vec::with_capacity(16),
324 prepared_statements: HashMap::new(),
325 stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
326 column_info_cache: HashMap::new(),
327 process_id: 0,
328 secret_key: 0,
329 };
330
331 conn.send(FrontendMessage::Startup {
332 user: user.to_string(),
333 database: database.to_string(),
334 })
335 .await?;
336
337 conn.handle_startup(user, password).await?;
338
339 Ok(conn)
340 }
341
342 async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
344 let mut scram_client: Option<ScramClient> = None;
345
346 loop {
347 let msg = self.recv().await?;
348 match msg {
349 BackendMessage::AuthenticationOk => {}
350 BackendMessage::AuthenticationMD5Password(_salt) => {
351 return Err(PgError::Auth(
352 "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
353 ));
354 }
355 BackendMessage::AuthenticationSASL(mechanisms) => {
356 let password = password.ok_or_else(|| {
357 PgError::Auth("Password required for SCRAM authentication".to_string())
358 })?;
359
360 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
361 return Err(PgError::Auth(format!(
362 "Server doesn't support SCRAM-SHA-256. Available: {:?}",
363 mechanisms
364 )));
365 }
366
367 let client = ScramClient::new(user, password);
368 let first_message = client.client_first_message();
369
370 self.send(FrontendMessage::SASLInitialResponse {
371 mechanism: "SCRAM-SHA-256".to_string(),
372 data: first_message,
373 })
374 .await?;
375
376 scram_client = Some(client);
377 }
378 BackendMessage::AuthenticationSASLContinue(server_data) => {
379 let client = scram_client.as_mut().ok_or_else(|| {
380 PgError::Auth("Received SASL Continue without SASL init".to_string())
381 })?;
382
383 let final_message = client
384 .process_server_first(&server_data)
385 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
386
387 self.send(FrontendMessage::SASLResponse(final_message))
388 .await?;
389 }
390 BackendMessage::AuthenticationSASLFinal(server_signature) => {
391 if let Some(client) = scram_client.as_ref() {
392 client.verify_server_final(&server_signature).map_err(|e| {
393 PgError::Auth(format!("Server verification failed: {}", e))
394 })?;
395 }
396 }
397 BackendMessage::ParameterStatus { .. } => {}
398 BackendMessage::BackendKeyData {
399 process_id,
400 secret_key,
401 } => {
402 self.process_id = process_id;
403 self.secret_key = secret_key;
404 }
405 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
406 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
407 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
408 return Ok(());
409 }
410 BackendMessage::ErrorResponse(err) => {
411 return Err(PgError::Connection(err.message));
412 }
413 _ => {}
414 }
415 }
416 }
417
418 pub async fn close(mut self) -> PgResult<()> {
421 use crate::protocol::PgEncoder;
422
423 let terminate = PgEncoder::encode_terminate();
425 self.stream.write_all(&terminate).await?;
426 self.stream.flush().await?;
427
428 Ok(())
429 }
430}
431
432impl Drop for PgConnection {
435 fn drop(&mut self) {
436 let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
439
440 match &mut self.stream {
441 PgStream::Tcp(tcp) => {
442 let _ = tcp.try_write(&terminate);
444 }
445 PgStream::Tls(_) => {
446 }
450 #[cfg(unix)]
451 PgStream::Unix(unix) => {
452 let _ = unix.try_write(&terminate);
453 }
454 }
455 }
456}
457
458pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
459 tag.split_whitespace()
460 .last()
461 .and_then(|s| s.parse().ok())
462 .unwrap_or(0)
463}