1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::DriverError;
18use crate::arena::Arena;
19use crate::auth;
20use crate::codec::Encode;
21use crate::proto::{self, BackendMessage};
22use crate::stmt_cache::{StmtCache, StmtInfo, build_bind_template, make_stmt_name};
23use crate::sync_io::Stream;
24use crate::types::{
25 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
26 StartupAction,
27};
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58pub struct Connection {
87 stream: Stream,
88 read_buf: Vec<u8>,
89 stream_buf: Vec<u8>,
90 stream_buf_pos: usize,
91 stream_buf_end: usize,
92 write_buf: Vec<u8>,
93 stmts: StmtCache,
94 params: Vec<(Box<str>, Box<str>)>,
95 pid: i32,
96 secret: i32,
97 tx_status: u8,
98 last_used: std::time::Instant,
99 streaming_active: bool,
100 created_at: std::time::Instant,
101 pending_notifications: Vec<Notification>,
102 max_stmt_cache_size: usize,
103 query_counter: u64,
104 connect_config: Arc<Config>,
107 tls_server_cert_hash: Option<[u8; 32]>,
110}
111
112impl std::fmt::Debug for Connection {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("Connection")
115 .field("pid", &self.pid)
116 .field("tx_status", &(self.tx_status as char))
117 .field("stmt_cache_len", &self.stmts.len())
118 .finish()
119 }
120}
121
122impl Connection {
123 pub fn connect(config: &Config) -> Result<Self, DriverError> {
135 Self::connect_arc(Arc::new(config.clone()))
136 }
137
138 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
143 config.validate()?;
144
145 #[allow(unused_mut)]
148 let mut tls_cert_hash: Option<[u8; 32]> = None;
149
150 let stream = if config.host_is_uds() {
151 #[cfg(unix)]
153 {
154 let path = config.uds_path();
155 let unix =
156 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
157 Stream::Unix(unix)
158 }
159 #[cfg(not(unix))]
160 {
161 return Err(DriverError::Protocol(
162 "Unix domain sockets are not supported on this platform".into(),
163 ));
164 }
165 } else {
166 let addr = format!("{}:{}", config.host, config.port);
168 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
169
170 match config.ssl {
171 SslMode::Disable => {
172 tcp.set_nodelay(true).map_err(DriverError::Io)?;
173 let stream = Stream::Tcp(tcp);
174 stream.set_keepalive()?;
175 stream
176 }
177 SslMode::Prefer | SslMode::Require => {
178 #[cfg(feature = "tls")]
179 {
180 match crate::tls_sync::try_upgrade(
181 tcp,
182 &config.host,
183 config.ssl == SslMode::Require,
184 ) {
185 Ok(result) => {
186 tls_cert_hash = result.server_cert_hash;
187 let stream = Stream::Tls(Box::new(result.stream));
188 stream.set_nodelay()?;
189 stream.set_keepalive()?;
190 stream
191 }
192 Err(e) => {
193 if config.ssl == SslMode::Require {
194 return Err(e);
195 }
196 let tcp =
198 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
199 tcp.set_nodelay(true).map_err(DriverError::Io)?;
200 let stream = Stream::Tcp(tcp);
201 stream.set_keepalive()?;
202 stream
203 }
204 }
205 }
206 #[cfg(not(feature = "tls"))]
207 {
208 if config.ssl == SslMode::Require {
209 return Err(DriverError::Protocol(
210 "sslmode=require but bsql was compiled without the 'tls' feature"
211 .into(),
212 ));
213 }
214 tcp.set_nodelay(true).map_err(DriverError::Io)?;
215 let stream = Stream::Tcp(tcp);
216 stream.set_keepalive()?;
217 stream
218 }
219 }
220 }
221 };
222
223 let mut conn = Self {
224 stream,
225 read_buf: Vec::with_capacity(8192),
226 stream_buf: vec![0u8; 65536],
227 stream_buf_pos: 0,
228 stream_buf_end: 0,
229 write_buf: Vec::with_capacity(4096),
230 stmts: StmtCache::default(),
231 params: Vec::new(),
232 pid: 0,
233 secret: 0,
234 tx_status: b'I',
235 last_used: std::time::Instant::now(),
236 streaming_active: false,
237 created_at: std::time::Instant::now(),
238 pending_notifications: Vec::new(),
239 max_stmt_cache_size: 256,
240 query_counter: 0,
241 connect_config: config.clone(),
242 tls_server_cert_hash: tls_cert_hash,
243 };
244
245 conn.startup(&config)?;
246 conn.validate_server_params()?;
247
248 if config.statement_timeout_secs > 0 {
249 conn.simple_query(&format!(
250 "SET statement_timeout = '{}s'",
251 config.statement_timeout_secs
252 ))?;
253 }
254
255 Ok(conn)
256 }
257
258 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
261 self.write_buf.clear();
262 proto::write_startup(&mut self.write_buf, &config.user, &config.database);
263 self.flush_write()?;
264
265 loop {
266 let action = self.read_startup_action()?;
267 match action {
268 StartupAction::AuthOk => {}
269 StartupAction::AuthCleartext => {
270 self.write_buf.clear();
271 let mut pw = config.password.as_bytes().to_vec();
272 pw.push(0);
273 proto::write_password(&mut self.write_buf, &pw);
274 self.flush_write()?;
275 }
276 StartupAction::AuthMd5(salt) => {
277 self.write_buf.clear();
278 let hash = auth::md5_password(&config.user, &config.password, &salt);
279 proto::write_password(&mut self.write_buf, &hash);
280 self.flush_write()?;
281 }
282 StartupAction::AuthSasl(mechanisms_data) => {
283 self.handle_scram(config, &mechanisms_data)?;
284 }
285 StartupAction::ParameterStatus(name, value) => {
286 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
287 entry.1 = value;
288 } else {
289 self.params.push((name, value));
290 }
291 }
292 StartupAction::BackendKeyData(pid, secret) => {
293 self.pid = pid;
294 self.secret = secret;
295 }
296 StartupAction::ReadyForQuery(status) => {
297 self.tx_status = status;
298 return Ok(());
299 }
300 StartupAction::Error(msg) => {
301 return Err(DriverError::Auth(msg));
302 }
303 StartupAction::Notice => {}
304 }
305 }
306 }
307
308 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
309 let (msg_type, _) = self.read_message_buffered()?;
310 let payload = &self.read_buf;
311 let msg = proto::parse_backend_message(msg_type, payload)?;
312 match msg {
313 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
314 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
315 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
316 BackendMessage::AuthSasl { mechanisms } => {
317 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
318 }
319 BackendMessage::ParameterStatus { name, value } => {
320 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
321 }
322 BackendMessage::BackendKeyData { pid, secret } => {
323 Ok(StartupAction::BackendKeyData(pid, secret))
324 }
325 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
326 BackendMessage::ErrorResponse { data } => {
327 let fields = proto::parse_error_response(data);
328 Ok(StartupAction::Error(fields.to_string()))
329 }
330 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
331 other => Err(DriverError::Protocol(format!(
332 "unexpected message during startup: {other:?}"
333 ))),
334 }
335 }
336
337 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
338 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
339
340 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
343 let mechanism = if use_plus {
344 "SCRAM-SHA-256-PLUS"
345 } else {
346 "SCRAM-SHA-256"
347 };
348
349 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
350 return Err(DriverError::Auth(format!(
351 "server requires unsupported SASL mechanism(s): {mechs:?}"
352 )));
353 }
354
355 let cert_hash = if use_plus {
356 self.tls_server_cert_hash.as_ref()
357 } else {
358 None
359 };
360 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
361
362 let client_first = scram.client_first_message();
364 self.write_buf.clear();
365 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
366 self.flush_write()?;
367
368 let (msg_type, _) = self.read_message_buffered()?;
370 let server_first = {
371 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
372 match msg {
373 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
374 BackendMessage::ErrorResponse { data } => {
375 let fields = proto::parse_error_response(data);
376 return Err(DriverError::Auth(fields.to_string()));
377 }
378 other => {
379 return Err(DriverError::Protocol(format!(
380 "expected AuthSaslContinue, got: {other:?}"
381 )));
382 }
383 }
384 };
385
386 scram.process_server_first(&server_first)?;
387
388 let client_final = scram.client_final_message()?;
390 self.write_buf.clear();
391 proto::write_sasl_response(&mut self.write_buf, &client_final);
392 self.flush_write()?;
393
394 let (msg_type, _) = self.read_message_buffered()?;
396 {
397 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
398 match msg {
399 BackendMessage::AuthSaslFinal { data } => {
400 let data_owned = data.to_vec();
401 scram.verify_server_final(&data_owned)?;
402 }
403 BackendMessage::ErrorResponse { data } => {
404 let fields = proto::parse_error_response(data);
405 return Err(DriverError::Auth(fields.to_string()));
406 }
407 other => {
408 return Err(DriverError::Protocol(format!(
409 "expected AuthSaslFinal, got: {other:?}"
410 )));
411 }
412 }
413 }
414
415 let (msg_type, _) = self.read_message_buffered()?;
417 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
418 match msg {
419 BackendMessage::AuthOk => Ok(()),
420 BackendMessage::ErrorResponse { data } => {
421 let fields = proto::parse_error_response(data);
422 Err(DriverError::Auth(fields.to_string()))
423 }
424 other => Err(DriverError::Protocol(format!(
425 "expected AuthOk after SCRAM, got: {other:?}"
426 ))),
427 }
428 }
429
430 fn validate_server_params(&self) -> Result<(), DriverError> {
431 if let Some(encoding) = self.parameter("server_encoding") {
432 let normalized = encoding.to_uppercase();
433 if normalized != "UTF8" && normalized != "UTF-8" {
434 return Err(DriverError::Protocol(format!(
435 "server_encoding is '{encoding}', but bsql requires UTF-8."
436 )));
437 }
438 }
439 if let Some(encoding) = self.parameter("client_encoding") {
440 let normalized = encoding.to_uppercase();
441 if normalized != "UTF8" && normalized != "UTF-8" {
442 return Err(DriverError::Protocol(format!(
443 "client_encoding is '{encoding}', but bsql requires UTF-8."
444 )));
445 }
446 }
447 if let Some(idt) = self.parameter("integer_datetimes") {
448 if idt != "on" {
449 return Err(DriverError::Protocol(format!(
450 "integer_datetimes is '{idt}', but bsql requires 'on'."
451 )));
452 }
453 }
454 Ok(())
455 }
456
457 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
463 if self.stmts.contains_key(&sql_hash) {
464 return Ok(());
465 }
466 let name = make_stmt_name(sql_hash);
467 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
468 self.write_buf.clear();
469 proto::write_parse(&mut self.write_buf, name_s, sql, &[]);
470 proto::write_describe(&mut self.write_buf, b'S', name_s);
471 proto::write_sync(&mut self.write_buf);
472 self.flush_write()?;
473
474 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
475 let columns = self.read_column_description()?;
476 self.expect_ready()?;
477
478 self.query_counter += 1;
479 self.cache_stmt(
480 sql_hash,
481 StmtInfo {
482 name,
483 columns,
484 last_used: self.query_counter,
485 bind_template: None,
486 },
487 );
488 Ok(())
489 }
490
491 #[inline]
502 pub fn query(
503 &mut self,
504 sql: &str,
505 sql_hash: u64,
506 params: &[&(dyn Encode + Sync)],
507 ) -> Result<QueryResult, DriverError> {
508 let columns = self
509 .send_pipeline(sql, sql_hash, params, true, true)?
510 .expect("send_pipeline(need_columns=true) must return Some");
511
512 let num_cols = columns.len();
513 let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
514 let mut affected_rows: u64 = 0;
515
516 let mut resp_buf = acquire_resp_buf();
525 resp_buf.clear();
526
527 'outer: loop {
529 loop {
530 let avail = self.stream_buf_end - self.stream_buf_pos;
531 if avail < 5 {
532 break; }
534
535 let msg_type = self.stream_buf[self.stream_buf_pos];
536 let raw_len = i32::from_be_bytes([
537 self.stream_buf[self.stream_buf_pos + 1],
538 self.stream_buf[self.stream_buf_pos + 2],
539 self.stream_buf[self.stream_buf_pos + 3],
540 self.stream_buf[self.stream_buf_pos + 4],
541 ]);
542
543 if raw_len < 4 {
544 return Err(DriverError::Protocol(format!(
545 "invalid message length {raw_len} for type '{}'",
546 msg_type as char
547 )));
548 }
549
550 let payload_len = (raw_len - 4) as usize;
551 let total_msg_len = 5 + payload_len;
552
553 if avail < total_msg_len {
554 if total_msg_len > self.stream_buf.len() {
555 let msg = self.read_one_message()?;
557 match msg {
558 BackendMessage::BindComplete => continue,
559 BackendMessage::DataRow { data } => {
560 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
561 continue;
562 }
563 BackendMessage::CommandComplete { tag } => {
564 affected_rows = proto::parse_command_tag(tag);
565 continue;
566 }
567 BackendMessage::EmptyQuery => continue,
568 BackendMessage::ReadyForQuery { status } => {
569 self.tx_status = status;
570 break 'outer;
571 }
572 BackendMessage::NoticeResponse { .. } => continue,
573 BackendMessage::ErrorResponse { data } => {
574 let fields = proto::parse_error_response(data);
575 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
576 self.drain_to_ready()?;
577 return Err(self.make_server_error(fields));
578 }
579 other => {
580 return Err(DriverError::Protocol(format!(
581 "unexpected message during query: {other:?}"
582 )));
583 }
584 }
585 }
586 break; }
588
589 let payload_start = self.stream_buf_pos + 5;
591 let payload_end = payload_start + payload_len;
592
593 if msg_type == b'D' {
594 parse_data_row_into_buf(
596 &self.stream_buf[payload_start..payload_end],
597 &mut resp_buf,
598 &mut all_col_offsets,
599 )?;
600 } else if msg_type == b'Z' {
601 if payload_len >= 1 {
602 self.tx_status = self.stream_buf[payload_start];
603 }
604 self.stream_buf_pos += total_msg_len;
605 break 'outer;
606 } else {
607 self.handle_non_datarow_query(
608 msg_type,
609 payload_start,
610 payload_end,
611 sql_hash,
612 &mut affected_rows,
613 )?;
614 }
615
616 self.stream_buf_pos += total_msg_len;
617 }
618
619 self.refill_stream_buf()?;
620 }
621
622 self.shrink_buffers();
623
624 Ok(QueryResult::from_parts_with_buf(
627 all_col_offsets,
628 num_cols,
629 columns,
630 affected_rows,
631 resp_buf,
632 ))
633 }
634
635 #[inline]
646 pub fn execute_monolithic(
647 &mut self,
648 sql: &str,
649 sql_hash: u64,
650 params: &[&(dyn Encode + Sync)],
651 ) -> Result<u64, DriverError> {
652 self.write_buf.clear();
654
655 let info = match self.stmts.get_mut(&sql_hash) {
657 Some(info) => {
658 self.query_counter += 1;
659 info.last_used = self.query_counter;
660 info
661 }
662 None => {
663 return self.execute_with_prepare(sql, sql_hash, params);
665 }
666 };
667
668 let can_use_template = info
670 .bind_template
671 .as_ref()
672 .is_some_and(|t| t.param_slots.len() == params.len());
673
674 let mut has_exec_sync = false;
675
676 if can_use_template {
677 let tmpl = info
679 .bind_template
680 .as_ref()
681 .expect("guarded by can_use_template");
682 self.write_buf.extend_from_slice(&tmpl.bytes);
683
684 let mut template_ok = true;
685 for (i, param) in params.iter().enumerate() {
686 let (data_offset, old_len) = tmpl.param_slots[i];
687 if param.is_null() {
688 let len_offset = data_offset - 4;
689 self.write_buf[len_offset..len_offset + 4]
690 .copy_from_slice(&(-1i32).to_be_bytes());
691 } else if old_len >= 0 {
692 let end = data_offset + old_len as usize;
693 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
694 template_ok = false;
695 break;
696 }
697 } else {
698 template_ok = false;
700 break;
701 }
702 }
703
704 if template_ok {
705 has_exec_sync = true;
706 } else {
707 self.write_buf.clear();
708 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
709 info.bind_template = None;
710 }
711 } else {
712 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
713 }
714
715 if info.bind_template.is_none() && !self.write_buf.is_empty() {
717 info.bind_template = build_bind_template(&self.write_buf, params.len());
718 }
719
720 if !has_exec_sync {
721 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
722 }
723
724 self.stream
726 .write_all(&self.write_buf)
727 .map_err(DriverError::Io)?;
728
729 let mut affected_rows: u64 = 0;
731
732 'outer: loop {
733 loop {
734 let avail = self.stream_buf_end - self.stream_buf_pos;
735 if avail < 5 {
736 break; }
738
739 let msg_type = self.stream_buf[self.stream_buf_pos];
740 let raw_len = i32::from_be_bytes([
741 self.stream_buf[self.stream_buf_pos + 1],
742 self.stream_buf[self.stream_buf_pos + 2],
743 self.stream_buf[self.stream_buf_pos + 3],
744 self.stream_buf[self.stream_buf_pos + 4],
745 ]);
746
747 if raw_len < 4 {
748 return Err(DriverError::Protocol(format!(
749 "invalid message length {raw_len} for type '{}'",
750 msg_type as char
751 )));
752 }
753
754 let payload_len = (raw_len - 4) as usize;
755 let total_msg_len = 5 + payload_len;
756
757 if avail < total_msg_len {
758 if total_msg_len > self.stream_buf.len() {
759 let msg = self.read_one_message()?;
760 match msg {
761 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
762 continue;
763 }
764 BackendMessage::CommandComplete { tag } => {
765 affected_rows = proto::parse_command_tag(tag);
766 continue;
767 }
768 BackendMessage::EmptyQuery => continue,
769 BackendMessage::ReadyForQuery { status } => {
770 self.tx_status = status;
771 break 'outer;
772 }
773 BackendMessage::NoticeResponse { .. } => continue,
774 BackendMessage::ErrorResponse { data } => {
775 let fields = proto::parse_error_response(data);
776 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
777 self.drain_to_ready()?;
778 return Err(self.make_server_error(fields));
779 }
780 other => {
781 return Err(DriverError::Protocol(format!(
782 "unexpected message during execute: {other:?}"
783 )));
784 }
785 }
786 }
787 break; }
789
790 let payload_start = self.stream_buf_pos + 5;
795 let payload_end = payload_start + payload_len;
796
797 if msg_type == b'2' {
798 self.stream_buf_pos += total_msg_len;
800 continue;
801 } else if msg_type == b'C' {
802 affected_rows = proto::parse_command_tag_bytes(
804 &self.stream_buf[payload_start..payload_end],
805 );
806 } else if msg_type == b'Z' {
807 if payload_len >= 1 {
809 self.tx_status = self.stream_buf[payload_start];
810 }
811 self.stream_buf_pos += total_msg_len;
812 break 'outer;
813 } else if msg_type == b'D' || msg_type == b'I' {
814 } else {
816 self.handle_non_datarow_execute(
817 msg_type,
818 payload_start,
819 payload_end,
820 sql_hash,
821 )?;
822 }
823
824 self.stream_buf_pos += total_msg_len;
825 }
826
827 let remaining = self.stream_buf_end - self.stream_buf_pos;
829 debug_assert!(
830 remaining == 0 || self.stream_buf_pos > 0,
831 "compact called with pos=0 and remaining data"
832 );
833 if remaining > 0 {
834 self.stream_buf
835 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
836 }
837 self.stream_buf_pos = 0;
838 self.stream_buf_end = remaining;
839 let n = self
840 .stream
841 .read(&mut self.stream_buf[remaining..])
842 .map_err(DriverError::Io)?;
843 if n == 0 {
844 return Err(DriverError::Io(std::io::Error::new(
845 std::io::ErrorKind::UnexpectedEof,
846 "connection closed",
847 )));
848 }
849 self.stream_buf_end = remaining + n;
850 }
851
852 if self.query_counter & 63 == 0 {
854 if self.read_buf.capacity() > 64 * 1024 {
855 self.read_buf.clear();
856 self.read_buf.shrink_to(8192);
857 }
858 if self.write_buf.capacity() > 16 * 1024 {
859 self.write_buf.clear();
860 self.write_buf.shrink_to(8192);
861 }
862 }
863
864 Ok(affected_rows)
865 }
866
867 #[cold]
869 #[inline(never)]
870 fn execute_with_prepare(
871 &mut self,
872 sql: &str,
873 sql_hash: u64,
874 params: &[&(dyn Encode + Sync)],
875 ) -> Result<u64, DriverError> {
876 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
877
878 if params.len() > i16::MAX as usize {
879 return Err(DriverError::Protocol(format!(
880 "parameter count {} exceeds maximum {}",
881 params.len(),
882 i16::MAX
883 )));
884 }
885
886 let name = make_stmt_name(sql_hash);
887 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
888 let param_oids: smallvec::SmallVec<[u32; 8]> =
889 params.iter().map(|p| p.type_oid()).collect();
890
891 self.write_buf.clear();
892 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
893 proto::write_describe(&mut self.write_buf, b'S', name_s);
894 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
895 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
896 self.stream
897 .write_all(&self.write_buf)
898 .map_err(DriverError::Io)?;
899
900 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
901 let columns = self.read_column_description()?;
902 self.query_counter += 1;
903 self.cache_stmt(
904 sql_hash,
905 StmtInfo {
906 name,
907 columns,
908 last_used: self.query_counter,
909 bind_template: None,
910 },
911 );
912
913 let mut affected_rows: u64 = 0;
915 'outer: loop {
916 loop {
917 let avail = self.stream_buf_end - self.stream_buf_pos;
918 if avail < 5 {
919 break;
920 }
921
922 let msg_type = self.stream_buf[self.stream_buf_pos];
923 let raw_len = i32::from_be_bytes([
924 self.stream_buf[self.stream_buf_pos + 1],
925 self.stream_buf[self.stream_buf_pos + 2],
926 self.stream_buf[self.stream_buf_pos + 3],
927 self.stream_buf[self.stream_buf_pos + 4],
928 ]);
929
930 if raw_len < 4 {
931 return Err(DriverError::Protocol(format!(
932 "invalid message length {raw_len} for type '{}'",
933 msg_type as char
934 )));
935 }
936
937 let payload_len = (raw_len - 4) as usize;
938 let total_msg_len = 5 + payload_len;
939
940 if avail < total_msg_len {
941 if total_msg_len > self.stream_buf.len() {
942 let msg = self.read_one_message()?;
943 match msg {
944 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
945 continue;
946 }
947 BackendMessage::CommandComplete { tag } => {
948 affected_rows = proto::parse_command_tag(tag);
949 continue;
950 }
951 BackendMessage::EmptyQuery => continue,
952 BackendMessage::ReadyForQuery { status } => {
953 self.tx_status = status;
954 break 'outer;
955 }
956 BackendMessage::NoticeResponse { .. } => continue,
957 BackendMessage::ErrorResponse { data } => {
958 let fields = proto::parse_error_response(data);
959 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
960 self.drain_to_ready()?;
961 return Err(self.make_server_error(fields));
962 }
963 other => {
964 return Err(DriverError::Protocol(format!(
965 "unexpected message during execute: {other:?}"
966 )));
967 }
968 }
969 }
970 break;
971 }
972
973 let payload_start = self.stream_buf_pos + 5;
974 let payload_end = payload_start + payload_len;
975
976 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
977 } else if msg_type == b'C' {
979 affected_rows = proto::parse_command_tag_bytes(
980 &self.stream_buf[payload_start..payload_end],
981 );
982 } else if msg_type == b'Z' {
983 if payload_len >= 1 {
984 self.tx_status = self.stream_buf[payload_start];
985 }
986 self.stream_buf_pos += total_msg_len;
987 break 'outer;
988 } else {
989 self.handle_non_datarow_execute(
990 msg_type,
991 payload_start,
992 payload_end,
993 sql_hash,
994 )?;
995 }
996
997 self.stream_buf_pos += total_msg_len;
998 }
999
1000 self.refill_stream_buf()?;
1001 }
1002
1003 Ok(affected_rows)
1004 }
1005
1006 #[inline]
1011 pub fn execute(
1012 &mut self,
1013 sql: &str,
1014 sql_hash: u64,
1015 params: &[&(dyn Encode + Sync)],
1016 ) -> Result<u64, DriverError> {
1017 self.execute_monolithic(sql, sql_hash, params)
1018 }
1019
1020 pub fn execute_pipeline(
1032 &mut self,
1033 sql: &str,
1034 sql_hash: u64,
1035 param_sets: &[&[&(dyn Encode + Sync)]],
1036 ) -> Result<Vec<u64>, DriverError> {
1037 if param_sets.is_empty() {
1038 return Ok(Vec::new());
1039 }
1040
1041 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1042
1043 self.write_buf.clear();
1044
1045 if !self.stmts.contains_key(&sql_hash) {
1047 let name = make_stmt_name(sql_hash);
1048 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1049 let first_params = param_sets[0];
1050 if first_params.len() > i16::MAX as usize {
1051 return Err(DriverError::Protocol(format!(
1052 "parameter count {} exceeds maximum {}",
1053 first_params.len(),
1054 i16::MAX
1055 )));
1056 }
1057 let param_oids: smallvec::SmallVec<[u32; 8]> =
1058 first_params.iter().map(|p| p.type_oid()).collect();
1059 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1060 proto::write_describe(&mut self.write_buf, b'S', name_s);
1061 proto::write_sync(&mut self.write_buf);
1062 self.flush_write()?;
1063
1064 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1065 let columns = self.read_column_description()?;
1066 self.expect_ready()?;
1067
1068 self.query_counter += 1;
1069 self.cache_stmt(
1070 sql_hash,
1071 StmtInfo {
1072 name,
1073 columns,
1074 last_used: self.query_counter,
1075 bind_template: None,
1076 },
1077 );
1078
1079 self.write_buf.clear();
1080 }
1081
1082 let stmt_name = self
1084 .stmts
1085 .get(&sql_hash)
1086 .expect("BUG: stmt just cached but not found")
1087 .name_str()
1088 .to_owned();
1089 let count = param_sets.len();
1090
1091 for params in param_sets {
1092 if params.len() > i16::MAX as usize {
1093 return Err(DriverError::Protocol(format!(
1094 "parameter count {} exceeds maximum {}",
1095 params.len(),
1096 i16::MAX
1097 )));
1098 }
1099 proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1100 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1101 }
1102
1103 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1104 self.flush_write()?;
1105
1106 let mut results = Vec::with_capacity(count);
1108 for _ in 0..count {
1109 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1110
1111 let mut affected_rows: u64 = 0;
1112 loop {
1113 let msg = self.read_one_message()?;
1114 match msg {
1115 BackendMessage::DataRow { .. } => {}
1116 BackendMessage::CommandComplete { tag } => {
1117 affected_rows = proto::parse_command_tag(tag);
1118 break;
1119 }
1120 BackendMessage::EmptyQuery => break,
1121 BackendMessage::NoticeResponse { .. } => {}
1122 BackendMessage::ErrorResponse { data } => {
1123 let fields = proto::parse_error_response(data);
1124 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1125 self.drain_to_ready()?;
1126 return Err(self.make_server_error(fields));
1127 }
1128 other => {
1129 return Err(DriverError::Protocol(format!(
1130 "unexpected message during execute_pipeline: {other:?}"
1131 )));
1132 }
1133 }
1134 }
1135 results.push(affected_rows);
1136 }
1137
1138 self.expect_ready()?;
1139 self.shrink_buffers();
1140 Ok(results)
1141 }
1142
1143 pub(crate) fn ensure_stmt_prepared(
1149 &mut self,
1150 sql: &str,
1151 sql_hash: u64,
1152 params: &[&(dyn Encode + Sync)],
1153 ) -> Result<[u8; 18], DriverError> {
1154 if let Some(info) = self.stmts.get(&sql_hash) {
1155 return Ok(info.name);
1156 }
1157
1158 let name = make_stmt_name(sql_hash);
1159 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1160 if params.len() > i16::MAX as usize {
1161 return Err(DriverError::Protocol(format!(
1162 "parameter count {} exceeds maximum {}",
1163 params.len(),
1164 i16::MAX
1165 )));
1166 }
1167 let param_oids: smallvec::SmallVec<[u32; 8]> =
1168 params.iter().map(|p| p.type_oid()).collect();
1169
1170 self.write_buf.clear();
1171 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1172 proto::write_describe(&mut self.write_buf, b'S', name_s);
1173 proto::write_sync(&mut self.write_buf);
1174 self.flush_write()?;
1175
1176 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1177 let columns = self.read_column_description()?;
1178 self.expect_ready()?;
1179
1180 self.query_counter += 1;
1181 self.cache_stmt(
1182 sql_hash,
1183 StmtInfo {
1184 name,
1185 columns,
1186 last_used: self.query_counter,
1187 bind_template: None,
1188 },
1189 );
1190
1191 Ok(name)
1192 }
1193
1194 pub(crate) fn write_deferred_bind_execute(
1197 &self,
1198 sql_hash: u64,
1199 params: &[&(dyn Encode + Sync)],
1200 buf: &mut Vec<u8>,
1201 ) {
1202 let stmt_name = self
1203 .stmts
1204 .get(&sql_hash)
1205 .expect("BUG: stmt just cached but not found")
1206 .name_str();
1207 proto::write_bind_params(buf, "", stmt_name, params);
1208 buf.extend_from_slice(proto::EXECUTE_ONLY);
1209 }
1210
1211 pub(crate) fn flush_deferred_pipeline(
1216 &mut self,
1217 buf: &mut Vec<u8>,
1218 count: usize,
1219 ) -> Result<Vec<u64>, DriverError> {
1220 if count == 0 {
1221 buf.clear();
1222 return Ok(Vec::new());
1223 }
1224
1225 buf.extend_from_slice(proto::SYNC_ONLY);
1226
1227 self.stream.write_all(buf).map_err(DriverError::Io)?;
1228 buf.clear();
1229
1230 let mut results = Vec::with_capacity(count);
1231 for _ in 0..count {
1232 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1233
1234 let mut affected_rows: u64 = 0;
1235 loop {
1236 let msg = self.read_one_message()?;
1237 match msg {
1238 BackendMessage::DataRow { .. } => {}
1239 BackendMessage::CommandComplete { tag } => {
1240 affected_rows = proto::parse_command_tag(tag);
1241 break;
1242 }
1243 BackendMessage::EmptyQuery => break,
1244 BackendMessage::NoticeResponse { .. } => {}
1245 BackendMessage::ErrorResponse { data } => {
1246 let fields = proto::parse_error_response(data);
1247 self.drain_to_ready()?;
1248 return Err(self.make_server_error(fields));
1249 }
1250 other => {
1251 return Err(DriverError::Protocol(format!(
1252 "unexpected message during flush_deferred_pipeline: {other:?}"
1253 )));
1254 }
1255 }
1256 }
1257 results.push(affected_rows);
1258 }
1259
1260 self.expect_ready()?;
1261 self.shrink_buffers();
1262 Ok(results)
1263 }
1264
1265 pub fn for_each<F>(
1267 &mut self,
1268 sql: &str,
1269 sql_hash: u64,
1270 params: &[&(dyn Encode + Sync)],
1271 mut f: F,
1272 ) -> Result<(), DriverError>
1273 where
1274 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1275 {
1276 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1277
1278 'outer: loop {
1280 loop {
1281 let avail = self.stream_buf_end - self.stream_buf_pos;
1282 if avail < 5 {
1283 break; }
1285
1286 let msg_type = self.stream_buf[self.stream_buf_pos];
1287 let raw_len = i32::from_be_bytes([
1288 self.stream_buf[self.stream_buf_pos + 1],
1289 self.stream_buf[self.stream_buf_pos + 2],
1290 self.stream_buf[self.stream_buf_pos + 3],
1291 self.stream_buf[self.stream_buf_pos + 4],
1292 ]);
1293
1294 if raw_len < 4 {
1295 return Err(DriverError::Protocol(format!(
1296 "invalid message length {raw_len} for type '{}'",
1297 msg_type as char
1298 )));
1299 }
1300
1301 let payload_len = (raw_len - 4) as usize;
1302 let total_msg_len = 5 + payload_len;
1303
1304 if avail < total_msg_len {
1305 if total_msg_len > self.stream_buf.len() {
1306 let msg = self.read_one_message()?;
1308 match msg {
1309 BackendMessage::BindComplete => continue,
1310 BackendMessage::DataRow { data } => {
1311 let row = PgDataRow::new(data)?;
1312 f(row)?;
1313 continue;
1314 }
1315 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1316 continue;
1317 }
1318 BackendMessage::ReadyForQuery { status } => {
1319 self.tx_status = status;
1320 break 'outer;
1321 }
1322 BackendMessage::NoticeResponse { .. } => continue,
1323 BackendMessage::ErrorResponse { data } => {
1324 let fields = proto::parse_error_response(data);
1325 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1326 self.drain_to_ready()?;
1327 return Err(self.make_server_error(fields));
1328 }
1329 other => {
1330 return Err(DriverError::Protocol(format!(
1331 "unexpected message during for_each: {other:?}"
1332 )));
1333 }
1334 }
1335 }
1336 break; }
1338
1339 let payload_start = self.stream_buf_pos + 5;
1341 let payload_end = payload_start + payload_len;
1342
1343 if msg_type == b'D' {
1346 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1348 f(row)?;
1349 } else if msg_type == b'Z' {
1350 if payload_len >= 1 {
1352 self.tx_status = self.stream_buf[payload_start];
1353 }
1354 self.stream_buf_pos += total_msg_len;
1355 break 'outer;
1356 } else {
1357 self.handle_non_datarow_execute(
1358 msg_type,
1359 payload_start,
1360 payload_end,
1361 sql_hash,
1362 )?;
1363 }
1364
1365 self.stream_buf_pos += total_msg_len;
1366 }
1367
1368 self.refill_stream_buf()?;
1370 }
1371
1372 self.shrink_buffers();
1373 Ok(())
1374 }
1375
1376 #[inline]
1387 pub fn for_each_raw_monolithic<F>(
1388 &mut self,
1389 sql: &str,
1390 sql_hash: u64,
1391 params: &[&(dyn Encode + Sync)],
1392 mut f: F,
1393 ) -> Result<(), DriverError>
1394 where
1395 F: FnMut(&[u8]) -> Result<(), DriverError>,
1396 {
1397 self.write_buf.clear();
1399
1400 let info = match self.stmts.get_mut(&sql_hash) {
1402 Some(info) => {
1403 self.query_counter += 1;
1404 info.last_used = self.query_counter;
1405 info
1406 }
1407 None => {
1408 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1410 }
1411 };
1412
1413 let can_use_template = info
1415 .bind_template
1416 .as_ref()
1417 .is_some_and(|t| t.param_slots.len() == params.len());
1418
1419 let mut has_exec_sync = false;
1420
1421 if can_use_template {
1422 let tmpl = info
1424 .bind_template
1425 .as_ref()
1426 .expect("guarded by can_use_template");
1427 self.write_buf.extend_from_slice(&tmpl.bytes);
1428
1429 let mut template_ok = true;
1430 for (i, param) in params.iter().enumerate() {
1431 let (data_offset, old_len) = tmpl.param_slots[i];
1432 if param.is_null() {
1433 let len_offset = data_offset - 4;
1434 self.write_buf[len_offset..len_offset + 4]
1435 .copy_from_slice(&(-1i32).to_be_bytes());
1436 } else if old_len >= 0 {
1437 let end = data_offset + old_len as usize;
1438 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1439 template_ok = false;
1440 break;
1441 }
1442 } else {
1443 template_ok = false;
1444 break;
1445 }
1446 }
1447
1448 if template_ok {
1449 has_exec_sync = true;
1450 } else {
1451 self.write_buf.clear();
1452 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1453 info.bind_template = None;
1454 }
1455 } else {
1456 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1457 }
1458
1459 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1461 info.bind_template = build_bind_template(&self.write_buf, params.len());
1462 }
1463
1464 if !has_exec_sync {
1465 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1466 }
1467
1468 self.stream
1470 .write_all(&self.write_buf)
1471 .map_err(DriverError::Io)?;
1472
1473 loop {
1477 let avail = self.stream_buf_end - self.stream_buf_pos;
1478 if avail >= 5 {
1479 let bc_type = self.stream_buf[self.stream_buf_pos];
1480 match bc_type {
1481 b'2' => {
1482 self.stream_buf_pos += 5;
1483 break;
1484 }
1485 b'E' => {
1486 let msg = self.read_one_message()?;
1487 if let BackendMessage::ErrorResponse { data } = msg {
1488 let fields = proto::parse_error_response(data);
1489 self.drain_to_ready()?;
1490 return Err(self.make_server_error(fields));
1491 }
1492 }
1493 b'N' | b'S' => {
1494 let raw_len = i32::from_be_bytes([
1495 self.stream_buf[self.stream_buf_pos + 1],
1496 self.stream_buf[self.stream_buf_pos + 2],
1497 self.stream_buf[self.stream_buf_pos + 3],
1498 self.stream_buf[self.stream_buf_pos + 4],
1499 ]);
1500 let total = 1 + raw_len as usize;
1501 if avail >= total {
1502 self.stream_buf_pos += total;
1503 continue;
1504 }
1505 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1506 break;
1507 }
1508 _ => {
1509 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1510 break;
1511 }
1512 }
1513 } else {
1514 let remaining = self.stream_buf_end - self.stream_buf_pos;
1516 if remaining > 0 && self.stream_buf_pos > 0 {
1517 self.stream_buf
1518 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1519 }
1520 self.stream_buf_pos = 0;
1521 self.stream_buf_end = remaining;
1522 let n = self
1523 .stream
1524 .read(&mut self.stream_buf[remaining..])
1525 .map_err(DriverError::Io)?;
1526 if n == 0 {
1527 return Err(DriverError::Io(std::io::Error::new(
1528 std::io::ErrorKind::UnexpectedEof,
1529 "connection closed",
1530 )));
1531 }
1532 self.stream_buf_end = remaining + n;
1533 }
1534 }
1535
1536 'outer: loop {
1538 loop {
1539 let avail = self.stream_buf_end - self.stream_buf_pos;
1540 if avail < 5 {
1541 break;
1542 }
1543
1544 let msg_type = self.stream_buf[self.stream_buf_pos];
1545 let raw_len = i32::from_be_bytes([
1546 self.stream_buf[self.stream_buf_pos + 1],
1547 self.stream_buf[self.stream_buf_pos + 2],
1548 self.stream_buf[self.stream_buf_pos + 3],
1549 self.stream_buf[self.stream_buf_pos + 4],
1550 ]);
1551
1552 if raw_len < 4 {
1553 return Err(DriverError::Protocol(format!(
1554 "invalid message length {raw_len} for type '{}'",
1555 msg_type as char
1556 )));
1557 }
1558
1559 let payload_len = (raw_len - 4) as usize;
1560 let total_msg_len = 5 + payload_len;
1561
1562 if avail < total_msg_len {
1563 if total_msg_len > self.stream_buf.len() {
1564 let msg = self.read_one_message()?;
1565 match msg {
1566 BackendMessage::DataRow { data } => {
1567 f(data)?;
1568 continue;
1569 }
1570 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1571 continue;
1572 }
1573 BackendMessage::ReadyForQuery { status } => {
1574 self.tx_status = status;
1575 break 'outer;
1576 }
1577 BackendMessage::ErrorResponse { data } => {
1578 let fields = proto::parse_error_response(data);
1579 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1580 self.drain_to_ready()?;
1581 return Err(self.make_server_error(fields));
1582 }
1583 BackendMessage::NoticeResponse { .. } => continue,
1584 other => {
1585 return Err(DriverError::Protocol(format!(
1586 "unexpected message during for_each_raw: {other:?}"
1587 )));
1588 }
1589 }
1590 }
1591 break; }
1593
1594 let payload_start = self.stream_buf_pos + 5;
1596 let payload_end = payload_start + payload_len;
1597
1598 if msg_type == b'D' {
1599 f(&self.stream_buf[payload_start..payload_end])?;
1600 } else if msg_type == b'Z' {
1601 if payload_len >= 1 {
1602 self.tx_status = self.stream_buf[payload_start];
1603 }
1604 self.stream_buf_pos += total_msg_len;
1605 break 'outer;
1606 } else {
1607 self.handle_non_datarow_execute(
1608 msg_type,
1609 payload_start,
1610 payload_end,
1611 sql_hash,
1612 )?;
1613 }
1614
1615 self.stream_buf_pos += total_msg_len;
1616 }
1617
1618 let remaining = self.stream_buf_end - self.stream_buf_pos;
1620 if remaining > 0 && self.stream_buf_pos > 0 {
1621 self.stream_buf
1622 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1623 }
1624 self.stream_buf_pos = 0;
1625 self.stream_buf_end = remaining;
1626 let n = self
1627 .stream
1628 .read(&mut self.stream_buf[remaining..])
1629 .map_err(DriverError::Io)?;
1630 if n == 0 {
1631 return Err(DriverError::Io(std::io::Error::new(
1632 std::io::ErrorKind::UnexpectedEof,
1633 "connection closed",
1634 )));
1635 }
1636 self.stream_buf_end = remaining + n;
1637 }
1638
1639 if self.query_counter & 63 == 0 {
1641 if self.read_buf.capacity() > 64 * 1024 {
1642 self.read_buf.clear();
1643 self.read_buf.shrink_to(8192);
1644 }
1645 if self.write_buf.capacity() > 16 * 1024 {
1646 self.write_buf.clear();
1647 self.write_buf.shrink_to(8192);
1648 }
1649 }
1650
1651 Ok(())
1652 }
1653
1654 #[cold]
1656 #[inline(never)]
1657 fn for_each_raw_with_prepare<F>(
1658 &mut self,
1659 sql: &str,
1660 sql_hash: u64,
1661 params: &[&(dyn Encode + Sync)],
1662 mut f: F,
1663 ) -> Result<(), DriverError>
1664 where
1665 F: FnMut(&[u8]) -> Result<(), DriverError>,
1666 {
1667 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1668
1669 if params.len() > i16::MAX as usize {
1670 return Err(DriverError::Protocol(format!(
1671 "parameter count {} exceeds maximum {}",
1672 params.len(),
1673 i16::MAX
1674 )));
1675 }
1676
1677 let name = make_stmt_name(sql_hash);
1678 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1679 let param_oids: smallvec::SmallVec<[u32; 8]> =
1680 params.iter().map(|p| p.type_oid()).collect();
1681
1682 self.write_buf.clear();
1683 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1684 proto::write_describe(&mut self.write_buf, b'S', name_s);
1685 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
1686 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1687 self.stream
1688 .write_all(&self.write_buf)
1689 .map_err(DriverError::Io)?;
1690
1691 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1692 let columns = self.read_column_description()?;
1693 self.query_counter += 1;
1694 self.cache_stmt(
1695 sql_hash,
1696 StmtInfo {
1697 name,
1698 columns,
1699 last_used: self.query_counter,
1700 bind_template: None,
1701 },
1702 );
1703
1704 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1706
1707 'outer: loop {
1708 loop {
1709 let avail = self.stream_buf_end - self.stream_buf_pos;
1710 if avail < 5 {
1711 break;
1712 }
1713
1714 let msg_type = self.stream_buf[self.stream_buf_pos];
1715 let raw_len = i32::from_be_bytes([
1716 self.stream_buf[self.stream_buf_pos + 1],
1717 self.stream_buf[self.stream_buf_pos + 2],
1718 self.stream_buf[self.stream_buf_pos + 3],
1719 self.stream_buf[self.stream_buf_pos + 4],
1720 ]);
1721
1722 if raw_len < 4 {
1723 return Err(DriverError::Protocol(format!(
1724 "invalid message length {raw_len} for type '{}'",
1725 msg_type as char
1726 )));
1727 }
1728
1729 let payload_len = (raw_len - 4) as usize;
1730 let total_msg_len = 5 + payload_len;
1731
1732 if avail < total_msg_len {
1733 if total_msg_len > self.stream_buf.len() {
1734 let msg = self.read_one_message()?;
1735 match msg {
1736 BackendMessage::DataRow { data } => {
1737 f(data)?;
1738 continue;
1739 }
1740 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1741 continue;
1742 }
1743 BackendMessage::ReadyForQuery { status } => {
1744 self.tx_status = status;
1745 break 'outer;
1746 }
1747 BackendMessage::ErrorResponse { data } => {
1748 let fields = proto::parse_error_response(data);
1749 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1750 self.drain_to_ready()?;
1751 return Err(self.make_server_error(fields));
1752 }
1753 BackendMessage::NoticeResponse { .. } => continue,
1754 other => {
1755 return Err(DriverError::Protocol(format!(
1756 "unexpected message during for_each_raw: {other:?}"
1757 )));
1758 }
1759 }
1760 }
1761 break;
1762 }
1763
1764 let payload_start = self.stream_buf_pos + 5;
1765 let payload_end = payload_start + payload_len;
1766
1767 if msg_type == b'D' {
1768 f(&self.stream_buf[payload_start..payload_end])?;
1769 } else if msg_type == b'Z' {
1770 if payload_len >= 1 {
1771 self.tx_status = self.stream_buf[payload_start];
1772 }
1773 self.stream_buf_pos += total_msg_len;
1774 break 'outer;
1775 } else {
1776 self.handle_non_datarow_execute(
1777 msg_type,
1778 payload_start,
1779 payload_end,
1780 sql_hash,
1781 )?;
1782 }
1783
1784 self.stream_buf_pos += total_msg_len;
1785 }
1786
1787 self.refill_stream_buf()?;
1788 }
1789
1790 self.shrink_buffers();
1791 Ok(())
1792 }
1793
1794 #[inline]
1799 pub fn for_each_raw<F>(
1800 &mut self,
1801 sql: &str,
1802 sql_hash: u64,
1803 params: &[&(dyn Encode + Sync)],
1804 f: F,
1805 ) -> Result<(), DriverError>
1806 where
1807 F: FnMut(&[u8]) -> Result<(), DriverError>,
1808 {
1809 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1810 }
1811
1812 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1814 self.write_buf.clear();
1815 proto::write_simple_query(&mut self.write_buf, sql);
1816 self.flush_write()?;
1817
1818 loop {
1819 let msg = self.read_one_message()?;
1820 match msg {
1821 BackendMessage::ReadyForQuery { status } => {
1822 self.tx_status = status;
1823 return Ok(());
1824 }
1825 BackendMessage::CommandComplete { .. }
1826 | BackendMessage::RowDescription { .. }
1827 | BackendMessage::DataRow { .. }
1828 | BackendMessage::EmptyQuery
1829 | BackendMessage::NoticeResponse { .. }
1830 | BackendMessage::ParameterStatus { .. }
1831 | BackendMessage::AuthOk
1835 | BackendMessage::AuthSaslFinal { .. }
1836 | BackendMessage::BackendKeyData { .. } => {}
1837 BackendMessage::ErrorResponse { data } => {
1838 let fields = proto::parse_error_response(data);
1839 self.drain_to_ready()?;
1840 return Err(self.make_server_error(fields));
1841 }
1842 other => {
1843 return Err(DriverError::Protocol(format!(
1844 "unexpected message during simple_query: {other:?}"
1845 )));
1846 }
1847 }
1848 }
1849 }
1850
1851 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1853 self.write_buf.clear();
1854 proto::write_simple_query(&mut self.write_buf, sql);
1855 self.flush_write()?;
1856
1857 let mut rows: Vec<SimpleRow> = Vec::new();
1858 loop {
1859 let msg = self.read_one_message()?;
1860 match msg {
1861 BackendMessage::ReadyForQuery { status } => {
1862 self.tx_status = status;
1863 return Ok(rows);
1864 }
1865 BackendMessage::DataRow { data } => {
1866 rows.push(proto::parse_simple_data_row(data)?);
1867 }
1868 BackendMessage::RowDescription { .. }
1869 | BackendMessage::CommandComplete { .. }
1870 | BackendMessage::EmptyQuery
1871 | BackendMessage::NoticeResponse { .. }
1872 | BackendMessage::ParameterStatus { .. }
1873 | BackendMessage::AuthOk
1874 | BackendMessage::AuthSaslFinal { .. }
1875 | BackendMessage::BackendKeyData { .. } => {}
1876 BackendMessage::ErrorResponse { data } => {
1877 let fields = proto::parse_error_response(data);
1878 self.drain_to_ready()?;
1879 return Err(self.make_server_error(fields));
1880 }
1881 other => {
1882 return Err(DriverError::Protocol(format!(
1883 "unexpected message during simple_query_rows: {other:?}"
1884 )));
1885 }
1886 }
1887 }
1888 }
1889
1890 pub fn copy_in<'a, I>(
1912 &mut self,
1913 table: &str,
1914 columns: &[&str],
1915 rows: I,
1916 ) -> Result<u64, DriverError>
1917 where
1918 I: IntoIterator<Item = &'a str>,
1919 {
1920 let quoted_table = proto::quote_ident(table);
1922 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
1923 let sql = format!(
1924 "COPY {}({}) FROM STDIN",
1925 quoted_table,
1926 quoted_cols.join(",")
1927 );
1928
1929 self.write_buf.clear();
1931 proto::write_simple_query(&mut self.write_buf, &sql);
1932 self.flush_write()?;
1933
1934 loop {
1936 let msg = self.read_one_message()?;
1937 match msg {
1938 BackendMessage::CopyInResponse { .. } => break,
1939 BackendMessage::ErrorResponse { data } => {
1940 let fields = proto::parse_error_response(data);
1941 self.drain_to_ready()?;
1942 return Err(self.make_server_error(fields));
1943 }
1944 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
1945 other => {
1946 return Err(DriverError::Protocol(format!(
1947 "expected CopyInResponse, got: {other:?}"
1948 )));
1949 }
1950 }
1951 }
1952
1953 for row in rows {
1955 self.write_buf.clear();
1956 let mut row_bytes = Vec::with_capacity(row.len() + 1);
1958 row_bytes.extend_from_slice(row.as_bytes());
1959 row_bytes.push(b'\n');
1960 proto::write_copy_data(&mut self.write_buf, &row_bytes);
1961 self.flush_write()?;
1962 }
1963
1964 self.write_buf.clear();
1966 proto::write_copy_done(&mut self.write_buf);
1967 self.flush_write()?;
1968
1969 let mut count: u64 = 0;
1971 loop {
1972 let msg = self.read_one_message()?;
1973 match msg {
1974 BackendMessage::CommandComplete { tag } => {
1975 count = proto::parse_command_tag(tag);
1976 }
1977 BackendMessage::ReadyForQuery { status } => {
1978 self.tx_status = status;
1979 return Ok(count);
1980 }
1981 BackendMessage::ErrorResponse { data } => {
1982 let fields = proto::parse_error_response(data);
1983 self.drain_to_ready()?;
1984 return Err(self.make_server_error(fields));
1985 }
1986 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
1987 other => {
1988 return Err(DriverError::Protocol(format!(
1989 "unexpected message during copy_in completion: {other:?}"
1990 )));
1991 }
1992 }
1993 }
1994 }
1995
1996 pub fn copy_out<W: std::io::Write>(
2016 &mut self,
2017 query: &str,
2018 writer: &mut W,
2019 ) -> Result<u64, DriverError> {
2020 let sql = format!("COPY ({query}) TO STDOUT");
2022
2023 self.write_buf.clear();
2025 proto::write_simple_query(&mut self.write_buf, &sql);
2026 self.flush_write()?;
2027
2028 loop {
2030 let msg = self.read_one_message()?;
2031 match msg {
2032 BackendMessage::CopyOutResponse { .. } => break,
2033 BackendMessage::ErrorResponse { data } => {
2034 let fields = proto::parse_error_response(data);
2035 self.drain_to_ready()?;
2036 return Err(self.make_server_error(fields));
2037 }
2038 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2039 other => {
2040 return Err(DriverError::Protocol(format!(
2041 "expected CopyOutResponse, got: {other:?}"
2042 )));
2043 }
2044 }
2045 }
2046
2047 loop {
2049 let msg = self.read_one_message()?;
2050 match msg {
2051 BackendMessage::CopyData { data } => {
2052 writer.write_all(&data).map_err(DriverError::Io)?;
2053 }
2054 BackendMessage::CopyDone => break,
2055 BackendMessage::ErrorResponse { data } => {
2056 let fields = proto::parse_error_response(data);
2057 self.drain_to_ready()?;
2058 return Err(self.make_server_error(fields));
2059 }
2060 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2061 other => {
2062 return Err(DriverError::Protocol(format!(
2063 "unexpected message during copy_out data: {other:?}"
2064 )));
2065 }
2066 }
2067 }
2068
2069 let mut count: u64 = 0;
2071 loop {
2072 let msg = self.read_one_message()?;
2073 match msg {
2074 BackendMessage::CommandComplete { tag } => {
2075 count = proto::parse_command_tag(tag);
2076 }
2077 BackendMessage::ReadyForQuery { status } => {
2078 self.tx_status = status;
2079 return Ok(count);
2080 }
2081 BackendMessage::ErrorResponse { data } => {
2082 let fields = proto::parse_error_response(data);
2083 self.drain_to_ready()?;
2084 return Err(self.make_server_error(fields));
2085 }
2086 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2087 other => {
2088 return Err(DriverError::Protocol(format!(
2089 "unexpected message during copy_out completion: {other:?}"
2090 )));
2091 }
2092 }
2093 }
2094 }
2095
2096 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2101 self.write_buf.clear();
2102 proto::write_parse(&mut self.write_buf, "", sql, &[]);
2105 proto::write_describe(&mut self.write_buf, b'S', "");
2106 proto::write_sync(&mut self.write_buf);
2107 self.flush_write()?;
2108
2109 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2111
2112 let mut param_oids: Vec<u32> = Vec::new();
2114 let columns;
2115 loop {
2116 let msg = self.read_one_message()?;
2117 match msg {
2118 BackendMessage::ParameterDescription { data } => {
2119 param_oids = proto::parse_parameter_description(data)?;
2120 }
2121 BackendMessage::RowDescription { data } => {
2122 columns = proto::parse_row_description(data)?;
2123 break;
2124 }
2125 BackendMessage::NoData => {
2126 columns = Vec::new();
2127 break;
2128 }
2129 BackendMessage::NoticeResponse { .. } => {}
2130 BackendMessage::ErrorResponse { data } => {
2131 let fields = proto::parse_error_response(data);
2132 self.drain_to_ready()?;
2133 return Err(self.make_server_error(fields));
2134 }
2135 other => {
2136 return Err(DriverError::Protocol(format!(
2137 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2138 )));
2139 }
2140 }
2141 }
2142
2143 self.expect_ready()?;
2145
2146 Ok(PrepareResult {
2147 columns,
2148 param_oids,
2149 })
2150 }
2151
2152 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2161 loop {
2162 let (msg_type, _payload_len) = self.read_message_buffered()?;
2163 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2164 match msg {
2165 BackendMessage::NotificationResponse {
2166 channel, payload, ..
2167 } => {
2168 return Ok((channel.to_owned(), payload.to_owned()));
2169 }
2170 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2171 continue;
2172 }
2173 _ => continue,
2174 }
2175 }
2176 }
2177
2178 pub fn cancel(&self) -> Result<(), DriverError> {
2184 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2185 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2186 let mut buf = Vec::with_capacity(16);
2187 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2188 tcp.write_all(&buf).map_err(DriverError::Io)?;
2189 tcp.flush().map_err(DriverError::Io)?;
2190 drop(tcp);
2192 Ok(())
2193 }
2194
2195 pub fn set_read_timeout(
2200 &self,
2201 timeout: Option<std::time::Duration>,
2202 ) -> Result<(), DriverError> {
2203 self.stream
2204 .set_read_timeout(timeout)
2205 .map_err(DriverError::Io)
2206 }
2207
2208 pub fn query_streaming_start(
2222 &mut self,
2223 sql: &str,
2224 sql_hash: u64,
2225 params: &[&(dyn Encode + Sync)],
2226 chunk_size: i32,
2227 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2228 self.write_buf.clear();
2229
2230 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
2231 self.query_counter += 1;
2233 info.last_used = self.query_counter;
2234
2235 let can_use_template = info
2236 .bind_template
2237 .as_ref()
2238 .is_some_and(|t| t.param_slots.len() == params.len());
2239
2240 if can_use_template {
2241 let tmpl = info
2243 .bind_template
2244 .as_ref()
2245 .expect("guarded by can_use_template");
2246 self.write_buf
2249 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2250
2251 let mut template_ok = true;
2252 for (i, param) in params.iter().enumerate() {
2253 let (data_offset, old_len) = tmpl.param_slots[i];
2254 if param.is_null() {
2255 let len_offset = data_offset - 4;
2256 self.write_buf[len_offset..len_offset + 4]
2257 .copy_from_slice(&(-1i32).to_be_bytes());
2258 } else if old_len >= 0 {
2259 let end = data_offset + old_len as usize;
2260 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2261 template_ok = false;
2262 break;
2263 }
2264 } else {
2265 template_ok = false;
2266 break;
2267 }
2268 }
2269
2270 if !template_ok {
2271 self.write_buf.clear();
2272 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2273 info.bind_template = None;
2274 }
2275 } else {
2276 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2277 }
2278
2279 let cols = info.columns.clone();
2280
2281 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2282 info.bind_template = build_bind_template(&self.write_buf, params.len());
2283 }
2284
2285 proto::write_execute(&mut self.write_buf, "", chunk_size);
2286 proto::write_flush(&mut self.write_buf);
2288 self.flush_write()?;
2289
2290 cols
2291 } else {
2292 let name = make_stmt_name(sql_hash);
2294 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2295 let param_oids: smallvec::SmallVec<[u32; 8]> =
2296 params.iter().map(|p| p.type_oid()).collect();
2297 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2298 proto::write_describe(&mut self.write_buf, b'S', name_s);
2299 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2300
2301 proto::write_execute(&mut self.write_buf, "", chunk_size);
2302 proto::write_flush(&mut self.write_buf);
2303 self.flush_write()?;
2304
2305 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2306 let columns = self.read_column_description()?;
2307 self.query_counter += 1;
2308 self.cache_stmt(
2309 sql_hash,
2310 StmtInfo {
2311 name,
2312 columns: columns.clone(),
2313 last_used: self.query_counter,
2314 bind_template: None,
2315 },
2316 );
2317 columns
2318 };
2319
2320 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2322
2323 self.streaming_active = true;
2324
2325 Ok((columns, false))
2326 }
2327
2328 pub fn streaming_next_chunk(
2336 &mut self,
2337 arena: &mut Arena,
2338 all_col_offsets: &mut Vec<(usize, i32)>,
2339 ) -> Result<bool, DriverError> {
2340 all_col_offsets.clear();
2341
2342 loop {
2343 let msg = self.read_one_message()?;
2344 match msg {
2345 BackendMessage::DataRow { data } => {
2346 parse_data_row_flat(data, arena, all_col_offsets)?;
2347 }
2348 BackendMessage::PortalSuspended => {
2349 return Ok(true);
2353 }
2354 BackendMessage::CommandComplete { .. } => {
2355 self.write_buf.clear();
2358 proto::write_sync(&mut self.write_buf);
2359 self.flush_write()?;
2360 self.expect_ready()?;
2361 self.shrink_buffers();
2362
2363 self.streaming_active = false;
2364 return Ok(false);
2365 }
2366 BackendMessage::EmptyQuery => {
2367 self.write_buf.clear();
2368 proto::write_sync(&mut self.write_buf);
2369 self.flush_write()?;
2370 self.expect_ready()?;
2371
2372 self.streaming_active = false;
2373 return Ok(false);
2374 }
2375 BackendMessage::ErrorResponse { data } => {
2376 let fields = proto::parse_error_response(data);
2377 self.write_buf.clear();
2379 proto::write_sync(&mut self.write_buf);
2380 self.flush_write()?;
2381 self.drain_to_ready()?;
2382
2383 self.streaming_active = false;
2384 return Err(self.make_server_error(fields));
2385 }
2386 BackendMessage::NoticeResponse { .. } => {}
2387 other => {
2388 return Err(DriverError::Protocol(format!(
2389 "unexpected message during streaming: {other:?}"
2390 )));
2391 }
2392 }
2393 }
2394 }
2395
2396 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2404 self.write_buf.clear();
2405 proto::write_execute(&mut self.write_buf, "", chunk_size);
2406 proto::write_flush(&mut self.write_buf);
2407 self.flush_write()
2408 }
2409
2410 pub fn is_streaming(&self) -> bool {
2412 self.streaming_active
2413 }
2414
2415 pub fn close(mut self) -> Result<(), DriverError> {
2417 self.write_buf.clear();
2418 proto::write_terminate(&mut self.write_buf);
2419 let _ = self.flush_write();
2420 Ok(())
2421 }
2422
2423 pub fn is_idle(&self) -> bool {
2427 self.tx_status == b'I'
2428 }
2429
2430 pub fn is_in_transaction(&self) -> bool {
2432 self.tx_status == b'T'
2433 }
2434
2435 pub fn is_in_failed_transaction(&self) -> bool {
2437 self.tx_status == b'E'
2438 }
2439
2440 pub fn touch(&mut self) {
2442 self.last_used = std::time::Instant::now();
2443 }
2444
2445 pub fn idle_duration(&self) -> std::time::Duration {
2447 self.last_used.elapsed()
2448 }
2449
2450 pub fn query_counter(&self) -> u64 {
2452 self.query_counter
2453 }
2454
2455 pub fn parameter(&self, name: &str) -> Option<&str> {
2457 self.params
2458 .iter()
2459 .find(|(k, _)| &**k == name)
2460 .map(|(_, v)| &**v)
2461 }
2462
2463 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2465 &self.params
2466 }
2467
2468 pub fn pid(&self) -> i32 {
2470 self.pid
2471 }
2472
2473 pub fn secret_key(&self) -> i32 {
2475 self.secret
2476 }
2477
2478 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2480 std::mem::take(&mut self.pending_notifications)
2481 }
2482
2483 pub fn pending_notification_count(&self) -> usize {
2485 self.pending_notifications.len()
2486 }
2487
2488 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2490 self.max_stmt_cache_size = size;
2491 }
2492
2493 pub fn stmt_cache_len(&self) -> usize {
2495 self.stmts.len()
2496 }
2497
2498 pub fn created_at(&self) -> std::time::Instant {
2500 self.created_at
2501 }
2502
2503 #[inline]
2511 fn send_pipeline(
2512 &mut self,
2513 sql: &str,
2514 sql_hash: u64,
2515 params: &[&(dyn Encode + Sync)],
2516 need_columns: bool,
2517 skip_bind_complete: bool,
2518 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
2519 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2520
2521 if params.len() > i16::MAX as usize {
2522 return Err(DriverError::Protocol(format!(
2523 "parameter count {} exceeds maximum {}",
2524 params.len(),
2525 i16::MAX
2526 )));
2527 }
2528
2529 self.write_buf.clear();
2530
2531 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
2532 self.query_counter += 1;
2534 info.last_used = self.query_counter;
2535
2536 let can_use_template = info
2537 .bind_template
2538 .as_ref()
2539 .is_some_and(|t| t.param_slots.len() == params.len());
2540
2541 let mut has_exec_sync = false;
2543
2544 if can_use_template {
2545 let tmpl = info
2549 .bind_template
2550 .as_ref()
2551 .expect("guarded by can_use_template");
2552 self.write_buf.extend_from_slice(&tmpl.bytes);
2553
2554 let mut template_ok = true;
2555 for (i, param) in params.iter().enumerate() {
2556 let (data_offset, old_len) = tmpl.param_slots[i];
2557 if param.is_null() {
2558 let len_offset = data_offset - 4;
2560 self.write_buf[len_offset..len_offset + 4]
2561 .copy_from_slice(&(-1i32).to_be_bytes());
2562 } else if old_len >= 0 {
2563 let end = data_offset + old_len as usize;
2564 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2565 template_ok = false;
2567 break;
2568 }
2569 } else {
2570 template_ok = false;
2573 break;
2574 }
2575 }
2576
2577 if template_ok {
2578 has_exec_sync = true; } else {
2580 self.write_buf.clear();
2581 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2582 info.bind_template = None;
2584 }
2585 } else {
2586 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2587 }
2588
2589 let cols = if need_columns {
2590 Some(info.columns.clone())
2591 } else {
2592 None
2593 };
2594
2595 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2599 info.bind_template = build_bind_template(&self.write_buf, params.len());
2600 }
2601
2602 if !has_exec_sync {
2603 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2604 }
2605 self.flush_write()?;
2606
2607 cols
2608 } else {
2609 let name = make_stmt_name(sql_hash);
2611 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2612 let param_oids: smallvec::SmallVec<[u32; 8]> =
2613 params.iter().map(|p| p.type_oid()).collect();
2614 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2615 proto::write_describe(&mut self.write_buf, b'S', name_s);
2616 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2617
2618 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2619 self.flush_write()?;
2620
2621 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2622 let columns = self.read_column_description()?;
2623 self.query_counter += 1;
2624 self.cache_stmt(
2625 sql_hash,
2626 StmtInfo {
2627 name,
2628 columns: columns.clone(),
2629 last_used: self.query_counter,
2630 bind_template: None,
2631 },
2632 );
2633 if need_columns { Some(columns) } else { None }
2634 };
2635
2636 if !skip_bind_complete {
2637 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2638 }
2639
2640 Ok(columns)
2641 }
2642
2643 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2645 loop {
2646 let msg = self.read_one_message()?;
2647 match msg {
2648 BackendMessage::RowDescription { data } => {
2649 let cols = proto::parse_row_description(data)?;
2650 return Ok(cols.into());
2651 }
2652 BackendMessage::ParameterDescription { .. } => {}
2653 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2654 BackendMessage::NoticeResponse { .. } => {}
2655 BackendMessage::ErrorResponse { data } => {
2656 let fields = proto::parse_error_response(data);
2657 self.drain_to_ready()?;
2658 return Err(self.make_server_error(fields));
2659 }
2660 other => {
2661 return Err(DriverError::Protocol(format!(
2662 "expected RowDescription/NoData, got: {other:?}"
2663 )));
2664 }
2665 }
2666 }
2667 }
2668
2669 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2672 if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
2673 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2674 proto::write_close(&mut self.write_buf, b'S', evicted.name_str());
2675 }
2676 }
2677 self.stmts.insert(sql_hash, info);
2678 }
2679
2680 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2681 if self.pending_notifications.len() < 1024 {
2682 self.pending_notifications.push(Notification {
2683 pid,
2684 channel: channel.to_owned(),
2685 payload: payload.to_owned(),
2686 });
2687 }
2688 }
2689
2690 fn shrink_buffers(&mut self) {
2691 if self.query_counter & 63 != 0 {
2695 return;
2696 }
2697 if self.read_buf.capacity() > 64 * 1024 {
2698 self.read_buf.clear();
2699 self.read_buf.shrink_to(8192);
2700 }
2701 if self.write_buf.capacity() > 16 * 1024 {
2702 self.write_buf.clear();
2703 self.write_buf.shrink_to(8192);
2704 }
2705 }
2706
2707 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2708 if &*fields.code == "26000" {
2709 self.stmts.remove(&sql_hash);
2710 true
2711 } else {
2712 false
2713 }
2714 }
2715
2716 #[cold]
2717 #[inline(never)]
2718 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2719 DriverError::Server {
2720 code: fields.code,
2721 message: fields.message.into_boxed_str(),
2722 detail: fields.detail.map(String::into_boxed_str),
2723 hint: fields.hint.map(String::into_boxed_str),
2724 position: fields.position,
2725 }
2726 }
2727
2728 #[cold]
2734 #[inline(never)]
2735 fn handle_non_datarow_query(
2736 &mut self,
2737 msg_type: u8,
2738 payload_start: usize,
2739 payload_end: usize,
2740 sql_hash: u64,
2741 affected_rows: &mut u64,
2742 ) -> Result<(), DriverError> {
2743 match msg_type {
2744 b'2' | b'I' => {} b'C' => {
2746 *affected_rows =
2747 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2748 }
2749 b'E' => {
2750 let fields =
2751 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2752 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2753 self.drain_to_ready()?;
2754 return Err(self.make_server_error(fields));
2755 }
2756 b'A' => {
2757 let msg = proto::parse_backend_message(
2758 msg_type,
2759 &self.stream_buf[payload_start..payload_end],
2760 )?;
2761 if let BackendMessage::NotificationResponse {
2762 pid,
2763 channel,
2764 payload,
2765 } = msg
2766 {
2767 let ch = channel.to_owned();
2768 let pl = payload.to_owned();
2769 self.buffer_notification(pid, &ch, &pl);
2770 }
2771 }
2772 _ => {} }
2774 Ok(())
2775 }
2776
2777 #[cold]
2780 #[inline(never)]
2781 fn handle_non_datarow_execute(
2782 &mut self,
2783 msg_type: u8,
2784 payload_start: usize,
2785 payload_end: usize,
2786 sql_hash: u64,
2787 ) -> Result<(), DriverError> {
2788 match msg_type {
2789 b'2' | b'C' | b'I' => {} b'E' => {
2791 let fields =
2792 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2793 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2794 self.drain_to_ready()?;
2795 return Err(self.make_server_error(fields));
2796 }
2797 b'A' => {
2798 let msg = proto::parse_backend_message(
2799 msg_type,
2800 &self.stream_buf[payload_start..payload_end],
2801 )?;
2802 if let BackendMessage::NotificationResponse {
2803 pid,
2804 channel,
2805 payload,
2806 } = msg
2807 {
2808 let ch = channel.to_owned();
2809 let pl = payload.to_owned();
2810 self.buffer_notification(pid, &ch, &pl);
2811 }
2812 }
2813 _ => {} }
2815 Ok(())
2816 }
2817
2818 #[inline]
2820 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2821 loop {
2822 let (msg_type, _payload_len) = self.read_message_buffered()?;
2823 if msg_type == b'A' {
2824 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2825 if let BackendMessage::NotificationResponse {
2826 pid,
2827 channel,
2828 payload,
2829 } = msg
2830 {
2831 let pid_owned = pid;
2832 let channel_owned = channel.to_owned();
2833 let payload_owned = payload.to_owned();
2834 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2835 continue;
2836 }
2837 }
2838 return proto::parse_backend_message(msg_type, &self.read_buf);
2839 }
2840 }
2841
2842 fn expect_message(
2843 &mut self,
2844 pred: impl Fn(&BackendMessage<'_>) -> bool,
2845 ) -> Result<(), DriverError> {
2846 loop {
2847 let msg = self.read_one_message()?;
2848 if pred(&msg) {
2849 return Ok(());
2850 }
2851 match msg {
2852 BackendMessage::ErrorResponse { data } => {
2853 let fields = proto::parse_error_response(data);
2854 self.drain_to_ready()?;
2855 return Err(self.make_server_error(fields));
2856 }
2857 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2858 other => {
2859 return Err(DriverError::Protocol(format!(
2860 "unexpected message while waiting for expected type: {other:?}"
2861 )));
2862 }
2863 }
2864 }
2865 }
2866
2867 fn expect_ready(&mut self) -> Result<(), DriverError> {
2868 loop {
2869 let msg = self.read_one_message()?;
2870 match msg {
2871 BackendMessage::ReadyForQuery { status } => {
2872 self.tx_status = status;
2873 return Ok(());
2874 }
2875 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2876 BackendMessage::ErrorResponse { data } => {
2877 let fields = proto::parse_error_response(data);
2878 self.drain_to_ready()?;
2879 return Err(self.make_server_error(fields));
2880 }
2881 _ => {}
2882 }
2883 }
2884 }
2885
2886 #[inline]
2887 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2888 loop {
2889 let msg = self.read_one_message()?;
2890 if let BackendMessage::ReadyForQuery { status } = msg {
2891 self.tx_status = status;
2892 return Ok(());
2893 }
2894 }
2895 }
2896
2897 #[inline]
2901 fn flush_write(&mut self) -> Result<(), DriverError> {
2902 self.stream
2903 .write_all(&self.write_buf)
2904 .map_err(DriverError::Io)
2905 }
2906
2907 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2911 let mut header = [0u8; 5];
2912 sync_buffered_read_exact(
2913 &mut self.stream,
2914 &mut self.stream_buf,
2915 &mut self.stream_buf_pos,
2916 &mut self.stream_buf_end,
2917 &mut header,
2918 )?;
2919
2920 let msg_type = header[0];
2921 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2922
2923 if len < 4 {
2924 return Err(DriverError::Protocol(format!(
2925 "invalid message length {len} for type '{}'",
2926 msg_type as char
2927 )));
2928 }
2929
2930 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2931 if len > MAX_MESSAGE_LEN {
2932 return Err(DriverError::Protocol(format!(
2933 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2934 msg_type as char
2935 )));
2936 }
2937
2938 let payload_len = (len - 4) as usize;
2939 self.read_buf.clear();
2940 self.read_buf.resize(payload_len, 0);
2941 if payload_len > 0 {
2942 sync_buffered_read_exact(
2943 &mut self.stream,
2944 &mut self.stream_buf,
2945 &mut self.stream_buf_pos,
2946 &mut self.stream_buf_end,
2947 &mut self.read_buf[..payload_len],
2948 )?;
2949 }
2950
2951 Ok((msg_type, payload_len))
2952 }
2953
2954 #[inline]
2956 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
2957 let remaining = self.stream_buf_end - self.stream_buf_pos;
2958 if remaining > 0 && self.stream_buf_pos > 0 {
2959 self.stream_buf
2960 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2961 }
2962 self.stream_buf_pos = 0;
2963 self.stream_buf_end = remaining;
2964
2965 let n = self
2966 .stream
2967 .read(&mut self.stream_buf[remaining..])
2968 .map_err(DriverError::Io)?;
2969 if n == 0 {
2970 return Err(DriverError::Io(std::io::Error::new(
2971 std::io::ErrorKind::UnexpectedEof,
2972 "connection closed",
2973 )));
2974 }
2975 self.stream_buf_end = remaining + n;
2976 Ok(())
2977 }
2978}
2979
2980fn sync_buffered_read_exact(
2983 stream: &mut Stream,
2984 buf: &mut [u8],
2985 pos: &mut usize,
2986 end: &mut usize,
2987 out: &mut [u8],
2988) -> Result<(), DriverError> {
2989 let mut filled = 0;
2990 while filled < out.len() {
2991 let avail = *end - *pos;
2992 if avail > 0 {
2993 let take = avail.min(out.len() - filled);
2994 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
2995 *pos += take;
2996 filled += take;
2997 } else {
2998 *pos = 0;
2999 let n = stream.read(buf).map_err(DriverError::Io)?;
3000 if n == 0 {
3001 return Err(DriverError::Io(std::io::Error::new(
3002 std::io::ErrorKind::UnexpectedEof,
3003 "connection closed",
3004 )));
3005 }
3006 *end = n;
3007 }
3008 }
3009 Ok(())
3010}
3011
3012#[inline(always)]
3022fn parse_data_row_into_buf(
3023 data: &[u8],
3024 buf: &mut Vec<u8>,
3025 out: &mut Vec<(usize, i32)>,
3026) -> Result<(), DriverError> {
3027 if data.len() < 2 {
3028 return Err(DriverError::Protocol("DataRow too short".into()));
3029 }
3030
3031 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3032 if num_cols < 0 {
3033 return Err(DriverError::Protocol(
3034 "DataRow: negative column count".into(),
3035 ));
3036 }
3037 let num_cols = num_cols as usize;
3038
3039 let col_data = &data[2..];
3042 let base = buf.len();
3043 buf.extend_from_slice(col_data);
3044
3045 let mut pos: usize = 0;
3047 for _ in 0..num_cols {
3048 if pos + 4 > col_data.len() {
3049 return Err(DriverError::Protocol("DataRow truncated".into()));
3050 }
3051
3052 let col_len = i32::from_be_bytes([
3053 col_data[pos],
3054 col_data[pos + 1],
3055 col_data[pos + 2],
3056 col_data[pos + 3],
3057 ]);
3058 pos += 4;
3059
3060 if col_len < 0 {
3061 out.push((0, -1));
3062 } else {
3063 let len = col_len as usize;
3064 if pos + len > col_data.len() {
3065 return Err(DriverError::Protocol(
3066 "DataRow column data truncated".into(),
3067 ));
3068 }
3069 out.push((base + pos, col_len));
3071 pos += len;
3072 }
3073 }
3074
3075 Ok(())
3076}
3077
3078fn parse_data_row_flat(
3082 data: &[u8],
3083 arena: &mut Arena,
3084 out: &mut Vec<(usize, i32)>,
3085) -> Result<(), DriverError> {
3086 if data.len() < 2 {
3087 return Err(DriverError::Protocol("DataRow too short".into()));
3088 }
3089
3090 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3091 if num_cols_raw < 0 {
3092 return Err(DriverError::Protocol(
3093 "DataRow: negative column count".into(),
3094 ));
3095 }
3096 let num_cols = num_cols_raw as usize;
3097 out.reserve(num_cols);
3098
3099 let col_data = &data[2..];
3102 let base = arena.alloc_copy(col_data);
3103
3104 let mut pos: usize = 0;
3106 for _ in 0..num_cols {
3107 if pos + 4 > col_data.len() {
3108 return Err(DriverError::Protocol("DataRow truncated".into()));
3109 }
3110
3111 let col_len = i32::from_be_bytes([
3112 col_data[pos],
3113 col_data[pos + 1],
3114 col_data[pos + 2],
3115 col_data[pos + 3],
3116 ]);
3117 pos += 4;
3118
3119 if col_len < 0 {
3120 out.push((0, -1));
3121 } else {
3122 let len = col_len as usize;
3123 if pos + len > col_data.len() {
3124 return Err(DriverError::Protocol(
3125 "DataRow column data truncated".into(),
3126 ));
3127 }
3128 out.push((base + pos, col_len));
3130 pos += len;
3131 }
3132 }
3133
3134 Ok(())
3135}
3136
3137#[cfg(test)]
3138#[allow(clippy::approx_constant)]
3139mod tests {
3140 use super::*;
3141 use crate::types::hash_sql;
3142
3143 #[test]
3144 fn sync_config_tcp_no_longer_rejected() {
3145 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3148 let result = Connection::connect(&config);
3149 assert!(result.is_err());
3150 let err = result.unwrap_err().to_string();
3151 assert!(
3154 !err.contains("Unix domain socket"),
3155 "error should NOT mention UDS requirement: {err}"
3156 );
3157 }
3158
3159 #[test]
3160 fn sync_data_row_parsing() {
3161 let mut arena = Arena::new();
3162 let mut out = Vec::new();
3163
3164 let mut data = Vec::new();
3165 data.extend_from_slice(&2i16.to_be_bytes());
3166 data.extend_from_slice(&4i32.to_be_bytes());
3167 data.extend_from_slice(&42i32.to_be_bytes());
3168 data.extend_from_slice(&(-1i32).to_be_bytes());
3169
3170 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3171 assert_eq!(out.len(), 2);
3172 assert_eq!(out[0].1, 4);
3173 assert_eq!(out[1].1, -1);
3174 }
3175
3176 #[test]
3177 fn sync_data_row_empty() {
3178 let mut arena = Arena::new();
3179 let mut out = Vec::new();
3180 let data = 0i16.to_be_bytes();
3181 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3182 assert_eq!(out.len(), 0);
3183 }
3184
3185 #[test]
3186 fn sync_data_row_too_short() {
3187 let mut arena = Arena::new();
3188 let mut out = Vec::new();
3189 let data = vec![0u8];
3190 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3191 }
3192
3193 #[test]
3194 fn sync_data_row_negative_col_count() {
3195 let mut arena = Arena::new();
3196 let mut out = Vec::new();
3197 let data = (-1i16).to_be_bytes();
3198 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3199 }
3200
3201 #[test]
3202 fn sync_data_row_truncated() {
3203 let mut arena = Arena::new();
3204 let mut out = Vec::new();
3205 let mut data = Vec::new();
3206 data.extend_from_slice(&2i16.to_be_bytes());
3207 data.extend_from_slice(&4i32.to_be_bytes());
3208 data.extend_from_slice(&42i32.to_be_bytes());
3209 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3211 }
3212
3213 #[test]
3214 fn sync_data_row_col_data_truncated() {
3215 let mut arena = Arena::new();
3216 let mut out = Vec::new();
3217 let mut data = Vec::new();
3218 data.extend_from_slice(&1i16.to_be_bytes());
3219 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3222 }
3223
3224 #[test]
3227 fn sync_connect_tcp_unreachable_port() {
3228 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3231 let result = Connection::connect(&config);
3232 assert!(result.is_err());
3233 let err = result.unwrap_err().to_string();
3234 assert!(
3235 !err.contains("Unix domain socket"),
3236 "error should NOT mention UDS: {err}"
3237 );
3238 }
3239
3240 #[test]
3241 fn sync_connect_ip_address_attempts_tcp() {
3242 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3245 let result = Connection::connect(&config);
3246 assert!(result.is_err());
3247 }
3248
3249 #[test]
3252 fn sync_data_row_all_null() {
3253 let mut arena = Arena::new();
3254 let mut out = Vec::new();
3255 let mut data = Vec::new();
3256 data.extend_from_slice(&3i16.to_be_bytes());
3257 data.extend_from_slice(&(-1i32).to_be_bytes());
3258 data.extend_from_slice(&(-1i32).to_be_bytes());
3259 data.extend_from_slice(&(-1i32).to_be_bytes());
3260 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3261 assert_eq!(out.len(), 3);
3262 for (_, len) in &out {
3263 assert_eq!(*len, -1);
3264 }
3265 }
3266
3267 #[test]
3268 fn sync_data_row_long_text() {
3269 let mut arena = Arena::new();
3270 let mut out = Vec::new();
3271 let long_text = "a".repeat(2048);
3272 let text_bytes = long_text.as_bytes();
3273 let mut data = Vec::new();
3274 data.extend_from_slice(&1i16.to_be_bytes());
3275 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3276 data.extend_from_slice(text_bytes);
3277 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3278 assert_eq!(out.len(), 1);
3279 assert_eq!(out[0].1, text_bytes.len() as i32);
3280 let stored = arena.get(out[0].0, out[0].1 as usize);
3281 assert_eq!(stored, text_bytes);
3282 }
3283
3284 #[test]
3285 fn sync_data_row_empty_text() {
3286 let mut arena = Arena::new();
3287 let mut out = Vec::new();
3288 let mut data = Vec::new();
3289 data.extend_from_slice(&1i16.to_be_bytes());
3290 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3292 assert_eq!(out.len(), 1);
3293 assert_eq!(out[0].1, 0); }
3295
3296 #[test]
3297 fn sync_data_row_17_columns_exceeds_smallvec() {
3298 let mut arena = Arena::new();
3299 let mut out = Vec::new();
3300 let mut data = Vec::new();
3301 let num_cols: i16 = 20;
3302 data.extend_from_slice(&num_cols.to_be_bytes());
3303 for i in 0..num_cols {
3304 let val = (i as i32).to_be_bytes();
3305 data.extend_from_slice(&4i32.to_be_bytes());
3306 data.extend_from_slice(&val);
3307 }
3308 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3309 assert_eq!(out.len(), 20);
3310 for (idx, (offset, len)) in out.iter().enumerate() {
3311 assert_eq!(*len, 4);
3312 let stored = arena.get(*offset, 4);
3313 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3314 assert_eq!(val, idx as i32);
3315 }
3316 }
3317
3318 #[test]
3319 fn sync_data_row_mixed_null_and_data() {
3320 let mut arena = Arena::new();
3321 let mut out = Vec::new();
3322 let mut data = Vec::new();
3323 data.extend_from_slice(&5i16.to_be_bytes());
3324 data.extend_from_slice(&(-1i32).to_be_bytes());
3326 data.extend_from_slice(&4i32.to_be_bytes());
3328 data.extend_from_slice(&42i32.to_be_bytes());
3329 data.extend_from_slice(&(-1i32).to_be_bytes());
3331 data.extend_from_slice(&(-1i32).to_be_bytes());
3333 data.extend_from_slice(&5i32.to_be_bytes());
3335 data.extend_from_slice(b"hello");
3336
3337 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3338 assert_eq!(out.len(), 5);
3339 assert_eq!(out[0].1, -1);
3340 assert_eq!(out[1].1, 4);
3341 assert_eq!(out[2].1, -1);
3342 assert_eq!(out[3].1, -1);
3343 assert_eq!(out[4].1, 5);
3344 let stored = arena.get(out[4].0, 5);
3345 assert_eq!(stored, b"hello");
3346 }
3347
3348 #[test]
3351 #[ignore] fn sync_connect_uds_if_pg_available() {
3353 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3354 let result = Connection::connect(&config);
3355 if let Ok(conn) = result {
3357 assert!(conn.pid() != 0, "pid should be nonzero");
3358 assert!(conn.is_idle(), "should start idle");
3359 assert!(!conn.is_in_transaction(), "should not be in tx");
3360 assert!(
3361 !conn.is_in_failed_transaction(),
3362 "should not be in failed tx"
3363 );
3364 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3365 let _ = conn.close();
3366 }
3367 }
3368
3369 #[test]
3370 #[ignore] fn sync_simple_query_if_pg_available() {
3372 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3373 let mut conn = Connection::connect(&config).unwrap();
3374 conn.simple_query("SELECT 1").unwrap();
3375 assert!(conn.is_idle());
3376 let _ = conn.close();
3377 }
3378
3379 #[test]
3380 #[ignore] fn sync_query_with_params_if_pg_available() {
3382 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3383 let mut conn = Connection::connect(&config).unwrap();
3384 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3385 let hash = hash_sql(sql);
3386 let a: i32 = 10;
3387 let b: i32 = 20;
3388 let result = conn
3389 .query(
3390 sql,
3391 hash,
3392 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3393 )
3394 .unwrap();
3395 assert_eq!(result.len(), 1);
3396 let _ = conn.close();
3397 }
3398
3399 #[test]
3400 #[ignore] fn sync_execute_if_pg_available() {
3402 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3403 let mut conn = Connection::connect(&config).unwrap();
3404 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3405 .unwrap();
3406 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3407 let hash = hash_sql(sql);
3408 let val: i32 = 42;
3409 let affected = conn
3410 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3411 .unwrap();
3412 assert_eq!(affected, 1);
3413 let _ = conn.close();
3414 }
3415
3416 #[test]
3417 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3419 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3420 let mut conn = Connection::connect(&config).unwrap();
3421 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3422 .unwrap();
3423 let sql = "SELECT id FROM _sync_fe0";
3424 let hash = hash_sql(sql);
3425 let mut count = 0u32;
3426 conn.for_each(sql, hash, &[], |_row| {
3427 count += 1;
3428 Ok(())
3429 })
3430 .unwrap();
3431 assert_eq!(count, 0);
3432 let _ = conn.close();
3433 }
3434
3435 #[test]
3436 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3438 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3439 let mut conn = Connection::connect(&config).unwrap();
3440 let sql = "SELECT generate_series(1, 5)";
3441 let hash = hash_sql(sql);
3442 let mut count = 0u32;
3443 conn.for_each(sql, hash, &[], |_row| {
3444 count += 1;
3445 Ok(())
3446 })
3447 .unwrap();
3448 assert_eq!(count, 5);
3449 let _ = conn.close();
3450 }
3451
3452 #[test]
3453 #[ignore] fn sync_prepare_only_if_pg_available() {
3455 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3456 let mut conn = Connection::connect(&config).unwrap();
3457 let sql = "SELECT 1";
3458 let hash = hash_sql(sql);
3459 conn.prepare_only(sql, hash).unwrap();
3460 assert_eq!(conn.stmt_cache_len(), 1);
3461 conn.prepare_only(sql, hash).unwrap();
3463 assert_eq!(conn.stmt_cache_len(), 1);
3464 let _ = conn.close();
3465 }
3466
3467 #[test]
3468 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3470 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3471 let mut conn = Connection::connect(&config).unwrap();
3472 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3473 assert!(!rows.is_empty());
3474 let _ = conn.close();
3475 }
3476
3477 #[test]
3478 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3480 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3481 let mut conn = Connection::connect(&config).unwrap();
3482 let sql1 = "SELECT 1";
3483 let hash1 = hash_sql(sql1);
3484 conn.query(sql1, hash1, &[]).unwrap();
3485 assert_eq!(conn.stmt_cache_len(), 1);
3486 conn.query(sql1, hash1, &[]).unwrap();
3488 assert_eq!(conn.stmt_cache_len(), 1);
3489 let sql2 = "SELECT 2";
3491 let hash2 = hash_sql(sql2);
3492 conn.query(sql2, hash2, &[]).unwrap();
3493 assert_eq!(conn.stmt_cache_len(), 2);
3494 let _ = conn.close();
3495 }
3496
3497 #[test]
3498 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3500 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3501 let mut conn = Connection::connect(&config).unwrap();
3502 let sql = "SELECTTTT INVALID GARBAGE";
3503 let hash = hash_sql(sql);
3504 let result = conn.query(sql, hash, &[]);
3505 assert!(result.is_err());
3506 assert!(conn.is_idle());
3508 let _ = conn.close();
3509 }
3510
3511 #[test]
3512 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3514 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3515 let mut conn = Connection::connect(&config).unwrap();
3516 assert!(conn.is_idle());
3517 assert!(!conn.is_in_transaction());
3518 conn.simple_query("BEGIN").unwrap();
3519 assert!(conn.is_in_transaction());
3520 assert!(!conn.is_idle());
3521 conn.simple_query("COMMIT").unwrap();
3522 assert!(conn.is_idle());
3523 assert!(!conn.is_in_transaction());
3524 let _ = conn.close();
3525 }
3526
3527 #[test]
3528 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3530 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3531 let mut conn = Connection::connect(&config).unwrap();
3532 conn.set_max_stmt_cache_size(3);
3533 for i in 0..5 {
3534 let sql = format!("SELECT {}", i);
3535 let hash = hash_sql(&sql);
3536 conn.query(&sql, hash, &[]).unwrap();
3537 }
3538 assert!(
3540 conn.stmt_cache_len() <= 3,
3541 "cache should be capped at 3, got {}",
3542 conn.stmt_cache_len()
3543 );
3544 let _ = conn.close();
3545 }
3546
3547 #[test]
3548 #[ignore] fn sync_for_each_raw_if_pg_available() {
3550 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3551 let mut conn = Connection::connect(&config).unwrap();
3552 let sql = "SELECT generate_series(1, 3)";
3553 let hash = hash_sql(sql);
3554 let mut raw_count = 0u32;
3555 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3556 raw_count += 1;
3557 Ok(())
3558 })
3559 .unwrap();
3560 assert_eq!(raw_count, 3);
3561 let _ = conn.close();
3562 }
3563
3564 #[test]
3565 #[ignore] fn sync_query_null_params_if_pg_available() {
3567 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3568 let mut conn = Connection::connect(&config).unwrap();
3569 let sql = "SELECT $1::int4 IS NULL AS is_null";
3570 let hash = hash_sql(sql);
3571 let val: Option<i32> = None;
3572 let _result = conn
3573 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
3574 .unwrap();
3575 let _ = conn.close();
3576 }
3577
3578 #[test]
3579 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3581 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3582 let mut conn = Connection::connect(&config).unwrap();
3583 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3584 let hash = hash_sql(sql);
3585 let p1: i32 = 42;
3586 let p2: i64 = 9999999;
3587 let p3: &str = "hello";
3588 let p4: bool = true;
3589 let p5: f64 = 3.14;
3590 let result = conn
3591 .query(
3592 sql,
3593 hash,
3594 &[
3595 &p1 as &(dyn Encode + Sync),
3596 &p2 as &(dyn Encode + Sync),
3597 &p3 as &(dyn Encode + Sync),
3598 &p4 as &(dyn Encode + Sync),
3599 &p5 as &(dyn Encode + Sync),
3600 ],
3601 )
3602 .unwrap();
3603 assert_eq!(result.len(), 1);
3604 let _ = conn.close();
3605 }
3606
3607 #[test]
3610 fn sync_shrink_threshold_values() {
3611 let shrink = 64 * 1024usize;
3620 let initial = 8192usize;
3621 assert!(
3622 shrink > initial,
3623 "shrink threshold must exceed initial size"
3624 );
3625 }
3626
3627 #[test]
3630 fn sync_connection_debug_format() {
3631 let fmt_str = format!(
3635 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3636 0, 'I', 0
3637 );
3638 assert!(fmt_str.contains("Connection"));
3639 assert!(fmt_str.contains("pid"));
3640 assert!(fmt_str.contains("tx_status"));
3641 }
3642
3643 #[test]
3646 fn sync_connect_sslmode_require_without_tls_feature() {
3647 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3651 config.ssl = SslMode::Require;
3652 let result = Connection::connect(&config);
3653 assert!(result.is_err());
3654 }
3659
3660 #[test]
3661 fn sync_connect_sslmode_disable_attempts_tcp() {
3662 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3663 config.ssl = SslMode::Disable;
3664 let result = Connection::connect(&config);
3665 assert!(result.is_err());
3666 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
3668 }
3669
3670 #[test]
3671 fn sync_connect_sslmode_prefer_attempts_tcp() {
3672 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3673 config.ssl = SslMode::Prefer;
3674 let result = Connection::connect(&config);
3675 assert!(result.is_err());
3676 }
3677
3678 #[test]
3681 #[ignore] fn sync_streaming_basic_if_pg_available() {
3683 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3684 let mut conn = Connection::connect(&config).unwrap();
3685 assert!(!conn.is_streaming());
3686
3687 let sql = "SELECT generate_series(1, 10)";
3688 let hash = hash_sql(sql);
3689
3690 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
3691 assert!(!cols.is_empty());
3692 assert!(conn.is_streaming());
3693
3694 let mut arena = Arena::new();
3695 let mut offsets = Vec::new();
3696 let mut total_rows = 0;
3697
3698 loop {
3700 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
3701 total_rows += offsets.len();
3702 if !has_more {
3703 break;
3704 }
3705 conn.streaming_send_execute(3).unwrap();
3706 }
3707
3708 assert_eq!(total_rows, 10);
3709 assert!(!conn.is_streaming());
3710 let _ = conn.close();
3711 }
3712
3713 #[test]
3716 #[ignore] fn sync_prepare_describe_if_pg_available() {
3718 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3719 let mut conn = Connection::connect(&config).unwrap();
3720
3721 let result = conn
3722 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
3723 .unwrap();
3724 assert_eq!(result.columns.len(), 1);
3725 assert_eq!(&*result.columns[0].name, "sum");
3726 assert_eq!(result.param_oids.len(), 2);
3727 let _ = conn.close();
3728 }
3729
3730 #[test]
3733 #[ignore] fn sync_wait_for_notification_if_pg_available() {
3735 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3736 let mut conn = Connection::connect(&config).unwrap();
3737
3738 conn.simple_query("LISTEN test_chan").unwrap();
3739 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
3740
3741 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
3743 .unwrap();
3744
3745 let (channel, payload) = conn.wait_for_notification().unwrap();
3746 assert_eq!(channel, "test_chan");
3747 assert_eq!(payload, "hello");
3748 let _ = conn.close();
3749 }
3750
3751 #[test]
3754 #[ignore] fn sync_cancel_if_pg_available() {
3756 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3757 let conn = Connection::connect(&config).unwrap();
3758 let result = conn.cancel();
3761 let _ = result;
3763 let _ = conn.close();
3764 }
3765
3766 #[test]
3769 #[ignore] fn sync_server_params_if_pg_available() {
3771 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3772 let conn = Connection::connect(&config).unwrap();
3773 let params = conn.server_params();
3774 assert!(
3775 !params.is_empty(),
3776 "server should send parameters during startup"
3777 );
3778 assert!(
3780 conn.parameter("server_encoding").is_some(),
3781 "server_encoding should be present"
3782 );
3783 let _ = conn.close();
3784 }
3785
3786 #[test]
3789 #[ignore] fn sync_set_read_timeout_if_pg_available() {
3791 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3792 let conn = Connection::connect(&config).unwrap();
3793 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
3795 .unwrap();
3796 conn.set_read_timeout(None).unwrap();
3797 let _ = conn.close();
3798 }
3799}