1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::arena::Arena;
18use crate::auth;
19use crate::codec::Encode;
20use crate::proto::{self, BackendMessage};
21use crate::stmt_cache::{build_bind_template, make_stmt_name, StmtCache, StmtInfo};
22use crate::sync_io::Stream;
23use crate::types::{
24 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
25 StartupAction, StatementCacheMode,
26};
27use crate::DriverError;
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42pub(crate) fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58thread_local! {
59 static COL_OFFSETS_POOL: RefCell<Vec<Vec<(usize, i32)>>> = const { RefCell::new(Vec::new()) };
60}
61
62pub(crate) fn acquire_col_offsets() -> Vec<(usize, i32)> {
63 COL_OFFSETS_POOL
64 .with(|pool| pool.borrow_mut().pop())
65 .unwrap_or_default()
66}
67
68pub fn release_col_offsets(buf: Vec<(usize, i32)>) {
69 COL_OFFSETS_POOL.with(|pool| {
70 let mut pool = pool.borrow_mut();
71 if pool.len() < 4 {
72 pool.push(buf);
73 }
74 });
75}
76
77pub struct Connection {
106 stream_buf_pos: usize,
108 stream_buf_end: usize,
109 query_counter: u64,
110 tx_status: u8,
111 streaming_active: bool,
112 pid: i32,
113 secret: i32,
114 max_stmt_cache_size: usize,
115 statement_cache_mode: StatementCacheMode,
116 stream: Stream,
118 write_buf: Vec<u8>,
119 stream_buf: Vec<u8>,
120 stmts: StmtCache,
121 read_buf: Vec<u8>,
123 params: Vec<(Box<str>, Box<str>)>,
124 last_used: std::time::Instant,
125 created_at: std::time::Instant,
126 pending_notifications: Vec<Notification>,
127 connect_config: Arc<Config>,
130 tls_server_cert_hash: Option<[u8; 32]>,
133}
134
135impl std::fmt::Debug for Connection {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("Connection")
138 .field("pid", &self.pid)
139 .field("tx_status", &(self.tx_status as char))
140 .field("stmt_cache_len", &self.stmts.len())
141 .finish()
142 }
143}
144
145impl Connection {
146 pub fn connect(config: &Config) -> Result<Self, DriverError> {
158 Self::connect_arc(Arc::new(config.clone()))
159 }
160
161 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
166 config.validate()?;
167
168 #[allow(unused_mut)]
171 let mut tls_cert_hash: Option<[u8; 32]> = None;
172
173 let stream = if config.host_is_uds() {
174 #[cfg(unix)]
176 {
177 let path = config.uds_path();
178 let unix =
179 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
180 Stream::Unix(unix)
181 }
182 #[cfg(not(unix))]
183 {
184 return Err(DriverError::Protocol(
185 "Unix domain sockets are not supported on this platform".into(),
186 ));
187 }
188 } else {
189 let addr = format!("{}:{}", config.host, config.port);
191 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
192
193 match config.ssl {
194 SslMode::Disable => {
195 tcp.set_nodelay(true).map_err(DriverError::Io)?;
196 let stream = Stream::Tcp(tcp);
197 stream.set_keepalive()?;
198 stream
199 }
200 SslMode::Prefer | SslMode::Require => {
201 #[cfg(feature = "tls")]
202 {
203 match crate::tls_sync::try_upgrade(
204 tcp,
205 &config,
206 config.ssl == SslMode::Require,
207 ) {
208 Ok(result) => {
209 tls_cert_hash = result.server_cert_hash;
210 let stream = Stream::Tls(Box::new(result.stream));
211 stream.set_nodelay()?;
212 stream.set_keepalive()?;
213 stream
214 }
215 Err(e) => {
216 if config.ssl == SslMode::Require {
217 return Err(e);
218 }
219 let tcp =
221 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
222 tcp.set_nodelay(true).map_err(DriverError::Io)?;
223 let stream = Stream::Tcp(tcp);
224 stream.set_keepalive()?;
225 stream
226 }
227 }
228 }
229 #[cfg(not(feature = "tls"))]
230 {
231 if config.ssl == SslMode::Require {
232 return Err(DriverError::Protocol(
233 "sslmode=require but bsql was compiled without the 'tls' feature"
234 .into(),
235 ));
236 }
237 tcp.set_nodelay(true).map_err(DriverError::Io)?;
238 let stream = Stream::Tcp(tcp);
239 stream.set_keepalive()?;
240 stream
241 }
242 }
243 }
244 };
245
246 let now = std::time::Instant::now();
247 let mut conn = Self {
248 stream_buf_pos: 0,
250 stream_buf_end: 0,
251 query_counter: 0,
252 tx_status: b'I',
253 streaming_active: false,
254 pid: 0,
255 secret: 0,
256 max_stmt_cache_size: 256,
257 statement_cache_mode: config.statement_cache_mode,
258 stream,
260 write_buf: Vec::with_capacity(4096),
261 stream_buf: vec![0u8; 65536],
262 stmts: StmtCache::default(),
263 read_buf: Vec::with_capacity(8192),
265 params: Vec::new(),
266 last_used: now,
267 created_at: now,
268 pending_notifications: Vec::new(),
269 connect_config: config.clone(),
270 tls_server_cert_hash: tls_cert_hash,
271 };
272
273 conn.startup(&config)?;
274 conn.validate_server_params()?;
275
276 Ok(conn)
277 }
278
279 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
282 self.write_buf.clear();
283 let timeout_str; let mut extra_params: smallvec::SmallVec<[(&str, &str); 2]> = smallvec::SmallVec::new();
287 if config.statement_timeout_secs > 0 {
288 timeout_str = format!("{}s", config.statement_timeout_secs);
289 extra_params.push(("statement_timeout", &timeout_str));
290 }
291 proto::write_startup(
292 &mut self.write_buf,
293 &config.user,
294 &config.database,
295 &extra_params,
296 );
297 self.flush_write()?;
298
299 loop {
300 let action = self.read_startup_action()?;
301 match action {
302 StartupAction::AuthOk => {}
303 StartupAction::AuthCleartext => {
304 self.write_buf.clear();
305 let mut pw = config.password.as_bytes().to_vec();
306 pw.push(0);
307 proto::write_password(&mut self.write_buf, &pw);
308 self.flush_write()?;
309 }
310 StartupAction::AuthMd5(salt) => {
311 self.write_buf.clear();
312 let hash = auth::md5_password(&config.user, &config.password, &salt);
313 proto::write_password(&mut self.write_buf, &hash);
314 self.flush_write()?;
315 }
316 StartupAction::AuthSasl(mechanisms_data) => {
317 self.handle_scram(config, &mechanisms_data)?;
318 }
319 StartupAction::ParameterStatus(name, value) => {
320 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
321 entry.1 = value;
322 } else {
323 self.params.push((name, value));
324 }
325 }
326 StartupAction::BackendKeyData(pid, secret) => {
327 self.pid = pid;
328 self.secret = secret;
329 }
330 StartupAction::ReadyForQuery(status) => {
331 self.tx_status = status;
332 return Ok(());
333 }
334 StartupAction::Error(msg) => {
335 return Err(DriverError::Auth(msg));
336 }
337 StartupAction::Notice => {}
338 }
339 }
340 }
341
342 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
343 let (msg_type, _) = self.read_message_buffered()?;
344 let payload = &self.read_buf;
345 let msg = proto::parse_backend_message(msg_type, payload)?;
346 match msg {
347 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
348 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
349 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
350 BackendMessage::AuthSasl { mechanisms } => {
351 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
352 }
353 BackendMessage::ParameterStatus { name, value } => {
354 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
355 }
356 BackendMessage::BackendKeyData { pid, secret } => {
357 Ok(StartupAction::BackendKeyData(pid, secret))
358 }
359 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
360 BackendMessage::ErrorResponse { data } => {
361 let fields = proto::parse_error_response(data);
362 Ok(StartupAction::Error(fields.to_string()))
363 }
364 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
365 other => Err(DriverError::Protocol(format!(
366 "unexpected message during startup: {other:?}"
367 ))),
368 }
369 }
370
371 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
372 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
373
374 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
377 let mechanism = if use_plus {
378 "SCRAM-SHA-256-PLUS"
379 } else {
380 "SCRAM-SHA-256"
381 };
382
383 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
384 return Err(DriverError::Auth(format!(
385 "server requires unsupported SASL mechanism(s): {mechs:?}"
386 )));
387 }
388
389 let cert_hash = if use_plus {
390 self.tls_server_cert_hash.as_ref()
391 } else {
392 None
393 };
394 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
395
396 let client_first = scram.client_first_message();
398 self.write_buf.clear();
399 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
400 self.flush_write()?;
401
402 let (msg_type, _) = self.read_message_buffered()?;
404 let server_first = {
405 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
406 match msg {
407 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
408 BackendMessage::ErrorResponse { data } => {
409 let fields = proto::parse_error_response(data);
410 return Err(DriverError::Auth(fields.to_string()));
411 }
412 other => {
413 return Err(DriverError::Protocol(format!(
414 "expected AuthSaslContinue, got: {other:?}"
415 )));
416 }
417 }
418 };
419
420 scram.process_server_first(&server_first)?;
421
422 let client_final = scram.client_final_message()?;
424 self.write_buf.clear();
425 proto::write_sasl_response(&mut self.write_buf, &client_final);
426 self.flush_write()?;
427
428 let (msg_type, _) = self.read_message_buffered()?;
430 {
431 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
432 match msg {
433 BackendMessage::AuthSaslFinal { data } => {
434 let data_owned = data.to_vec();
435 scram.verify_server_final(&data_owned)?;
436 }
437 BackendMessage::ErrorResponse { data } => {
438 let fields = proto::parse_error_response(data);
439 return Err(DriverError::Auth(fields.to_string()));
440 }
441 other => {
442 return Err(DriverError::Protocol(format!(
443 "expected AuthSaslFinal, got: {other:?}"
444 )));
445 }
446 }
447 }
448
449 let (msg_type, _) = self.read_message_buffered()?;
451 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
452 match msg {
453 BackendMessage::AuthOk => Ok(()),
454 BackendMessage::ErrorResponse { data } => {
455 let fields = proto::parse_error_response(data);
456 Err(DriverError::Auth(fields.to_string()))
457 }
458 other => Err(DriverError::Protocol(format!(
459 "expected AuthOk after SCRAM, got: {other:?}"
460 ))),
461 }
462 }
463
464 fn validate_server_params(&self) -> Result<(), DriverError> {
465 if let Some(encoding) = self.parameter("server_encoding") {
466 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
467 return Err(DriverError::Protocol(format!(
468 "server_encoding is '{encoding}', but bsql requires UTF-8."
469 )));
470 }
471 }
472 if let Some(encoding) = self.parameter("client_encoding") {
473 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
474 return Err(DriverError::Protocol(format!(
475 "client_encoding is '{encoding}', but bsql requires UTF-8."
476 )));
477 }
478 }
479 if let Some(idt) = self.parameter("integer_datetimes") {
480 if idt != "on" {
481 return Err(DriverError::Protocol(format!(
482 "integer_datetimes is '{idt}', but bsql requires 'on'."
483 )));
484 }
485 }
486 Ok(())
487 }
488
489 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
495 if self.statement_cache_mode == StatementCacheMode::Disabled {
496 return Ok(()); }
498 if self.stmts.contains_key(&sql_hash, sql) {
499 return Ok(());
500 }
501 let name = make_stmt_name(sql_hash);
502 self.write_buf.clear();
503 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
504 proto::write_describe(&mut self.write_buf, b'S', &name);
505 proto::write_sync(&mut self.write_buf);
506 self.flush_write()?;
507
508 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
509 let columns = self.read_column_description()?;
510 self.expect_ready()?;
511
512 self.query_counter += 1;
513 self.cache_stmt(
514 sql_hash,
515 StmtInfo {
516 name,
517 sql: sql.into(),
518 columns,
519 last_used: self.query_counter,
520 bind_template: None,
521 },
522 );
523 Ok(())
524 }
525
526 pub fn prepare_batch(&mut self, sqls: &[(&str, u64)]) -> Result<(), DriverError> {
534 if sqls.is_empty() || self.statement_cache_mode == StatementCacheMode::Disabled {
535 return Ok(()); }
537
538 let mut pending = 0usize;
540 self.write_buf.clear();
541 for &(sql, sql_hash) in sqls {
542 if self.stmts.contains_key(&sql_hash, sql) {
543 continue;
544 }
545 let name = make_stmt_name(sql_hash);
546 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
547 proto::write_describe(&mut self.write_buf, b'S', &name);
548 pending += 1;
549 }
550
551 if pending == 0 {
552 return Ok(());
553 }
554
555 proto::write_sync(&mut self.write_buf);
556 self.flush_write()?;
557
558 for &(sql, sql_hash) in sqls {
561 if self.stmts.contains_key(&sql_hash, sql) {
562 continue;
563 }
564
565 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
566 let columns = self.read_column_description()?;
567
568 let name = make_stmt_name(sql_hash);
569 self.query_counter += 1;
570 self.cache_stmt(
571 sql_hash,
572 StmtInfo {
573 name,
574 sql: sql.into(),
575 columns,
576 last_used: self.query_counter,
577 bind_template: None,
578 },
579 );
580 }
581
582 self.expect_ready()?;
583 Ok(())
584 }
585
586 #[inline]
597 pub fn query(
598 &mut self,
599 sql: &str,
600 sql_hash: u64,
601 params: &[&(dyn Encode + Sync)],
602 ) -> Result<QueryResult, DriverError> {
603 let columns = self
604 .send_pipeline(sql, sql_hash, params, true, true)?
605 .ok_or_else(|| {
606 DriverError::Protocol("send_pipeline(need_columns=true) returned None".into())
607 })?;
608
609 let num_cols = columns.len();
610 let mut all_col_offsets = acquire_col_offsets();
611 all_col_offsets.clear();
612 let mut affected_rows: u64 = 0;
613
614 let mut resp_buf = acquire_resp_buf();
623 resp_buf.clear();
624
625 'outer: loop {
627 loop {
628 let avail = self.stream_buf_end - self.stream_buf_pos;
629 if avail < 5 {
630 break; }
632
633 let msg_type = self.stream_buf[self.stream_buf_pos];
634 let raw_len = i32::from_be_bytes([
635 self.stream_buf[self.stream_buf_pos + 1],
636 self.stream_buf[self.stream_buf_pos + 2],
637 self.stream_buf[self.stream_buf_pos + 3],
638 self.stream_buf[self.stream_buf_pos + 4],
639 ]);
640
641 if raw_len < 4 {
642 return Err(DriverError::Protocol(format!(
643 "invalid message length {raw_len} for type '{}'",
644 msg_type as char
645 )));
646 }
647
648 let payload_len = (raw_len - 4) as usize;
649 let total_msg_len = 5 + payload_len;
650
651 if avail < total_msg_len {
652 if total_msg_len > self.stream_buf.len() {
653 let msg = self.read_one_message()?;
655 match msg {
656 BackendMessage::BindComplete => continue,
657 BackendMessage::DataRow { data } => {
658 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
659 continue;
660 }
661 BackendMessage::CommandComplete { tag } => {
662 affected_rows = proto::parse_command_tag(tag);
663 continue;
664 }
665 BackendMessage::EmptyQuery => continue,
666 BackendMessage::ReadyForQuery { status } => {
667 self.tx_status = status;
668 break 'outer;
669 }
670 BackendMessage::NoticeResponse { .. } => continue,
671 BackendMessage::ErrorResponse { data } => {
672 let fields = proto::parse_error_response(data);
673 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
674 self.drain_to_ready()?;
675 return Err(self.make_server_error(fields));
676 }
677 other => {
678 return Err(DriverError::Protocol(format!(
679 "unexpected message during query: {other:?}"
680 )));
681 }
682 }
683 }
684 break; }
686
687 let payload_start = self.stream_buf_pos + 5;
689 let payload_end = payload_start + payload_len;
690
691 if msg_type == b'D' {
692 parse_data_row_into_buf(
694 &self.stream_buf[payload_start..payload_end],
695 &mut resp_buf,
696 &mut all_col_offsets,
697 )?;
698 } else if msg_type == b'Z' {
699 if payload_len >= 1 {
700 self.tx_status = self.stream_buf[payload_start];
701 }
702 self.stream_buf_pos += total_msg_len;
703 break 'outer;
704 } else {
705 self.handle_non_datarow_query(
706 msg_type,
707 payload_start,
708 payload_end,
709 sql_hash,
710 &mut affected_rows,
711 )?;
712 }
713
714 self.stream_buf_pos += total_msg_len;
715 }
716
717 self.refill_stream_buf()?;
718 }
719
720 self.shrink_buffers();
721
722 Ok(QueryResult::from_parts_with_buf(
725 all_col_offsets,
726 num_cols,
727 columns,
728 affected_rows,
729 resp_buf,
730 ))
731 }
732
733 #[inline]
744 pub fn execute_monolithic(
745 &mut self,
746 sql: &str,
747 sql_hash: u64,
748 params: &[&(dyn Encode + Sync)],
749 ) -> Result<u64, DriverError> {
750 if self.statement_cache_mode == StatementCacheMode::Disabled {
752 return self.execute_unnamed(sql, params);
753 }
754
755 self.write_buf.clear();
757
758 let info = match self.stmts.get_mut(&sql_hash, sql) {
760 Some(info) => {
761 self.query_counter += 1;
762 info.last_used = self.query_counter;
763 info
764 }
765 None => {
766 return self.execute_with_prepare(sql, sql_hash, params);
768 }
769 };
770
771 let can_use_template = info
773 .bind_template
774 .as_ref()
775 .is_some_and(|t| t.param_slots.len() == params.len());
776
777 let mut has_exec_sync = false;
778
779 if can_use_template {
780 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
782 DriverError::Protocol("bind_template missing despite can_use_template".into())
783 })?;
784 self.write_buf.extend_from_slice(&tmpl.bytes);
785
786 let mut template_ok = true;
787 for (i, param) in params.iter().enumerate() {
788 let (data_offset, old_len) = tmpl.param_slots[i];
789 if param.is_null() {
790 let len_offset = data_offset - 4;
791 self.write_buf[len_offset..len_offset + 4]
792 .copy_from_slice(&(-1i32).to_be_bytes());
793 } else if old_len >= 0 {
794 let end = data_offset + old_len as usize;
795 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
796 template_ok = false;
797 break;
798 }
799 } else {
800 template_ok = false;
802 break;
803 }
804 }
805
806 if template_ok {
807 has_exec_sync = true;
808 } else {
809 self.write_buf.clear();
810 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
811 info.bind_template = None;
812 }
813 } else {
814 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
815 }
816
817 if info.bind_template.is_none() && !self.write_buf.is_empty() {
819 info.bind_template = build_bind_template(&self.write_buf, params.len());
820 }
821
822 if !has_exec_sync {
823 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
824 }
825
826 self.stream
828 .write_all(&self.write_buf)
829 .map_err(DriverError::Io)?;
830
831 let mut affected_rows: u64 = 0;
833
834 'outer: loop {
835 loop {
836 let avail = self.stream_buf_end - self.stream_buf_pos;
837 if avail < 5 {
838 break; }
840
841 let msg_type = self.stream_buf[self.stream_buf_pos];
842 let raw_len = i32::from_be_bytes([
843 self.stream_buf[self.stream_buf_pos + 1],
844 self.stream_buf[self.stream_buf_pos + 2],
845 self.stream_buf[self.stream_buf_pos + 3],
846 self.stream_buf[self.stream_buf_pos + 4],
847 ]);
848
849 if raw_len < 4 {
850 return Err(DriverError::Protocol(format!(
851 "invalid message length {raw_len} for type '{}'",
852 msg_type as char
853 )));
854 }
855
856 let payload_len = (raw_len - 4) as usize;
857 let total_msg_len = 5 + payload_len;
858
859 if avail < total_msg_len {
860 if total_msg_len > self.stream_buf.len() {
861 let msg = self.read_one_message()?;
862 match msg {
863 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
864 continue;
865 }
866 BackendMessage::CommandComplete { tag } => {
867 affected_rows = proto::parse_command_tag(tag);
868 continue;
869 }
870 BackendMessage::EmptyQuery => continue,
871 BackendMessage::ReadyForQuery { status } => {
872 self.tx_status = status;
873 break 'outer;
874 }
875 BackendMessage::NoticeResponse { .. } => continue,
876 BackendMessage::ErrorResponse { data } => {
877 let fields = proto::parse_error_response(data);
878 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
879 self.drain_to_ready()?;
880 return Err(self.make_server_error(fields));
881 }
882 other => {
883 return Err(DriverError::Protocol(format!(
884 "unexpected message during execute: {other:?}"
885 )));
886 }
887 }
888 }
889 break; }
891
892 let payload_start = self.stream_buf_pos + 5;
897 let payload_end = payload_start + payload_len;
898
899 if msg_type == b'2' {
900 self.stream_buf_pos += total_msg_len;
902 continue;
903 } else if msg_type == b'C' {
904 affected_rows = proto::parse_command_tag_bytes(
906 &self.stream_buf[payload_start..payload_end],
907 );
908 } else if msg_type == b'Z' {
909 if payload_len >= 1 {
911 self.tx_status = self.stream_buf[payload_start];
912 }
913 self.stream_buf_pos += total_msg_len;
914 break 'outer;
915 } else if msg_type == b'D' || msg_type == b'I' {
916 } else {
918 self.handle_non_datarow_execute(
919 msg_type,
920 payload_start,
921 payload_end,
922 sql_hash,
923 )?;
924 }
925
926 self.stream_buf_pos += total_msg_len;
927 }
928
929 let remaining = self.stream_buf_end - self.stream_buf_pos;
931 debug_assert!(
932 remaining == 0 || self.stream_buf_pos > 0,
933 "compact called with pos=0 and remaining data"
934 );
935 if remaining > 0 {
936 self.stream_buf
937 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
938 }
939 self.stream_buf_pos = 0;
940 self.stream_buf_end = remaining;
941 let n = self
942 .stream
943 .read(&mut self.stream_buf[remaining..])
944 .map_err(DriverError::Io)?;
945 if n == 0 {
946 return Err(DriverError::Io(std::io::Error::new(
947 std::io::ErrorKind::UnexpectedEof,
948 "connection closed",
949 )));
950 }
951 self.stream_buf_end = remaining + n;
952 }
953
954 if self.query_counter & 63 == 0 {
956 if self.read_buf.capacity() > 64 * 1024 {
957 self.read_buf.clear();
958 self.read_buf.shrink_to(8192);
959 }
960 if self.write_buf.capacity() > 16 * 1024 {
961 self.write_buf.clear();
962 self.write_buf.shrink_to(8192);
963 }
964 }
965
966 Ok(affected_rows)
967 }
968
969 #[cold]
971 #[inline(never)]
972 fn execute_with_prepare(
973 &mut self,
974 sql: &str,
975 sql_hash: u64,
976 params: &[&(dyn Encode + Sync)],
977 ) -> Result<u64, DriverError> {
978 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
979
980 if params.len() > i16::MAX as usize {
981 return Err(DriverError::Protocol(format!(
982 "parameter count {} exceeds maximum {}",
983 params.len(),
984 i16::MAX
985 )));
986 }
987
988 let name = make_stmt_name(sql_hash);
989 let param_oids: smallvec::SmallVec<[u32; 8]> =
990 params.iter().map(|p| p.type_oid()).collect();
991
992 self.write_buf.clear();
993 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
994 proto::write_describe(&mut self.write_buf, b'S', &name);
995 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
996 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
997 self.stream
998 .write_all(&self.write_buf)
999 .map_err(DriverError::Io)?;
1000
1001 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1002 let columns = self.read_column_description()?;
1003 self.query_counter += 1;
1004 self.cache_stmt(
1005 sql_hash,
1006 StmtInfo {
1007 name,
1008 sql: sql.into(),
1009 columns,
1010 last_used: self.query_counter,
1011 bind_template: None,
1012 },
1013 );
1014
1015 let mut affected_rows: u64 = 0;
1017 'outer: loop {
1018 loop {
1019 let avail = self.stream_buf_end - self.stream_buf_pos;
1020 if avail < 5 {
1021 break;
1022 }
1023
1024 let msg_type = self.stream_buf[self.stream_buf_pos];
1025 let raw_len = i32::from_be_bytes([
1026 self.stream_buf[self.stream_buf_pos + 1],
1027 self.stream_buf[self.stream_buf_pos + 2],
1028 self.stream_buf[self.stream_buf_pos + 3],
1029 self.stream_buf[self.stream_buf_pos + 4],
1030 ]);
1031
1032 if raw_len < 4 {
1033 return Err(DriverError::Protocol(format!(
1034 "invalid message length {raw_len} for type '{}'",
1035 msg_type as char
1036 )));
1037 }
1038
1039 let payload_len = (raw_len - 4) as usize;
1040 let total_msg_len = 5 + payload_len;
1041
1042 if avail < total_msg_len {
1043 if total_msg_len > self.stream_buf.len() {
1044 let msg = self.read_one_message()?;
1045 match msg {
1046 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
1047 continue;
1048 }
1049 BackendMessage::CommandComplete { tag } => {
1050 affected_rows = proto::parse_command_tag(tag);
1051 continue;
1052 }
1053 BackendMessage::EmptyQuery => continue,
1054 BackendMessage::ReadyForQuery { status } => {
1055 self.tx_status = status;
1056 break 'outer;
1057 }
1058 BackendMessage::NoticeResponse { .. } => continue,
1059 BackendMessage::ErrorResponse { data } => {
1060 let fields = proto::parse_error_response(data);
1061 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1062 self.drain_to_ready()?;
1063 return Err(self.make_server_error(fields));
1064 }
1065 other => {
1066 return Err(DriverError::Protocol(format!(
1067 "unexpected message during execute: {other:?}"
1068 )));
1069 }
1070 }
1071 }
1072 break;
1073 }
1074
1075 let payload_start = self.stream_buf_pos + 5;
1076 let payload_end = payload_start + payload_len;
1077
1078 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
1079 } else if msg_type == b'C' {
1081 affected_rows = proto::parse_command_tag_bytes(
1082 &self.stream_buf[payload_start..payload_end],
1083 );
1084 } else if msg_type == b'Z' {
1085 if payload_len >= 1 {
1086 self.tx_status = self.stream_buf[payload_start];
1087 }
1088 self.stream_buf_pos += total_msg_len;
1089 break 'outer;
1090 } else {
1091 self.handle_non_datarow_execute(
1092 msg_type,
1093 payload_start,
1094 payload_end,
1095 sql_hash,
1096 )?;
1097 }
1098
1099 self.stream_buf_pos += total_msg_len;
1100 }
1101
1102 self.refill_stream_buf()?;
1103 }
1104
1105 Ok(affected_rows)
1106 }
1107
1108 fn execute_unnamed(
1112 &mut self,
1113 sql: &str,
1114 params: &[&(dyn Encode + Sync)],
1115 ) -> Result<u64, DriverError> {
1116 if params.len() > i16::MAX as usize {
1117 return Err(DriverError::Protocol(format!(
1118 "parameter count {} exceeds maximum {}",
1119 params.len(),
1120 i16::MAX
1121 )));
1122 }
1123
1124 self.write_buf.clear();
1125 let param_oids: smallvec::SmallVec<[u32; 8]> =
1126 params.iter().map(|p| p.type_oid()).collect();
1127 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
1128 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
1129 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1130 self.stream
1131 .write_all(&self.write_buf)
1132 .map_err(DriverError::Io)?;
1133
1134 let mut affected_rows: u64 = 0;
1136 'outer: loop {
1137 loop {
1138 let avail = self.stream_buf_end - self.stream_buf_pos;
1139 if avail < 5 {
1140 break;
1141 }
1142
1143 let msg_type = self.stream_buf[self.stream_buf_pos];
1144 let raw_len = i32::from_be_bytes([
1145 self.stream_buf[self.stream_buf_pos + 1],
1146 self.stream_buf[self.stream_buf_pos + 2],
1147 self.stream_buf[self.stream_buf_pos + 3],
1148 self.stream_buf[self.stream_buf_pos + 4],
1149 ]);
1150
1151 if raw_len < 4 {
1152 return Err(DriverError::Protocol(format!(
1153 "invalid message length {raw_len} for type '{}'",
1154 msg_type as char
1155 )));
1156 }
1157
1158 let payload_len = (raw_len - 4) as usize;
1159 let total_msg_len = 5 + payload_len;
1160
1161 if avail < total_msg_len {
1162 if total_msg_len > self.stream_buf.len() {
1163 let msg = self.read_one_message()?;
1164 match msg {
1165 BackendMessage::ParseComplete
1166 | BackendMessage::BindComplete
1167 | BackendMessage::DataRow { .. } => continue,
1168 BackendMessage::CommandComplete { tag } => {
1169 affected_rows = proto::parse_command_tag(tag);
1170 continue;
1171 }
1172 BackendMessage::EmptyQuery => continue,
1173 BackendMessage::ReadyForQuery { status } => {
1174 self.tx_status = status;
1175 break 'outer;
1176 }
1177 BackendMessage::NoticeResponse { .. } => continue,
1178 BackendMessage::ErrorResponse { data } => {
1179 let fields = proto::parse_error_response(data);
1180 self.drain_to_ready()?;
1181 return Err(self.make_server_error(fields));
1182 }
1183 other => {
1184 return Err(DriverError::Protocol(format!(
1185 "unexpected message during unnamed execute: {other:?}"
1186 )));
1187 }
1188 }
1189 }
1190 break;
1191 }
1192
1193 if msg_type == b'1' || msg_type == b'2' || msg_type == b'I' {
1195 self.stream_buf_pos += total_msg_len;
1197 continue;
1198 } else if msg_type == b'C' {
1199 let payload_start = self.stream_buf_pos + 5;
1201 let payload_end = payload_start + payload_len;
1202 affected_rows = proto::parse_command_tag_bytes(
1203 &self.stream_buf[payload_start..payload_end],
1204 );
1205 self.stream_buf_pos += total_msg_len;
1206 continue;
1207 } else if msg_type == b'Z' {
1208 let payload_start = self.stream_buf_pos + 5;
1210 let payload_end = payload_start + payload_len;
1211 if payload_end > payload_start {
1212 self.tx_status = self.stream_buf[payload_start];
1213 }
1214 self.stream_buf_pos += total_msg_len;
1215 break 'outer;
1216 } else if msg_type == b'E' {
1217 let payload_start = self.stream_buf_pos + 5;
1219 let payload_end = payload_start + payload_len;
1220 let fields =
1221 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
1222 self.stream_buf_pos += total_msg_len;
1223 self.drain_to_ready()?;
1224 return Err(self.make_server_error(fields));
1225 } else if msg_type == b'N' || msg_type == b'D' {
1226 self.stream_buf_pos += total_msg_len;
1228 continue;
1229 } else {
1230 return Err(DriverError::Protocol(format!(
1231 "unexpected message type '{}' during unnamed execute",
1232 msg_type as char
1233 )));
1234 }
1235 }
1236
1237 self.refill_stream_buf()?;
1238 }
1239
1240 Ok(affected_rows)
1241 }
1242
1243 #[inline]
1248 pub fn execute(
1249 &mut self,
1250 sql: &str,
1251 sql_hash: u64,
1252 params: &[&(dyn Encode + Sync)],
1253 ) -> Result<u64, DriverError> {
1254 self.execute_monolithic(sql, sql_hash, params)
1255 }
1256
1257 pub fn execute_pipeline(
1269 &mut self,
1270 sql: &str,
1271 sql_hash: u64,
1272 param_sets: &[&[&(dyn Encode + Sync)]],
1273 ) -> Result<Vec<u64>, DriverError> {
1274 if param_sets.is_empty() {
1275 return Ok(Vec::new());
1276 }
1277
1278 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1279
1280 if self.statement_cache_mode == StatementCacheMode::Disabled {
1282 return self.execute_pipeline_unnamed(sql, param_sets);
1283 }
1284
1285 self.write_buf.clear();
1286
1287 if !self.stmts.contains_key(&sql_hash, sql) {
1289 let name = make_stmt_name(sql_hash);
1290 let first_params = param_sets[0];
1291 if first_params.len() > i16::MAX as usize {
1292 return Err(DriverError::Protocol(format!(
1293 "parameter count {} exceeds maximum {}",
1294 first_params.len(),
1295 i16::MAX
1296 )));
1297 }
1298 let param_oids: smallvec::SmallVec<[u32; 8]> =
1299 first_params.iter().map(|p| p.type_oid()).collect();
1300 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1301 proto::write_describe(&mut self.write_buf, b'S', &name);
1302 proto::write_sync(&mut self.write_buf);
1303 self.flush_write()?;
1304
1305 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1306 let columns = self.read_column_description()?;
1307 self.expect_ready()?;
1308
1309 self.query_counter += 1;
1310 self.cache_stmt(
1311 sql_hash,
1312 StmtInfo {
1313 name,
1314 sql: sql.into(),
1315 columns,
1316 last_used: self.query_counter,
1317 bind_template: None,
1318 },
1319 );
1320
1321 self.write_buf.clear();
1322 }
1323
1324 let stmt_name = self
1326 .stmts
1327 .get(&sql_hash, sql)
1328 .ok_or_else(|| {
1329 DriverError::Protocol("stmt just cached but not found in execute_pipeline".into())
1330 })?
1331 .name;
1332 let count = param_sets.len();
1333
1334 for params in param_sets {
1335 if params.len() > i16::MAX as usize {
1336 return Err(DriverError::Protocol(format!(
1337 "parameter count {} exceeds maximum {}",
1338 params.len(),
1339 i16::MAX
1340 )));
1341 }
1342 proto::write_bind_params(&mut self.write_buf, b"", &stmt_name, params);
1343 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1344 }
1345
1346 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1347 self.flush_write()?;
1348
1349 let mut results = Vec::with_capacity(count);
1352
1353 'outer: loop {
1354 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1355 if msg_type == b'2' {
1356 } else if msg_type == b'C' {
1358 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1360 results.push(rows);
1361 } else if msg_type == b'Z' {
1362 if end > start {
1364 self.tx_status = self.stream_buf[start];
1365 }
1366 self.advance_stream_msg(total);
1367 break 'outer;
1368 } else if msg_type == b'I' {
1369 results.push(0);
1371 } else if msg_type == b'D' || msg_type == b'N' {
1372 } else if msg_type == b'E' {
1374 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1376 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1377 self.advance_stream_msg(total);
1378 self.drain_to_ready()?;
1379 return Err(self.make_server_error(fields));
1380 } else if msg_type == b'A' {
1381 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1383 if let BackendMessage::NotificationResponse {
1384 pid,
1385 channel,
1386 payload,
1387 } = msg
1388 {
1389 let ch = channel.to_owned();
1390 let pl = payload.to_owned();
1391 self.buffer_notification(pid, &ch, &pl);
1392 }
1393 }
1394 self.advance_stream_msg(total);
1397 }
1398
1399 self.refill_stream_buf()?;
1401 }
1402
1403 self.shrink_buffers();
1404 Ok(results)
1405 }
1406
1407 fn execute_pipeline_unnamed(
1412 &mut self,
1413 sql: &str,
1414 param_sets: &[&[&(dyn Encode + Sync)]],
1415 ) -> Result<Vec<u64>, DriverError> {
1416 let count = param_sets.len();
1417 self.write_buf.clear();
1418
1419 for params in param_sets {
1420 if params.len() > i16::MAX as usize {
1421 return Err(DriverError::Protocol(format!(
1422 "parameter count {} exceeds maximum {}",
1423 params.len(),
1424 i16::MAX
1425 )));
1426 }
1427 let param_oids: smallvec::SmallVec<[u32; 8]> =
1428 params.iter().map(|p| p.type_oid()).collect();
1429 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
1430 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
1431 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1432 }
1433
1434 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1435 self.flush_write()?;
1436
1437 let mut results = Vec::with_capacity(count);
1439
1440 'outer: loop {
1441 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1442 if msg_type == b'1' || msg_type == b'2' {
1443 } else if msg_type == b'C' {
1445 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1446 results.push(rows);
1447 } else if msg_type == b'Z' {
1448 if end > start {
1449 self.tx_status = self.stream_buf[start];
1450 }
1451 self.advance_stream_msg(total);
1452 break 'outer;
1453 } else if msg_type == b'I' {
1454 results.push(0);
1455 } else if msg_type == b'D' || msg_type == b'N' {
1456 } else if msg_type == b'E' {
1458 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1459 self.advance_stream_msg(total);
1460 self.drain_to_ready()?;
1461 return Err(self.make_server_error(fields));
1462 } else if msg_type == b'A' {
1463 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1464 if let BackendMessage::NotificationResponse {
1465 pid,
1466 channel,
1467 payload,
1468 } = msg
1469 {
1470 let ch = channel.to_owned();
1471 let pl = payload.to_owned();
1472 self.buffer_notification(pid, &ch, &pl);
1473 }
1474 }
1475 self.advance_stream_msg(total);
1476 }
1477
1478 self.refill_stream_buf()?;
1479 }
1480
1481 self.shrink_buffers();
1482 Ok(results)
1483 }
1484
1485 pub(crate) fn ensure_stmt_prepared(
1491 &mut self,
1492 sql: &str,
1493 sql_hash: u64,
1494 params: &[&(dyn Encode + Sync)],
1495 ) -> Result<[u8; 18], DriverError> {
1496 if self.statement_cache_mode == StatementCacheMode::Disabled {
1498 return Ok([0u8; 18]);
1499 }
1500
1501 if let Some(info) = self.stmts.get(&sql_hash, sql) {
1502 return Ok(info.name);
1503 }
1504
1505 let name = make_stmt_name(sql_hash);
1506 if params.len() > i16::MAX as usize {
1507 return Err(DriverError::Protocol(format!(
1508 "parameter count {} exceeds maximum {}",
1509 params.len(),
1510 i16::MAX
1511 )));
1512 }
1513 let param_oids: smallvec::SmallVec<[u32; 8]> =
1514 params.iter().map(|p| p.type_oid()).collect();
1515
1516 self.write_buf.clear();
1517 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1518 proto::write_describe(&mut self.write_buf, b'S', &name);
1519 proto::write_sync(&mut self.write_buf);
1520 self.flush_write()?;
1521
1522 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1523 let columns = self.read_column_description()?;
1524 self.expect_ready()?;
1525
1526 self.query_counter += 1;
1527 self.cache_stmt(
1528 sql_hash,
1529 StmtInfo {
1530 name,
1531 sql: sql.into(),
1532 columns,
1533 last_used: self.query_counter,
1534 bind_template: None,
1535 },
1536 );
1537
1538 Ok(name)
1539 }
1540
1541 pub(crate) fn write_deferred_bind_execute(
1545 &self,
1546 sql: &str,
1547 sql_hash: u64,
1548 params: &[&(dyn Encode + Sync)],
1549 buf: &mut Vec<u8>,
1550 ) -> Result<(), DriverError> {
1551 if self.statement_cache_mode == StatementCacheMode::Disabled {
1552 let param_oids: smallvec::SmallVec<[u32; 8]> =
1554 params.iter().map(|p| p.type_oid()).collect();
1555 proto::write_parse(buf, b"", sql, ¶m_oids);
1556 proto::write_bind_params(buf, b"", b"", params);
1557 buf.extend_from_slice(proto::EXECUTE_ONLY);
1558 return Ok(());
1559 }
1560
1561 let stmt_name = self
1562 .stmts
1563 .get(&sql_hash, sql)
1564 .ok_or_else(|| {
1565 DriverError::Protocol("stmt just cached but not found in write_deferred".into())
1566 })?
1567 .name;
1568 proto::write_bind_params(buf, b"", &stmt_name, params);
1569 buf.extend_from_slice(proto::EXECUTE_ONLY);
1570 Ok(())
1571 }
1572
1573 pub(crate) fn flush_deferred_pipeline(
1578 &mut self,
1579 buf: &mut Vec<u8>,
1580 count: usize,
1581 ) -> Result<Vec<u64>, DriverError> {
1582 if count == 0 {
1583 buf.clear();
1584 return Ok(Vec::new());
1585 }
1586
1587 buf.extend_from_slice(proto::SYNC_ONLY);
1588
1589 self.stream.write_all(buf).map_err(DriverError::Io)?;
1590 buf.clear();
1591
1592 let mut results = Vec::with_capacity(count);
1594
1595 'outer: loop {
1596 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1597 if msg_type == b'1' || msg_type == b'2' {
1598 } else if msg_type == b'C' {
1601 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1603 results.push(rows);
1604 } else if msg_type == b'Z' {
1605 if end > start {
1607 self.tx_status = self.stream_buf[start];
1608 }
1609 self.advance_stream_msg(total);
1610 break 'outer;
1611 } else if msg_type == b'I' {
1612 results.push(0);
1614 } else if msg_type == b'D' || msg_type == b'N' {
1615 } else if msg_type == b'E' {
1617 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1619 self.advance_stream_msg(total);
1620 self.drain_to_ready()?;
1621 return Err(self.make_server_error(fields));
1622 } else if msg_type == b'A' {
1623 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1625 if let BackendMessage::NotificationResponse {
1626 pid,
1627 channel,
1628 payload,
1629 } = msg
1630 {
1631 let ch = channel.to_owned();
1632 let pl = payload.to_owned();
1633 self.buffer_notification(pid, &ch, &pl);
1634 }
1635 }
1636 self.advance_stream_msg(total);
1639 }
1640
1641 self.refill_stream_buf()?;
1643 }
1644
1645 self.shrink_buffers();
1646 Ok(results)
1647 }
1648
1649 pub fn for_each<F>(
1651 &mut self,
1652 sql: &str,
1653 sql_hash: u64,
1654 params: &[&(dyn Encode + Sync)],
1655 mut f: F,
1656 ) -> Result<(), DriverError>
1657 where
1658 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1659 {
1660 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1661
1662 'outer: loop {
1664 loop {
1665 let avail = self.stream_buf_end - self.stream_buf_pos;
1666 if avail < 5 {
1667 break; }
1669
1670 let msg_type = self.stream_buf[self.stream_buf_pos];
1671 let raw_len = i32::from_be_bytes([
1672 self.stream_buf[self.stream_buf_pos + 1],
1673 self.stream_buf[self.stream_buf_pos + 2],
1674 self.stream_buf[self.stream_buf_pos + 3],
1675 self.stream_buf[self.stream_buf_pos + 4],
1676 ]);
1677
1678 if raw_len < 4 {
1679 return Err(DriverError::Protocol(format!(
1680 "invalid message length {raw_len} for type '{}'",
1681 msg_type as char
1682 )));
1683 }
1684
1685 let payload_len = (raw_len - 4) as usize;
1686 let total_msg_len = 5 + payload_len;
1687
1688 if avail < total_msg_len {
1689 if total_msg_len > self.stream_buf.len() {
1690 let msg = self.read_one_message()?;
1692 match msg {
1693 BackendMessage::BindComplete => continue,
1694 BackendMessage::DataRow { data } => {
1695 let row = PgDataRow::new(data)?;
1696 f(row)?;
1697 continue;
1698 }
1699 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1700 continue;
1701 }
1702 BackendMessage::ReadyForQuery { status } => {
1703 self.tx_status = status;
1704 break 'outer;
1705 }
1706 BackendMessage::NoticeResponse { .. } => continue,
1707 BackendMessage::ErrorResponse { data } => {
1708 let fields = proto::parse_error_response(data);
1709 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1710 self.drain_to_ready()?;
1711 return Err(self.make_server_error(fields));
1712 }
1713 other => {
1714 return Err(DriverError::Protocol(format!(
1715 "unexpected message during for_each: {other:?}"
1716 )));
1717 }
1718 }
1719 }
1720 break; }
1722
1723 let payload_start = self.stream_buf_pos + 5;
1725 let payload_end = payload_start + payload_len;
1726
1727 if msg_type == b'D' {
1730 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1732 f(row)?;
1733 } else if msg_type == b'Z' {
1734 if payload_len >= 1 {
1736 self.tx_status = self.stream_buf[payload_start];
1737 }
1738 self.stream_buf_pos += total_msg_len;
1739 break 'outer;
1740 } else {
1741 self.handle_non_datarow_execute(
1742 msg_type,
1743 payload_start,
1744 payload_end,
1745 sql_hash,
1746 )?;
1747 }
1748
1749 self.stream_buf_pos += total_msg_len;
1750 }
1751
1752 self.refill_stream_buf()?;
1754 }
1755
1756 self.shrink_buffers();
1757 Ok(())
1758 }
1759
1760 #[inline]
1771 pub fn for_each_raw_monolithic<F>(
1772 &mut self,
1773 sql: &str,
1774 sql_hash: u64,
1775 params: &[&(dyn Encode + Sync)],
1776 mut f: F,
1777 ) -> Result<(), DriverError>
1778 where
1779 F: FnMut(&[u8]) -> Result<(), DriverError>,
1780 {
1781 if self.statement_cache_mode == StatementCacheMode::Disabled {
1783 return self.for_each_raw_unnamed(sql, params, f);
1784 }
1785
1786 self.write_buf.clear();
1788
1789 let info = match self.stmts.get_mut(&sql_hash, sql) {
1791 Some(info) => {
1792 self.query_counter += 1;
1793 info.last_used = self.query_counter;
1794 info
1795 }
1796 None => {
1797 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1799 }
1800 };
1801
1802 let can_use_template = info
1804 .bind_template
1805 .as_ref()
1806 .is_some_and(|t| t.param_slots.len() == params.len());
1807
1808 let mut has_exec_sync = false;
1809
1810 if can_use_template {
1811 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
1813 DriverError::Protocol("bind_template missing despite can_use_template".into())
1814 })?;
1815 self.write_buf.extend_from_slice(&tmpl.bytes);
1816
1817 let mut template_ok = true;
1818 for (i, param) in params.iter().enumerate() {
1819 let (data_offset, old_len) = tmpl.param_slots[i];
1820 if param.is_null() {
1821 let len_offset = data_offset - 4;
1822 self.write_buf[len_offset..len_offset + 4]
1823 .copy_from_slice(&(-1i32).to_be_bytes());
1824 } else if old_len >= 0 {
1825 let end = data_offset + old_len as usize;
1826 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1827 template_ok = false;
1828 break;
1829 }
1830 } else {
1831 template_ok = false;
1832 break;
1833 }
1834 }
1835
1836 if template_ok {
1837 has_exec_sync = true;
1838 } else {
1839 self.write_buf.clear();
1840 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1841 info.bind_template = None;
1842 }
1843 } else {
1844 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1845 }
1846
1847 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1849 info.bind_template = build_bind_template(&self.write_buf, params.len());
1850 }
1851
1852 if !has_exec_sync {
1853 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1854 }
1855
1856 self.stream
1858 .write_all(&self.write_buf)
1859 .map_err(DriverError::Io)?;
1860
1861 loop {
1865 let avail = self.stream_buf_end - self.stream_buf_pos;
1866 if avail >= 5 {
1867 let bc_type = self.stream_buf[self.stream_buf_pos];
1868 match bc_type {
1869 b'2' => {
1870 self.stream_buf_pos += 5;
1871 break;
1872 }
1873 b'E' => {
1874 let msg = self.read_one_message()?;
1875 if let BackendMessage::ErrorResponse { data } = msg {
1876 let fields = proto::parse_error_response(data);
1877 self.drain_to_ready()?;
1878 return Err(self.make_server_error(fields));
1879 }
1880 }
1881 b'N' | b'S' => {
1882 let raw_len = i32::from_be_bytes([
1883 self.stream_buf[self.stream_buf_pos + 1],
1884 self.stream_buf[self.stream_buf_pos + 2],
1885 self.stream_buf[self.stream_buf_pos + 3],
1886 self.stream_buf[self.stream_buf_pos + 4],
1887 ]);
1888 let total = 1 + raw_len as usize;
1889 if avail >= total {
1890 self.stream_buf_pos += total;
1891 continue;
1892 }
1893 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1894 break;
1895 }
1896 _ => {
1897 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1898 break;
1899 }
1900 }
1901 } else {
1902 let remaining = self.stream_buf_end - self.stream_buf_pos;
1904 if remaining > 0 && self.stream_buf_pos > 0 {
1905 self.stream_buf
1906 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1907 }
1908 self.stream_buf_pos = 0;
1909 self.stream_buf_end = remaining;
1910 let n = self
1911 .stream
1912 .read(&mut self.stream_buf[remaining..])
1913 .map_err(DriverError::Io)?;
1914 if n == 0 {
1915 return Err(DriverError::Io(std::io::Error::new(
1916 std::io::ErrorKind::UnexpectedEof,
1917 "connection closed",
1918 )));
1919 }
1920 self.stream_buf_end = remaining + n;
1921 }
1922 }
1923
1924 'outer: loop {
1926 loop {
1927 let avail = self.stream_buf_end - self.stream_buf_pos;
1928 if avail < 5 {
1929 break;
1930 }
1931
1932 let msg_type = self.stream_buf[self.stream_buf_pos];
1933 let raw_len = i32::from_be_bytes([
1934 self.stream_buf[self.stream_buf_pos + 1],
1935 self.stream_buf[self.stream_buf_pos + 2],
1936 self.stream_buf[self.stream_buf_pos + 3],
1937 self.stream_buf[self.stream_buf_pos + 4],
1938 ]);
1939
1940 if raw_len < 4 {
1941 return Err(DriverError::Protocol(format!(
1942 "invalid message length {raw_len} for type '{}'",
1943 msg_type as char
1944 )));
1945 }
1946
1947 let payload_len = (raw_len - 4) as usize;
1948 let total_msg_len = 5 + payload_len;
1949
1950 if avail < total_msg_len {
1951 if total_msg_len > self.stream_buf.len() {
1952 let msg = self.read_one_message()?;
1953 match msg {
1954 BackendMessage::DataRow { data } => {
1955 f(data)?;
1956 continue;
1957 }
1958 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1959 continue;
1960 }
1961 BackendMessage::ReadyForQuery { status } => {
1962 self.tx_status = status;
1963 break 'outer;
1964 }
1965 BackendMessage::ErrorResponse { data } => {
1966 let fields = proto::parse_error_response(data);
1967 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1968 self.drain_to_ready()?;
1969 return Err(self.make_server_error(fields));
1970 }
1971 BackendMessage::NoticeResponse { .. } => continue,
1972 other => {
1973 return Err(DriverError::Protocol(format!(
1974 "unexpected message during for_each_raw: {other:?}"
1975 )));
1976 }
1977 }
1978 }
1979 break; }
1981
1982 let payload_start = self.stream_buf_pos + 5;
1984 let payload_end = payload_start + payload_len;
1985
1986 if msg_type == b'D' {
1987 f(&self.stream_buf[payload_start..payload_end])?;
1988 } else if msg_type == b'Z' {
1989 if payload_len >= 1 {
1990 self.tx_status = self.stream_buf[payload_start];
1991 }
1992 self.stream_buf_pos += total_msg_len;
1993 break 'outer;
1994 } else {
1995 self.handle_non_datarow_execute(
1996 msg_type,
1997 payload_start,
1998 payload_end,
1999 sql_hash,
2000 )?;
2001 }
2002
2003 self.stream_buf_pos += total_msg_len;
2004 }
2005
2006 let remaining = self.stream_buf_end - self.stream_buf_pos;
2008 if remaining > 0 && self.stream_buf_pos > 0 {
2009 self.stream_buf
2010 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2011 }
2012 self.stream_buf_pos = 0;
2013 self.stream_buf_end = remaining;
2014 let n = self
2015 .stream
2016 .read(&mut self.stream_buf[remaining..])
2017 .map_err(DriverError::Io)?;
2018 if n == 0 {
2019 return Err(DriverError::Io(std::io::Error::new(
2020 std::io::ErrorKind::UnexpectedEof,
2021 "connection closed",
2022 )));
2023 }
2024 self.stream_buf_end = remaining + n;
2025 }
2026
2027 if self.query_counter & 63 == 0 {
2029 if self.read_buf.capacity() > 64 * 1024 {
2030 self.read_buf.clear();
2031 self.read_buf.shrink_to(8192);
2032 }
2033 if self.write_buf.capacity() > 16 * 1024 {
2034 self.write_buf.clear();
2035 self.write_buf.shrink_to(8192);
2036 }
2037 }
2038
2039 Ok(())
2040 }
2041
2042 #[cold]
2044 #[inline(never)]
2045 fn for_each_raw_with_prepare<F>(
2046 &mut self,
2047 sql: &str,
2048 sql_hash: u64,
2049 params: &[&(dyn Encode + Sync)],
2050 mut f: F,
2051 ) -> Result<(), DriverError>
2052 where
2053 F: FnMut(&[u8]) -> Result<(), DriverError>,
2054 {
2055 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2056
2057 if params.len() > i16::MAX as usize {
2058 return Err(DriverError::Protocol(format!(
2059 "parameter count {} exceeds maximum {}",
2060 params.len(),
2061 i16::MAX
2062 )));
2063 }
2064
2065 let name = make_stmt_name(sql_hash);
2066 let param_oids: smallvec::SmallVec<[u32; 8]> =
2067 params.iter().map(|p| p.type_oid()).collect();
2068
2069 self.write_buf.clear();
2070 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2071 proto::write_describe(&mut self.write_buf, b'S', &name);
2072 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2073 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2074 self.stream
2075 .write_all(&self.write_buf)
2076 .map_err(DriverError::Io)?;
2077
2078 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2079 let columns = self.read_column_description()?;
2080 self.query_counter += 1;
2081 self.cache_stmt(
2082 sql_hash,
2083 StmtInfo {
2084 name,
2085 sql: sql.into(),
2086 columns,
2087 last_used: self.query_counter,
2088 bind_template: None,
2089 },
2090 );
2091
2092 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2094
2095 'outer: loop {
2096 loop {
2097 let avail = self.stream_buf_end - self.stream_buf_pos;
2098 if avail < 5 {
2099 break;
2100 }
2101
2102 let msg_type = self.stream_buf[self.stream_buf_pos];
2103 let raw_len = i32::from_be_bytes([
2104 self.stream_buf[self.stream_buf_pos + 1],
2105 self.stream_buf[self.stream_buf_pos + 2],
2106 self.stream_buf[self.stream_buf_pos + 3],
2107 self.stream_buf[self.stream_buf_pos + 4],
2108 ]);
2109
2110 if raw_len < 4 {
2111 return Err(DriverError::Protocol(format!(
2112 "invalid message length {raw_len} for type '{}'",
2113 msg_type as char
2114 )));
2115 }
2116
2117 let payload_len = (raw_len - 4) as usize;
2118 let total_msg_len = 5 + payload_len;
2119
2120 if avail < total_msg_len {
2121 if total_msg_len > self.stream_buf.len() {
2122 let msg = self.read_one_message()?;
2123 match msg {
2124 BackendMessage::DataRow { data } => {
2125 f(data)?;
2126 continue;
2127 }
2128 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
2129 continue;
2130 }
2131 BackendMessage::ReadyForQuery { status } => {
2132 self.tx_status = status;
2133 break 'outer;
2134 }
2135 BackendMessage::ErrorResponse { data } => {
2136 let fields = proto::parse_error_response(data);
2137 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2138 self.drain_to_ready()?;
2139 return Err(self.make_server_error(fields));
2140 }
2141 BackendMessage::NoticeResponse { .. } => continue,
2142 other => {
2143 return Err(DriverError::Protocol(format!(
2144 "unexpected message during for_each_raw: {other:?}"
2145 )));
2146 }
2147 }
2148 }
2149 break;
2150 }
2151
2152 let payload_start = self.stream_buf_pos + 5;
2153 let payload_end = payload_start + payload_len;
2154
2155 if msg_type == b'D' {
2156 f(&self.stream_buf[payload_start..payload_end])?;
2157 } else if msg_type == b'Z' {
2158 if payload_len >= 1 {
2159 self.tx_status = self.stream_buf[payload_start];
2160 }
2161 self.stream_buf_pos += total_msg_len;
2162 break 'outer;
2163 } else {
2164 self.handle_non_datarow_execute(
2165 msg_type,
2166 payload_start,
2167 payload_end,
2168 sql_hash,
2169 )?;
2170 }
2171
2172 self.stream_buf_pos += total_msg_len;
2173 }
2174
2175 self.refill_stream_buf()?;
2176 }
2177
2178 self.shrink_buffers();
2179 Ok(())
2180 }
2181
2182 fn for_each_raw_unnamed<F>(
2184 &mut self,
2185 sql: &str,
2186 params: &[&(dyn Encode + Sync)],
2187 mut f: F,
2188 ) -> Result<(), DriverError>
2189 where
2190 F: FnMut(&[u8]) -> Result<(), DriverError>,
2191 {
2192 if params.len() > i16::MAX as usize {
2193 return Err(DriverError::Protocol(format!(
2194 "parameter count {} exceeds maximum {}",
2195 params.len(),
2196 i16::MAX
2197 )));
2198 }
2199
2200 let param_oids: smallvec::SmallVec<[u32; 8]> =
2201 params.iter().map(|p| p.type_oid()).collect();
2202
2203 self.write_buf.clear();
2204 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
2205 proto::write_describe(&mut self.write_buf, b'S', b"");
2206 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
2207 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2208 self.stream
2209 .write_all(&self.write_buf)
2210 .map_err(DriverError::Io)?;
2211
2212 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2213 let _columns = self.read_column_description()?;
2214 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2215
2216 'outer: loop {
2217 loop {
2218 let avail = self.stream_buf_end - self.stream_buf_pos;
2219 if avail < 5 {
2220 break;
2221 }
2222
2223 let msg_type = self.stream_buf[self.stream_buf_pos];
2224 let raw_len = i32::from_be_bytes([
2225 self.stream_buf[self.stream_buf_pos + 1],
2226 self.stream_buf[self.stream_buf_pos + 2],
2227 self.stream_buf[self.stream_buf_pos + 3],
2228 self.stream_buf[self.stream_buf_pos + 4],
2229 ]);
2230
2231 if raw_len < 4 {
2232 return Err(DriverError::Protocol(format!(
2233 "invalid message length {raw_len} for type '{}'",
2234 msg_type as char
2235 )));
2236 }
2237
2238 let payload_len = (raw_len - 4) as usize;
2239 let total_msg_len = 5 + payload_len;
2240
2241 if avail < total_msg_len {
2242 if total_msg_len > self.stream_buf.len() {
2243 let msg = self.read_one_message()?;
2244 match msg {
2245 BackendMessage::DataRow { data } => {
2246 f(data)?;
2247 continue;
2248 }
2249 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
2250 continue
2251 }
2252 BackendMessage::ReadyForQuery { status } => {
2253 self.tx_status = status;
2254 break 'outer;
2255 }
2256 BackendMessage::ErrorResponse { data } => {
2257 let fields = proto::parse_error_response(data);
2258 self.drain_to_ready()?;
2259 return Err(self.make_server_error(fields));
2260 }
2261 BackendMessage::NoticeResponse { .. } => continue,
2262 other => {
2263 return Err(DriverError::Protocol(format!(
2264 "unexpected message during for_each_raw (unnamed): {other:?}"
2265 )));
2266 }
2267 }
2268 }
2269 break;
2270 }
2271
2272 let payload_start = self.stream_buf_pos + 5;
2273 let payload_end = payload_start + payload_len;
2274
2275 if msg_type == b'D' {
2276 f(&self.stream_buf[payload_start..payload_end])?;
2277 } else if msg_type == b'Z' {
2278 if payload_len >= 1 {
2279 self.tx_status = self.stream_buf[payload_start];
2280 }
2281 self.stream_buf_pos += total_msg_len;
2282 break 'outer;
2283 } else if msg_type == b'C' || msg_type == b'I' || msg_type == b'N' {
2284 } else if msg_type == b'E' {
2286 let fields =
2287 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2288 self.stream_buf_pos += total_msg_len;
2289 self.drain_to_ready()?;
2290 return Err(self.make_server_error(fields));
2291 } else {
2292 return Err(DriverError::Protocol(format!(
2293 "unexpected message type '{}' during for_each_raw (unnamed)",
2294 msg_type as char
2295 )));
2296 }
2297
2298 self.stream_buf_pos += total_msg_len;
2299 }
2300
2301 self.refill_stream_buf()?;
2302 }
2303
2304 self.shrink_buffers();
2305 Ok(())
2306 }
2307
2308 #[inline]
2313 pub fn for_each_raw<F>(
2314 &mut self,
2315 sql: &str,
2316 sql_hash: u64,
2317 params: &[&(dyn Encode + Sync)],
2318 f: F,
2319 ) -> Result<(), DriverError>
2320 where
2321 F: FnMut(&[u8]) -> Result<(), DriverError>,
2322 {
2323 self.for_each_raw_monolithic(sql, sql_hash, params, f)
2324 }
2325
2326 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
2328 self.write_buf.clear();
2329 proto::write_simple_query(&mut self.write_buf, sql);
2330 self.flush_write()?;
2331
2332 loop {
2333 let msg = self.read_one_message()?;
2334 match msg {
2335 BackendMessage::ReadyForQuery { status } => {
2336 self.tx_status = status;
2337 return Ok(());
2338 }
2339 BackendMessage::CommandComplete { .. }
2340 | BackendMessage::RowDescription { .. }
2341 | BackendMessage::DataRow { .. }
2342 | BackendMessage::EmptyQuery
2343 | BackendMessage::NoticeResponse { .. }
2344 | BackendMessage::ParameterStatus { .. }
2345 | BackendMessage::AuthOk
2349 | BackendMessage::AuthSaslFinal { .. }
2350 | BackendMessage::BackendKeyData { .. } => {}
2351 BackendMessage::ErrorResponse { data } => {
2352 let fields = proto::parse_error_response(data);
2353 self.drain_to_ready()?;
2354 return Err(self.make_server_error(fields));
2355 }
2356 other => {
2357 return Err(DriverError::Protocol(format!(
2358 "unexpected message during simple_query: {other:?}"
2359 )));
2360 }
2361 }
2362 }
2363 }
2364
2365 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
2367 self.write_buf.clear();
2368 proto::write_simple_query(&mut self.write_buf, sql);
2369 self.flush_write()?;
2370
2371 let mut rows: Vec<SimpleRow> = Vec::new();
2372 loop {
2373 let msg = self.read_one_message()?;
2374 match msg {
2375 BackendMessage::ReadyForQuery { status } => {
2376 self.tx_status = status;
2377 return Ok(rows);
2378 }
2379 BackendMessage::DataRow { data } => {
2380 rows.push(proto::parse_simple_data_row(data)?);
2381 }
2382 BackendMessage::RowDescription { .. }
2383 | BackendMessage::CommandComplete { .. }
2384 | BackendMessage::EmptyQuery
2385 | BackendMessage::NoticeResponse { .. }
2386 | BackendMessage::ParameterStatus { .. }
2387 | BackendMessage::AuthOk
2388 | BackendMessage::AuthSaslFinal { .. }
2389 | BackendMessage::BackendKeyData { .. } => {}
2390 BackendMessage::ErrorResponse { data } => {
2391 let fields = proto::parse_error_response(data);
2392 self.drain_to_ready()?;
2393 return Err(self.make_server_error(fields));
2394 }
2395 other => {
2396 return Err(DriverError::Protocol(format!(
2397 "unexpected message during simple_query_rows: {other:?}"
2398 )));
2399 }
2400 }
2401 }
2402 }
2403
2404 pub fn copy_in<'a, I>(
2426 &mut self,
2427 table: &str,
2428 columns: &[&str],
2429 rows: I,
2430 ) -> Result<u64, DriverError>
2431 where
2432 I: IntoIterator<Item = &'a str>,
2433 {
2434 let quoted_table = proto::quote_ident(table);
2436 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
2437 let sql = format!(
2438 "COPY {}({}) FROM STDIN",
2439 quoted_table,
2440 quoted_cols.join(",")
2441 );
2442
2443 self.write_buf.clear();
2445 proto::write_simple_query(&mut self.write_buf, &sql);
2446 self.flush_write()?;
2447
2448 loop {
2450 let msg = self.read_one_message()?;
2451 match msg {
2452 BackendMessage::CopyInResponse { .. } => break,
2453 BackendMessage::ErrorResponse { data } => {
2454 let fields = proto::parse_error_response(data);
2455 self.drain_to_ready()?;
2456 return Err(self.make_server_error(fields));
2457 }
2458 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2459 other => {
2460 return Err(DriverError::Protocol(format!(
2461 "expected CopyInResponse, got: {other:?}"
2462 )));
2463 }
2464 }
2465 }
2466
2467 self.write_buf.clear();
2477 for row in rows {
2478 let row_data = row.as_bytes();
2480 let data_len = (4 + row_data.len() + 1) as i32;
2481 self.write_buf.push(b'd');
2482 self.write_buf.extend_from_slice(&data_len.to_be_bytes());
2483 self.write_buf.extend_from_slice(row_data);
2484 self.write_buf.push(b'\n');
2485 if self.write_buf.len() > 65536 {
2487 self.flush_write()?;
2488 self.write_buf.clear();
2489 }
2490 }
2491 proto::write_copy_done(&mut self.write_buf);
2494 self.flush_write()?;
2495 self.write_buf.clear();
2496
2497 let mut count: u64 = 0;
2499 loop {
2500 let msg = self.read_one_message()?;
2501 match msg {
2502 BackendMessage::CommandComplete { tag } => {
2503 count = proto::parse_command_tag(tag);
2504 }
2505 BackendMessage::ReadyForQuery { status } => {
2506 self.tx_status = status;
2507 return Ok(count);
2508 }
2509 BackendMessage::ErrorResponse { data } => {
2510 let fields = proto::parse_error_response(data);
2511 self.drain_to_ready()?;
2512 return Err(self.make_server_error(fields));
2513 }
2514 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2515 other => {
2516 return Err(DriverError::Protocol(format!(
2517 "unexpected message during copy_in completion: {other:?}"
2518 )));
2519 }
2520 }
2521 }
2522 }
2523
2524 pub fn copy_out<W: std::io::Write>(
2544 &mut self,
2545 query: &str,
2546 writer: &mut W,
2547 ) -> Result<u64, DriverError> {
2548 let sql = format!("COPY ({query}) TO STDOUT");
2550
2551 self.write_buf.clear();
2553 proto::write_simple_query(&mut self.write_buf, &sql);
2554 self.flush_write()?;
2555
2556 loop {
2558 let msg = self.read_one_message()?;
2559 match msg {
2560 BackendMessage::CopyOutResponse { .. } => break,
2561 BackendMessage::ErrorResponse { data } => {
2562 let fields = proto::parse_error_response(data);
2563 self.drain_to_ready()?;
2564 return Err(self.make_server_error(fields));
2565 }
2566 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2567 other => {
2568 return Err(DriverError::Protocol(format!(
2569 "expected CopyOutResponse, got: {other:?}"
2570 )));
2571 }
2572 }
2573 }
2574
2575 loop {
2577 let msg = self.read_one_message()?;
2578 match msg {
2579 BackendMessage::CopyData { data } => {
2580 writer.write_all(data).map_err(DriverError::Io)?;
2581 }
2582 BackendMessage::CopyDone => break,
2583 BackendMessage::ErrorResponse { data } => {
2584 let fields = proto::parse_error_response(data);
2585 self.drain_to_ready()?;
2586 return Err(self.make_server_error(fields));
2587 }
2588 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2589 other => {
2590 return Err(DriverError::Protocol(format!(
2591 "unexpected message during copy_out data: {other:?}"
2592 )));
2593 }
2594 }
2595 }
2596
2597 let mut count: u64 = 0;
2599 loop {
2600 let msg = self.read_one_message()?;
2601 match msg {
2602 BackendMessage::CommandComplete { tag } => {
2603 count = proto::parse_command_tag(tag);
2604 }
2605 BackendMessage::ReadyForQuery { status } => {
2606 self.tx_status = status;
2607 return Ok(count);
2608 }
2609 BackendMessage::ErrorResponse { data } => {
2610 let fields = proto::parse_error_response(data);
2611 self.drain_to_ready()?;
2612 return Err(self.make_server_error(fields));
2613 }
2614 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2615 other => {
2616 return Err(DriverError::Protocol(format!(
2617 "unexpected message during copy_out completion: {other:?}"
2618 )));
2619 }
2620 }
2621 }
2622 }
2623
2624 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2629 self.prepare_describe_with_oids(sql, &[])
2630 }
2631
2632 pub fn prepare_describe_with_oids(
2635 &mut self,
2636 sql: &str,
2637 param_oids: &[u32],
2638 ) -> Result<PrepareResult, DriverError> {
2639 self.write_buf.clear();
2640 proto::write_parse(&mut self.write_buf, b"", sql, param_oids);
2643 proto::write_describe(&mut self.write_buf, b'S', b"");
2644 proto::write_sync(&mut self.write_buf);
2645 self.flush_write()?;
2646
2647 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2649
2650 let mut param_oids: Vec<u32> = Vec::new();
2652 let columns;
2653 loop {
2654 let msg = self.read_one_message()?;
2655 match msg {
2656 BackendMessage::ParameterDescription { data } => {
2657 param_oids = proto::parse_parameter_description(data)?;
2658 }
2659 BackendMessage::RowDescription { data } => {
2660 columns = proto::parse_row_description(data)?;
2661 break;
2662 }
2663 BackendMessage::NoData => {
2664 columns = Vec::new();
2665 break;
2666 }
2667 BackendMessage::NoticeResponse { .. } => {}
2668 BackendMessage::ErrorResponse { data } => {
2669 let fields = proto::parse_error_response(data);
2670 self.drain_to_ready()?;
2671 return Err(self.make_server_error(fields));
2672 }
2673 other => {
2674 return Err(DriverError::Protocol(format!(
2675 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2676 )));
2677 }
2678 }
2679 }
2680
2681 self.expect_ready()?;
2683
2684 Ok(PrepareResult {
2685 columns,
2686 param_oids,
2687 })
2688 }
2689
2690 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2699 loop {
2700 let (msg_type, _payload_len) = self.read_message_buffered()?;
2701 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2702 match msg {
2703 BackendMessage::NotificationResponse {
2704 channel, payload, ..
2705 } => {
2706 return Ok((channel.to_owned(), payload.to_owned()));
2707 }
2708 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2709 continue;
2710 }
2711 _ => continue,
2712 }
2713 }
2714 }
2715
2716 pub fn cancel(&self) -> Result<(), DriverError> {
2722 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2723 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2724 let mut buf = Vec::with_capacity(16);
2725 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2726 tcp.write_all(&buf).map_err(DriverError::Io)?;
2727 tcp.flush().map_err(DriverError::Io)?;
2728 drop(tcp);
2730 Ok(())
2731 }
2732
2733 pub fn set_read_timeout(
2738 &self,
2739 timeout: Option<std::time::Duration>,
2740 ) -> Result<(), DriverError> {
2741 self.stream
2742 .set_read_timeout(timeout)
2743 .map_err(DriverError::Io)
2744 }
2745
2746 pub fn query_streaming_start(
2760 &mut self,
2761 sql: &str,
2762 sql_hash: u64,
2763 params: &[&(dyn Encode + Sync)],
2764 chunk_size: i32,
2765 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2766 self.write_buf.clear();
2767
2768 if self.statement_cache_mode == StatementCacheMode::Disabled {
2770 let param_oids: smallvec::SmallVec<[u32; 8]> =
2771 params.iter().map(|p| p.type_oid()).collect();
2772 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
2773 proto::write_describe(&mut self.write_buf, b'S', b"");
2774 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
2775 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2776 proto::write_flush(&mut self.write_buf);
2777 self.flush_write()?;
2778
2779 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2780 let columns = self.read_column_description()?;
2781 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2782 self.streaming_active = true;
2783 return Ok((columns, false));
2784 }
2785
2786 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2787 self.query_counter += 1;
2789 info.last_used = self.query_counter;
2790
2791 let can_use_template = info
2792 .bind_template
2793 .as_ref()
2794 .is_some_and(|t| t.param_slots.len() == params.len());
2795
2796 if can_use_template {
2797 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
2799 DriverError::Protocol("bind_template missing despite can_use_template".into())
2800 })?;
2801 self.write_buf
2804 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2805
2806 let mut template_ok = true;
2807 for (i, param) in params.iter().enumerate() {
2808 let (data_offset, old_len) = tmpl.param_slots[i];
2809 if param.is_null() {
2810 let len_offset = data_offset - 4;
2811 self.write_buf[len_offset..len_offset + 4]
2812 .copy_from_slice(&(-1i32).to_be_bytes());
2813 } else if old_len >= 0 {
2814 let end = data_offset + old_len as usize;
2815 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2816 template_ok = false;
2817 break;
2818 }
2819 } else {
2820 template_ok = false;
2821 break;
2822 }
2823 }
2824
2825 if !template_ok {
2826 self.write_buf.clear();
2827 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2828 info.bind_template = None;
2829 }
2830 } else {
2831 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2832 }
2833
2834 let cols = info.columns.clone();
2835
2836 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2837 info.bind_template = build_bind_template(&self.write_buf, params.len());
2838 }
2839
2840 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2841 proto::write_flush(&mut self.write_buf);
2843 self.flush_write()?;
2844
2845 cols
2846 } else {
2847 let name = make_stmt_name(sql_hash);
2849 let param_oids: smallvec::SmallVec<[u32; 8]> =
2850 params.iter().map(|p| p.type_oid()).collect();
2851 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2852 proto::write_describe(&mut self.write_buf, b'S', &name);
2853 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2854
2855 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2856 proto::write_flush(&mut self.write_buf);
2857 self.flush_write()?;
2858
2859 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2860 let columns = self.read_column_description()?;
2861 self.query_counter += 1;
2862 self.cache_stmt(
2863 sql_hash,
2864 StmtInfo {
2865 name,
2866 sql: sql.into(),
2867 columns: columns.clone(),
2868 last_used: self.query_counter,
2869 bind_template: None,
2870 },
2871 );
2872 columns
2873 };
2874
2875 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2877
2878 self.streaming_active = true;
2879
2880 Ok((columns, false))
2881 }
2882
2883 pub fn streaming_next_chunk(
2891 &mut self,
2892 arena: &mut Arena,
2893 all_col_offsets: &mut Vec<(usize, i32)>,
2894 ) -> Result<bool, DriverError> {
2895 all_col_offsets.clear();
2896
2897 loop {
2898 let msg = self.read_one_message()?;
2899 match msg {
2900 BackendMessage::DataRow { data } => {
2901 parse_data_row_flat(data, arena, all_col_offsets)?;
2902 }
2903 BackendMessage::PortalSuspended => {
2904 return Ok(true);
2908 }
2909 BackendMessage::CommandComplete { .. } => {
2910 self.write_buf.clear();
2913 proto::write_sync(&mut self.write_buf);
2914 self.flush_write()?;
2915 self.expect_ready()?;
2916 self.shrink_buffers();
2917
2918 self.streaming_active = false;
2919 return Ok(false);
2920 }
2921 BackendMessage::EmptyQuery => {
2922 self.write_buf.clear();
2923 proto::write_sync(&mut self.write_buf);
2924 self.flush_write()?;
2925 self.expect_ready()?;
2926
2927 self.streaming_active = false;
2928 return Ok(false);
2929 }
2930 BackendMessage::ErrorResponse { data } => {
2931 let fields = proto::parse_error_response(data);
2932 self.write_buf.clear();
2934 proto::write_sync(&mut self.write_buf);
2935 self.flush_write()?;
2936 self.drain_to_ready()?;
2937
2938 self.streaming_active = false;
2939 return Err(self.make_server_error(fields));
2940 }
2941 BackendMessage::NoticeResponse { .. } => {}
2942 other => {
2943 return Err(DriverError::Protocol(format!(
2944 "unexpected message during streaming: {other:?}"
2945 )));
2946 }
2947 }
2948 }
2949 }
2950
2951 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2959 self.write_buf.clear();
2960 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2961 proto::write_flush(&mut self.write_buf);
2962 self.flush_write()
2963 }
2964
2965 pub fn is_streaming(&self) -> bool {
2967 self.streaming_active
2968 }
2969
2970 pub fn close(mut self) -> Result<(), DriverError> {
2972 self.write_buf.clear();
2973 proto::write_terminate(&mut self.write_buf);
2974 let _ = self.flush_write();
2975 Ok(())
2976 }
2977
2978 pub fn is_idle(&self) -> bool {
2982 self.tx_status == b'I'
2983 }
2984
2985 pub fn is_in_transaction(&self) -> bool {
2987 self.tx_status == b'T'
2988 }
2989
2990 pub fn is_in_failed_transaction(&self) -> bool {
2992 self.tx_status == b'E'
2993 }
2994
2995 pub fn touch(&mut self) {
2997 self.last_used = std::time::Instant::now();
2998 }
2999
3000 pub fn idle_duration(&self) -> std::time::Duration {
3002 self.last_used.elapsed()
3003 }
3004
3005 pub fn query_counter(&self) -> u64 {
3007 self.query_counter
3008 }
3009
3010 pub fn parameter(&self, name: &str) -> Option<&str> {
3012 self.params
3013 .iter()
3014 .find(|(k, _)| &**k == name)
3015 .map(|(_, v)| &**v)
3016 }
3017
3018 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
3020 &self.params
3021 }
3022
3023 pub fn pid(&self) -> i32 {
3025 self.pid
3026 }
3027
3028 pub fn secret_key(&self) -> i32 {
3030 self.secret
3031 }
3032
3033 pub fn drain_notifications(&mut self) -> Vec<Notification> {
3035 std::mem::take(&mut self.pending_notifications)
3036 }
3037
3038 pub fn pending_notification_count(&self) -> usize {
3040 self.pending_notifications.len()
3041 }
3042
3043 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
3045 self.max_stmt_cache_size = size;
3046 }
3047
3048 pub fn stmt_cache_len(&self) -> usize {
3050 self.stmts.len()
3051 }
3052
3053 pub fn created_at(&self) -> std::time::Instant {
3055 self.created_at
3056 }
3057
3058 #[inline]
3066 fn send_pipeline(
3067 &mut self,
3068 sql: &str,
3069 sql_hash: u64,
3070 params: &[&(dyn Encode + Sync)],
3071 need_columns: bool,
3072 skip_bind_complete: bool,
3073 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
3074 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
3075
3076 if params.len() > i16::MAX as usize {
3077 return Err(DriverError::Protocol(format!(
3078 "parameter count {} exceeds maximum {}",
3079 params.len(),
3080 i16::MAX
3081 )));
3082 }
3083
3084 self.write_buf.clear();
3085
3086 if self.statement_cache_mode == StatementCacheMode::Disabled {
3088 let param_oids: smallvec::SmallVec<[u32; 8]> =
3089 params.iter().map(|p| p.type_oid()).collect();
3090 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
3091 if need_columns {
3092 proto::write_describe(&mut self.write_buf, b'S', b"");
3093 }
3094 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
3095 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
3096 self.flush_write()?;
3097
3098 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
3099 let columns = if need_columns {
3100 Some(self.read_column_description()?)
3101 } else {
3102 None
3103 };
3104 if !skip_bind_complete {
3105 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
3106 }
3107 return Ok(columns);
3108 }
3109
3110 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
3111 self.query_counter += 1;
3113 info.last_used = self.query_counter;
3114
3115 let can_use_template = info
3116 .bind_template
3117 .as_ref()
3118 .is_some_and(|t| t.param_slots.len() == params.len());
3119
3120 let mut has_exec_sync = false;
3122
3123 if can_use_template {
3124 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
3128 DriverError::Protocol("bind_template missing despite can_use_template".into())
3129 })?;
3130 self.write_buf.extend_from_slice(&tmpl.bytes);
3131
3132 let mut template_ok = true;
3133 for (i, param) in params.iter().enumerate() {
3134 let (data_offset, old_len) = tmpl.param_slots[i];
3135 if param.is_null() {
3136 let len_offset = data_offset - 4;
3138 self.write_buf[len_offset..len_offset + 4]
3139 .copy_from_slice(&(-1i32).to_be_bytes());
3140 } else if old_len >= 0 {
3141 let end = data_offset + old_len as usize;
3142 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
3143 template_ok = false;
3145 break;
3146 }
3147 } else {
3148 template_ok = false;
3151 break;
3152 }
3153 }
3154
3155 if template_ok {
3156 has_exec_sync = true; } else {
3158 self.write_buf.clear();
3159 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
3160 info.bind_template = None;
3162 }
3163 } else {
3164 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
3165 }
3166
3167 let cols = if need_columns {
3168 Some(info.columns.clone())
3169 } else {
3170 None
3171 };
3172
3173 if info.bind_template.is_none() && !self.write_buf.is_empty() {
3177 info.bind_template = build_bind_template(&self.write_buf, params.len());
3178 }
3179
3180 if !has_exec_sync {
3181 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
3182 }
3183 self.flush_write()?;
3184
3185 cols
3186 } else {
3187 let name = make_stmt_name(sql_hash);
3189 let param_oids: smallvec::SmallVec<[u32; 8]> =
3190 params.iter().map(|p| p.type_oid()).collect();
3191 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
3192 proto::write_describe(&mut self.write_buf, b'S', &name);
3193 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
3194
3195 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
3196 self.flush_write()?;
3197
3198 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
3199 let columns = self.read_column_description()?;
3200 self.query_counter += 1;
3201 self.cache_stmt(
3202 sql_hash,
3203 StmtInfo {
3204 name,
3205 sql: sql.into(),
3206 columns: columns.clone(),
3207 last_used: self.query_counter,
3208 bind_template: None,
3209 },
3210 );
3211 if need_columns {
3212 Some(columns)
3213 } else {
3214 None
3215 }
3216 };
3217
3218 if !skip_bind_complete {
3219 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
3220 }
3221
3222 Ok(columns)
3223 }
3224
3225 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
3227 loop {
3228 let msg = self.read_one_message()?;
3229 match msg {
3230 BackendMessage::RowDescription { data } => {
3231 let cols = proto::parse_row_description(data)?;
3232 return Ok(cols.into());
3233 }
3234 BackendMessage::ParameterDescription { .. } => {}
3235 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
3236 BackendMessage::NoticeResponse { .. } => {}
3237 BackendMessage::ErrorResponse { data } => {
3238 let fields = proto::parse_error_response(data);
3239 self.drain_to_ready()?;
3240 return Err(self.make_server_error(fields));
3241 }
3242 other => {
3243 return Err(DriverError::Protocol(format!(
3244 "expected RowDescription/NoData, got: {other:?}"
3245 )));
3246 }
3247 }
3248 }
3249 }
3250
3251 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
3254 if self.stmts.len() >= self.max_stmt_cache_size
3255 && !self.stmts.contains_key(&sql_hash, &info.sql)
3256 {
3257 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
3258 proto::write_close(&mut self.write_buf, b'S', &evicted.name);
3259 }
3260 }
3261 self.stmts.insert(sql_hash, info);
3262 }
3263
3264 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
3265 if self.pending_notifications.len() < 1024 {
3266 self.pending_notifications.push(Notification {
3267 pid,
3268 channel: channel.to_owned(),
3269 payload: payload.to_owned(),
3270 });
3271 }
3272 }
3273
3274 fn shrink_buffers(&mut self) {
3275 if self.query_counter & 63 != 0 {
3279 return;
3280 }
3281 if self.read_buf.capacity() > 64 * 1024 {
3282 self.read_buf.clear();
3283 self.read_buf.shrink_to(8192);
3284 }
3285 if self.write_buf.capacity() > 16 * 1024 {
3286 self.write_buf.clear();
3287 self.write_buf.shrink_to(8192);
3288 }
3289 }
3290
3291 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
3292 if &fields.code == b"26000" {
3293 self.stmts.remove(&sql_hash);
3294 true
3295 } else {
3296 false
3297 }
3298 }
3299
3300 #[cold]
3301 #[inline(never)]
3302 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
3303 DriverError::Server {
3304 code: fields.code,
3305 message: fields.message.into_boxed_str(),
3306 detail: fields.detail.map(String::into_boxed_str),
3307 hint: fields.hint.map(String::into_boxed_str),
3308 position: fields.position,
3309 }
3310 }
3311
3312 #[cold]
3318 #[inline(never)]
3319 fn handle_non_datarow_query(
3320 &mut self,
3321 msg_type: u8,
3322 payload_start: usize,
3323 payload_end: usize,
3324 sql_hash: u64,
3325 affected_rows: &mut u64,
3326 ) -> Result<(), DriverError> {
3327 match msg_type {
3328 b'2' | b'I' => {} b'C' => {
3330 *affected_rows =
3331 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
3332 }
3333 b'E' => {
3334 let fields =
3335 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
3336 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
3337 self.drain_to_ready()?;
3338 return Err(self.make_server_error(fields));
3339 }
3340 b'A' => {
3341 let msg = proto::parse_backend_message(
3342 msg_type,
3343 &self.stream_buf[payload_start..payload_end],
3344 )?;
3345 if let BackendMessage::NotificationResponse {
3346 pid,
3347 channel,
3348 payload,
3349 } = msg
3350 {
3351 let ch = channel.to_owned();
3352 let pl = payload.to_owned();
3353 self.buffer_notification(pid, &ch, &pl);
3354 }
3355 }
3356 _ => {} }
3358 Ok(())
3359 }
3360
3361 #[cold]
3364 #[inline(never)]
3365 fn handle_non_datarow_execute(
3366 &mut self,
3367 msg_type: u8,
3368 payload_start: usize,
3369 payload_end: usize,
3370 sql_hash: u64,
3371 ) -> Result<(), DriverError> {
3372 match msg_type {
3373 b'2' | b'C' | b'I' => {} b'E' => {
3375 let fields =
3376 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
3377 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
3378 self.drain_to_ready()?;
3379 return Err(self.make_server_error(fields));
3380 }
3381 b'A' => {
3382 let msg = proto::parse_backend_message(
3383 msg_type,
3384 &self.stream_buf[payload_start..payload_end],
3385 )?;
3386 if let BackendMessage::NotificationResponse {
3387 pid,
3388 channel,
3389 payload,
3390 } = msg
3391 {
3392 let ch = channel.to_owned();
3393 let pl = payload.to_owned();
3394 self.buffer_notification(pid, &ch, &pl);
3395 }
3396 }
3397 _ => {} }
3399 Ok(())
3400 }
3401
3402 #[inline(always)]
3409 fn peek_stream_msg(&self) -> Result<Option<(u8, usize, usize, usize)>, DriverError> {
3410 let avail = self.stream_buf_end - self.stream_buf_pos;
3411 if avail < 5 {
3412 return Ok(None);
3413 }
3414
3415 let msg_type = self.stream_buf[self.stream_buf_pos];
3416 let raw_len = i32::from_be_bytes([
3417 self.stream_buf[self.stream_buf_pos + 1],
3418 self.stream_buf[self.stream_buf_pos + 2],
3419 self.stream_buf[self.stream_buf_pos + 3],
3420 self.stream_buf[self.stream_buf_pos + 4],
3421 ]);
3422
3423 if raw_len < 4 {
3424 return Err(DriverError::Protocol(format!(
3425 "invalid message length {raw_len} for type '{}'",
3426 msg_type as char
3427 )));
3428 }
3429
3430 let payload_len = (raw_len - 4) as usize;
3431 let total_msg_len = 5 + payload_len;
3432
3433 if avail < total_msg_len {
3434 return Ok(None);
3435 }
3436
3437 let payload_start = self.stream_buf_pos + 5;
3438 Ok(Some((
3439 msg_type,
3440 payload_start,
3441 payload_start + payload_len,
3442 total_msg_len,
3443 )))
3444 }
3445
3446 #[inline(always)]
3448 fn advance_stream_msg(&mut self, total_msg_len: usize) {
3449 self.stream_buf_pos += total_msg_len;
3450 }
3451
3452 #[inline]
3454 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
3455 loop {
3456 let (msg_type, _payload_len) = self.read_message_buffered()?;
3457 if msg_type == b'A' {
3458 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
3459 if let BackendMessage::NotificationResponse {
3460 pid,
3461 channel,
3462 payload,
3463 } = msg
3464 {
3465 let pid_owned = pid;
3466 let channel_owned = channel.to_owned();
3467 let payload_owned = payload.to_owned();
3468 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
3469 continue;
3470 }
3471 }
3472 return proto::parse_backend_message(msg_type, &self.read_buf);
3473 }
3474 }
3475
3476 fn expect_message(
3477 &mut self,
3478 pred: impl Fn(&BackendMessage<'_>) -> bool,
3479 ) -> Result<(), DriverError> {
3480 loop {
3481 let msg = self.read_one_message()?;
3482 if pred(&msg) {
3483 return Ok(());
3484 }
3485 match msg {
3486 BackendMessage::ErrorResponse { data } => {
3487 let fields = proto::parse_error_response(data);
3488 self.drain_to_ready()?;
3489 return Err(self.make_server_error(fields));
3490 }
3491 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3492 other => {
3493 return Err(DriverError::Protocol(format!(
3494 "unexpected message while waiting for expected type: {other:?}"
3495 )));
3496 }
3497 }
3498 }
3499 }
3500
3501 fn expect_ready(&mut self) -> Result<(), DriverError> {
3502 loop {
3503 let msg = self.read_one_message()?;
3504 match msg {
3505 BackendMessage::ReadyForQuery { status } => {
3506 self.tx_status = status;
3507 return Ok(());
3508 }
3509 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3510 BackendMessage::ErrorResponse { data } => {
3511 let fields = proto::parse_error_response(data);
3512 self.drain_to_ready()?;
3513 return Err(self.make_server_error(fields));
3514 }
3515 _ => {}
3516 }
3517 }
3518 }
3519
3520 #[inline]
3521 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
3522 loop {
3523 let msg = self.read_one_message()?;
3524 if let BackendMessage::ReadyForQuery { status } = msg {
3525 self.tx_status = status;
3526 return Ok(());
3527 }
3528 }
3529 }
3530
3531 #[inline]
3535 fn flush_write(&mut self) -> Result<(), DriverError> {
3536 self.stream
3537 .write_all(&self.write_buf)
3538 .map_err(DriverError::Io)
3539 }
3540
3541 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
3545 let mut header = [0u8; 5];
3546 sync_buffered_read_exact(
3547 &mut self.stream,
3548 &mut self.stream_buf,
3549 &mut self.stream_buf_pos,
3550 &mut self.stream_buf_end,
3551 &mut header,
3552 )?;
3553
3554 let msg_type = header[0];
3555 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
3556
3557 if len < 4 {
3558 return Err(DriverError::Protocol(format!(
3559 "invalid message length {len} for type '{}'",
3560 msg_type as char
3561 )));
3562 }
3563
3564 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
3565 if len > MAX_MESSAGE_LEN {
3566 return Err(DriverError::Protocol(format!(
3567 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
3568 msg_type as char
3569 )));
3570 }
3571
3572 let payload_len = (len - 4) as usize;
3573 self.read_buf.clear();
3574 self.read_buf.resize(payload_len, 0);
3575 if payload_len > 0 {
3576 sync_buffered_read_exact(
3577 &mut self.stream,
3578 &mut self.stream_buf,
3579 &mut self.stream_buf_pos,
3580 &mut self.stream_buf_end,
3581 &mut self.read_buf[..payload_len],
3582 )?;
3583 }
3584
3585 Ok((msg_type, payload_len))
3586 }
3587
3588 #[inline]
3590 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
3591 let remaining = self.stream_buf_end - self.stream_buf_pos;
3592 if remaining > 0 && self.stream_buf_pos > 0 {
3593 self.stream_buf
3594 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
3595 }
3596 self.stream_buf_pos = 0;
3597 self.stream_buf_end = remaining;
3598
3599 let n = self
3600 .stream
3601 .read(&mut self.stream_buf[remaining..])
3602 .map_err(DriverError::Io)?;
3603 if n == 0 {
3604 return Err(DriverError::Io(std::io::Error::new(
3605 std::io::ErrorKind::UnexpectedEof,
3606 "connection closed",
3607 )));
3608 }
3609 self.stream_buf_end = remaining + n;
3610 Ok(())
3611 }
3612}
3613
3614fn sync_buffered_read_exact(
3617 stream: &mut Stream,
3618 buf: &mut [u8],
3619 pos: &mut usize,
3620 end: &mut usize,
3621 out: &mut [u8],
3622) -> Result<(), DriverError> {
3623 let mut filled = 0;
3624 while filled < out.len() {
3625 let avail = *end - *pos;
3626 if avail > 0 {
3627 let take = avail.min(out.len() - filled);
3628 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
3629 *pos += take;
3630 filled += take;
3631 } else {
3632 *pos = 0;
3633 let n = stream.read(buf).map_err(DriverError::Io)?;
3634 if n == 0 {
3635 return Err(DriverError::Io(std::io::Error::new(
3636 std::io::ErrorKind::UnexpectedEof,
3637 "connection closed",
3638 )));
3639 }
3640 *end = n;
3641 }
3642 }
3643 Ok(())
3644}
3645
3646#[inline(always)]
3656pub(crate) fn parse_data_row_into_buf(
3657 data: &[u8],
3658 buf: &mut Vec<u8>,
3659 out: &mut Vec<(usize, i32)>,
3660) -> Result<(), DriverError> {
3661 if data.len() < 2 {
3662 return Err(DriverError::Protocol("DataRow too short".into()));
3663 }
3664
3665 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3666 if num_cols < 0 {
3667 return Err(DriverError::Protocol(
3668 "DataRow: negative column count".into(),
3669 ));
3670 }
3671 let num_cols = num_cols as usize;
3672
3673 let col_data = &data[2..];
3681 let base = buf.len();
3682 buf.extend_from_slice(col_data);
3683
3684 let mut pos: usize = 0;
3686 for _ in 0..num_cols {
3687 if pos + 4 > col_data.len() {
3688 return Err(DriverError::Protocol("DataRow truncated".into()));
3689 }
3690
3691 let col_len = i32::from_be_bytes([
3692 col_data[pos],
3693 col_data[pos + 1],
3694 col_data[pos + 2],
3695 col_data[pos + 3],
3696 ]);
3697 pos += 4;
3698
3699 if col_len < 0 {
3700 out.push((0, -1));
3701 } else {
3702 let len = col_len as usize;
3703 if pos + len > col_data.len() {
3704 return Err(DriverError::Protocol(
3705 "DataRow column data truncated".into(),
3706 ));
3707 }
3708 out.push((base + pos, col_len));
3710 pos += len;
3711 }
3712 }
3713
3714 Ok(())
3715}
3716
3717fn parse_data_row_flat(
3721 data: &[u8],
3722 arena: &mut Arena,
3723 out: &mut Vec<(usize, i32)>,
3724) -> Result<(), DriverError> {
3725 if data.len() < 2 {
3726 return Err(DriverError::Protocol("DataRow too short".into()));
3727 }
3728
3729 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3730 if num_cols_raw < 0 {
3731 return Err(DriverError::Protocol(
3732 "DataRow: negative column count".into(),
3733 ));
3734 }
3735 let num_cols = num_cols_raw as usize;
3736 out.reserve(num_cols);
3737
3738 let col_data = &data[2..];
3741 let base = arena.alloc_copy(col_data);
3742
3743 let mut pos: usize = 0;
3745 for _ in 0..num_cols {
3746 if pos + 4 > col_data.len() {
3747 return Err(DriverError::Protocol("DataRow truncated".into()));
3748 }
3749
3750 let col_len = i32::from_be_bytes([
3751 col_data[pos],
3752 col_data[pos + 1],
3753 col_data[pos + 2],
3754 col_data[pos + 3],
3755 ]);
3756 pos += 4;
3757
3758 if col_len < 0 {
3759 out.push((0, -1));
3760 } else {
3761 let len = col_len as usize;
3762 if pos + len > col_data.len() {
3763 return Err(DriverError::Protocol(
3764 "DataRow column data truncated".into(),
3765 ));
3766 }
3767 out.push((base + pos, col_len));
3769 pos += len;
3770 }
3771 }
3772
3773 Ok(())
3774}
3775
3776#[cfg(test)]
3777#[allow(clippy::approx_constant)]
3778mod tests {
3779 use super::*;
3780 use crate::types::hash_sql;
3781
3782 #[test]
3783 fn sync_config_tcp_no_longer_rejected() {
3784 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3787 let result = Connection::connect(&config);
3788 assert!(result.is_err());
3789 let err = result.unwrap_err().to_string();
3790 assert!(
3793 !err.contains("Unix domain socket"),
3794 "error should NOT mention UDS requirement: {err}"
3795 );
3796 }
3797
3798 #[test]
3799 fn sync_data_row_parsing() {
3800 let mut arena = Arena::new();
3801 let mut out = Vec::new();
3802
3803 let mut data = Vec::new();
3804 data.extend_from_slice(&2i16.to_be_bytes());
3805 data.extend_from_slice(&4i32.to_be_bytes());
3806 data.extend_from_slice(&42i32.to_be_bytes());
3807 data.extend_from_slice(&(-1i32).to_be_bytes());
3808
3809 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3810 assert_eq!(out.len(), 2);
3811 assert_eq!(out[0].1, 4);
3812 assert_eq!(out[1].1, -1);
3813 }
3814
3815 #[test]
3816 fn sync_data_row_empty() {
3817 let mut arena = Arena::new();
3818 let mut out = Vec::new();
3819 let data = 0i16.to_be_bytes();
3820 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3821 assert_eq!(out.len(), 0);
3822 }
3823
3824 #[test]
3825 fn sync_data_row_too_short() {
3826 let mut arena = Arena::new();
3827 let mut out = Vec::new();
3828 let data = vec![0u8];
3829 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3830 }
3831
3832 #[test]
3833 fn sync_data_row_negative_col_count() {
3834 let mut arena = Arena::new();
3835 let mut out = Vec::new();
3836 let data = (-1i16).to_be_bytes();
3837 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3838 }
3839
3840 #[test]
3841 fn sync_data_row_truncated() {
3842 let mut arena = Arena::new();
3843 let mut out = Vec::new();
3844 let mut data = Vec::new();
3845 data.extend_from_slice(&2i16.to_be_bytes());
3846 data.extend_from_slice(&4i32.to_be_bytes());
3847 data.extend_from_slice(&42i32.to_be_bytes());
3848 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3850 }
3851
3852 #[test]
3853 fn sync_data_row_col_data_truncated() {
3854 let mut arena = Arena::new();
3855 let mut out = Vec::new();
3856 let mut data = Vec::new();
3857 data.extend_from_slice(&1i16.to_be_bytes());
3858 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3861 }
3862
3863 #[test]
3866 fn sync_connect_tcp_unreachable_port() {
3867 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3870 let result = Connection::connect(&config);
3871 assert!(result.is_err());
3872 let err = result.unwrap_err().to_string();
3873 assert!(
3874 !err.contains("Unix domain socket"),
3875 "error should NOT mention UDS: {err}"
3876 );
3877 }
3878
3879 #[test]
3880 fn sync_connect_ip_address_attempts_tcp() {
3881 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3884 let result = Connection::connect(&config);
3885 assert!(result.is_err());
3886 }
3887
3888 #[test]
3891 fn sync_data_row_all_null() {
3892 let mut arena = Arena::new();
3893 let mut out = Vec::new();
3894 let mut data = Vec::new();
3895 data.extend_from_slice(&3i16.to_be_bytes());
3896 data.extend_from_slice(&(-1i32).to_be_bytes());
3897 data.extend_from_slice(&(-1i32).to_be_bytes());
3898 data.extend_from_slice(&(-1i32).to_be_bytes());
3899 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3900 assert_eq!(out.len(), 3);
3901 for (_, len) in &out {
3902 assert_eq!(*len, -1);
3903 }
3904 }
3905
3906 #[test]
3907 fn sync_data_row_long_text() {
3908 let mut arena = Arena::new();
3909 let mut out = Vec::new();
3910 let long_text = "a".repeat(2048);
3911 let text_bytes = long_text.as_bytes();
3912 let mut data = Vec::new();
3913 data.extend_from_slice(&1i16.to_be_bytes());
3914 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3915 data.extend_from_slice(text_bytes);
3916 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3917 assert_eq!(out.len(), 1);
3918 assert_eq!(out[0].1, text_bytes.len() as i32);
3919 let stored = arena.get(out[0].0, out[0].1 as usize);
3920 assert_eq!(stored, text_bytes);
3921 }
3922
3923 #[test]
3924 fn sync_data_row_empty_text() {
3925 let mut arena = Arena::new();
3926 let mut out = Vec::new();
3927 let mut data = Vec::new();
3928 data.extend_from_slice(&1i16.to_be_bytes());
3929 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3931 assert_eq!(out.len(), 1);
3932 assert_eq!(out[0].1, 0); }
3934
3935 #[test]
3936 fn sync_data_row_17_columns_exceeds_smallvec() {
3937 let mut arena = Arena::new();
3938 let mut out = Vec::new();
3939 let mut data = Vec::new();
3940 let num_cols: i16 = 20;
3941 data.extend_from_slice(&num_cols.to_be_bytes());
3942 for i in 0..num_cols {
3943 let val = (i as i32).to_be_bytes();
3944 data.extend_from_slice(&4i32.to_be_bytes());
3945 data.extend_from_slice(&val);
3946 }
3947 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3948 assert_eq!(out.len(), 20);
3949 for (idx, (offset, len)) in out.iter().enumerate() {
3950 assert_eq!(*len, 4);
3951 let stored = arena.get(*offset, 4);
3952 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3953 assert_eq!(val, idx as i32);
3954 }
3955 }
3956
3957 #[test]
3958 fn sync_data_row_mixed_null_and_data() {
3959 let mut arena = Arena::new();
3960 let mut out = Vec::new();
3961 let mut data = Vec::new();
3962 data.extend_from_slice(&5i16.to_be_bytes());
3963 data.extend_from_slice(&(-1i32).to_be_bytes());
3965 data.extend_from_slice(&4i32.to_be_bytes());
3967 data.extend_from_slice(&42i32.to_be_bytes());
3968 data.extend_from_slice(&(-1i32).to_be_bytes());
3970 data.extend_from_slice(&(-1i32).to_be_bytes());
3972 data.extend_from_slice(&5i32.to_be_bytes());
3974 data.extend_from_slice(b"hello");
3975
3976 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3977 assert_eq!(out.len(), 5);
3978 assert_eq!(out[0].1, -1);
3979 assert_eq!(out[1].1, 4);
3980 assert_eq!(out[2].1, -1);
3981 assert_eq!(out[3].1, -1);
3982 assert_eq!(out[4].1, 5);
3983 let stored = arena.get(out[4].0, 5);
3984 assert_eq!(stored, b"hello");
3985 }
3986
3987 #[test]
3990 #[ignore] fn sync_connect_uds_if_pg_available() {
3992 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3993 let result = Connection::connect(&config);
3994 if let Ok(conn) = result {
3996 assert!(conn.pid() != 0, "pid should be nonzero");
3997 assert!(conn.is_idle(), "should start idle");
3998 assert!(!conn.is_in_transaction(), "should not be in tx");
3999 assert!(
4000 !conn.is_in_failed_transaction(),
4001 "should not be in failed tx"
4002 );
4003 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
4004 let _ = conn.close();
4005 }
4006 }
4007
4008 #[test]
4009 #[ignore] fn sync_simple_query_if_pg_available() {
4011 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4012 let mut conn = Connection::connect(&config).unwrap();
4013 conn.simple_query("SELECT 1").unwrap();
4014 assert!(conn.is_idle());
4015 let _ = conn.close();
4016 }
4017
4018 #[test]
4019 #[ignore] fn sync_query_with_params_if_pg_available() {
4021 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4022 let mut conn = Connection::connect(&config).unwrap();
4023 let sql = "SELECT $1::int4 + $2::int4 AS sum";
4024 let hash = hash_sql(sql);
4025 let a: i32 = 10;
4026 let b: i32 = 20;
4027 let result = conn
4028 .query(
4029 sql,
4030 hash,
4031 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
4032 )
4033 .unwrap();
4034 assert_eq!(result.len(), 1);
4035 let _ = conn.close();
4036 }
4037
4038 #[test]
4039 #[ignore] fn sync_execute_if_pg_available() {
4041 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4042 let mut conn = Connection::connect(&config).unwrap();
4043 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
4044 .unwrap();
4045 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
4046 let hash = hash_sql(sql);
4047 let val: i32 = 42;
4048 let affected = conn
4049 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
4050 .unwrap();
4051 assert_eq!(affected, 1);
4052 let _ = conn.close();
4053 }
4054
4055 #[test]
4056 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
4058 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4059 let mut conn = Connection::connect(&config).unwrap();
4060 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
4061 .unwrap();
4062 let sql = "SELECT id FROM _sync_fe0";
4063 let hash = hash_sql(sql);
4064 let mut count = 0u32;
4065 conn.for_each(sql, hash, &[], |_row| {
4066 count += 1;
4067 Ok(())
4068 })
4069 .unwrap();
4070 assert_eq!(count, 0);
4071 let _ = conn.close();
4072 }
4073
4074 #[test]
4075 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
4077 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4078 let mut conn = Connection::connect(&config).unwrap();
4079 let sql = "SELECT generate_series(1, 5)";
4080 let hash = hash_sql(sql);
4081 let mut count = 0u32;
4082 conn.for_each(sql, hash, &[], |_row| {
4083 count += 1;
4084 Ok(())
4085 })
4086 .unwrap();
4087 assert_eq!(count, 5);
4088 let _ = conn.close();
4089 }
4090
4091 #[test]
4092 #[ignore] fn sync_prepare_only_if_pg_available() {
4094 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4095 let mut conn = Connection::connect(&config).unwrap();
4096 let sql = "SELECT 1";
4097 let hash = hash_sql(sql);
4098 conn.prepare_only(sql, hash).unwrap();
4099 assert_eq!(conn.stmt_cache_len(), 1);
4100 conn.prepare_only(sql, hash).unwrap();
4102 assert_eq!(conn.stmt_cache_len(), 1);
4103 let _ = conn.close();
4104 }
4105
4106 #[test]
4107 #[ignore] fn sync_simple_query_rows_if_pg_available() {
4109 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4110 let mut conn = Connection::connect(&config).unwrap();
4111 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
4112 assert!(!rows.is_empty());
4113 let _ = conn.close();
4114 }
4115
4116 #[test]
4117 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
4119 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4120 let mut conn = Connection::connect(&config).unwrap();
4121 let sql1 = "SELECT 1";
4122 let hash1 = hash_sql(sql1);
4123 conn.query(sql1, hash1, &[]).unwrap();
4124 assert_eq!(conn.stmt_cache_len(), 1);
4125 conn.query(sql1, hash1, &[]).unwrap();
4127 assert_eq!(conn.stmt_cache_len(), 1);
4128 let sql2 = "SELECT 2";
4130 let hash2 = hash_sql(sql2);
4131 conn.query(sql2, hash2, &[]).unwrap();
4132 assert_eq!(conn.stmt_cache_len(), 2);
4133 let _ = conn.close();
4134 }
4135
4136 #[test]
4137 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
4139 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4140 let mut conn = Connection::connect(&config).unwrap();
4141 let sql = "SELECTTTT INVALID GARBAGE";
4142 let hash = hash_sql(sql);
4143 let result = conn.query(sql, hash, &[]);
4144 assert!(result.is_err());
4145 assert!(conn.is_idle());
4147 let _ = conn.close();
4148 }
4149
4150 #[test]
4151 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
4153 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4154 let mut conn = Connection::connect(&config).unwrap();
4155 assert!(conn.is_idle());
4156 assert!(!conn.is_in_transaction());
4157 conn.simple_query("BEGIN").unwrap();
4158 assert!(conn.is_in_transaction());
4159 assert!(!conn.is_idle());
4160 conn.simple_query("COMMIT").unwrap();
4161 assert!(conn.is_idle());
4162 assert!(!conn.is_in_transaction());
4163 let _ = conn.close();
4164 }
4165
4166 #[test]
4167 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
4169 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4170 let mut conn = Connection::connect(&config).unwrap();
4171 conn.set_max_stmt_cache_size(3);
4172 for i in 0..5 {
4173 let sql = format!("SELECT {}", i);
4174 let hash = hash_sql(&sql);
4175 conn.query(&sql, hash, &[]).unwrap();
4176 }
4177 assert!(
4179 conn.stmt_cache_len() <= 3,
4180 "cache should be capped at 3, got {}",
4181 conn.stmt_cache_len()
4182 );
4183 let _ = conn.close();
4184 }
4185
4186 #[test]
4187 #[ignore] fn sync_for_each_raw_if_pg_available() {
4189 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4190 let mut conn = Connection::connect(&config).unwrap();
4191 let sql = "SELECT generate_series(1, 3)";
4192 let hash = hash_sql(sql);
4193 let mut raw_count = 0u32;
4194 conn.for_each_raw(sql, hash, &[], |_raw_data| {
4195 raw_count += 1;
4196 Ok(())
4197 })
4198 .unwrap();
4199 assert_eq!(raw_count, 3);
4200 let _ = conn.close();
4201 }
4202
4203 #[test]
4204 #[ignore] fn sync_query_null_params_if_pg_available() {
4206 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4207 let mut conn = Connection::connect(&config).unwrap();
4208 let sql = "SELECT $1::int4 IS NULL AS is_null";
4209 let hash = hash_sql(sql);
4210 let val: Option<i32> = None;
4211 let _result = conn
4212 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
4213 .unwrap();
4214 let _ = conn.close();
4215 }
4216
4217 #[test]
4218 #[ignore] fn sync_query_various_param_types_if_pg_available() {
4220 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4221 let mut conn = Connection::connect(&config).unwrap();
4222 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
4223 let hash = hash_sql(sql);
4224 let p1: i32 = 42;
4225 let p2: i64 = 9999999;
4226 let p3: &str = "hello";
4227 let p4: bool = true;
4228 let p5: f64 = 3.14;
4229 let result = conn
4230 .query(
4231 sql,
4232 hash,
4233 &[
4234 &p1 as &(dyn Encode + Sync),
4235 &p2 as &(dyn Encode + Sync),
4236 &p3 as &(dyn Encode + Sync),
4237 &p4 as &(dyn Encode + Sync),
4238 &p5 as &(dyn Encode + Sync),
4239 ],
4240 )
4241 .unwrap();
4242 assert_eq!(result.len(), 1);
4243 let _ = conn.close();
4244 }
4245
4246 #[test]
4249 fn sync_shrink_threshold_values() {
4250 let shrink = 64 * 1024usize;
4259 let initial = 8192usize;
4260 assert!(
4261 shrink > initial,
4262 "shrink threshold must exceed initial size"
4263 );
4264 }
4265
4266 #[test]
4269 fn sync_connection_debug_format() {
4270 let fmt_str = format!(
4274 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
4275 0, 'I', 0
4276 );
4277 assert!(fmt_str.contains("Connection"));
4278 assert!(fmt_str.contains("pid"));
4279 assert!(fmt_str.contains("tx_status"));
4280 }
4281
4282 #[test]
4285 fn sync_connect_sslmode_require_without_tls_feature() {
4286 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
4290 config.ssl = SslMode::Require;
4291 let result = Connection::connect(&config);
4292 assert!(result.is_err());
4293 }
4298
4299 #[test]
4300 fn sync_connect_sslmode_disable_attempts_tcp() {
4301 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
4302 config.ssl = SslMode::Disable;
4303 let result = Connection::connect(&config);
4304 assert!(result.is_err());
4305 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
4307 }
4308
4309 #[test]
4310 fn sync_connect_sslmode_prefer_attempts_tcp() {
4311 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
4312 config.ssl = SslMode::Prefer;
4313 let result = Connection::connect(&config);
4314 assert!(result.is_err());
4315 }
4316
4317 #[test]
4320 #[ignore] fn sync_streaming_basic_if_pg_available() {
4322 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4323 let mut conn = Connection::connect(&config).unwrap();
4324 assert!(!conn.is_streaming());
4325
4326 let sql = "SELECT generate_series(1, 10)";
4327 let hash = hash_sql(sql);
4328
4329 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
4330 assert!(!cols.is_empty());
4331 assert!(conn.is_streaming());
4332
4333 let mut arena = Arena::new();
4334 let mut offsets = Vec::new();
4335 let mut total_rows = 0;
4336
4337 loop {
4339 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
4340 total_rows += offsets.len();
4341 if !has_more {
4342 break;
4343 }
4344 conn.streaming_send_execute(3).unwrap();
4345 }
4346
4347 assert_eq!(total_rows, 10);
4348 assert!(!conn.is_streaming());
4349 let _ = conn.close();
4350 }
4351
4352 #[test]
4355 #[ignore] fn sync_prepare_describe_if_pg_available() {
4357 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4358 let mut conn = Connection::connect(&config).unwrap();
4359
4360 let result = conn
4361 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
4362 .unwrap();
4363 assert_eq!(result.columns.len(), 1);
4364 assert_eq!(&*result.columns[0].name, "sum");
4365 assert_eq!(result.param_oids.len(), 2);
4366 let _ = conn.close();
4367 }
4368
4369 #[test]
4372 #[ignore] fn sync_wait_for_notification_if_pg_available() {
4374 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4375 let mut conn = Connection::connect(&config).unwrap();
4376
4377 conn.simple_query("LISTEN test_chan").unwrap();
4378 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
4379
4380 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
4382 .unwrap();
4383
4384 let (channel, payload) = conn.wait_for_notification().unwrap();
4385 assert_eq!(channel, "test_chan");
4386 assert_eq!(payload, "hello");
4387 let _ = conn.close();
4388 }
4389
4390 #[test]
4393 #[ignore] fn sync_cancel_if_pg_available() {
4395 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4396 let conn = Connection::connect(&config).unwrap();
4397 let result = conn.cancel();
4400 let _ = result;
4402 let _ = conn.close();
4403 }
4404
4405 #[test]
4408 #[ignore] fn sync_server_params_if_pg_available() {
4410 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4411 let conn = Connection::connect(&config).unwrap();
4412 let params = conn.server_params();
4413 assert!(
4414 !params.is_empty(),
4415 "server should send parameters during startup"
4416 );
4417 assert!(
4419 conn.parameter("server_encoding").is_some(),
4420 "server_encoding should be present"
4421 );
4422 let _ = conn.close();
4423 }
4424
4425 #[test]
4428 #[ignore] fn sync_set_read_timeout_if_pg_available() {
4430 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4431 let conn = Connection::connect(&config).unwrap();
4432 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
4434 .unwrap();
4435 conn.set_read_timeout(None).unwrap();
4436 let _ = conn.close();
4437 }
4438}