1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::arena::Arena;
18use crate::auth;
19use crate::codec::Encode;
20use crate::proto::{self, BackendMessage};
21use crate::stmt_cache::{build_bind_template, make_stmt_name, StmtCache, StmtInfo};
22use crate::sync_io::Stream;
23use crate::types::{
24 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
25 StartupAction, StatementCacheMode,
26};
27use crate::DriverError;
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42pub(crate) fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58thread_local! {
59 static COL_OFFSETS_POOL: RefCell<Vec<Vec<(usize, i32)>>> = const { RefCell::new(Vec::new()) };
60}
61
62pub(crate) fn acquire_col_offsets() -> Vec<(usize, i32)> {
63 COL_OFFSETS_POOL
64 .with(|pool| pool.borrow_mut().pop())
65 .unwrap_or_default()
66}
67
68pub fn release_col_offsets(buf: Vec<(usize, i32)>) {
69 COL_OFFSETS_POOL.with(|pool| {
70 let mut pool = pool.borrow_mut();
71 if pool.len() < 4 {
72 pool.push(buf);
73 }
74 });
75}
76
77pub struct Connection {
106 stream_buf_pos: usize,
108 stream_buf_end: usize,
109 query_counter: u64,
110 tx_status: u8,
111 streaming_active: bool,
112 pid: i32,
113 secret: i32,
114 max_stmt_cache_size: usize,
115 statement_cache_mode: StatementCacheMode,
116 stream: Stream,
118 write_buf: Vec<u8>,
119 stream_buf: Vec<u8>,
120 stmts: StmtCache,
121 read_buf: Vec<u8>,
123 params: Vec<(Box<str>, Box<str>)>,
124 last_used: std::time::Instant,
125 created_at: std::time::Instant,
126 pending_notifications: Vec<Notification>,
127 connect_config: Arc<Config>,
130 tls_server_cert_hash: Option<[u8; 32]>,
133}
134
135impl std::fmt::Debug for Connection {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("Connection")
138 .field("pid", &self.pid)
139 .field("tx_status", &(self.tx_status as char))
140 .field("stmt_cache_len", &self.stmts.len())
141 .finish()
142 }
143}
144
145impl Connection {
146 pub fn connect(config: &Config) -> Result<Self, DriverError> {
158 Self::connect_arc(Arc::new(config.clone()))
159 }
160
161 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
166 config.validate()?;
167
168 #[allow(unused_mut)]
171 let mut tls_cert_hash: Option<[u8; 32]> = None;
172
173 let stream = if config.host_is_uds() {
174 #[cfg(unix)]
176 {
177 let path = config.uds_path();
178 let unix =
179 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
180 Stream::Unix(unix)
181 }
182 #[cfg(not(unix))]
183 {
184 return Err(DriverError::Protocol(
185 "Unix domain sockets are not supported on this platform".into(),
186 ));
187 }
188 } else {
189 let addr = format!("{}:{}", config.host, config.port);
191 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
192
193 match config.ssl {
194 SslMode::Disable => {
195 tcp.set_nodelay(true).map_err(DriverError::Io)?;
196 let stream = Stream::Tcp(tcp);
197 stream.set_keepalive()?;
198 stream
199 }
200 SslMode::Prefer | SslMode::Require => {
201 #[cfg(feature = "tls")]
202 {
203 match crate::tls_sync::try_upgrade(
204 tcp,
205 &config,
206 config.ssl == SslMode::Require,
207 ) {
208 Ok(result) => {
209 tls_cert_hash = result.server_cert_hash;
210 let stream = Stream::Tls(Box::new(result.stream));
211 stream.set_nodelay()?;
212 stream.set_keepalive()?;
213 stream
214 }
215 Err(e) => {
216 if config.ssl == SslMode::Require {
217 return Err(e);
218 }
219 let tcp =
221 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
222 tcp.set_nodelay(true).map_err(DriverError::Io)?;
223 let stream = Stream::Tcp(tcp);
224 stream.set_keepalive()?;
225 stream
226 }
227 }
228 }
229 #[cfg(not(feature = "tls"))]
230 {
231 if config.ssl == SslMode::Require {
232 return Err(DriverError::Protocol(
233 "sslmode=require but bsql was compiled without the 'tls' feature"
234 .into(),
235 ));
236 }
237 tcp.set_nodelay(true).map_err(DriverError::Io)?;
238 let stream = Stream::Tcp(tcp);
239 stream.set_keepalive()?;
240 stream
241 }
242 }
243 }
244 };
245
246 let now = std::time::Instant::now();
247 let mut conn = Self {
248 stream_buf_pos: 0,
250 stream_buf_end: 0,
251 query_counter: 0,
252 tx_status: b'I',
253 streaming_active: false,
254 pid: 0,
255 secret: 0,
256 max_stmt_cache_size: 256,
257 statement_cache_mode: config.statement_cache_mode,
258 stream,
260 write_buf: Vec::with_capacity(4096),
261 stream_buf: vec![0u8; 65536],
262 stmts: StmtCache::default(),
263 read_buf: Vec::with_capacity(8192),
265 params: Vec::new(),
266 last_used: now,
267 created_at: now,
268 pending_notifications: Vec::new(),
269 connect_config: config.clone(),
270 tls_server_cert_hash: tls_cert_hash,
271 };
272
273 conn.startup(&config)?;
274 conn.validate_server_params()?;
275
276 Ok(conn)
277 }
278
279 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
282 self.write_buf.clear();
283 let timeout_str; let mut extra_params: smallvec::SmallVec<[(&str, &str); 2]> = smallvec::SmallVec::new();
287 if config.statement_timeout_secs > 0 {
288 timeout_str = format!("{}s", config.statement_timeout_secs);
289 extra_params.push(("statement_timeout", &timeout_str));
290 }
291 proto::write_startup(
292 &mut self.write_buf,
293 &config.user,
294 &config.database,
295 &extra_params,
296 );
297 self.flush_write()?;
298
299 loop {
300 let action = self.read_startup_action()?;
301 match action {
302 StartupAction::AuthOk => {}
303 StartupAction::AuthCleartext => {
304 self.write_buf.clear();
305 let mut pw = config.password.as_bytes().to_vec();
306 pw.push(0);
307 proto::write_password(&mut self.write_buf, &pw);
308 self.flush_write()?;
309 }
310 StartupAction::AuthMd5(salt) => {
311 self.write_buf.clear();
312 let hash = auth::md5_password(&config.user, &config.password, &salt);
313 proto::write_password(&mut self.write_buf, &hash);
314 self.flush_write()?;
315 }
316 StartupAction::AuthSasl(mechanisms_data) => {
317 self.handle_scram(config, &mechanisms_data)?;
318 }
319 StartupAction::ParameterStatus(name, value) => {
320 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
321 entry.1 = value;
322 } else {
323 self.params.push((name, value));
324 }
325 }
326 StartupAction::BackendKeyData(pid, secret) => {
327 self.pid = pid;
328 self.secret = secret;
329 }
330 StartupAction::ReadyForQuery(status) => {
331 self.tx_status = status;
332 return Ok(());
333 }
334 StartupAction::Error(msg) => {
335 return Err(DriverError::Auth(msg));
336 }
337 StartupAction::Notice => {}
338 }
339 }
340 }
341
342 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
343 let (msg_type, _) = self.read_message_buffered()?;
344 let payload = &self.read_buf;
345 let msg = proto::parse_backend_message(msg_type, payload)?;
346 match msg {
347 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
348 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
349 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
350 BackendMessage::AuthSasl { mechanisms } => {
351 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
352 }
353 BackendMessage::ParameterStatus { name, value } => {
354 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
355 }
356 BackendMessage::BackendKeyData { pid, secret } => {
357 Ok(StartupAction::BackendKeyData(pid, secret))
358 }
359 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
360 BackendMessage::ErrorResponse { data } => {
361 let fields = proto::parse_error_response(data);
362 Ok(StartupAction::Error(fields.to_string()))
363 }
364 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
365 other => Err(DriverError::Protocol(format!(
366 "unexpected message during startup: {other:?}"
367 ))),
368 }
369 }
370
371 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
372 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
373
374 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
377 let mechanism = if use_plus {
378 "SCRAM-SHA-256-PLUS"
379 } else {
380 "SCRAM-SHA-256"
381 };
382
383 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
384 return Err(DriverError::Auth(format!(
385 "server requires unsupported SASL mechanism(s): {mechs:?}"
386 )));
387 }
388
389 let cert_hash = if use_plus {
390 self.tls_server_cert_hash.as_ref()
391 } else {
392 None
393 };
394 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
395
396 let client_first = scram.client_first_message();
398 self.write_buf.clear();
399 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
400 self.flush_write()?;
401
402 let (msg_type, _) = self.read_message_buffered()?;
404 let server_first = {
405 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
406 match msg {
407 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
408 BackendMessage::ErrorResponse { data } => {
409 let fields = proto::parse_error_response(data);
410 return Err(DriverError::Auth(fields.to_string()));
411 }
412 other => {
413 return Err(DriverError::Protocol(format!(
414 "expected AuthSaslContinue, got: {other:?}"
415 )));
416 }
417 }
418 };
419
420 scram.process_server_first(&server_first)?;
421
422 let client_final = scram.client_final_message()?;
424 self.write_buf.clear();
425 proto::write_sasl_response(&mut self.write_buf, &client_final);
426 self.flush_write()?;
427
428 let (msg_type, _) = self.read_message_buffered()?;
430 {
431 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
432 match msg {
433 BackendMessage::AuthSaslFinal { data } => {
434 let data_owned = data.to_vec();
435 scram.verify_server_final(&data_owned)?;
436 }
437 BackendMessage::ErrorResponse { data } => {
438 let fields = proto::parse_error_response(data);
439 return Err(DriverError::Auth(fields.to_string()));
440 }
441 other => {
442 return Err(DriverError::Protocol(format!(
443 "expected AuthSaslFinal, got: {other:?}"
444 )));
445 }
446 }
447 }
448
449 let (msg_type, _) = self.read_message_buffered()?;
451 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
452 match msg {
453 BackendMessage::AuthOk => Ok(()),
454 BackendMessage::ErrorResponse { data } => {
455 let fields = proto::parse_error_response(data);
456 Err(DriverError::Auth(fields.to_string()))
457 }
458 other => Err(DriverError::Protocol(format!(
459 "expected AuthOk after SCRAM, got: {other:?}"
460 ))),
461 }
462 }
463
464 fn validate_server_params(&self) -> Result<(), DriverError> {
465 if let Some(encoding) = self.parameter("server_encoding") {
466 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
467 return Err(DriverError::Protocol(format!(
468 "server_encoding is '{encoding}', but bsql requires UTF-8."
469 )));
470 }
471 }
472 if let Some(encoding) = self.parameter("client_encoding") {
473 if !encoding.eq_ignore_ascii_case("UTF8") && !encoding.eq_ignore_ascii_case("UTF-8") {
474 return Err(DriverError::Protocol(format!(
475 "client_encoding is '{encoding}', but bsql requires UTF-8."
476 )));
477 }
478 }
479 if let Some(idt) = self.parameter("integer_datetimes") {
480 if idt != "on" {
481 return Err(DriverError::Protocol(format!(
482 "integer_datetimes is '{idt}', but bsql requires 'on'."
483 )));
484 }
485 }
486 Ok(())
487 }
488
489 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
495 if self.statement_cache_mode == StatementCacheMode::Disabled {
496 return Ok(()); }
498 if self.stmts.contains_key(&sql_hash, sql) {
499 return Ok(());
500 }
501 let name = make_stmt_name(sql_hash);
502 self.write_buf.clear();
503 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
504 proto::write_describe(&mut self.write_buf, b'S', &name);
505 proto::write_sync(&mut self.write_buf);
506 self.flush_write()?;
507
508 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
509 let columns = self.read_column_description()?;
510 self.expect_ready()?;
511
512 self.query_counter += 1;
513 self.cache_stmt(
514 sql_hash,
515 StmtInfo {
516 name,
517 sql: sql.into(),
518 columns,
519 last_used: self.query_counter,
520 bind_template: None,
521 },
522 );
523 Ok(())
524 }
525
526 pub fn prepare_batch(&mut self, sqls: &[(&str, u64)]) -> Result<(), DriverError> {
534 if sqls.is_empty() || self.statement_cache_mode == StatementCacheMode::Disabled {
535 return Ok(()); }
537
538 let mut pending = 0usize;
540 self.write_buf.clear();
541 for &(sql, sql_hash) in sqls {
542 if self.stmts.contains_key(&sql_hash, sql) {
543 continue;
544 }
545 let name = make_stmt_name(sql_hash);
546 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
547 proto::write_describe(&mut self.write_buf, b'S', &name);
548 pending += 1;
549 }
550
551 if pending == 0 {
552 return Ok(());
553 }
554
555 proto::write_sync(&mut self.write_buf);
556 self.flush_write()?;
557
558 for &(sql, sql_hash) in sqls {
561 if self.stmts.contains_key(&sql_hash, sql) {
562 continue;
563 }
564
565 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
566 let columns = self.read_column_description()?;
567
568 let name = make_stmt_name(sql_hash);
569 self.query_counter += 1;
570 self.cache_stmt(
571 sql_hash,
572 StmtInfo {
573 name,
574 sql: sql.into(),
575 columns,
576 last_used: self.query_counter,
577 bind_template: None,
578 },
579 );
580 }
581
582 self.expect_ready()?;
583 Ok(())
584 }
585
586 #[inline]
597 pub fn query(
598 &mut self,
599 sql: &str,
600 sql_hash: u64,
601 params: &[&(dyn Encode + Sync)],
602 ) -> Result<QueryResult, DriverError> {
603 let columns = self
604 .send_pipeline(sql, sql_hash, params, true, true)?
605 .ok_or_else(|| {
606 DriverError::Protocol("send_pipeline(need_columns=true) returned None".into())
607 })?;
608
609 let num_cols = columns.len();
610 let mut all_col_offsets = acquire_col_offsets();
611 all_col_offsets.clear();
612 let mut affected_rows: u64 = 0;
613
614 let mut resp_buf = acquire_resp_buf();
623 resp_buf.clear();
624
625 'outer: loop {
627 loop {
628 let avail = self.stream_buf_end - self.stream_buf_pos;
629 if avail < 5 {
630 break; }
632
633 let msg_type = self.stream_buf[self.stream_buf_pos];
634 let raw_len = i32::from_be_bytes([
635 self.stream_buf[self.stream_buf_pos + 1],
636 self.stream_buf[self.stream_buf_pos + 2],
637 self.stream_buf[self.stream_buf_pos + 3],
638 self.stream_buf[self.stream_buf_pos + 4],
639 ]);
640
641 if raw_len < 4 {
642 return Err(DriverError::Protocol(format!(
643 "invalid message length {raw_len} for type '{}'",
644 msg_type as char
645 )));
646 }
647
648 let payload_len = (raw_len - 4) as usize;
649 let total_msg_len = 5 + payload_len;
650
651 if avail < total_msg_len {
652 if total_msg_len > self.stream_buf.len() {
653 let msg = self.read_one_message()?;
655 match msg {
656 BackendMessage::BindComplete => continue,
657 BackendMessage::DataRow { data } => {
658 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
659 continue;
660 }
661 BackendMessage::CommandComplete { tag } => {
662 affected_rows = proto::parse_command_tag(tag);
663 continue;
664 }
665 BackendMessage::EmptyQuery => continue,
666 BackendMessage::ReadyForQuery { status } => {
667 self.tx_status = status;
668 break 'outer;
669 }
670 BackendMessage::NoticeResponse { .. } => continue,
671 BackendMessage::ErrorResponse { data } => {
672 let fields = proto::parse_error_response(data);
673 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
674 self.drain_to_ready()?;
675 return Err(self.make_server_error(fields));
676 }
677 other => {
678 return Err(DriverError::Protocol(format!(
679 "unexpected message during query: {other:?}"
680 )));
681 }
682 }
683 }
684 break; }
686
687 let payload_start = self.stream_buf_pos + 5;
689 let payload_end = payload_start + payload_len;
690
691 if msg_type == b'D' {
692 parse_data_row_into_buf(
694 &self.stream_buf[payload_start..payload_end],
695 &mut resp_buf,
696 &mut all_col_offsets,
697 )?;
698 } else if msg_type == b'Z' {
699 if payload_len >= 1 {
700 self.tx_status = self.stream_buf[payload_start];
701 }
702 self.stream_buf_pos += total_msg_len;
703 break 'outer;
704 } else {
705 self.handle_non_datarow_query(
706 msg_type,
707 payload_start,
708 payload_end,
709 sql_hash,
710 &mut affected_rows,
711 )?;
712 }
713
714 self.stream_buf_pos += total_msg_len;
715 }
716
717 self.refill_stream_buf()?;
718 }
719
720 self.shrink_buffers();
721
722 Ok(QueryResult::from_parts_with_buf(
725 all_col_offsets,
726 num_cols,
727 columns,
728 affected_rows,
729 resp_buf,
730 ))
731 }
732
733 #[inline]
744 pub fn execute_monolithic(
745 &mut self,
746 sql: &str,
747 sql_hash: u64,
748 params: &[&(dyn Encode + Sync)],
749 ) -> Result<u64, DriverError> {
750 if self.statement_cache_mode == StatementCacheMode::Disabled {
752 return self.execute_unnamed(sql, params);
753 }
754
755 self.write_buf.clear();
757
758 let info = match self.stmts.get_mut(&sql_hash, sql) {
760 Some(info) => {
761 self.query_counter += 1;
762 info.last_used = self.query_counter;
763 info
764 }
765 None => {
766 return self.execute_with_prepare(sql, sql_hash, params);
768 }
769 };
770
771 let can_use_template = info
773 .bind_template
774 .as_ref()
775 .is_some_and(|t| t.param_slots.len() == params.len());
776
777 let mut has_exec_sync = false;
778
779 if can_use_template {
780 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
782 DriverError::Protocol("bind_template missing despite can_use_template".into())
783 })?;
784 self.write_buf.extend_from_slice(&tmpl.bytes);
785
786 let mut template_ok = true;
787 for (i, param) in params.iter().enumerate() {
788 let (data_offset, old_len) = tmpl.param_slots[i];
789 if param.is_null() {
790 let len_offset = data_offset - 4;
791 self.write_buf[len_offset..len_offset + 4]
792 .copy_from_slice(&(-1i32).to_be_bytes());
793 } else if old_len >= 0 {
794 let end = data_offset + old_len as usize;
795 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
796 template_ok = false;
797 break;
798 }
799 } else {
800 template_ok = false;
802 break;
803 }
804 }
805
806 if template_ok {
807 has_exec_sync = true;
808 } else {
809 self.write_buf.clear();
810 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
811 info.bind_template = None;
812 }
813 } else {
814 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
815 }
816
817 if info.bind_template.is_none() && !self.write_buf.is_empty() {
819 info.bind_template = build_bind_template(&self.write_buf, params.len());
820 }
821
822 if !has_exec_sync {
823 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
824 }
825
826 self.stream
828 .write_all(&self.write_buf)
829 .map_err(DriverError::Io)?;
830
831 let mut affected_rows: u64 = 0;
833
834 'outer: loop {
835 loop {
836 let avail = self.stream_buf_end - self.stream_buf_pos;
837 if avail < 5 {
838 break; }
840
841 let msg_type = self.stream_buf[self.stream_buf_pos];
842 let raw_len = i32::from_be_bytes([
843 self.stream_buf[self.stream_buf_pos + 1],
844 self.stream_buf[self.stream_buf_pos + 2],
845 self.stream_buf[self.stream_buf_pos + 3],
846 self.stream_buf[self.stream_buf_pos + 4],
847 ]);
848
849 if raw_len < 4 {
850 return Err(DriverError::Protocol(format!(
851 "invalid message length {raw_len} for type '{}'",
852 msg_type as char
853 )));
854 }
855
856 let payload_len = (raw_len - 4) as usize;
857 let total_msg_len = 5 + payload_len;
858
859 if avail < total_msg_len {
860 if total_msg_len > self.stream_buf.len() {
861 let msg = self.read_one_message()?;
862 match msg {
863 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
864 continue;
865 }
866 BackendMessage::CommandComplete { tag } => {
867 affected_rows = proto::parse_command_tag(tag);
868 continue;
869 }
870 BackendMessage::EmptyQuery => continue,
871 BackendMessage::ReadyForQuery { status } => {
872 self.tx_status = status;
873 break 'outer;
874 }
875 BackendMessage::NoticeResponse { .. } => continue,
876 BackendMessage::ErrorResponse { data } => {
877 let fields = proto::parse_error_response(data);
878 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
879 self.drain_to_ready()?;
880 return Err(self.make_server_error(fields));
881 }
882 other => {
883 return Err(DriverError::Protocol(format!(
884 "unexpected message during execute: {other:?}"
885 )));
886 }
887 }
888 }
889 break; }
891
892 let payload_start = self.stream_buf_pos + 5;
897 let payload_end = payload_start + payload_len;
898
899 if msg_type == b'2' {
900 self.stream_buf_pos += total_msg_len;
902 continue;
903 } else if msg_type == b'C' {
904 affected_rows = proto::parse_command_tag_bytes(
906 &self.stream_buf[payload_start..payload_end],
907 );
908 } else if msg_type == b'Z' {
909 if payload_len >= 1 {
911 self.tx_status = self.stream_buf[payload_start];
912 }
913 self.stream_buf_pos += total_msg_len;
914 break 'outer;
915 } else if msg_type == b'D' || msg_type == b'I' {
916 } else {
918 self.handle_non_datarow_execute(
919 msg_type,
920 payload_start,
921 payload_end,
922 sql_hash,
923 )?;
924 }
925
926 self.stream_buf_pos += total_msg_len;
927 }
928
929 let remaining = self.stream_buf_end - self.stream_buf_pos;
931 debug_assert!(
932 remaining == 0 || self.stream_buf_pos > 0,
933 "compact called with pos=0 and remaining data"
934 );
935 if remaining > 0 {
936 self.stream_buf
937 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
938 }
939 self.stream_buf_pos = 0;
940 self.stream_buf_end = remaining;
941 let n = self
942 .stream
943 .read(&mut self.stream_buf[remaining..])
944 .map_err(DriverError::Io)?;
945 if n == 0 {
946 return Err(DriverError::Io(std::io::Error::new(
947 std::io::ErrorKind::UnexpectedEof,
948 "connection closed",
949 )));
950 }
951 self.stream_buf_end = remaining + n;
952 }
953
954 if self.query_counter & 63 == 0 {
956 if self.read_buf.capacity() > 64 * 1024 {
957 self.read_buf.clear();
958 self.read_buf.shrink_to(8192);
959 }
960 if self.write_buf.capacity() > 16 * 1024 {
961 self.write_buf.clear();
962 self.write_buf.shrink_to(8192);
963 }
964 }
965
966 Ok(affected_rows)
967 }
968
969 #[cold]
971 #[inline(never)]
972 fn execute_with_prepare(
973 &mut self,
974 sql: &str,
975 sql_hash: u64,
976 params: &[&(dyn Encode + Sync)],
977 ) -> Result<u64, DriverError> {
978 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
979
980 if params.len() > i16::MAX as usize {
981 return Err(DriverError::Protocol(format!(
982 "parameter count {} exceeds maximum {}",
983 params.len(),
984 i16::MAX
985 )));
986 }
987
988 let name = make_stmt_name(sql_hash);
989 let param_oids: smallvec::SmallVec<[u32; 8]> =
990 params.iter().map(|p| p.type_oid()).collect();
991
992 self.write_buf.clear();
993 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
994 proto::write_describe(&mut self.write_buf, b'S', &name);
995 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
996 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
997 self.stream
998 .write_all(&self.write_buf)
999 .map_err(DriverError::Io)?;
1000
1001 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1002 let columns = self.read_column_description()?;
1003 self.query_counter += 1;
1004 self.cache_stmt(
1005 sql_hash,
1006 StmtInfo {
1007 name,
1008 sql: sql.into(),
1009 columns,
1010 last_used: self.query_counter,
1011 bind_template: None,
1012 },
1013 );
1014
1015 let mut affected_rows: u64 = 0;
1017 'outer: loop {
1018 loop {
1019 let avail = self.stream_buf_end - self.stream_buf_pos;
1020 if avail < 5 {
1021 break;
1022 }
1023
1024 let msg_type = self.stream_buf[self.stream_buf_pos];
1025 let raw_len = i32::from_be_bytes([
1026 self.stream_buf[self.stream_buf_pos + 1],
1027 self.stream_buf[self.stream_buf_pos + 2],
1028 self.stream_buf[self.stream_buf_pos + 3],
1029 self.stream_buf[self.stream_buf_pos + 4],
1030 ]);
1031
1032 if raw_len < 4 {
1033 return Err(DriverError::Protocol(format!(
1034 "invalid message length {raw_len} for type '{}'",
1035 msg_type as char
1036 )));
1037 }
1038
1039 let payload_len = (raw_len - 4) as usize;
1040 let total_msg_len = 5 + payload_len;
1041
1042 if avail < total_msg_len {
1043 if total_msg_len > self.stream_buf.len() {
1044 let msg = self.read_one_message()?;
1045 match msg {
1046 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
1047 continue;
1048 }
1049 BackendMessage::CommandComplete { tag } => {
1050 affected_rows = proto::parse_command_tag(tag);
1051 continue;
1052 }
1053 BackendMessage::EmptyQuery => continue,
1054 BackendMessage::ReadyForQuery { status } => {
1055 self.tx_status = status;
1056 break 'outer;
1057 }
1058 BackendMessage::NoticeResponse { .. } => continue,
1059 BackendMessage::ErrorResponse { data } => {
1060 let fields = proto::parse_error_response(data);
1061 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1062 self.drain_to_ready()?;
1063 return Err(self.make_server_error(fields));
1064 }
1065 other => {
1066 return Err(DriverError::Protocol(format!(
1067 "unexpected message during execute: {other:?}"
1068 )));
1069 }
1070 }
1071 }
1072 break;
1073 }
1074
1075 let payload_start = self.stream_buf_pos + 5;
1076 let payload_end = payload_start + payload_len;
1077
1078 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
1079 } else if msg_type == b'C' {
1081 affected_rows = proto::parse_command_tag_bytes(
1082 &self.stream_buf[payload_start..payload_end],
1083 );
1084 } else if msg_type == b'Z' {
1085 if payload_len >= 1 {
1086 self.tx_status = self.stream_buf[payload_start];
1087 }
1088 self.stream_buf_pos += total_msg_len;
1089 break 'outer;
1090 } else {
1091 self.handle_non_datarow_execute(
1092 msg_type,
1093 payload_start,
1094 payload_end,
1095 sql_hash,
1096 )?;
1097 }
1098
1099 self.stream_buf_pos += total_msg_len;
1100 }
1101
1102 self.refill_stream_buf()?;
1103 }
1104
1105 Ok(affected_rows)
1106 }
1107
1108 fn execute_unnamed(
1112 &mut self,
1113 sql: &str,
1114 params: &[&(dyn Encode + Sync)],
1115 ) -> Result<u64, DriverError> {
1116 if params.len() > i16::MAX as usize {
1117 return Err(DriverError::Protocol(format!(
1118 "parameter count {} exceeds maximum {}",
1119 params.len(),
1120 i16::MAX
1121 )));
1122 }
1123
1124 self.write_buf.clear();
1125 let param_oids: smallvec::SmallVec<[u32; 8]> =
1126 params.iter().map(|p| p.type_oid()).collect();
1127 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
1128 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
1129 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1130 self.stream
1131 .write_all(&self.write_buf)
1132 .map_err(DriverError::Io)?;
1133
1134 let mut affected_rows: u64 = 0;
1136 'outer: loop {
1137 loop {
1138 let avail = self.stream_buf_end - self.stream_buf_pos;
1139 if avail < 5 {
1140 break;
1141 }
1142
1143 let msg_type = self.stream_buf[self.stream_buf_pos];
1144 let raw_len = i32::from_be_bytes([
1145 self.stream_buf[self.stream_buf_pos + 1],
1146 self.stream_buf[self.stream_buf_pos + 2],
1147 self.stream_buf[self.stream_buf_pos + 3],
1148 self.stream_buf[self.stream_buf_pos + 4],
1149 ]);
1150
1151 if raw_len < 4 {
1152 return Err(DriverError::Protocol(format!(
1153 "invalid message length {raw_len} for type '{}'",
1154 msg_type as char
1155 )));
1156 }
1157
1158 let payload_len = (raw_len - 4) as usize;
1159 let total_msg_len = 5 + payload_len;
1160
1161 if avail < total_msg_len {
1162 if total_msg_len > self.stream_buf.len() {
1163 let msg = self.read_one_message()?;
1164 match msg {
1165 BackendMessage::ParseComplete
1166 | BackendMessage::BindComplete
1167 | BackendMessage::DataRow { .. } => continue,
1168 BackendMessage::CommandComplete { tag } => {
1169 affected_rows = proto::parse_command_tag(tag);
1170 continue;
1171 }
1172 BackendMessage::EmptyQuery => continue,
1173 BackendMessage::ReadyForQuery { status } => {
1174 self.tx_status = status;
1175 break 'outer;
1176 }
1177 BackendMessage::NoticeResponse { .. } => continue,
1178 BackendMessage::ErrorResponse { data } => {
1179 let fields = proto::parse_error_response(data);
1180 self.drain_to_ready()?;
1181 return Err(self.make_server_error(fields));
1182 }
1183 other => {
1184 return Err(DriverError::Protocol(format!(
1185 "unexpected message during unnamed execute: {other:?}"
1186 )));
1187 }
1188 }
1189 }
1190 break;
1191 }
1192
1193 if msg_type == b'1' || msg_type == b'2' || msg_type == b'I' {
1195 self.stream_buf_pos += total_msg_len;
1197 continue;
1198 } else if msg_type == b'C' {
1199 let payload_start = self.stream_buf_pos + 5;
1201 let payload_end = payload_start + payload_len;
1202 affected_rows = proto::parse_command_tag_bytes(
1203 &self.stream_buf[payload_start..payload_end],
1204 );
1205 self.stream_buf_pos += total_msg_len;
1206 continue;
1207 } else if msg_type == b'Z' {
1208 let payload_start = self.stream_buf_pos + 5;
1210 let payload_end = payload_start + payload_len;
1211 if payload_end > payload_start {
1212 self.tx_status = self.stream_buf[payload_start];
1213 }
1214 self.stream_buf_pos += total_msg_len;
1215 break 'outer;
1216 } else if msg_type == b'E' {
1217 let payload_start = self.stream_buf_pos + 5;
1219 let payload_end = payload_start + payload_len;
1220 let fields =
1221 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
1222 self.stream_buf_pos += total_msg_len;
1223 self.drain_to_ready()?;
1224 return Err(self.make_server_error(fields));
1225 } else if msg_type == b'N' || msg_type == b'D' {
1226 self.stream_buf_pos += total_msg_len;
1228 continue;
1229 } else {
1230 return Err(DriverError::Protocol(format!(
1231 "unexpected message type '{}' during unnamed execute",
1232 msg_type as char
1233 )));
1234 }
1235 }
1236
1237 self.refill_stream_buf()?;
1238 }
1239
1240 Ok(affected_rows)
1241 }
1242
1243 #[inline]
1248 pub fn execute(
1249 &mut self,
1250 sql: &str,
1251 sql_hash: u64,
1252 params: &[&(dyn Encode + Sync)],
1253 ) -> Result<u64, DriverError> {
1254 self.execute_monolithic(sql, sql_hash, params)
1255 }
1256
1257 pub fn execute_pipeline(
1269 &mut self,
1270 sql: &str,
1271 sql_hash: u64,
1272 param_sets: &[&[&(dyn Encode + Sync)]],
1273 ) -> Result<Vec<u64>, DriverError> {
1274 if param_sets.is_empty() {
1275 return Ok(Vec::new());
1276 }
1277
1278 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1279
1280 if self.statement_cache_mode == StatementCacheMode::Disabled {
1282 return self.execute_pipeline_unnamed(sql, param_sets);
1283 }
1284
1285 self.write_buf.clear();
1286
1287 if !self.stmts.contains_key(&sql_hash, sql) {
1289 let name = make_stmt_name(sql_hash);
1290 let first_params = param_sets[0];
1291 if first_params.len() > i16::MAX as usize {
1292 return Err(DriverError::Protocol(format!(
1293 "parameter count {} exceeds maximum {}",
1294 first_params.len(),
1295 i16::MAX
1296 )));
1297 }
1298 let param_oids: smallvec::SmallVec<[u32; 8]> =
1299 first_params.iter().map(|p| p.type_oid()).collect();
1300 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1301 proto::write_describe(&mut self.write_buf, b'S', &name);
1302 proto::write_sync(&mut self.write_buf);
1303 self.flush_write()?;
1304
1305 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1306 let columns = self.read_column_description()?;
1307 self.expect_ready()?;
1308
1309 self.query_counter += 1;
1310 self.cache_stmt(
1311 sql_hash,
1312 StmtInfo {
1313 name,
1314 sql: sql.into(),
1315 columns,
1316 last_used: self.query_counter,
1317 bind_template: None,
1318 },
1319 );
1320
1321 self.write_buf.clear();
1322 }
1323
1324 let stmt_name = self
1326 .stmts
1327 .get(&sql_hash, sql)
1328 .ok_or_else(|| {
1329 DriverError::Protocol("stmt just cached but not found in execute_pipeline".into())
1330 })?
1331 .name;
1332 let count = param_sets.len();
1333
1334 for params in param_sets {
1335 if params.len() > i16::MAX as usize {
1336 return Err(DriverError::Protocol(format!(
1337 "parameter count {} exceeds maximum {}",
1338 params.len(),
1339 i16::MAX
1340 )));
1341 }
1342 proto::write_bind_params(&mut self.write_buf, b"", &stmt_name, params);
1343 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1344 }
1345
1346 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1347 self.flush_write()?;
1348
1349 let mut results = Vec::with_capacity(count);
1352
1353 'outer: loop {
1354 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1355 if msg_type == b'2' {
1356 } else if msg_type == b'C' {
1358 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1360 results.push(rows);
1361 } else if msg_type == b'Z' {
1362 if end > start {
1364 self.tx_status = self.stream_buf[start];
1365 }
1366 self.advance_stream_msg(total);
1367 break 'outer;
1368 } else if msg_type == b'I' {
1369 results.push(0);
1371 } else if msg_type == b'D' || msg_type == b'N' {
1372 } else if msg_type == b'E' {
1374 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1376 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1377 self.advance_stream_msg(total);
1378 self.drain_to_ready()?;
1379 return Err(self.make_server_error(fields));
1380 } else if msg_type == b'A' {
1381 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1383 if let BackendMessage::NotificationResponse {
1384 pid,
1385 channel,
1386 payload,
1387 } = msg
1388 {
1389 let ch = channel.to_owned();
1390 let pl = payload.to_owned();
1391 self.buffer_notification(pid, &ch, &pl);
1392 }
1393 }
1394 self.advance_stream_msg(total);
1397 }
1398
1399 self.refill_stream_buf()?;
1401 }
1402
1403 self.shrink_buffers();
1404 Ok(results)
1405 }
1406
1407 fn execute_pipeline_unnamed(
1412 &mut self,
1413 sql: &str,
1414 param_sets: &[&[&(dyn Encode + Sync)]],
1415 ) -> Result<Vec<u64>, DriverError> {
1416 let count = param_sets.len();
1417 self.write_buf.clear();
1418
1419 for params in param_sets {
1420 if params.len() > i16::MAX as usize {
1421 return Err(DriverError::Protocol(format!(
1422 "parameter count {} exceeds maximum {}",
1423 params.len(),
1424 i16::MAX
1425 )));
1426 }
1427 let param_oids: smallvec::SmallVec<[u32; 8]> =
1428 params.iter().map(|p| p.type_oid()).collect();
1429 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
1430 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
1431 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1432 }
1433
1434 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1435 self.flush_write()?;
1436
1437 let mut results = Vec::with_capacity(count);
1439
1440 'outer: loop {
1441 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1442 if msg_type == b'1' || msg_type == b'2' {
1443 } else if msg_type == b'C' {
1445 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1446 results.push(rows);
1447 } else if msg_type == b'Z' {
1448 if end > start {
1449 self.tx_status = self.stream_buf[start];
1450 }
1451 self.advance_stream_msg(total);
1452 break 'outer;
1453 } else if msg_type == b'I' {
1454 results.push(0);
1455 } else if msg_type == b'D' || msg_type == b'N' {
1456 } else if msg_type == b'E' {
1458 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1459 self.advance_stream_msg(total);
1460 self.drain_to_ready()?;
1461 return Err(self.make_server_error(fields));
1462 } else if msg_type == b'A' {
1463 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1464 if let BackendMessage::NotificationResponse {
1465 pid,
1466 channel,
1467 payload,
1468 } = msg
1469 {
1470 let ch = channel.to_owned();
1471 let pl = payload.to_owned();
1472 self.buffer_notification(pid, &ch, &pl);
1473 }
1474 }
1475 self.advance_stream_msg(total);
1476 }
1477
1478 self.refill_stream_buf()?;
1479 }
1480
1481 self.shrink_buffers();
1482 Ok(results)
1483 }
1484
1485 pub(crate) fn ensure_stmt_prepared(
1491 &mut self,
1492 sql: &str,
1493 sql_hash: u64,
1494 params: &[&(dyn Encode + Sync)],
1495 ) -> Result<[u8; 18], DriverError> {
1496 if self.statement_cache_mode == StatementCacheMode::Disabled {
1498 return Ok([0u8; 18]);
1499 }
1500
1501 if let Some(info) = self.stmts.get(&sql_hash, sql) {
1502 return Ok(info.name);
1503 }
1504
1505 let name = make_stmt_name(sql_hash);
1506 if params.len() > i16::MAX as usize {
1507 return Err(DriverError::Protocol(format!(
1508 "parameter count {} exceeds maximum {}",
1509 params.len(),
1510 i16::MAX
1511 )));
1512 }
1513 let param_oids: smallvec::SmallVec<[u32; 8]> =
1514 params.iter().map(|p| p.type_oid()).collect();
1515
1516 self.write_buf.clear();
1517 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1518 proto::write_describe(&mut self.write_buf, b'S', &name);
1519 proto::write_sync(&mut self.write_buf);
1520 self.flush_write()?;
1521
1522 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1523 let columns = self.read_column_description()?;
1524 self.expect_ready()?;
1525
1526 self.query_counter += 1;
1527 self.cache_stmt(
1528 sql_hash,
1529 StmtInfo {
1530 name,
1531 sql: sql.into(),
1532 columns,
1533 last_used: self.query_counter,
1534 bind_template: None,
1535 },
1536 );
1537
1538 Ok(name)
1539 }
1540
1541 pub(crate) fn write_deferred_bind_execute(
1545 &self,
1546 sql: &str,
1547 sql_hash: u64,
1548 params: &[&(dyn Encode + Sync)],
1549 buf: &mut Vec<u8>,
1550 ) -> Result<(), DriverError> {
1551 if self.statement_cache_mode == StatementCacheMode::Disabled {
1552 let param_oids: smallvec::SmallVec<[u32; 8]> =
1554 params.iter().map(|p| p.type_oid()).collect();
1555 proto::write_parse(buf, b"", sql, ¶m_oids);
1556 proto::write_bind_params(buf, b"", b"", params);
1557 buf.extend_from_slice(proto::EXECUTE_ONLY);
1558 return Ok(());
1559 }
1560
1561 let stmt_name = self
1562 .stmts
1563 .get(&sql_hash, sql)
1564 .ok_or_else(|| {
1565 DriverError::Protocol("stmt just cached but not found in write_deferred".into())
1566 })?
1567 .name;
1568 proto::write_bind_params(buf, b"", &stmt_name, params);
1569 buf.extend_from_slice(proto::EXECUTE_ONLY);
1570 Ok(())
1571 }
1572
1573 pub(crate) fn flush_deferred_pipeline(
1578 &mut self,
1579 buf: &mut Vec<u8>,
1580 count: usize,
1581 ) -> Result<Vec<u64>, DriverError> {
1582 if count == 0 {
1583 buf.clear();
1584 return Ok(Vec::new());
1585 }
1586
1587 buf.extend_from_slice(proto::SYNC_ONLY);
1588
1589 self.stream.write_all(buf).map_err(DriverError::Io)?;
1590 buf.clear();
1591
1592 let mut results = Vec::with_capacity(count);
1594
1595 'outer: loop {
1596 while let Some((msg_type, start, end, total)) = self.peek_stream_msg()? {
1597 if msg_type == b'1' || msg_type == b'2' {
1598 } else if msg_type == b'C' {
1601 let rows = proto::parse_command_tag_bytes(&self.stream_buf[start..end]);
1603 results.push(rows);
1604 } else if msg_type == b'Z' {
1605 if end > start {
1607 self.tx_status = self.stream_buf[start];
1608 }
1609 self.advance_stream_msg(total);
1610 break 'outer;
1611 } else if msg_type == b'I' {
1612 results.push(0);
1614 } else if msg_type == b'D' || msg_type == b'N' {
1615 } else if msg_type == b'E' {
1617 let fields = proto::parse_error_response(&self.stream_buf[start..end]);
1619 self.advance_stream_msg(total);
1620 self.drain_to_ready()?;
1621 return Err(self.make_server_error(fields));
1622 } else if msg_type == b'A' {
1623 let msg = proto::parse_backend_message(msg_type, &self.stream_buf[start..end])?;
1625 if let BackendMessage::NotificationResponse {
1626 pid,
1627 channel,
1628 payload,
1629 } = msg
1630 {
1631 let ch = channel.to_owned();
1632 let pl = payload.to_owned();
1633 self.buffer_notification(pid, &ch, &pl);
1634 }
1635 }
1636 self.advance_stream_msg(total);
1639 }
1640
1641 self.refill_stream_buf()?;
1643 }
1644
1645 self.shrink_buffers();
1646 Ok(results)
1647 }
1648
1649 pub fn for_each<F>(
1651 &mut self,
1652 sql: &str,
1653 sql_hash: u64,
1654 params: &[&(dyn Encode + Sync)],
1655 mut f: F,
1656 ) -> Result<(), DriverError>
1657 where
1658 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1659 {
1660 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1661
1662 'outer: loop {
1664 loop {
1665 let avail = self.stream_buf_end - self.stream_buf_pos;
1666 if avail < 5 {
1667 break; }
1669
1670 let msg_type = self.stream_buf[self.stream_buf_pos];
1671 let raw_len = i32::from_be_bytes([
1672 self.stream_buf[self.stream_buf_pos + 1],
1673 self.stream_buf[self.stream_buf_pos + 2],
1674 self.stream_buf[self.stream_buf_pos + 3],
1675 self.stream_buf[self.stream_buf_pos + 4],
1676 ]);
1677
1678 if raw_len < 4 {
1679 return Err(DriverError::Protocol(format!(
1680 "invalid message length {raw_len} for type '{}'",
1681 msg_type as char
1682 )));
1683 }
1684
1685 let payload_len = (raw_len - 4) as usize;
1686 let total_msg_len = 5 + payload_len;
1687
1688 if avail < total_msg_len {
1689 if total_msg_len > self.stream_buf.len() {
1690 let msg = self.read_one_message()?;
1692 match msg {
1693 BackendMessage::BindComplete => continue,
1694 BackendMessage::DataRow { data } => {
1695 let row = PgDataRow::new(data)?;
1696 f(row)?;
1697 continue;
1698 }
1699 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1700 continue;
1701 }
1702 BackendMessage::ReadyForQuery { status } => {
1703 self.tx_status = status;
1704 break 'outer;
1705 }
1706 BackendMessage::NoticeResponse { .. } => continue,
1707 BackendMessage::ErrorResponse { data } => {
1708 let fields = proto::parse_error_response(data);
1709 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1710 self.drain_to_ready()?;
1711 return Err(self.make_server_error(fields));
1712 }
1713 other => {
1714 return Err(DriverError::Protocol(format!(
1715 "unexpected message during for_each: {other:?}"
1716 )));
1717 }
1718 }
1719 }
1720 break; }
1722
1723 let payload_start = self.stream_buf_pos + 5;
1725 let payload_end = payload_start + payload_len;
1726
1727 if msg_type == b'D' {
1730 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1732 f(row)?;
1733 } else if msg_type == b'Z' {
1734 if payload_len >= 1 {
1736 self.tx_status = self.stream_buf[payload_start];
1737 }
1738 self.stream_buf_pos += total_msg_len;
1739 break 'outer;
1740 } else {
1741 self.handle_non_datarow_execute(
1742 msg_type,
1743 payload_start,
1744 payload_end,
1745 sql_hash,
1746 )?;
1747 }
1748
1749 self.stream_buf_pos += total_msg_len;
1750 }
1751
1752 self.refill_stream_buf()?;
1754 }
1755
1756 self.shrink_buffers();
1757 Ok(())
1758 }
1759
1760 #[inline]
1771 pub fn for_each_raw_monolithic<F>(
1772 &mut self,
1773 sql: &str,
1774 sql_hash: u64,
1775 params: &[&(dyn Encode + Sync)],
1776 mut f: F,
1777 ) -> Result<(), DriverError>
1778 where
1779 F: FnMut(&[u8]) -> Result<(), DriverError>,
1780 {
1781 if self.statement_cache_mode == StatementCacheMode::Disabled {
1783 return self.for_each_raw_unnamed(sql, params, f);
1784 }
1785
1786 self.write_buf.clear();
1788
1789 let info = match self.stmts.get_mut(&sql_hash, sql) {
1791 Some(info) => {
1792 self.query_counter += 1;
1793 info.last_used = self.query_counter;
1794 info
1795 }
1796 None => {
1797 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1799 }
1800 };
1801
1802 let can_use_template = info
1804 .bind_template
1805 .as_ref()
1806 .is_some_and(|t| t.param_slots.len() == params.len());
1807
1808 let mut has_exec_sync = false;
1809
1810 if can_use_template {
1811 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
1813 DriverError::Protocol("bind_template missing despite can_use_template".into())
1814 })?;
1815 self.write_buf.extend_from_slice(&tmpl.bytes);
1816
1817 let mut template_ok = true;
1818 for (i, param) in params.iter().enumerate() {
1819 let (data_offset, old_len) = tmpl.param_slots[i];
1820 if param.is_null() {
1821 let len_offset = data_offset - 4;
1822 self.write_buf[len_offset..len_offset + 4]
1823 .copy_from_slice(&(-1i32).to_be_bytes());
1824 } else if old_len >= 0 {
1825 let end = data_offset + old_len as usize;
1826 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1827 template_ok = false;
1828 break;
1829 }
1830 } else {
1831 template_ok = false;
1832 break;
1833 }
1834 }
1835
1836 if template_ok {
1837 has_exec_sync = true;
1838 } else {
1839 self.write_buf.clear();
1840 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1841 info.bind_template = None;
1842 }
1843 } else {
1844 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
1845 }
1846
1847 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1849 info.bind_template = build_bind_template(&self.write_buf, params.len());
1850 }
1851
1852 if !has_exec_sync {
1853 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1854 }
1855
1856 self.stream
1858 .write_all(&self.write_buf)
1859 .map_err(DriverError::Io)?;
1860
1861 loop {
1865 let avail = self.stream_buf_end - self.stream_buf_pos;
1866 if avail >= 5 {
1867 let bc_type = self.stream_buf[self.stream_buf_pos];
1868 match bc_type {
1869 b'2' => {
1870 self.stream_buf_pos += 5;
1871 break;
1872 }
1873 b'E' => {
1874 let msg = self.read_one_message()?;
1875 if let BackendMessage::ErrorResponse { data } = msg {
1876 let fields = proto::parse_error_response(data);
1877 self.drain_to_ready()?;
1878 return Err(self.make_server_error(fields));
1879 }
1880 }
1881 b'N' | b'S' => {
1882 let raw_len = i32::from_be_bytes([
1883 self.stream_buf[self.stream_buf_pos + 1],
1884 self.stream_buf[self.stream_buf_pos + 2],
1885 self.stream_buf[self.stream_buf_pos + 3],
1886 self.stream_buf[self.stream_buf_pos + 4],
1887 ]);
1888 let total = 1 + raw_len as usize;
1889 if avail >= total {
1890 self.stream_buf_pos += total;
1891 continue;
1892 }
1893 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1894 break;
1895 }
1896 _ => {
1897 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1898 break;
1899 }
1900 }
1901 } else {
1902 let remaining = self.stream_buf_end - self.stream_buf_pos;
1904 if remaining > 0 && self.stream_buf_pos > 0 {
1905 self.stream_buf
1906 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1907 }
1908 self.stream_buf_pos = 0;
1909 self.stream_buf_end = remaining;
1910 let n = self
1911 .stream
1912 .read(&mut self.stream_buf[remaining..])
1913 .map_err(DriverError::Io)?;
1914 if n == 0 {
1915 return Err(DriverError::Io(std::io::Error::new(
1916 std::io::ErrorKind::UnexpectedEof,
1917 "connection closed",
1918 )));
1919 }
1920 self.stream_buf_end = remaining + n;
1921 }
1922 }
1923
1924 'outer: loop {
1926 loop {
1927 let avail = self.stream_buf_end - self.stream_buf_pos;
1928 if avail < 5 {
1929 break;
1930 }
1931
1932 let msg_type = self.stream_buf[self.stream_buf_pos];
1933 let raw_len = i32::from_be_bytes([
1934 self.stream_buf[self.stream_buf_pos + 1],
1935 self.stream_buf[self.stream_buf_pos + 2],
1936 self.stream_buf[self.stream_buf_pos + 3],
1937 self.stream_buf[self.stream_buf_pos + 4],
1938 ]);
1939
1940 if raw_len < 4 {
1941 return Err(DriverError::Protocol(format!(
1942 "invalid message length {raw_len} for type '{}'",
1943 msg_type as char
1944 )));
1945 }
1946
1947 let payload_len = (raw_len - 4) as usize;
1948 let total_msg_len = 5 + payload_len;
1949
1950 if avail < total_msg_len {
1951 if total_msg_len > self.stream_buf.len() {
1952 let msg = self.read_one_message()?;
1953 match msg {
1954 BackendMessage::DataRow { data } => {
1955 f(data)?;
1956 continue;
1957 }
1958 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1959 continue;
1960 }
1961 BackendMessage::ReadyForQuery { status } => {
1962 self.tx_status = status;
1963 break 'outer;
1964 }
1965 BackendMessage::ErrorResponse { data } => {
1966 let fields = proto::parse_error_response(data);
1967 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1968 self.drain_to_ready()?;
1969 return Err(self.make_server_error(fields));
1970 }
1971 BackendMessage::NoticeResponse { .. } => continue,
1972 other => {
1973 return Err(DriverError::Protocol(format!(
1974 "unexpected message during for_each_raw: {other:?}"
1975 )));
1976 }
1977 }
1978 }
1979 break; }
1981
1982 let payload_start = self.stream_buf_pos + 5;
1984 let payload_end = payload_start + payload_len;
1985
1986 if msg_type == b'D' {
1987 f(&self.stream_buf[payload_start..payload_end])?;
1988 } else if msg_type == b'Z' {
1989 if payload_len >= 1 {
1990 self.tx_status = self.stream_buf[payload_start];
1991 }
1992 self.stream_buf_pos += total_msg_len;
1993 break 'outer;
1994 } else {
1995 self.handle_non_datarow_execute(
1996 msg_type,
1997 payload_start,
1998 payload_end,
1999 sql_hash,
2000 )?;
2001 }
2002
2003 self.stream_buf_pos += total_msg_len;
2004 }
2005
2006 let remaining = self.stream_buf_end - self.stream_buf_pos;
2008 if remaining > 0 && self.stream_buf_pos > 0 {
2009 self.stream_buf
2010 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2011 }
2012 self.stream_buf_pos = 0;
2013 self.stream_buf_end = remaining;
2014 let n = self
2015 .stream
2016 .read(&mut self.stream_buf[remaining..])
2017 .map_err(DriverError::Io)?;
2018 if n == 0 {
2019 return Err(DriverError::Io(std::io::Error::new(
2020 std::io::ErrorKind::UnexpectedEof,
2021 "connection closed",
2022 )));
2023 }
2024 self.stream_buf_end = remaining + n;
2025 }
2026
2027 if self.query_counter & 63 == 0 {
2029 if self.read_buf.capacity() > 64 * 1024 {
2030 self.read_buf.clear();
2031 self.read_buf.shrink_to(8192);
2032 }
2033 if self.write_buf.capacity() > 16 * 1024 {
2034 self.write_buf.clear();
2035 self.write_buf.shrink_to(8192);
2036 }
2037 }
2038
2039 Ok(())
2040 }
2041
2042 #[cold]
2044 #[inline(never)]
2045 fn for_each_raw_with_prepare<F>(
2046 &mut self,
2047 sql: &str,
2048 sql_hash: u64,
2049 params: &[&(dyn Encode + Sync)],
2050 mut f: F,
2051 ) -> Result<(), DriverError>
2052 where
2053 F: FnMut(&[u8]) -> Result<(), DriverError>,
2054 {
2055 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2056
2057 if params.len() > i16::MAX as usize {
2058 return Err(DriverError::Protocol(format!(
2059 "parameter count {} exceeds maximum {}",
2060 params.len(),
2061 i16::MAX
2062 )));
2063 }
2064
2065 let name = make_stmt_name(sql_hash);
2066 let param_oids: smallvec::SmallVec<[u32; 8]> =
2067 params.iter().map(|p| p.type_oid()).collect();
2068
2069 self.write_buf.clear();
2070 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2071 proto::write_describe(&mut self.write_buf, b'S', &name);
2072 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2073 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2074 self.stream
2075 .write_all(&self.write_buf)
2076 .map_err(DriverError::Io)?;
2077
2078 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2079 let columns = self.read_column_description()?;
2080 self.query_counter += 1;
2081 self.cache_stmt(
2082 sql_hash,
2083 StmtInfo {
2084 name,
2085 sql: sql.into(),
2086 columns,
2087 last_used: self.query_counter,
2088 bind_template: None,
2089 },
2090 );
2091
2092 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2094
2095 'outer: loop {
2096 loop {
2097 let avail = self.stream_buf_end - self.stream_buf_pos;
2098 if avail < 5 {
2099 break;
2100 }
2101
2102 let msg_type = self.stream_buf[self.stream_buf_pos];
2103 let raw_len = i32::from_be_bytes([
2104 self.stream_buf[self.stream_buf_pos + 1],
2105 self.stream_buf[self.stream_buf_pos + 2],
2106 self.stream_buf[self.stream_buf_pos + 3],
2107 self.stream_buf[self.stream_buf_pos + 4],
2108 ]);
2109
2110 if raw_len < 4 {
2111 return Err(DriverError::Protocol(format!(
2112 "invalid message length {raw_len} for type '{}'",
2113 msg_type as char
2114 )));
2115 }
2116
2117 let payload_len = (raw_len - 4) as usize;
2118 let total_msg_len = 5 + payload_len;
2119
2120 if avail < total_msg_len {
2121 if total_msg_len > self.stream_buf.len() {
2122 let msg = self.read_one_message()?;
2123 match msg {
2124 BackendMessage::DataRow { data } => {
2125 f(data)?;
2126 continue;
2127 }
2128 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
2129 continue;
2130 }
2131 BackendMessage::ReadyForQuery { status } => {
2132 self.tx_status = status;
2133 break 'outer;
2134 }
2135 BackendMessage::ErrorResponse { data } => {
2136 let fields = proto::parse_error_response(data);
2137 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2138 self.drain_to_ready()?;
2139 return Err(self.make_server_error(fields));
2140 }
2141 BackendMessage::NoticeResponse { .. } => continue,
2142 other => {
2143 return Err(DriverError::Protocol(format!(
2144 "unexpected message during for_each_raw: {other:?}"
2145 )));
2146 }
2147 }
2148 }
2149 break;
2150 }
2151
2152 let payload_start = self.stream_buf_pos + 5;
2153 let payload_end = payload_start + payload_len;
2154
2155 if msg_type == b'D' {
2156 f(&self.stream_buf[payload_start..payload_end])?;
2157 } else if msg_type == b'Z' {
2158 if payload_len >= 1 {
2159 self.tx_status = self.stream_buf[payload_start];
2160 }
2161 self.stream_buf_pos += total_msg_len;
2162 break 'outer;
2163 } else {
2164 self.handle_non_datarow_execute(
2165 msg_type,
2166 payload_start,
2167 payload_end,
2168 sql_hash,
2169 )?;
2170 }
2171
2172 self.stream_buf_pos += total_msg_len;
2173 }
2174
2175 self.refill_stream_buf()?;
2176 }
2177
2178 self.shrink_buffers();
2179 Ok(())
2180 }
2181
2182 fn for_each_raw_unnamed<F>(
2184 &mut self,
2185 sql: &str,
2186 params: &[&(dyn Encode + Sync)],
2187 mut f: F,
2188 ) -> Result<(), DriverError>
2189 where
2190 F: FnMut(&[u8]) -> Result<(), DriverError>,
2191 {
2192 if params.len() > i16::MAX as usize {
2193 return Err(DriverError::Protocol(format!(
2194 "parameter count {} exceeds maximum {}",
2195 params.len(),
2196 i16::MAX
2197 )));
2198 }
2199
2200 let param_oids: smallvec::SmallVec<[u32; 8]> =
2201 params.iter().map(|p| p.type_oid()).collect();
2202
2203 self.write_buf.clear();
2204 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
2205 proto::write_describe(&mut self.write_buf, b'S', b"");
2206 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
2207 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2208 self.stream
2209 .write_all(&self.write_buf)
2210 .map_err(DriverError::Io)?;
2211
2212 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2213 let _columns = self.read_column_description()?;
2214 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2215
2216 'outer: loop {
2217 loop {
2218 let avail = self.stream_buf_end - self.stream_buf_pos;
2219 if avail < 5 {
2220 break;
2221 }
2222
2223 let msg_type = self.stream_buf[self.stream_buf_pos];
2224 let raw_len = i32::from_be_bytes([
2225 self.stream_buf[self.stream_buf_pos + 1],
2226 self.stream_buf[self.stream_buf_pos + 2],
2227 self.stream_buf[self.stream_buf_pos + 3],
2228 self.stream_buf[self.stream_buf_pos + 4],
2229 ]);
2230
2231 if raw_len < 4 {
2232 return Err(DriverError::Protocol(format!(
2233 "invalid message length {raw_len} for type '{}'",
2234 msg_type as char
2235 )));
2236 }
2237
2238 let payload_len = (raw_len - 4) as usize;
2239 let total_msg_len = 5 + payload_len;
2240
2241 if avail < total_msg_len {
2242 if total_msg_len > self.stream_buf.len() {
2243 let msg = self.read_one_message()?;
2244 match msg {
2245 BackendMessage::DataRow { data } => {
2246 f(data)?;
2247 continue;
2248 }
2249 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
2250 continue
2251 }
2252 BackendMessage::ReadyForQuery { status } => {
2253 self.tx_status = status;
2254 break 'outer;
2255 }
2256 BackendMessage::ErrorResponse { data } => {
2257 let fields = proto::parse_error_response(data);
2258 self.drain_to_ready()?;
2259 return Err(self.make_server_error(fields));
2260 }
2261 BackendMessage::NoticeResponse { .. } => continue,
2262 other => {
2263 return Err(DriverError::Protocol(format!(
2264 "unexpected message during for_each_raw (unnamed): {other:?}"
2265 )));
2266 }
2267 }
2268 }
2269 break;
2270 }
2271
2272 let payload_start = self.stream_buf_pos + 5;
2273 let payload_end = payload_start + payload_len;
2274
2275 if msg_type == b'D' {
2276 f(&self.stream_buf[payload_start..payload_end])?;
2277 } else if msg_type == b'Z' {
2278 if payload_len >= 1 {
2279 self.tx_status = self.stream_buf[payload_start];
2280 }
2281 self.stream_buf_pos += total_msg_len;
2282 break 'outer;
2283 } else if msg_type == b'C' || msg_type == b'I' || msg_type == b'N' {
2284 } else if msg_type == b'E' {
2286 let fields =
2287 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2288 self.stream_buf_pos += total_msg_len;
2289 self.drain_to_ready()?;
2290 return Err(self.make_server_error(fields));
2291 } else {
2292 return Err(DriverError::Protocol(format!(
2293 "unexpected message type '{}' during for_each_raw (unnamed)",
2294 msg_type as char
2295 )));
2296 }
2297
2298 self.stream_buf_pos += total_msg_len;
2299 }
2300
2301 self.refill_stream_buf()?;
2302 }
2303
2304 self.shrink_buffers();
2305 Ok(())
2306 }
2307
2308 #[inline]
2313 pub fn for_each_raw<F>(
2314 &mut self,
2315 sql: &str,
2316 sql_hash: u64,
2317 params: &[&(dyn Encode + Sync)],
2318 f: F,
2319 ) -> Result<(), DriverError>
2320 where
2321 F: FnMut(&[u8]) -> Result<(), DriverError>,
2322 {
2323 self.for_each_raw_monolithic(sql, sql_hash, params, f)
2324 }
2325
2326 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
2328 self.write_buf.clear();
2329 proto::write_simple_query(&mut self.write_buf, sql);
2330 self.flush_write()?;
2331
2332 loop {
2333 let msg = self.read_one_message()?;
2334 match msg {
2335 BackendMessage::ReadyForQuery { status } => {
2336 self.tx_status = status;
2337 return Ok(());
2338 }
2339 BackendMessage::CommandComplete { .. }
2340 | BackendMessage::RowDescription { .. }
2341 | BackendMessage::DataRow { .. }
2342 | BackendMessage::EmptyQuery
2343 | BackendMessage::NoticeResponse { .. }
2344 | BackendMessage::ParameterStatus { .. }
2345 | BackendMessage::AuthOk
2349 | BackendMessage::AuthSaslFinal { .. }
2350 | BackendMessage::BackendKeyData { .. } => {}
2351 BackendMessage::ErrorResponse { data } => {
2352 let fields = proto::parse_error_response(data);
2353 self.drain_to_ready()?;
2354 return Err(self.make_server_error(fields));
2355 }
2356 other => {
2357 return Err(DriverError::Protocol(format!(
2358 "unexpected message during simple_query: {other:?}"
2359 )));
2360 }
2361 }
2362 }
2363 }
2364
2365 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
2367 self.write_buf.clear();
2368 proto::write_simple_query(&mut self.write_buf, sql);
2369 self.flush_write()?;
2370
2371 let mut rows: Vec<SimpleRow> = Vec::new();
2372 loop {
2373 let msg = self.read_one_message()?;
2374 match msg {
2375 BackendMessage::ReadyForQuery { status } => {
2376 self.tx_status = status;
2377 return Ok(rows);
2378 }
2379 BackendMessage::DataRow { data } => {
2380 rows.push(proto::parse_simple_data_row(data)?);
2381 }
2382 BackendMessage::RowDescription { .. }
2383 | BackendMessage::CommandComplete { .. }
2384 | BackendMessage::EmptyQuery
2385 | BackendMessage::NoticeResponse { .. }
2386 | BackendMessage::ParameterStatus { .. }
2387 | BackendMessage::AuthOk
2388 | BackendMessage::AuthSaslFinal { .. }
2389 | BackendMessage::BackendKeyData { .. } => {}
2390 BackendMessage::ErrorResponse { data } => {
2391 let fields = proto::parse_error_response(data);
2392 self.drain_to_ready()?;
2393 return Err(self.make_server_error(fields));
2394 }
2395 other => {
2396 return Err(DriverError::Protocol(format!(
2397 "unexpected message during simple_query_rows: {other:?}"
2398 )));
2399 }
2400 }
2401 }
2402 }
2403
2404 pub fn copy_in<'a, I>(
2426 &mut self,
2427 table: &str,
2428 columns: &[&str],
2429 rows: I,
2430 ) -> Result<u64, DriverError>
2431 where
2432 I: IntoIterator<Item = &'a str>,
2433 {
2434 let quoted_table = proto::quote_ident(table);
2436 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
2437 let sql = format!(
2438 "COPY {}({}) FROM STDIN",
2439 quoted_table,
2440 quoted_cols.join(",")
2441 );
2442
2443 self.write_buf.clear();
2445 proto::write_simple_query(&mut self.write_buf, &sql);
2446 self.flush_write()?;
2447
2448 loop {
2450 let msg = self.read_one_message()?;
2451 match msg {
2452 BackendMessage::CopyInResponse { .. } => break,
2453 BackendMessage::ErrorResponse { data } => {
2454 let fields = proto::parse_error_response(data);
2455 self.drain_to_ready()?;
2456 return Err(self.make_server_error(fields));
2457 }
2458 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2459 other => {
2460 return Err(DriverError::Protocol(format!(
2461 "expected CopyInResponse, got: {other:?}"
2462 )));
2463 }
2464 }
2465 }
2466
2467 self.write_buf.clear();
2477 for row in rows {
2478 let row_data = row.as_bytes();
2480 let data_len = (4 + row_data.len() + 1) as i32;
2481 self.write_buf.push(b'd');
2482 self.write_buf.extend_from_slice(&data_len.to_be_bytes());
2483 self.write_buf.extend_from_slice(row_data);
2484 self.write_buf.push(b'\n');
2485 if self.write_buf.len() > 65536 {
2487 self.flush_write()?;
2488 self.write_buf.clear();
2489 }
2490 }
2491 proto::write_copy_done(&mut self.write_buf);
2494 self.flush_write()?;
2495 self.write_buf.clear();
2496
2497 let mut count: u64 = 0;
2499 loop {
2500 let msg = self.read_one_message()?;
2501 match msg {
2502 BackendMessage::CommandComplete { tag } => {
2503 count = proto::parse_command_tag(tag);
2504 }
2505 BackendMessage::ReadyForQuery { status } => {
2506 self.tx_status = status;
2507 return Ok(count);
2508 }
2509 BackendMessage::ErrorResponse { data } => {
2510 let fields = proto::parse_error_response(data);
2511 self.drain_to_ready()?;
2512 return Err(self.make_server_error(fields));
2513 }
2514 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2515 other => {
2516 return Err(DriverError::Protocol(format!(
2517 "unexpected message during copy_in completion: {other:?}"
2518 )));
2519 }
2520 }
2521 }
2522 }
2523
2524 pub fn copy_out<W: std::io::Write>(
2544 &mut self,
2545 query: &str,
2546 writer: &mut W,
2547 ) -> Result<u64, DriverError> {
2548 let sql = format!("COPY ({query}) TO STDOUT");
2550
2551 self.write_buf.clear();
2553 proto::write_simple_query(&mut self.write_buf, &sql);
2554 self.flush_write()?;
2555
2556 loop {
2558 let msg = self.read_one_message()?;
2559 match msg {
2560 BackendMessage::CopyOutResponse { .. } => break,
2561 BackendMessage::ErrorResponse { data } => {
2562 let fields = proto::parse_error_response(data);
2563 self.drain_to_ready()?;
2564 return Err(self.make_server_error(fields));
2565 }
2566 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2567 other => {
2568 return Err(DriverError::Protocol(format!(
2569 "expected CopyOutResponse, got: {other:?}"
2570 )));
2571 }
2572 }
2573 }
2574
2575 loop {
2577 let msg = self.read_one_message()?;
2578 match msg {
2579 BackendMessage::CopyData { data } => {
2580 writer.write_all(data).map_err(DriverError::Io)?;
2581 }
2582 BackendMessage::CopyDone => break,
2583 BackendMessage::ErrorResponse { data } => {
2584 let fields = proto::parse_error_response(data);
2585 self.drain_to_ready()?;
2586 return Err(self.make_server_error(fields));
2587 }
2588 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2589 other => {
2590 return Err(DriverError::Protocol(format!(
2591 "unexpected message during copy_out data: {other:?}"
2592 )));
2593 }
2594 }
2595 }
2596
2597 let mut count: u64 = 0;
2599 loop {
2600 let msg = self.read_one_message()?;
2601 match msg {
2602 BackendMessage::CommandComplete { tag } => {
2603 count = proto::parse_command_tag(tag);
2604 }
2605 BackendMessage::ReadyForQuery { status } => {
2606 self.tx_status = status;
2607 return Ok(count);
2608 }
2609 BackendMessage::ErrorResponse { data } => {
2610 let fields = proto::parse_error_response(data);
2611 self.drain_to_ready()?;
2612 return Err(self.make_server_error(fields));
2613 }
2614 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2615 other => {
2616 return Err(DriverError::Protocol(format!(
2617 "unexpected message during copy_out completion: {other:?}"
2618 )));
2619 }
2620 }
2621 }
2622 }
2623
2624 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2629 self.write_buf.clear();
2630 proto::write_parse(&mut self.write_buf, b"", sql, &[]);
2633 proto::write_describe(&mut self.write_buf, b'S', b"");
2634 proto::write_sync(&mut self.write_buf);
2635 self.flush_write()?;
2636
2637 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2639
2640 let mut param_oids: Vec<u32> = Vec::new();
2642 let columns;
2643 loop {
2644 let msg = self.read_one_message()?;
2645 match msg {
2646 BackendMessage::ParameterDescription { data } => {
2647 param_oids = proto::parse_parameter_description(data)?;
2648 }
2649 BackendMessage::RowDescription { data } => {
2650 columns = proto::parse_row_description(data)?;
2651 break;
2652 }
2653 BackendMessage::NoData => {
2654 columns = Vec::new();
2655 break;
2656 }
2657 BackendMessage::NoticeResponse { .. } => {}
2658 BackendMessage::ErrorResponse { data } => {
2659 let fields = proto::parse_error_response(data);
2660 self.drain_to_ready()?;
2661 return Err(self.make_server_error(fields));
2662 }
2663 other => {
2664 return Err(DriverError::Protocol(format!(
2665 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2666 )));
2667 }
2668 }
2669 }
2670
2671 self.expect_ready()?;
2673
2674 Ok(PrepareResult {
2675 columns,
2676 param_oids,
2677 })
2678 }
2679
2680 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2689 loop {
2690 let (msg_type, _payload_len) = self.read_message_buffered()?;
2691 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2692 match msg {
2693 BackendMessage::NotificationResponse {
2694 channel, payload, ..
2695 } => {
2696 return Ok((channel.to_owned(), payload.to_owned()));
2697 }
2698 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2699 continue;
2700 }
2701 _ => continue,
2702 }
2703 }
2704 }
2705
2706 pub fn cancel(&self) -> Result<(), DriverError> {
2712 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2713 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2714 let mut buf = Vec::with_capacity(16);
2715 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2716 tcp.write_all(&buf).map_err(DriverError::Io)?;
2717 tcp.flush().map_err(DriverError::Io)?;
2718 drop(tcp);
2720 Ok(())
2721 }
2722
2723 pub fn set_read_timeout(
2728 &self,
2729 timeout: Option<std::time::Duration>,
2730 ) -> Result<(), DriverError> {
2731 self.stream
2732 .set_read_timeout(timeout)
2733 .map_err(DriverError::Io)
2734 }
2735
2736 pub fn query_streaming_start(
2750 &mut self,
2751 sql: &str,
2752 sql_hash: u64,
2753 params: &[&(dyn Encode + Sync)],
2754 chunk_size: i32,
2755 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2756 self.write_buf.clear();
2757
2758 if self.statement_cache_mode == StatementCacheMode::Disabled {
2760 let param_oids: smallvec::SmallVec<[u32; 8]> =
2761 params.iter().map(|p| p.type_oid()).collect();
2762 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
2763 proto::write_describe(&mut self.write_buf, b'S', b"");
2764 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
2765 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2766 proto::write_flush(&mut self.write_buf);
2767 self.flush_write()?;
2768
2769 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2770 let columns = self.read_column_description()?;
2771 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2772 self.streaming_active = true;
2773 return Ok((columns, false));
2774 }
2775
2776 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2777 self.query_counter += 1;
2779 info.last_used = self.query_counter;
2780
2781 let can_use_template = info
2782 .bind_template
2783 .as_ref()
2784 .is_some_and(|t| t.param_slots.len() == params.len());
2785
2786 if can_use_template {
2787 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
2789 DriverError::Protocol("bind_template missing despite can_use_template".into())
2790 })?;
2791 self.write_buf
2794 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2795
2796 let mut template_ok = true;
2797 for (i, param) in params.iter().enumerate() {
2798 let (data_offset, old_len) = tmpl.param_slots[i];
2799 if param.is_null() {
2800 let len_offset = data_offset - 4;
2801 self.write_buf[len_offset..len_offset + 4]
2802 .copy_from_slice(&(-1i32).to_be_bytes());
2803 } else if old_len >= 0 {
2804 let end = data_offset + old_len as usize;
2805 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2806 template_ok = false;
2807 break;
2808 }
2809 } else {
2810 template_ok = false;
2811 break;
2812 }
2813 }
2814
2815 if !template_ok {
2816 self.write_buf.clear();
2817 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2818 info.bind_template = None;
2819 }
2820 } else {
2821 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
2822 }
2823
2824 let cols = info.columns.clone();
2825
2826 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2827 info.bind_template = build_bind_template(&self.write_buf, params.len());
2828 }
2829
2830 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2831 proto::write_flush(&mut self.write_buf);
2833 self.flush_write()?;
2834
2835 cols
2836 } else {
2837 let name = make_stmt_name(sql_hash);
2839 let param_oids: smallvec::SmallVec<[u32; 8]> =
2840 params.iter().map(|p| p.type_oid()).collect();
2841 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2842 proto::write_describe(&mut self.write_buf, b'S', &name);
2843 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
2844
2845 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2846 proto::write_flush(&mut self.write_buf);
2847 self.flush_write()?;
2848
2849 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2850 let columns = self.read_column_description()?;
2851 self.query_counter += 1;
2852 self.cache_stmt(
2853 sql_hash,
2854 StmtInfo {
2855 name,
2856 sql: sql.into(),
2857 columns: columns.clone(),
2858 last_used: self.query_counter,
2859 bind_template: None,
2860 },
2861 );
2862 columns
2863 };
2864
2865 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2867
2868 self.streaming_active = true;
2869
2870 Ok((columns, false))
2871 }
2872
2873 pub fn streaming_next_chunk(
2881 &mut self,
2882 arena: &mut Arena,
2883 all_col_offsets: &mut Vec<(usize, i32)>,
2884 ) -> Result<bool, DriverError> {
2885 all_col_offsets.clear();
2886
2887 loop {
2888 let msg = self.read_one_message()?;
2889 match msg {
2890 BackendMessage::DataRow { data } => {
2891 parse_data_row_flat(data, arena, all_col_offsets)?;
2892 }
2893 BackendMessage::PortalSuspended => {
2894 return Ok(true);
2898 }
2899 BackendMessage::CommandComplete { .. } => {
2900 self.write_buf.clear();
2903 proto::write_sync(&mut self.write_buf);
2904 self.flush_write()?;
2905 self.expect_ready()?;
2906 self.shrink_buffers();
2907
2908 self.streaming_active = false;
2909 return Ok(false);
2910 }
2911 BackendMessage::EmptyQuery => {
2912 self.write_buf.clear();
2913 proto::write_sync(&mut self.write_buf);
2914 self.flush_write()?;
2915 self.expect_ready()?;
2916
2917 self.streaming_active = false;
2918 return Ok(false);
2919 }
2920 BackendMessage::ErrorResponse { data } => {
2921 let fields = proto::parse_error_response(data);
2922 self.write_buf.clear();
2924 proto::write_sync(&mut self.write_buf);
2925 self.flush_write()?;
2926 self.drain_to_ready()?;
2927
2928 self.streaming_active = false;
2929 return Err(self.make_server_error(fields));
2930 }
2931 BackendMessage::NoticeResponse { .. } => {}
2932 other => {
2933 return Err(DriverError::Protocol(format!(
2934 "unexpected message during streaming: {other:?}"
2935 )));
2936 }
2937 }
2938 }
2939 }
2940
2941 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2949 self.write_buf.clear();
2950 proto::write_execute(&mut self.write_buf, b"", chunk_size);
2951 proto::write_flush(&mut self.write_buf);
2952 self.flush_write()
2953 }
2954
2955 pub fn is_streaming(&self) -> bool {
2957 self.streaming_active
2958 }
2959
2960 pub fn close(mut self) -> Result<(), DriverError> {
2962 self.write_buf.clear();
2963 proto::write_terminate(&mut self.write_buf);
2964 let _ = self.flush_write();
2965 Ok(())
2966 }
2967
2968 pub fn is_idle(&self) -> bool {
2972 self.tx_status == b'I'
2973 }
2974
2975 pub fn is_in_transaction(&self) -> bool {
2977 self.tx_status == b'T'
2978 }
2979
2980 pub fn is_in_failed_transaction(&self) -> bool {
2982 self.tx_status == b'E'
2983 }
2984
2985 pub fn touch(&mut self) {
2987 self.last_used = std::time::Instant::now();
2988 }
2989
2990 pub fn idle_duration(&self) -> std::time::Duration {
2992 self.last_used.elapsed()
2993 }
2994
2995 pub fn query_counter(&self) -> u64 {
2997 self.query_counter
2998 }
2999
3000 pub fn parameter(&self, name: &str) -> Option<&str> {
3002 self.params
3003 .iter()
3004 .find(|(k, _)| &**k == name)
3005 .map(|(_, v)| &**v)
3006 }
3007
3008 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
3010 &self.params
3011 }
3012
3013 pub fn pid(&self) -> i32 {
3015 self.pid
3016 }
3017
3018 pub fn secret_key(&self) -> i32 {
3020 self.secret
3021 }
3022
3023 pub fn drain_notifications(&mut self) -> Vec<Notification> {
3025 std::mem::take(&mut self.pending_notifications)
3026 }
3027
3028 pub fn pending_notification_count(&self) -> usize {
3030 self.pending_notifications.len()
3031 }
3032
3033 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
3035 self.max_stmt_cache_size = size;
3036 }
3037
3038 pub fn stmt_cache_len(&self) -> usize {
3040 self.stmts.len()
3041 }
3042
3043 pub fn created_at(&self) -> std::time::Instant {
3045 self.created_at
3046 }
3047
3048 #[inline]
3056 fn send_pipeline(
3057 &mut self,
3058 sql: &str,
3059 sql_hash: u64,
3060 params: &[&(dyn Encode + Sync)],
3061 need_columns: bool,
3062 skip_bind_complete: bool,
3063 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
3064 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
3065
3066 if params.len() > i16::MAX as usize {
3067 return Err(DriverError::Protocol(format!(
3068 "parameter count {} exceeds maximum {}",
3069 params.len(),
3070 i16::MAX
3071 )));
3072 }
3073
3074 self.write_buf.clear();
3075
3076 if self.statement_cache_mode == StatementCacheMode::Disabled {
3078 let param_oids: smallvec::SmallVec<[u32; 8]> =
3079 params.iter().map(|p| p.type_oid()).collect();
3080 proto::write_parse(&mut self.write_buf, b"", sql, ¶m_oids);
3081 if need_columns {
3082 proto::write_describe(&mut self.write_buf, b'S', b"");
3083 }
3084 proto::write_bind_params(&mut self.write_buf, b"", b"", params);
3085 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
3086 self.flush_write()?;
3087
3088 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
3089 let columns = if need_columns {
3090 Some(self.read_column_description()?)
3091 } else {
3092 None
3093 };
3094 if !skip_bind_complete {
3095 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
3096 }
3097 return Ok(columns);
3098 }
3099
3100 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
3101 self.query_counter += 1;
3103 info.last_used = self.query_counter;
3104
3105 let can_use_template = info
3106 .bind_template
3107 .as_ref()
3108 .is_some_and(|t| t.param_slots.len() == params.len());
3109
3110 let mut has_exec_sync = false;
3112
3113 if can_use_template {
3114 let tmpl = info.bind_template.as_ref().ok_or_else(|| {
3118 DriverError::Protocol("bind_template missing despite can_use_template".into())
3119 })?;
3120 self.write_buf.extend_from_slice(&tmpl.bytes);
3121
3122 let mut template_ok = true;
3123 for (i, param) in params.iter().enumerate() {
3124 let (data_offset, old_len) = tmpl.param_slots[i];
3125 if param.is_null() {
3126 let len_offset = data_offset - 4;
3128 self.write_buf[len_offset..len_offset + 4]
3129 .copy_from_slice(&(-1i32).to_be_bytes());
3130 } else if old_len >= 0 {
3131 let end = data_offset + old_len as usize;
3132 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
3133 template_ok = false;
3135 break;
3136 }
3137 } else {
3138 template_ok = false;
3141 break;
3142 }
3143 }
3144
3145 if template_ok {
3146 has_exec_sync = true; } else {
3148 self.write_buf.clear();
3149 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
3150 info.bind_template = None;
3152 }
3153 } else {
3154 proto::write_bind_params(&mut self.write_buf, b"", &info.name, params);
3155 }
3156
3157 let cols = if need_columns {
3158 Some(info.columns.clone())
3159 } else {
3160 None
3161 };
3162
3163 if info.bind_template.is_none() && !self.write_buf.is_empty() {
3167 info.bind_template = build_bind_template(&self.write_buf, params.len());
3168 }
3169
3170 if !has_exec_sync {
3171 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
3172 }
3173 self.flush_write()?;
3174
3175 cols
3176 } else {
3177 let name = make_stmt_name(sql_hash);
3179 let param_oids: smallvec::SmallVec<[u32; 8]> =
3180 params.iter().map(|p| p.type_oid()).collect();
3181 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
3182 proto::write_describe(&mut self.write_buf, b'S', &name);
3183 proto::write_bind_params(&mut self.write_buf, b"", &name, params);
3184
3185 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
3186 self.flush_write()?;
3187
3188 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
3189 let columns = self.read_column_description()?;
3190 self.query_counter += 1;
3191 self.cache_stmt(
3192 sql_hash,
3193 StmtInfo {
3194 name,
3195 sql: sql.into(),
3196 columns: columns.clone(),
3197 last_used: self.query_counter,
3198 bind_template: None,
3199 },
3200 );
3201 if need_columns {
3202 Some(columns)
3203 } else {
3204 None
3205 }
3206 };
3207
3208 if !skip_bind_complete {
3209 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
3210 }
3211
3212 Ok(columns)
3213 }
3214
3215 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
3217 loop {
3218 let msg = self.read_one_message()?;
3219 match msg {
3220 BackendMessage::RowDescription { data } => {
3221 let cols = proto::parse_row_description(data)?;
3222 return Ok(cols.into());
3223 }
3224 BackendMessage::ParameterDescription { .. } => {}
3225 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
3226 BackendMessage::NoticeResponse { .. } => {}
3227 BackendMessage::ErrorResponse { data } => {
3228 let fields = proto::parse_error_response(data);
3229 self.drain_to_ready()?;
3230 return Err(self.make_server_error(fields));
3231 }
3232 other => {
3233 return Err(DriverError::Protocol(format!(
3234 "expected RowDescription/NoData, got: {other:?}"
3235 )));
3236 }
3237 }
3238 }
3239 }
3240
3241 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
3244 if self.stmts.len() >= self.max_stmt_cache_size
3245 && !self.stmts.contains_key(&sql_hash, &info.sql)
3246 {
3247 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
3248 proto::write_close(&mut self.write_buf, b'S', &evicted.name);
3249 }
3250 }
3251 self.stmts.insert(sql_hash, info);
3252 }
3253
3254 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
3255 if self.pending_notifications.len() < 1024 {
3256 self.pending_notifications.push(Notification {
3257 pid,
3258 channel: channel.to_owned(),
3259 payload: payload.to_owned(),
3260 });
3261 }
3262 }
3263
3264 fn shrink_buffers(&mut self) {
3265 if self.query_counter & 63 != 0 {
3269 return;
3270 }
3271 if self.read_buf.capacity() > 64 * 1024 {
3272 self.read_buf.clear();
3273 self.read_buf.shrink_to(8192);
3274 }
3275 if self.write_buf.capacity() > 16 * 1024 {
3276 self.write_buf.clear();
3277 self.write_buf.shrink_to(8192);
3278 }
3279 }
3280
3281 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
3282 if &fields.code == b"26000" {
3283 self.stmts.remove(&sql_hash);
3284 true
3285 } else {
3286 false
3287 }
3288 }
3289
3290 #[cold]
3291 #[inline(never)]
3292 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
3293 DriverError::Server {
3294 code: fields.code,
3295 message: fields.message.into_boxed_str(),
3296 detail: fields.detail.map(String::into_boxed_str),
3297 hint: fields.hint.map(String::into_boxed_str),
3298 position: fields.position,
3299 }
3300 }
3301
3302 #[cold]
3308 #[inline(never)]
3309 fn handle_non_datarow_query(
3310 &mut self,
3311 msg_type: u8,
3312 payload_start: usize,
3313 payload_end: usize,
3314 sql_hash: u64,
3315 affected_rows: &mut u64,
3316 ) -> Result<(), DriverError> {
3317 match msg_type {
3318 b'2' | b'I' => {} b'C' => {
3320 *affected_rows =
3321 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
3322 }
3323 b'E' => {
3324 let fields =
3325 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
3326 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
3327 self.drain_to_ready()?;
3328 return Err(self.make_server_error(fields));
3329 }
3330 b'A' => {
3331 let msg = proto::parse_backend_message(
3332 msg_type,
3333 &self.stream_buf[payload_start..payload_end],
3334 )?;
3335 if let BackendMessage::NotificationResponse {
3336 pid,
3337 channel,
3338 payload,
3339 } = msg
3340 {
3341 let ch = channel.to_owned();
3342 let pl = payload.to_owned();
3343 self.buffer_notification(pid, &ch, &pl);
3344 }
3345 }
3346 _ => {} }
3348 Ok(())
3349 }
3350
3351 #[cold]
3354 #[inline(never)]
3355 fn handle_non_datarow_execute(
3356 &mut self,
3357 msg_type: u8,
3358 payload_start: usize,
3359 payload_end: usize,
3360 sql_hash: u64,
3361 ) -> Result<(), DriverError> {
3362 match msg_type {
3363 b'2' | b'C' | b'I' => {} b'E' => {
3365 let fields =
3366 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
3367 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
3368 self.drain_to_ready()?;
3369 return Err(self.make_server_error(fields));
3370 }
3371 b'A' => {
3372 let msg = proto::parse_backend_message(
3373 msg_type,
3374 &self.stream_buf[payload_start..payload_end],
3375 )?;
3376 if let BackendMessage::NotificationResponse {
3377 pid,
3378 channel,
3379 payload,
3380 } = msg
3381 {
3382 let ch = channel.to_owned();
3383 let pl = payload.to_owned();
3384 self.buffer_notification(pid, &ch, &pl);
3385 }
3386 }
3387 _ => {} }
3389 Ok(())
3390 }
3391
3392 #[inline(always)]
3399 fn peek_stream_msg(&self) -> Result<Option<(u8, usize, usize, usize)>, DriverError> {
3400 let avail = self.stream_buf_end - self.stream_buf_pos;
3401 if avail < 5 {
3402 return Ok(None);
3403 }
3404
3405 let msg_type = self.stream_buf[self.stream_buf_pos];
3406 let raw_len = i32::from_be_bytes([
3407 self.stream_buf[self.stream_buf_pos + 1],
3408 self.stream_buf[self.stream_buf_pos + 2],
3409 self.stream_buf[self.stream_buf_pos + 3],
3410 self.stream_buf[self.stream_buf_pos + 4],
3411 ]);
3412
3413 if raw_len < 4 {
3414 return Err(DriverError::Protocol(format!(
3415 "invalid message length {raw_len} for type '{}'",
3416 msg_type as char
3417 )));
3418 }
3419
3420 let payload_len = (raw_len - 4) as usize;
3421 let total_msg_len = 5 + payload_len;
3422
3423 if avail < total_msg_len {
3424 return Ok(None);
3425 }
3426
3427 let payload_start = self.stream_buf_pos + 5;
3428 Ok(Some((
3429 msg_type,
3430 payload_start,
3431 payload_start + payload_len,
3432 total_msg_len,
3433 )))
3434 }
3435
3436 #[inline(always)]
3438 fn advance_stream_msg(&mut self, total_msg_len: usize) {
3439 self.stream_buf_pos += total_msg_len;
3440 }
3441
3442 #[inline]
3444 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
3445 loop {
3446 let (msg_type, _payload_len) = self.read_message_buffered()?;
3447 if msg_type == b'A' {
3448 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
3449 if let BackendMessage::NotificationResponse {
3450 pid,
3451 channel,
3452 payload,
3453 } = msg
3454 {
3455 let pid_owned = pid;
3456 let channel_owned = channel.to_owned();
3457 let payload_owned = payload.to_owned();
3458 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
3459 continue;
3460 }
3461 }
3462 return proto::parse_backend_message(msg_type, &self.read_buf);
3463 }
3464 }
3465
3466 fn expect_message(
3467 &mut self,
3468 pred: impl Fn(&BackendMessage<'_>) -> bool,
3469 ) -> Result<(), DriverError> {
3470 loop {
3471 let msg = self.read_one_message()?;
3472 if pred(&msg) {
3473 return Ok(());
3474 }
3475 match msg {
3476 BackendMessage::ErrorResponse { data } => {
3477 let fields = proto::parse_error_response(data);
3478 self.drain_to_ready()?;
3479 return Err(self.make_server_error(fields));
3480 }
3481 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3482 other => {
3483 return Err(DriverError::Protocol(format!(
3484 "unexpected message while waiting for expected type: {other:?}"
3485 )));
3486 }
3487 }
3488 }
3489 }
3490
3491 fn expect_ready(&mut self) -> Result<(), DriverError> {
3492 loop {
3493 let msg = self.read_one_message()?;
3494 match msg {
3495 BackendMessage::ReadyForQuery { status } => {
3496 self.tx_status = status;
3497 return Ok(());
3498 }
3499 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
3500 BackendMessage::ErrorResponse { data } => {
3501 let fields = proto::parse_error_response(data);
3502 self.drain_to_ready()?;
3503 return Err(self.make_server_error(fields));
3504 }
3505 _ => {}
3506 }
3507 }
3508 }
3509
3510 #[inline]
3511 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
3512 loop {
3513 let msg = self.read_one_message()?;
3514 if let BackendMessage::ReadyForQuery { status } = msg {
3515 self.tx_status = status;
3516 return Ok(());
3517 }
3518 }
3519 }
3520
3521 #[inline]
3525 fn flush_write(&mut self) -> Result<(), DriverError> {
3526 self.stream
3527 .write_all(&self.write_buf)
3528 .map_err(DriverError::Io)
3529 }
3530
3531 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
3535 let mut header = [0u8; 5];
3536 sync_buffered_read_exact(
3537 &mut self.stream,
3538 &mut self.stream_buf,
3539 &mut self.stream_buf_pos,
3540 &mut self.stream_buf_end,
3541 &mut header,
3542 )?;
3543
3544 let msg_type = header[0];
3545 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
3546
3547 if len < 4 {
3548 return Err(DriverError::Protocol(format!(
3549 "invalid message length {len} for type '{}'",
3550 msg_type as char
3551 )));
3552 }
3553
3554 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
3555 if len > MAX_MESSAGE_LEN {
3556 return Err(DriverError::Protocol(format!(
3557 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
3558 msg_type as char
3559 )));
3560 }
3561
3562 let payload_len = (len - 4) as usize;
3563 self.read_buf.clear();
3564 self.read_buf.resize(payload_len, 0);
3565 if payload_len > 0 {
3566 sync_buffered_read_exact(
3567 &mut self.stream,
3568 &mut self.stream_buf,
3569 &mut self.stream_buf_pos,
3570 &mut self.stream_buf_end,
3571 &mut self.read_buf[..payload_len],
3572 )?;
3573 }
3574
3575 Ok((msg_type, payload_len))
3576 }
3577
3578 #[inline]
3580 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
3581 let remaining = self.stream_buf_end - self.stream_buf_pos;
3582 if remaining > 0 && self.stream_buf_pos > 0 {
3583 self.stream_buf
3584 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
3585 }
3586 self.stream_buf_pos = 0;
3587 self.stream_buf_end = remaining;
3588
3589 let n = self
3590 .stream
3591 .read(&mut self.stream_buf[remaining..])
3592 .map_err(DriverError::Io)?;
3593 if n == 0 {
3594 return Err(DriverError::Io(std::io::Error::new(
3595 std::io::ErrorKind::UnexpectedEof,
3596 "connection closed",
3597 )));
3598 }
3599 self.stream_buf_end = remaining + n;
3600 Ok(())
3601 }
3602}
3603
3604fn sync_buffered_read_exact(
3607 stream: &mut Stream,
3608 buf: &mut [u8],
3609 pos: &mut usize,
3610 end: &mut usize,
3611 out: &mut [u8],
3612) -> Result<(), DriverError> {
3613 let mut filled = 0;
3614 while filled < out.len() {
3615 let avail = *end - *pos;
3616 if avail > 0 {
3617 let take = avail.min(out.len() - filled);
3618 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
3619 *pos += take;
3620 filled += take;
3621 } else {
3622 *pos = 0;
3623 let n = stream.read(buf).map_err(DriverError::Io)?;
3624 if n == 0 {
3625 return Err(DriverError::Io(std::io::Error::new(
3626 std::io::ErrorKind::UnexpectedEof,
3627 "connection closed",
3628 )));
3629 }
3630 *end = n;
3631 }
3632 }
3633 Ok(())
3634}
3635
3636#[inline(always)]
3646pub(crate) fn parse_data_row_into_buf(
3647 data: &[u8],
3648 buf: &mut Vec<u8>,
3649 out: &mut Vec<(usize, i32)>,
3650) -> Result<(), DriverError> {
3651 if data.len() < 2 {
3652 return Err(DriverError::Protocol("DataRow too short".into()));
3653 }
3654
3655 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3656 if num_cols < 0 {
3657 return Err(DriverError::Protocol(
3658 "DataRow: negative column count".into(),
3659 ));
3660 }
3661 let num_cols = num_cols as usize;
3662
3663 let col_data = &data[2..];
3671 let base = buf.len();
3672 buf.extend_from_slice(col_data);
3673
3674 let mut pos: usize = 0;
3676 for _ in 0..num_cols {
3677 if pos + 4 > col_data.len() {
3678 return Err(DriverError::Protocol("DataRow truncated".into()));
3679 }
3680
3681 let col_len = i32::from_be_bytes([
3682 col_data[pos],
3683 col_data[pos + 1],
3684 col_data[pos + 2],
3685 col_data[pos + 3],
3686 ]);
3687 pos += 4;
3688
3689 if col_len < 0 {
3690 out.push((0, -1));
3691 } else {
3692 let len = col_len as usize;
3693 if pos + len > col_data.len() {
3694 return Err(DriverError::Protocol(
3695 "DataRow column data truncated".into(),
3696 ));
3697 }
3698 out.push((base + pos, col_len));
3700 pos += len;
3701 }
3702 }
3703
3704 Ok(())
3705}
3706
3707fn parse_data_row_flat(
3711 data: &[u8],
3712 arena: &mut Arena,
3713 out: &mut Vec<(usize, i32)>,
3714) -> Result<(), DriverError> {
3715 if data.len() < 2 {
3716 return Err(DriverError::Protocol("DataRow too short".into()));
3717 }
3718
3719 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3720 if num_cols_raw < 0 {
3721 return Err(DriverError::Protocol(
3722 "DataRow: negative column count".into(),
3723 ));
3724 }
3725 let num_cols = num_cols_raw as usize;
3726 out.reserve(num_cols);
3727
3728 let col_data = &data[2..];
3731 let base = arena.alloc_copy(col_data);
3732
3733 let mut pos: usize = 0;
3735 for _ in 0..num_cols {
3736 if pos + 4 > col_data.len() {
3737 return Err(DriverError::Protocol("DataRow truncated".into()));
3738 }
3739
3740 let col_len = i32::from_be_bytes([
3741 col_data[pos],
3742 col_data[pos + 1],
3743 col_data[pos + 2],
3744 col_data[pos + 3],
3745 ]);
3746 pos += 4;
3747
3748 if col_len < 0 {
3749 out.push((0, -1));
3750 } else {
3751 let len = col_len as usize;
3752 if pos + len > col_data.len() {
3753 return Err(DriverError::Protocol(
3754 "DataRow column data truncated".into(),
3755 ));
3756 }
3757 out.push((base + pos, col_len));
3759 pos += len;
3760 }
3761 }
3762
3763 Ok(())
3764}
3765
3766#[cfg(test)]
3767#[allow(clippy::approx_constant)]
3768mod tests {
3769 use super::*;
3770 use crate::types::hash_sql;
3771
3772 #[test]
3773 fn sync_config_tcp_no_longer_rejected() {
3774 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3777 let result = Connection::connect(&config);
3778 assert!(result.is_err());
3779 let err = result.unwrap_err().to_string();
3780 assert!(
3783 !err.contains("Unix domain socket"),
3784 "error should NOT mention UDS requirement: {err}"
3785 );
3786 }
3787
3788 #[test]
3789 fn sync_data_row_parsing() {
3790 let mut arena = Arena::new();
3791 let mut out = Vec::new();
3792
3793 let mut data = Vec::new();
3794 data.extend_from_slice(&2i16.to_be_bytes());
3795 data.extend_from_slice(&4i32.to_be_bytes());
3796 data.extend_from_slice(&42i32.to_be_bytes());
3797 data.extend_from_slice(&(-1i32).to_be_bytes());
3798
3799 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3800 assert_eq!(out.len(), 2);
3801 assert_eq!(out[0].1, 4);
3802 assert_eq!(out[1].1, -1);
3803 }
3804
3805 #[test]
3806 fn sync_data_row_empty() {
3807 let mut arena = Arena::new();
3808 let mut out = Vec::new();
3809 let data = 0i16.to_be_bytes();
3810 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3811 assert_eq!(out.len(), 0);
3812 }
3813
3814 #[test]
3815 fn sync_data_row_too_short() {
3816 let mut arena = Arena::new();
3817 let mut out = Vec::new();
3818 let data = vec![0u8];
3819 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3820 }
3821
3822 #[test]
3823 fn sync_data_row_negative_col_count() {
3824 let mut arena = Arena::new();
3825 let mut out = Vec::new();
3826 let data = (-1i16).to_be_bytes();
3827 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3828 }
3829
3830 #[test]
3831 fn sync_data_row_truncated() {
3832 let mut arena = Arena::new();
3833 let mut out = Vec::new();
3834 let mut data = Vec::new();
3835 data.extend_from_slice(&2i16.to_be_bytes());
3836 data.extend_from_slice(&4i32.to_be_bytes());
3837 data.extend_from_slice(&42i32.to_be_bytes());
3838 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3840 }
3841
3842 #[test]
3843 fn sync_data_row_col_data_truncated() {
3844 let mut arena = Arena::new();
3845 let mut out = Vec::new();
3846 let mut data = Vec::new();
3847 data.extend_from_slice(&1i16.to_be_bytes());
3848 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3851 }
3852
3853 #[test]
3856 fn sync_connect_tcp_unreachable_port() {
3857 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3860 let result = Connection::connect(&config);
3861 assert!(result.is_err());
3862 let err = result.unwrap_err().to_string();
3863 assert!(
3864 !err.contains("Unix domain socket"),
3865 "error should NOT mention UDS: {err}"
3866 );
3867 }
3868
3869 #[test]
3870 fn sync_connect_ip_address_attempts_tcp() {
3871 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3874 let result = Connection::connect(&config);
3875 assert!(result.is_err());
3876 }
3877
3878 #[test]
3881 fn sync_data_row_all_null() {
3882 let mut arena = Arena::new();
3883 let mut out = Vec::new();
3884 let mut data = Vec::new();
3885 data.extend_from_slice(&3i16.to_be_bytes());
3886 data.extend_from_slice(&(-1i32).to_be_bytes());
3887 data.extend_from_slice(&(-1i32).to_be_bytes());
3888 data.extend_from_slice(&(-1i32).to_be_bytes());
3889 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3890 assert_eq!(out.len(), 3);
3891 for (_, len) in &out {
3892 assert_eq!(*len, -1);
3893 }
3894 }
3895
3896 #[test]
3897 fn sync_data_row_long_text() {
3898 let mut arena = Arena::new();
3899 let mut out = Vec::new();
3900 let long_text = "a".repeat(2048);
3901 let text_bytes = long_text.as_bytes();
3902 let mut data = Vec::new();
3903 data.extend_from_slice(&1i16.to_be_bytes());
3904 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3905 data.extend_from_slice(text_bytes);
3906 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3907 assert_eq!(out.len(), 1);
3908 assert_eq!(out[0].1, text_bytes.len() as i32);
3909 let stored = arena.get(out[0].0, out[0].1 as usize);
3910 assert_eq!(stored, text_bytes);
3911 }
3912
3913 #[test]
3914 fn sync_data_row_empty_text() {
3915 let mut arena = Arena::new();
3916 let mut out = Vec::new();
3917 let mut data = Vec::new();
3918 data.extend_from_slice(&1i16.to_be_bytes());
3919 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3921 assert_eq!(out.len(), 1);
3922 assert_eq!(out[0].1, 0); }
3924
3925 #[test]
3926 fn sync_data_row_17_columns_exceeds_smallvec() {
3927 let mut arena = Arena::new();
3928 let mut out = Vec::new();
3929 let mut data = Vec::new();
3930 let num_cols: i16 = 20;
3931 data.extend_from_slice(&num_cols.to_be_bytes());
3932 for i in 0..num_cols {
3933 let val = (i as i32).to_be_bytes();
3934 data.extend_from_slice(&4i32.to_be_bytes());
3935 data.extend_from_slice(&val);
3936 }
3937 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3938 assert_eq!(out.len(), 20);
3939 for (idx, (offset, len)) in out.iter().enumerate() {
3940 assert_eq!(*len, 4);
3941 let stored = arena.get(*offset, 4);
3942 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3943 assert_eq!(val, idx as i32);
3944 }
3945 }
3946
3947 #[test]
3948 fn sync_data_row_mixed_null_and_data() {
3949 let mut arena = Arena::new();
3950 let mut out = Vec::new();
3951 let mut data = Vec::new();
3952 data.extend_from_slice(&5i16.to_be_bytes());
3953 data.extend_from_slice(&(-1i32).to_be_bytes());
3955 data.extend_from_slice(&4i32.to_be_bytes());
3957 data.extend_from_slice(&42i32.to_be_bytes());
3958 data.extend_from_slice(&(-1i32).to_be_bytes());
3960 data.extend_from_slice(&(-1i32).to_be_bytes());
3962 data.extend_from_slice(&5i32.to_be_bytes());
3964 data.extend_from_slice(b"hello");
3965
3966 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3967 assert_eq!(out.len(), 5);
3968 assert_eq!(out[0].1, -1);
3969 assert_eq!(out[1].1, 4);
3970 assert_eq!(out[2].1, -1);
3971 assert_eq!(out[3].1, -1);
3972 assert_eq!(out[4].1, 5);
3973 let stored = arena.get(out[4].0, 5);
3974 assert_eq!(stored, b"hello");
3975 }
3976
3977 #[test]
3980 #[ignore] fn sync_connect_uds_if_pg_available() {
3982 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3983 let result = Connection::connect(&config);
3984 if let Ok(conn) = result {
3986 assert!(conn.pid() != 0, "pid should be nonzero");
3987 assert!(conn.is_idle(), "should start idle");
3988 assert!(!conn.is_in_transaction(), "should not be in tx");
3989 assert!(
3990 !conn.is_in_failed_transaction(),
3991 "should not be in failed tx"
3992 );
3993 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3994 let _ = conn.close();
3995 }
3996 }
3997
3998 #[test]
3999 #[ignore] fn sync_simple_query_if_pg_available() {
4001 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4002 let mut conn = Connection::connect(&config).unwrap();
4003 conn.simple_query("SELECT 1").unwrap();
4004 assert!(conn.is_idle());
4005 let _ = conn.close();
4006 }
4007
4008 #[test]
4009 #[ignore] fn sync_query_with_params_if_pg_available() {
4011 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4012 let mut conn = Connection::connect(&config).unwrap();
4013 let sql = "SELECT $1::int4 + $2::int4 AS sum";
4014 let hash = hash_sql(sql);
4015 let a: i32 = 10;
4016 let b: i32 = 20;
4017 let result = conn
4018 .query(
4019 sql,
4020 hash,
4021 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
4022 )
4023 .unwrap();
4024 assert_eq!(result.len(), 1);
4025 let _ = conn.close();
4026 }
4027
4028 #[test]
4029 #[ignore] fn sync_execute_if_pg_available() {
4031 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4032 let mut conn = Connection::connect(&config).unwrap();
4033 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
4034 .unwrap();
4035 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
4036 let hash = hash_sql(sql);
4037 let val: i32 = 42;
4038 let affected = conn
4039 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
4040 .unwrap();
4041 assert_eq!(affected, 1);
4042 let _ = conn.close();
4043 }
4044
4045 #[test]
4046 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
4048 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4049 let mut conn = Connection::connect(&config).unwrap();
4050 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
4051 .unwrap();
4052 let sql = "SELECT id FROM _sync_fe0";
4053 let hash = hash_sql(sql);
4054 let mut count = 0u32;
4055 conn.for_each(sql, hash, &[], |_row| {
4056 count += 1;
4057 Ok(())
4058 })
4059 .unwrap();
4060 assert_eq!(count, 0);
4061 let _ = conn.close();
4062 }
4063
4064 #[test]
4065 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
4067 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4068 let mut conn = Connection::connect(&config).unwrap();
4069 let sql = "SELECT generate_series(1, 5)";
4070 let hash = hash_sql(sql);
4071 let mut count = 0u32;
4072 conn.for_each(sql, hash, &[], |_row| {
4073 count += 1;
4074 Ok(())
4075 })
4076 .unwrap();
4077 assert_eq!(count, 5);
4078 let _ = conn.close();
4079 }
4080
4081 #[test]
4082 #[ignore] fn sync_prepare_only_if_pg_available() {
4084 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4085 let mut conn = Connection::connect(&config).unwrap();
4086 let sql = "SELECT 1";
4087 let hash = hash_sql(sql);
4088 conn.prepare_only(sql, hash).unwrap();
4089 assert_eq!(conn.stmt_cache_len(), 1);
4090 conn.prepare_only(sql, hash).unwrap();
4092 assert_eq!(conn.stmt_cache_len(), 1);
4093 let _ = conn.close();
4094 }
4095
4096 #[test]
4097 #[ignore] fn sync_simple_query_rows_if_pg_available() {
4099 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4100 let mut conn = Connection::connect(&config).unwrap();
4101 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
4102 assert!(!rows.is_empty());
4103 let _ = conn.close();
4104 }
4105
4106 #[test]
4107 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
4109 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4110 let mut conn = Connection::connect(&config).unwrap();
4111 let sql1 = "SELECT 1";
4112 let hash1 = hash_sql(sql1);
4113 conn.query(sql1, hash1, &[]).unwrap();
4114 assert_eq!(conn.stmt_cache_len(), 1);
4115 conn.query(sql1, hash1, &[]).unwrap();
4117 assert_eq!(conn.stmt_cache_len(), 1);
4118 let sql2 = "SELECT 2";
4120 let hash2 = hash_sql(sql2);
4121 conn.query(sql2, hash2, &[]).unwrap();
4122 assert_eq!(conn.stmt_cache_len(), 2);
4123 let _ = conn.close();
4124 }
4125
4126 #[test]
4127 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
4129 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4130 let mut conn = Connection::connect(&config).unwrap();
4131 let sql = "SELECTTTT INVALID GARBAGE";
4132 let hash = hash_sql(sql);
4133 let result = conn.query(sql, hash, &[]);
4134 assert!(result.is_err());
4135 assert!(conn.is_idle());
4137 let _ = conn.close();
4138 }
4139
4140 #[test]
4141 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
4143 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4144 let mut conn = Connection::connect(&config).unwrap();
4145 assert!(conn.is_idle());
4146 assert!(!conn.is_in_transaction());
4147 conn.simple_query("BEGIN").unwrap();
4148 assert!(conn.is_in_transaction());
4149 assert!(!conn.is_idle());
4150 conn.simple_query("COMMIT").unwrap();
4151 assert!(conn.is_idle());
4152 assert!(!conn.is_in_transaction());
4153 let _ = conn.close();
4154 }
4155
4156 #[test]
4157 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
4159 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4160 let mut conn = Connection::connect(&config).unwrap();
4161 conn.set_max_stmt_cache_size(3);
4162 for i in 0..5 {
4163 let sql = format!("SELECT {}", i);
4164 let hash = hash_sql(&sql);
4165 conn.query(&sql, hash, &[]).unwrap();
4166 }
4167 assert!(
4169 conn.stmt_cache_len() <= 3,
4170 "cache should be capped at 3, got {}",
4171 conn.stmt_cache_len()
4172 );
4173 let _ = conn.close();
4174 }
4175
4176 #[test]
4177 #[ignore] fn sync_for_each_raw_if_pg_available() {
4179 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4180 let mut conn = Connection::connect(&config).unwrap();
4181 let sql = "SELECT generate_series(1, 3)";
4182 let hash = hash_sql(sql);
4183 let mut raw_count = 0u32;
4184 conn.for_each_raw(sql, hash, &[], |_raw_data| {
4185 raw_count += 1;
4186 Ok(())
4187 })
4188 .unwrap();
4189 assert_eq!(raw_count, 3);
4190 let _ = conn.close();
4191 }
4192
4193 #[test]
4194 #[ignore] fn sync_query_null_params_if_pg_available() {
4196 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4197 let mut conn = Connection::connect(&config).unwrap();
4198 let sql = "SELECT $1::int4 IS NULL AS is_null";
4199 let hash = hash_sql(sql);
4200 let val: Option<i32> = None;
4201 let _result = conn
4202 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
4203 .unwrap();
4204 let _ = conn.close();
4205 }
4206
4207 #[test]
4208 #[ignore] fn sync_query_various_param_types_if_pg_available() {
4210 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4211 let mut conn = Connection::connect(&config).unwrap();
4212 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
4213 let hash = hash_sql(sql);
4214 let p1: i32 = 42;
4215 let p2: i64 = 9999999;
4216 let p3: &str = "hello";
4217 let p4: bool = true;
4218 let p5: f64 = 3.14;
4219 let result = conn
4220 .query(
4221 sql,
4222 hash,
4223 &[
4224 &p1 as &(dyn Encode + Sync),
4225 &p2 as &(dyn Encode + Sync),
4226 &p3 as &(dyn Encode + Sync),
4227 &p4 as &(dyn Encode + Sync),
4228 &p5 as &(dyn Encode + Sync),
4229 ],
4230 )
4231 .unwrap();
4232 assert_eq!(result.len(), 1);
4233 let _ = conn.close();
4234 }
4235
4236 #[test]
4239 fn sync_shrink_threshold_values() {
4240 let shrink = 64 * 1024usize;
4249 let initial = 8192usize;
4250 assert!(
4251 shrink > initial,
4252 "shrink threshold must exceed initial size"
4253 );
4254 }
4255
4256 #[test]
4259 fn sync_connection_debug_format() {
4260 let fmt_str = format!(
4264 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
4265 0, 'I', 0
4266 );
4267 assert!(fmt_str.contains("Connection"));
4268 assert!(fmt_str.contains("pid"));
4269 assert!(fmt_str.contains("tx_status"));
4270 }
4271
4272 #[test]
4275 fn sync_connect_sslmode_require_without_tls_feature() {
4276 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
4280 config.ssl = SslMode::Require;
4281 let result = Connection::connect(&config);
4282 assert!(result.is_err());
4283 }
4288
4289 #[test]
4290 fn sync_connect_sslmode_disable_attempts_tcp() {
4291 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
4292 config.ssl = SslMode::Disable;
4293 let result = Connection::connect(&config);
4294 assert!(result.is_err());
4295 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
4297 }
4298
4299 #[test]
4300 fn sync_connect_sslmode_prefer_attempts_tcp() {
4301 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
4302 config.ssl = SslMode::Prefer;
4303 let result = Connection::connect(&config);
4304 assert!(result.is_err());
4305 }
4306
4307 #[test]
4310 #[ignore] fn sync_streaming_basic_if_pg_available() {
4312 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4313 let mut conn = Connection::connect(&config).unwrap();
4314 assert!(!conn.is_streaming());
4315
4316 let sql = "SELECT generate_series(1, 10)";
4317 let hash = hash_sql(sql);
4318
4319 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
4320 assert!(!cols.is_empty());
4321 assert!(conn.is_streaming());
4322
4323 let mut arena = Arena::new();
4324 let mut offsets = Vec::new();
4325 let mut total_rows = 0;
4326
4327 loop {
4329 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
4330 total_rows += offsets.len();
4331 if !has_more {
4332 break;
4333 }
4334 conn.streaming_send_execute(3).unwrap();
4335 }
4336
4337 assert_eq!(total_rows, 10);
4338 assert!(!conn.is_streaming());
4339 let _ = conn.close();
4340 }
4341
4342 #[test]
4345 #[ignore] fn sync_prepare_describe_if_pg_available() {
4347 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4348 let mut conn = Connection::connect(&config).unwrap();
4349
4350 let result = conn
4351 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
4352 .unwrap();
4353 assert_eq!(result.columns.len(), 1);
4354 assert_eq!(&*result.columns[0].name, "sum");
4355 assert_eq!(result.param_oids.len(), 2);
4356 let _ = conn.close();
4357 }
4358
4359 #[test]
4362 #[ignore] fn sync_wait_for_notification_if_pg_available() {
4364 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4365 let mut conn = Connection::connect(&config).unwrap();
4366
4367 conn.simple_query("LISTEN test_chan").unwrap();
4368 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
4369
4370 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
4372 .unwrap();
4373
4374 let (channel, payload) = conn.wait_for_notification().unwrap();
4375 assert_eq!(channel, "test_chan");
4376 assert_eq!(payload, "hello");
4377 let _ = conn.close();
4378 }
4379
4380 #[test]
4383 #[ignore] fn sync_cancel_if_pg_available() {
4385 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4386 let conn = Connection::connect(&config).unwrap();
4387 let result = conn.cancel();
4390 let _ = result;
4392 let _ = conn.close();
4393 }
4394
4395 #[test]
4398 #[ignore] fn sync_server_params_if_pg_available() {
4400 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4401 let conn = Connection::connect(&config).unwrap();
4402 let params = conn.server_params();
4403 assert!(
4404 !params.is_empty(),
4405 "server should send parameters during startup"
4406 );
4407 assert!(
4409 conn.parameter("server_encoding").is_some(),
4410 "server_encoding should be present"
4411 );
4412 let _ = conn.close();
4413 }
4414
4415 #[test]
4418 #[ignore] fn sync_set_read_timeout_if_pg_available() {
4420 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
4421 let conn = Connection::connect(&config).unwrap();
4422 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
4424 .unwrap();
4425 conn.set_read_timeout(None).unwrap();
4426 let _ = conn.close();
4427 }
4428}