1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::arena::Arena;
18use crate::auth;
19use crate::codec::Encode;
20use crate::proto::{self, BackendMessage};
21use crate::stmt_cache::{build_bind_template, make_stmt_name, StmtCache, StmtInfo};
22use crate::sync_io::Stream;
23use crate::types::{
24 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
25 StartupAction,
26};
27use crate::DriverError;
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42pub(crate) fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58thread_local! {
59 static COL_OFFSETS_POOL: RefCell<Vec<Vec<(usize, i32)>>> = const { RefCell::new(Vec::new()) };
60}
61
62pub(crate) fn acquire_col_offsets() -> Vec<(usize, i32)> {
63 COL_OFFSETS_POOL
64 .with(|pool| pool.borrow_mut().pop())
65 .unwrap_or_default()
66}
67
68pub fn release_col_offsets(buf: Vec<(usize, i32)>) {
69 COL_OFFSETS_POOL.with(|pool| {
70 let mut pool = pool.borrow_mut();
71 if pool.len() < 4 {
72 pool.push(buf);
73 }
74 });
75}
76
77pub struct Connection {
106 stream_buf_pos: usize,
108 stream_buf_end: usize,
109 query_counter: u64,
110 tx_status: u8,
111 streaming_active: bool,
112 pid: i32,
113 secret: i32,
114 max_stmt_cache_size: usize,
115 stream: Stream,
117 write_buf: Vec<u8>,
118 stream_buf: Vec<u8>,
119 stmts: StmtCache,
120 read_buf: Vec<u8>,
122 params: Vec<(Box<str>, Box<str>)>,
123 last_used: std::time::Instant,
124 created_at: std::time::Instant,
125 pending_notifications: Vec<Notification>,
126 connect_config: Arc<Config>,
129 tls_server_cert_hash: Option<[u8; 32]>,
132}
133
134impl std::fmt::Debug for Connection {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("Connection")
137 .field("pid", &self.pid)
138 .field("tx_status", &(self.tx_status as char))
139 .field("stmt_cache_len", &self.stmts.len())
140 .finish()
141 }
142}
143
144impl Connection {
145 pub fn connect(config: &Config) -> Result<Self, DriverError> {
157 Self::connect_arc(Arc::new(config.clone()))
158 }
159
160 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
165 config.validate()?;
166
167 #[allow(unused_mut)]
170 let mut tls_cert_hash: Option<[u8; 32]> = None;
171
172 let stream = if config.host_is_uds() {
173 #[cfg(unix)]
175 {
176 let path = config.uds_path();
177 let unix =
178 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
179 Stream::Unix(unix)
180 }
181 #[cfg(not(unix))]
182 {
183 return Err(DriverError::Protocol(
184 "Unix domain sockets are not supported on this platform".into(),
185 ));
186 }
187 } else {
188 let addr = format!("{}:{}", config.host, config.port);
190 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
191
192 match config.ssl {
193 SslMode::Disable => {
194 tcp.set_nodelay(true).map_err(DriverError::Io)?;
195 let stream = Stream::Tcp(tcp);
196 stream.set_keepalive()?;
197 stream
198 }
199 SslMode::Prefer | SslMode::Require => {
200 #[cfg(feature = "tls")]
201 {
202 match crate::tls_sync::try_upgrade(
203 tcp,
204 &config.host,
205 config.ssl == SslMode::Require,
206 ) {
207 Ok(result) => {
208 tls_cert_hash = result.server_cert_hash;
209 let stream = Stream::Tls(Box::new(result.stream));
210 stream.set_nodelay()?;
211 stream.set_keepalive()?;
212 stream
213 }
214 Err(e) => {
215 if config.ssl == SslMode::Require {
216 return Err(e);
217 }
218 let tcp =
220 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
221 tcp.set_nodelay(true).map_err(DriverError::Io)?;
222 let stream = Stream::Tcp(tcp);
223 stream.set_keepalive()?;
224 stream
225 }
226 }
227 }
228 #[cfg(not(feature = "tls"))]
229 {
230 if config.ssl == SslMode::Require {
231 return Err(DriverError::Protocol(
232 "sslmode=require but bsql was compiled without the 'tls' feature"
233 .into(),
234 ));
235 }
236 tcp.set_nodelay(true).map_err(DriverError::Io)?;
237 let stream = Stream::Tcp(tcp);
238 stream.set_keepalive()?;
239 stream
240 }
241 }
242 }
243 };
244
245 let now = std::time::Instant::now();
246 let mut conn = Self {
247 stream_buf_pos: 0,
249 stream_buf_end: 0,
250 query_counter: 0,
251 tx_status: b'I',
252 streaming_active: false,
253 pid: 0,
254 secret: 0,
255 max_stmt_cache_size: 256,
256 stream,
258 write_buf: Vec::with_capacity(4096),
259 stream_buf: vec![0u8; 65536],
260 stmts: StmtCache::default(),
261 read_buf: Vec::with_capacity(8192),
263 params: Vec::new(),
264 last_used: now,
265 created_at: now,
266 pending_notifications: Vec::new(),
267 connect_config: config.clone(),
268 tls_server_cert_hash: tls_cert_hash,
269 };
270
271 conn.startup(&config)?;
272 conn.validate_server_params()?;
273
274 Ok(conn)
275 }
276
277 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
280 self.write_buf.clear();
281 let timeout_str; let mut extra_params: smallvec::SmallVec<[(&str, &str); 2]> = smallvec::SmallVec::new();
285 if config.statement_timeout_secs > 0 {
286 timeout_str = format!("{}s", config.statement_timeout_secs);
287 extra_params.push(("statement_timeout", &timeout_str));
288 }
289 proto::write_startup(
290 &mut self.write_buf,
291 &config.user,
292 &config.database,
293 &extra_params,
294 );
295 self.flush_write()?;
296
297 loop {
298 let action = self.read_startup_action()?;
299 match action {
300 StartupAction::AuthOk => {}
301 StartupAction::AuthCleartext => {
302 self.write_buf.clear();
303 let mut pw = config.password.as_bytes().to_vec();
304 pw.push(0);
305 proto::write_password(&mut self.write_buf, &pw);
306 self.flush_write()?;
307 }
308 StartupAction::AuthMd5(salt) => {
309 self.write_buf.clear();
310 let hash = auth::md5_password(&config.user, &config.password, &salt);
311 proto::write_password(&mut self.write_buf, &hash);
312 self.flush_write()?;
313 }
314 StartupAction::AuthSasl(mechanisms_data) => {
315 self.handle_scram(config, &mechanisms_data)?;
316 }
317 StartupAction::ParameterStatus(name, value) => {
318 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
319 entry.1 = value;
320 } else {
321 self.params.push((name, value));
322 }
323 }
324 StartupAction::BackendKeyData(pid, secret) => {
325 self.pid = pid;
326 self.secret = secret;
327 }
328 StartupAction::ReadyForQuery(status) => {
329 self.tx_status = status;
330 return Ok(());
331 }
332 StartupAction::Error(msg) => {
333 return Err(DriverError::Auth(msg));
334 }
335 StartupAction::Notice => {}
336 }
337 }
338 }
339
340 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
341 let (msg_type, _) = self.read_message_buffered()?;
342 let payload = &self.read_buf;
343 let msg = proto::parse_backend_message(msg_type, payload)?;
344 match msg {
345 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
346 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
347 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
348 BackendMessage::AuthSasl { mechanisms } => {
349 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
350 }
351 BackendMessage::ParameterStatus { name, value } => {
352 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
353 }
354 BackendMessage::BackendKeyData { pid, secret } => {
355 Ok(StartupAction::BackendKeyData(pid, secret))
356 }
357 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
358 BackendMessage::ErrorResponse { data } => {
359 let fields = proto::parse_error_response(data);
360 Ok(StartupAction::Error(fields.to_string()))
361 }
362 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
363 other => Err(DriverError::Protocol(format!(
364 "unexpected message during startup: {other:?}"
365 ))),
366 }
367 }
368
369 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
370 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
371
372 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
375 let mechanism = if use_plus {
376 "SCRAM-SHA-256-PLUS"
377 } else {
378 "SCRAM-SHA-256"
379 };
380
381 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
382 return Err(DriverError::Auth(format!(
383 "server requires unsupported SASL mechanism(s): {mechs:?}"
384 )));
385 }
386
387 let cert_hash = if use_plus {
388 self.tls_server_cert_hash.as_ref()
389 } else {
390 None
391 };
392 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
393
394 let client_first = scram.client_first_message();
396 self.write_buf.clear();
397 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
398 self.flush_write()?;
399
400 let (msg_type, _) = self.read_message_buffered()?;
402 let server_first = {
403 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
404 match msg {
405 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
406 BackendMessage::ErrorResponse { data } => {
407 let fields = proto::parse_error_response(data);
408 return Err(DriverError::Auth(fields.to_string()));
409 }
410 other => {
411 return Err(DriverError::Protocol(format!(
412 "expected AuthSaslContinue, got: {other:?}"
413 )));
414 }
415 }
416 };
417
418 scram.process_server_first(&server_first)?;
419
420 let client_final = scram.client_final_message()?;
422 self.write_buf.clear();
423 proto::write_sasl_response(&mut self.write_buf, &client_final);
424 self.flush_write()?;
425
426 let (msg_type, _) = self.read_message_buffered()?;
428 {
429 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
430 match msg {
431 BackendMessage::AuthSaslFinal { data } => {
432 let data_owned = data.to_vec();
433 scram.verify_server_final(&data_owned)?;
434 }
435 BackendMessage::ErrorResponse { data } => {
436 let fields = proto::parse_error_response(data);
437 return Err(DriverError::Auth(fields.to_string()));
438 }
439 other => {
440 return Err(DriverError::Protocol(format!(
441 "expected AuthSaslFinal, got: {other:?}"
442 )));
443 }
444 }
445 }
446
447 let (msg_type, _) = self.read_message_buffered()?;
449 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
450 match msg {
451 BackendMessage::AuthOk => Ok(()),
452 BackendMessage::ErrorResponse { data } => {
453 let fields = proto::parse_error_response(data);
454 Err(DriverError::Auth(fields.to_string()))
455 }
456 other => Err(DriverError::Protocol(format!(
457 "expected AuthOk after SCRAM, got: {other:?}"
458 ))),
459 }
460 }
461
462 fn validate_server_params(&self) -> Result<(), DriverError> {
463 if let Some(encoding) = self.parameter("server_encoding") {
464 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
465 return Err(DriverError::Protocol(format!(
466 "server_encoding is '{encoding}', but bsql requires UTF-8."
467 )));
468 }
469 }
470 if let Some(encoding) = self.parameter("client_encoding") {
471 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
472 return Err(DriverError::Protocol(format!(
473 "client_encoding is '{encoding}', but bsql requires UTF-8."
474 )));
475 }
476 }
477 if let Some(idt) = self.parameter("integer_datetimes") {
478 if idt != "on" {
479 return Err(DriverError::Protocol(format!(
480 "integer_datetimes is '{idt}', but bsql requires 'on'."
481 )));
482 }
483 }
484 Ok(())
485 }
486
487 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
493 if self.stmts.contains_key(&sql_hash, sql) {
494 return Ok(());
495 }
496 let name = make_stmt_name(sql_hash);
497 self.write_buf.clear();
498 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
499 proto::write_describe(&mut self.write_buf, b'S', &name);
500 proto::write_sync(&mut self.write_buf);
501 self.flush_write()?;
502
503 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
504 let columns = self.read_column_description()?;
505 self.expect_ready()?;
506
507 self.query_counter += 1;
508 self.cache_stmt(
509 sql_hash,
510 StmtInfo {
511 name,
512 sql: sql.into(),
513 columns,
514 last_used: self.query_counter,
515 bind_template: None,
516 },
517 );
518 Ok(())
519 }
520
521 pub fn prepare_batch(&mut self, sqls: &[(&str, u64)]) -> Result<(), DriverError> {
529 if sqls.is_empty() {
530 return Ok(());
531 }
532
533 let mut pending = 0usize;
535 self.write_buf.clear();
536 for &(sql, sql_hash) in sqls {
537 if self.stmts.contains_key(&sql_hash, sql) {
538 continue;
539 }
540 let name = make_stmt_name(sql_hash);
541 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
542 proto::write_describe(&mut self.write_buf, b'S', &name);
543 pending += 1;
544 }
545
546 if pending == 0 {
547 return Ok(());
548 }
549
550 proto::write_sync(&mut self.write_buf);
551 self.flush_write()?;
552
553 for &(sql, sql_hash) in sqls {
556 if self.stmts.contains_key(&sql_hash, sql) {
557 continue;
558 }
559
560 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
561 let columns = self.read_column_description()?;
562
563 let name = make_stmt_name(sql_hash);
564 self.query_counter += 1;
565 self.cache_stmt(
566 sql_hash,
567 StmtInfo {
568 name,
569 sql: sql.into(),
570 columns,
571 last_used: self.query_counter,
572 bind_template: None,
573 },
574 );
575 }
576
577 self.expect_ready()?;
578 Ok(())
579 }
580
581 #[inline]
592 pub fn query(
593 &mut self,
594 sql: &str,
595 sql_hash: u64,
596 params: &[&(dyn Encode + Sync)],
597 ) -> Result<QueryResult, DriverError> {
598 let columns = self
599 .send_pipeline(sql, sql_hash, params, true, true)?
600 .expect("send_pipeline(need_columns=true) must return Some");
601
602 let num_cols = columns.len();
603 let mut all_col_offsets = acquire_col_offsets();
604 all_col_offsets.clear();
605 let mut affected_rows: u64 = 0;
606
607 let mut resp_buf = acquire_resp_buf();
616 resp_buf.clear();
617
618 'outer: loop {
620 loop {
621 let avail = self.stream_buf_end - self.stream_buf_pos;
622 if avail < 5 {
623 break; }
625
626 let msg_type = self.stream_buf[self.stream_buf_pos];
627 let raw_len = i32::from_be_bytes([
628 self.stream_buf[self.stream_buf_pos + 1],
629 self.stream_buf[self.stream_buf_pos + 2],
630 self.stream_buf[self.stream_buf_pos + 3],
631 self.stream_buf[self.stream_buf_pos + 4],
632 ]);
633
634 if raw_len < 4 {
635 return Err(DriverError::Protocol(format!(
636 "invalid message length {raw_len} for type '{}'",
637 msg_type as char
638 )));
639 }
640
641 let payload_len = (raw_len - 4) as usize;
642 let total_msg_len = 5 + payload_len;
643
644 if avail < total_msg_len {
645 if total_msg_len > self.stream_buf.len() {
646 let msg = self.read_one_message()?;
648 match msg {
649 BackendMessage::BindComplete => continue,
650 BackendMessage::DataRow { data } => {
651 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
652 continue;
653 }
654 BackendMessage::CommandComplete { tag } => {
655 affected_rows = proto::parse_command_tag(tag);
656 continue;
657 }
658 BackendMessage::EmptyQuery => continue,
659 BackendMessage::ReadyForQuery { status } => {
660 self.tx_status = status;
661 break 'outer;
662 }
663 BackendMessage::NoticeResponse { .. } => continue,
664 BackendMessage::ErrorResponse { data } => {
665 let fields = proto::parse_error_response(data);
666 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
667 self.drain_to_ready()?;
668 return Err(self.make_server_error(fields));
669 }
670 other => {
671 return Err(DriverError::Protocol(format!(
672 "unexpected message during query: {other:?}"
673 )));
674 }
675 }
676 }
677 break; }
679
680 let payload_start = self.stream_buf_pos + 5;
682 let payload_end = payload_start + payload_len;
683
684 if msg_type == b'D' {
685 parse_data_row_into_buf(
687 &self.stream_buf[payload_start..payload_end],
688 &mut resp_buf,
689 &mut all_col_offsets,
690 )?;
691 } else if msg_type == b'Z' {
692 if payload_len >= 1 {
693 self.tx_status = self.stream_buf[payload_start];
694 }
695 self.stream_buf_pos += total_msg_len;
696 break 'outer;
697 } else {
698 self.handle_non_datarow_query(
699 msg_type,
700 payload_start,
701 payload_end,
702 sql_hash,
703 &mut affected_rows,
704 )?;
705 }
706
707 self.stream_buf_pos += total_msg_len;
708 }
709
710 self.refill_stream_buf()?;
711 }
712
713 self.shrink_buffers();
714
715 Ok(QueryResult::from_parts_with_buf(
718 all_col_offsets,
719 num_cols,
720 columns,
721 affected_rows,
722 resp_buf,
723 ))
724 }
725
726 #[inline]
737 pub fn execute_monolithic(
738 &mut self,
739 sql: &str,
740 sql_hash: u64,
741 params: &[&(dyn Encode + Sync)],
742 ) -> Result<u64, DriverError> {
743 self.write_buf.clear();
745
746 let info = match self.stmts.get_mut(&sql_hash, sql) {
748 Some(info) => {
749 self.query_counter += 1;
750 info.last_used = self.query_counter;
751 info
752 }
753 None => {
754 return self.execute_with_prepare(sql, sql_hash, params);
756 }
757 };
758
759 let can_use_template = info
761 .bind_template
762 .as_ref()
763 .is_some_and(|t| t.param_slots.len() == params.len());
764
765 let mut has_exec_sync = false;
766
767 if can_use_template {
768 let tmpl = info
770 .bind_template
771 .as_ref()
772 .expect("guarded by can_use_template");
773 self.write_buf.extend_from_slice(&tmpl.bytes);
774
775 let mut template_ok = true;
776 for (i, param) in params.iter().enumerate() {
777 let (data_offset, old_len) = tmpl.param_slots[i];
778 if param.is_null() {
779 let len_offset = data_offset - 4;
780 self.write_buf[len_offset..len_offset + 4]
781 .copy_from_slice(&(-1i32).to_be_bytes());
782 } else if old_len >= 0 {
783 let end = data_offset + old_len as usize;
784 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
785 template_ok = false;
786 break;
787 }
788 } else {
789 template_ok = false;
791 break;
792 }
793 }
794
795 if template_ok {
796 has_exec_sync = true;
797 } else {
798 self.write_buf.clear();
799 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
800 info.bind_template = None;
801 }
802 } else {
803 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
804 }
805
806 if info.bind_template.is_none() && !self.write_buf.is_empty() {
808 info.bind_template = build_bind_template(&self.write_buf, params.len());
809 }
810
811 if !has_exec_sync {
812 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
813 }
814
815 self.stream
817 .write_all(&self.write_buf)
818 .map_err(DriverError::Io)?;
819
820 let mut affected_rows: u64 = 0;
822
823 'outer: loop {
824 loop {
825 let avail = self.stream_buf_end - self.stream_buf_pos;
826 if avail < 5 {
827 break; }
829
830 let msg_type = self.stream_buf[self.stream_buf_pos];
831 let raw_len = i32::from_be_bytes([
832 self.stream_buf[self.stream_buf_pos + 1],
833 self.stream_buf[self.stream_buf_pos + 2],
834 self.stream_buf[self.stream_buf_pos + 3],
835 self.stream_buf[self.stream_buf_pos + 4],
836 ]);
837
838 if raw_len < 4 {
839 return Err(DriverError::Protocol(format!(
840 "invalid message length {raw_len} for type '{}'",
841 msg_type as char
842 )));
843 }
844
845 let payload_len = (raw_len - 4) as usize;
846 let total_msg_len = 5 + payload_len;
847
848 if avail < total_msg_len {
849 if total_msg_len > self.stream_buf.len() {
850 let msg = self.read_one_message()?;
851 match msg {
852 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
853 continue;
854 }
855 BackendMessage::CommandComplete { tag } => {
856 affected_rows = proto::parse_command_tag(tag);
857 continue;
858 }
859 BackendMessage::EmptyQuery => continue,
860 BackendMessage::ReadyForQuery { status } => {
861 self.tx_status = status;
862 break 'outer;
863 }
864 BackendMessage::NoticeResponse { .. } => continue,
865 BackendMessage::ErrorResponse { data } => {
866 let fields = proto::parse_error_response(data);
867 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
868 self.drain_to_ready()?;
869 return Err(self.make_server_error(fields));
870 }
871 other => {
872 return Err(DriverError::Protocol(format!(
873 "unexpected message during execute: {other:?}"
874 )));
875 }
876 }
877 }
878 break; }
880
881 let payload_start = self.stream_buf_pos + 5;
886 let payload_end = payload_start + payload_len;
887
888 if msg_type == b'2' {
889 self.stream_buf_pos += total_msg_len;
891 continue;
892 } else if msg_type == b'C' {
893 affected_rows = proto::parse_command_tag_bytes(
895 &self.stream_buf[payload_start..payload_end],
896 );
897 } else if msg_type == b'Z' {
898 if payload_len >= 1 {
900 self.tx_status = self.stream_buf[payload_start];
901 }
902 self.stream_buf_pos += total_msg_len;
903 break 'outer;
904 } else if msg_type == b'D' || msg_type == b'I' {
905 } else {
907 self.handle_non_datarow_execute(
908 msg_type,
909 payload_start,
910 payload_end,
911 sql_hash,
912 )?;
913 }
914
915 self.stream_buf_pos += total_msg_len;
916 }
917
918 let remaining = self.stream_buf_end - self.stream_buf_pos;
920 debug_assert!(
921 remaining == 0 || self.stream_buf_pos > 0,
922 "compact called with pos=0 and remaining data"
923 );
924 if remaining > 0 {
925 self.stream_buf
926 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
927 }
928 self.stream_buf_pos = 0;
929 self.stream_buf_end = remaining;
930 let n = self
931 .stream
932 .read(&mut self.stream_buf[remaining..])
933 .map_err(DriverError::Io)?;
934 if n == 0 {
935 return Err(DriverError::Io(std::io::Error::new(
936 std::io::ErrorKind::UnexpectedEof,
937 "connection closed",
938 )));
939 }
940 self.stream_buf_end = remaining + n;
941 }
942
943 if self.query_counter & 63 == 0 {
945 if self.read_buf.capacity() > 64 * 1024 {
946 self.read_buf.clear();
947 self.read_buf.shrink_to(8192);
948 }
949 if self.write_buf.capacity() > 16 * 1024 {
950 self.write_buf.clear();
951 self.write_buf.shrink_to(8192);
952 }
953 }
954
955 Ok(affected_rows)
956 }
957
958 #[cold]
960 #[inline(never)]
961 fn execute_with_prepare(
962 &mut self,
963 sql: &str,
964 sql_hash: u64,
965 params: &[&(dyn Encode + Sync)],
966 ) -> Result<u64, DriverError> {
967 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
968
969 if params.len() > i16::MAX as usize {
970 return Err(DriverError::Protocol(format!(
971 "parameter count {} exceeds maximum {}",
972 params.len(),
973 i16::MAX
974 )));
975 }
976
977 let name = make_stmt_name(sql_hash);
978 let param_oids: smallvec::SmallVec<[u32; 8]> =
979 params.iter().map(|p| p.type_oid()).collect();
980
981 self.write_buf.clear();
982 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
983 proto::write_describe(&mut self.write_buf, b'S', &name);
984 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
985 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
986 self.stream
987 .write_all(&self.write_buf)
988 .map_err(DriverError::Io)?;
989
990 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
991 let columns = self.read_column_description()?;
992 self.query_counter += 1;
993 self.cache_stmt(
994 sql_hash,
995 StmtInfo {
996 name,
997 sql: sql.into(),
998 columns,
999 last_used: self.query_counter,
1000 bind_template: None,
1001 },
1002 );
1003
1004 let mut affected_rows: u64 = 0;
1006 'outer: loop {
1007 loop {
1008 let avail = self.stream_buf_end - self.stream_buf_pos;
1009 if avail < 5 {
1010 break;
1011 }
1012
1013 let msg_type = self.stream_buf[self.stream_buf_pos];
1014 let raw_len = i32::from_be_bytes([
1015 self.stream_buf[self.stream_buf_pos + 1],
1016 self.stream_buf[self.stream_buf_pos + 2],
1017 self.stream_buf[self.stream_buf_pos + 3],
1018 self.stream_buf[self.stream_buf_pos + 4],
1019 ]);
1020
1021 if raw_len < 4 {
1022 return Err(DriverError::Protocol(format!(
1023 "invalid message length {raw_len} for type '{}'",
1024 msg_type as char
1025 )));
1026 }
1027
1028 let payload_len = (raw_len - 4) as usize;
1029 let total_msg_len = 5 + payload_len;
1030
1031 if avail < total_msg_len {
1032 if total_msg_len > self.stream_buf.len() {
1033 let msg = self.read_one_message()?;
1034 match msg {
1035 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
1036 continue;
1037 }
1038 BackendMessage::CommandComplete { tag } => {
1039 affected_rows = proto::parse_command_tag(tag);
1040 continue;
1041 }
1042 BackendMessage::EmptyQuery => continue,
1043 BackendMessage::ReadyForQuery { status } => {
1044 self.tx_status = status;
1045 break 'outer;
1046 }
1047 BackendMessage::NoticeResponse { .. } => continue,
1048 BackendMessage::ErrorResponse { data } => {
1049 let fields = proto::parse_error_response(data);
1050 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1051 self.drain_to_ready()?;
1052 return Err(self.make_server_error(fields));
1053 }
1054 other => {
1055 return Err(DriverError::Protocol(format!(
1056 "unexpected message during execute: {other:?}"
1057 )));
1058 }
1059 }
1060 }
1061 break;
1062 }
1063
1064 let payload_start = self.stream_buf_pos + 5;
1065 let payload_end = payload_start + payload_len;
1066
1067 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
1068 } else if msg_type == b'C' {
1070 affected_rows = proto::parse_command_tag_bytes(
1071 &self.stream_buf[payload_start..payload_end],
1072 );
1073 } else if msg_type == b'Z' {
1074 if payload_len >= 1 {
1075 self.tx_status = self.stream_buf[payload_start];
1076 }
1077 self.stream_buf_pos += total_msg_len;
1078 break 'outer;
1079 } else {
1080 self.handle_non_datarow_execute(
1081 msg_type,
1082 payload_start,
1083 payload_end,
1084 sql_hash,
1085 )?;
1086 }
1087
1088 self.stream_buf_pos += total_msg_len;
1089 }
1090
1091 self.refill_stream_buf()?;
1092 }
1093
1094 Ok(affected_rows)
1095 }
1096
1097 #[inline]
1102 pub fn execute(
1103 &mut self,
1104 sql: &str,
1105 sql_hash: u64,
1106 params: &[&(dyn Encode + Sync)],
1107 ) -> Result<u64, DriverError> {
1108 self.execute_monolithic(sql, sql_hash, params)
1109 }
1110
1111 pub fn execute_pipeline(
1123 &mut self,
1124 sql: &str,
1125 sql_hash: u64,
1126 param_sets: &[&[&(dyn Encode + Sync)]],
1127 ) -> Result<Vec<u64>, DriverError> {
1128 if param_sets.is_empty() {
1129 return Ok(Vec::new());
1130 }
1131
1132 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1133
1134 self.write_buf.clear();
1135
1136 if !self.stmts.contains_key(&sql_hash, sql) {
1138 let name = make_stmt_name(sql_hash);
1139 let first_params = param_sets[0];
1140 if first_params.len() > i16::MAX as usize {
1141 return Err(DriverError::Protocol(format!(
1142 "parameter count {} exceeds maximum {}",
1143 first_params.len(),
1144 i16::MAX
1145 )));
1146 }
1147 let param_oids: smallvec::SmallVec<[u32; 8]> =
1148 first_params.iter().map(|p| p.type_oid()).collect();
1149 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1150 proto::write_describe(&mut self.write_buf, b'S', &name);
1151 proto::write_sync(&mut self.write_buf);
1152 self.flush_write()?;
1153
1154 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1155 let columns = self.read_column_description()?;
1156 self.expect_ready()?;
1157
1158 self.query_counter += 1;
1159 self.cache_stmt(
1160 sql_hash,
1161 StmtInfo {
1162 name,
1163 sql: sql.into(),
1164 columns,
1165 last_used: self.query_counter,
1166 bind_template: None,
1167 },
1168 );
1169
1170 self.write_buf.clear();
1171 }
1172
1173 let stmt_name = self
1175 .stmts
1176 .get(&sql_hash, sql)
1177 .expect("BUG: stmt just cached but not found")
1178 .name;
1179 let count = param_sets.len();
1180
1181 for params in param_sets {
1182 if params.len() > i16::MAX as usize {
1183 return Err(DriverError::Protocol(format!(
1184 "parameter count {} exceeds maximum {}",
1185 params.len(),
1186 i16::MAX
1187 )));
1188 }
1189 proto::write_bind_params(&mut self.write_buf, b"", &stmt_name, params);
1190 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1191 }
1192
1193 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1194 self.flush_write()?;
1195
1196 let mut results = Vec::with_capacity(count);
1199
1200 'outer: loop {
1201 loop {
1202 let Some((msg_type, start, end, total)) = self.peek_stream_msg()? else {
1203 break; };
1205
1206 if msg_type == b'2' {
1207 } else if msg_type == b'C' {
1209 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1211 results.push(rows);
1212 } else if msg_type == b'Z' {
1213 if end > start {
1215 self.tx_status = self.stream_buf[start];
1216 }
1217 self.advance_stream_msg(total);
1218 break 'outer;
1219 } else if msg_type == b'I' {
1220 results.push(0);
1222 } else if msg_type == b'D' || msg_type == b'N' {
1223 } else if msg_type == b'E' {
1225 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1227 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1228 self.advance_stream_msg(total);
1229 self.drain_to_ready()?;
1230 return Err(self.make_server_error(fields));
1231 } else if msg_type == b'A' {
1232 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1234 if let BackendMessage::NotificationResponse {
1235 pid,
1236 channel,
1237 payload,
1238 } = msg
1239 {
1240 let ch = channel.to_owned();
1241 let pl = payload.to_owned();
1242 self.buffer_notification(pid, &ch, &pl);
1243 }
1244 }
1245 self.advance_stream_msg(total);
1248 }
1249
1250 self.refill_stream_buf()?;
1252 }
1253
1254 self.shrink_buffers();
1255 Ok(results)
1256 }
1257
1258 pub(crate) fn ensure_stmt_prepared(
1264 &mut self,
1265 sql: &str,
1266 sql_hash: u64,
1267 params: &[&(dyn Encode + Sync)],
1268 ) -> Result<[u8; 18], DriverError> {
1269 if let Some(info) = self.stmts.get(&sql_hash, sql) {
1270 return Ok(info.name);
1271 }
1272
1273 let name = make_stmt_name(sql_hash);
1274 if params.len() > i16::MAX as usize {
1275 return Err(DriverError::Protocol(format!(
1276 "parameter count {} exceeds maximum {}",
1277 params.len(),
1278 i16::MAX
1279 )));
1280 }
1281 let param_oids: smallvec::SmallVec<[u32; 8]> =
1282 params.iter().map(|p| p.type_oid()).collect();
1283
1284 self.write_buf.clear();
1285 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1286 proto::write_describe(&mut self.write_buf, b'S', &name);
1287 proto::write_sync(&mut self.write_buf);
1288 self.flush_write()?;
1289
1290 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1291 let columns = self.read_column_description()?;
1292 self.expect_ready()?;
1293
1294 self.query_counter += 1;
1295 self.cache_stmt(
1296 sql_hash,
1297 StmtInfo {
1298 name,
1299 sql: sql.into(),
1300 columns,
1301 last_used: self.query_counter,
1302 bind_template: None,
1303 },
1304 );
1305
1306 Ok(name)
1307 }
1308
1309 pub(crate) fn write_deferred_bind_execute(
1312 &self,
1313 sql: &str,
1314 sql_hash: u64,
1315 params: &[&(dyn Encode + Sync)],
1316 buf: &mut Vec<u8>,
1317 ) {
1318 let stmt_name = self
1319 .stmts
1320 .get(&sql_hash, sql)
1321 .expect("BUG: stmt just cached but not found")
1322 .name;
1323 proto::write_bind_params(buf, b"", &stmt_name, params);
1324 buf.extend_from_slice(proto::EXECUTE_ONLY);
1325 }
1326
1327 pub(crate) fn flush_deferred_pipeline(
1332 &mut self,
1333 buf: &mut Vec<u8>,
1334 count: usize,
1335 ) -> Result<Vec<u64>, DriverError> {
1336 if count == 0 {
1337 buf.clear();
1338 return Ok(Vec::new());
1339 }
1340
1341 buf.extend_from_slice(proto::SYNC_ONLY);
1342
1343 self.stream.write_all(buf).map_err(DriverError::Io)?;
1344 buf.clear();
1345
1346 let mut results = Vec::with_capacity(count);
1348
1349 'outer: loop {
1350 loop {
1351 let Some((msg_type, start, end, total)) = self.peek_stream_msg()? else {
1352 break; };
1354
1355 if msg_type == b'2' {
1356 } else if msg_type == b'C' {
1358 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1360 results.push(rows);
1361 } else if msg_type == b'Z' {
1362 if end > start {
1364 self.tx_status = self.stream_buf[start];
1365 }
1366 self.advance_stream_msg(total);
1367 break 'outer;
1368 } else if msg_type == b'I' {
1369 results.push(0);
1371 } else if msg_type == b'D' || msg_type == b'N' {
1372 } else if msg_type == b'E' {
1374 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1376 self.advance_stream_msg(total);
1377 self.drain_to_ready()?;
1378 return Err(self.make_server_error(fields));
1379 } else if msg_type == b'A' {
1380 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1382 if let BackendMessage::NotificationResponse {
1383 pid,
1384 channel,
1385 payload,
1386 } = msg
1387 {
1388 let ch = channel.to_owned();
1389 let pl = payload.to_owned();
1390 self.buffer_notification(pid, &ch, &pl);
1391 }
1392 }
1393 self.advance_stream_msg(total);
1396 }
1397
1398 self.refill_stream_buf()?;
1400 }
1401
1402 self.shrink_buffers();
1403 Ok(results)
1404 }
1405
1406 pub fn for_each<F>(
1408 &mut self,
1409 sql: &str,
1410 sql_hash: u64,
1411 params: &[&(dyn Encode + Sync)],
1412 mut f: F,
1413 ) -> Result<(), DriverError>
1414 where
1415 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1416 {
1417 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1418
1419 'outer: loop {
1421 loop {
1422 let avail = self.stream_buf_end - self.stream_buf_pos;
1423 if avail < 5 {
1424 break; }
1426
1427 let msg_type = self.stream_buf[self.stream_buf_pos];
1428 let raw_len = i32::from_be_bytes([
1429 self.stream_buf[self.stream_buf_pos + 1],
1430 self.stream_buf[self.stream_buf_pos + 2],
1431 self.stream_buf[self.stream_buf_pos + 3],
1432 self.stream_buf[self.stream_buf_pos + 4],
1433 ]);
1434
1435 if raw_len < 4 {
1436 return Err(DriverError::Protocol(format!(
1437 "invalid message length {raw_len} for type '{}'",
1438 msg_type as char
1439 )));
1440 }
1441
1442 let payload_len = (raw_len - 4) as usize;
1443 let total_msg_len = 5 + payload_len;
1444
1445 if avail < total_msg_len {
1446 if total_msg_len > self.stream_buf.len() {
1447 let msg = self.read_one_message()?;
1449 match msg {
1450 BackendMessage::BindComplete => continue,
1451 BackendMessage::DataRow { data } => {
1452 let row = PgDataRow::new(data)?;
1453 f(row)?;
1454 continue;
1455 }
1456 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1457 continue;
1458 }
1459 BackendMessage::ReadyForQuery { status } => {
1460 self.tx_status = status;
1461 break 'outer;
1462 }
1463 BackendMessage::NoticeResponse { .. } => continue,
1464 BackendMessage::ErrorResponse { data } => {
1465 let fields = proto::parse_error_response(data);
1466 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1467 self.drain_to_ready()?;
1468 return Err(self.make_server_error(fields));
1469 }
1470 other => {
1471 return Err(DriverError::Protocol(format!(
1472 "unexpected message during for_each: {other:?}"
1473 )));
1474 }
1475 }
1476 }
1477 break; }
1479
1480 let payload_start = self.stream_buf_pos + 5;
1482 let payload_end = payload_start + payload_len;
1483
1484 if msg_type == b'D' {
1487 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1489 f(row)?;
1490 } else if msg_type == b'Z' {
1491 if payload_len >= 1 {
1493 self.tx_status = self.stream_buf[payload_start];
1494 }
1495 self.stream_buf_pos += total_msg_len;
1496 break 'outer;
1497 } else {
1498 self.handle_non_datarow_execute(
1499 msg_type,
1500 payload_start,
1501 payload_end,
1502 sql_hash,
1503 )?;
1504 }
1505
1506 self.stream_buf_pos += total_msg_len;
1507 }
1508
1509 self.refill_stream_buf()?;
1511 }
1512
1513 self.shrink_buffers();
1514 Ok(())
1515 }
1516
1517 #[inline]
1528 pub fn for_each_raw_monolithic<F>(
1529 &mut self,
1530 sql: &str,
1531 sql_hash: u64,
1532 params: &[&(dyn Encode + Sync)],
1533 mut f: F,
1534 ) -> Result<(), DriverError>
1535 where
1536 F: FnMut(&[u8]) -> Result<(), DriverError>,
1537 {
1538 self.write_buf.clear();
1540
1541 let info = match self.stmts.get_mut(&sql_hash, sql) {
1543 Some(info) => {
1544 self.query_counter += 1;
1545 info.last_used = self.query_counter;
1546 info
1547 }
1548 None => {
1549 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1551 }
1552 };
1553
1554 let can_use_template = info
1556 .bind_template
1557 .as_ref()
1558 .is_some_and(|t| t.param_slots.len() == params.len());
1559
1560 let mut has_exec_sync = false;
1561
1562 if can_use_template {
1563 let tmpl = info
1565 .bind_template
1566 .as_ref()
1567 .expect("guarded by can_use_template");
1568 self.write_buf.extend_from_slice(&tmpl.bytes);
1569
1570 let mut template_ok = true;
1571 for (i, param) in params.iter().enumerate() {
1572 let (data_offset, old_len) = tmpl.param_slots[i];
1573 if param.is_null() {
1574 let len_offset = data_offset - 4;
1575 self.write_buf[len_offset..len_offset + 4]
1576 .copy_from_slice(&(-1i32).to_be_bytes());
1577 } else if old_len >= 0 {
1578 let end = data_offset + old_len as usize;
1579 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1580 template_ok = false;
1581 break;
1582 }
1583 } else {
1584 template_ok = false;
1585 break;
1586 }
1587 }
1588
1589 if template_ok {
1590 has_exec_sync = true;
1591 } else {
1592 self.write_buf.clear();
1593 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1594 info.bind_template = None;
1595 }
1596 } else {
1597 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1598 }
1599
1600 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1602 info.bind_template = build_bind_template(&self.write_buf, params.len());
1603 }
1604
1605 if !has_exec_sync {
1606 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1607 }
1608
1609 self.stream
1611 .write_all(&self.write_buf)
1612 .map_err(DriverError::Io)?;
1613
1614 loop {
1618 let avail = self.stream_buf_end - self.stream_buf_pos;
1619 if avail >= 5 {
1620 let bc_type = self.stream_buf[self.stream_buf_pos];
1621 match bc_type {
1622 b'2' => {
1623 self.stream_buf_pos += 5;
1624 break;
1625 }
1626 b'E' => {
1627 let msg = self.read_one_message()?;
1628 if let BackendMessage::ErrorResponse { data } = msg {
1629 let fields = proto::parse_error_response(data);
1630 self.drain_to_ready()?;
1631 return Err(self.make_server_error(fields));
1632 }
1633 }
1634 b'N' | b'S' => {
1635 let raw_len = i32::from_be_bytes([
1636 self.stream_buf[self.stream_buf_pos + 1],
1637 self.stream_buf[self.stream_buf_pos + 2],
1638 self.stream_buf[self.stream_buf_pos + 3],
1639 self.stream_buf[self.stream_buf_pos + 4],
1640 ]);
1641 let total = 1 + raw_len as usize;
1642 if avail >= total {
1643 self.stream_buf_pos += total;
1644 continue;
1645 }
1646 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1647 break;
1648 }
1649 _ => {
1650 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1651 break;
1652 }
1653 }
1654 } else {
1655 let remaining = self.stream_buf_end - self.stream_buf_pos;
1657 if remaining > 0 && self.stream_buf_pos > 0 {
1658 self.stream_buf
1659 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1660 }
1661 self.stream_buf_pos = 0;
1662 self.stream_buf_end = remaining;
1663 let n = self
1664 .stream
1665 .read(&mut self.stream_buf[remaining..])
1666 .map_err(DriverError::Io)?;
1667 if n == 0 {
1668 return Err(DriverError::Io(std::io::Error::new(
1669 std::io::ErrorKind::UnexpectedEof,
1670 "connection closed",
1671 )));
1672 }
1673 self.stream_buf_end = remaining + n;
1674 }
1675 }
1676
1677 'outer: loop {
1679 loop {
1680 let avail = self.stream_buf_end - self.stream_buf_pos;
1681 if avail < 5 {
1682 break;
1683 }
1684
1685 let msg_type = self.stream_buf[self.stream_buf_pos];
1686 let raw_len = i32::from_be_bytes([
1687 self.stream_buf[self.stream_buf_pos + 1],
1688 self.stream_buf[self.stream_buf_pos + 2],
1689 self.stream_buf[self.stream_buf_pos + 3],
1690 self.stream_buf[self.stream_buf_pos + 4],
1691 ]);
1692
1693 if raw_len < 4 {
1694 return Err(DriverError::Protocol(format!(
1695 "invalid message length {raw_len} for type '{}'",
1696 msg_type as char
1697 )));
1698 }
1699
1700 let payload_len = (raw_len - 4) as usize;
1701 let total_msg_len = 5 + payload_len;
1702
1703 if avail < total_msg_len {
1704 if total_msg_len > self.stream_buf.len() {
1705 let msg = self.read_one_message()?;
1706 match msg {
1707 BackendMessage::DataRow { data } => {
1708 f(data)?;
1709 continue;
1710 }
1711 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1712 continue;
1713 }
1714 BackendMessage::ReadyForQuery { status } => {
1715 self.tx_status = status;
1716 break 'outer;
1717 }
1718 BackendMessage::ErrorResponse { data } => {
1719 let fields = proto::parse_error_response(data);
1720 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1721 self.drain_to_ready()?;
1722 return Err(self.make_server_error(fields));
1723 }
1724 BackendMessage::NoticeResponse { .. } => continue,
1725 other => {
1726 return Err(DriverError::Protocol(format!(
1727 "unexpected message during for_each_raw: {other:?}"
1728 )));
1729 }
1730 }
1731 }
1732 break; }
1734
1735 let payload_start = self.stream_buf_pos + 5;
1737 let payload_end = payload_start + payload_len;
1738
1739 if msg_type == b'D' {
1740 f(&self.stream_buf[payload_start..payload_end])?;
1741 } else if msg_type == b'Z' {
1742 if payload_len >= 1 {
1743 self.tx_status = self.stream_buf[payload_start];
1744 }
1745 self.stream_buf_pos += total_msg_len;
1746 break 'outer;
1747 } else {
1748 self.handle_non_datarow_execute(
1749 msg_type,
1750 payload_start,
1751 payload_end,
1752 sql_hash,
1753 )?;
1754 }
1755
1756 self.stream_buf_pos += total_msg_len;
1757 }
1758
1759 let remaining = self.stream_buf_end - self.stream_buf_pos;
1761 if remaining > 0 && self.stream_buf_pos > 0 {
1762 self.stream_buf
1763 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1764 }
1765 self.stream_buf_pos = 0;
1766 self.stream_buf_end = remaining;
1767 let n = self
1768 .stream
1769 .read(&mut self.stream_buf[remaining..])
1770 .map_err(DriverError::Io)?;
1771 if n == 0 {
1772 return Err(DriverError::Io(std::io::Error::new(
1773 std::io::ErrorKind::UnexpectedEof,
1774 "connection closed",
1775 )));
1776 }
1777 self.stream_buf_end = remaining + n;
1778 }
1779
1780 if self.query_counter & 63 == 0 {
1782 if self.read_buf.capacity() > 64 * 1024 {
1783 self.read_buf.clear();
1784 self.read_buf.shrink_to(8192);
1785 }
1786 if self.write_buf.capacity() > 16 * 1024 {
1787 self.write_buf.clear();
1788 self.write_buf.shrink_to(8192);
1789 }
1790 }
1791
1792 Ok(())
1793 }
1794
1795 #[cold]
1797 #[inline(never)]
1798 fn for_each_raw_with_prepare<F>(
1799 &mut self,
1800 sql: &str,
1801 sql_hash: u64,
1802 params: &[&(dyn Encode + Sync)],
1803 mut f: F,
1804 ) -> Result<(), DriverError>
1805 where
1806 F: FnMut(&[u8]) -> Result<(), DriverError>,
1807 {
1808 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1809
1810 if params.len() > i16::MAX as usize {
1811 return Err(DriverError::Protocol(format!(
1812 "parameter count {} exceeds maximum {}",
1813 params.len(),
1814 i16::MAX
1815 )));
1816 }
1817
1818 let name = make_stmt_name(sql_hash);
1819 let param_oids: smallvec::SmallVec<[u32; 8]> =
1820 params.iter().map(|p| p.type_oid()).collect();
1821
1822 self.write_buf.clear();
1823 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1824 proto::write_describe(&mut self.write_buf, b'S', &name);
1825 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
1826 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1827 self.stream
1828 .write_all(&self.write_buf)
1829 .map_err(DriverError::Io)?;
1830
1831 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1832 let columns = self.read_column_description()?;
1833 self.query_counter += 1;
1834 self.cache_stmt(
1835 sql_hash,
1836 StmtInfo {
1837 name,
1838 sql: sql.into(),
1839 columns,
1840 last_used: self.query_counter,
1841 bind_template: None,
1842 },
1843 );
1844
1845 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1847
1848 'outer: loop {
1849 loop {
1850 let avail = self.stream_buf_end - self.stream_buf_pos;
1851 if avail < 5 {
1852 break;
1853 }
1854
1855 let msg_type = self.stream_buf[self.stream_buf_pos];
1856 let raw_len = i32::from_be_bytes([
1857 self.stream_buf[self.stream_buf_pos + 1],
1858 self.stream_buf[self.stream_buf_pos + 2],
1859 self.stream_buf[self.stream_buf_pos + 3],
1860 self.stream_buf[self.stream_buf_pos + 4],
1861 ]);
1862
1863 if raw_len < 4 {
1864 return Err(DriverError::Protocol(format!(
1865 "invalid message length {raw_len} for type '{}'",
1866 msg_type as char
1867 )));
1868 }
1869
1870 let payload_len = (raw_len - 4) as usize;
1871 let total_msg_len = 5 + payload_len;
1872
1873 if avail < total_msg_len {
1874 if total_msg_len > self.stream_buf.len() {
1875 let msg = self.read_one_message()?;
1876 match msg {
1877 BackendMessage::DataRow { data } => {
1878 f(data)?;
1879 continue;
1880 }
1881 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1882 continue;
1883 }
1884 BackendMessage::ReadyForQuery { status } => {
1885 self.tx_status = status;
1886 break 'outer;
1887 }
1888 BackendMessage::ErrorResponse { data } => {
1889 let fields = proto::parse_error_response(data);
1890 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1891 self.drain_to_ready()?;
1892 return Err(self.make_server_error(fields));
1893 }
1894 BackendMessage::NoticeResponse { .. } => continue,
1895 other => {
1896 return Err(DriverError::Protocol(format!(
1897 "unexpected message during for_each_raw: {other:?}"
1898 )));
1899 }
1900 }
1901 }
1902 break;
1903 }
1904
1905 let payload_start = self.stream_buf_pos + 5;
1906 let payload_end = payload_start + payload_len;
1907
1908 if msg_type == b'D' {
1909 f(&self.stream_buf[payload_start..payload_end])?;
1910 } else if msg_type == b'Z' {
1911 if payload_len >= 1 {
1912 self.tx_status = self.stream_buf[payload_start];
1913 }
1914 self.stream_buf_pos += total_msg_len;
1915 break 'outer;
1916 } else {
1917 self.handle_non_datarow_execute(
1918 msg_type,
1919 payload_start,
1920 payload_end,
1921 sql_hash,
1922 )?;
1923 }
1924
1925 self.stream_buf_pos += total_msg_len;
1926 }
1927
1928 self.refill_stream_buf()?;
1929 }
1930
1931 self.shrink_buffers();
1932 Ok(())
1933 }
1934
1935 #[inline]
1940 pub fn for_each_raw<F>(
1941 &mut self,
1942 sql: &str,
1943 sql_hash: u64,
1944 params: &[&(dyn Encode + Sync)],
1945 f: F,
1946 ) -> Result<(), DriverError>
1947 where
1948 F: FnMut(&[u8]) -> Result<(), DriverError>,
1949 {
1950 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1951 }
1952
1953 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1955 self.write_buf.clear();
1956 proto::write_simple_query(&mut self.write_buf, sql);
1957 self.flush_write()?;
1958
1959 loop {
1960 let msg = self.read_one_message()?;
1961 match msg {
1962 BackendMessage::ReadyForQuery { status } => {
1963 self.tx_status = status;
1964 return Ok(());
1965 }
1966 BackendMessage::CommandComplete { .. }
1967 | BackendMessage::RowDescription { .. }
1968 | BackendMessage::DataRow { .. }
1969 | BackendMessage::EmptyQuery
1970 | BackendMessage::NoticeResponse { .. }
1971 | BackendMessage::ParameterStatus { .. }
1972 | BackendMessage::AuthOk
1976 | BackendMessage::AuthSaslFinal { .. }
1977 | BackendMessage::BackendKeyData { .. } => {}
1978 BackendMessage::ErrorResponse { data } => {
1979 let fields = proto::parse_error_response(data);
1980 self.drain_to_ready()?;
1981 return Err(self.make_server_error(fields));
1982 }
1983 other => {
1984 return Err(DriverError::Protocol(format!(
1985 "unexpected message during simple_query: {other:?}"
1986 )));
1987 }
1988 }
1989 }
1990 }
1991
1992 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1994 self.write_buf.clear();
1995 proto::write_simple_query(&mut self.write_buf, sql);
1996 self.flush_write()?;
1997
1998 let mut rows: Vec<SimpleRow> = Vec::new();
1999 loop {
2000 let msg = self.read_one_message()?;
2001 match msg {
2002 BackendMessage::ReadyForQuery { status } => {
2003 self.tx_status = status;
2004 return Ok(rows);
2005 }
2006 BackendMessage::DataRow { data } => {
2007 rows.push(proto::parse_simple_data_row(data)?);
2008 }
2009 BackendMessage::RowDescription { .. }
2010 | BackendMessage::CommandComplete { .. }
2011 | BackendMessage::EmptyQuery
2012 | BackendMessage::NoticeResponse { .. }
2013 | BackendMessage::ParameterStatus { .. }
2014 | BackendMessage::AuthOk
2015 | BackendMessage::AuthSaslFinal { .. }
2016 | BackendMessage::BackendKeyData { .. } => {}
2017 BackendMessage::ErrorResponse { data } => {
2018 let fields = proto::parse_error_response(data);
2019 self.drain_to_ready()?;
2020 return Err(self.make_server_error(fields));
2021 }
2022 other => {
2023 return Err(DriverError::Protocol(format!(
2024 "unexpected message during simple_query_rows: {other:?}"
2025 )));
2026 }
2027 }
2028 }
2029 }
2030
2031 pub fn copy_in<'a, I>(
2053 &mut self,
2054 table: &str,
2055 columns: &[&str],
2056 rows: I,
2057 ) -> Result<u64, DriverError>
2058 where
2059 I: IntoIterator<Item = &'a str>,
2060 {
2061 let quoted_table = proto::quote_ident(table);
2063 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
2064 let sql = format!(
2065 "COPY {}({}) FROM STDIN",
2066 quoted_table,
2067 quoted_cols.join(",")
2068 );
2069
2070 self.write_buf.clear();
2072 proto::write_simple_query(&mut self.write_buf, &sql);
2073 self.flush_write()?;
2074
2075 loop {
2077 let msg = self.read_one_message()?;
2078 match msg {
2079 BackendMessage::CopyInResponse { .. } => break,
2080 BackendMessage::ErrorResponse { data } => {
2081 let fields = proto::parse_error_response(data);
2082 self.drain_to_ready()?;
2083 return Err(self.make_server_error(fields));
2084 }
2085 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2086 other => {
2087 return Err(DriverError::Protocol(format!(
2088 "expected CopyInResponse, got: {other:?}"
2089 )));
2090 }
2091 }
2092 }
2093
2094 self.write_buf.clear();
2104 for row in rows {
2105 let row_data = row.as_bytes();
2107 let data_len = (4 + row_data.len() + 1) as i32;
2108 self.write_buf.push(b'd');
2109 self.write_buf.extend_from_slice(&data_len.to_be_bytes());
2110 self.write_buf.extend_from_slice(row_data);
2111 self.write_buf.push(b'\n');
2112 if self.write_buf.len() > 65536 {
2114 self.flush_write()?;
2115 self.write_buf.clear();
2116 }
2117 }
2118 proto::write_copy_done(&mut self.write_buf);
2121 self.flush_write()?;
2122 self.write_buf.clear();
2123
2124 let mut count: u64 = 0;
2126 loop {
2127 let msg = self.read_one_message()?;
2128 match msg {
2129 BackendMessage::CommandComplete { tag } => {
2130 count = proto::parse_command_tag(tag);
2131 }
2132 BackendMessage::ReadyForQuery { status } => {
2133 self.tx_status = status;
2134 return Ok(count);
2135 }
2136 BackendMessage::ErrorResponse { data } => {
2137 let fields = proto::parse_error_response(data);
2138 self.drain_to_ready()?;
2139 return Err(self.make_server_error(fields));
2140 }
2141 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2142 other => {
2143 return Err(DriverError::Protocol(format!(
2144 "unexpected message during copy_in completion: {other:?}"
2145 )));
2146 }
2147 }
2148 }
2149 }
2150
2151 pub fn copy_out<W: std::io::Write>(
2171 &mut self,
2172 query: &str,
2173 writer: &mut W,
2174 ) -> Result<u64, DriverError> {
2175 let sql = format!("COPY ({query}) TO STDOUT");
2177
2178 self.write_buf.clear();
2180 proto::write_simple_query(&mut self.write_buf, &sql);
2181 self.flush_write()?;
2182
2183 loop {
2185 let msg = self.read_one_message()?;
2186 match msg {
2187 BackendMessage::CopyOutResponse { .. } => break,
2188 BackendMessage::ErrorResponse { data } => {
2189 let fields = proto::parse_error_response(data);
2190 self.drain_to_ready()?;
2191 return Err(self.make_server_error(fields));
2192 }
2193 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2194 other => {
2195 return Err(DriverError::Protocol(format!(
2196 "expected CopyOutResponse, got: {other:?}"
2197 )));
2198 }
2199 }
2200 }
2201
2202 loop {
2204 let msg = self.read_one_message()?;
2205 match msg {
2206 BackendMessage::CopyData { data } => {
2207 writer.write_all(data).map_err(DriverError::Io)?;
2208 }
2209 BackendMessage::CopyDone => break,
2210 BackendMessage::ErrorResponse { data } => {
2211 let fields = proto::parse_error_response(data);
2212 self.drain_to_ready()?;
2213 return Err(self.make_server_error(fields));
2214 }
2215 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2216 other => {
2217 return Err(DriverError::Protocol(format!(
2218 "unexpected message during copy_out data: {other:?}"
2219 )));
2220 }
2221 }
2222 }
2223
2224 let mut count: u64 = 0;
2226 loop {
2227 let msg = self.read_one_message()?;
2228 match msg {
2229 BackendMessage::CommandComplete { tag } => {
2230 count = proto::parse_command_tag(tag);
2231 }
2232 BackendMessage::ReadyForQuery { status } => {
2233 self.tx_status = status;
2234 return Ok(count);
2235 }
2236 BackendMessage::ErrorResponse { data } => {
2237 let fields = proto::parse_error_response(data);
2238 self.drain_to_ready()?;
2239 return Err(self.make_server_error(fields));
2240 }
2241 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2242 other => {
2243 return Err(DriverError::Protocol(format!(
2244 "unexpected message during copy_out completion: {other:?}"
2245 )));
2246 }
2247 }
2248 }
2249 }
2250
2251 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2256 self.write_buf.clear();
2257 proto::write_parse(&mut self.write_buf, b"", sql, &[]);
2260 proto::write_describe(&mut self.write_buf, b'S', b"");
2261 proto::write_sync(&mut self.write_buf);
2262 self.flush_write()?;
2263
2264 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2266
2267 let mut param_oids: Vec<u32> = Vec::new();
2269 let columns;
2270 loop {
2271 let msg = self.read_one_message()?;
2272 match msg {
2273 BackendMessage::ParameterDescription { data } => {
2274 param_oids = proto::parse_parameter_description(data)?;
2275 }
2276 BackendMessage::RowDescription { data } => {
2277 columns = proto::parse_row_description(data)?;
2278 break;
2279 }
2280 BackendMessage::NoData => {
2281 columns = Vec::new();
2282 break;
2283 }
2284 BackendMessage::NoticeResponse { .. } => {}
2285 BackendMessage::ErrorResponse { data } => {
2286 let fields = proto::parse_error_response(data);
2287 self.drain_to_ready()?;
2288 return Err(self.make_server_error(fields));
2289 }
2290 other => {
2291 return Err(DriverError::Protocol(format!(
2292 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2293 )));
2294 }
2295 }
2296 }
2297
2298 self.expect_ready()?;
2300
2301 Ok(PrepareResult {
2302 columns,
2303 param_oids,
2304 })
2305 }
2306
2307 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2316 loop {
2317 let (msg_type, _payload_len) = self.read_message_buffered()?;
2318 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2319 match msg {
2320 BackendMessage::NotificationResponse {
2321 channel, payload, ..
2322 } => {
2323 return Ok((channel.to_owned(), payload.to_owned()));
2324 }
2325 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2326 continue;
2327 }
2328 _ => continue,
2329 }
2330 }
2331 }
2332
2333 pub fn cancel(&self) -> Result<(), DriverError> {
2339 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2340 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2341 let mut buf = Vec::with_capacity(16);
2342 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2343 tcp.write_all(&buf).map_err(DriverError::Io)?;
2344 tcp.flush().map_err(DriverError::Io)?;
2345 drop(tcp);
2347 Ok(())
2348 }
2349
2350 pub fn set_read_timeout(
2355 &self,
2356 timeout: Option<std::time::Duration>,
2357 ) -> Result<(), DriverError> {
2358 self.stream
2359 .set_read_timeout(timeout)
2360 .map_err(DriverError::Io)
2361 }
2362
2363 pub fn query_streaming_start(
2377 &mut self,
2378 sql: &str,
2379 sql_hash: u64,
2380 params: &[&(dyn Encode + Sync)],
2381 chunk_size: i32,
2382 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2383 self.write_buf.clear();
2384
2385 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2386 self.query_counter += 1;
2388 info.last_used = self.query_counter;
2389
2390 let can_use_template = info
2391 .bind_template
2392 .as_ref()
2393 .is_some_and(|t| t.param_slots.len() == params.len());
2394
2395 if can_use_template {
2396 let tmpl = info
2398 .bind_template
2399 .as_ref()
2400 .expect("guarded by can_use_template");
2401 self.write_buf
2404 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2405
2406 let mut template_ok = true;
2407 for (i, param) in params.iter().enumerate() {
2408 let (data_offset, old_len) = tmpl.param_slots[i];
2409 if param.is_null() {
2410 let len_offset = data_offset - 4;
2411 self.write_buf[len_offset..len_offset + 4]
2412 .copy_from_slice(&(-1i32).to_be_bytes());
2413 } else if old_len >= 0 {
2414 let end = data_offset + old_len as usize;
2415 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2416 template_ok = false;
2417 break;
2418 }
2419 } else {
2420 template_ok = false;
2421 break;
2422 }
2423 }
2424
2425 if !template_ok {
2426 self.write_buf.clear();
2427 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2428 info.bind_template = None;
2429 }
2430 } else {
2431 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2432 }
2433
2434 let cols = info.columns.clone();
2435
2436 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2437 info.bind_template = build_bind_template(&self.write_buf, params.len());
2438 }
2439
2440 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2441 proto::write_flush(&mut self.write_buf);
2443 self.flush_write()?;
2444
2445 cols
2446 } else {
2447 let name = make_stmt_name(sql_hash);
2449 let param_oids: smallvec::SmallVec<[u32; 8]> =
2450 params.iter().map(|p| p.type_oid()).collect();
2451 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2452 proto::write_describe(&mut self.write_buf, b'S', &name);
2453 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2454
2455 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2456 proto::write_flush(&mut self.write_buf);
2457 self.flush_write()?;
2458
2459 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2460 let columns = self.read_column_description()?;
2461 self.query_counter += 1;
2462 self.cache_stmt(
2463 sql_hash,
2464 StmtInfo {
2465 name,
2466 sql: sql.into(),
2467 columns: columns.clone(),
2468 last_used: self.query_counter,
2469 bind_template: None,
2470 },
2471 );
2472 columns
2473 };
2474
2475 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2477
2478 self.streaming_active = true;
2479
2480 Ok((columns, false))
2481 }
2482
2483 pub fn streaming_next_chunk(
2491 &mut self,
2492 arena: &mut Arena,
2493 all_col_offsets: &mut Vec<(usize, i32)>,
2494 ) -> Result<bool, DriverError> {
2495 all_col_offsets.clear();
2496
2497 loop {
2498 let msg = self.read_one_message()?;
2499 match msg {
2500 BackendMessage::DataRow { data } => {
2501 parse_data_row_flat(data, arena, all_col_offsets)?;
2502 }
2503 BackendMessage::PortalSuspended => {
2504 return Ok(true);
2508 }
2509 BackendMessage::CommandComplete { .. } => {
2510 self.write_buf.clear();
2513 proto::write_sync(&mut self.write_buf);
2514 self.flush_write()?;
2515 self.expect_ready()?;
2516 self.shrink_buffers();
2517
2518 self.streaming_active = false;
2519 return Ok(false);
2520 }
2521 BackendMessage::EmptyQuery => {
2522 self.write_buf.clear();
2523 proto::write_sync(&mut self.write_buf);
2524 self.flush_write()?;
2525 self.expect_ready()?;
2526
2527 self.streaming_active = false;
2528 return Ok(false);
2529 }
2530 BackendMessage::ErrorResponse { data } => {
2531 let fields = proto::parse_error_response(data);
2532 self.write_buf.clear();
2534 proto::write_sync(&mut self.write_buf);
2535 self.flush_write()?;
2536 self.drain_to_ready()?;
2537
2538 self.streaming_active = false;
2539 return Err(self.make_server_error(fields));
2540 }
2541 BackendMessage::NoticeResponse { .. } => {}
2542 other => {
2543 return Err(DriverError::Protocol(format!(
2544 "unexpected message during streaming: {other:?}"
2545 )));
2546 }
2547 }
2548 }
2549 }
2550
2551 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2559 self.write_buf.clear();
2560 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2561 proto::write_flush(&mut self.write_buf);
2562 self.flush_write()
2563 }
2564
2565 pub fn is_streaming(&self) -> bool {
2567 self.streaming_active
2568 }
2569
2570 pub fn close(mut self) -> Result<(), DriverError> {
2572 self.write_buf.clear();
2573 proto::write_terminate(&mut self.write_buf);
2574 let _ = self.flush_write();
2575 Ok(())
2576 }
2577
2578 pub fn is_idle(&self) -> bool {
2582 self.tx_status == b'I'
2583 }
2584
2585 pub fn is_in_transaction(&self) -> bool {
2587 self.tx_status == b'T'
2588 }
2589
2590 pub fn is_in_failed_transaction(&self) -> bool {
2592 self.tx_status == b'E'
2593 }
2594
2595 pub fn touch(&mut self) {
2597 self.last_used = std::time::Instant::now();
2598 }
2599
2600 pub fn idle_duration(&self) -> std::time::Duration {
2602 self.last_used.elapsed()
2603 }
2604
2605 pub fn query_counter(&self) -> u64 {
2607 self.query_counter
2608 }
2609
2610 pub fn parameter(&self, name: &str) -> Option<&str> {
2612 self.params
2613 .iter()
2614 .find(|(k, _)| &**k == name)
2615 .map(|(_, v)| &**v)
2616 }
2617
2618 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2620 &self.params
2621 }
2622
2623 pub fn pid(&self) -> i32 {
2625 self.pid
2626 }
2627
2628 pub fn secret_key(&self) -> i32 {
2630 self.secret
2631 }
2632
2633 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2635 std::mem::take(&mut self.pending_notifications)
2636 }
2637
2638 pub fn pending_notification_count(&self) -> usize {
2640 self.pending_notifications.len()
2641 }
2642
2643 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2645 self.max_stmt_cache_size = size;
2646 }
2647
2648 pub fn stmt_cache_len(&self) -> usize {
2650 self.stmts.len()
2651 }
2652
2653 pub fn created_at(&self) -> std::time::Instant {
2655 self.created_at
2656 }
2657
2658 #[inline]
2666 fn send_pipeline(
2667 &mut self,
2668 sql: &str,
2669 sql_hash: u64,
2670 params: &[&(dyn Encode + Sync)],
2671 need_columns: bool,
2672 skip_bind_complete: bool,
2673 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
2674 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2675
2676 if params.len() > i16::MAX as usize {
2677 return Err(DriverError::Protocol(format!(
2678 "parameter count {} exceeds maximum {}",
2679 params.len(),
2680 i16::MAX
2681 )));
2682 }
2683
2684 self.write_buf.clear();
2685
2686 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2687 self.query_counter += 1;
2689 info.last_used = self.query_counter;
2690
2691 let can_use_template = info
2692 .bind_template
2693 .as_ref()
2694 .is_some_and(|t| t.param_slots.len() == params.len());
2695
2696 let mut has_exec_sync = false;
2698
2699 if can_use_template {
2700 let tmpl = info
2704 .bind_template
2705 .as_ref()
2706 .expect("guarded by can_use_template");
2707 self.write_buf.extend_from_slice(&tmpl.bytes);
2708
2709 let mut template_ok = true;
2710 for (i, param) in params.iter().enumerate() {
2711 let (data_offset, old_len) = tmpl.param_slots[i];
2712 if param.is_null() {
2713 let len_offset = data_offset - 4;
2715 self.write_buf[len_offset..len_offset + 4]
2716 .copy_from_slice(&(-1i32).to_be_bytes());
2717 } else if old_len >= 0 {
2718 let end = data_offset + old_len as usize;
2719 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2720 template_ok = false;
2722 break;
2723 }
2724 } else {
2725 template_ok = false;
2728 break;
2729 }
2730 }
2731
2732 if template_ok {
2733 has_exec_sync = true; } else {
2735 self.write_buf.clear();
2736 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2737 info.bind_template = None;
2739 }
2740 } else {
2741 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2742 }
2743
2744 let cols = if need_columns {
2745 Some(info.columns.clone())
2746 } else {
2747 None
2748 };
2749
2750 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2754 info.bind_template = build_bind_template(&self.write_buf, params.len());
2755 }
2756
2757 if !has_exec_sync {
2758 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2759 }
2760 self.flush_write()?;
2761
2762 cols
2763 } else {
2764 let name = make_stmt_name(sql_hash);
2766 let param_oids: smallvec::SmallVec<[u32; 8]> =
2767 params.iter().map(|p| p.type_oid()).collect();
2768 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2769 proto::write_describe(&mut self.write_buf, b'S', &name);
2770 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2771
2772 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2773 self.flush_write()?;
2774
2775 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2776 let columns = self.read_column_description()?;
2777 self.query_counter += 1;
2778 self.cache_stmt(
2779 sql_hash,
2780 StmtInfo {
2781 name,
2782 sql: sql.into(),
2783 columns: columns.clone(),
2784 last_used: self.query_counter,
2785 bind_template: None,
2786 },
2787 );
2788 if need_columns {
2789 Some(columns)
2790 } else {
2791 None
2792 }
2793 };
2794
2795 if !skip_bind_complete {
2796 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2797 }
2798
2799 Ok(columns)
2800 }
2801
2802 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2804 loop {
2805 let msg = self.read_one_message()?;
2806 match msg {
2807 BackendMessage::RowDescription { data } => {
2808 let cols = proto::parse_row_description(data)?;
2809 return Ok(cols.into());
2810 }
2811 BackendMessage::ParameterDescription { .. } => {}
2812 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2813 BackendMessage::NoticeResponse { .. } => {}
2814 BackendMessage::ErrorResponse { data } => {
2815 let fields = proto::parse_error_response(data);
2816 self.drain_to_ready()?;
2817 return Err(self.make_server_error(fields));
2818 }
2819 other => {
2820 return Err(DriverError::Protocol(format!(
2821 "expected RowDescription/NoData, got: {other:?}"
2822 )));
2823 }
2824 }
2825 }
2826 }
2827
2828 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2831 if self.stmts.len() >= self.max_stmt_cache_size
2832 && !self.stmts.contains_key(&sql_hash, &info.sql)
2833 {
2834 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2835 proto::write_close(&mut self.write_buf, b'S', &evicted.name);
2836 }
2837 }
2838 self.stmts.insert(sql_hash, info);
2839 }
2840
2841 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2842 if self.pending_notifications.len() < 1024 {
2843 self.pending_notifications.push(Notification {
2844 pid,
2845 channel: channel.to_owned(),
2846 payload: payload.to_owned(),
2847 });
2848 }
2849 }
2850
2851 fn shrink_buffers(&mut self) {
2852 if self.query_counter & 63 != 0 {
2856 return;
2857 }
2858 if self.read_buf.capacity() > 64 * 1024 {
2859 self.read_buf.clear();
2860 self.read_buf.shrink_to(8192);
2861 }
2862 if self.write_buf.capacity() > 16 * 1024 {
2863 self.write_buf.clear();
2864 self.write_buf.shrink_to(8192);
2865 }
2866 }
2867
2868 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2869 if &fields.code == b"26000" {
2870 self.stmts.remove(&sql_hash);
2871 true
2872 } else {
2873 false
2874 }
2875 }
2876
2877 #[cold]
2878 #[inline(never)]
2879 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2880 DriverError::Server {
2881 code: fields.code,
2882 message: fields.message.into_boxed_str(),
2883 detail: fields.detail.map(String::into_boxed_str),
2884 hint: fields.hint.map(String::into_boxed_str),
2885 position: fields.position,
2886 }
2887 }
2888
2889 #[cold]
2895 #[inline(never)]
2896 fn handle_non_datarow_query(
2897 &mut self,
2898 msg_type: u8,
2899 payload_start: usize,
2900 payload_end: usize,
2901 sql_hash: u64,
2902 affected_rows: &mut u64,
2903 ) -> Result<(), DriverError> {
2904 match msg_type {
2905 b'2' | b'I' => {} b'C' => {
2907 *affected_rows =
2908 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2909 }
2910 b'E' => {
2911 let fields =
2912 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2913 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2914 self.drain_to_ready()?;
2915 return Err(self.make_server_error(fields));
2916 }
2917 b'A' => {
2918 let msg = proto::parse_backend_message(
2919 msg_type,
2920 &self.stream_buf[payload_start..payload_end],
2921 )?;
2922 if let BackendMessage::NotificationResponse {
2923 pid,
2924 channel,
2925 payload,
2926 } = msg
2927 {
2928 let ch = channel.to_owned();
2929 let pl = payload.to_owned();
2930 self.buffer_notification(pid, &ch, &pl);
2931 }
2932 }
2933 _ => {} }
2935 Ok(())
2936 }
2937
2938 #[cold]
2941 #[inline(never)]
2942 fn handle_non_datarow_execute(
2943 &mut self,
2944 msg_type: u8,
2945 payload_start: usize,
2946 payload_end: usize,
2947 sql_hash: u64,
2948 ) -> Result<(), DriverError> {
2949 match msg_type {
2950 b'2' | b'C' | b'I' => {} b'E' => {
2952 let fields =
2953 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2954 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2955 self.drain_to_ready()?;
2956 return Err(self.make_server_error(fields));
2957 }
2958 b'A' => {
2959 let msg = proto::parse_backend_message(
2960 msg_type,
2961 &self.stream_buf[payload_start..payload_end],
2962 )?;
2963 if let BackendMessage::NotificationResponse {
2964 pid,
2965 channel,
2966 payload,
2967 } = msg
2968 {
2969 let ch = channel.to_owned();
2970 let pl = payload.to_owned();
2971 self.buffer_notification(pid, &ch, &pl);
2972 }
2973 }
2974 _ => {} }
2976 Ok(())
2977 }
2978
2979 #[inline(always)]
2986 fn peek_stream_msg(&self) -> Result<Option<(u8, usize, usize, usize)>, DriverError> {
2987 let avail = self.stream_buf_end - self.stream_buf_pos;
2988 if avail < 5 {
2989 return Ok(None);
2990 }
2991
2992 let msg_type = self.stream_buf[self.stream_buf_pos];
2993 let raw_len = i32::from_be_bytes([
2994 self.stream_buf[self.stream_buf_pos + 1],
2995 self.stream_buf[self.stream_buf_pos + 2],
2996 self.stream_buf[self.stream_buf_pos + 3],
2997 self.stream_buf[self.stream_buf_pos + 4],
2998 ]);
2999
3000 if raw_len < 4 {
3001 return Err(DriverError::Protocol(format!(
3002 "invalid message length {raw_len} for type '{}'",
3003 msg_type as char
3004 )));
3005 }
3006
3007 let payload_len = (raw_len - 4) as usize;
3008 let total_msg_len = 5 + payload_len;
3009
3010 if avail < total_msg_len {
3011 return Ok(None);
3012 }
3013
3014 let payload_start = self.stream_buf_pos + 5;
3015 Ok(Some((
3016 msg_type,
3017 payload_start,
3018 payload_start + payload_len,
3019 total_msg_len,
3020 )))
3021 }
3022
3023 #[inline(always)]
3025 fn advance_stream_msg(&mut self, total_msg_len: usize) {
3026 self.stream_buf_pos += total_msg_len;
3027 }
3028
3029 #[inline]
3031 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
3032 loop {
3033 let (msg_type, _payload_len) = self.read_message_buffered()?;
3034 if msg_type == b'A' {
3035 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
3036 if let BackendMessage::NotificationResponse {
3037 pid,
3038 channel,
3039 payload,
3040 } = msg
3041 {
3042 let pid_owned = pid;
3043 let channel_owned = channel.to_owned();
3044 let payload_owned = payload.to_owned();
3045 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
3046 continue;
3047 }
3048 }
3049 return proto::parse_backend_message(msg_type, &self.read_buf);
3050 }
3051 }
3052
3053 fn expect_message(
3054 &mut self,
3055 pred: impl Fn(&BackendMessage<'_>) -> bool,
3056 ) -> Result<(), DriverError> {
3057 loop {
3058 let msg = self.read_one_message()?;
3059 if pred(&msg) {
3060 return Ok(());
3061 }
3062 match msg {
3063 BackendMessage::ErrorResponse { data } => {
3064 let fields = proto::parse_error_response(data);
3065 self.drain_to_ready()?;
3066 return Err(self.make_server_error(fields));
3067 }
3068 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3069 other => {
3070 return Err(DriverError::Protocol(format!(
3071 "unexpected message while waiting for expected type: {other:?}"
3072 )));
3073 }
3074 }
3075 }
3076 }
3077
3078 fn expect_ready(&mut self) -> Result<(), DriverError> {
3079 loop {
3080 let msg = self.read_one_message()?;
3081 match msg {
3082 BackendMessage::ReadyForQuery { status } => {
3083 self.tx_status = status;
3084 return Ok(());
3085 }
3086 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3087 BackendMessage::ErrorResponse { data } => {
3088 let fields = proto::parse_error_response(data);
3089 self.drain_to_ready()?;
3090 return Err(self.make_server_error(fields));
3091 }
3092 _ => {}
3093 }
3094 }
3095 }
3096
3097 #[inline]
3098 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
3099 loop {
3100 let msg = self.read_one_message()?;
3101 if let BackendMessage::ReadyForQuery { status } = msg {
3102 self.tx_status = status;
3103 return Ok(());
3104 }
3105 }
3106 }
3107
3108 #[inline]
3112 fn flush_write(&mut self) -> Result<(), DriverError> {
3113 self.stream
3114 .write_all(&self.write_buf)
3115 .map_err(DriverError::Io)
3116 }
3117
3118 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
3122 let mut header = [0u8; 5];
3123 sync_buffered_read_exact(
3124 &mut self.stream,
3125 &mut self.stream_buf,
3126 &mut self.stream_buf_pos,
3127 &mut self.stream_buf_end,
3128 &mut header,
3129 )?;
3130
3131 let msg_type = header[0];
3132 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
3133
3134 if len < 4 {
3135 return Err(DriverError::Protocol(format!(
3136 "invalid message length {len} for type '{}'",
3137 msg_type as char
3138 )));
3139 }
3140
3141 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
3142 if len > MAX_MESSAGE_LEN {
3143 return Err(DriverError::Protocol(format!(
3144 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
3145 msg_type as char
3146 )));
3147 }
3148
3149 let payload_len = (len - 4) as usize;
3150 self.read_buf.clear();
3151 self.read_buf.resize(payload_len, 0);
3152 if payload_len > 0 {
3153 sync_buffered_read_exact(
3154 &mut self.stream,
3155 &mut self.stream_buf,
3156 &mut self.stream_buf_pos,
3157 &mut self.stream_buf_end,
3158 &mut self.read_buf[..payload_len],
3159 )?;
3160 }
3161
3162 Ok((msg_type, payload_len))
3163 }
3164
3165 #[inline]
3167 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
3168 let remaining = self.stream_buf_end - self.stream_buf_pos;
3169 if remaining > 0 && self.stream_buf_pos > 0 {
3170 self.stream_buf
3171 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
3172 }
3173 self.stream_buf_pos = 0;
3174 self.stream_buf_end = remaining;
3175
3176 let n = self
3177 .stream
3178 .read(&mut self.stream_buf[remaining..])
3179 .map_err(DriverError::Io)?;
3180 if n == 0 {
3181 return Err(DriverError::Io(std::io::Error::new(
3182 std::io::ErrorKind::UnexpectedEof,
3183 "connection closed",
3184 )));
3185 }
3186 self.stream_buf_end = remaining + n;
3187 Ok(())
3188 }
3189}
3190
3191fn sync_buffered_read_exact(
3194 stream: &mut Stream,
3195 buf: &mut [u8],
3196 pos: &mut usize,
3197 end: &mut usize,
3198 out: &mut [u8],
3199) -> Result<(), DriverError> {
3200 let mut filled = 0;
3201 while filled < out.len() {
3202 let avail = *end - *pos;
3203 if avail > 0 {
3204 let take = avail.min(out.len() - filled);
3205 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
3206 *pos += take;
3207 filled += take;
3208 } else {
3209 *pos = 0;
3210 let n = stream.read(buf).map_err(DriverError::Io)?;
3211 if n == 0 {
3212 return Err(DriverError::Io(std::io::Error::new(
3213 std::io::ErrorKind::UnexpectedEof,
3214 "connection closed",
3215 )));
3216 }
3217 *end = n;
3218 }
3219 }
3220 Ok(())
3221}
3222
3223#[inline(always)]
3233pub(crate) fn parse_data_row_into_buf(
3234 data: &[u8],
3235 buf: &mut Vec<u8>,
3236 out: &mut Vec<(usize, i32)>,
3237) -> Result<(), DriverError> {
3238 if data.len() < 2 {
3239 return Err(DriverError::Protocol("DataRow too short".into()));
3240 }
3241
3242 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3243 if num_cols < 0 {
3244 return Err(DriverError::Protocol(
3245 "DataRow: negative column count".into(),
3246 ));
3247 }
3248 let num_cols = num_cols as usize;
3249
3250 let col_data = &data[2..];
3258 let base = buf.len();
3259 buf.extend_from_slice(col_data);
3260
3261 let mut pos: usize = 0;
3263 for _ in 0..num_cols {
3264 if pos + 4 > col_data.len() {
3265 return Err(DriverError::Protocol("DataRow truncated".into()));
3266 }
3267
3268 let col_len = i32::from_be_bytes([
3269 col_data[pos],
3270 col_data[pos + 1],
3271 col_data[pos + 2],
3272 col_data[pos + 3],
3273 ]);
3274 pos += 4;
3275
3276 if col_len < 0 {
3277 out.push((0, -1));
3278 } else {
3279 let len = col_len as usize;
3280 if pos + len > col_data.len() {
3281 return Err(DriverError::Protocol(
3282 "DataRow column data truncated".into(),
3283 ));
3284 }
3285 out.push((base + pos, col_len));
3287 pos += len;
3288 }
3289 }
3290
3291 Ok(())
3292}
3293
3294fn parse_data_row_flat(
3298 data: &[u8],
3299 arena: &mut Arena,
3300 out: &mut Vec<(usize, i32)>,
3301) -> Result<(), DriverError> {
3302 if data.len() < 2 {
3303 return Err(DriverError::Protocol("DataRow too short".into()));
3304 }
3305
3306 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3307 if num_cols_raw < 0 {
3308 return Err(DriverError::Protocol(
3309 "DataRow: negative column count".into(),
3310 ));
3311 }
3312 let num_cols = num_cols_raw as usize;
3313 out.reserve(num_cols);
3314
3315 let col_data = &data[2..];
3318 let base = arena.alloc_copy(col_data);
3319
3320 let mut pos: usize = 0;
3322 for _ in 0..num_cols {
3323 if pos + 4 > col_data.len() {
3324 return Err(DriverError::Protocol("DataRow truncated".into()));
3325 }
3326
3327 let col_len = i32::from_be_bytes([
3328 col_data[pos],
3329 col_data[pos + 1],
3330 col_data[pos + 2],
3331 col_data[pos + 3],
3332 ]);
3333 pos += 4;
3334
3335 if col_len < 0 {
3336 out.push((0, -1));
3337 } else {
3338 let len = col_len as usize;
3339 if pos + len > col_data.len() {
3340 return Err(DriverError::Protocol(
3341 "DataRow column data truncated".into(),
3342 ));
3343 }
3344 out.push((base + pos, col_len));
3346 pos += len;
3347 }
3348 }
3349
3350 Ok(())
3351}
3352
3353#[cfg(test)]
3354#[allow(clippy::approx_constant)]
3355mod tests {
3356 use super::*;
3357 use crate::types::hash_sql;
3358
3359 #[test]
3360 fn sync_config_tcp_no_longer_rejected() {
3361 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3364 let result = Connection::connect(&config);
3365 assert!(result.is_err());
3366 let err = result.unwrap_err().to_string();
3367 assert!(
3370 !err.contains("Unix domain socket"),
3371 "error should NOT mention UDS requirement: {err}"
3372 );
3373 }
3374
3375 #[test]
3376 fn sync_data_row_parsing() {
3377 let mut arena = Arena::new();
3378 let mut out = Vec::new();
3379
3380 let mut data = Vec::new();
3381 data.extend_from_slice(&2i16.to_be_bytes());
3382 data.extend_from_slice(&4i32.to_be_bytes());
3383 data.extend_from_slice(&42i32.to_be_bytes());
3384 data.extend_from_slice(&(-1i32).to_be_bytes());
3385
3386 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3387 assert_eq!(out.len(), 2);
3388 assert_eq!(out[0].1, 4);
3389 assert_eq!(out[1].1, -1);
3390 }
3391
3392 #[test]
3393 fn sync_data_row_empty() {
3394 let mut arena = Arena::new();
3395 let mut out = Vec::new();
3396 let data = 0i16.to_be_bytes();
3397 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3398 assert_eq!(out.len(), 0);
3399 }
3400
3401 #[test]
3402 fn sync_data_row_too_short() {
3403 let mut arena = Arena::new();
3404 let mut out = Vec::new();
3405 let data = vec![0u8];
3406 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3407 }
3408
3409 #[test]
3410 fn sync_data_row_negative_col_count() {
3411 let mut arena = Arena::new();
3412 let mut out = Vec::new();
3413 let data = (-1i16).to_be_bytes();
3414 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3415 }
3416
3417 #[test]
3418 fn sync_data_row_truncated() {
3419 let mut arena = Arena::new();
3420 let mut out = Vec::new();
3421 let mut data = Vec::new();
3422 data.extend_from_slice(&2i16.to_be_bytes());
3423 data.extend_from_slice(&4i32.to_be_bytes());
3424 data.extend_from_slice(&42i32.to_be_bytes());
3425 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3427 }
3428
3429 #[test]
3430 fn sync_data_row_col_data_truncated() {
3431 let mut arena = Arena::new();
3432 let mut out = Vec::new();
3433 let mut data = Vec::new();
3434 data.extend_from_slice(&1i16.to_be_bytes());
3435 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3438 }
3439
3440 #[test]
3443 fn sync_connect_tcp_unreachable_port() {
3444 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3447 let result = Connection::connect(&config);
3448 assert!(result.is_err());
3449 let err = result.unwrap_err().to_string();
3450 assert!(
3451 !err.contains("Unix domain socket"),
3452 "error should NOT mention UDS: {err}"
3453 );
3454 }
3455
3456 #[test]
3457 fn sync_connect_ip_address_attempts_tcp() {
3458 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3461 let result = Connection::connect(&config);
3462 assert!(result.is_err());
3463 }
3464
3465 #[test]
3468 fn sync_data_row_all_null() {
3469 let mut arena = Arena::new();
3470 let mut out = Vec::new();
3471 let mut data = Vec::new();
3472 data.extend_from_slice(&3i16.to_be_bytes());
3473 data.extend_from_slice(&(-1i32).to_be_bytes());
3474 data.extend_from_slice(&(-1i32).to_be_bytes());
3475 data.extend_from_slice(&(-1i32).to_be_bytes());
3476 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3477 assert_eq!(out.len(), 3);
3478 for (_, len) in &out {
3479 assert_eq!(*len, -1);
3480 }
3481 }
3482
3483 #[test]
3484 fn sync_data_row_long_text() {
3485 let mut arena = Arena::new();
3486 let mut out = Vec::new();
3487 let long_text = "a".repeat(2048);
3488 let text_bytes = long_text.as_bytes();
3489 let mut data = Vec::new();
3490 data.extend_from_slice(&1i16.to_be_bytes());
3491 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3492 data.extend_from_slice(text_bytes);
3493 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3494 assert_eq!(out.len(), 1);
3495 assert_eq!(out[0].1, text_bytes.len() as i32);
3496 let stored = arena.get(out[0].0, out[0].1 as usize);
3497 assert_eq!(stored, text_bytes);
3498 }
3499
3500 #[test]
3501 fn sync_data_row_empty_text() {
3502 let mut arena = Arena::new();
3503 let mut out = Vec::new();
3504 let mut data = Vec::new();
3505 data.extend_from_slice(&1i16.to_be_bytes());
3506 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3508 assert_eq!(out.len(), 1);
3509 assert_eq!(out[0].1, 0); }
3511
3512 #[test]
3513 fn sync_data_row_17_columns_exceeds_smallvec() {
3514 let mut arena = Arena::new();
3515 let mut out = Vec::new();
3516 let mut data = Vec::new();
3517 let num_cols: i16 = 20;
3518 data.extend_from_slice(&num_cols.to_be_bytes());
3519 for i in 0..num_cols {
3520 let val = (i as i32).to_be_bytes();
3521 data.extend_from_slice(&4i32.to_be_bytes());
3522 data.extend_from_slice(&val);
3523 }
3524 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3525 assert_eq!(out.len(), 20);
3526 for (idx, (offset, len)) in out.iter().enumerate() {
3527 assert_eq!(*len, 4);
3528 let stored = arena.get(*offset, 4);
3529 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3530 assert_eq!(val, idx as i32);
3531 }
3532 }
3533
3534 #[test]
3535 fn sync_data_row_mixed_null_and_data() {
3536 let mut arena = Arena::new();
3537 let mut out = Vec::new();
3538 let mut data = Vec::new();
3539 data.extend_from_slice(&5i16.to_be_bytes());
3540 data.extend_from_slice(&(-1i32).to_be_bytes());
3542 data.extend_from_slice(&4i32.to_be_bytes());
3544 data.extend_from_slice(&42i32.to_be_bytes());
3545 data.extend_from_slice(&(-1i32).to_be_bytes());
3547 data.extend_from_slice(&(-1i32).to_be_bytes());
3549 data.extend_from_slice(&5i32.to_be_bytes());
3551 data.extend_from_slice(b"hello");
3552
3553 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3554 assert_eq!(out.len(), 5);
3555 assert_eq!(out[0].1, -1);
3556 assert_eq!(out[1].1, 4);
3557 assert_eq!(out[2].1, -1);
3558 assert_eq!(out[3].1, -1);
3559 assert_eq!(out[4].1, 5);
3560 let stored = arena.get(out[4].0, 5);
3561 assert_eq!(stored, b"hello");
3562 }
3563
3564 #[test]
3567 #[ignore] fn sync_connect_uds_if_pg_available() {
3569 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3570 let result = Connection::connect(&config);
3571 if let Ok(conn) = result {
3573 assert!(conn.pid() != 0, "pid should be nonzero");
3574 assert!(conn.is_idle(), "should start idle");
3575 assert!(!conn.is_in_transaction(), "should not be in tx");
3576 assert!(
3577 !conn.is_in_failed_transaction(),
3578 "should not be in failed tx"
3579 );
3580 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3581 let _ = conn.close();
3582 }
3583 }
3584
3585 #[test]
3586 #[ignore] fn sync_simple_query_if_pg_available() {
3588 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3589 let mut conn = Connection::connect(&config).unwrap();
3590 conn.simple_query("SELECT 1").unwrap();
3591 assert!(conn.is_idle());
3592 let _ = conn.close();
3593 }
3594
3595 #[test]
3596 #[ignore] fn sync_query_with_params_if_pg_available() {
3598 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3599 let mut conn = Connection::connect(&config).unwrap();
3600 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3601 let hash = hash_sql(sql);
3602 let a: i32 = 10;
3603 let b: i32 = 20;
3604 let result = conn
3605 .query(
3606 sql,
3607 hash,
3608 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3609 )
3610 .unwrap();
3611 assert_eq!(result.len(), 1);
3612 let _ = conn.close();
3613 }
3614
3615 #[test]
3616 #[ignore] fn sync_execute_if_pg_available() {
3618 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3619 let mut conn = Connection::connect(&config).unwrap();
3620 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3621 .unwrap();
3622 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3623 let hash = hash_sql(sql);
3624 let val: i32 = 42;
3625 let affected = conn
3626 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3627 .unwrap();
3628 assert_eq!(affected, 1);
3629 let _ = conn.close();
3630 }
3631
3632 #[test]
3633 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3635 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3636 let mut conn = Connection::connect(&config).unwrap();
3637 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3638 .unwrap();
3639 let sql = "SELECT id FROM _sync_fe0";
3640 let hash = hash_sql(sql);
3641 let mut count = 0u32;
3642 conn.for_each(sql, hash, &[], |_row| {
3643 count += 1;
3644 Ok(())
3645 })
3646 .unwrap();
3647 assert_eq!(count, 0);
3648 let _ = conn.close();
3649 }
3650
3651 #[test]
3652 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3654 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3655 let mut conn = Connection::connect(&config).unwrap();
3656 let sql = "SELECT generate_series(1, 5)";
3657 let hash = hash_sql(sql);
3658 let mut count = 0u32;
3659 conn.for_each(sql, hash, &[], |_row| {
3660 count += 1;
3661 Ok(())
3662 })
3663 .unwrap();
3664 assert_eq!(count, 5);
3665 let _ = conn.close();
3666 }
3667
3668 #[test]
3669 #[ignore] fn sync_prepare_only_if_pg_available() {
3671 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3672 let mut conn = Connection::connect(&config).unwrap();
3673 let sql = "SELECT 1";
3674 let hash = hash_sql(sql);
3675 conn.prepare_only(sql, hash).unwrap();
3676 assert_eq!(conn.stmt_cache_len(), 1);
3677 conn.prepare_only(sql, hash).unwrap();
3679 assert_eq!(conn.stmt_cache_len(), 1);
3680 let _ = conn.close();
3681 }
3682
3683 #[test]
3684 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3686 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3687 let mut conn = Connection::connect(&config).unwrap();
3688 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3689 assert!(!rows.is_empty());
3690 let _ = conn.close();
3691 }
3692
3693 #[test]
3694 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3696 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3697 let mut conn = Connection::connect(&config).unwrap();
3698 let sql1 = "SELECT 1";
3699 let hash1 = hash_sql(sql1);
3700 conn.query(sql1, hash1, &[]).unwrap();
3701 assert_eq!(conn.stmt_cache_len(), 1);
3702 conn.query(sql1, hash1, &[]).unwrap();
3704 assert_eq!(conn.stmt_cache_len(), 1);
3705 let sql2 = "SELECT 2";
3707 let hash2 = hash_sql(sql2);
3708 conn.query(sql2, hash2, &[]).unwrap();
3709 assert_eq!(conn.stmt_cache_len(), 2);
3710 let _ = conn.close();
3711 }
3712
3713 #[test]
3714 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3716 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3717 let mut conn = Connection::connect(&config).unwrap();
3718 let sql = "SELECTTTT INVALID GARBAGE";
3719 let hash = hash_sql(sql);
3720 let result = conn.query(sql, hash, &[]);
3721 assert!(result.is_err());
3722 assert!(conn.is_idle());
3724 let _ = conn.close();
3725 }
3726
3727 #[test]
3728 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3730 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3731 let mut conn = Connection::connect(&config).unwrap();
3732 assert!(conn.is_idle());
3733 assert!(!conn.is_in_transaction());
3734 conn.simple_query("BEGIN").unwrap();
3735 assert!(conn.is_in_transaction());
3736 assert!(!conn.is_idle());
3737 conn.simple_query("COMMIT").unwrap();
3738 assert!(conn.is_idle());
3739 assert!(!conn.is_in_transaction());
3740 let _ = conn.close();
3741 }
3742
3743 #[test]
3744 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3746 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3747 let mut conn = Connection::connect(&config).unwrap();
3748 conn.set_max_stmt_cache_size(3);
3749 for i in 0..5 {
3750 let sql = format!("SELECT {}", i);
3751 let hash = hash_sql(&sql);
3752 conn.query(&sql, hash, &[]).unwrap();
3753 }
3754 assert!(
3756 conn.stmt_cache_len() <= 3,
3757 "cache should be capped at 3, got {}",
3758 conn.stmt_cache_len()
3759 );
3760 let _ = conn.close();
3761 }
3762
3763 #[test]
3764 #[ignore] fn sync_for_each_raw_if_pg_available() {
3766 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3767 let mut conn = Connection::connect(&config).unwrap();
3768 let sql = "SELECT generate_series(1, 3)";
3769 let hash = hash_sql(sql);
3770 let mut raw_count = 0u32;
3771 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3772 raw_count += 1;
3773 Ok(())
3774 })
3775 .unwrap();
3776 assert_eq!(raw_count, 3);
3777 let _ = conn.close();
3778 }
3779
3780 #[test]
3781 #[ignore] fn sync_query_null_params_if_pg_available() {
3783 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3784 let mut conn = Connection::connect(&config).unwrap();
3785 let sql = "SELECT $1::int4 IS NULL AS is_null";
3786 let hash = hash_sql(sql);
3787 let val: Option<i32> = None;
3788 let _result = conn
3789 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
3790 .unwrap();
3791 let _ = conn.close();
3792 }
3793
3794 #[test]
3795 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3797 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3798 let mut conn = Connection::connect(&config).unwrap();
3799 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3800 let hash = hash_sql(sql);
3801 let p1: i32 = 42;
3802 let p2: i64 = 9999999;
3803 let p3: &str = "hello";
3804 let p4: bool = true;
3805 let p5: f64 = 3.14;
3806 let result = conn
3807 .query(
3808 sql,
3809 hash,
3810 &[
3811 &p1 as &(dyn Encode + Sync),
3812 &p2 as &(dyn Encode + Sync),
3813 &p3 as &(dyn Encode + Sync),
3814 &p4 as &(dyn Encode + Sync),
3815 &p5 as &(dyn Encode + Sync),
3816 ],
3817 )
3818 .unwrap();
3819 assert_eq!(result.len(), 1);
3820 let _ = conn.close();
3821 }
3822
3823 #[test]
3826 fn sync_shrink_threshold_values() {
3827 let shrink = 64 * 1024usize;
3836 let initial = 8192usize;
3837 assert!(
3838 shrink > initial,
3839 "shrink threshold must exceed initial size"
3840 );
3841 }
3842
3843 #[test]
3846 fn sync_connection_debug_format() {
3847 let fmt_str = format!(
3851 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3852 0, 'I', 0
3853 );
3854 assert!(fmt_str.contains("Connection"));
3855 assert!(fmt_str.contains("pid"));
3856 assert!(fmt_str.contains("tx_status"));
3857 }
3858
3859 #[test]
3862 fn sync_connect_sslmode_require_without_tls_feature() {
3863 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3867 config.ssl = SslMode::Require;
3868 let result = Connection::connect(&config);
3869 assert!(result.is_err());
3870 }
3875
3876 #[test]
3877 fn sync_connect_sslmode_disable_attempts_tcp() {
3878 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3879 config.ssl = SslMode::Disable;
3880 let result = Connection::connect(&config);
3881 assert!(result.is_err());
3882 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
3884 }
3885
3886 #[test]
3887 fn sync_connect_sslmode_prefer_attempts_tcp() {
3888 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3889 config.ssl = SslMode::Prefer;
3890 let result = Connection::connect(&config);
3891 assert!(result.is_err());
3892 }
3893
3894 #[test]
3897 #[ignore] fn sync_streaming_basic_if_pg_available() {
3899 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3900 let mut conn = Connection::connect(&config).unwrap();
3901 assert!(!conn.is_streaming());
3902
3903 let sql = "SELECT generate_series(1, 10)";
3904 let hash = hash_sql(sql);
3905
3906 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
3907 assert!(!cols.is_empty());
3908 assert!(conn.is_streaming());
3909
3910 let mut arena = Arena::new();
3911 let mut offsets = Vec::new();
3912 let mut total_rows = 0;
3913
3914 loop {
3916 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
3917 total_rows += offsets.len();
3918 if !has_more {
3919 break;
3920 }
3921 conn.streaming_send_execute(3).unwrap();
3922 }
3923
3924 assert_eq!(total_rows, 10);
3925 assert!(!conn.is_streaming());
3926 let _ = conn.close();
3927 }
3928
3929 #[test]
3932 #[ignore] fn sync_prepare_describe_if_pg_available() {
3934 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3935 let mut conn = Connection::connect(&config).unwrap();
3936
3937 let result = conn
3938 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
3939 .unwrap();
3940 assert_eq!(result.columns.len(), 1);
3941 assert_eq!(&*result.columns[0].name, "sum");
3942 assert_eq!(result.param_oids.len(), 2);
3943 let _ = conn.close();
3944 }
3945
3946 #[test]
3949 #[ignore] fn sync_wait_for_notification_if_pg_available() {
3951 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3952 let mut conn = Connection::connect(&config).unwrap();
3953
3954 conn.simple_query("LISTEN test_chan").unwrap();
3955 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
3956
3957 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
3959 .unwrap();
3960
3961 let (channel, payload) = conn.wait_for_notification().unwrap();
3962 assert_eq!(channel, "test_chan");
3963 assert_eq!(payload, "hello");
3964 let _ = conn.close();
3965 }
3966
3967 #[test]
3970 #[ignore] fn sync_cancel_if_pg_available() {
3972 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3973 let conn = Connection::connect(&config).unwrap();
3974 let result = conn.cancel();
3977 let _ = result;
3979 let _ = conn.close();
3980 }
3981
3982 #[test]
3985 #[ignore] fn sync_server_params_if_pg_available() {
3987 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3988 let conn = Connection::connect(&config).unwrap();
3989 let params = conn.server_params();
3990 assert!(
3991 !params.is_empty(),
3992 "server should send parameters during startup"
3993 );
3994 assert!(
3996 conn.parameter("server_encoding").is_some(),
3997 "server_encoding should be present"
3998 );
3999 let _ = conn.close();
4000 }
4001
4002 #[test]
4005 #[ignore] fn sync_set_read_timeout_if_pg_available() {
4007 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4008 let conn = Connection::connect(&config).unwrap();
4009 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
4011 .unwrap();
4012 conn.set_read_timeout(None).unwrap();
4013 let _ = conn.close();
4014 }
4015}