1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::arena::Arena;
18use crate::auth;
19use crate::codec::Encode;
20use crate::proto::{self, BackendMessage};
21use crate::stmt_cache::{build_bind_template, make_stmt_name, StmtCache, StmtInfo};
22use crate::sync_io::Stream;
23use crate::types::{
24 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
25 StartupAction,
26};
27use crate::DriverError;
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42pub(crate) fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58thread_local! {
59 static COL_OFFSETS_POOL: RefCell<Vec<Vec<(usize, i32)>>> = const { RefCell::new(Vec::new()) };
60}
61
62pub(crate) fn acquire_col_offsets() -> Vec<(usize, i32)> {
63 COL_OFFSETS_POOL
64 .with(|pool| pool.borrow_mut().pop())
65 .unwrap_or_default()
66}
67
68pub fn release_col_offsets(buf: Vec<(usize, i32)>) {
69 COL_OFFSETS_POOL.with(|pool| {
70 let mut pool = pool.borrow_mut();
71 if pool.len() < 4 {
72 pool.push(buf);
73 }
74 });
75}
76
77pub struct Connection {
106 stream_buf_pos: usize,
108 stream_buf_end: usize,
109 query_counter: u64,
110 tx_status: u8,
111 streaming_active: bool,
112 pid: i32,
113 secret: i32,
114 max_stmt_cache_size: usize,
115 stream: Stream,
117 write_buf: Vec<u8>,
118 stream_buf: Vec<u8>,
119 stmts: StmtCache,
120 read_buf: Vec<u8>,
122 params: Vec<(Box<str>, Box<str>)>,
123 last_used: std::time::Instant,
124 created_at: std::time::Instant,
125 pending_notifications: Vec<Notification>,
126 connect_config: Arc<Config>,
129 tls_server_cert_hash: Option<[u8; 32]>,
132}
133
134impl std::fmt::Debug for Connection {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("Connection")
137 .field("pid", &self.pid)
138 .field("tx_status", &(self.tx_status as char))
139 .field("stmt_cache_len", &self.stmts.len())
140 .finish()
141 }
142}
143
144impl Connection {
145 pub fn connect(config: &Config) -> Result<Self, DriverError> {
157 Self::connect_arc(Arc::new(config.clone()))
158 }
159
160 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
165 config.validate()?;
166
167 #[allow(unused_mut)]
170 let mut tls_cert_hash: Option<[u8; 32]> = None;
171
172 let stream = if config.host_is_uds() {
173 #[cfg(unix)]
175 {
176 let path = config.uds_path();
177 let unix =
178 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
179 Stream::Unix(unix)
180 }
181 #[cfg(not(unix))]
182 {
183 return Err(DriverError::Protocol(
184 "Unix domain sockets are not supported on this platform".into(),
185 ));
186 }
187 } else {
188 let addr = format!("{}:{}", config.host, config.port);
190 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
191
192 match config.ssl {
193 SslMode::Disable => {
194 tcp.set_nodelay(true).map_err(DriverError::Io)?;
195 let stream = Stream::Tcp(tcp);
196 stream.set_keepalive()?;
197 stream
198 }
199 SslMode::Prefer | SslMode::Require => {
200 #[cfg(feature = "tls")]
201 {
202 match crate::tls_sync::try_upgrade(
203 tcp,
204 &config.host,
205 config.ssl == SslMode::Require,
206 ) {
207 Ok(result) => {
208 tls_cert_hash = result.server_cert_hash;
209 let stream = Stream::Tls(Box::new(result.stream));
210 stream.set_nodelay()?;
211 stream.set_keepalive()?;
212 stream
213 }
214 Err(e) => {
215 if config.ssl == SslMode::Require {
216 return Err(e);
217 }
218 let tcp =
220 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
221 tcp.set_nodelay(true).map_err(DriverError::Io)?;
222 let stream = Stream::Tcp(tcp);
223 stream.set_keepalive()?;
224 stream
225 }
226 }
227 }
228 #[cfg(not(feature = "tls"))]
229 {
230 if config.ssl == SslMode::Require {
231 return Err(DriverError::Protocol(
232 "sslmode=require but bsql was compiled without the 'tls' feature"
233 .into(),
234 ));
235 }
236 tcp.set_nodelay(true).map_err(DriverError::Io)?;
237 let stream = Stream::Tcp(tcp);
238 stream.set_keepalive()?;
239 stream
240 }
241 }
242 }
243 };
244
245 let now = std::time::Instant::now();
246 let mut conn = Self {
247 stream_buf_pos: 0,
249 stream_buf_end: 0,
250 query_counter: 0,
251 tx_status: b'I',
252 streaming_active: false,
253 pid: 0,
254 secret: 0,
255 max_stmt_cache_size: 256,
256 stream,
258 write_buf: Vec::with_capacity(4096),
259 stream_buf: vec![0u8; 65536],
260 stmts: StmtCache::default(),
261 read_buf: Vec::with_capacity(8192),
263 params: Vec::new(),
264 last_used: now,
265 created_at: now,
266 pending_notifications: Vec::new(),
267 connect_config: config.clone(),
268 tls_server_cert_hash: tls_cert_hash,
269 };
270
271 conn.startup(&config)?;
272 conn.validate_server_params()?;
273
274 Ok(conn)
275 }
276
277 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
280 self.write_buf.clear();
281 let timeout_str; let mut extra_params: smallvec::SmallVec<[(&str, &str); 2]> = smallvec::SmallVec::new();
285 if config.statement_timeout_secs > 0 {
286 timeout_str = format!("{}s", config.statement_timeout_secs);
287 extra_params.push(("statement_timeout", &timeout_str));
288 }
289 proto::write_startup(
290 &mut self.write_buf,
291 &config.user,
292 &config.database,
293 &extra_params,
294 );
295 self.flush_write()?;
296
297 loop {
298 let action = self.read_startup_action()?;
299 match action {
300 StartupAction::AuthOk => {}
301 StartupAction::AuthCleartext => {
302 self.write_buf.clear();
303 let mut pw = config.password.as_bytes().to_vec();
304 pw.push(0);
305 proto::write_password(&mut self.write_buf, &pw);
306 self.flush_write()?;
307 }
308 StartupAction::AuthMd5(salt) => {
309 self.write_buf.clear();
310 let hash = auth::md5_password(&config.user, &config.password, &salt);
311 proto::write_password(&mut self.write_buf, &hash);
312 self.flush_write()?;
313 }
314 StartupAction::AuthSasl(mechanisms_data) => {
315 self.handle_scram(config, &mechanisms_data)?;
316 }
317 StartupAction::ParameterStatus(name, value) => {
318 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
319 entry.1 = value;
320 } else {
321 self.params.push((name, value));
322 }
323 }
324 StartupAction::BackendKeyData(pid, secret) => {
325 self.pid = pid;
326 self.secret = secret;
327 }
328 StartupAction::ReadyForQuery(status) => {
329 self.tx_status = status;
330 return Ok(());
331 }
332 StartupAction::Error(msg) => {
333 return Err(DriverError::Auth(msg));
334 }
335 StartupAction::Notice => {}
336 }
337 }
338 }
339
340 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
341 let (msg_type, _) = self.read_message_buffered()?;
342 let payload = &self.read_buf;
343 let msg = proto::parse_backend_message(msg_type, payload)?;
344 match msg {
345 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
346 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
347 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
348 BackendMessage::AuthSasl { mechanisms } => {
349 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
350 }
351 BackendMessage::ParameterStatus { name, value } => {
352 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
353 }
354 BackendMessage::BackendKeyData { pid, secret } => {
355 Ok(StartupAction::BackendKeyData(pid, secret))
356 }
357 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
358 BackendMessage::ErrorResponse { data } => {
359 let fields = proto::parse_error_response(data);
360 Ok(StartupAction::Error(fields.to_string()))
361 }
362 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
363 other => Err(DriverError::Protocol(format!(
364 "unexpected message during startup: {other:?}"
365 ))),
366 }
367 }
368
369 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
370 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
371
372 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
375 let mechanism = if use_plus {
376 "SCRAM-SHA-256-PLUS"
377 } else {
378 "SCRAM-SHA-256"
379 };
380
381 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
382 return Err(DriverError::Auth(format!(
383 "server requires unsupported SASL mechanism(s): {mechs:?}"
384 )));
385 }
386
387 let cert_hash = if use_plus {
388 self.tls_server_cert_hash.as_ref()
389 } else {
390 None
391 };
392 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
393
394 let client_first = scram.client_first_message();
396 self.write_buf.clear();
397 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
398 self.flush_write()?;
399
400 let (msg_type, _) = self.read_message_buffered()?;
402 let server_first = {
403 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
404 match msg {
405 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
406 BackendMessage::ErrorResponse { data } => {
407 let fields = proto::parse_error_response(data);
408 return Err(DriverError::Auth(fields.to_string()));
409 }
410 other => {
411 return Err(DriverError::Protocol(format!(
412 "expected AuthSaslContinue, got: {other:?}"
413 )));
414 }
415 }
416 };
417
418 scram.process_server_first(&server_first)?;
419
420 let client_final = scram.client_final_message()?;
422 self.write_buf.clear();
423 proto::write_sasl_response(&mut self.write_buf, &client_final);
424 self.flush_write()?;
425
426 let (msg_type, _) = self.read_message_buffered()?;
428 {
429 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
430 match msg {
431 BackendMessage::AuthSaslFinal { data } => {
432 let data_owned = data.to_vec();
433 scram.verify_server_final(&data_owned)?;
434 }
435 BackendMessage::ErrorResponse { data } => {
436 let fields = proto::parse_error_response(data);
437 return Err(DriverError::Auth(fields.to_string()));
438 }
439 other => {
440 return Err(DriverError::Protocol(format!(
441 "expected AuthSaslFinal, got: {other:?}"
442 )));
443 }
444 }
445 }
446
447 let (msg_type, _) = self.read_message_buffered()?;
449 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
450 match msg {
451 BackendMessage::AuthOk => Ok(()),
452 BackendMessage::ErrorResponse { data } => {
453 let fields = proto::parse_error_response(data);
454 Err(DriverError::Auth(fields.to_string()))
455 }
456 other => Err(DriverError::Protocol(format!(
457 "expected AuthOk after SCRAM, got: {other:?}"
458 ))),
459 }
460 }
461
462 fn validate_server_params(&self) -> Result<(), DriverError> {
463 if let Some(encoding) = self.parameter("server_encoding") {
464 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
465 return Err(DriverError::Protocol(format!(
466 "server_encoding is '{encoding}', but bsql requires UTF-8."
467 )));
468 }
469 }
470 if let Some(encoding) = self.parameter("client_encoding") {
471 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
472 return Err(DriverError::Protocol(format!(
473 "client_encoding is '{encoding}', but bsql requires UTF-8."
474 )));
475 }
476 }
477 if let Some(idt) = self.parameter("integer_datetimes") {
478 if idt != "on" {
479 return Err(DriverError::Protocol(format!(
480 "integer_datetimes is '{idt}', but bsql requires 'on'."
481 )));
482 }
483 }
484 Ok(())
485 }
486
487 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
493 if self.stmts.contains_key(&sql_hash, sql) {
494 return Ok(());
495 }
496 let name = make_stmt_name(sql_hash);
497 self.write_buf.clear();
498 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
499 proto::write_describe(&mut self.write_buf, b'S', &name);
500 proto::write_sync(&mut self.write_buf);
501 self.flush_write()?;
502
503 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
504 let columns = self.read_column_description()?;
505 self.expect_ready()?;
506
507 self.query_counter += 1;
508 self.cache_stmt(
509 sql_hash,
510 StmtInfo {
511 name,
512 sql: sql.into(),
513 columns,
514 last_used: self.query_counter,
515 bind_template: None,
516 },
517 );
518 Ok(())
519 }
520
521 pub fn prepare_batch(&mut self, sqls: &[(&str, u64)]) -> Result<(), DriverError> {
529 if sqls.is_empty() {
530 return Ok(());
531 }
532
533 let mut pending = 0usize;
535 self.write_buf.clear();
536 for &(sql, sql_hash) in sqls {
537 if self.stmts.contains_key(&sql_hash, sql) {
538 continue;
539 }
540 let name = make_stmt_name(sql_hash);
541 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
542 proto::write_describe(&mut self.write_buf, b'S', &name);
543 pending += 1;
544 }
545
546 if pending == 0 {
547 return Ok(());
548 }
549
550 proto::write_sync(&mut self.write_buf);
551 self.flush_write()?;
552
553 for &(sql, sql_hash) in sqls {
556 if self.stmts.contains_key(&sql_hash, sql) {
557 continue;
558 }
559
560 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
561 let columns = self.read_column_description()?;
562
563 let name = make_stmt_name(sql_hash);
564 self.query_counter += 1;
565 self.cache_stmt(
566 sql_hash,
567 StmtInfo {
568 name,
569 sql: sql.into(),
570 columns,
571 last_used: self.query_counter,
572 bind_template: None,
573 },
574 );
575 }
576
577 self.expect_ready()?;
578 Ok(())
579 }
580
581 #[inline]
592 pub fn query(
593 &mut self,
594 sql: &str,
595 sql_hash: u64,
596 params: &[&(dyn Encode + Sync)],
597 ) -> Result<QueryResult, DriverError> {
598 let columns = self
599 .send_pipeline(sql, sql_hash, params, true, true)?
600 .ok_or_else(|| {
601 DriverError::Protocol("send_pipeline(need_columns=true) returned None".into())
602 })?;
603
604 let num_cols = columns.len();
605 let mut all_col_offsets = acquire_col_offsets();
606 all_col_offsets.clear();
607 let mut affected_rows: u64 = 0;
608
609 let mut resp_buf = acquire_resp_buf();
618 resp_buf.clear();
619
620 'outer: loop {
622 loop {
623 let avail = self.stream_buf_end - self.stream_buf_pos;
624 if avail < 5 {
625 break; }
627
628 let msg_type = self.stream_buf[self.stream_buf_pos];
629 let raw_len = i32::from_be_bytes([
630 self.stream_buf[self.stream_buf_pos + 1],
631 self.stream_buf[self.stream_buf_pos + 2],
632 self.stream_buf[self.stream_buf_pos + 3],
633 self.stream_buf[self.stream_buf_pos + 4],
634 ]);
635
636 if raw_len < 4 {
637 return Err(DriverError::Protocol(format!(
638 "invalid message length {raw_len} for type '{}'",
639 msg_type as char
640 )));
641 }
642
643 let payload_len = (raw_len - 4) as usize;
644 let total_msg_len = 5 + payload_len;
645
646 if avail < total_msg_len {
647 if total_msg_len > self.stream_buf.len() {
648 let msg = self.read_one_message()?;
650 match msg {
651 BackendMessage::BindComplete => continue,
652 BackendMessage::DataRow { data } => {
653 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
654 continue;
655 }
656 BackendMessage::CommandComplete { tag } => {
657 affected_rows = proto::parse_command_tag(tag);
658 continue;
659 }
660 BackendMessage::EmptyQuery => continue,
661 BackendMessage::ReadyForQuery { status } => {
662 self.tx_status = status;
663 break 'outer;
664 }
665 BackendMessage::NoticeResponse { .. } => continue,
666 BackendMessage::ErrorResponse { data } => {
667 let fields = proto::parse_error_response(data);
668 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
669 self.drain_to_ready()?;
670 return Err(self.make_server_error(fields));
671 }
672 other => {
673 return Err(DriverError::Protocol(format!(
674 "unexpected message during query: {other:?}"
675 )));
676 }
677 }
678 }
679 break; }
681
682 let payload_start = self.stream_buf_pos + 5;
684 let payload_end = payload_start + payload_len;
685
686 if msg_type == b'D' {
687 parse_data_row_into_buf(
689 &self.stream_buf[payload_start..payload_end],
690 &mut resp_buf,
691 &mut all_col_offsets,
692 )?;
693 } else if msg_type == b'Z' {
694 if payload_len >= 1 {
695 self.tx_status = self.stream_buf[payload_start];
696 }
697 self.stream_buf_pos += total_msg_len;
698 break 'outer;
699 } else {
700 self.handle_non_datarow_query(
701 msg_type,
702 payload_start,
703 payload_end,
704 sql_hash,
705 &mut affected_rows,
706 )?;
707 }
708
709 self.stream_buf_pos += total_msg_len;
710 }
711
712 self.refill_stream_buf()?;
713 }
714
715 self.shrink_buffers();
716
717 Ok(QueryResult::from_parts_with_buf(
720 all_col_offsets,
721 num_cols,
722 columns,
723 affected_rows,
724 resp_buf,
725 ))
726 }
727
728 #[inline]
739 pub fn execute_monolithic(
740 &mut self,
741 sql: &str,
742 sql_hash: u64,
743 params: &[&(dyn Encode + Sync)],
744 ) -> Result<u64, DriverError> {
745 self.write_buf.clear();
747
748 let info = match self.stmts.get_mut(&sql_hash, sql) {
750 Some(info) => {
751 self.query_counter += 1;
752 info.last_used = self.query_counter;
753 info
754 }
755 None => {
756 return self.execute_with_prepare(sql, sql_hash, params);
758 }
759 };
760
761 let can_use_template = info
763 .bind_template
764 .as_ref()
765 .is_some_and(|t| t.param_slots.len() == params.len());
766
767 let mut has_exec_sync = false;
768
769 if can_use_template {
770 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
772 DriverError::Protocol("bind_template missing despite can_use_template".into())
773 })?;
774 self.write_buf.extend_from_slice(&tmpl.bytes);
775
776 let mut template_ok = true;
777 for (i, param) in params.iter().enumerate() {
778 let (data_offset, old_len) = tmpl.param_slots[i];
779 if param.is_null() {
780 let len_offset = data_offset - 4;
781 self.write_buf[len_offset..len_offset + 4]
782 .copy_from_slice(&(-1i32).to_be_bytes());
783 } else if old_len >= 0 {
784 let end = data_offset + old_len as usize;
785 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
786 template_ok = false;
787 break;
788 }
789 } else {
790 template_ok = false;
792 break;
793 }
794 }
795
796 if template_ok {
797 has_exec_sync = true;
798 } else {
799 self.write_buf.clear();
800 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
801 info.bind_template = None;
802 }
803 } else {
804 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
805 }
806
807 if info.bind_template.is_none() && !self.write_buf.is_empty() {
809 info.bind_template = build_bind_template(&self.write_buf, params.len());
810 }
811
812 if !has_exec_sync {
813 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
814 }
815
816 self.stream
818 .write_all(&self.write_buf)
819 .map_err(DriverError::Io)?;
820
821 let mut affected_rows: u64 = 0;
823
824 'outer: loop {
825 loop {
826 let avail = self.stream_buf_end - self.stream_buf_pos;
827 if avail < 5 {
828 break; }
830
831 let msg_type = self.stream_buf[self.stream_buf_pos];
832 let raw_len = i32::from_be_bytes([
833 self.stream_buf[self.stream_buf_pos + 1],
834 self.stream_buf[self.stream_buf_pos + 2],
835 self.stream_buf[self.stream_buf_pos + 3],
836 self.stream_buf[self.stream_buf_pos + 4],
837 ]);
838
839 if raw_len < 4 {
840 return Err(DriverError::Protocol(format!(
841 "invalid message length {raw_len} for type '{}'",
842 msg_type as char
843 )));
844 }
845
846 let payload_len = (raw_len - 4) as usize;
847 let total_msg_len = 5 + payload_len;
848
849 if avail < total_msg_len {
850 if total_msg_len > self.stream_buf.len() {
851 let msg = self.read_one_message()?;
852 match msg {
853 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
854 continue;
855 }
856 BackendMessage::CommandComplete { tag } => {
857 affected_rows = proto::parse_command_tag(tag);
858 continue;
859 }
860 BackendMessage::EmptyQuery => continue,
861 BackendMessage::ReadyForQuery { status } => {
862 self.tx_status = status;
863 break 'outer;
864 }
865 BackendMessage::NoticeResponse { .. } => continue,
866 BackendMessage::ErrorResponse { data } => {
867 let fields = proto::parse_error_response(data);
868 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
869 self.drain_to_ready()?;
870 return Err(self.make_server_error(fields));
871 }
872 other => {
873 return Err(DriverError::Protocol(format!(
874 "unexpected message during execute: {other:?}"
875 )));
876 }
877 }
878 }
879 break; }
881
882 let payload_start = self.stream_buf_pos + 5;
887 let payload_end = payload_start + payload_len;
888
889 if msg_type == b'2' {
890 self.stream_buf_pos += total_msg_len;
892 continue;
893 } else if msg_type == b'C' {
894 affected_rows = proto::parse_command_tag_bytes(
896 &self.stream_buf[payload_start..payload_end],
897 );
898 } else if msg_type == b'Z' {
899 if payload_len >= 1 {
901 self.tx_status = self.stream_buf[payload_start];
902 }
903 self.stream_buf_pos += total_msg_len;
904 break 'outer;
905 } else if msg_type == b'D' || msg_type == b'I' {
906 } else {
908 self.handle_non_datarow_execute(
909 msg_type,
910 payload_start,
911 payload_end,
912 sql_hash,
913 )?;
914 }
915
916 self.stream_buf_pos += total_msg_len;
917 }
918
919 let remaining = self.stream_buf_end - self.stream_buf_pos;
921 debug_assert!(
922 remaining == 0 || self.stream_buf_pos > 0,
923 "compact called with pos=0 and remaining data"
924 );
925 if remaining > 0 {
926 self.stream_buf
927 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
928 }
929 self.stream_buf_pos = 0;
930 self.stream_buf_end = remaining;
931 let n = self
932 .stream
933 .read(&mut self.stream_buf[remaining..])
934 .map_err(DriverError::Io)?;
935 if n == 0 {
936 return Err(DriverError::Io(std::io::Error::new(
937 std::io::ErrorKind::UnexpectedEof,
938 "connection closed",
939 )));
940 }
941 self.stream_buf_end = remaining + n;
942 }
943
944 if self.query_counter & 63 == 0 {
946 if self.read_buf.capacity() > 64 * 1024 {
947 self.read_buf.clear();
948 self.read_buf.shrink_to(8192);
949 }
950 if self.write_buf.capacity() > 16 * 1024 {
951 self.write_buf.clear();
952 self.write_buf.shrink_to(8192);
953 }
954 }
955
956 Ok(affected_rows)
957 }
958
959 #[cold]
961 #[inline(never)]
962 fn execute_with_prepare(
963 &mut self,
964 sql: &str,
965 sql_hash: u64,
966 params: &[&(dyn Encode + Sync)],
967 ) -> Result<u64, DriverError> {
968 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
969
970 if params.len() > i16::MAX as usize {
971 return Err(DriverError::Protocol(format!(
972 "parameter count {} exceeds maximum {}",
973 params.len(),
974 i16::MAX
975 )));
976 }
977
978 let name = make_stmt_name(sql_hash);
979 let param_oids: smallvec::SmallVec<[u32; 8]> =
980 params.iter().map(|p| p.type_oid()).collect();
981
982 self.write_buf.clear();
983 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
984 proto::write_describe(&mut self.write_buf, b'S', &name);
985 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
986 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
987 self.stream
988 .write_all(&self.write_buf)
989 .map_err(DriverError::Io)?;
990
991 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
992 let columns = self.read_column_description()?;
993 self.query_counter += 1;
994 self.cache_stmt(
995 sql_hash,
996 StmtInfo {
997 name,
998 sql: sql.into(),
999 columns,
1000 last_used: self.query_counter,
1001 bind_template: None,
1002 },
1003 );
1004
1005 let mut affected_rows: u64 = 0;
1007 'outer: loop {
1008 loop {
1009 let avail = self.stream_buf_end - self.stream_buf_pos;
1010 if avail < 5 {
1011 break;
1012 }
1013
1014 let msg_type = self.stream_buf[self.stream_buf_pos];
1015 let raw_len = i32::from_be_bytes([
1016 self.stream_buf[self.stream_buf_pos + 1],
1017 self.stream_buf[self.stream_buf_pos + 2],
1018 self.stream_buf[self.stream_buf_pos + 3],
1019 self.stream_buf[self.stream_buf_pos + 4],
1020 ]);
1021
1022 if raw_len < 4 {
1023 return Err(DriverError::Protocol(format!(
1024 "invalid message length {raw_len} for type '{}'",
1025 msg_type as char
1026 )));
1027 }
1028
1029 let payload_len = (raw_len - 4) as usize;
1030 let total_msg_len = 5 + payload_len;
1031
1032 if avail < total_msg_len {
1033 if total_msg_len > self.stream_buf.len() {
1034 let msg = self.read_one_message()?;
1035 match msg {
1036 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
1037 continue;
1038 }
1039 BackendMessage::CommandComplete { tag } => {
1040 affected_rows = proto::parse_command_tag(tag);
1041 continue;
1042 }
1043 BackendMessage::EmptyQuery => continue,
1044 BackendMessage::ReadyForQuery { status } => {
1045 self.tx_status = status;
1046 break 'outer;
1047 }
1048 BackendMessage::NoticeResponse { .. } => continue,
1049 BackendMessage::ErrorResponse { data } => {
1050 let fields = proto::parse_error_response(data);
1051 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1052 self.drain_to_ready()?;
1053 return Err(self.make_server_error(fields));
1054 }
1055 other => {
1056 return Err(DriverError::Protocol(format!(
1057 "unexpected message during execute: {other:?}"
1058 )));
1059 }
1060 }
1061 }
1062 break;
1063 }
1064
1065 let payload_start = self.stream_buf_pos + 5;
1066 let payload_end = payload_start + payload_len;
1067
1068 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
1069 } else if msg_type == b'C' {
1071 affected_rows = proto::parse_command_tag_bytes(
1072 &self.stream_buf[payload_start..payload_end],
1073 );
1074 } else if msg_type == b'Z' {
1075 if payload_len >= 1 {
1076 self.tx_status = self.stream_buf[payload_start];
1077 }
1078 self.stream_buf_pos += total_msg_len;
1079 break 'outer;
1080 } else {
1081 self.handle_non_datarow_execute(
1082 msg_type,
1083 payload_start,
1084 payload_end,
1085 sql_hash,
1086 )?;
1087 }
1088
1089 self.stream_buf_pos += total_msg_len;
1090 }
1091
1092 self.refill_stream_buf()?;
1093 }
1094
1095 Ok(affected_rows)
1096 }
1097
1098 #[inline]
1103 pub fn execute(
1104 &mut self,
1105 sql: &str,
1106 sql_hash: u64,
1107 params: &[&(dyn Encode + Sync)],
1108 ) -> Result<u64, DriverError> {
1109 self.execute_monolithic(sql, sql_hash, params)
1110 }
1111
1112 pub fn execute_pipeline(
1124 &mut self,
1125 sql: &str,
1126 sql_hash: u64,
1127 param_sets: &[&[&(dyn Encode + Sync)]],
1128 ) -> Result<Vec<u64>, DriverError> {
1129 if param_sets.is_empty() {
1130 return Ok(Vec::new());
1131 }
1132
1133 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1134
1135 self.write_buf.clear();
1136
1137 if !self.stmts.contains_key(&sql_hash, sql) {
1139 let name = make_stmt_name(sql_hash);
1140 let first_params = param_sets[0];
1141 if first_params.len() > i16::MAX as usize {
1142 return Err(DriverError::Protocol(format!(
1143 "parameter count {} exceeds maximum {}",
1144 first_params.len(),
1145 i16::MAX
1146 )));
1147 }
1148 let param_oids: smallvec::SmallVec<[u32; 8]> =
1149 first_params.iter().map(|p| p.type_oid()).collect();
1150 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1151 proto::write_describe(&mut self.write_buf, b'S', &name);
1152 proto::write_sync(&mut self.write_buf);
1153 self.flush_write()?;
1154
1155 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1156 let columns = self.read_column_description()?;
1157 self.expect_ready()?;
1158
1159 self.query_counter += 1;
1160 self.cache_stmt(
1161 sql_hash,
1162 StmtInfo {
1163 name,
1164 sql: sql.into(),
1165 columns,
1166 last_used: self.query_counter,
1167 bind_template: None,
1168 },
1169 );
1170
1171 self.write_buf.clear();
1172 }
1173
1174 let stmt_name = self
1176 .stmts
1177 .get(&sql_hash, sql)
1178 .ok_or_else(|| {
1179 DriverError::Protocol("stmt just cached but not found in execute_pipeline".into())
1180 })?
1181 .name;
1182 let count = param_sets.len();
1183
1184 for params in param_sets {
1185 if params.len() > i16::MAX as usize {
1186 return Err(DriverError::Protocol(format!(
1187 "parameter count {} exceeds maximum {}",
1188 params.len(),
1189 i16::MAX
1190 )));
1191 }
1192 proto::write_bind_params(&mut self.write_buf, b"", &stmt_name, params);
1193 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1194 }
1195
1196 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1197 self.flush_write()?;
1198
1199 let mut results = Vec::with_capacity(count);
1202
1203 'outer: loop {
1204 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1205 if msg_type == b'2' {
1206 } else if msg_type == b'C' {
1208 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1210 results.push(rows);
1211 } else if msg_type == b'Z' {
1212 if end > start {
1214 self.tx_status = self.stream_buf[start];
1215 }
1216 self.advance_stream_msg(total);
1217 break 'outer;
1218 } else if msg_type == b'I' {
1219 results.push(0);
1221 } else if msg_type == b'D' || msg_type == b'N' {
1222 } else if msg_type == b'E' {
1224 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1226 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1227 self.advance_stream_msg(total);
1228 self.drain_to_ready()?;
1229 return Err(self.make_server_error(fields));
1230 } else if msg_type == b'A' {
1231 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1233 if let BackendMessage::NotificationResponse {
1234 pid,
1235 channel,
1236 payload,
1237 } = msg
1238 {
1239 let ch = channel.to_owned();
1240 let pl = payload.to_owned();
1241 self.buffer_notification(pid, &ch, &pl);
1242 }
1243 }
1244 self.advance_stream_msg(total);
1247 }
1248
1249 self.refill_stream_buf()?;
1251 }
1252
1253 self.shrink_buffers();
1254 Ok(results)
1255 }
1256
1257 pub(crate) fn ensure_stmt_prepared(
1263 &mut self,
1264 sql: &str,
1265 sql_hash: u64,
1266 params: &[&(dyn Encode + Sync)],
1267 ) -> Result<[u8; 18], DriverError> {
1268 if let Some(info) = self.stmts.get(&sql_hash, sql) {
1269 return Ok(info.name);
1270 }
1271
1272 let name = make_stmt_name(sql_hash);
1273 if params.len() > i16::MAX as usize {
1274 return Err(DriverError::Protocol(format!(
1275 "parameter count {} exceeds maximum {}",
1276 params.len(),
1277 i16::MAX
1278 )));
1279 }
1280 let param_oids: smallvec::SmallVec<[u32; 8]> =
1281 params.iter().map(|p| p.type_oid()).collect();
1282
1283 self.write_buf.clear();
1284 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1285 proto::write_describe(&mut self.write_buf, b'S', &name);
1286 proto::write_sync(&mut self.write_buf);
1287 self.flush_write()?;
1288
1289 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1290 let columns = self.read_column_description()?;
1291 self.expect_ready()?;
1292
1293 self.query_counter += 1;
1294 self.cache_stmt(
1295 sql_hash,
1296 StmtInfo {
1297 name,
1298 sql: sql.into(),
1299 columns,
1300 last_used: self.query_counter,
1301 bind_template: None,
1302 },
1303 );
1304
1305 Ok(name)
1306 }
1307
1308 pub(crate) fn write_deferred_bind_execute(
1311 &self,
1312 sql: &str,
1313 sql_hash: u64,
1314 params: &[&(dyn Encode + Sync)],
1315 buf: &mut Vec<u8>,
1316 ) -> Result<(), DriverError> {
1317 let stmt_name = self
1318 .stmts
1319 .get(&sql_hash, sql)
1320 .ok_or_else(|| {
1321 DriverError::Protocol("stmt just cached but not found in write_deferred".into())
1322 })?
1323 .name;
1324 proto::write_bind_params(buf, b"", &stmt_name, params);
1325 buf.extend_from_slice(proto::EXECUTE_ONLY);
1326 Ok(())
1327 }
1328
1329 pub(crate) fn flush_deferred_pipeline(
1334 &mut self,
1335 buf: &mut Vec<u8>,
1336 count: usize,
1337 ) -> Result<Vec<u64>, DriverError> {
1338 if count == 0 {
1339 buf.clear();
1340 return Ok(Vec::new());
1341 }
1342
1343 buf.extend_from_slice(proto::SYNC_ONLY);
1344
1345 self.stream.write_all(buf).map_err(DriverError::Io)?;
1346 buf.clear();
1347
1348 let mut results = Vec::with_capacity(count);
1350
1351 'outer: loop {
1352 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1353 if msg_type == b'2' {
1354 } else if msg_type == b'C' {
1356 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1358 results.push(rows);
1359 } else if msg_type == b'Z' {
1360 if end > start {
1362 self.tx_status = self.stream_buf[start];
1363 }
1364 self.advance_stream_msg(total);
1365 break 'outer;
1366 } else if msg_type == b'I' {
1367 results.push(0);
1369 } else if msg_type == b'D' || msg_type == b'N' {
1370 } else if msg_type == b'E' {
1372 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1374 self.advance_stream_msg(total);
1375 self.drain_to_ready()?;
1376 return Err(self.make_server_error(fields));
1377 } else if msg_type == b'A' {
1378 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1380 if let BackendMessage::NotificationResponse {
1381 pid,
1382 channel,
1383 payload,
1384 } = msg
1385 {
1386 let ch = channel.to_owned();
1387 let pl = payload.to_owned();
1388 self.buffer_notification(pid, &ch, &pl);
1389 }
1390 }
1391 self.advance_stream_msg(total);
1394 }
1395
1396 self.refill_stream_buf()?;
1398 }
1399
1400 self.shrink_buffers();
1401 Ok(results)
1402 }
1403
1404 pub fn for_each<F>(
1406 &mut self,
1407 sql: &str,
1408 sql_hash: u64,
1409 params: &[&(dyn Encode + Sync)],
1410 mut f: F,
1411 ) -> Result<(), DriverError>
1412 where
1413 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1414 {
1415 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1416
1417 'outer: loop {
1419 loop {
1420 let avail = self.stream_buf_end - self.stream_buf_pos;
1421 if avail < 5 {
1422 break; }
1424
1425 let msg_type = self.stream_buf[self.stream_buf_pos];
1426 let raw_len = i32::from_be_bytes([
1427 self.stream_buf[self.stream_buf_pos + 1],
1428 self.stream_buf[self.stream_buf_pos + 2],
1429 self.stream_buf[self.stream_buf_pos + 3],
1430 self.stream_buf[self.stream_buf_pos + 4],
1431 ]);
1432
1433 if raw_len < 4 {
1434 return Err(DriverError::Protocol(format!(
1435 "invalid message length {raw_len} for type '{}'",
1436 msg_type as char
1437 )));
1438 }
1439
1440 let payload_len = (raw_len - 4) as usize;
1441 let total_msg_len = 5 + payload_len;
1442
1443 if avail < total_msg_len {
1444 if total_msg_len > self.stream_buf.len() {
1445 let msg = self.read_one_message()?;
1447 match msg {
1448 BackendMessage::BindComplete => continue,
1449 BackendMessage::DataRow { data } => {
1450 let row = PgDataRow::new(data)?;
1451 f(row)?;
1452 continue;
1453 }
1454 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1455 continue;
1456 }
1457 BackendMessage::ReadyForQuery { status } => {
1458 self.tx_status = status;
1459 break 'outer;
1460 }
1461 BackendMessage::NoticeResponse { .. } => continue,
1462 BackendMessage::ErrorResponse { data } => {
1463 let fields = proto::parse_error_response(data);
1464 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1465 self.drain_to_ready()?;
1466 return Err(self.make_server_error(fields));
1467 }
1468 other => {
1469 return Err(DriverError::Protocol(format!(
1470 "unexpected message during for_each: {other:?}"
1471 )));
1472 }
1473 }
1474 }
1475 break; }
1477
1478 let payload_start = self.stream_buf_pos + 5;
1480 let payload_end = payload_start + payload_len;
1481
1482 if msg_type == b'D' {
1485 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1487 f(row)?;
1488 } else if msg_type == b'Z' {
1489 if payload_len >= 1 {
1491 self.tx_status = self.stream_buf[payload_start];
1492 }
1493 self.stream_buf_pos += total_msg_len;
1494 break 'outer;
1495 } else {
1496 self.handle_non_datarow_execute(
1497 msg_type,
1498 payload_start,
1499 payload_end,
1500 sql_hash,
1501 )?;
1502 }
1503
1504 self.stream_buf_pos += total_msg_len;
1505 }
1506
1507 self.refill_stream_buf()?;
1509 }
1510
1511 self.shrink_buffers();
1512 Ok(())
1513 }
1514
1515 #[inline]
1526 pub fn for_each_raw_monolithic<F>(
1527 &mut self,
1528 sql: &str,
1529 sql_hash: u64,
1530 params: &[&(dyn Encode + Sync)],
1531 mut f: F,
1532 ) -> Result<(), DriverError>
1533 where
1534 F: FnMut(&[u8]) -> Result<(), DriverError>,
1535 {
1536 self.write_buf.clear();
1538
1539 let info = match self.stmts.get_mut(&sql_hash, sql) {
1541 Some(info) => {
1542 self.query_counter += 1;
1543 info.last_used = self.query_counter;
1544 info
1545 }
1546 None => {
1547 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1549 }
1550 };
1551
1552 let can_use_template = info
1554 .bind_template
1555 .as_ref()
1556 .is_some_and(|t| t.param_slots.len() == params.len());
1557
1558 let mut has_exec_sync = false;
1559
1560 if can_use_template {
1561 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
1563 DriverError::Protocol("bind_template missing despite can_use_template".into())
1564 })?;
1565 self.write_buf.extend_from_slice(&tmpl.bytes);
1566
1567 let mut template_ok = true;
1568 for (i, param) in params.iter().enumerate() {
1569 let (data_offset, old_len) = tmpl.param_slots[i];
1570 if param.is_null() {
1571 let len_offset = data_offset - 4;
1572 self.write_buf[len_offset..len_offset + 4]
1573 .copy_from_slice(&(-1i32).to_be_bytes());
1574 } else if old_len >= 0 {
1575 let end = data_offset + old_len as usize;
1576 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1577 template_ok = false;
1578 break;
1579 }
1580 } else {
1581 template_ok = false;
1582 break;
1583 }
1584 }
1585
1586 if template_ok {
1587 has_exec_sync = true;
1588 } else {
1589 self.write_buf.clear();
1590 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1591 info.bind_template = None;
1592 }
1593 } else {
1594 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1595 }
1596
1597 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1599 info.bind_template = build_bind_template(&self.write_buf, params.len());
1600 }
1601
1602 if !has_exec_sync {
1603 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1604 }
1605
1606 self.stream
1608 .write_all(&self.write_buf)
1609 .map_err(DriverError::Io)?;
1610
1611 loop {
1615 let avail = self.stream_buf_end - self.stream_buf_pos;
1616 if avail >= 5 {
1617 let bc_type = self.stream_buf[self.stream_buf_pos];
1618 match bc_type {
1619 b'2' => {
1620 self.stream_buf_pos += 5;
1621 break;
1622 }
1623 b'E' => {
1624 let msg = self.read_one_message()?;
1625 if let BackendMessage::ErrorResponse { data } = msg {
1626 let fields = proto::parse_error_response(data);
1627 self.drain_to_ready()?;
1628 return Err(self.make_server_error(fields));
1629 }
1630 }
1631 b'N' | b'S' => {
1632 let raw_len = i32::from_be_bytes([
1633 self.stream_buf[self.stream_buf_pos + 1],
1634 self.stream_buf[self.stream_buf_pos + 2],
1635 self.stream_buf[self.stream_buf_pos + 3],
1636 self.stream_buf[self.stream_buf_pos + 4],
1637 ]);
1638 let total = 1 + raw_len as usize;
1639 if avail >= total {
1640 self.stream_buf_pos += total;
1641 continue;
1642 }
1643 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1644 break;
1645 }
1646 _ => {
1647 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1648 break;
1649 }
1650 }
1651 } else {
1652 let remaining = self.stream_buf_end - self.stream_buf_pos;
1654 if remaining > 0 && self.stream_buf_pos > 0 {
1655 self.stream_buf
1656 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1657 }
1658 self.stream_buf_pos = 0;
1659 self.stream_buf_end = remaining;
1660 let n = self
1661 .stream
1662 .read(&mut self.stream_buf[remaining..])
1663 .map_err(DriverError::Io)?;
1664 if n == 0 {
1665 return Err(DriverError::Io(std::io::Error::new(
1666 std::io::ErrorKind::UnexpectedEof,
1667 "connection closed",
1668 )));
1669 }
1670 self.stream_buf_end = remaining + n;
1671 }
1672 }
1673
1674 'outer: loop {
1676 loop {
1677 let avail = self.stream_buf_end - self.stream_buf_pos;
1678 if avail < 5 {
1679 break;
1680 }
1681
1682 let msg_type = self.stream_buf[self.stream_buf_pos];
1683 let raw_len = i32::from_be_bytes([
1684 self.stream_buf[self.stream_buf_pos + 1],
1685 self.stream_buf[self.stream_buf_pos + 2],
1686 self.stream_buf[self.stream_buf_pos + 3],
1687 self.stream_buf[self.stream_buf_pos + 4],
1688 ]);
1689
1690 if raw_len < 4 {
1691 return Err(DriverError::Protocol(format!(
1692 "invalid message length {raw_len} for type '{}'",
1693 msg_type as char
1694 )));
1695 }
1696
1697 let payload_len = (raw_len - 4) as usize;
1698 let total_msg_len = 5 + payload_len;
1699
1700 if avail < total_msg_len {
1701 if total_msg_len > self.stream_buf.len() {
1702 let msg = self.read_one_message()?;
1703 match msg {
1704 BackendMessage::DataRow { data } => {
1705 f(data)?;
1706 continue;
1707 }
1708 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1709 continue;
1710 }
1711 BackendMessage::ReadyForQuery { status } => {
1712 self.tx_status = status;
1713 break 'outer;
1714 }
1715 BackendMessage::ErrorResponse { data } => {
1716 let fields = proto::parse_error_response(data);
1717 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1718 self.drain_to_ready()?;
1719 return Err(self.make_server_error(fields));
1720 }
1721 BackendMessage::NoticeResponse { .. } => continue,
1722 other => {
1723 return Err(DriverError::Protocol(format!(
1724 "unexpected message during for_each_raw: {other:?}"
1725 )));
1726 }
1727 }
1728 }
1729 break; }
1731
1732 let payload_start = self.stream_buf_pos + 5;
1734 let payload_end = payload_start + payload_len;
1735
1736 if msg_type == b'D' {
1737 f(&self.stream_buf[payload_start..payload_end])?;
1738 } else if msg_type == b'Z' {
1739 if payload_len >= 1 {
1740 self.tx_status = self.stream_buf[payload_start];
1741 }
1742 self.stream_buf_pos += total_msg_len;
1743 break 'outer;
1744 } else {
1745 self.handle_non_datarow_execute(
1746 msg_type,
1747 payload_start,
1748 payload_end,
1749 sql_hash,
1750 )?;
1751 }
1752
1753 self.stream_buf_pos += total_msg_len;
1754 }
1755
1756 let remaining = self.stream_buf_end - self.stream_buf_pos;
1758 if remaining > 0 && self.stream_buf_pos > 0 {
1759 self.stream_buf
1760 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1761 }
1762 self.stream_buf_pos = 0;
1763 self.stream_buf_end = remaining;
1764 let n = self
1765 .stream
1766 .read(&mut self.stream_buf[remaining..])
1767 .map_err(DriverError::Io)?;
1768 if n == 0 {
1769 return Err(DriverError::Io(std::io::Error::new(
1770 std::io::ErrorKind::UnexpectedEof,
1771 "connection closed",
1772 )));
1773 }
1774 self.stream_buf_end = remaining + n;
1775 }
1776
1777 if self.query_counter & 63 == 0 {
1779 if self.read_buf.capacity() > 64 * 1024 {
1780 self.read_buf.clear();
1781 self.read_buf.shrink_to(8192);
1782 }
1783 if self.write_buf.capacity() > 16 * 1024 {
1784 self.write_buf.clear();
1785 self.write_buf.shrink_to(8192);
1786 }
1787 }
1788
1789 Ok(())
1790 }
1791
1792 #[cold]
1794 #[inline(never)]
1795 fn for_each_raw_with_prepare<F>(
1796 &mut self,
1797 sql: &str,
1798 sql_hash: u64,
1799 params: &[&(dyn Encode + Sync)],
1800 mut f: F,
1801 ) -> Result<(), DriverError>
1802 where
1803 F: FnMut(&[u8]) -> Result<(), DriverError>,
1804 {
1805 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1806
1807 if params.len() > i16::MAX as usize {
1808 return Err(DriverError::Protocol(format!(
1809 "parameter count {} exceeds maximum {}",
1810 params.len(),
1811 i16::MAX
1812 )));
1813 }
1814
1815 let name = make_stmt_name(sql_hash);
1816 let param_oids: smallvec::SmallVec<[u32; 8]> =
1817 params.iter().map(|p| p.type_oid()).collect();
1818
1819 self.write_buf.clear();
1820 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1821 proto::write_describe(&mut self.write_buf, b'S', &name);
1822 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
1823 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1824 self.stream
1825 .write_all(&self.write_buf)
1826 .map_err(DriverError::Io)?;
1827
1828 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1829 let columns = self.read_column_description()?;
1830 self.query_counter += 1;
1831 self.cache_stmt(
1832 sql_hash,
1833 StmtInfo {
1834 name,
1835 sql: sql.into(),
1836 columns,
1837 last_used: self.query_counter,
1838 bind_template: None,
1839 },
1840 );
1841
1842 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1844
1845 'outer: loop {
1846 loop {
1847 let avail = self.stream_buf_end - self.stream_buf_pos;
1848 if avail < 5 {
1849 break;
1850 }
1851
1852 let msg_type = self.stream_buf[self.stream_buf_pos];
1853 let raw_len = i32::from_be_bytes([
1854 self.stream_buf[self.stream_buf_pos + 1],
1855 self.stream_buf[self.stream_buf_pos + 2],
1856 self.stream_buf[self.stream_buf_pos + 3],
1857 self.stream_buf[self.stream_buf_pos + 4],
1858 ]);
1859
1860 if raw_len < 4 {
1861 return Err(DriverError::Protocol(format!(
1862 "invalid message length {raw_len} for type '{}'",
1863 msg_type as char
1864 )));
1865 }
1866
1867 let payload_len = (raw_len - 4) as usize;
1868 let total_msg_len = 5 + payload_len;
1869
1870 if avail < total_msg_len {
1871 if total_msg_len > self.stream_buf.len() {
1872 let msg = self.read_one_message()?;
1873 match msg {
1874 BackendMessage::DataRow { data } => {
1875 f(data)?;
1876 continue;
1877 }
1878 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1879 continue;
1880 }
1881 BackendMessage::ReadyForQuery { status } => {
1882 self.tx_status = status;
1883 break 'outer;
1884 }
1885 BackendMessage::ErrorResponse { data } => {
1886 let fields = proto::parse_error_response(data);
1887 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1888 self.drain_to_ready()?;
1889 return Err(self.make_server_error(fields));
1890 }
1891 BackendMessage::NoticeResponse { .. } => continue,
1892 other => {
1893 return Err(DriverError::Protocol(format!(
1894 "unexpected message during for_each_raw: {other:?}"
1895 )));
1896 }
1897 }
1898 }
1899 break;
1900 }
1901
1902 let payload_start = self.stream_buf_pos + 5;
1903 let payload_end = payload_start + payload_len;
1904
1905 if msg_type == b'D' {
1906 f(&self.stream_buf[payload_start..payload_end])?;
1907 } else if msg_type == b'Z' {
1908 if payload_len >= 1 {
1909 self.tx_status = self.stream_buf[payload_start];
1910 }
1911 self.stream_buf_pos += total_msg_len;
1912 break 'outer;
1913 } else {
1914 self.handle_non_datarow_execute(
1915 msg_type,
1916 payload_start,
1917 payload_end,
1918 sql_hash,
1919 )?;
1920 }
1921
1922 self.stream_buf_pos += total_msg_len;
1923 }
1924
1925 self.refill_stream_buf()?;
1926 }
1927
1928 self.shrink_buffers();
1929 Ok(())
1930 }
1931
1932 #[inline]
1937 pub fn for_each_raw<F>(
1938 &mut self,
1939 sql: &str,
1940 sql_hash: u64,
1941 params: &[&(dyn Encode + Sync)],
1942 f: F,
1943 ) -> Result<(), DriverError>
1944 where
1945 F: FnMut(&[u8]) -> Result<(), DriverError>,
1946 {
1947 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1948 }
1949
1950 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1952 self.write_buf.clear();
1953 proto::write_simple_query(&mut self.write_buf, sql);
1954 self.flush_write()?;
1955
1956 loop {
1957 let msg = self.read_one_message()?;
1958 match msg {
1959 BackendMessage::ReadyForQuery { status } => {
1960 self.tx_status = status;
1961 return Ok(());
1962 }
1963 BackendMessage::CommandComplete { .. }
1964 | BackendMessage::RowDescription { .. }
1965 | BackendMessage::DataRow { .. }
1966 | BackendMessage::EmptyQuery
1967 | BackendMessage::NoticeResponse { .. }
1968 | BackendMessage::ParameterStatus { .. }
1969 | BackendMessage::AuthOk
1973 | BackendMessage::AuthSaslFinal { .. }
1974 | BackendMessage::BackendKeyData { .. } => {}
1975 BackendMessage::ErrorResponse { data } => {
1976 let fields = proto::parse_error_response(data);
1977 self.drain_to_ready()?;
1978 return Err(self.make_server_error(fields));
1979 }
1980 other => {
1981 return Err(DriverError::Protocol(format!(
1982 "unexpected message during simple_query: {other:?}"
1983 )));
1984 }
1985 }
1986 }
1987 }
1988
1989 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1991 self.write_buf.clear();
1992 proto::write_simple_query(&mut self.write_buf, sql);
1993 self.flush_write()?;
1994
1995 let mut rows: Vec<SimpleRow> = Vec::new();
1996 loop {
1997 let msg = self.read_one_message()?;
1998 match msg {
1999 BackendMessage::ReadyForQuery { status } => {
2000 self.tx_status = status;
2001 return Ok(rows);
2002 }
2003 BackendMessage::DataRow { data } => {
2004 rows.push(proto::parse_simple_data_row(data)?);
2005 }
2006 BackendMessage::RowDescription { .. }
2007 | BackendMessage::CommandComplete { .. }
2008 | BackendMessage::EmptyQuery
2009 | BackendMessage::NoticeResponse { .. }
2010 | BackendMessage::ParameterStatus { .. }
2011 | BackendMessage::AuthOk
2012 | BackendMessage::AuthSaslFinal { .. }
2013 | BackendMessage::BackendKeyData { .. } => {}
2014 BackendMessage::ErrorResponse { data } => {
2015 let fields = proto::parse_error_response(data);
2016 self.drain_to_ready()?;
2017 return Err(self.make_server_error(fields));
2018 }
2019 other => {
2020 return Err(DriverError::Protocol(format!(
2021 "unexpected message during simple_query_rows: {other:?}"
2022 )));
2023 }
2024 }
2025 }
2026 }
2027
2028 pub fn copy_in<'a, I>(
2050 &mut self,
2051 table: &str,
2052 columns: &[&str],
2053 rows: I,
2054 ) -> Result<u64, DriverError>
2055 where
2056 I: IntoIterator<Item = &'a str>,
2057 {
2058 let quoted_table = proto::quote_ident(table);
2060 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
2061 let sql = format!(
2062 "COPY {}({}) FROM STDIN",
2063 quoted_table,
2064 quoted_cols.join(",")
2065 );
2066
2067 self.write_buf.clear();
2069 proto::write_simple_query(&mut self.write_buf, &sql);
2070 self.flush_write()?;
2071
2072 loop {
2074 let msg = self.read_one_message()?;
2075 match msg {
2076 BackendMessage::CopyInResponse { .. } => break,
2077 BackendMessage::ErrorResponse { data } => {
2078 let fields = proto::parse_error_response(data);
2079 self.drain_to_ready()?;
2080 return Err(self.make_server_error(fields));
2081 }
2082 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2083 other => {
2084 return Err(DriverError::Protocol(format!(
2085 "expected CopyInResponse, got: {other:?}"
2086 )));
2087 }
2088 }
2089 }
2090
2091 self.write_buf.clear();
2101 for row in rows {
2102 let row_data = row.as_bytes();
2104 let data_len = (4 + row_data.len() + 1) as i32;
2105 self.write_buf.push(b'd');
2106 self.write_buf.extend_from_slice(&data_len.to_be_bytes());
2107 self.write_buf.extend_from_slice(row_data);
2108 self.write_buf.push(b'\n');
2109 if self.write_buf.len() > 65536 {
2111 self.flush_write()?;
2112 self.write_buf.clear();
2113 }
2114 }
2115 proto::write_copy_done(&mut self.write_buf);
2118 self.flush_write()?;
2119 self.write_buf.clear();
2120
2121 let mut count: u64 = 0;
2123 loop {
2124 let msg = self.read_one_message()?;
2125 match msg {
2126 BackendMessage::CommandComplete { tag } => {
2127 count = proto::parse_command_tag(tag);
2128 }
2129 BackendMessage::ReadyForQuery { status } => {
2130 self.tx_status = status;
2131 return Ok(count);
2132 }
2133 BackendMessage::ErrorResponse { data } => {
2134 let fields = proto::parse_error_response(data);
2135 self.drain_to_ready()?;
2136 return Err(self.make_server_error(fields));
2137 }
2138 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2139 other => {
2140 return Err(DriverError::Protocol(format!(
2141 "unexpected message during copy_in completion: {other:?}"
2142 )));
2143 }
2144 }
2145 }
2146 }
2147
2148 pub fn copy_out<W: std::io::Write>(
2168 &mut self,
2169 query: &str,
2170 writer: &mut W,
2171 ) -> Result<u64, DriverError> {
2172 let sql = format!("COPY ({query}) TO STDOUT");
2174
2175 self.write_buf.clear();
2177 proto::write_simple_query(&mut self.write_buf, &sql);
2178 self.flush_write()?;
2179
2180 loop {
2182 let msg = self.read_one_message()?;
2183 match msg {
2184 BackendMessage::CopyOutResponse { .. } => break,
2185 BackendMessage::ErrorResponse { data } => {
2186 let fields = proto::parse_error_response(data);
2187 self.drain_to_ready()?;
2188 return Err(self.make_server_error(fields));
2189 }
2190 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2191 other => {
2192 return Err(DriverError::Protocol(format!(
2193 "expected CopyOutResponse, got: {other:?}"
2194 )));
2195 }
2196 }
2197 }
2198
2199 loop {
2201 let msg = self.read_one_message()?;
2202 match msg {
2203 BackendMessage::CopyData { data } => {
2204 writer.write_all(data).map_err(DriverError::Io)?;
2205 }
2206 BackendMessage::CopyDone => break,
2207 BackendMessage::ErrorResponse { data } => {
2208 let fields = proto::parse_error_response(data);
2209 self.drain_to_ready()?;
2210 return Err(self.make_server_error(fields));
2211 }
2212 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2213 other => {
2214 return Err(DriverError::Protocol(format!(
2215 "unexpected message during copy_out data: {other:?}"
2216 )));
2217 }
2218 }
2219 }
2220
2221 let mut count: u64 = 0;
2223 loop {
2224 let msg = self.read_one_message()?;
2225 match msg {
2226 BackendMessage::CommandComplete { tag } => {
2227 count = proto::parse_command_tag(tag);
2228 }
2229 BackendMessage::ReadyForQuery { status } => {
2230 self.tx_status = status;
2231 return Ok(count);
2232 }
2233 BackendMessage::ErrorResponse { data } => {
2234 let fields = proto::parse_error_response(data);
2235 self.drain_to_ready()?;
2236 return Err(self.make_server_error(fields));
2237 }
2238 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2239 other => {
2240 return Err(DriverError::Protocol(format!(
2241 "unexpected message during copy_out completion: {other:?}"
2242 )));
2243 }
2244 }
2245 }
2246 }
2247
2248 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2253 self.write_buf.clear();
2254 proto::write_parse(&mut self.write_buf, b"", sql, &[]);
2257 proto::write_describe(&mut self.write_buf, b'S', b"");
2258 proto::write_sync(&mut self.write_buf);
2259 self.flush_write()?;
2260
2261 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2263
2264 let mut param_oids: Vec<u32> = Vec::new();
2266 let columns;
2267 loop {
2268 let msg = self.read_one_message()?;
2269 match msg {
2270 BackendMessage::ParameterDescription { data } => {
2271 param_oids = proto::parse_parameter_description(data)?;
2272 }
2273 BackendMessage::RowDescription { data } => {
2274 columns = proto::parse_row_description(data)?;
2275 break;
2276 }
2277 BackendMessage::NoData => {
2278 columns = Vec::new();
2279 break;
2280 }
2281 BackendMessage::NoticeResponse { .. } => {}
2282 BackendMessage::ErrorResponse { data } => {
2283 let fields = proto::parse_error_response(data);
2284 self.drain_to_ready()?;
2285 return Err(self.make_server_error(fields));
2286 }
2287 other => {
2288 return Err(DriverError::Protocol(format!(
2289 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2290 )));
2291 }
2292 }
2293 }
2294
2295 self.expect_ready()?;
2297
2298 Ok(PrepareResult {
2299 columns,
2300 param_oids,
2301 })
2302 }
2303
2304 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2313 loop {
2314 let (msg_type, _payload_len) = self.read_message_buffered()?;
2315 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2316 match msg {
2317 BackendMessage::NotificationResponse {
2318 channel, payload, ..
2319 } => {
2320 return Ok((channel.to_owned(), payload.to_owned()));
2321 }
2322 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2323 continue;
2324 }
2325 _ => continue,
2326 }
2327 }
2328 }
2329
2330 pub fn cancel(&self) -> Result<(), DriverError> {
2336 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2337 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2338 let mut buf = Vec::with_capacity(16);
2339 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2340 tcp.write_all(&buf).map_err(DriverError::Io)?;
2341 tcp.flush().map_err(DriverError::Io)?;
2342 drop(tcp);
2344 Ok(())
2345 }
2346
2347 pub fn set_read_timeout(
2352 &self,
2353 timeout: Option<std::time::Duration>,
2354 ) -> Result<(), DriverError> {
2355 self.stream
2356 .set_read_timeout(timeout)
2357 .map_err(DriverError::Io)
2358 }
2359
2360 pub fn query_streaming_start(
2374 &mut self,
2375 sql: &str,
2376 sql_hash: u64,
2377 params: &[&(dyn Encode + Sync)],
2378 chunk_size: i32,
2379 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2380 self.write_buf.clear();
2381
2382 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2383 self.query_counter += 1;
2385 info.last_used = self.query_counter;
2386
2387 let can_use_template = info
2388 .bind_template
2389 .as_ref()
2390 .is_some_and(|t| t.param_slots.len() == params.len());
2391
2392 if can_use_template {
2393 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
2395 DriverError::Protocol("bind_template missing despite can_use_template".into())
2396 })?;
2397 self.write_buf
2400 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2401
2402 let mut template_ok = true;
2403 for (i, param) in params.iter().enumerate() {
2404 let (data_offset, old_len) = tmpl.param_slots[i];
2405 if param.is_null() {
2406 let len_offset = data_offset - 4;
2407 self.write_buf[len_offset..len_offset + 4]
2408 .copy_from_slice(&(-1i32).to_be_bytes());
2409 } else if old_len >= 0 {
2410 let end = data_offset + old_len as usize;
2411 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2412 template_ok = false;
2413 break;
2414 }
2415 } else {
2416 template_ok = false;
2417 break;
2418 }
2419 }
2420
2421 if !template_ok {
2422 self.write_buf.clear();
2423 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2424 info.bind_template = None;
2425 }
2426 } else {
2427 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2428 }
2429
2430 let cols = info.columns.clone();
2431
2432 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2433 info.bind_template = build_bind_template(&self.write_buf, params.len());
2434 }
2435
2436 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2437 proto::write_flush(&mut self.write_buf);
2439 self.flush_write()?;
2440
2441 cols
2442 } else {
2443 let name = make_stmt_name(sql_hash);
2445 let param_oids: smallvec::SmallVec<[u32; 8]> =
2446 params.iter().map(|p| p.type_oid()).collect();
2447 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2448 proto::write_describe(&mut self.write_buf, b'S', &name);
2449 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2450
2451 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2452 proto::write_flush(&mut self.write_buf);
2453 self.flush_write()?;
2454
2455 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2456 let columns = self.read_column_description()?;
2457 self.query_counter += 1;
2458 self.cache_stmt(
2459 sql_hash,
2460 StmtInfo {
2461 name,
2462 sql: sql.into(),
2463 columns: columns.clone(),
2464 last_used: self.query_counter,
2465 bind_template: None,
2466 },
2467 );
2468 columns
2469 };
2470
2471 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2473
2474 self.streaming_active = true;
2475
2476 Ok((columns, false))
2477 }
2478
2479 pub fn streaming_next_chunk(
2487 &mut self,
2488 arena: &mut Arena,
2489 all_col_offsets: &mut Vec<(usize, i32)>,
2490 ) -> Result<bool, DriverError> {
2491 all_col_offsets.clear();
2492
2493 loop {
2494 let msg = self.read_one_message()?;
2495 match msg {
2496 BackendMessage::DataRow { data } => {
2497 parse_data_row_flat(data, arena, all_col_offsets)?;
2498 }
2499 BackendMessage::PortalSuspended => {
2500 return Ok(true);
2504 }
2505 BackendMessage::CommandComplete { .. } => {
2506 self.write_buf.clear();
2509 proto::write_sync(&mut self.write_buf);
2510 self.flush_write()?;
2511 self.expect_ready()?;
2512 self.shrink_buffers();
2513
2514 self.streaming_active = false;
2515 return Ok(false);
2516 }
2517 BackendMessage::EmptyQuery => {
2518 self.write_buf.clear();
2519 proto::write_sync(&mut self.write_buf);
2520 self.flush_write()?;
2521 self.expect_ready()?;
2522
2523 self.streaming_active = false;
2524 return Ok(false);
2525 }
2526 BackendMessage::ErrorResponse { data } => {
2527 let fields = proto::parse_error_response(data);
2528 self.write_buf.clear();
2530 proto::write_sync(&mut self.write_buf);
2531 self.flush_write()?;
2532 self.drain_to_ready()?;
2533
2534 self.streaming_active = false;
2535 return Err(self.make_server_error(fields));
2536 }
2537 BackendMessage::NoticeResponse { .. } => {}
2538 other => {
2539 return Err(DriverError::Protocol(format!(
2540 "unexpected message during streaming: {other:?}"
2541 )));
2542 }
2543 }
2544 }
2545 }
2546
2547 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2555 self.write_buf.clear();
2556 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2557 proto::write_flush(&mut self.write_buf);
2558 self.flush_write()
2559 }
2560
2561 pub fn is_streaming(&self) -> bool {
2563 self.streaming_active
2564 }
2565
2566 pub fn close(mut self) -> Result<(), DriverError> {
2568 self.write_buf.clear();
2569 proto::write_terminate(&mut self.write_buf);
2570 let _ = self.flush_write();
2571 Ok(())
2572 }
2573
2574 pub fn is_idle(&self) -> bool {
2578 self.tx_status == b'I'
2579 }
2580
2581 pub fn is_in_transaction(&self) -> bool {
2583 self.tx_status == b'T'
2584 }
2585
2586 pub fn is_in_failed_transaction(&self) -> bool {
2588 self.tx_status == b'E'
2589 }
2590
2591 pub fn touch(&mut self) {
2593 self.last_used = std::time::Instant::now();
2594 }
2595
2596 pub fn idle_duration(&self) -> std::time::Duration {
2598 self.last_used.elapsed()
2599 }
2600
2601 pub fn query_counter(&self) -> u64 {
2603 self.query_counter
2604 }
2605
2606 pub fn parameter(&self, name: &str) -> Option<&str> {
2608 self.params
2609 .iter()
2610 .find(|(k, _)| &**k == name)
2611 .map(|(_, v)| &**v)
2612 }
2613
2614 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2616 &self.params
2617 }
2618
2619 pub fn pid(&self) -> i32 {
2621 self.pid
2622 }
2623
2624 pub fn secret_key(&self) -> i32 {
2626 self.secret
2627 }
2628
2629 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2631 std::mem::take(&mut self.pending_notifications)
2632 }
2633
2634 pub fn pending_notification_count(&self) -> usize {
2636 self.pending_notifications.len()
2637 }
2638
2639 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2641 self.max_stmt_cache_size = size;
2642 }
2643
2644 pub fn stmt_cache_len(&self) -> usize {
2646 self.stmts.len()
2647 }
2648
2649 pub fn created_at(&self) -> std::time::Instant {
2651 self.created_at
2652 }
2653
2654 #[inline]
2662 fn send_pipeline(
2663 &mut self,
2664 sql: &str,
2665 sql_hash: u64,
2666 params: &[&(dyn Encode + Sync)],
2667 need_columns: bool,
2668 skip_bind_complete: bool,
2669 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
2670 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2671
2672 if params.len() > i16::MAX as usize {
2673 return Err(DriverError::Protocol(format!(
2674 "parameter count {} exceeds maximum {}",
2675 params.len(),
2676 i16::MAX
2677 )));
2678 }
2679
2680 self.write_buf.clear();
2681
2682 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2683 self.query_counter += 1;
2685 info.last_used = self.query_counter;
2686
2687 let can_use_template = info
2688 .bind_template
2689 .as_ref()
2690 .is_some_and(|t| t.param_slots.len() == params.len());
2691
2692 let mut has_exec_sync = false;
2694
2695 if can_use_template {
2696 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
2700 DriverError::Protocol("bind_template missing despite can_use_template".into())
2701 })?;
2702 self.write_buf.extend_from_slice(&tmpl.bytes);
2703
2704 let mut template_ok = true;
2705 for (i, param) in params.iter().enumerate() {
2706 let (data_offset, old_len) = tmpl.param_slots[i];
2707 if param.is_null() {
2708 let len_offset = data_offset - 4;
2710 self.write_buf[len_offset..len_offset + 4]
2711 .copy_from_slice(&(-1i32).to_be_bytes());
2712 } else if old_len >= 0 {
2713 let end = data_offset + old_len as usize;
2714 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2715 template_ok = false;
2717 break;
2718 }
2719 } else {
2720 template_ok = false;
2723 break;
2724 }
2725 }
2726
2727 if template_ok {
2728 has_exec_sync = true; } else {
2730 self.write_buf.clear();
2731 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2732 info.bind_template = None;
2734 }
2735 } else {
2736 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2737 }
2738
2739 let cols = if need_columns {
2740 Some(info.columns.clone())
2741 } else {
2742 None
2743 };
2744
2745 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2749 info.bind_template = build_bind_template(&self.write_buf, params.len());
2750 }
2751
2752 if !has_exec_sync {
2753 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2754 }
2755 self.flush_write()?;
2756
2757 cols
2758 } else {
2759 let name = make_stmt_name(sql_hash);
2761 let param_oids: smallvec::SmallVec<[u32; 8]> =
2762 params.iter().map(|p| p.type_oid()).collect();
2763 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2764 proto::write_describe(&mut self.write_buf, b'S', &name);
2765 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2766
2767 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2768 self.flush_write()?;
2769
2770 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2771 let columns = self.read_column_description()?;
2772 self.query_counter += 1;
2773 self.cache_stmt(
2774 sql_hash,
2775 StmtInfo {
2776 name,
2777 sql: sql.into(),
2778 columns: columns.clone(),
2779 last_used: self.query_counter,
2780 bind_template: None,
2781 },
2782 );
2783 if need_columns {
2784 Some(columns)
2785 } else {
2786 None
2787 }
2788 };
2789
2790 if !skip_bind_complete {
2791 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2792 }
2793
2794 Ok(columns)
2795 }
2796
2797 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2799 loop {
2800 let msg = self.read_one_message()?;
2801 match msg {
2802 BackendMessage::RowDescription { data } => {
2803 let cols = proto::parse_row_description(data)?;
2804 return Ok(cols.into());
2805 }
2806 BackendMessage::ParameterDescription { .. } => {}
2807 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2808 BackendMessage::NoticeResponse { .. } => {}
2809 BackendMessage::ErrorResponse { data } => {
2810 let fields = proto::parse_error_response(data);
2811 self.drain_to_ready()?;
2812 return Err(self.make_server_error(fields));
2813 }
2814 other => {
2815 return Err(DriverError::Protocol(format!(
2816 "expected RowDescription/NoData, got: {other:?}"
2817 )));
2818 }
2819 }
2820 }
2821 }
2822
2823 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2826 if self.stmts.len() >= self.max_stmt_cache_size
2827 && !self.stmts.contains_key(&sql_hash, &info.sql)
2828 {
2829 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2830 proto::write_close(&mut self.write_buf, b'S', &evicted.name);
2831 }
2832 }
2833 self.stmts.insert(sql_hash, info);
2834 }
2835
2836 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2837 if self.pending_notifications.len() < 1024 {
2838 self.pending_notifications.push(Notification {
2839 pid,
2840 channel: channel.to_owned(),
2841 payload: payload.to_owned(),
2842 });
2843 }
2844 }
2845
2846 fn shrink_buffers(&mut self) {
2847 if self.query_counter & 63 != 0 {
2851 return;
2852 }
2853 if self.read_buf.capacity() > 64 * 1024 {
2854 self.read_buf.clear();
2855 self.read_buf.shrink_to(8192);
2856 }
2857 if self.write_buf.capacity() > 16 * 1024 {
2858 self.write_buf.clear();
2859 self.write_buf.shrink_to(8192);
2860 }
2861 }
2862
2863 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2864 if &fields.code == b"26000" {
2865 self.stmts.remove(&sql_hash);
2866 true
2867 } else {
2868 false
2869 }
2870 }
2871
2872 #[cold]
2873 #[inline(never)]
2874 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2875 DriverError::Server {
2876 code: fields.code,
2877 message: fields.message.into_boxed_str(),
2878 detail: fields.detail.map(String::into_boxed_str),
2879 hint: fields.hint.map(String::into_boxed_str),
2880 position: fields.position,
2881 }
2882 }
2883
2884 #[cold]
2890 #[inline(never)]
2891 fn handle_non_datarow_query(
2892 &mut self,
2893 msg_type: u8,
2894 payload_start: usize,
2895 payload_end: usize,
2896 sql_hash: u64,
2897 affected_rows: &mut u64,
2898 ) -> Result<(), DriverError> {
2899 match msg_type {
2900 b'2' | b'I' => {} b'C' => {
2902 *affected_rows =
2903 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2904 }
2905 b'E' => {
2906 let fields =
2907 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2908 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2909 self.drain_to_ready()?;
2910 return Err(self.make_server_error(fields));
2911 }
2912 b'A' => {
2913 let msg = proto::parse_backend_message(
2914 msg_type,
2915 &self.stream_buf[payload_start..payload_end],
2916 )?;
2917 if let BackendMessage::NotificationResponse {
2918 pid,
2919 channel,
2920 payload,
2921 } = msg
2922 {
2923 let ch = channel.to_owned();
2924 let pl = payload.to_owned();
2925 self.buffer_notification(pid, &ch, &pl);
2926 }
2927 }
2928 _ => {} }
2930 Ok(())
2931 }
2932
2933 #[cold]
2936 #[inline(never)]
2937 fn handle_non_datarow_execute(
2938 &mut self,
2939 msg_type: u8,
2940 payload_start: usize,
2941 payload_end: usize,
2942 sql_hash: u64,
2943 ) -> Result<(), DriverError> {
2944 match msg_type {
2945 b'2' | b'C' | b'I' => {} b'E' => {
2947 let fields =
2948 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2949 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2950 self.drain_to_ready()?;
2951 return Err(self.make_server_error(fields));
2952 }
2953 b'A' => {
2954 let msg = proto::parse_backend_message(
2955 msg_type,
2956 &self.stream_buf[payload_start..payload_end],
2957 )?;
2958 if let BackendMessage::NotificationResponse {
2959 pid,
2960 channel,
2961 payload,
2962 } = msg
2963 {
2964 let ch = channel.to_owned();
2965 let pl = payload.to_owned();
2966 self.buffer_notification(pid, &ch, &pl);
2967 }
2968 }
2969 _ => {} }
2971 Ok(())
2972 }
2973
2974 #[inline(always)]
2981 fn peek_stream_msg(&self) -> Result<Option<(u8, usize, usize, usize)>, DriverError> {
2982 let avail = self.stream_buf_end - self.stream_buf_pos;
2983 if avail < 5 {
2984 return Ok(None);
2985 }
2986
2987 let msg_type = self.stream_buf[self.stream_buf_pos];
2988 let raw_len = i32::from_be_bytes([
2989 self.stream_buf[self.stream_buf_pos + 1],
2990 self.stream_buf[self.stream_buf_pos + 2],
2991 self.stream_buf[self.stream_buf_pos + 3],
2992 self.stream_buf[self.stream_buf_pos + 4],
2993 ]);
2994
2995 if raw_len < 4 {
2996 return Err(DriverError::Protocol(format!(
2997 "invalid message length {raw_len} for type '{}'",
2998 msg_type as char
2999 )));
3000 }
3001
3002 let payload_len = (raw_len - 4) as usize;
3003 let total_msg_len = 5 + payload_len;
3004
3005 if avail < total_msg_len {
3006 return Ok(None);
3007 }
3008
3009 let payload_start = self.stream_buf_pos + 5;
3010 Ok(Some((
3011 msg_type,
3012 payload_start,
3013 payload_start + payload_len,
3014 total_msg_len,
3015 )))
3016 }
3017
3018 #[inline(always)]
3020 fn advance_stream_msg(&mut self, total_msg_len: usize) {
3021 self.stream_buf_pos += total_msg_len;
3022 }
3023
3024 #[inline]
3026 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
3027 loop {
3028 let (msg_type, _payload_len) = self.read_message_buffered()?;
3029 if msg_type == b'A' {
3030 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
3031 if let BackendMessage::NotificationResponse {
3032 pid,
3033 channel,
3034 payload,
3035 } = msg
3036 {
3037 let pid_owned = pid;
3038 let channel_owned = channel.to_owned();
3039 let payload_owned = payload.to_owned();
3040 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
3041 continue;
3042 }
3043 }
3044 return proto::parse_backend_message(msg_type, &self.read_buf);
3045 }
3046 }
3047
3048 fn expect_message(
3049 &mut self,
3050 pred: impl Fn(&BackendMessage<'_>) -> bool,
3051 ) -> Result<(), DriverError> {
3052 loop {
3053 let msg = self.read_one_message()?;
3054 if pred(&msg) {
3055 return Ok(());
3056 }
3057 match msg {
3058 BackendMessage::ErrorResponse { data } => {
3059 let fields = proto::parse_error_response(data);
3060 self.drain_to_ready()?;
3061 return Err(self.make_server_error(fields));
3062 }
3063 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3064 other => {
3065 return Err(DriverError::Protocol(format!(
3066 "unexpected message while waiting for expected type: {other:?}"
3067 )));
3068 }
3069 }
3070 }
3071 }
3072
3073 fn expect_ready(&mut self) -> Result<(), DriverError> {
3074 loop {
3075 let msg = self.read_one_message()?;
3076 match msg {
3077 BackendMessage::ReadyForQuery { status } => {
3078 self.tx_status = status;
3079 return Ok(());
3080 }
3081 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3082 BackendMessage::ErrorResponse { data } => {
3083 let fields = proto::parse_error_response(data);
3084 self.drain_to_ready()?;
3085 return Err(self.make_server_error(fields));
3086 }
3087 _ => {}
3088 }
3089 }
3090 }
3091
3092 #[inline]
3093 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
3094 loop {
3095 let msg = self.read_one_message()?;
3096 if let BackendMessage::ReadyForQuery { status } = msg {
3097 self.tx_status = status;
3098 return Ok(());
3099 }
3100 }
3101 }
3102
3103 #[inline]
3107 fn flush_write(&mut self) -> Result<(), DriverError> {
3108 self.stream
3109 .write_all(&self.write_buf)
3110 .map_err(DriverError::Io)
3111 }
3112
3113 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
3117 let mut header = [0u8; 5];
3118 sync_buffered_read_exact(
3119 &mut self.stream,
3120 &mut self.stream_buf,
3121 &mut self.stream_buf_pos,
3122 &mut self.stream_buf_end,
3123 &mut header,
3124 )?;
3125
3126 let msg_type = header[0];
3127 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
3128
3129 if len < 4 {
3130 return Err(DriverError::Protocol(format!(
3131 "invalid message length {len} for type '{}'",
3132 msg_type as char
3133 )));
3134 }
3135
3136 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
3137 if len > MAX_MESSAGE_LEN {
3138 return Err(DriverError::Protocol(format!(
3139 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
3140 msg_type as char
3141 )));
3142 }
3143
3144 let payload_len = (len - 4) as usize;
3145 self.read_buf.clear();
3146 self.read_buf.resize(payload_len, 0);
3147 if payload_len > 0 {
3148 sync_buffered_read_exact(
3149 &mut self.stream,
3150 &mut self.stream_buf,
3151 &mut self.stream_buf_pos,
3152 &mut self.stream_buf_end,
3153 &mut self.read_buf[..payload_len],
3154 )?;
3155 }
3156
3157 Ok((msg_type, payload_len))
3158 }
3159
3160 #[inline]
3162 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
3163 let remaining = self.stream_buf_end - self.stream_buf_pos;
3164 if remaining > 0 && self.stream_buf_pos > 0 {
3165 self.stream_buf
3166 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
3167 }
3168 self.stream_buf_pos = 0;
3169 self.stream_buf_end = remaining;
3170
3171 let n = self
3172 .stream
3173 .read(&mut self.stream_buf[remaining..])
3174 .map_err(DriverError::Io)?;
3175 if n == 0 {
3176 return Err(DriverError::Io(std::io::Error::new(
3177 std::io::ErrorKind::UnexpectedEof,
3178 "connection closed",
3179 )));
3180 }
3181 self.stream_buf_end = remaining + n;
3182 Ok(())
3183 }
3184}
3185
3186fn sync_buffered_read_exact(
3189 stream: &mut Stream,
3190 buf: &mut [u8],
3191 pos: &mut usize,
3192 end: &mut usize,
3193 out: &mut [u8],
3194) -> Result<(), DriverError> {
3195 let mut filled = 0;
3196 while filled < out.len() {
3197 let avail = *end - *pos;
3198 if avail > 0 {
3199 let take = avail.min(out.len() - filled);
3200 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
3201 *pos += take;
3202 filled += take;
3203 } else {
3204 *pos = 0;
3205 let n = stream.read(buf).map_err(DriverError::Io)?;
3206 if n == 0 {
3207 return Err(DriverError::Io(std::io::Error::new(
3208 std::io::ErrorKind::UnexpectedEof,
3209 "connection closed",
3210 )));
3211 }
3212 *end = n;
3213 }
3214 }
3215 Ok(())
3216}
3217
3218#[inline(always)]
3228pub(crate) fn parse_data_row_into_buf(
3229 data: &[u8],
3230 buf: &mut Vec<u8>,
3231 out: &mut Vec<(usize, i32)>,
3232) -> Result<(), DriverError> {
3233 if data.len() < 2 {
3234 return Err(DriverError::Protocol("DataRow too short".into()));
3235 }
3236
3237 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3238 if num_cols < 0 {
3239 return Err(DriverError::Protocol(
3240 "DataRow: negative column count".into(),
3241 ));
3242 }
3243 let num_cols = num_cols as usize;
3244
3245 let col_data = &data[2..];
3253 let base = buf.len();
3254 buf.extend_from_slice(col_data);
3255
3256 let mut pos: usize = 0;
3258 for _ in 0..num_cols {
3259 if pos + 4 > col_data.len() {
3260 return Err(DriverError::Protocol("DataRow truncated".into()));
3261 }
3262
3263 let col_len = i32::from_be_bytes([
3264 col_data[pos],
3265 col_data[pos + 1],
3266 col_data[pos + 2],
3267 col_data[pos + 3],
3268 ]);
3269 pos += 4;
3270
3271 if col_len < 0 {
3272 out.push((0, -1));
3273 } else {
3274 let len = col_len as usize;
3275 if pos + len > col_data.len() {
3276 return Err(DriverError::Protocol(
3277 "DataRow column data truncated".into(),
3278 ));
3279 }
3280 out.push((base + pos, col_len));
3282 pos += len;
3283 }
3284 }
3285
3286 Ok(())
3287}
3288
3289fn parse_data_row_flat(
3293 data: &[u8],
3294 arena: &mut Arena,
3295 out: &mut Vec<(usize, i32)>,
3296) -> Result<(), DriverError> {
3297 if data.len() < 2 {
3298 return Err(DriverError::Protocol("DataRow too short".into()));
3299 }
3300
3301 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3302 if num_cols_raw < 0 {
3303 return Err(DriverError::Protocol(
3304 "DataRow: negative column count".into(),
3305 ));
3306 }
3307 let num_cols = num_cols_raw as usize;
3308 out.reserve(num_cols);
3309
3310 let col_data = &data[2..];
3313 let base = arena.alloc_copy(col_data);
3314
3315 let mut pos: usize = 0;
3317 for _ in 0..num_cols {
3318 if pos + 4 > col_data.len() {
3319 return Err(DriverError::Protocol("DataRow truncated".into()));
3320 }
3321
3322 let col_len = i32::from_be_bytes([
3323 col_data[pos],
3324 col_data[pos + 1],
3325 col_data[pos + 2],
3326 col_data[pos + 3],
3327 ]);
3328 pos += 4;
3329
3330 if col_len < 0 {
3331 out.push((0, -1));
3332 } else {
3333 let len = col_len as usize;
3334 if pos + len > col_data.len() {
3335 return Err(DriverError::Protocol(
3336 "DataRow column data truncated".into(),
3337 ));
3338 }
3339 out.push((base + pos, col_len));
3341 pos += len;
3342 }
3343 }
3344
3345 Ok(())
3346}
3347
3348#[cfg(test)]
3349#[allow(clippy::approx_constant)]
3350mod tests {
3351 use super::*;
3352 use crate::types::hash_sql;
3353
3354 #[test]
3355 fn sync_config_tcp_no_longer_rejected() {
3356 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3359 let result = Connection::connect(&config);
3360 assert!(result.is_err());
3361 let err = result.unwrap_err().to_string();
3362 assert!(
3365 !err.contains("Unix domain socket"),
3366 "error should NOT mention UDS requirement: {err}"
3367 );
3368 }
3369
3370 #[test]
3371 fn sync_data_row_parsing() {
3372 let mut arena = Arena::new();
3373 let mut out = Vec::new();
3374
3375 let mut data = Vec::new();
3376 data.extend_from_slice(&2i16.to_be_bytes());
3377 data.extend_from_slice(&4i32.to_be_bytes());
3378 data.extend_from_slice(&42i32.to_be_bytes());
3379 data.extend_from_slice(&(-1i32).to_be_bytes());
3380
3381 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3382 assert_eq!(out.len(), 2);
3383 assert_eq!(out[0].1, 4);
3384 assert_eq!(out[1].1, -1);
3385 }
3386
3387 #[test]
3388 fn sync_data_row_empty() {
3389 let mut arena = Arena::new();
3390 let mut out = Vec::new();
3391 let data = 0i16.to_be_bytes();
3392 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3393 assert_eq!(out.len(), 0);
3394 }
3395
3396 #[test]
3397 fn sync_data_row_too_short() {
3398 let mut arena = Arena::new();
3399 let mut out = Vec::new();
3400 let data = vec![0u8];
3401 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3402 }
3403
3404 #[test]
3405 fn sync_data_row_negative_col_count() {
3406 let mut arena = Arena::new();
3407 let mut out = Vec::new();
3408 let data = (-1i16).to_be_bytes();
3409 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3410 }
3411
3412 #[test]
3413 fn sync_data_row_truncated() {
3414 let mut arena = Arena::new();
3415 let mut out = Vec::new();
3416 let mut data = Vec::new();
3417 data.extend_from_slice(&2i16.to_be_bytes());
3418 data.extend_from_slice(&4i32.to_be_bytes());
3419 data.extend_from_slice(&42i32.to_be_bytes());
3420 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3422 }
3423
3424 #[test]
3425 fn sync_data_row_col_data_truncated() {
3426 let mut arena = Arena::new();
3427 let mut out = Vec::new();
3428 let mut data = Vec::new();
3429 data.extend_from_slice(&1i16.to_be_bytes());
3430 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3433 }
3434
3435 #[test]
3438 fn sync_connect_tcp_unreachable_port() {
3439 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3442 let result = Connection::connect(&config);
3443 assert!(result.is_err());
3444 let err = result.unwrap_err().to_string();
3445 assert!(
3446 !err.contains("Unix domain socket"),
3447 "error should NOT mention UDS: {err}"
3448 );
3449 }
3450
3451 #[test]
3452 fn sync_connect_ip_address_attempts_tcp() {
3453 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3456 let result = Connection::connect(&config);
3457 assert!(result.is_err());
3458 }
3459
3460 #[test]
3463 fn sync_data_row_all_null() {
3464 let mut arena = Arena::new();
3465 let mut out = Vec::new();
3466 let mut data = Vec::new();
3467 data.extend_from_slice(&3i16.to_be_bytes());
3468 data.extend_from_slice(&(-1i32).to_be_bytes());
3469 data.extend_from_slice(&(-1i32).to_be_bytes());
3470 data.extend_from_slice(&(-1i32).to_be_bytes());
3471 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3472 assert_eq!(out.len(), 3);
3473 for (_, len) in &out {
3474 assert_eq!(*len, -1);
3475 }
3476 }
3477
3478 #[test]
3479 fn sync_data_row_long_text() {
3480 let mut arena = Arena::new();
3481 let mut out = Vec::new();
3482 let long_text = "a".repeat(2048);
3483 let text_bytes = long_text.as_bytes();
3484 let mut data = Vec::new();
3485 data.extend_from_slice(&1i16.to_be_bytes());
3486 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3487 data.extend_from_slice(text_bytes);
3488 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3489 assert_eq!(out.len(), 1);
3490 assert_eq!(out[0].1, text_bytes.len() as i32);
3491 let stored = arena.get(out[0].0, out[0].1 as usize);
3492 assert_eq!(stored, text_bytes);
3493 }
3494
3495 #[test]
3496 fn sync_data_row_empty_text() {
3497 let mut arena = Arena::new();
3498 let mut out = Vec::new();
3499 let mut data = Vec::new();
3500 data.extend_from_slice(&1i16.to_be_bytes());
3501 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3503 assert_eq!(out.len(), 1);
3504 assert_eq!(out[0].1, 0); }
3506
3507 #[test]
3508 fn sync_data_row_17_columns_exceeds_smallvec() {
3509 let mut arena = Arena::new();
3510 let mut out = Vec::new();
3511 let mut data = Vec::new();
3512 let num_cols: i16 = 20;
3513 data.extend_from_slice(&num_cols.to_be_bytes());
3514 for i in 0..num_cols {
3515 let val = (i as i32).to_be_bytes();
3516 data.extend_from_slice(&4i32.to_be_bytes());
3517 data.extend_from_slice(&val);
3518 }
3519 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3520 assert_eq!(out.len(), 20);
3521 for (idx, (offset, len)) in out.iter().enumerate() {
3522 assert_eq!(*len, 4);
3523 let stored = arena.get(*offset, 4);
3524 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3525 assert_eq!(val, idx as i32);
3526 }
3527 }
3528
3529 #[test]
3530 fn sync_data_row_mixed_null_and_data() {
3531 let mut arena = Arena::new();
3532 let mut out = Vec::new();
3533 let mut data = Vec::new();
3534 data.extend_from_slice(&5i16.to_be_bytes());
3535 data.extend_from_slice(&(-1i32).to_be_bytes());
3537 data.extend_from_slice(&4i32.to_be_bytes());
3539 data.extend_from_slice(&42i32.to_be_bytes());
3540 data.extend_from_slice(&(-1i32).to_be_bytes());
3542 data.extend_from_slice(&(-1i32).to_be_bytes());
3544 data.extend_from_slice(&5i32.to_be_bytes());
3546 data.extend_from_slice(b"hello");
3547
3548 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3549 assert_eq!(out.len(), 5);
3550 assert_eq!(out[0].1, -1);
3551 assert_eq!(out[1].1, 4);
3552 assert_eq!(out[2].1, -1);
3553 assert_eq!(out[3].1, -1);
3554 assert_eq!(out[4].1, 5);
3555 let stored = arena.get(out[4].0, 5);
3556 assert_eq!(stored, b"hello");
3557 }
3558
3559 #[test]
3562 #[ignore] fn sync_connect_uds_if_pg_available() {
3564 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3565 let result = Connection::connect(&config);
3566 if let Ok(conn) = result {
3568 assert!(conn.pid() != 0, "pid should be nonzero");
3569 assert!(conn.is_idle(), "should start idle");
3570 assert!(!conn.is_in_transaction(), "should not be in tx");
3571 assert!(
3572 !conn.is_in_failed_transaction(),
3573 "should not be in failed tx"
3574 );
3575 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3576 let _ = conn.close();
3577 }
3578 }
3579
3580 #[test]
3581 #[ignore] fn sync_simple_query_if_pg_available() {
3583 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3584 let mut conn = Connection::connect(&config).unwrap();
3585 conn.simple_query("SELECT 1").unwrap();
3586 assert!(conn.is_idle());
3587 let _ = conn.close();
3588 }
3589
3590 #[test]
3591 #[ignore] fn sync_query_with_params_if_pg_available() {
3593 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3594 let mut conn = Connection::connect(&config).unwrap();
3595 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3596 let hash = hash_sql(sql);
3597 let a: i32 = 10;
3598 let b: i32 = 20;
3599 let result = conn
3600 .query(
3601 sql,
3602 hash,
3603 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3604 )
3605 .unwrap();
3606 assert_eq!(result.len(), 1);
3607 let _ = conn.close();
3608 }
3609
3610 #[test]
3611 #[ignore] fn sync_execute_if_pg_available() {
3613 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3614 let mut conn = Connection::connect(&config).unwrap();
3615 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3616 .unwrap();
3617 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3618 let hash = hash_sql(sql);
3619 let val: i32 = 42;
3620 let affected = conn
3621 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3622 .unwrap();
3623 assert_eq!(affected, 1);
3624 let _ = conn.close();
3625 }
3626
3627 #[test]
3628 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3630 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3631 let mut conn = Connection::connect(&config).unwrap();
3632 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3633 .unwrap();
3634 let sql = "SELECT id FROM _sync_fe0";
3635 let hash = hash_sql(sql);
3636 let mut count = 0u32;
3637 conn.for_each(sql, hash, &[], |_row| {
3638 count += 1;
3639 Ok(())
3640 })
3641 .unwrap();
3642 assert_eq!(count, 0);
3643 let _ = conn.close();
3644 }
3645
3646 #[test]
3647 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3649 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3650 let mut conn = Connection::connect(&config).unwrap();
3651 let sql = "SELECT generate_series(1, 5)";
3652 let hash = hash_sql(sql);
3653 let mut count = 0u32;
3654 conn.for_each(sql, hash, &[], |_row| {
3655 count += 1;
3656 Ok(())
3657 })
3658 .unwrap();
3659 assert_eq!(count, 5);
3660 let _ = conn.close();
3661 }
3662
3663 #[test]
3664 #[ignore] fn sync_prepare_only_if_pg_available() {
3666 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3667 let mut conn = Connection::connect(&config).unwrap();
3668 let sql = "SELECT 1";
3669 let hash = hash_sql(sql);
3670 conn.prepare_only(sql, hash).unwrap();
3671 assert_eq!(conn.stmt_cache_len(), 1);
3672 conn.prepare_only(sql, hash).unwrap();
3674 assert_eq!(conn.stmt_cache_len(), 1);
3675 let _ = conn.close();
3676 }
3677
3678 #[test]
3679 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3681 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3682 let mut conn = Connection::connect(&config).unwrap();
3683 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3684 assert!(!rows.is_empty());
3685 let _ = conn.close();
3686 }
3687
3688 #[test]
3689 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3691 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3692 let mut conn = Connection::connect(&config).unwrap();
3693 let sql1 = "SELECT 1";
3694 let hash1 = hash_sql(sql1);
3695 conn.query(sql1, hash1, &[]).unwrap();
3696 assert_eq!(conn.stmt_cache_len(), 1);
3697 conn.query(sql1, hash1, &[]).unwrap();
3699 assert_eq!(conn.stmt_cache_len(), 1);
3700 let sql2 = "SELECT 2";
3702 let hash2 = hash_sql(sql2);
3703 conn.query(sql2, hash2, &[]).unwrap();
3704 assert_eq!(conn.stmt_cache_len(), 2);
3705 let _ = conn.close();
3706 }
3707
3708 #[test]
3709 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3711 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3712 let mut conn = Connection::connect(&config).unwrap();
3713 let sql = "SELECTTTT INVALID GARBAGE";
3714 let hash = hash_sql(sql);
3715 let result = conn.query(sql, hash, &[]);
3716 assert!(result.is_err());
3717 assert!(conn.is_idle());
3719 let _ = conn.close();
3720 }
3721
3722 #[test]
3723 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3725 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3726 let mut conn = Connection::connect(&config).unwrap();
3727 assert!(conn.is_idle());
3728 assert!(!conn.is_in_transaction());
3729 conn.simple_query("BEGIN").unwrap();
3730 assert!(conn.is_in_transaction());
3731 assert!(!conn.is_idle());
3732 conn.simple_query("COMMIT").unwrap();
3733 assert!(conn.is_idle());
3734 assert!(!conn.is_in_transaction());
3735 let _ = conn.close();
3736 }
3737
3738 #[test]
3739 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3741 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3742 let mut conn = Connection::connect(&config).unwrap();
3743 conn.set_max_stmt_cache_size(3);
3744 for i in 0..5 {
3745 let sql = format!("SELECT {}", i);
3746 let hash = hash_sql(&sql);
3747 conn.query(&sql, hash, &[]).unwrap();
3748 }
3749 assert!(
3751 conn.stmt_cache_len() <= 3,
3752 "cache should be capped at 3, got {}",
3753 conn.stmt_cache_len()
3754 );
3755 let _ = conn.close();
3756 }
3757
3758 #[test]
3759 #[ignore] fn sync_for_each_raw_if_pg_available() {
3761 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3762 let mut conn = Connection::connect(&config).unwrap();
3763 let sql = "SELECT generate_series(1, 3)";
3764 let hash = hash_sql(sql);
3765 let mut raw_count = 0u32;
3766 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3767 raw_count += 1;
3768 Ok(())
3769 })
3770 .unwrap();
3771 assert_eq!(raw_count, 3);
3772 let _ = conn.close();
3773 }
3774
3775 #[test]
3776 #[ignore] fn sync_query_null_params_if_pg_available() {
3778 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3779 let mut conn = Connection::connect(&config).unwrap();
3780 let sql = "SELECT $1::int4 IS NULL AS is_null";
3781 let hash = hash_sql(sql);
3782 let val: Option<i32> = None;
3783 let _result = conn
3784 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
3785 .unwrap();
3786 let _ = conn.close();
3787 }
3788
3789 #[test]
3790 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3792 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3793 let mut conn = Connection::connect(&config).unwrap();
3794 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3795 let hash = hash_sql(sql);
3796 let p1: i32 = 42;
3797 let p2: i64 = 9999999;
3798 let p3: &str = "hello";
3799 let p4: bool = true;
3800 let p5: f64 = 3.14;
3801 let result = conn
3802 .query(
3803 sql,
3804 hash,
3805 &[
3806 &p1 as &(dyn Encode + Sync),
3807 &p2 as &(dyn Encode + Sync),
3808 &p3 as &(dyn Encode + Sync),
3809 &p4 as &(dyn Encode + Sync),
3810 &p5 as &(dyn Encode + Sync),
3811 ],
3812 )
3813 .unwrap();
3814 assert_eq!(result.len(), 1);
3815 let _ = conn.close();
3816 }
3817
3818 #[test]
3821 fn sync_shrink_threshold_values() {
3822 let shrink = 64 * 1024usize;
3831 let initial = 8192usize;
3832 assert!(
3833 shrink > initial,
3834 "shrink threshold must exceed initial size"
3835 );
3836 }
3837
3838 #[test]
3841 fn sync_connection_debug_format() {
3842 let fmt_str = format!(
3846 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3847 0, 'I', 0
3848 );
3849 assert!(fmt_str.contains("Connection"));
3850 assert!(fmt_str.contains("pid"));
3851 assert!(fmt_str.contains("tx_status"));
3852 }
3853
3854 #[test]
3857 fn sync_connect_sslmode_require_without_tls_feature() {
3858 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3862 config.ssl = SslMode::Require;
3863 let result = Connection::connect(&config);
3864 assert!(result.is_err());
3865 }
3870
3871 #[test]
3872 fn sync_connect_sslmode_disable_attempts_tcp() {
3873 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3874 config.ssl = SslMode::Disable;
3875 let result = Connection::connect(&config);
3876 assert!(result.is_err());
3877 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
3879 }
3880
3881 #[test]
3882 fn sync_connect_sslmode_prefer_attempts_tcp() {
3883 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3884 config.ssl = SslMode::Prefer;
3885 let result = Connection::connect(&config);
3886 assert!(result.is_err());
3887 }
3888
3889 #[test]
3892 #[ignore] fn sync_streaming_basic_if_pg_available() {
3894 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3895 let mut conn = Connection::connect(&config).unwrap();
3896 assert!(!conn.is_streaming());
3897
3898 let sql = "SELECT generate_series(1, 10)";
3899 let hash = hash_sql(sql);
3900
3901 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
3902 assert!(!cols.is_empty());
3903 assert!(conn.is_streaming());
3904
3905 let mut arena = Arena::new();
3906 let mut offsets = Vec::new();
3907 let mut total_rows = 0;
3908
3909 loop {
3911 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
3912 total_rows += offsets.len();
3913 if !has_more {
3914 break;
3915 }
3916 conn.streaming_send_execute(3).unwrap();
3917 }
3918
3919 assert_eq!(total_rows, 10);
3920 assert!(!conn.is_streaming());
3921 let _ = conn.close();
3922 }
3923
3924 #[test]
3927 #[ignore] fn sync_prepare_describe_if_pg_available() {
3929 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3930 let mut conn = Connection::connect(&config).unwrap();
3931
3932 let result = conn
3933 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
3934 .unwrap();
3935 assert_eq!(result.columns.len(), 1);
3936 assert_eq!(&*result.columns[0].name, "sum");
3937 assert_eq!(result.param_oids.len(), 2);
3938 let _ = conn.close();
3939 }
3940
3941 #[test]
3944 #[ignore] fn sync_wait_for_notification_if_pg_available() {
3946 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3947 let mut conn = Connection::connect(&config).unwrap();
3948
3949 conn.simple_query("LISTEN test_chan").unwrap();
3950 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
3951
3952 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
3954 .unwrap();
3955
3956 let (channel, payload) = conn.wait_for_notification().unwrap();
3957 assert_eq!(channel, "test_chan");
3958 assert_eq!(payload, "hello");
3959 let _ = conn.close();
3960 }
3961
3962 #[test]
3965 #[ignore] fn sync_cancel_if_pg_available() {
3967 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3968 let conn = Connection::connect(&config).unwrap();
3969 let result = conn.cancel();
3972 let _ = result;
3974 let _ = conn.close();
3975 }
3976
3977 #[test]
3980 #[ignore] fn sync_server_params_if_pg_available() {
3982 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3983 let conn = Connection::connect(&config).unwrap();
3984 let params = conn.server_params();
3985 assert!(
3986 !params.is_empty(),
3987 "server should send parameters during startup"
3988 );
3989 assert!(
3991 conn.parameter("server_encoding").is_some(),
3992 "server_encoding should be present"
3993 );
3994 let _ = conn.close();
3995 }
3996
3997 #[test]
4000 #[ignore] fn sync_set_read_timeout_if_pg_available() {
4002 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4003 let conn = Connection::connect(&config).unwrap();
4004 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
4006 .unwrap();
4007 conn.set_read_timeout(None).unwrap();
4008 let _ = conn.close();
4009 }
4010}