1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::DriverError;
18use crate::arena::Arena;
19use crate::auth;
20use crate::codec::Encode;
21use crate::proto::{self, BackendMessage};
22use crate::stmt_cache::{StmtCache, StmtInfo, build_bind_template, make_stmt_name};
23use crate::sync_io::Stream;
24use crate::types::{
25 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
26 StartupAction,
27};
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42pub(crate) fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58pub struct Connection {
87 stream: Stream,
88 read_buf: Vec<u8>,
89 stream_buf: Vec<u8>,
90 stream_buf_pos: usize,
91 stream_buf_end: usize,
92 write_buf: Vec<u8>,
93 stmts: StmtCache,
94 params: Vec<(Box<str>, Box<str>)>,
95 pid: i32,
96 secret: i32,
97 tx_status: u8,
98 last_used: std::time::Instant,
99 streaming_active: bool,
100 created_at: std::time::Instant,
101 pending_notifications: Vec<Notification>,
102 max_stmt_cache_size: usize,
103 query_counter: u64,
104 connect_config: Arc<Config>,
107 tls_server_cert_hash: Option<[u8; 32]>,
110}
111
112impl std::fmt::Debug for Connection {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("Connection")
115 .field("pid", &self.pid)
116 .field("tx_status", &(self.tx_status as char))
117 .field("stmt_cache_len", &self.stmts.len())
118 .finish()
119 }
120}
121
122impl Connection {
123 pub fn connect(config: &Config) -> Result<Self, DriverError> {
135 Self::connect_arc(Arc::new(config.clone()))
136 }
137
138 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
143 config.validate()?;
144
145 #[allow(unused_mut)]
148 let mut tls_cert_hash: Option<[u8; 32]> = None;
149
150 let stream = if config.host_is_uds() {
151 #[cfg(unix)]
153 {
154 let path = config.uds_path();
155 let unix =
156 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
157 Stream::Unix(unix)
158 }
159 #[cfg(not(unix))]
160 {
161 return Err(DriverError::Protocol(
162 "Unix domain sockets are not supported on this platform".into(),
163 ));
164 }
165 } else {
166 let addr = format!("{}:{}", config.host, config.port);
168 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
169
170 match config.ssl {
171 SslMode::Disable => {
172 tcp.set_nodelay(true).map_err(DriverError::Io)?;
173 let stream = Stream::Tcp(tcp);
174 stream.set_keepalive()?;
175 stream
176 }
177 SslMode::Prefer | SslMode::Require => {
178 #[cfg(feature = "tls")]
179 {
180 match crate::tls_sync::try_upgrade(
181 tcp,
182 &config.host,
183 config.ssl == SslMode::Require,
184 ) {
185 Ok(result) => {
186 tls_cert_hash = result.server_cert_hash;
187 let stream = Stream::Tls(Box::new(result.stream));
188 stream.set_nodelay()?;
189 stream.set_keepalive()?;
190 stream
191 }
192 Err(e) => {
193 if config.ssl == SslMode::Require {
194 return Err(e);
195 }
196 let tcp =
198 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
199 tcp.set_nodelay(true).map_err(DriverError::Io)?;
200 let stream = Stream::Tcp(tcp);
201 stream.set_keepalive()?;
202 stream
203 }
204 }
205 }
206 #[cfg(not(feature = "tls"))]
207 {
208 if config.ssl == SslMode::Require {
209 return Err(DriverError::Protocol(
210 "sslmode=require but bsql was compiled without the 'tls' feature"
211 .into(),
212 ));
213 }
214 tcp.set_nodelay(true).map_err(DriverError::Io)?;
215 let stream = Stream::Tcp(tcp);
216 stream.set_keepalive()?;
217 stream
218 }
219 }
220 }
221 };
222
223 let mut conn = Self {
224 stream,
225 read_buf: Vec::with_capacity(8192),
226 stream_buf: vec![0u8; 65536],
227 stream_buf_pos: 0,
228 stream_buf_end: 0,
229 write_buf: Vec::with_capacity(4096),
230 stmts: StmtCache::default(),
231 params: Vec::new(),
232 pid: 0,
233 secret: 0,
234 tx_status: b'I',
235 last_used: std::time::Instant::now(),
236 streaming_active: false,
237 created_at: std::time::Instant::now(),
238 pending_notifications: Vec::new(),
239 max_stmt_cache_size: 256,
240 query_counter: 0,
241 connect_config: config.clone(),
242 tls_server_cert_hash: tls_cert_hash,
243 };
244
245 conn.startup(&config)?;
246 conn.validate_server_params()?;
247
248 if config.statement_timeout_secs > 0 {
249 conn.simple_query(&format!(
250 "SET statement_timeout = '{}s'",
251 config.statement_timeout_secs
252 ))?;
253 }
254
255 Ok(conn)
256 }
257
258 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
261 self.write_buf.clear();
262 proto::write_startup(&mut self.write_buf, &config.user, &config.database);
263 self.flush_write()?;
264
265 loop {
266 let action = self.read_startup_action()?;
267 match action {
268 StartupAction::AuthOk => {}
269 StartupAction::AuthCleartext => {
270 self.write_buf.clear();
271 let mut pw = config.password.as_bytes().to_vec();
272 pw.push(0);
273 proto::write_password(&mut self.write_buf, &pw);
274 self.flush_write()?;
275 }
276 StartupAction::AuthMd5(salt) => {
277 self.write_buf.clear();
278 let hash = auth::md5_password(&config.user, &config.password, &salt);
279 proto::write_password(&mut self.write_buf, &hash);
280 self.flush_write()?;
281 }
282 StartupAction::AuthSasl(mechanisms_data) => {
283 self.handle_scram(config, &mechanisms_data)?;
284 }
285 StartupAction::ParameterStatus(name, value) => {
286 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
287 entry.1 = value;
288 } else {
289 self.params.push((name, value));
290 }
291 }
292 StartupAction::BackendKeyData(pid, secret) => {
293 self.pid = pid;
294 self.secret = secret;
295 }
296 StartupAction::ReadyForQuery(status) => {
297 self.tx_status = status;
298 return Ok(());
299 }
300 StartupAction::Error(msg) => {
301 return Err(DriverError::Auth(msg));
302 }
303 StartupAction::Notice => {}
304 }
305 }
306 }
307
308 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
309 let (msg_type, _) = self.read_message_buffered()?;
310 let payload = &self.read_buf;
311 let msg = proto::parse_backend_message(msg_type, payload)?;
312 match msg {
313 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
314 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
315 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
316 BackendMessage::AuthSasl { mechanisms } => {
317 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
318 }
319 BackendMessage::ParameterStatus { name, value } => {
320 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
321 }
322 BackendMessage::BackendKeyData { pid, secret } => {
323 Ok(StartupAction::BackendKeyData(pid, secret))
324 }
325 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
326 BackendMessage::ErrorResponse { data } => {
327 let fields = proto::parse_error_response(data);
328 Ok(StartupAction::Error(fields.to_string()))
329 }
330 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
331 other => Err(DriverError::Protocol(format!(
332 "unexpected message during startup: {other:?}"
333 ))),
334 }
335 }
336
337 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
338 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
339
340 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
343 let mechanism = if use_plus {
344 "SCRAM-SHA-256-PLUS"
345 } else {
346 "SCRAM-SHA-256"
347 };
348
349 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
350 return Err(DriverError::Auth(format!(
351 "server requires unsupported SASL mechanism(s): {mechs:?}"
352 )));
353 }
354
355 let cert_hash = if use_plus {
356 self.tls_server_cert_hash.as_ref()
357 } else {
358 None
359 };
360 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
361
362 let client_first = scram.client_first_message();
364 self.write_buf.clear();
365 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
366 self.flush_write()?;
367
368 let (msg_type, _) = self.read_message_buffered()?;
370 let server_first = {
371 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
372 match msg {
373 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
374 BackendMessage::ErrorResponse { data } => {
375 let fields = proto::parse_error_response(data);
376 return Err(DriverError::Auth(fields.to_string()));
377 }
378 other => {
379 return Err(DriverError::Protocol(format!(
380 "expected AuthSaslContinue, got: {other:?}"
381 )));
382 }
383 }
384 };
385
386 scram.process_server_first(&server_first)?;
387
388 let client_final = scram.client_final_message()?;
390 self.write_buf.clear();
391 proto::write_sasl_response(&mut self.write_buf, &client_final);
392 self.flush_write()?;
393
394 let (msg_type, _) = self.read_message_buffered()?;
396 {
397 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
398 match msg {
399 BackendMessage::AuthSaslFinal { data } => {
400 let data_owned = data.to_vec();
401 scram.verify_server_final(&data_owned)?;
402 }
403 BackendMessage::ErrorResponse { data } => {
404 let fields = proto::parse_error_response(data);
405 return Err(DriverError::Auth(fields.to_string()));
406 }
407 other => {
408 return Err(DriverError::Protocol(format!(
409 "expected AuthSaslFinal, got: {other:?}"
410 )));
411 }
412 }
413 }
414
415 let (msg_type, _) = self.read_message_buffered()?;
417 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
418 match msg {
419 BackendMessage::AuthOk => Ok(()),
420 BackendMessage::ErrorResponse { data } => {
421 let fields = proto::parse_error_response(data);
422 Err(DriverError::Auth(fields.to_string()))
423 }
424 other => Err(DriverError::Protocol(format!(
425 "expected AuthOk after SCRAM, got: {other:?}"
426 ))),
427 }
428 }
429
430 fn validate_server_params(&self) -> Result<(), DriverError> {
431 if let Some(encoding) = self.parameter("server_encoding") {
432 let normalized = encoding.to_uppercase();
433 if normalized != "UTF8" && normalized != "UTF-8" {
434 return Err(DriverError::Protocol(format!(
435 "server_encoding is '{encoding}', but bsql requires UTF-8."
436 )));
437 }
438 }
439 if let Some(encoding) = self.parameter("client_encoding") {
440 let normalized = encoding.to_uppercase();
441 if normalized != "UTF8" && normalized != "UTF-8" {
442 return Err(DriverError::Protocol(format!(
443 "client_encoding is '{encoding}', but bsql requires UTF-8."
444 )));
445 }
446 }
447 if let Some(idt) = self.parameter("integer_datetimes") {
448 if idt != "on" {
449 return Err(DriverError::Protocol(format!(
450 "integer_datetimes is '{idt}', but bsql requires 'on'."
451 )));
452 }
453 }
454 Ok(())
455 }
456
457 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
463 if self.stmts.contains_key(&sql_hash) {
464 return Ok(());
465 }
466 let name = make_stmt_name(sql_hash);
467 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
468 self.write_buf.clear();
469 proto::write_parse(&mut self.write_buf, name_s, sql, &[]);
470 proto::write_describe(&mut self.write_buf, b'S', name_s);
471 proto::write_sync(&mut self.write_buf);
472 self.flush_write()?;
473
474 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
475 let columns = self.read_column_description()?;
476 self.expect_ready()?;
477
478 self.query_counter += 1;
479 self.cache_stmt(
480 sql_hash,
481 StmtInfo {
482 name,
483 columns,
484 last_used: self.query_counter,
485 bind_template: None,
486 },
487 );
488 Ok(())
489 }
490
491 #[inline]
502 pub fn query(
503 &mut self,
504 sql: &str,
505 sql_hash: u64,
506 params: &[&(dyn Encode + Sync)],
507 ) -> Result<QueryResult, DriverError> {
508 let columns = self
509 .send_pipeline(sql, sql_hash, params, true, true)?
510 .expect("send_pipeline(need_columns=true) must return Some");
511
512 let num_cols = columns.len();
513 let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
514 let mut affected_rows: u64 = 0;
515
516 let mut resp_buf = acquire_resp_buf();
525 resp_buf.clear();
526
527 'outer: loop {
529 loop {
530 let avail = self.stream_buf_end - self.stream_buf_pos;
531 if avail < 5 {
532 break; }
534
535 let msg_type = self.stream_buf[self.stream_buf_pos];
536 let raw_len = i32::from_be_bytes([
537 self.stream_buf[self.stream_buf_pos + 1],
538 self.stream_buf[self.stream_buf_pos + 2],
539 self.stream_buf[self.stream_buf_pos + 3],
540 self.stream_buf[self.stream_buf_pos + 4],
541 ]);
542
543 if raw_len < 4 {
544 return Err(DriverError::Protocol(format!(
545 "invalid message length {raw_len} for type '{}'",
546 msg_type as char
547 )));
548 }
549
550 let payload_len = (raw_len - 4) as usize;
551 let total_msg_len = 5 + payload_len;
552
553 if avail < total_msg_len {
554 if total_msg_len > self.stream_buf.len() {
555 let msg = self.read_one_message()?;
557 match msg {
558 BackendMessage::BindComplete => continue,
559 BackendMessage::DataRow { data } => {
560 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
561 continue;
562 }
563 BackendMessage::CommandComplete { tag } => {
564 affected_rows = proto::parse_command_tag(tag);
565 continue;
566 }
567 BackendMessage::EmptyQuery => continue,
568 BackendMessage::ReadyForQuery { status } => {
569 self.tx_status = status;
570 break 'outer;
571 }
572 BackendMessage::NoticeResponse { .. } => continue,
573 BackendMessage::ErrorResponse { data } => {
574 let fields = proto::parse_error_response(data);
575 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
576 self.drain_to_ready()?;
577 return Err(self.make_server_error(fields));
578 }
579 other => {
580 return Err(DriverError::Protocol(format!(
581 "unexpected message during query: {other:?}"
582 )));
583 }
584 }
585 }
586 break; }
588
589 let payload_start = self.stream_buf_pos + 5;
591 let payload_end = payload_start + payload_len;
592
593 if msg_type == b'D' {
594 parse_data_row_into_buf(
596 &self.stream_buf[payload_start..payload_end],
597 &mut resp_buf,
598 &mut all_col_offsets,
599 )?;
600 } else if msg_type == b'Z' {
601 if payload_len >= 1 {
602 self.tx_status = self.stream_buf[payload_start];
603 }
604 self.stream_buf_pos += total_msg_len;
605 break 'outer;
606 } else {
607 self.handle_non_datarow_query(
608 msg_type,
609 payload_start,
610 payload_end,
611 sql_hash,
612 &mut affected_rows,
613 )?;
614 }
615
616 self.stream_buf_pos += total_msg_len;
617 }
618
619 self.refill_stream_buf()?;
620 }
621
622 self.shrink_buffers();
623
624 Ok(QueryResult::from_parts_with_buf(
627 all_col_offsets,
628 num_cols,
629 columns,
630 affected_rows,
631 resp_buf,
632 ))
633 }
634
635 #[inline]
646 pub fn execute_monolithic(
647 &mut self,
648 sql: &str,
649 sql_hash: u64,
650 params: &[&(dyn Encode + Sync)],
651 ) -> Result<u64, DriverError> {
652 self.write_buf.clear();
654
655 let info = match self.stmts.get_mut(&sql_hash) {
657 Some(info) => {
658 self.query_counter += 1;
659 info.last_used = self.query_counter;
660 info
661 }
662 None => {
663 return self.execute_with_prepare(sql, sql_hash, params);
665 }
666 };
667
668 let can_use_template = info
670 .bind_template
671 .as_ref()
672 .is_some_and(|t| t.param_slots.len() == params.len());
673
674 let mut has_exec_sync = false;
675
676 if can_use_template {
677 let tmpl = info
679 .bind_template
680 .as_ref()
681 .expect("guarded by can_use_template");
682 self.write_buf.extend_from_slice(&tmpl.bytes);
683
684 let mut template_ok = true;
685 for (i, param) in params.iter().enumerate() {
686 let (data_offset, old_len) = tmpl.param_slots[i];
687 if param.is_null() {
688 let len_offset = data_offset - 4;
689 self.write_buf[len_offset..len_offset + 4]
690 .copy_from_slice(&(-1i32).to_be_bytes());
691 } else if old_len >= 0 {
692 let end = data_offset + old_len as usize;
693 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
694 template_ok = false;
695 break;
696 }
697 } else {
698 template_ok = false;
700 break;
701 }
702 }
703
704 if template_ok {
705 has_exec_sync = true;
706 } else {
707 self.write_buf.clear();
708 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
709 info.bind_template = None;
710 }
711 } else {
712 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
713 }
714
715 if info.bind_template.is_none() && !self.write_buf.is_empty() {
717 info.bind_template = build_bind_template(&self.write_buf, params.len());
718 }
719
720 if !has_exec_sync {
721 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
722 }
723
724 self.stream
726 .write_all(&self.write_buf)
727 .map_err(DriverError::Io)?;
728
729 let mut affected_rows: u64 = 0;
731
732 'outer: loop {
733 loop {
734 let avail = self.stream_buf_end - self.stream_buf_pos;
735 if avail < 5 {
736 break; }
738
739 let msg_type = self.stream_buf[self.stream_buf_pos];
740 let raw_len = i32::from_be_bytes([
741 self.stream_buf[self.stream_buf_pos + 1],
742 self.stream_buf[self.stream_buf_pos + 2],
743 self.stream_buf[self.stream_buf_pos + 3],
744 self.stream_buf[self.stream_buf_pos + 4],
745 ]);
746
747 if raw_len < 4 {
748 return Err(DriverError::Protocol(format!(
749 "invalid message length {raw_len} for type '{}'",
750 msg_type as char
751 )));
752 }
753
754 let payload_len = (raw_len - 4) as usize;
755 let total_msg_len = 5 + payload_len;
756
757 if avail < total_msg_len {
758 if total_msg_len > self.stream_buf.len() {
759 let msg = self.read_one_message()?;
760 match msg {
761 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
762 continue;
763 }
764 BackendMessage::CommandComplete { tag } => {
765 affected_rows = proto::parse_command_tag(tag);
766 continue;
767 }
768 BackendMessage::EmptyQuery => continue,
769 BackendMessage::ReadyForQuery { status } => {
770 self.tx_status = status;
771 break 'outer;
772 }
773 BackendMessage::NoticeResponse { .. } => continue,
774 BackendMessage::ErrorResponse { data } => {
775 let fields = proto::parse_error_response(data);
776 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
777 self.drain_to_ready()?;
778 return Err(self.make_server_error(fields));
779 }
780 other => {
781 return Err(DriverError::Protocol(format!(
782 "unexpected message during execute: {other:?}"
783 )));
784 }
785 }
786 }
787 break; }
789
790 let payload_start = self.stream_buf_pos + 5;
795 let payload_end = payload_start + payload_len;
796
797 if msg_type == b'2' {
798 self.stream_buf_pos += total_msg_len;
800 continue;
801 } else if msg_type == b'C' {
802 affected_rows = proto::parse_command_tag_bytes(
804 &self.stream_buf[payload_start..payload_end],
805 );
806 } else if msg_type == b'Z' {
807 if payload_len >= 1 {
809 self.tx_status = self.stream_buf[payload_start];
810 }
811 self.stream_buf_pos += total_msg_len;
812 break 'outer;
813 } else if msg_type == b'D' || msg_type == b'I' {
814 } else {
816 self.handle_non_datarow_execute(
817 msg_type,
818 payload_start,
819 payload_end,
820 sql_hash,
821 )?;
822 }
823
824 self.stream_buf_pos += total_msg_len;
825 }
826
827 let remaining = self.stream_buf_end - self.stream_buf_pos;
829 debug_assert!(
830 remaining == 0 || self.stream_buf_pos > 0,
831 "compact called with pos=0 and remaining data"
832 );
833 if remaining > 0 {
834 self.stream_buf
835 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
836 }
837 self.stream_buf_pos = 0;
838 self.stream_buf_end = remaining;
839 let n = self
840 .stream
841 .read(&mut self.stream_buf[remaining..])
842 .map_err(DriverError::Io)?;
843 if n == 0 {
844 return Err(DriverError::Io(std::io::Error::new(
845 std::io::ErrorKind::UnexpectedEof,
846 "connection closed",
847 )));
848 }
849 self.stream_buf_end = remaining + n;
850 }
851
852 if self.query_counter & 63 == 0 {
854 if self.read_buf.capacity() > 64 * 1024 {
855 self.read_buf.clear();
856 self.read_buf.shrink_to(8192);
857 }
858 if self.write_buf.capacity() > 16 * 1024 {
859 self.write_buf.clear();
860 self.write_buf.shrink_to(8192);
861 }
862 }
863
864 Ok(affected_rows)
865 }
866
867 #[cold]
869 #[inline(never)]
870 fn execute_with_prepare(
871 &mut self,
872 sql: &str,
873 sql_hash: u64,
874 params: &[&(dyn Encode + Sync)],
875 ) -> Result<u64, DriverError> {
876 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
877
878 if params.len() > i16::MAX as usize {
879 return Err(DriverError::Protocol(format!(
880 "parameter count {} exceeds maximum {}",
881 params.len(),
882 i16::MAX
883 )));
884 }
885
886 let name = make_stmt_name(sql_hash);
887 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
888 let param_oids: smallvec::SmallVec<[u32; 8]> =
889 params.iter().map(|p| p.type_oid()).collect();
890
891 self.write_buf.clear();
892 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
893 proto::write_describe(&mut self.write_buf, b'S', name_s);
894 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
895 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
896 self.stream
897 .write_all(&self.write_buf)
898 .map_err(DriverError::Io)?;
899
900 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
901 let columns = self.read_column_description()?;
902 self.query_counter += 1;
903 self.cache_stmt(
904 sql_hash,
905 StmtInfo {
906 name,
907 columns,
908 last_used: self.query_counter,
909 bind_template: None,
910 },
911 );
912
913 let mut affected_rows: u64 = 0;
915 'outer: loop {
916 loop {
917 let avail = self.stream_buf_end - self.stream_buf_pos;
918 if avail < 5 {
919 break;
920 }
921
922 let msg_type = self.stream_buf[self.stream_buf_pos];
923 let raw_len = i32::from_be_bytes([
924 self.stream_buf[self.stream_buf_pos + 1],
925 self.stream_buf[self.stream_buf_pos + 2],
926 self.stream_buf[self.stream_buf_pos + 3],
927 self.stream_buf[self.stream_buf_pos + 4],
928 ]);
929
930 if raw_len < 4 {
931 return Err(DriverError::Protocol(format!(
932 "invalid message length {raw_len} for type '{}'",
933 msg_type as char
934 )));
935 }
936
937 let payload_len = (raw_len - 4) as usize;
938 let total_msg_len = 5 + payload_len;
939
940 if avail < total_msg_len {
941 if total_msg_len > self.stream_buf.len() {
942 let msg = self.read_one_message()?;
943 match msg {
944 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
945 continue;
946 }
947 BackendMessage::CommandComplete { tag } => {
948 affected_rows = proto::parse_command_tag(tag);
949 continue;
950 }
951 BackendMessage::EmptyQuery => continue,
952 BackendMessage::ReadyForQuery { status } => {
953 self.tx_status = status;
954 break 'outer;
955 }
956 BackendMessage::NoticeResponse { .. } => continue,
957 BackendMessage::ErrorResponse { data } => {
958 let fields = proto::parse_error_response(data);
959 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
960 self.drain_to_ready()?;
961 return Err(self.make_server_error(fields));
962 }
963 other => {
964 return Err(DriverError::Protocol(format!(
965 "unexpected message during execute: {other:?}"
966 )));
967 }
968 }
969 }
970 break;
971 }
972
973 let payload_start = self.stream_buf_pos + 5;
974 let payload_end = payload_start + payload_len;
975
976 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
977 } else if msg_type == b'C' {
979 affected_rows = proto::parse_command_tag_bytes(
980 &self.stream_buf[payload_start..payload_end],
981 );
982 } else if msg_type == b'Z' {
983 if payload_len >= 1 {
984 self.tx_status = self.stream_buf[payload_start];
985 }
986 self.stream_buf_pos += total_msg_len;
987 break 'outer;
988 } else {
989 self.handle_non_datarow_execute(
990 msg_type,
991 payload_start,
992 payload_end,
993 sql_hash,
994 )?;
995 }
996
997 self.stream_buf_pos += total_msg_len;
998 }
999
1000 self.refill_stream_buf()?;
1001 }
1002
1003 Ok(affected_rows)
1004 }
1005
1006 #[inline]
1011 pub fn execute(
1012 &mut self,
1013 sql: &str,
1014 sql_hash: u64,
1015 params: &[&(dyn Encode + Sync)],
1016 ) -> Result<u64, DriverError> {
1017 self.execute_monolithic(sql, sql_hash, params)
1018 }
1019
1020 pub fn execute_pipeline(
1032 &mut self,
1033 sql: &str,
1034 sql_hash: u64,
1035 param_sets: &[&[&(dyn Encode + Sync)]],
1036 ) -> Result<Vec<u64>, DriverError> {
1037 if param_sets.is_empty() {
1038 return Ok(Vec::new());
1039 }
1040
1041 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1042
1043 self.write_buf.clear();
1044
1045 if !self.stmts.contains_key(&sql_hash) {
1047 let name = make_stmt_name(sql_hash);
1048 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1049 let first_params = param_sets[0];
1050 if first_params.len() > i16::MAX as usize {
1051 return Err(DriverError::Protocol(format!(
1052 "parameter count {} exceeds maximum {}",
1053 first_params.len(),
1054 i16::MAX
1055 )));
1056 }
1057 let param_oids: smallvec::SmallVec<[u32; 8]> =
1058 first_params.iter().map(|p| p.type_oid()).collect();
1059 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1060 proto::write_describe(&mut self.write_buf, b'S', name_s);
1061 proto::write_sync(&mut self.write_buf);
1062 self.flush_write()?;
1063
1064 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1065 let columns = self.read_column_description()?;
1066 self.expect_ready()?;
1067
1068 self.query_counter += 1;
1069 self.cache_stmt(
1070 sql_hash,
1071 StmtInfo {
1072 name,
1073 columns,
1074 last_used: self.query_counter,
1075 bind_template: None,
1076 },
1077 );
1078
1079 self.write_buf.clear();
1080 }
1081
1082 let stmt_name = self
1084 .stmts
1085 .get(&sql_hash)
1086 .expect("BUG: stmt just cached but not found")
1087 .name_str()
1088 .to_owned();
1089 let count = param_sets.len();
1090
1091 for params in param_sets {
1092 if params.len() > i16::MAX as usize {
1093 return Err(DriverError::Protocol(format!(
1094 "parameter count {} exceeds maximum {}",
1095 params.len(),
1096 i16::MAX
1097 )));
1098 }
1099 proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1100 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1101 }
1102
1103 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1104 self.flush_write()?;
1105
1106 let mut results = Vec::with_capacity(count);
1108 for _ in 0..count {
1109 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1110
1111 let mut affected_rows: u64 = 0;
1112 loop {
1113 let msg = self.read_one_message()?;
1114 match msg {
1115 BackendMessage::DataRow { .. } => {}
1116 BackendMessage::CommandComplete { tag } => {
1117 affected_rows = proto::parse_command_tag(tag);
1118 break;
1119 }
1120 BackendMessage::EmptyQuery => break,
1121 BackendMessage::NoticeResponse { .. } => {}
1122 BackendMessage::ErrorResponse { data } => {
1123 let fields = proto::parse_error_response(data);
1124 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1125 self.drain_to_ready()?;
1126 return Err(self.make_server_error(fields));
1127 }
1128 other => {
1129 return Err(DriverError::Protocol(format!(
1130 "unexpected message during execute_pipeline: {other:?}"
1131 )));
1132 }
1133 }
1134 }
1135 results.push(affected_rows);
1136 }
1137
1138 self.expect_ready()?;
1139 self.shrink_buffers();
1140 Ok(results)
1141 }
1142
1143 pub(crate) fn ensure_stmt_prepared(
1149 &mut self,
1150 sql: &str,
1151 sql_hash: u64,
1152 params: &[&(dyn Encode + Sync)],
1153 ) -> Result<[u8; 18], DriverError> {
1154 if let Some(info) = self.stmts.get(&sql_hash) {
1155 return Ok(info.name);
1156 }
1157
1158 let name = make_stmt_name(sql_hash);
1159 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1160 if params.len() > i16::MAX as usize {
1161 return Err(DriverError::Protocol(format!(
1162 "parameter count {} exceeds maximum {}",
1163 params.len(),
1164 i16::MAX
1165 )));
1166 }
1167 let param_oids: smallvec::SmallVec<[u32; 8]> =
1168 params.iter().map(|p| p.type_oid()).collect();
1169
1170 self.write_buf.clear();
1171 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1172 proto::write_describe(&mut self.write_buf, b'S', name_s);
1173 proto::write_sync(&mut self.write_buf);
1174 self.flush_write()?;
1175
1176 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1177 let columns = self.read_column_description()?;
1178 self.expect_ready()?;
1179
1180 self.query_counter += 1;
1181 self.cache_stmt(
1182 sql_hash,
1183 StmtInfo {
1184 name,
1185 columns,
1186 last_used: self.query_counter,
1187 bind_template: None,
1188 },
1189 );
1190
1191 Ok(name)
1192 }
1193
1194 pub(crate) fn write_deferred_bind_execute(
1197 &self,
1198 sql_hash: u64,
1199 params: &[&(dyn Encode + Sync)],
1200 buf: &mut Vec<u8>,
1201 ) {
1202 let stmt_name = self
1203 .stmts
1204 .get(&sql_hash)
1205 .expect("BUG: stmt just cached but not found")
1206 .name_str();
1207 proto::write_bind_params(buf, "", stmt_name, params);
1208 buf.extend_from_slice(proto::EXECUTE_ONLY);
1209 }
1210
1211 pub(crate) fn flush_deferred_pipeline(
1216 &mut self,
1217 buf: &mut Vec<u8>,
1218 count: usize,
1219 ) -> Result<Vec<u64>, DriverError> {
1220 if count == 0 {
1221 buf.clear();
1222 return Ok(Vec::new());
1223 }
1224
1225 buf.extend_from_slice(proto::SYNC_ONLY);
1226
1227 self.stream.write_all(buf).map_err(DriverError::Io)?;
1228 buf.clear();
1229
1230 let mut results = Vec::with_capacity(count);
1231 for _ in 0..count {
1232 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1233
1234 let mut affected_rows: u64 = 0;
1235 loop {
1236 let msg = self.read_one_message()?;
1237 match msg {
1238 BackendMessage::DataRow { .. } => {}
1239 BackendMessage::CommandComplete { tag } => {
1240 affected_rows = proto::parse_command_tag(tag);
1241 break;
1242 }
1243 BackendMessage::EmptyQuery => break,
1244 BackendMessage::NoticeResponse { .. } => {}
1245 BackendMessage::ErrorResponse { data } => {
1246 let fields = proto::parse_error_response(data);
1247 self.drain_to_ready()?;
1248 return Err(self.make_server_error(fields));
1249 }
1250 other => {
1251 return Err(DriverError::Protocol(format!(
1252 "unexpected message during flush_deferred_pipeline: {other:?}"
1253 )));
1254 }
1255 }
1256 }
1257 results.push(affected_rows);
1258 }
1259
1260 self.expect_ready()?;
1261 self.shrink_buffers();
1262 Ok(results)
1263 }
1264
1265 pub fn for_each<F>(
1267 &mut self,
1268 sql: &str,
1269 sql_hash: u64,
1270 params: &[&(dyn Encode + Sync)],
1271 mut f: F,
1272 ) -> Result<(), DriverError>
1273 where
1274 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1275 {
1276 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1277
1278 'outer: loop {
1280 loop {
1281 let avail = self.stream_buf_end - self.stream_buf_pos;
1282 if avail < 5 {
1283 break; }
1285
1286 let msg_type = self.stream_buf[self.stream_buf_pos];
1287 let raw_len = i32::from_be_bytes([
1288 self.stream_buf[self.stream_buf_pos + 1],
1289 self.stream_buf[self.stream_buf_pos + 2],
1290 self.stream_buf[self.stream_buf_pos + 3],
1291 self.stream_buf[self.stream_buf_pos + 4],
1292 ]);
1293
1294 if raw_len < 4 {
1295 return Err(DriverError::Protocol(format!(
1296 "invalid message length {raw_len} for type '{}'",
1297 msg_type as char
1298 )));
1299 }
1300
1301 let payload_len = (raw_len - 4) as usize;
1302 let total_msg_len = 5 + payload_len;
1303
1304 if avail < total_msg_len {
1305 if total_msg_len > self.stream_buf.len() {
1306 let msg = self.read_one_message()?;
1308 match msg {
1309 BackendMessage::BindComplete => continue,
1310 BackendMessage::DataRow { data } => {
1311 let row = PgDataRow::new(data)?;
1312 f(row)?;
1313 continue;
1314 }
1315 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1316 continue;
1317 }
1318 BackendMessage::ReadyForQuery { status } => {
1319 self.tx_status = status;
1320 break 'outer;
1321 }
1322 BackendMessage::NoticeResponse { .. } => continue,
1323 BackendMessage::ErrorResponse { data } => {
1324 let fields = proto::parse_error_response(data);
1325 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1326 self.drain_to_ready()?;
1327 return Err(self.make_server_error(fields));
1328 }
1329 other => {
1330 return Err(DriverError::Protocol(format!(
1331 "unexpected message during for_each: {other:?}"
1332 )));
1333 }
1334 }
1335 }
1336 break; }
1338
1339 let payload_start = self.stream_buf_pos + 5;
1341 let payload_end = payload_start + payload_len;
1342
1343 if msg_type == b'D' {
1346 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1348 f(row)?;
1349 } else if msg_type == b'Z' {
1350 if payload_len >= 1 {
1352 self.tx_status = self.stream_buf[payload_start];
1353 }
1354 self.stream_buf_pos += total_msg_len;
1355 break 'outer;
1356 } else {
1357 self.handle_non_datarow_execute(
1358 msg_type,
1359 payload_start,
1360 payload_end,
1361 sql_hash,
1362 )?;
1363 }
1364
1365 self.stream_buf_pos += total_msg_len;
1366 }
1367
1368 self.refill_stream_buf()?;
1370 }
1371
1372 self.shrink_buffers();
1373 Ok(())
1374 }
1375
1376 #[inline]
1387 pub fn for_each_raw_monolithic<F>(
1388 &mut self,
1389 sql: &str,
1390 sql_hash: u64,
1391 params: &[&(dyn Encode + Sync)],
1392 mut f: F,
1393 ) -> Result<(), DriverError>
1394 where
1395 F: FnMut(&[u8]) -> Result<(), DriverError>,
1396 {
1397 self.write_buf.clear();
1399
1400 let info = match self.stmts.get_mut(&sql_hash) {
1402 Some(info) => {
1403 self.query_counter += 1;
1404 info.last_used = self.query_counter;
1405 info
1406 }
1407 None => {
1408 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1410 }
1411 };
1412
1413 let can_use_template = info
1415 .bind_template
1416 .as_ref()
1417 .is_some_and(|t| t.param_slots.len() == params.len());
1418
1419 let mut has_exec_sync = false;
1420
1421 if can_use_template {
1422 let tmpl = info
1424 .bind_template
1425 .as_ref()
1426 .expect("guarded by can_use_template");
1427 self.write_buf.extend_from_slice(&tmpl.bytes);
1428
1429 let mut template_ok = true;
1430 for (i, param) in params.iter().enumerate() {
1431 let (data_offset, old_len) = tmpl.param_slots[i];
1432 if param.is_null() {
1433 let len_offset = data_offset - 4;
1434 self.write_buf[len_offset..len_offset + 4]
1435 .copy_from_slice(&(-1i32).to_be_bytes());
1436 } else if old_len >= 0 {
1437 let end = data_offset + old_len as usize;
1438 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1439 template_ok = false;
1440 break;
1441 }
1442 } else {
1443 template_ok = false;
1444 break;
1445 }
1446 }
1447
1448 if template_ok {
1449 has_exec_sync = true;
1450 } else {
1451 self.write_buf.clear();
1452 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1453 info.bind_template = None;
1454 }
1455 } else {
1456 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1457 }
1458
1459 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1461 info.bind_template = build_bind_template(&self.write_buf, params.len());
1462 }
1463
1464 if !has_exec_sync {
1465 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1466 }
1467
1468 self.stream
1470 .write_all(&self.write_buf)
1471 .map_err(DriverError::Io)?;
1472
1473 loop {
1477 let avail = self.stream_buf_end - self.stream_buf_pos;
1478 if avail >= 5 {
1479 let bc_type = self.stream_buf[self.stream_buf_pos];
1480 match bc_type {
1481 b'2' => {
1482 self.stream_buf_pos += 5;
1483 break;
1484 }
1485 b'E' => {
1486 let msg = self.read_one_message()?;
1487 if let BackendMessage::ErrorResponse { data } = msg {
1488 let fields = proto::parse_error_response(data);
1489 self.drain_to_ready()?;
1490 return Err(self.make_server_error(fields));
1491 }
1492 }
1493 b'N' | b'S' => {
1494 let raw_len = i32::from_be_bytes([
1495 self.stream_buf[self.stream_buf_pos + 1],
1496 self.stream_buf[self.stream_buf_pos + 2],
1497 self.stream_buf[self.stream_buf_pos + 3],
1498 self.stream_buf[self.stream_buf_pos + 4],
1499 ]);
1500 let total = 1 + raw_len as usize;
1501 if avail >= total {
1502 self.stream_buf_pos += total;
1503 continue;
1504 }
1505 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1506 break;
1507 }
1508 _ => {
1509 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1510 break;
1511 }
1512 }
1513 } else {
1514 let remaining = self.stream_buf_end - self.stream_buf_pos;
1516 if remaining > 0 && self.stream_buf_pos > 0 {
1517 self.stream_buf
1518 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1519 }
1520 self.stream_buf_pos = 0;
1521 self.stream_buf_end = remaining;
1522 let n = self
1523 .stream
1524 .read(&mut self.stream_buf[remaining..])
1525 .map_err(DriverError::Io)?;
1526 if n == 0 {
1527 return Err(DriverError::Io(std::io::Error::new(
1528 std::io::ErrorKind::UnexpectedEof,
1529 "connection closed",
1530 )));
1531 }
1532 self.stream_buf_end = remaining + n;
1533 }
1534 }
1535
1536 'outer: loop {
1538 loop {
1539 let avail = self.stream_buf_end - self.stream_buf_pos;
1540 if avail < 5 {
1541 break;
1542 }
1543
1544 let msg_type = self.stream_buf[self.stream_buf_pos];
1545 let raw_len = i32::from_be_bytes([
1546 self.stream_buf[self.stream_buf_pos + 1],
1547 self.stream_buf[self.stream_buf_pos + 2],
1548 self.stream_buf[self.stream_buf_pos + 3],
1549 self.stream_buf[self.stream_buf_pos + 4],
1550 ]);
1551
1552 if raw_len < 4 {
1553 return Err(DriverError::Protocol(format!(
1554 "invalid message length {raw_len} for type '{}'",
1555 msg_type as char
1556 )));
1557 }
1558
1559 let payload_len = (raw_len - 4) as usize;
1560 let total_msg_len = 5 + payload_len;
1561
1562 if avail < total_msg_len {
1563 if total_msg_len > self.stream_buf.len() {
1564 let msg = self.read_one_message()?;
1565 match msg {
1566 BackendMessage::DataRow { data } => {
1567 f(data)?;
1568 continue;
1569 }
1570 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1571 continue;
1572 }
1573 BackendMessage::ReadyForQuery { status } => {
1574 self.tx_status = status;
1575 break 'outer;
1576 }
1577 BackendMessage::ErrorResponse { data } => {
1578 let fields = proto::parse_error_response(data);
1579 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1580 self.drain_to_ready()?;
1581 return Err(self.make_server_error(fields));
1582 }
1583 BackendMessage::NoticeResponse { .. } => continue,
1584 other => {
1585 return Err(DriverError::Protocol(format!(
1586 "unexpected message during for_each_raw: {other:?}"
1587 )));
1588 }
1589 }
1590 }
1591 break; }
1593
1594 let payload_start = self.stream_buf_pos + 5;
1596 let payload_end = payload_start + payload_len;
1597
1598 if msg_type == b'D' {
1599 f(&self.stream_buf[payload_start..payload_end])?;
1600 } else if msg_type == b'Z' {
1601 if payload_len >= 1 {
1602 self.tx_status = self.stream_buf[payload_start];
1603 }
1604 self.stream_buf_pos += total_msg_len;
1605 break 'outer;
1606 } else {
1607 self.handle_non_datarow_execute(
1608 msg_type,
1609 payload_start,
1610 payload_end,
1611 sql_hash,
1612 )?;
1613 }
1614
1615 self.stream_buf_pos += total_msg_len;
1616 }
1617
1618 let remaining = self.stream_buf_end - self.stream_buf_pos;
1620 if remaining > 0 && self.stream_buf_pos > 0 {
1621 self.stream_buf
1622 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1623 }
1624 self.stream_buf_pos = 0;
1625 self.stream_buf_end = remaining;
1626 let n = self
1627 .stream
1628 .read(&mut self.stream_buf[remaining..])
1629 .map_err(DriverError::Io)?;
1630 if n == 0 {
1631 return Err(DriverError::Io(std::io::Error::new(
1632 std::io::ErrorKind::UnexpectedEof,
1633 "connection closed",
1634 )));
1635 }
1636 self.stream_buf_end = remaining + n;
1637 }
1638
1639 if self.query_counter & 63 == 0 {
1641 if self.read_buf.capacity() > 64 * 1024 {
1642 self.read_buf.clear();
1643 self.read_buf.shrink_to(8192);
1644 }
1645 if self.write_buf.capacity() > 16 * 1024 {
1646 self.write_buf.clear();
1647 self.write_buf.shrink_to(8192);
1648 }
1649 }
1650
1651 Ok(())
1652 }
1653
1654 #[cold]
1656 #[inline(never)]
1657 fn for_each_raw_with_prepare<F>(
1658 &mut self,
1659 sql: &str,
1660 sql_hash: u64,
1661 params: &[&(dyn Encode + Sync)],
1662 mut f: F,
1663 ) -> Result<(), DriverError>
1664 where
1665 F: FnMut(&[u8]) -> Result<(), DriverError>,
1666 {
1667 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1668
1669 if params.len() > i16::MAX as usize {
1670 return Err(DriverError::Protocol(format!(
1671 "parameter count {} exceeds maximum {}",
1672 params.len(),
1673 i16::MAX
1674 )));
1675 }
1676
1677 let name = make_stmt_name(sql_hash);
1678 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1679 let param_oids: smallvec::SmallVec<[u32; 8]> =
1680 params.iter().map(|p| p.type_oid()).collect();
1681
1682 self.write_buf.clear();
1683 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1684 proto::write_describe(&mut self.write_buf, b'S', name_s);
1685 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
1686 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1687 self.stream
1688 .write_all(&self.write_buf)
1689 .map_err(DriverError::Io)?;
1690
1691 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1692 let columns = self.read_column_description()?;
1693 self.query_counter += 1;
1694 self.cache_stmt(
1695 sql_hash,
1696 StmtInfo {
1697 name,
1698 columns,
1699 last_used: self.query_counter,
1700 bind_template: None,
1701 },
1702 );
1703
1704 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1706
1707 'outer: loop {
1708 loop {
1709 let avail = self.stream_buf_end - self.stream_buf_pos;
1710 if avail < 5 {
1711 break;
1712 }
1713
1714 let msg_type = self.stream_buf[self.stream_buf_pos];
1715 let raw_len = i32::from_be_bytes([
1716 self.stream_buf[self.stream_buf_pos + 1],
1717 self.stream_buf[self.stream_buf_pos + 2],
1718 self.stream_buf[self.stream_buf_pos + 3],
1719 self.stream_buf[self.stream_buf_pos + 4],
1720 ]);
1721
1722 if raw_len < 4 {
1723 return Err(DriverError::Protocol(format!(
1724 "invalid message length {raw_len} for type '{}'",
1725 msg_type as char
1726 )));
1727 }
1728
1729 let payload_len = (raw_len - 4) as usize;
1730 let total_msg_len = 5 + payload_len;
1731
1732 if avail < total_msg_len {
1733 if total_msg_len > self.stream_buf.len() {
1734 let msg = self.read_one_message()?;
1735 match msg {
1736 BackendMessage::DataRow { data } => {
1737 f(data)?;
1738 continue;
1739 }
1740 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1741 continue;
1742 }
1743 BackendMessage::ReadyForQuery { status } => {
1744 self.tx_status = status;
1745 break 'outer;
1746 }
1747 BackendMessage::ErrorResponse { data } => {
1748 let fields = proto::parse_error_response(data);
1749 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1750 self.drain_to_ready()?;
1751 return Err(self.make_server_error(fields));
1752 }
1753 BackendMessage::NoticeResponse { .. } => continue,
1754 other => {
1755 return Err(DriverError::Protocol(format!(
1756 "unexpected message during for_each_raw: {other:?}"
1757 )));
1758 }
1759 }
1760 }
1761 break;
1762 }
1763
1764 let payload_start = self.stream_buf_pos + 5;
1765 let payload_end = payload_start + payload_len;
1766
1767 if msg_type == b'D' {
1768 f(&self.stream_buf[payload_start..payload_end])?;
1769 } else if msg_type == b'Z' {
1770 if payload_len >= 1 {
1771 self.tx_status = self.stream_buf[payload_start];
1772 }
1773 self.stream_buf_pos += total_msg_len;
1774 break 'outer;
1775 } else {
1776 self.handle_non_datarow_execute(
1777 msg_type,
1778 payload_start,
1779 payload_end,
1780 sql_hash,
1781 )?;
1782 }
1783
1784 self.stream_buf_pos += total_msg_len;
1785 }
1786
1787 self.refill_stream_buf()?;
1788 }
1789
1790 self.shrink_buffers();
1791 Ok(())
1792 }
1793
1794 #[inline]
1799 pub fn for_each_raw<F>(
1800 &mut self,
1801 sql: &str,
1802 sql_hash: u64,
1803 params: &[&(dyn Encode + Sync)],
1804 f: F,
1805 ) -> Result<(), DriverError>
1806 where
1807 F: FnMut(&[u8]) -> Result<(), DriverError>,
1808 {
1809 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1810 }
1811
1812 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1814 self.write_buf.clear();
1815 proto::write_simple_query(&mut self.write_buf, sql);
1816 self.flush_write()?;
1817
1818 loop {
1819 let msg = self.read_one_message()?;
1820 match msg {
1821 BackendMessage::ReadyForQuery { status } => {
1822 self.tx_status = status;
1823 return Ok(());
1824 }
1825 BackendMessage::CommandComplete { .. }
1826 | BackendMessage::RowDescription { .. }
1827 | BackendMessage::DataRow { .. }
1828 | BackendMessage::EmptyQuery
1829 | BackendMessage::NoticeResponse { .. }
1830 | BackendMessage::ParameterStatus { .. }
1831 | BackendMessage::AuthOk
1835 | BackendMessage::AuthSaslFinal { .. }
1836 | BackendMessage::BackendKeyData { .. } => {}
1837 BackendMessage::ErrorResponse { data } => {
1838 let fields = proto::parse_error_response(data);
1839 self.drain_to_ready()?;
1840 return Err(self.make_server_error(fields));
1841 }
1842 other => {
1843 return Err(DriverError::Protocol(format!(
1844 "unexpected message during simple_query: {other:?}"
1845 )));
1846 }
1847 }
1848 }
1849 }
1850
1851 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1853 self.write_buf.clear();
1854 proto::write_simple_query(&mut self.write_buf, sql);
1855 self.flush_write()?;
1856
1857 let mut rows: Vec<SimpleRow> = Vec::new();
1858 loop {
1859 let msg = self.read_one_message()?;
1860 match msg {
1861 BackendMessage::ReadyForQuery { status } => {
1862 self.tx_status = status;
1863 return Ok(rows);
1864 }
1865 BackendMessage::DataRow { data } => {
1866 rows.push(proto::parse_simple_data_row(data)?);
1867 }
1868 BackendMessage::RowDescription { .. }
1869 | BackendMessage::CommandComplete { .. }
1870 | BackendMessage::EmptyQuery
1871 | BackendMessage::NoticeResponse { .. }
1872 | BackendMessage::ParameterStatus { .. }
1873 | BackendMessage::AuthOk
1874 | BackendMessage::AuthSaslFinal { .. }
1875 | BackendMessage::BackendKeyData { .. } => {}
1876 BackendMessage::ErrorResponse { data } => {
1877 let fields = proto::parse_error_response(data);
1878 self.drain_to_ready()?;
1879 return Err(self.make_server_error(fields));
1880 }
1881 other => {
1882 return Err(DriverError::Protocol(format!(
1883 "unexpected message during simple_query_rows: {other:?}"
1884 )));
1885 }
1886 }
1887 }
1888 }
1889
1890 pub fn copy_in<'a, I>(
1912 &mut self,
1913 table: &str,
1914 columns: &[&str],
1915 rows: I,
1916 ) -> Result<u64, DriverError>
1917 where
1918 I: IntoIterator<Item = &'a str>,
1919 {
1920 let quoted_table = proto::quote_ident(table);
1922 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
1923 let sql = format!(
1924 "COPY {}({}) FROM STDIN",
1925 quoted_table,
1926 quoted_cols.join(",")
1927 );
1928
1929 self.write_buf.clear();
1931 proto::write_simple_query(&mut self.write_buf, &sql);
1932 self.flush_write()?;
1933
1934 loop {
1936 let msg = self.read_one_message()?;
1937 match msg {
1938 BackendMessage::CopyInResponse { .. } => break,
1939 BackendMessage::ErrorResponse { data } => {
1940 let fields = proto::parse_error_response(data);
1941 self.drain_to_ready()?;
1942 return Err(self.make_server_error(fields));
1943 }
1944 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
1945 other => {
1946 return Err(DriverError::Protocol(format!(
1947 "expected CopyInResponse, got: {other:?}"
1948 )));
1949 }
1950 }
1951 }
1952
1953 for row in rows {
1960 self.write_buf.clear();
1961 let mut row_bytes = Vec::with_capacity(row.len() + 1);
1963 row_bytes.extend_from_slice(row.as_bytes());
1964 row_bytes.push(b'\n');
1965 proto::write_copy_data(&mut self.write_buf, &row_bytes);
1966 self.flush_write()?;
1967 }
1968
1969 self.write_buf.clear();
1971 proto::write_copy_done(&mut self.write_buf);
1972 self.flush_write()?;
1973
1974 let mut count: u64 = 0;
1976 loop {
1977 let msg = self.read_one_message()?;
1978 match msg {
1979 BackendMessage::CommandComplete { tag } => {
1980 count = proto::parse_command_tag(tag);
1981 }
1982 BackendMessage::ReadyForQuery { status } => {
1983 self.tx_status = status;
1984 return Ok(count);
1985 }
1986 BackendMessage::ErrorResponse { data } => {
1987 let fields = proto::parse_error_response(data);
1988 self.drain_to_ready()?;
1989 return Err(self.make_server_error(fields));
1990 }
1991 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
1992 other => {
1993 return Err(DriverError::Protocol(format!(
1994 "unexpected message during copy_in completion: {other:?}"
1995 )));
1996 }
1997 }
1998 }
1999 }
2000
2001 pub fn copy_out<W: std::io::Write>(
2021 &mut self,
2022 query: &str,
2023 writer: &mut W,
2024 ) -> Result<u64, DriverError> {
2025 let sql = format!("COPY ({query}) TO STDOUT");
2027
2028 self.write_buf.clear();
2030 proto::write_simple_query(&mut self.write_buf, &sql);
2031 self.flush_write()?;
2032
2033 loop {
2035 let msg = self.read_one_message()?;
2036 match msg {
2037 BackendMessage::CopyOutResponse { .. } => break,
2038 BackendMessage::ErrorResponse { data } => {
2039 let fields = proto::parse_error_response(data);
2040 self.drain_to_ready()?;
2041 return Err(self.make_server_error(fields));
2042 }
2043 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2044 other => {
2045 return Err(DriverError::Protocol(format!(
2046 "expected CopyOutResponse, got: {other:?}"
2047 )));
2048 }
2049 }
2050 }
2051
2052 loop {
2054 let msg = self.read_one_message()?;
2055 match msg {
2056 BackendMessage::CopyData { data } => {
2057 writer.write_all(&data).map_err(DriverError::Io)?;
2058 }
2059 BackendMessage::CopyDone => break,
2060 BackendMessage::ErrorResponse { data } => {
2061 let fields = proto::parse_error_response(data);
2062 self.drain_to_ready()?;
2063 return Err(self.make_server_error(fields));
2064 }
2065 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2066 other => {
2067 return Err(DriverError::Protocol(format!(
2068 "unexpected message during copy_out data: {other:?}"
2069 )));
2070 }
2071 }
2072 }
2073
2074 let mut count: u64 = 0;
2076 loop {
2077 let msg = self.read_one_message()?;
2078 match msg {
2079 BackendMessage::CommandComplete { tag } => {
2080 count = proto::parse_command_tag(tag);
2081 }
2082 BackendMessage::ReadyForQuery { status } => {
2083 self.tx_status = status;
2084 return Ok(count);
2085 }
2086 BackendMessage::ErrorResponse { data } => {
2087 let fields = proto::parse_error_response(data);
2088 self.drain_to_ready()?;
2089 return Err(self.make_server_error(fields));
2090 }
2091 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2092 other => {
2093 return Err(DriverError::Protocol(format!(
2094 "unexpected message during copy_out completion: {other:?}"
2095 )));
2096 }
2097 }
2098 }
2099 }
2100
2101 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2106 self.write_buf.clear();
2107 proto::write_parse(&mut self.write_buf, "", sql, &[]);
2110 proto::write_describe(&mut self.write_buf, b'S', "");
2111 proto::write_sync(&mut self.write_buf);
2112 self.flush_write()?;
2113
2114 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2116
2117 let mut param_oids: Vec<u32> = Vec::new();
2119 let columns;
2120 loop {
2121 let msg = self.read_one_message()?;
2122 match msg {
2123 BackendMessage::ParameterDescription { data } => {
2124 param_oids = proto::parse_parameter_description(data)?;
2125 }
2126 BackendMessage::RowDescription { data } => {
2127 columns = proto::parse_row_description(data)?;
2128 break;
2129 }
2130 BackendMessage::NoData => {
2131 columns = Vec::new();
2132 break;
2133 }
2134 BackendMessage::NoticeResponse { .. } => {}
2135 BackendMessage::ErrorResponse { data } => {
2136 let fields = proto::parse_error_response(data);
2137 self.drain_to_ready()?;
2138 return Err(self.make_server_error(fields));
2139 }
2140 other => {
2141 return Err(DriverError::Protocol(format!(
2142 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2143 )));
2144 }
2145 }
2146 }
2147
2148 self.expect_ready()?;
2150
2151 Ok(PrepareResult {
2152 columns,
2153 param_oids,
2154 })
2155 }
2156
2157 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2166 loop {
2167 let (msg_type, _payload_len) = self.read_message_buffered()?;
2168 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2169 match msg {
2170 BackendMessage::NotificationResponse {
2171 channel, payload, ..
2172 } => {
2173 return Ok((channel.to_owned(), payload.to_owned()));
2174 }
2175 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2176 continue;
2177 }
2178 _ => continue,
2179 }
2180 }
2181 }
2182
2183 pub fn cancel(&self) -> Result<(), DriverError> {
2189 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2190 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2191 let mut buf = Vec::with_capacity(16);
2192 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2193 tcp.write_all(&buf).map_err(DriverError::Io)?;
2194 tcp.flush().map_err(DriverError::Io)?;
2195 drop(tcp);
2197 Ok(())
2198 }
2199
2200 pub fn set_read_timeout(
2205 &self,
2206 timeout: Option<std::time::Duration>,
2207 ) -> Result<(), DriverError> {
2208 self.stream
2209 .set_read_timeout(timeout)
2210 .map_err(DriverError::Io)
2211 }
2212
2213 pub fn query_streaming_start(
2227 &mut self,
2228 sql: &str,
2229 sql_hash: u64,
2230 params: &[&(dyn Encode + Sync)],
2231 chunk_size: i32,
2232 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2233 self.write_buf.clear();
2234
2235 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
2236 self.query_counter += 1;
2238 info.last_used = self.query_counter;
2239
2240 let can_use_template = info
2241 .bind_template
2242 .as_ref()
2243 .is_some_and(|t| t.param_slots.len() == params.len());
2244
2245 if can_use_template {
2246 let tmpl = info
2248 .bind_template
2249 .as_ref()
2250 .expect("guarded by can_use_template");
2251 self.write_buf
2254 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2255
2256 let mut template_ok = true;
2257 for (i, param) in params.iter().enumerate() {
2258 let (data_offset, old_len) = tmpl.param_slots[i];
2259 if param.is_null() {
2260 let len_offset = data_offset - 4;
2261 self.write_buf[len_offset..len_offset + 4]
2262 .copy_from_slice(&(-1i32).to_be_bytes());
2263 } else if old_len >= 0 {
2264 let end = data_offset + old_len as usize;
2265 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2266 template_ok = false;
2267 break;
2268 }
2269 } else {
2270 template_ok = false;
2271 break;
2272 }
2273 }
2274
2275 if !template_ok {
2276 self.write_buf.clear();
2277 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2278 info.bind_template = None;
2279 }
2280 } else {
2281 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2282 }
2283
2284 let cols = info.columns.clone();
2285
2286 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2287 info.bind_template = build_bind_template(&self.write_buf, params.len());
2288 }
2289
2290 proto::write_execute(&mut self.write_buf, "", chunk_size);
2291 proto::write_flush(&mut self.write_buf);
2293 self.flush_write()?;
2294
2295 cols
2296 } else {
2297 let name = make_stmt_name(sql_hash);
2299 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2300 let param_oids: smallvec::SmallVec<[u32; 8]> =
2301 params.iter().map(|p| p.type_oid()).collect();
2302 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2303 proto::write_describe(&mut self.write_buf, b'S', name_s);
2304 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2305
2306 proto::write_execute(&mut self.write_buf, "", chunk_size);
2307 proto::write_flush(&mut self.write_buf);
2308 self.flush_write()?;
2309
2310 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2311 let columns = self.read_column_description()?;
2312 self.query_counter += 1;
2313 self.cache_stmt(
2314 sql_hash,
2315 StmtInfo {
2316 name,
2317 columns: columns.clone(),
2318 last_used: self.query_counter,
2319 bind_template: None,
2320 },
2321 );
2322 columns
2323 };
2324
2325 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2327
2328 self.streaming_active = true;
2329
2330 Ok((columns, false))
2331 }
2332
2333 pub fn streaming_next_chunk(
2341 &mut self,
2342 arena: &mut Arena,
2343 all_col_offsets: &mut Vec<(usize, i32)>,
2344 ) -> Result<bool, DriverError> {
2345 all_col_offsets.clear();
2346
2347 loop {
2348 let msg = self.read_one_message()?;
2349 match msg {
2350 BackendMessage::DataRow { data } => {
2351 parse_data_row_flat(data, arena, all_col_offsets)?;
2352 }
2353 BackendMessage::PortalSuspended => {
2354 return Ok(true);
2358 }
2359 BackendMessage::CommandComplete { .. } => {
2360 self.write_buf.clear();
2363 proto::write_sync(&mut self.write_buf);
2364 self.flush_write()?;
2365 self.expect_ready()?;
2366 self.shrink_buffers();
2367
2368 self.streaming_active = false;
2369 return Ok(false);
2370 }
2371 BackendMessage::EmptyQuery => {
2372 self.write_buf.clear();
2373 proto::write_sync(&mut self.write_buf);
2374 self.flush_write()?;
2375 self.expect_ready()?;
2376
2377 self.streaming_active = false;
2378 return Ok(false);
2379 }
2380 BackendMessage::ErrorResponse { data } => {
2381 let fields = proto::parse_error_response(data);
2382 self.write_buf.clear();
2384 proto::write_sync(&mut self.write_buf);
2385 self.flush_write()?;
2386 self.drain_to_ready()?;
2387
2388 self.streaming_active = false;
2389 return Err(self.make_server_error(fields));
2390 }
2391 BackendMessage::NoticeResponse { .. } => {}
2392 other => {
2393 return Err(DriverError::Protocol(format!(
2394 "unexpected message during streaming: {other:?}"
2395 )));
2396 }
2397 }
2398 }
2399 }
2400
2401 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2409 self.write_buf.clear();
2410 proto::write_execute(&mut self.write_buf, "", chunk_size);
2411 proto::write_flush(&mut self.write_buf);
2412 self.flush_write()
2413 }
2414
2415 pub fn is_streaming(&self) -> bool {
2417 self.streaming_active
2418 }
2419
2420 pub fn close(mut self) -> Result<(), DriverError> {
2422 self.write_buf.clear();
2423 proto::write_terminate(&mut self.write_buf);
2424 let _ = self.flush_write();
2425 Ok(())
2426 }
2427
2428 pub fn is_idle(&self) -> bool {
2432 self.tx_status == b'I'
2433 }
2434
2435 pub fn is_in_transaction(&self) -> bool {
2437 self.tx_status == b'T'
2438 }
2439
2440 pub fn is_in_failed_transaction(&self) -> bool {
2442 self.tx_status == b'E'
2443 }
2444
2445 pub fn touch(&mut self) {
2447 self.last_used = std::time::Instant::now();
2448 }
2449
2450 pub fn idle_duration(&self) -> std::time::Duration {
2452 self.last_used.elapsed()
2453 }
2454
2455 pub fn query_counter(&self) -> u64 {
2457 self.query_counter
2458 }
2459
2460 pub fn parameter(&self, name: &str) -> Option<&str> {
2462 self.params
2463 .iter()
2464 .find(|(k, _)| &**k == name)
2465 .map(|(_, v)| &**v)
2466 }
2467
2468 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2470 &self.params
2471 }
2472
2473 pub fn pid(&self) -> i32 {
2475 self.pid
2476 }
2477
2478 pub fn secret_key(&self) -> i32 {
2480 self.secret
2481 }
2482
2483 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2485 std::mem::take(&mut self.pending_notifications)
2486 }
2487
2488 pub fn pending_notification_count(&self) -> usize {
2490 self.pending_notifications.len()
2491 }
2492
2493 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2495 self.max_stmt_cache_size = size;
2496 }
2497
2498 pub fn stmt_cache_len(&self) -> usize {
2500 self.stmts.len()
2501 }
2502
2503 pub fn created_at(&self) -> std::time::Instant {
2505 self.created_at
2506 }
2507
2508 #[inline]
2516 fn send_pipeline(
2517 &mut self,
2518 sql: &str,
2519 sql_hash: u64,
2520 params: &[&(dyn Encode + Sync)],
2521 need_columns: bool,
2522 skip_bind_complete: bool,
2523 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
2524 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2525
2526 if params.len() > i16::MAX as usize {
2527 return Err(DriverError::Protocol(format!(
2528 "parameter count {} exceeds maximum {}",
2529 params.len(),
2530 i16::MAX
2531 )));
2532 }
2533
2534 self.write_buf.clear();
2535
2536 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
2537 self.query_counter += 1;
2539 info.last_used = self.query_counter;
2540
2541 let can_use_template = info
2542 .bind_template
2543 .as_ref()
2544 .is_some_and(|t| t.param_slots.len() == params.len());
2545
2546 let mut has_exec_sync = false;
2548
2549 if can_use_template {
2550 let tmpl = info
2554 .bind_template
2555 .as_ref()
2556 .expect("guarded by can_use_template");
2557 self.write_buf.extend_from_slice(&tmpl.bytes);
2558
2559 let mut template_ok = true;
2560 for (i, param) in params.iter().enumerate() {
2561 let (data_offset, old_len) = tmpl.param_slots[i];
2562 if param.is_null() {
2563 let len_offset = data_offset - 4;
2565 self.write_buf[len_offset..len_offset + 4]
2566 .copy_from_slice(&(-1i32).to_be_bytes());
2567 } else if old_len >= 0 {
2568 let end = data_offset + old_len as usize;
2569 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2570 template_ok = false;
2572 break;
2573 }
2574 } else {
2575 template_ok = false;
2578 break;
2579 }
2580 }
2581
2582 if template_ok {
2583 has_exec_sync = true; } else {
2585 self.write_buf.clear();
2586 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2587 info.bind_template = None;
2589 }
2590 } else {
2591 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2592 }
2593
2594 let cols = if need_columns {
2595 Some(info.columns.clone())
2596 } else {
2597 None
2598 };
2599
2600 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2604 info.bind_template = build_bind_template(&self.write_buf, params.len());
2605 }
2606
2607 if !has_exec_sync {
2608 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2609 }
2610 self.flush_write()?;
2611
2612 cols
2613 } else {
2614 let name = make_stmt_name(sql_hash);
2616 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2617 let param_oids: smallvec::SmallVec<[u32; 8]> =
2618 params.iter().map(|p| p.type_oid()).collect();
2619 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2620 proto::write_describe(&mut self.write_buf, b'S', name_s);
2621 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2622
2623 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2624 self.flush_write()?;
2625
2626 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2627 let columns = self.read_column_description()?;
2628 self.query_counter += 1;
2629 self.cache_stmt(
2630 sql_hash,
2631 StmtInfo {
2632 name,
2633 columns: columns.clone(),
2634 last_used: self.query_counter,
2635 bind_template: None,
2636 },
2637 );
2638 if need_columns { Some(columns) } else { None }
2639 };
2640
2641 if !skip_bind_complete {
2642 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2643 }
2644
2645 Ok(columns)
2646 }
2647
2648 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2650 loop {
2651 let msg = self.read_one_message()?;
2652 match msg {
2653 BackendMessage::RowDescription { data } => {
2654 let cols = proto::parse_row_description(data)?;
2655 return Ok(cols.into());
2656 }
2657 BackendMessage::ParameterDescription { .. } => {}
2658 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2659 BackendMessage::NoticeResponse { .. } => {}
2660 BackendMessage::ErrorResponse { data } => {
2661 let fields = proto::parse_error_response(data);
2662 self.drain_to_ready()?;
2663 return Err(self.make_server_error(fields));
2664 }
2665 other => {
2666 return Err(DriverError::Protocol(format!(
2667 "expected RowDescription/NoData, got: {other:?}"
2668 )));
2669 }
2670 }
2671 }
2672 }
2673
2674 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2677 if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
2678 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2679 proto::write_close(&mut self.write_buf, b'S', evicted.name_str());
2680 }
2681 }
2682 self.stmts.insert(sql_hash, info);
2683 }
2684
2685 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2686 if self.pending_notifications.len() < 1024 {
2687 self.pending_notifications.push(Notification {
2688 pid,
2689 channel: channel.to_owned(),
2690 payload: payload.to_owned(),
2691 });
2692 }
2693 }
2694
2695 fn shrink_buffers(&mut self) {
2696 if self.query_counter & 63 != 0 {
2700 return;
2701 }
2702 if self.read_buf.capacity() > 64 * 1024 {
2703 self.read_buf.clear();
2704 self.read_buf.shrink_to(8192);
2705 }
2706 if self.write_buf.capacity() > 16 * 1024 {
2707 self.write_buf.clear();
2708 self.write_buf.shrink_to(8192);
2709 }
2710 }
2711
2712 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2713 if &*fields.code == "26000" {
2714 self.stmts.remove(&sql_hash);
2715 true
2716 } else {
2717 false
2718 }
2719 }
2720
2721 #[cold]
2722 #[inline(never)]
2723 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2724 DriverError::Server {
2725 code: fields.code,
2726 message: fields.message.into_boxed_str(),
2727 detail: fields.detail.map(String::into_boxed_str),
2728 hint: fields.hint.map(String::into_boxed_str),
2729 position: fields.position,
2730 }
2731 }
2732
2733 #[cold]
2739 #[inline(never)]
2740 fn handle_non_datarow_query(
2741 &mut self,
2742 msg_type: u8,
2743 payload_start: usize,
2744 payload_end: usize,
2745 sql_hash: u64,
2746 affected_rows: &mut u64,
2747 ) -> Result<(), DriverError> {
2748 match msg_type {
2749 b'2' | b'I' => {} b'C' => {
2751 *affected_rows =
2752 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2753 }
2754 b'E' => {
2755 let fields =
2756 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2757 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2758 self.drain_to_ready()?;
2759 return Err(self.make_server_error(fields));
2760 }
2761 b'A' => {
2762 let msg = proto::parse_backend_message(
2763 msg_type,
2764 &self.stream_buf[payload_start..payload_end],
2765 )?;
2766 if let BackendMessage::NotificationResponse {
2767 pid,
2768 channel,
2769 payload,
2770 } = msg
2771 {
2772 let ch = channel.to_owned();
2773 let pl = payload.to_owned();
2774 self.buffer_notification(pid, &ch, &pl);
2775 }
2776 }
2777 _ => {} }
2779 Ok(())
2780 }
2781
2782 #[cold]
2785 #[inline(never)]
2786 fn handle_non_datarow_execute(
2787 &mut self,
2788 msg_type: u8,
2789 payload_start: usize,
2790 payload_end: usize,
2791 sql_hash: u64,
2792 ) -> Result<(), DriverError> {
2793 match msg_type {
2794 b'2' | b'C' | b'I' => {} b'E' => {
2796 let fields =
2797 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2798 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2799 self.drain_to_ready()?;
2800 return Err(self.make_server_error(fields));
2801 }
2802 b'A' => {
2803 let msg = proto::parse_backend_message(
2804 msg_type,
2805 &self.stream_buf[payload_start..payload_end],
2806 )?;
2807 if let BackendMessage::NotificationResponse {
2808 pid,
2809 channel,
2810 payload,
2811 } = msg
2812 {
2813 let ch = channel.to_owned();
2814 let pl = payload.to_owned();
2815 self.buffer_notification(pid, &ch, &pl);
2816 }
2817 }
2818 _ => {} }
2820 Ok(())
2821 }
2822
2823 #[inline]
2825 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2826 loop {
2827 let (msg_type, _payload_len) = self.read_message_buffered()?;
2828 if msg_type == b'A' {
2829 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2830 if let BackendMessage::NotificationResponse {
2831 pid,
2832 channel,
2833 payload,
2834 } = msg
2835 {
2836 let pid_owned = pid;
2837 let channel_owned = channel.to_owned();
2838 let payload_owned = payload.to_owned();
2839 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2840 continue;
2841 }
2842 }
2843 return proto::parse_backend_message(msg_type, &self.read_buf);
2844 }
2845 }
2846
2847 fn expect_message(
2848 &mut self,
2849 pred: impl Fn(&BackendMessage<'_>) -> bool,
2850 ) -> Result<(), DriverError> {
2851 loop {
2852 let msg = self.read_one_message()?;
2853 if pred(&msg) {
2854 return Ok(());
2855 }
2856 match msg {
2857 BackendMessage::ErrorResponse { data } => {
2858 let fields = proto::parse_error_response(data);
2859 self.drain_to_ready()?;
2860 return Err(self.make_server_error(fields));
2861 }
2862 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2863 other => {
2864 return Err(DriverError::Protocol(format!(
2865 "unexpected message while waiting for expected type: {other:?}"
2866 )));
2867 }
2868 }
2869 }
2870 }
2871
2872 fn expect_ready(&mut self) -> Result<(), DriverError> {
2873 loop {
2874 let msg = self.read_one_message()?;
2875 match msg {
2876 BackendMessage::ReadyForQuery { status } => {
2877 self.tx_status = status;
2878 return Ok(());
2879 }
2880 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2881 BackendMessage::ErrorResponse { data } => {
2882 let fields = proto::parse_error_response(data);
2883 self.drain_to_ready()?;
2884 return Err(self.make_server_error(fields));
2885 }
2886 _ => {}
2887 }
2888 }
2889 }
2890
2891 #[inline]
2892 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2893 loop {
2894 let msg = self.read_one_message()?;
2895 if let BackendMessage::ReadyForQuery { status } = msg {
2896 self.tx_status = status;
2897 return Ok(());
2898 }
2899 }
2900 }
2901
2902 #[inline]
2906 fn flush_write(&mut self) -> Result<(), DriverError> {
2907 self.stream
2908 .write_all(&self.write_buf)
2909 .map_err(DriverError::Io)
2910 }
2911
2912 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2916 let mut header = [0u8; 5];
2917 sync_buffered_read_exact(
2918 &mut self.stream,
2919 &mut self.stream_buf,
2920 &mut self.stream_buf_pos,
2921 &mut self.stream_buf_end,
2922 &mut header,
2923 )?;
2924
2925 let msg_type = header[0];
2926 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2927
2928 if len < 4 {
2929 return Err(DriverError::Protocol(format!(
2930 "invalid message length {len} for type '{}'",
2931 msg_type as char
2932 )));
2933 }
2934
2935 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2936 if len > MAX_MESSAGE_LEN {
2937 return Err(DriverError::Protocol(format!(
2938 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2939 msg_type as char
2940 )));
2941 }
2942
2943 let payload_len = (len - 4) as usize;
2944 self.read_buf.clear();
2945 self.read_buf.resize(payload_len, 0);
2946 if payload_len > 0 {
2947 sync_buffered_read_exact(
2948 &mut self.stream,
2949 &mut self.stream_buf,
2950 &mut self.stream_buf_pos,
2951 &mut self.stream_buf_end,
2952 &mut self.read_buf[..payload_len],
2953 )?;
2954 }
2955
2956 Ok((msg_type, payload_len))
2957 }
2958
2959 #[inline]
2961 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
2962 let remaining = self.stream_buf_end - self.stream_buf_pos;
2963 if remaining > 0 && self.stream_buf_pos > 0 {
2964 self.stream_buf
2965 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2966 }
2967 self.stream_buf_pos = 0;
2968 self.stream_buf_end = remaining;
2969
2970 let n = self
2971 .stream
2972 .read(&mut self.stream_buf[remaining..])
2973 .map_err(DriverError::Io)?;
2974 if n == 0 {
2975 return Err(DriverError::Io(std::io::Error::new(
2976 std::io::ErrorKind::UnexpectedEof,
2977 "connection closed",
2978 )));
2979 }
2980 self.stream_buf_end = remaining + n;
2981 Ok(())
2982 }
2983}
2984
2985fn sync_buffered_read_exact(
2988 stream: &mut Stream,
2989 buf: &mut [u8],
2990 pos: &mut usize,
2991 end: &mut usize,
2992 out: &mut [u8],
2993) -> Result<(), DriverError> {
2994 let mut filled = 0;
2995 while filled < out.len() {
2996 let avail = *end - *pos;
2997 if avail > 0 {
2998 let take = avail.min(out.len() - filled);
2999 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
3000 *pos += take;
3001 filled += take;
3002 } else {
3003 *pos = 0;
3004 let n = stream.read(buf).map_err(DriverError::Io)?;
3005 if n == 0 {
3006 return Err(DriverError::Io(std::io::Error::new(
3007 std::io::ErrorKind::UnexpectedEof,
3008 "connection closed",
3009 )));
3010 }
3011 *end = n;
3012 }
3013 }
3014 Ok(())
3015}
3016
3017#[inline(always)]
3027pub(crate) fn parse_data_row_into_buf(
3028 data: &[u8],
3029 buf: &mut Vec<u8>,
3030 out: &mut Vec<(usize, i32)>,
3031) -> Result<(), DriverError> {
3032 if data.len() < 2 {
3033 return Err(DriverError::Protocol("DataRow too short".into()));
3034 }
3035
3036 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3037 if num_cols < 0 {
3038 return Err(DriverError::Protocol(
3039 "DataRow: negative column count".into(),
3040 ));
3041 }
3042 let num_cols = num_cols as usize;
3043
3044 let col_data = &data[2..];
3052 let base = buf.len();
3053 buf.extend_from_slice(col_data);
3054
3055 let mut pos: usize = 0;
3057 for _ in 0..num_cols {
3058 if pos + 4 > col_data.len() {
3059 return Err(DriverError::Protocol("DataRow truncated".into()));
3060 }
3061
3062 let col_len = i32::from_be_bytes([
3063 col_data[pos],
3064 col_data[pos + 1],
3065 col_data[pos + 2],
3066 col_data[pos + 3],
3067 ]);
3068 pos += 4;
3069
3070 if col_len < 0 {
3071 out.push((0, -1));
3072 } else {
3073 let len = col_len as usize;
3074 if pos + len > col_data.len() {
3075 return Err(DriverError::Protocol(
3076 "DataRow column data truncated".into(),
3077 ));
3078 }
3079 out.push((base + pos, col_len));
3081 pos += len;
3082 }
3083 }
3084
3085 Ok(())
3086}
3087
3088fn parse_data_row_flat(
3092 data: &[u8],
3093 arena: &mut Arena,
3094 out: &mut Vec<(usize, i32)>,
3095) -> Result<(), DriverError> {
3096 if data.len() < 2 {
3097 return Err(DriverError::Protocol("DataRow too short".into()));
3098 }
3099
3100 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3101 if num_cols_raw < 0 {
3102 return Err(DriverError::Protocol(
3103 "DataRow: negative column count".into(),
3104 ));
3105 }
3106 let num_cols = num_cols_raw as usize;
3107 out.reserve(num_cols);
3108
3109 let col_data = &data[2..];
3112 let base = arena.alloc_copy(col_data);
3113
3114 let mut pos: usize = 0;
3116 for _ in 0..num_cols {
3117 if pos + 4 > col_data.len() {
3118 return Err(DriverError::Protocol("DataRow truncated".into()));
3119 }
3120
3121 let col_len = i32::from_be_bytes([
3122 col_data[pos],
3123 col_data[pos + 1],
3124 col_data[pos + 2],
3125 col_data[pos + 3],
3126 ]);
3127 pos += 4;
3128
3129 if col_len < 0 {
3130 out.push((0, -1));
3131 } else {
3132 let len = col_len as usize;
3133 if pos + len > col_data.len() {
3134 return Err(DriverError::Protocol(
3135 "DataRow column data truncated".into(),
3136 ));
3137 }
3138 out.push((base + pos, col_len));
3140 pos += len;
3141 }
3142 }
3143
3144 Ok(())
3145}
3146
3147#[cfg(test)]
3148#[allow(clippy::approx_constant)]
3149mod tests {
3150 use super::*;
3151 use crate::types::hash_sql;
3152
3153 #[test]
3154 fn sync_config_tcp_no_longer_rejected() {
3155 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3158 let result = Connection::connect(&config);
3159 assert!(result.is_err());
3160 let err = result.unwrap_err().to_string();
3161 assert!(
3164 !err.contains("Unix domain socket"),
3165 "error should NOT mention UDS requirement: {err}"
3166 );
3167 }
3168
3169 #[test]
3170 fn sync_data_row_parsing() {
3171 let mut arena = Arena::new();
3172 let mut out = Vec::new();
3173
3174 let mut data = Vec::new();
3175 data.extend_from_slice(&2i16.to_be_bytes());
3176 data.extend_from_slice(&4i32.to_be_bytes());
3177 data.extend_from_slice(&42i32.to_be_bytes());
3178 data.extend_from_slice(&(-1i32).to_be_bytes());
3179
3180 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3181 assert_eq!(out.len(), 2);
3182 assert_eq!(out[0].1, 4);
3183 assert_eq!(out[1].1, -1);
3184 }
3185
3186 #[test]
3187 fn sync_data_row_empty() {
3188 let mut arena = Arena::new();
3189 let mut out = Vec::new();
3190 let data = 0i16.to_be_bytes();
3191 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3192 assert_eq!(out.len(), 0);
3193 }
3194
3195 #[test]
3196 fn sync_data_row_too_short() {
3197 let mut arena = Arena::new();
3198 let mut out = Vec::new();
3199 let data = vec![0u8];
3200 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3201 }
3202
3203 #[test]
3204 fn sync_data_row_negative_col_count() {
3205 let mut arena = Arena::new();
3206 let mut out = Vec::new();
3207 let data = (-1i16).to_be_bytes();
3208 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3209 }
3210
3211 #[test]
3212 fn sync_data_row_truncated() {
3213 let mut arena = Arena::new();
3214 let mut out = Vec::new();
3215 let mut data = Vec::new();
3216 data.extend_from_slice(&2i16.to_be_bytes());
3217 data.extend_from_slice(&4i32.to_be_bytes());
3218 data.extend_from_slice(&42i32.to_be_bytes());
3219 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3221 }
3222
3223 #[test]
3224 fn sync_data_row_col_data_truncated() {
3225 let mut arena = Arena::new();
3226 let mut out = Vec::new();
3227 let mut data = Vec::new();
3228 data.extend_from_slice(&1i16.to_be_bytes());
3229 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3232 }
3233
3234 #[test]
3237 fn sync_connect_tcp_unreachable_port() {
3238 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3241 let result = Connection::connect(&config);
3242 assert!(result.is_err());
3243 let err = result.unwrap_err().to_string();
3244 assert!(
3245 !err.contains("Unix domain socket"),
3246 "error should NOT mention UDS: {err}"
3247 );
3248 }
3249
3250 #[test]
3251 fn sync_connect_ip_address_attempts_tcp() {
3252 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3255 let result = Connection::connect(&config);
3256 assert!(result.is_err());
3257 }
3258
3259 #[test]
3262 fn sync_data_row_all_null() {
3263 let mut arena = Arena::new();
3264 let mut out = Vec::new();
3265 let mut data = Vec::new();
3266 data.extend_from_slice(&3i16.to_be_bytes());
3267 data.extend_from_slice(&(-1i32).to_be_bytes());
3268 data.extend_from_slice(&(-1i32).to_be_bytes());
3269 data.extend_from_slice(&(-1i32).to_be_bytes());
3270 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3271 assert_eq!(out.len(), 3);
3272 for (_, len) in &out {
3273 assert_eq!(*len, -1);
3274 }
3275 }
3276
3277 #[test]
3278 fn sync_data_row_long_text() {
3279 let mut arena = Arena::new();
3280 let mut out = Vec::new();
3281 let long_text = "a".repeat(2048);
3282 let text_bytes = long_text.as_bytes();
3283 let mut data = Vec::new();
3284 data.extend_from_slice(&1i16.to_be_bytes());
3285 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3286 data.extend_from_slice(text_bytes);
3287 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3288 assert_eq!(out.len(), 1);
3289 assert_eq!(out[0].1, text_bytes.len() as i32);
3290 let stored = arena.get(out[0].0, out[0].1 as usize);
3291 assert_eq!(stored, text_bytes);
3292 }
3293
3294 #[test]
3295 fn sync_data_row_empty_text() {
3296 let mut arena = Arena::new();
3297 let mut out = Vec::new();
3298 let mut data = Vec::new();
3299 data.extend_from_slice(&1i16.to_be_bytes());
3300 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3302 assert_eq!(out.len(), 1);
3303 assert_eq!(out[0].1, 0); }
3305
3306 #[test]
3307 fn sync_data_row_17_columns_exceeds_smallvec() {
3308 let mut arena = Arena::new();
3309 let mut out = Vec::new();
3310 let mut data = Vec::new();
3311 let num_cols: i16 = 20;
3312 data.extend_from_slice(&num_cols.to_be_bytes());
3313 for i in 0..num_cols {
3314 let val = (i as i32).to_be_bytes();
3315 data.extend_from_slice(&4i32.to_be_bytes());
3316 data.extend_from_slice(&val);
3317 }
3318 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3319 assert_eq!(out.len(), 20);
3320 for (idx, (offset, len)) in out.iter().enumerate() {
3321 assert_eq!(*len, 4);
3322 let stored = arena.get(*offset, 4);
3323 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3324 assert_eq!(val, idx as i32);
3325 }
3326 }
3327
3328 #[test]
3329 fn sync_data_row_mixed_null_and_data() {
3330 let mut arena = Arena::new();
3331 let mut out = Vec::new();
3332 let mut data = Vec::new();
3333 data.extend_from_slice(&5i16.to_be_bytes());
3334 data.extend_from_slice(&(-1i32).to_be_bytes());
3336 data.extend_from_slice(&4i32.to_be_bytes());
3338 data.extend_from_slice(&42i32.to_be_bytes());
3339 data.extend_from_slice(&(-1i32).to_be_bytes());
3341 data.extend_from_slice(&(-1i32).to_be_bytes());
3343 data.extend_from_slice(&5i32.to_be_bytes());
3345 data.extend_from_slice(b"hello");
3346
3347 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3348 assert_eq!(out.len(), 5);
3349 assert_eq!(out[0].1, -1);
3350 assert_eq!(out[1].1, 4);
3351 assert_eq!(out[2].1, -1);
3352 assert_eq!(out[3].1, -1);
3353 assert_eq!(out[4].1, 5);
3354 let stored = arena.get(out[4].0, 5);
3355 assert_eq!(stored, b"hello");
3356 }
3357
3358 #[test]
3361 #[ignore] fn sync_connect_uds_if_pg_available() {
3363 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3364 let result = Connection::connect(&config);
3365 if let Ok(conn) = result {
3367 assert!(conn.pid() != 0, "pid should be nonzero");
3368 assert!(conn.is_idle(), "should start idle");
3369 assert!(!conn.is_in_transaction(), "should not be in tx");
3370 assert!(
3371 !conn.is_in_failed_transaction(),
3372 "should not be in failed tx"
3373 );
3374 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3375 let _ = conn.close();
3376 }
3377 }
3378
3379 #[test]
3380 #[ignore] fn sync_simple_query_if_pg_available() {
3382 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3383 let mut conn = Connection::connect(&config).unwrap();
3384 conn.simple_query("SELECT 1").unwrap();
3385 assert!(conn.is_idle());
3386 let _ = conn.close();
3387 }
3388
3389 #[test]
3390 #[ignore] fn sync_query_with_params_if_pg_available() {
3392 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3393 let mut conn = Connection::connect(&config).unwrap();
3394 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3395 let hash = hash_sql(sql);
3396 let a: i32 = 10;
3397 let b: i32 = 20;
3398 let result = conn
3399 .query(
3400 sql,
3401 hash,
3402 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3403 )
3404 .unwrap();
3405 assert_eq!(result.len(), 1);
3406 let _ = conn.close();
3407 }
3408
3409 #[test]
3410 #[ignore] fn sync_execute_if_pg_available() {
3412 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3413 let mut conn = Connection::connect(&config).unwrap();
3414 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3415 .unwrap();
3416 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3417 let hash = hash_sql(sql);
3418 let val: i32 = 42;
3419 let affected = conn
3420 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3421 .unwrap();
3422 assert_eq!(affected, 1);
3423 let _ = conn.close();
3424 }
3425
3426 #[test]
3427 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3429 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3430 let mut conn = Connection::connect(&config).unwrap();
3431 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3432 .unwrap();
3433 let sql = "SELECT id FROM _sync_fe0";
3434 let hash = hash_sql(sql);
3435 let mut count = 0u32;
3436 conn.for_each(sql, hash, &[], |_row| {
3437 count += 1;
3438 Ok(())
3439 })
3440 .unwrap();
3441 assert_eq!(count, 0);
3442 let _ = conn.close();
3443 }
3444
3445 #[test]
3446 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3448 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3449 let mut conn = Connection::connect(&config).unwrap();
3450 let sql = "SELECT generate_series(1, 5)";
3451 let hash = hash_sql(sql);
3452 let mut count = 0u32;
3453 conn.for_each(sql, hash, &[], |_row| {
3454 count += 1;
3455 Ok(())
3456 })
3457 .unwrap();
3458 assert_eq!(count, 5);
3459 let _ = conn.close();
3460 }
3461
3462 #[test]
3463 #[ignore] fn sync_prepare_only_if_pg_available() {
3465 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3466 let mut conn = Connection::connect(&config).unwrap();
3467 let sql = "SELECT 1";
3468 let hash = hash_sql(sql);
3469 conn.prepare_only(sql, hash).unwrap();
3470 assert_eq!(conn.stmt_cache_len(), 1);
3471 conn.prepare_only(sql, hash).unwrap();
3473 assert_eq!(conn.stmt_cache_len(), 1);
3474 let _ = conn.close();
3475 }
3476
3477 #[test]
3478 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3480 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3481 let mut conn = Connection::connect(&config).unwrap();
3482 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3483 assert!(!rows.is_empty());
3484 let _ = conn.close();
3485 }
3486
3487 #[test]
3488 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3490 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3491 let mut conn = Connection::connect(&config).unwrap();
3492 let sql1 = "SELECT 1";
3493 let hash1 = hash_sql(sql1);
3494 conn.query(sql1, hash1, &[]).unwrap();
3495 assert_eq!(conn.stmt_cache_len(), 1);
3496 conn.query(sql1, hash1, &[]).unwrap();
3498 assert_eq!(conn.stmt_cache_len(), 1);
3499 let sql2 = "SELECT 2";
3501 let hash2 = hash_sql(sql2);
3502 conn.query(sql2, hash2, &[]).unwrap();
3503 assert_eq!(conn.stmt_cache_len(), 2);
3504 let _ = conn.close();
3505 }
3506
3507 #[test]
3508 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3510 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3511 let mut conn = Connection::connect(&config).unwrap();
3512 let sql = "SELECTTTT INVALID GARBAGE";
3513 let hash = hash_sql(sql);
3514 let result = conn.query(sql, hash, &[]);
3515 assert!(result.is_err());
3516 assert!(conn.is_idle());
3518 let _ = conn.close();
3519 }
3520
3521 #[test]
3522 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3524 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3525 let mut conn = Connection::connect(&config).unwrap();
3526 assert!(conn.is_idle());
3527 assert!(!conn.is_in_transaction());
3528 conn.simple_query("BEGIN").unwrap();
3529 assert!(conn.is_in_transaction());
3530 assert!(!conn.is_idle());
3531 conn.simple_query("COMMIT").unwrap();
3532 assert!(conn.is_idle());
3533 assert!(!conn.is_in_transaction());
3534 let _ = conn.close();
3535 }
3536
3537 #[test]
3538 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3540 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3541 let mut conn = Connection::connect(&config).unwrap();
3542 conn.set_max_stmt_cache_size(3);
3543 for i in 0..5 {
3544 let sql = format!("SELECT {}", i);
3545 let hash = hash_sql(&sql);
3546 conn.query(&sql, hash, &[]).unwrap();
3547 }
3548 assert!(
3550 conn.stmt_cache_len() <= 3,
3551 "cache should be capped at 3, got {}",
3552 conn.stmt_cache_len()
3553 );
3554 let _ = conn.close();
3555 }
3556
3557 #[test]
3558 #[ignore] fn sync_for_each_raw_if_pg_available() {
3560 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3561 let mut conn = Connection::connect(&config).unwrap();
3562 let sql = "SELECT generate_series(1, 3)";
3563 let hash = hash_sql(sql);
3564 let mut raw_count = 0u32;
3565 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3566 raw_count += 1;
3567 Ok(())
3568 })
3569 .unwrap();
3570 assert_eq!(raw_count, 3);
3571 let _ = conn.close();
3572 }
3573
3574 #[test]
3575 #[ignore] fn sync_query_null_params_if_pg_available() {
3577 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3578 let mut conn = Connection::connect(&config).unwrap();
3579 let sql = "SELECT $1::int4 IS NULL AS is_null";
3580 let hash = hash_sql(sql);
3581 let val: Option<i32> = None;
3582 let _result = conn
3583 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
3584 .unwrap();
3585 let _ = conn.close();
3586 }
3587
3588 #[test]
3589 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3591 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3592 let mut conn = Connection::connect(&config).unwrap();
3593 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3594 let hash = hash_sql(sql);
3595 let p1: i32 = 42;
3596 let p2: i64 = 9999999;
3597 let p3: &str = "hello";
3598 let p4: bool = true;
3599 let p5: f64 = 3.14;
3600 let result = conn
3601 .query(
3602 sql,
3603 hash,
3604 &[
3605 &p1 as &(dyn Encode + Sync),
3606 &p2 as &(dyn Encode + Sync),
3607 &p3 as &(dyn Encode + Sync),
3608 &p4 as &(dyn Encode + Sync),
3609 &p5 as &(dyn Encode + Sync),
3610 ],
3611 )
3612 .unwrap();
3613 assert_eq!(result.len(), 1);
3614 let _ = conn.close();
3615 }
3616
3617 #[test]
3620 fn sync_shrink_threshold_values() {
3621 let shrink = 64 * 1024usize;
3630 let initial = 8192usize;
3631 assert!(
3632 shrink > initial,
3633 "shrink threshold must exceed initial size"
3634 );
3635 }
3636
3637 #[test]
3640 fn sync_connection_debug_format() {
3641 let fmt_str = format!(
3645 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3646 0, 'I', 0
3647 );
3648 assert!(fmt_str.contains("Connection"));
3649 assert!(fmt_str.contains("pid"));
3650 assert!(fmt_str.contains("tx_status"));
3651 }
3652
3653 #[test]
3656 fn sync_connect_sslmode_require_without_tls_feature() {
3657 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3661 config.ssl = SslMode::Require;
3662 let result = Connection::connect(&config);
3663 assert!(result.is_err());
3664 }
3669
3670 #[test]
3671 fn sync_connect_sslmode_disable_attempts_tcp() {
3672 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3673 config.ssl = SslMode::Disable;
3674 let result = Connection::connect(&config);
3675 assert!(result.is_err());
3676 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
3678 }
3679
3680 #[test]
3681 fn sync_connect_sslmode_prefer_attempts_tcp() {
3682 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3683 config.ssl = SslMode::Prefer;
3684 let result = Connection::connect(&config);
3685 assert!(result.is_err());
3686 }
3687
3688 #[test]
3691 #[ignore] fn sync_streaming_basic_if_pg_available() {
3693 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3694 let mut conn = Connection::connect(&config).unwrap();
3695 assert!(!conn.is_streaming());
3696
3697 let sql = "SELECT generate_series(1, 10)";
3698 let hash = hash_sql(sql);
3699
3700 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
3701 assert!(!cols.is_empty());
3702 assert!(conn.is_streaming());
3703
3704 let mut arena = Arena::new();
3705 let mut offsets = Vec::new();
3706 let mut total_rows = 0;
3707
3708 loop {
3710 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
3711 total_rows += offsets.len();
3712 if !has_more {
3713 break;
3714 }
3715 conn.streaming_send_execute(3).unwrap();
3716 }
3717
3718 assert_eq!(total_rows, 10);
3719 assert!(!conn.is_streaming());
3720 let _ = conn.close();
3721 }
3722
3723 #[test]
3726 #[ignore] fn sync_prepare_describe_if_pg_available() {
3728 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3729 let mut conn = Connection::connect(&config).unwrap();
3730
3731 let result = conn
3732 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
3733 .unwrap();
3734 assert_eq!(result.columns.len(), 1);
3735 assert_eq!(&*result.columns[0].name, "sum");
3736 assert_eq!(result.param_oids.len(), 2);
3737 let _ = conn.close();
3738 }
3739
3740 #[test]
3743 #[ignore] fn sync_wait_for_notification_if_pg_available() {
3745 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3746 let mut conn = Connection::connect(&config).unwrap();
3747
3748 conn.simple_query("LISTEN test_chan").unwrap();
3749 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
3750
3751 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
3753 .unwrap();
3754
3755 let (channel, payload) = conn.wait_for_notification().unwrap();
3756 assert_eq!(channel, "test_chan");
3757 assert_eq!(payload, "hello");
3758 let _ = conn.close();
3759 }
3760
3761 #[test]
3764 #[ignore] fn sync_cancel_if_pg_available() {
3766 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3767 let conn = Connection::connect(&config).unwrap();
3768 let result = conn.cancel();
3771 let _ = result;
3773 let _ = conn.close();
3774 }
3775
3776 #[test]
3779 #[ignore] fn sync_server_params_if_pg_available() {
3781 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3782 let conn = Connection::connect(&config).unwrap();
3783 let params = conn.server_params();
3784 assert!(
3785 !params.is_empty(),
3786 "server should send parameters during startup"
3787 );
3788 assert!(
3790 conn.parameter("server_encoding").is_some(),
3791 "server_encoding should be present"
3792 );
3793 let _ = conn.close();
3794 }
3795
3796 #[test]
3799 #[ignore] fn sync_set_read_timeout_if_pg_available() {
3801 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3802 let conn = Connection::connect(&config).unwrap();
3803 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
3805 .unwrap();
3806 conn.set_read_timeout(None).unwrap();
3807 let _ = conn.close();
3808 }
3809}