1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::DriverError;
18use crate::arena::Arena;
19use crate::auth;
20use crate::codec::Encode;
21use crate::proto::{self, BackendMessage};
22use crate::stmt_cache::{StmtCache, StmtInfo, build_bind_template, make_stmt_name};
23use crate::sync_io::Stream;
24use crate::types::{
25 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
26 StartupAction,
27};
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42pub(crate) fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58pub struct Connection {
87 stream: Stream,
88 read_buf: Vec<u8>,
89 stream_buf: Vec<u8>,
90 stream_buf_pos: usize,
91 stream_buf_end: usize,
92 write_buf: Vec<u8>,
93 stmts: StmtCache,
94 params: Vec<(Box<str>, Box<str>)>,
95 pid: i32,
96 secret: i32,
97 tx_status: u8,
98 last_used: std::time::Instant,
99 streaming_active: bool,
100 created_at: std::time::Instant,
101 pending_notifications: Vec<Notification>,
102 max_stmt_cache_size: usize,
103 query_counter: u64,
104 connect_config: Arc<Config>,
107 tls_server_cert_hash: Option<[u8; 32]>,
110}
111
112impl std::fmt::Debug for Connection {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("Connection")
115 .field("pid", &self.pid)
116 .field("tx_status", &(self.tx_status as char))
117 .field("stmt_cache_len", &self.stmts.len())
118 .finish()
119 }
120}
121
122impl Connection {
123 pub fn connect(config: &Config) -> Result<Self, DriverError> {
135 Self::connect_arc(Arc::new(config.clone()))
136 }
137
138 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
143 config.validate()?;
144
145 #[allow(unused_mut)]
148 let mut tls_cert_hash: Option<[u8; 32]> = None;
149
150 let stream = if config.host_is_uds() {
151 #[cfg(unix)]
153 {
154 let path = config.uds_path();
155 let unix =
156 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
157 Stream::Unix(unix)
158 }
159 #[cfg(not(unix))]
160 {
161 return Err(DriverError::Protocol(
162 "Unix domain sockets are not supported on this platform".into(),
163 ));
164 }
165 } else {
166 let addr = format!("{}:{}", config.host, config.port);
168 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
169
170 match config.ssl {
171 SslMode::Disable => {
172 tcp.set_nodelay(true).map_err(DriverError::Io)?;
173 let stream = Stream::Tcp(tcp);
174 stream.set_keepalive()?;
175 stream
176 }
177 SslMode::Prefer | SslMode::Require => {
178 #[cfg(feature = "tls")]
179 {
180 match crate::tls_sync::try_upgrade(
181 tcp,
182 &config.host,
183 config.ssl == SslMode::Require,
184 ) {
185 Ok(result) => {
186 tls_cert_hash = result.server_cert_hash;
187 let stream = Stream::Tls(Box::new(result.stream));
188 stream.set_nodelay()?;
189 stream.set_keepalive()?;
190 stream
191 }
192 Err(e) => {
193 if config.ssl == SslMode::Require {
194 return Err(e);
195 }
196 let tcp =
198 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
199 tcp.set_nodelay(true).map_err(DriverError::Io)?;
200 let stream = Stream::Tcp(tcp);
201 stream.set_keepalive()?;
202 stream
203 }
204 }
205 }
206 #[cfg(not(feature = "tls"))]
207 {
208 if config.ssl == SslMode::Require {
209 return Err(DriverError::Protocol(
210 "sslmode=require but bsql was compiled without the 'tls' feature"
211 .into(),
212 ));
213 }
214 tcp.set_nodelay(true).map_err(DriverError::Io)?;
215 let stream = Stream::Tcp(tcp);
216 stream.set_keepalive()?;
217 stream
218 }
219 }
220 }
221 };
222
223 let mut conn = Self {
224 stream,
225 read_buf: Vec::with_capacity(8192),
226 stream_buf: vec![0u8; 65536],
227 stream_buf_pos: 0,
228 stream_buf_end: 0,
229 write_buf: Vec::with_capacity(4096),
230 stmts: StmtCache::default(),
231 params: Vec::new(),
232 pid: 0,
233 secret: 0,
234 tx_status: b'I',
235 last_used: std::time::Instant::now(),
236 streaming_active: false,
237 created_at: std::time::Instant::now(),
238 pending_notifications: Vec::new(),
239 max_stmt_cache_size: 256,
240 query_counter: 0,
241 connect_config: config.clone(),
242 tls_server_cert_hash: tls_cert_hash,
243 };
244
245 conn.startup(&config)?;
246 conn.validate_server_params()?;
247
248 if config.statement_timeout_secs > 0 {
249 conn.simple_query(&format!(
250 "SET statement_timeout = '{}s'",
251 config.statement_timeout_secs
252 ))?;
253 }
254
255 Ok(conn)
256 }
257
258 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
261 self.write_buf.clear();
262 proto::write_startup(&mut self.write_buf, &config.user, &config.database);
263 self.flush_write()?;
264
265 loop {
266 let action = self.read_startup_action()?;
267 match action {
268 StartupAction::AuthOk => {}
269 StartupAction::AuthCleartext => {
270 self.write_buf.clear();
271 let mut pw = config.password.as_bytes().to_vec();
272 pw.push(0);
273 proto::write_password(&mut self.write_buf, &pw);
274 self.flush_write()?;
275 }
276 StartupAction::AuthMd5(salt) => {
277 self.write_buf.clear();
278 let hash = auth::md5_password(&config.user, &config.password, &salt);
279 proto::write_password(&mut self.write_buf, &hash);
280 self.flush_write()?;
281 }
282 StartupAction::AuthSasl(mechanisms_data) => {
283 self.handle_scram(config, &mechanisms_data)?;
284 }
285 StartupAction::ParameterStatus(name, value) => {
286 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
287 entry.1 = value;
288 } else {
289 self.params.push((name, value));
290 }
291 }
292 StartupAction::BackendKeyData(pid, secret) => {
293 self.pid = pid;
294 self.secret = secret;
295 }
296 StartupAction::ReadyForQuery(status) => {
297 self.tx_status = status;
298 return Ok(());
299 }
300 StartupAction::Error(msg) => {
301 return Err(DriverError::Auth(msg));
302 }
303 StartupAction::Notice => {}
304 }
305 }
306 }
307
308 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
309 let (msg_type, _) = self.read_message_buffered()?;
310 let payload = &self.read_buf;
311 let msg = proto::parse_backend_message(msg_type, payload)?;
312 match msg {
313 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
314 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
315 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
316 BackendMessage::AuthSasl { mechanisms } => {
317 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
318 }
319 BackendMessage::ParameterStatus { name, value } => {
320 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
321 }
322 BackendMessage::BackendKeyData { pid, secret } => {
323 Ok(StartupAction::BackendKeyData(pid, secret))
324 }
325 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
326 BackendMessage::ErrorResponse { data } => {
327 let fields = proto::parse_error_response(data);
328 Ok(StartupAction::Error(fields.to_string()))
329 }
330 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
331 other => Err(DriverError::Protocol(format!(
332 "unexpected message during startup: {other:?}"
333 ))),
334 }
335 }
336
337 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
338 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
339
340 let use_plus = self.tls_server_cert_hash.is_some() && mechs.contains(&"SCRAM-SHA-256-PLUS");
343 let mechanism = if use_plus {
344 "SCRAM-SHA-256-PLUS"
345 } else {
346 "SCRAM-SHA-256"
347 };
348
349 if !mechs.contains(&mechanism) && !mechs.contains(&"SCRAM-SHA-256") {
350 return Err(DriverError::Auth(format!(
351 "server requires unsupported SASL mechanism(s): {mechs:?}"
352 )));
353 }
354
355 let cert_hash = if use_plus {
356 self.tls_server_cert_hash.as_ref()
357 } else {
358 None
359 };
360 let mut scram = auth::ScramClient::new(&config.user, &config.password, cert_hash)?;
361
362 let client_first = scram.client_first_message();
364 self.write_buf.clear();
365 proto::write_sasl_initial(&mut self.write_buf, mechanism, &client_first);
366 self.flush_write()?;
367
368 let (msg_type, _) = self.read_message_buffered()?;
370 let server_first = {
371 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
372 match msg {
373 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
374 BackendMessage::ErrorResponse { data } => {
375 let fields = proto::parse_error_response(data);
376 return Err(DriverError::Auth(fields.to_string()));
377 }
378 other => {
379 return Err(DriverError::Protocol(format!(
380 "expected AuthSaslContinue, got: {other:?}"
381 )));
382 }
383 }
384 };
385
386 scram.process_server_first(&server_first)?;
387
388 let client_final = scram.client_final_message()?;
390 self.write_buf.clear();
391 proto::write_sasl_response(&mut self.write_buf, &client_final);
392 self.flush_write()?;
393
394 let (msg_type, _) = self.read_message_buffered()?;
396 {
397 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
398 match msg {
399 BackendMessage::AuthSaslFinal { data } => {
400 let data_owned = data.to_vec();
401 scram.verify_server_final(&data_owned)?;
402 }
403 BackendMessage::ErrorResponse { data } => {
404 let fields = proto::parse_error_response(data);
405 return Err(DriverError::Auth(fields.to_string()));
406 }
407 other => {
408 return Err(DriverError::Protocol(format!(
409 "expected AuthSaslFinal, got: {other:?}"
410 )));
411 }
412 }
413 }
414
415 let (msg_type, _) = self.read_message_buffered()?;
417 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
418 match msg {
419 BackendMessage::AuthOk => Ok(()),
420 BackendMessage::ErrorResponse { data } => {
421 let fields = proto::parse_error_response(data);
422 Err(DriverError::Auth(fields.to_string()))
423 }
424 other => Err(DriverError::Protocol(format!(
425 "expected AuthOk after SCRAM, got: {other:?}"
426 ))),
427 }
428 }
429
430 fn validate_server_params(&self) -> Result<(), DriverError> {
431 if let Some(encoding) = self.parameter("server_encoding") {
432 let normalized = encoding.to_uppercase();
433 if normalized != "UTF8" && normalized != "UTF-8" {
434 return Err(DriverError::Protocol(format!(
435 "server_encoding is '{encoding}', but bsql requires UTF-8."
436 )));
437 }
438 }
439 if let Some(encoding) = self.parameter("client_encoding") {
440 let normalized = encoding.to_uppercase();
441 if normalized != "UTF8" && normalized != "UTF-8" {
442 return Err(DriverError::Protocol(format!(
443 "client_encoding is '{encoding}', but bsql requires UTF-8."
444 )));
445 }
446 }
447 if let Some(idt) = self.parameter("integer_datetimes") {
448 if idt != "on" {
449 return Err(DriverError::Protocol(format!(
450 "integer_datetimes is '{idt}', but bsql requires 'on'."
451 )));
452 }
453 }
454 Ok(())
455 }
456
457 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
463 if self.stmts.contains_key(&sql_hash, sql) {
464 return Ok(());
465 }
466 let name = make_stmt_name(sql_hash);
467 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
468 self.write_buf.clear();
469 proto::write_parse(&mut self.write_buf, name_s, sql, &[]);
470 proto::write_describe(&mut self.write_buf, b'S', name_s);
471 proto::write_sync(&mut self.write_buf);
472 self.flush_write()?;
473
474 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
475 let columns = self.read_column_description()?;
476 self.expect_ready()?;
477
478 self.query_counter += 1;
479 self.cache_stmt(
480 sql_hash,
481 StmtInfo {
482 name,
483 sql: sql.into(),
484 columns,
485 last_used: self.query_counter,
486 bind_template: None,
487 },
488 );
489 Ok(())
490 }
491
492 #[inline]
503 pub fn query(
504 &mut self,
505 sql: &str,
506 sql_hash: u64,
507 params: &[&(dyn Encode + Sync)],
508 ) -> Result<QueryResult, DriverError> {
509 let columns = self
510 .send_pipeline(sql, sql_hash, params, true, true)?
511 .expect("send_pipeline(need_columns=true) must return Some");
512
513 let num_cols = columns.len();
514 let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
515 let mut affected_rows: u64 = 0;
516
517 let mut resp_buf = acquire_resp_buf();
526 resp_buf.clear();
527
528 'outer: loop {
530 loop {
531 let avail = self.stream_buf_end - self.stream_buf_pos;
532 if avail < 5 {
533 break; }
535
536 let msg_type = self.stream_buf[self.stream_buf_pos];
537 let raw_len = i32::from_be_bytes([
538 self.stream_buf[self.stream_buf_pos + 1],
539 self.stream_buf[self.stream_buf_pos + 2],
540 self.stream_buf[self.stream_buf_pos + 3],
541 self.stream_buf[self.stream_buf_pos + 4],
542 ]);
543
544 if raw_len < 4 {
545 return Err(DriverError::Protocol(format!(
546 "invalid message length {raw_len} for type '{}'",
547 msg_type as char
548 )));
549 }
550
551 let payload_len = (raw_len - 4) as usize;
552 let total_msg_len = 5 + payload_len;
553
554 if avail < total_msg_len {
555 if total_msg_len > self.stream_buf.len() {
556 let msg = self.read_one_message()?;
558 match msg {
559 BackendMessage::BindComplete => continue,
560 BackendMessage::DataRow { data } => {
561 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
562 continue;
563 }
564 BackendMessage::CommandComplete { tag } => {
565 affected_rows = proto::parse_command_tag(tag);
566 continue;
567 }
568 BackendMessage::EmptyQuery => continue,
569 BackendMessage::ReadyForQuery { status } => {
570 self.tx_status = status;
571 break 'outer;
572 }
573 BackendMessage::NoticeResponse { .. } => continue,
574 BackendMessage::ErrorResponse { data } => {
575 let fields = proto::parse_error_response(data);
576 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
577 self.drain_to_ready()?;
578 return Err(self.make_server_error(fields));
579 }
580 other => {
581 return Err(DriverError::Protocol(format!(
582 "unexpected message during query: {other:?}"
583 )));
584 }
585 }
586 }
587 break; }
589
590 let payload_start = self.stream_buf_pos + 5;
592 let payload_end = payload_start + payload_len;
593
594 if msg_type == b'D' {
595 parse_data_row_into_buf(
597 &self.stream_buf[payload_start..payload_end],
598 &mut resp_buf,
599 &mut all_col_offsets,
600 )?;
601 } else if msg_type == b'Z' {
602 if payload_len >= 1 {
603 self.tx_status = self.stream_buf[payload_start];
604 }
605 self.stream_buf_pos += total_msg_len;
606 break 'outer;
607 } else {
608 self.handle_non_datarow_query(
609 msg_type,
610 payload_start,
611 payload_end,
612 sql_hash,
613 &mut affected_rows,
614 )?;
615 }
616
617 self.stream_buf_pos += total_msg_len;
618 }
619
620 self.refill_stream_buf()?;
621 }
622
623 self.shrink_buffers();
624
625 Ok(QueryResult::from_parts_with_buf(
628 all_col_offsets,
629 num_cols,
630 columns,
631 affected_rows,
632 resp_buf,
633 ))
634 }
635
636 #[inline]
647 pub fn execute_monolithic(
648 &mut self,
649 sql: &str,
650 sql_hash: u64,
651 params: &[&(dyn Encode + Sync)],
652 ) -> Result<u64, DriverError> {
653 self.write_buf.clear();
655
656 let info = match self.stmts.get_mut(&sql_hash, sql) {
658 Some(info) => {
659 self.query_counter += 1;
660 info.last_used = self.query_counter;
661 info
662 }
663 None => {
664 return self.execute_with_prepare(sql, sql_hash, params);
666 }
667 };
668
669 let can_use_template = info
671 .bind_template
672 .as_ref()
673 .is_some_and(|t| t.param_slots.len() == params.len());
674
675 let mut has_exec_sync = false;
676
677 if can_use_template {
678 let tmpl = info
680 .bind_template
681 .as_ref()
682 .expect("guarded by can_use_template");
683 self.write_buf.extend_from_slice(&tmpl.bytes);
684
685 let mut template_ok = true;
686 for (i, param) in params.iter().enumerate() {
687 let (data_offset, old_len) = tmpl.param_slots[i];
688 if param.is_null() {
689 let len_offset = data_offset - 4;
690 self.write_buf[len_offset..len_offset + 4]
691 .copy_from_slice(&(-1i32).to_be_bytes());
692 } else if old_len >= 0 {
693 let end = data_offset + old_len as usize;
694 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
695 template_ok = false;
696 break;
697 }
698 } else {
699 template_ok = false;
701 break;
702 }
703 }
704
705 if template_ok {
706 has_exec_sync = true;
707 } else {
708 self.write_buf.clear();
709 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
710 info.bind_template = None;
711 }
712 } else {
713 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
714 }
715
716 if info.bind_template.is_none() && !self.write_buf.is_empty() {
718 info.bind_template = build_bind_template(&self.write_buf, params.len());
719 }
720
721 if !has_exec_sync {
722 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
723 }
724
725 self.stream
727 .write_all(&self.write_buf)
728 .map_err(DriverError::Io)?;
729
730 let mut affected_rows: u64 = 0;
732
733 'outer: loop {
734 loop {
735 let avail = self.stream_buf_end - self.stream_buf_pos;
736 if avail < 5 {
737 break; }
739
740 let msg_type = self.stream_buf[self.stream_buf_pos];
741 let raw_len = i32::from_be_bytes([
742 self.stream_buf[self.stream_buf_pos + 1],
743 self.stream_buf[self.stream_buf_pos + 2],
744 self.stream_buf[self.stream_buf_pos + 3],
745 self.stream_buf[self.stream_buf_pos + 4],
746 ]);
747
748 if raw_len < 4 {
749 return Err(DriverError::Protocol(format!(
750 "invalid message length {raw_len} for type '{}'",
751 msg_type as char
752 )));
753 }
754
755 let payload_len = (raw_len - 4) as usize;
756 let total_msg_len = 5 + payload_len;
757
758 if avail < total_msg_len {
759 if total_msg_len > self.stream_buf.len() {
760 let msg = self.read_one_message()?;
761 match msg {
762 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
763 continue;
764 }
765 BackendMessage::CommandComplete { tag } => {
766 affected_rows = proto::parse_command_tag(tag);
767 continue;
768 }
769 BackendMessage::EmptyQuery => continue,
770 BackendMessage::ReadyForQuery { status } => {
771 self.tx_status = status;
772 break 'outer;
773 }
774 BackendMessage::NoticeResponse { .. } => continue,
775 BackendMessage::ErrorResponse { data } => {
776 let fields = proto::parse_error_response(data);
777 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
778 self.drain_to_ready()?;
779 return Err(self.make_server_error(fields));
780 }
781 other => {
782 return Err(DriverError::Protocol(format!(
783 "unexpected message during execute: {other:?}"
784 )));
785 }
786 }
787 }
788 break; }
790
791 let payload_start = self.stream_buf_pos + 5;
796 let payload_end = payload_start + payload_len;
797
798 if msg_type == b'2' {
799 self.stream_buf_pos += total_msg_len;
801 continue;
802 } else if msg_type == b'C' {
803 affected_rows = proto::parse_command_tag_bytes(
805 &self.stream_buf[payload_start..payload_end],
806 );
807 } else if msg_type == b'Z' {
808 if payload_len >= 1 {
810 self.tx_status = self.stream_buf[payload_start];
811 }
812 self.stream_buf_pos += total_msg_len;
813 break 'outer;
814 } else if msg_type == b'D' || msg_type == b'I' {
815 } else {
817 self.handle_non_datarow_execute(
818 msg_type,
819 payload_start,
820 payload_end,
821 sql_hash,
822 )?;
823 }
824
825 self.stream_buf_pos += total_msg_len;
826 }
827
828 let remaining = self.stream_buf_end - self.stream_buf_pos;
830 debug_assert!(
831 remaining == 0 || self.stream_buf_pos > 0,
832 "compact called with pos=0 and remaining data"
833 );
834 if remaining > 0 {
835 self.stream_buf
836 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
837 }
838 self.stream_buf_pos = 0;
839 self.stream_buf_end = remaining;
840 let n = self
841 .stream
842 .read(&mut self.stream_buf[remaining..])
843 .map_err(DriverError::Io)?;
844 if n == 0 {
845 return Err(DriverError::Io(std::io::Error::new(
846 std::io::ErrorKind::UnexpectedEof,
847 "connection closed",
848 )));
849 }
850 self.stream_buf_end = remaining + n;
851 }
852
853 if self.query_counter & 63 == 0 {
855 if self.read_buf.capacity() > 64 * 1024 {
856 self.read_buf.clear();
857 self.read_buf.shrink_to(8192);
858 }
859 if self.write_buf.capacity() > 16 * 1024 {
860 self.write_buf.clear();
861 self.write_buf.shrink_to(8192);
862 }
863 }
864
865 Ok(affected_rows)
866 }
867
868 #[cold]
870 #[inline(never)]
871 fn execute_with_prepare(
872 &mut self,
873 sql: &str,
874 sql_hash: u64,
875 params: &[&(dyn Encode + Sync)],
876 ) -> Result<u64, DriverError> {
877 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
878
879 if params.len() > i16::MAX as usize {
880 return Err(DriverError::Protocol(format!(
881 "parameter count {} exceeds maximum {}",
882 params.len(),
883 i16::MAX
884 )));
885 }
886
887 let name = make_stmt_name(sql_hash);
888 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
889 let param_oids: smallvec::SmallVec<[u32; 8]> =
890 params.iter().map(|p| p.type_oid()).collect();
891
892 self.write_buf.clear();
893 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
894 proto::write_describe(&mut self.write_buf, b'S', name_s);
895 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
896 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
897 self.stream
898 .write_all(&self.write_buf)
899 .map_err(DriverError::Io)?;
900
901 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
902 let columns = self.read_column_description()?;
903 self.query_counter += 1;
904 self.cache_stmt(
905 sql_hash,
906 StmtInfo {
907 name,
908 sql: sql.into(),
909 columns,
910 last_used: self.query_counter,
911 bind_template: None,
912 },
913 );
914
915 let mut affected_rows: u64 = 0;
917 'outer: loop {
918 loop {
919 let avail = self.stream_buf_end - self.stream_buf_pos;
920 if avail < 5 {
921 break;
922 }
923
924 let msg_type = self.stream_buf[self.stream_buf_pos];
925 let raw_len = i32::from_be_bytes([
926 self.stream_buf[self.stream_buf_pos + 1],
927 self.stream_buf[self.stream_buf_pos + 2],
928 self.stream_buf[self.stream_buf_pos + 3],
929 self.stream_buf[self.stream_buf_pos + 4],
930 ]);
931
932 if raw_len < 4 {
933 return Err(DriverError::Protocol(format!(
934 "invalid message length {raw_len} for type '{}'",
935 msg_type as char
936 )));
937 }
938
939 let payload_len = (raw_len - 4) as usize;
940 let total_msg_len = 5 + payload_len;
941
942 if avail < total_msg_len {
943 if total_msg_len > self.stream_buf.len() {
944 let msg = self.read_one_message()?;
945 match msg {
946 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
947 continue;
948 }
949 BackendMessage::CommandComplete { tag } => {
950 affected_rows = proto::parse_command_tag(tag);
951 continue;
952 }
953 BackendMessage::EmptyQuery => continue,
954 BackendMessage::ReadyForQuery { status } => {
955 self.tx_status = status;
956 break 'outer;
957 }
958 BackendMessage::NoticeResponse { .. } => continue,
959 BackendMessage::ErrorResponse { data } => {
960 let fields = proto::parse_error_response(data);
961 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
962 self.drain_to_ready()?;
963 return Err(self.make_server_error(fields));
964 }
965 other => {
966 return Err(DriverError::Protocol(format!(
967 "unexpected message during execute: {other:?}"
968 )));
969 }
970 }
971 }
972 break;
973 }
974
975 let payload_start = self.stream_buf_pos + 5;
976 let payload_end = payload_start + payload_len;
977
978 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
979 } else if msg_type == b'C' {
981 affected_rows = proto::parse_command_tag_bytes(
982 &self.stream_buf[payload_start..payload_end],
983 );
984 } else if msg_type == b'Z' {
985 if payload_len >= 1 {
986 self.tx_status = self.stream_buf[payload_start];
987 }
988 self.stream_buf_pos += total_msg_len;
989 break 'outer;
990 } else {
991 self.handle_non_datarow_execute(
992 msg_type,
993 payload_start,
994 payload_end,
995 sql_hash,
996 )?;
997 }
998
999 self.stream_buf_pos += total_msg_len;
1000 }
1001
1002 self.refill_stream_buf()?;
1003 }
1004
1005 Ok(affected_rows)
1006 }
1007
1008 #[inline]
1013 pub fn execute(
1014 &mut self,
1015 sql: &str,
1016 sql_hash: u64,
1017 params: &[&(dyn Encode + Sync)],
1018 ) -> Result<u64, DriverError> {
1019 self.execute_monolithic(sql, sql_hash, params)
1020 }
1021
1022 pub fn execute_pipeline(
1034 &mut self,
1035 sql: &str,
1036 sql_hash: u64,
1037 param_sets: &[&[&(dyn Encode + Sync)]],
1038 ) -> Result<Vec<u64>, DriverError> {
1039 if param_sets.is_empty() {
1040 return Ok(Vec::new());
1041 }
1042
1043 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1044
1045 self.write_buf.clear();
1046
1047 if !self.stmts.contains_key(&sql_hash, sql) {
1049 let name = make_stmt_name(sql_hash);
1050 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1051 let first_params = param_sets[0];
1052 if first_params.len() > i16::MAX as usize {
1053 return Err(DriverError::Protocol(format!(
1054 "parameter count {} exceeds maximum {}",
1055 first_params.len(),
1056 i16::MAX
1057 )));
1058 }
1059 let param_oids: smallvec::SmallVec<[u32; 8]> =
1060 first_params.iter().map(|p| p.type_oid()).collect();
1061 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1062 proto::write_describe(&mut self.write_buf, b'S', name_s);
1063 proto::write_sync(&mut self.write_buf);
1064 self.flush_write()?;
1065
1066 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1067 let columns = self.read_column_description()?;
1068 self.expect_ready()?;
1069
1070 self.query_counter += 1;
1071 self.cache_stmt(
1072 sql_hash,
1073 StmtInfo {
1074 name,
1075 sql: sql.into(),
1076 columns,
1077 last_used: self.query_counter,
1078 bind_template: None,
1079 },
1080 );
1081
1082 self.write_buf.clear();
1083 }
1084
1085 let stmt_name = self
1087 .stmts
1088 .get(&sql_hash, sql)
1089 .expect("BUG: stmt just cached but not found")
1090 .name_str()
1091 .to_owned();
1092 let count = param_sets.len();
1093
1094 for params in param_sets {
1095 if params.len() > i16::MAX as usize {
1096 return Err(DriverError::Protocol(format!(
1097 "parameter count {} exceeds maximum {}",
1098 params.len(),
1099 i16::MAX
1100 )));
1101 }
1102 proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1103 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1104 }
1105
1106 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1107 self.flush_write()?;
1108
1109 let mut results = Vec::with_capacity(count);
1111 for _ in 0..count {
1112 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1113
1114 let mut affected_rows: u64 = 0;
1115 loop {
1116 let msg = self.read_one_message()?;
1117 match msg {
1118 BackendMessage::DataRow { .. } => {}
1119 BackendMessage::CommandComplete { tag } => {
1120 affected_rows = proto::parse_command_tag(tag);
1121 break;
1122 }
1123 BackendMessage::EmptyQuery => break,
1124 BackendMessage::NoticeResponse { .. } => {}
1125 BackendMessage::ErrorResponse { data } => {
1126 let fields = proto::parse_error_response(data);
1127 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1128 self.drain_to_ready()?;
1129 return Err(self.make_server_error(fields));
1130 }
1131 other => {
1132 return Err(DriverError::Protocol(format!(
1133 "unexpected message during execute_pipeline: {other:?}"
1134 )));
1135 }
1136 }
1137 }
1138 results.push(affected_rows);
1139 }
1140
1141 self.expect_ready()?;
1142 self.shrink_buffers();
1143 Ok(results)
1144 }
1145
1146 pub(crate) fn ensure_stmt_prepared(
1152 &mut self,
1153 sql: &str,
1154 sql_hash: u64,
1155 params: &[&(dyn Encode + Sync)],
1156 ) -> Result<[u8; 18], DriverError> {
1157 if let Some(info) = self.stmts.get(&sql_hash, sql) {
1158 return Ok(info.name);
1159 }
1160
1161 let name = make_stmt_name(sql_hash);
1162 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1163 if params.len() > i16::MAX as usize {
1164 return Err(DriverError::Protocol(format!(
1165 "parameter count {} exceeds maximum {}",
1166 params.len(),
1167 i16::MAX
1168 )));
1169 }
1170 let param_oids: smallvec::SmallVec<[u32; 8]> =
1171 params.iter().map(|p| p.type_oid()).collect();
1172
1173 self.write_buf.clear();
1174 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1175 proto::write_describe(&mut self.write_buf, b'S', name_s);
1176 proto::write_sync(&mut self.write_buf);
1177 self.flush_write()?;
1178
1179 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1180 let columns = self.read_column_description()?;
1181 self.expect_ready()?;
1182
1183 self.query_counter += 1;
1184 self.cache_stmt(
1185 sql_hash,
1186 StmtInfo {
1187 name,
1188 sql: sql.into(),
1189 columns,
1190 last_used: self.query_counter,
1191 bind_template: None,
1192 },
1193 );
1194
1195 Ok(name)
1196 }
1197
1198 pub(crate) fn write_deferred_bind_execute(
1201 &self,
1202 sql: &str,
1203 sql_hash: u64,
1204 params: &[&(dyn Encode + Sync)],
1205 buf: &mut Vec<u8>,
1206 ) {
1207 let stmt_name = self
1208 .stmts
1209 .get(&sql_hash, sql)
1210 .expect("BUG: stmt just cached but not found")
1211 .name_str();
1212 proto::write_bind_params(buf, "", stmt_name, params);
1213 buf.extend_from_slice(proto::EXECUTE_ONLY);
1214 }
1215
1216 pub(crate) fn flush_deferred_pipeline(
1221 &mut self,
1222 buf: &mut Vec<u8>,
1223 count: usize,
1224 ) -> Result<Vec<u64>, DriverError> {
1225 if count == 0 {
1226 buf.clear();
1227 return Ok(Vec::new());
1228 }
1229
1230 buf.extend_from_slice(proto::SYNC_ONLY);
1231
1232 self.stream.write_all(buf).map_err(DriverError::Io)?;
1233 buf.clear();
1234
1235 let mut results = Vec::with_capacity(count);
1236 for _ in 0..count {
1237 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1238
1239 let mut affected_rows: u64 = 0;
1240 loop {
1241 let msg = self.read_one_message()?;
1242 match msg {
1243 BackendMessage::DataRow { .. } => {}
1244 BackendMessage::CommandComplete { tag } => {
1245 affected_rows = proto::parse_command_tag(tag);
1246 break;
1247 }
1248 BackendMessage::EmptyQuery => break,
1249 BackendMessage::NoticeResponse { .. } => {}
1250 BackendMessage::ErrorResponse { data } => {
1251 let fields = proto::parse_error_response(data);
1252 self.drain_to_ready()?;
1253 return Err(self.make_server_error(fields));
1254 }
1255 other => {
1256 return Err(DriverError::Protocol(format!(
1257 "unexpected message during flush_deferred_pipeline: {other:?}"
1258 )));
1259 }
1260 }
1261 }
1262 results.push(affected_rows);
1263 }
1264
1265 self.expect_ready()?;
1266 self.shrink_buffers();
1267 Ok(results)
1268 }
1269
1270 pub fn for_each<F>(
1272 &mut self,
1273 sql: &str,
1274 sql_hash: u64,
1275 params: &[&(dyn Encode + Sync)],
1276 mut f: F,
1277 ) -> Result<(), DriverError>
1278 where
1279 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1280 {
1281 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1282
1283 'outer: loop {
1285 loop {
1286 let avail = self.stream_buf_end - self.stream_buf_pos;
1287 if avail < 5 {
1288 break; }
1290
1291 let msg_type = self.stream_buf[self.stream_buf_pos];
1292 let raw_len = i32::from_be_bytes([
1293 self.stream_buf[self.stream_buf_pos + 1],
1294 self.stream_buf[self.stream_buf_pos + 2],
1295 self.stream_buf[self.stream_buf_pos + 3],
1296 self.stream_buf[self.stream_buf_pos + 4],
1297 ]);
1298
1299 if raw_len < 4 {
1300 return Err(DriverError::Protocol(format!(
1301 "invalid message length {raw_len} for type '{}'",
1302 msg_type as char
1303 )));
1304 }
1305
1306 let payload_len = (raw_len - 4) as usize;
1307 let total_msg_len = 5 + payload_len;
1308
1309 if avail < total_msg_len {
1310 if total_msg_len > self.stream_buf.len() {
1311 let msg = self.read_one_message()?;
1313 match msg {
1314 BackendMessage::BindComplete => continue,
1315 BackendMessage::DataRow { data } => {
1316 let row = PgDataRow::new(data)?;
1317 f(row)?;
1318 continue;
1319 }
1320 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1321 continue;
1322 }
1323 BackendMessage::ReadyForQuery { status } => {
1324 self.tx_status = status;
1325 break 'outer;
1326 }
1327 BackendMessage::NoticeResponse { .. } => continue,
1328 BackendMessage::ErrorResponse { data } => {
1329 let fields = proto::parse_error_response(data);
1330 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1331 self.drain_to_ready()?;
1332 return Err(self.make_server_error(fields));
1333 }
1334 other => {
1335 return Err(DriverError::Protocol(format!(
1336 "unexpected message during for_each: {other:?}"
1337 )));
1338 }
1339 }
1340 }
1341 break; }
1343
1344 let payload_start = self.stream_buf_pos + 5;
1346 let payload_end = payload_start + payload_len;
1347
1348 if msg_type == b'D' {
1351 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1353 f(row)?;
1354 } else if msg_type == b'Z' {
1355 if payload_len >= 1 {
1357 self.tx_status = self.stream_buf[payload_start];
1358 }
1359 self.stream_buf_pos += total_msg_len;
1360 break 'outer;
1361 } else {
1362 self.handle_non_datarow_execute(
1363 msg_type,
1364 payload_start,
1365 payload_end,
1366 sql_hash,
1367 )?;
1368 }
1369
1370 self.stream_buf_pos += total_msg_len;
1371 }
1372
1373 self.refill_stream_buf()?;
1375 }
1376
1377 self.shrink_buffers();
1378 Ok(())
1379 }
1380
1381 #[inline]
1392 pub fn for_each_raw_monolithic<F>(
1393 &mut self,
1394 sql: &str,
1395 sql_hash: u64,
1396 params: &[&(dyn Encode + Sync)],
1397 mut f: F,
1398 ) -> Result<(), DriverError>
1399 where
1400 F: FnMut(&[u8]) -> Result<(), DriverError>,
1401 {
1402 self.write_buf.clear();
1404
1405 let info = match self.stmts.get_mut(&sql_hash, sql) {
1407 Some(info) => {
1408 self.query_counter += 1;
1409 info.last_used = self.query_counter;
1410 info
1411 }
1412 None => {
1413 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1415 }
1416 };
1417
1418 let can_use_template = info
1420 .bind_template
1421 .as_ref()
1422 .is_some_and(|t| t.param_slots.len() == params.len());
1423
1424 let mut has_exec_sync = false;
1425
1426 if can_use_template {
1427 let tmpl = info
1429 .bind_template
1430 .as_ref()
1431 .expect("guarded by can_use_template");
1432 self.write_buf.extend_from_slice(&tmpl.bytes);
1433
1434 let mut template_ok = true;
1435 for (i, param) in params.iter().enumerate() {
1436 let (data_offset, old_len) = tmpl.param_slots[i];
1437 if param.is_null() {
1438 let len_offset = data_offset - 4;
1439 self.write_buf[len_offset..len_offset + 4]
1440 .copy_from_slice(&(-1i32).to_be_bytes());
1441 } else if old_len >= 0 {
1442 let end = data_offset + old_len as usize;
1443 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1444 template_ok = false;
1445 break;
1446 }
1447 } else {
1448 template_ok = false;
1449 break;
1450 }
1451 }
1452
1453 if template_ok {
1454 has_exec_sync = true;
1455 } else {
1456 self.write_buf.clear();
1457 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1458 info.bind_template = None;
1459 }
1460 } else {
1461 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1462 }
1463
1464 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1466 info.bind_template = build_bind_template(&self.write_buf, params.len());
1467 }
1468
1469 if !has_exec_sync {
1470 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1471 }
1472
1473 self.stream
1475 .write_all(&self.write_buf)
1476 .map_err(DriverError::Io)?;
1477
1478 loop {
1482 let avail = self.stream_buf_end - self.stream_buf_pos;
1483 if avail >= 5 {
1484 let bc_type = self.stream_buf[self.stream_buf_pos];
1485 match bc_type {
1486 b'2' => {
1487 self.stream_buf_pos += 5;
1488 break;
1489 }
1490 b'E' => {
1491 let msg = self.read_one_message()?;
1492 if let BackendMessage::ErrorResponse { data } = msg {
1493 let fields = proto::parse_error_response(data);
1494 self.drain_to_ready()?;
1495 return Err(self.make_server_error(fields));
1496 }
1497 }
1498 b'N' | b'S' => {
1499 let raw_len = i32::from_be_bytes([
1500 self.stream_buf[self.stream_buf_pos + 1],
1501 self.stream_buf[self.stream_buf_pos + 2],
1502 self.stream_buf[self.stream_buf_pos + 3],
1503 self.stream_buf[self.stream_buf_pos + 4],
1504 ]);
1505 let total = 1 + raw_len as usize;
1506 if avail >= total {
1507 self.stream_buf_pos += total;
1508 continue;
1509 }
1510 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1511 break;
1512 }
1513 _ => {
1514 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1515 break;
1516 }
1517 }
1518 } else {
1519 let remaining = self.stream_buf_end - self.stream_buf_pos;
1521 if remaining > 0 && self.stream_buf_pos > 0 {
1522 self.stream_buf
1523 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1524 }
1525 self.stream_buf_pos = 0;
1526 self.stream_buf_end = remaining;
1527 let n = self
1528 .stream
1529 .read(&mut self.stream_buf[remaining..])
1530 .map_err(DriverError::Io)?;
1531 if n == 0 {
1532 return Err(DriverError::Io(std::io::Error::new(
1533 std::io::ErrorKind::UnexpectedEof,
1534 "connection closed",
1535 )));
1536 }
1537 self.stream_buf_end = remaining + n;
1538 }
1539 }
1540
1541 'outer: loop {
1543 loop {
1544 let avail = self.stream_buf_end - self.stream_buf_pos;
1545 if avail < 5 {
1546 break;
1547 }
1548
1549 let msg_type = self.stream_buf[self.stream_buf_pos];
1550 let raw_len = i32::from_be_bytes([
1551 self.stream_buf[self.stream_buf_pos + 1],
1552 self.stream_buf[self.stream_buf_pos + 2],
1553 self.stream_buf[self.stream_buf_pos + 3],
1554 self.stream_buf[self.stream_buf_pos + 4],
1555 ]);
1556
1557 if raw_len < 4 {
1558 return Err(DriverError::Protocol(format!(
1559 "invalid message length {raw_len} for type '{}'",
1560 msg_type as char
1561 )));
1562 }
1563
1564 let payload_len = (raw_len - 4) as usize;
1565 let total_msg_len = 5 + payload_len;
1566
1567 if avail < total_msg_len {
1568 if total_msg_len > self.stream_buf.len() {
1569 let msg = self.read_one_message()?;
1570 match msg {
1571 BackendMessage::DataRow { data } => {
1572 f(data)?;
1573 continue;
1574 }
1575 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1576 continue;
1577 }
1578 BackendMessage::ReadyForQuery { status } => {
1579 self.tx_status = status;
1580 break 'outer;
1581 }
1582 BackendMessage::ErrorResponse { data } => {
1583 let fields = proto::parse_error_response(data);
1584 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1585 self.drain_to_ready()?;
1586 return Err(self.make_server_error(fields));
1587 }
1588 BackendMessage::NoticeResponse { .. } => continue,
1589 other => {
1590 return Err(DriverError::Protocol(format!(
1591 "unexpected message during for_each_raw: {other:?}"
1592 )));
1593 }
1594 }
1595 }
1596 break; }
1598
1599 let payload_start = self.stream_buf_pos + 5;
1601 let payload_end = payload_start + payload_len;
1602
1603 if msg_type == b'D' {
1604 f(&self.stream_buf[payload_start..payload_end])?;
1605 } else if msg_type == b'Z' {
1606 if payload_len >= 1 {
1607 self.tx_status = self.stream_buf[payload_start];
1608 }
1609 self.stream_buf_pos += total_msg_len;
1610 break 'outer;
1611 } else {
1612 self.handle_non_datarow_execute(
1613 msg_type,
1614 payload_start,
1615 payload_end,
1616 sql_hash,
1617 )?;
1618 }
1619
1620 self.stream_buf_pos += total_msg_len;
1621 }
1622
1623 let remaining = self.stream_buf_end - self.stream_buf_pos;
1625 if remaining > 0 && self.stream_buf_pos > 0 {
1626 self.stream_buf
1627 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1628 }
1629 self.stream_buf_pos = 0;
1630 self.stream_buf_end = remaining;
1631 let n = self
1632 .stream
1633 .read(&mut self.stream_buf[remaining..])
1634 .map_err(DriverError::Io)?;
1635 if n == 0 {
1636 return Err(DriverError::Io(std::io::Error::new(
1637 std::io::ErrorKind::UnexpectedEof,
1638 "connection closed",
1639 )));
1640 }
1641 self.stream_buf_end = remaining + n;
1642 }
1643
1644 if self.query_counter & 63 == 0 {
1646 if self.read_buf.capacity() > 64 * 1024 {
1647 self.read_buf.clear();
1648 self.read_buf.shrink_to(8192);
1649 }
1650 if self.write_buf.capacity() > 16 * 1024 {
1651 self.write_buf.clear();
1652 self.write_buf.shrink_to(8192);
1653 }
1654 }
1655
1656 Ok(())
1657 }
1658
1659 #[cold]
1661 #[inline(never)]
1662 fn for_each_raw_with_prepare<F>(
1663 &mut self,
1664 sql: &str,
1665 sql_hash: u64,
1666 params: &[&(dyn Encode + Sync)],
1667 mut f: F,
1668 ) -> Result<(), DriverError>
1669 where
1670 F: FnMut(&[u8]) -> Result<(), DriverError>,
1671 {
1672 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1673
1674 if params.len() > i16::MAX as usize {
1675 return Err(DriverError::Protocol(format!(
1676 "parameter count {} exceeds maximum {}",
1677 params.len(),
1678 i16::MAX
1679 )));
1680 }
1681
1682 let name = make_stmt_name(sql_hash);
1683 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1684 let param_oids: smallvec::SmallVec<[u32; 8]> =
1685 params.iter().map(|p| p.type_oid()).collect();
1686
1687 self.write_buf.clear();
1688 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1689 proto::write_describe(&mut self.write_buf, b'S', name_s);
1690 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
1691 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1692 self.stream
1693 .write_all(&self.write_buf)
1694 .map_err(DriverError::Io)?;
1695
1696 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1697 let columns = self.read_column_description()?;
1698 self.query_counter += 1;
1699 self.cache_stmt(
1700 sql_hash,
1701 StmtInfo {
1702 name,
1703 sql: sql.into(),
1704 columns,
1705 last_used: self.query_counter,
1706 bind_template: None,
1707 },
1708 );
1709
1710 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1712
1713 'outer: loop {
1714 loop {
1715 let avail = self.stream_buf_end - self.stream_buf_pos;
1716 if avail < 5 {
1717 break;
1718 }
1719
1720 let msg_type = self.stream_buf[self.stream_buf_pos];
1721 let raw_len = i32::from_be_bytes([
1722 self.stream_buf[self.stream_buf_pos + 1],
1723 self.stream_buf[self.stream_buf_pos + 2],
1724 self.stream_buf[self.stream_buf_pos + 3],
1725 self.stream_buf[self.stream_buf_pos + 4],
1726 ]);
1727
1728 if raw_len < 4 {
1729 return Err(DriverError::Protocol(format!(
1730 "invalid message length {raw_len} for type '{}'",
1731 msg_type as char
1732 )));
1733 }
1734
1735 let payload_len = (raw_len - 4) as usize;
1736 let total_msg_len = 5 + payload_len;
1737
1738 if avail < total_msg_len {
1739 if total_msg_len > self.stream_buf.len() {
1740 let msg = self.read_one_message()?;
1741 match msg {
1742 BackendMessage::DataRow { data } => {
1743 f(data)?;
1744 continue;
1745 }
1746 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1747 continue;
1748 }
1749 BackendMessage::ReadyForQuery { status } => {
1750 self.tx_status = status;
1751 break 'outer;
1752 }
1753 BackendMessage::ErrorResponse { data } => {
1754 let fields = proto::parse_error_response(data);
1755 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1756 self.drain_to_ready()?;
1757 return Err(self.make_server_error(fields));
1758 }
1759 BackendMessage::NoticeResponse { .. } => continue,
1760 other => {
1761 return Err(DriverError::Protocol(format!(
1762 "unexpected message during for_each_raw: {other:?}"
1763 )));
1764 }
1765 }
1766 }
1767 break;
1768 }
1769
1770 let payload_start = self.stream_buf_pos + 5;
1771 let payload_end = payload_start + payload_len;
1772
1773 if msg_type == b'D' {
1774 f(&self.stream_buf[payload_start..payload_end])?;
1775 } else if msg_type == b'Z' {
1776 if payload_len >= 1 {
1777 self.tx_status = self.stream_buf[payload_start];
1778 }
1779 self.stream_buf_pos += total_msg_len;
1780 break 'outer;
1781 } else {
1782 self.handle_non_datarow_execute(
1783 msg_type,
1784 payload_start,
1785 payload_end,
1786 sql_hash,
1787 )?;
1788 }
1789
1790 self.stream_buf_pos += total_msg_len;
1791 }
1792
1793 self.refill_stream_buf()?;
1794 }
1795
1796 self.shrink_buffers();
1797 Ok(())
1798 }
1799
1800 #[inline]
1805 pub fn for_each_raw<F>(
1806 &mut self,
1807 sql: &str,
1808 sql_hash: u64,
1809 params: &[&(dyn Encode + Sync)],
1810 f: F,
1811 ) -> Result<(), DriverError>
1812 where
1813 F: FnMut(&[u8]) -> Result<(), DriverError>,
1814 {
1815 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1816 }
1817
1818 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1820 self.write_buf.clear();
1821 proto::write_simple_query(&mut self.write_buf, sql);
1822 self.flush_write()?;
1823
1824 loop {
1825 let msg = self.read_one_message()?;
1826 match msg {
1827 BackendMessage::ReadyForQuery { status } => {
1828 self.tx_status = status;
1829 return Ok(());
1830 }
1831 BackendMessage::CommandComplete { .. }
1832 | BackendMessage::RowDescription { .. }
1833 | BackendMessage::DataRow { .. }
1834 | BackendMessage::EmptyQuery
1835 | BackendMessage::NoticeResponse { .. }
1836 | BackendMessage::ParameterStatus { .. }
1837 | BackendMessage::AuthOk
1841 | BackendMessage::AuthSaslFinal { .. }
1842 | BackendMessage::BackendKeyData { .. } => {}
1843 BackendMessage::ErrorResponse { data } => {
1844 let fields = proto::parse_error_response(data);
1845 self.drain_to_ready()?;
1846 return Err(self.make_server_error(fields));
1847 }
1848 other => {
1849 return Err(DriverError::Protocol(format!(
1850 "unexpected message during simple_query: {other:?}"
1851 )));
1852 }
1853 }
1854 }
1855 }
1856
1857 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1859 self.write_buf.clear();
1860 proto::write_simple_query(&mut self.write_buf, sql);
1861 self.flush_write()?;
1862
1863 let mut rows: Vec<SimpleRow> = Vec::new();
1864 loop {
1865 let msg = self.read_one_message()?;
1866 match msg {
1867 BackendMessage::ReadyForQuery { status } => {
1868 self.tx_status = status;
1869 return Ok(rows);
1870 }
1871 BackendMessage::DataRow { data } => {
1872 rows.push(proto::parse_simple_data_row(data)?);
1873 }
1874 BackendMessage::RowDescription { .. }
1875 | BackendMessage::CommandComplete { .. }
1876 | BackendMessage::EmptyQuery
1877 | BackendMessage::NoticeResponse { .. }
1878 | BackendMessage::ParameterStatus { .. }
1879 | BackendMessage::AuthOk
1880 | BackendMessage::AuthSaslFinal { .. }
1881 | BackendMessage::BackendKeyData { .. } => {}
1882 BackendMessage::ErrorResponse { data } => {
1883 let fields = proto::parse_error_response(data);
1884 self.drain_to_ready()?;
1885 return Err(self.make_server_error(fields));
1886 }
1887 other => {
1888 return Err(DriverError::Protocol(format!(
1889 "unexpected message during simple_query_rows: {other:?}"
1890 )));
1891 }
1892 }
1893 }
1894 }
1895
1896 pub fn copy_in<'a, I>(
1918 &mut self,
1919 table: &str,
1920 columns: &[&str],
1921 rows: I,
1922 ) -> Result<u64, DriverError>
1923 where
1924 I: IntoIterator<Item = &'a str>,
1925 {
1926 let quoted_table = proto::quote_ident(table);
1928 let quoted_cols: Vec<String> = columns.iter().map(|c| proto::quote_ident(c)).collect();
1929 let sql = format!(
1930 "COPY {}({}) FROM STDIN",
1931 quoted_table,
1932 quoted_cols.join(",")
1933 );
1934
1935 self.write_buf.clear();
1937 proto::write_simple_query(&mut self.write_buf, &sql);
1938 self.flush_write()?;
1939
1940 loop {
1942 let msg = self.read_one_message()?;
1943 match msg {
1944 BackendMessage::CopyInResponse { .. } => break,
1945 BackendMessage::ErrorResponse { data } => {
1946 let fields = proto::parse_error_response(data);
1947 self.drain_to_ready()?;
1948 return Err(self.make_server_error(fields));
1949 }
1950 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
1951 other => {
1952 return Err(DriverError::Protocol(format!(
1953 "expected CopyInResponse, got: {other:?}"
1954 )));
1955 }
1956 }
1957 }
1958
1959 for row in rows {
1966 self.write_buf.clear();
1967 let mut row_bytes = Vec::with_capacity(row.len() + 1);
1969 row_bytes.extend_from_slice(row.as_bytes());
1970 row_bytes.push(b'\n');
1971 proto::write_copy_data(&mut self.write_buf, &row_bytes);
1972 self.flush_write()?;
1973 }
1974
1975 self.write_buf.clear();
1977 proto::write_copy_done(&mut self.write_buf);
1978 self.flush_write()?;
1979
1980 let mut count: u64 = 0;
1982 loop {
1983 let msg = self.read_one_message()?;
1984 match msg {
1985 BackendMessage::CommandComplete { tag } => {
1986 count = proto::parse_command_tag(tag);
1987 }
1988 BackendMessage::ReadyForQuery { status } => {
1989 self.tx_status = status;
1990 return Ok(count);
1991 }
1992 BackendMessage::ErrorResponse { data } => {
1993 let fields = proto::parse_error_response(data);
1994 self.drain_to_ready()?;
1995 return Err(self.make_server_error(fields));
1996 }
1997 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
1998 other => {
1999 return Err(DriverError::Protocol(format!(
2000 "unexpected message during copy_in completion: {other:?}"
2001 )));
2002 }
2003 }
2004 }
2005 }
2006
2007 pub fn copy_out<W: std::io::Write>(
2027 &mut self,
2028 query: &str,
2029 writer: &mut W,
2030 ) -> Result<u64, DriverError> {
2031 let sql = format!("COPY ({query}) TO STDOUT");
2033
2034 self.write_buf.clear();
2036 proto::write_simple_query(&mut self.write_buf, &sql);
2037 self.flush_write()?;
2038
2039 loop {
2041 let msg = self.read_one_message()?;
2042 match msg {
2043 BackendMessage::CopyOutResponse { .. } => break,
2044 BackendMessage::ErrorResponse { data } => {
2045 let fields = proto::parse_error_response(data);
2046 self.drain_to_ready()?;
2047 return Err(self.make_server_error(fields));
2048 }
2049 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2050 other => {
2051 return Err(DriverError::Protocol(format!(
2052 "expected CopyOutResponse, got: {other:?}"
2053 )));
2054 }
2055 }
2056 }
2057
2058 loop {
2060 let msg = self.read_one_message()?;
2061 match msg {
2062 BackendMessage::CopyData { data } => {
2063 writer.write_all(&data).map_err(DriverError::Io)?;
2064 }
2065 BackendMessage::CopyDone => break,
2066 BackendMessage::ErrorResponse { data } => {
2067 let fields = proto::parse_error_response(data);
2068 self.drain_to_ready()?;
2069 return Err(self.make_server_error(fields));
2070 }
2071 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2072 other => {
2073 return Err(DriverError::Protocol(format!(
2074 "unexpected message during copy_out data: {other:?}"
2075 )));
2076 }
2077 }
2078 }
2079
2080 let mut count: u64 = 0;
2082 loop {
2083 let msg = self.read_one_message()?;
2084 match msg {
2085 BackendMessage::CommandComplete { tag } => {
2086 count = proto::parse_command_tag(tag);
2087 }
2088 BackendMessage::ReadyForQuery { status } => {
2089 self.tx_status = status;
2090 return Ok(count);
2091 }
2092 BackendMessage::ErrorResponse { data } => {
2093 let fields = proto::parse_error_response(data);
2094 self.drain_to_ready()?;
2095 return Err(self.make_server_error(fields));
2096 }
2097 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2098 other => {
2099 return Err(DriverError::Protocol(format!(
2100 "unexpected message during copy_out completion: {other:?}"
2101 )));
2102 }
2103 }
2104 }
2105 }
2106
2107 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
2112 self.write_buf.clear();
2113 proto::write_parse(&mut self.write_buf, "", sql, &[]);
2116 proto::write_describe(&mut self.write_buf, b'S', "");
2117 proto::write_sync(&mut self.write_buf);
2118 self.flush_write()?;
2119
2120 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2122
2123 let mut param_oids: Vec<u32> = Vec::new();
2125 let columns;
2126 loop {
2127 let msg = self.read_one_message()?;
2128 match msg {
2129 BackendMessage::ParameterDescription { data } => {
2130 param_oids = proto::parse_parameter_description(data)?;
2131 }
2132 BackendMessage::RowDescription { data } => {
2133 columns = proto::parse_row_description(data)?;
2134 break;
2135 }
2136 BackendMessage::NoData => {
2137 columns = Vec::new();
2138 break;
2139 }
2140 BackendMessage::NoticeResponse { .. } => {}
2141 BackendMessage::ErrorResponse { data } => {
2142 let fields = proto::parse_error_response(data);
2143 self.drain_to_ready()?;
2144 return Err(self.make_server_error(fields));
2145 }
2146 other => {
2147 return Err(DriverError::Protocol(format!(
2148 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
2149 )));
2150 }
2151 }
2152 }
2153
2154 self.expect_ready()?;
2156
2157 Ok(PrepareResult {
2158 columns,
2159 param_oids,
2160 })
2161 }
2162
2163 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2172 loop {
2173 let (msg_type, _payload_len) = self.read_message_buffered()?;
2174 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2175 match msg {
2176 BackendMessage::NotificationResponse {
2177 channel, payload, ..
2178 } => {
2179 return Ok((channel.to_owned(), payload.to_owned()));
2180 }
2181 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2182 continue;
2183 }
2184 _ => continue,
2185 }
2186 }
2187 }
2188
2189 pub fn cancel(&self) -> Result<(), DriverError> {
2195 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
2196 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
2197 let mut buf = Vec::with_capacity(16);
2198 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2199 tcp.write_all(&buf).map_err(DriverError::Io)?;
2200 tcp.flush().map_err(DriverError::Io)?;
2201 drop(tcp);
2203 Ok(())
2204 }
2205
2206 pub fn set_read_timeout(
2211 &self,
2212 timeout: Option<std::time::Duration>,
2213 ) -> Result<(), DriverError> {
2214 self.stream
2215 .set_read_timeout(timeout)
2216 .map_err(DriverError::Io)
2217 }
2218
2219 pub fn query_streaming_start(
2233 &mut self,
2234 sql: &str,
2235 sql_hash: u64,
2236 params: &[&(dyn Encode + Sync)],
2237 chunk_size: i32,
2238 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
2239 self.write_buf.clear();
2240
2241 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2242 self.query_counter += 1;
2244 info.last_used = self.query_counter;
2245
2246 let can_use_template = info
2247 .bind_template
2248 .as_ref()
2249 .is_some_and(|t| t.param_slots.len() == params.len());
2250
2251 if can_use_template {
2252 let tmpl = info
2254 .bind_template
2255 .as_ref()
2256 .expect("guarded by can_use_template");
2257 self.write_buf
2260 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2261
2262 let mut template_ok = true;
2263 for (i, param) in params.iter().enumerate() {
2264 let (data_offset, old_len) = tmpl.param_slots[i];
2265 if param.is_null() {
2266 let len_offset = data_offset - 4;
2267 self.write_buf[len_offset..len_offset + 4]
2268 .copy_from_slice(&(-1i32).to_be_bytes());
2269 } else if old_len >= 0 {
2270 let end = data_offset + old_len as usize;
2271 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2272 template_ok = false;
2273 break;
2274 }
2275 } else {
2276 template_ok = false;
2277 break;
2278 }
2279 }
2280
2281 if !template_ok {
2282 self.write_buf.clear();
2283 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2284 info.bind_template = None;
2285 }
2286 } else {
2287 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2288 }
2289
2290 let cols = info.columns.clone();
2291
2292 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2293 info.bind_template = build_bind_template(&self.write_buf, params.len());
2294 }
2295
2296 proto::write_execute(&mut self.write_buf, "", chunk_size);
2297 proto::write_flush(&mut self.write_buf);
2299 self.flush_write()?;
2300
2301 cols
2302 } else {
2303 let name = make_stmt_name(sql_hash);
2305 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2306 let param_oids: smallvec::SmallVec<[u32; 8]> =
2307 params.iter().map(|p| p.type_oid()).collect();
2308 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2309 proto::write_describe(&mut self.write_buf, b'S', name_s);
2310 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2311
2312 proto::write_execute(&mut self.write_buf, "", chunk_size);
2313 proto::write_flush(&mut self.write_buf);
2314 self.flush_write()?;
2315
2316 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2317 let columns = self.read_column_description()?;
2318 self.query_counter += 1;
2319 self.cache_stmt(
2320 sql_hash,
2321 StmtInfo {
2322 name,
2323 sql: sql.into(),
2324 columns: columns.clone(),
2325 last_used: self.query_counter,
2326 bind_template: None,
2327 },
2328 );
2329 columns
2330 };
2331
2332 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2334
2335 self.streaming_active = true;
2336
2337 Ok((columns, false))
2338 }
2339
2340 pub fn streaming_next_chunk(
2348 &mut self,
2349 arena: &mut Arena,
2350 all_col_offsets: &mut Vec<(usize, i32)>,
2351 ) -> Result<bool, DriverError> {
2352 all_col_offsets.clear();
2353
2354 loop {
2355 let msg = self.read_one_message()?;
2356 match msg {
2357 BackendMessage::DataRow { data } => {
2358 parse_data_row_flat(data, arena, all_col_offsets)?;
2359 }
2360 BackendMessage::PortalSuspended => {
2361 return Ok(true);
2365 }
2366 BackendMessage::CommandComplete { .. } => {
2367 self.write_buf.clear();
2370 proto::write_sync(&mut self.write_buf);
2371 self.flush_write()?;
2372 self.expect_ready()?;
2373 self.shrink_buffers();
2374
2375 self.streaming_active = false;
2376 return Ok(false);
2377 }
2378 BackendMessage::EmptyQuery => {
2379 self.write_buf.clear();
2380 proto::write_sync(&mut self.write_buf);
2381 self.flush_write()?;
2382 self.expect_ready()?;
2383
2384 self.streaming_active = false;
2385 return Ok(false);
2386 }
2387 BackendMessage::ErrorResponse { data } => {
2388 let fields = proto::parse_error_response(data);
2389 self.write_buf.clear();
2391 proto::write_sync(&mut self.write_buf);
2392 self.flush_write()?;
2393 self.drain_to_ready()?;
2394
2395 self.streaming_active = false;
2396 return Err(self.make_server_error(fields));
2397 }
2398 BackendMessage::NoticeResponse { .. } => {}
2399 other => {
2400 return Err(DriverError::Protocol(format!(
2401 "unexpected message during streaming: {other:?}"
2402 )));
2403 }
2404 }
2405 }
2406 }
2407
2408 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2416 self.write_buf.clear();
2417 proto::write_execute(&mut self.write_buf, "", chunk_size);
2418 proto::write_flush(&mut self.write_buf);
2419 self.flush_write()
2420 }
2421
2422 pub fn is_streaming(&self) -> bool {
2424 self.streaming_active
2425 }
2426
2427 pub fn close(mut self) -> Result<(), DriverError> {
2429 self.write_buf.clear();
2430 proto::write_terminate(&mut self.write_buf);
2431 let _ = self.flush_write();
2432 Ok(())
2433 }
2434
2435 pub fn is_idle(&self) -> bool {
2439 self.tx_status == b'I'
2440 }
2441
2442 pub fn is_in_transaction(&self) -> bool {
2444 self.tx_status == b'T'
2445 }
2446
2447 pub fn is_in_failed_transaction(&self) -> bool {
2449 self.tx_status == b'E'
2450 }
2451
2452 pub fn touch(&mut self) {
2454 self.last_used = std::time::Instant::now();
2455 }
2456
2457 pub fn idle_duration(&self) -> std::time::Duration {
2459 self.last_used.elapsed()
2460 }
2461
2462 pub fn query_counter(&self) -> u64 {
2464 self.query_counter
2465 }
2466
2467 pub fn parameter(&self, name: &str) -> Option<&str> {
2469 self.params
2470 .iter()
2471 .find(|(k, _)| &**k == name)
2472 .map(|(_, v)| &**v)
2473 }
2474
2475 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2477 &self.params
2478 }
2479
2480 pub fn pid(&self) -> i32 {
2482 self.pid
2483 }
2484
2485 pub fn secret_key(&self) -> i32 {
2487 self.secret
2488 }
2489
2490 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2492 std::mem::take(&mut self.pending_notifications)
2493 }
2494
2495 pub fn pending_notification_count(&self) -> usize {
2497 self.pending_notifications.len()
2498 }
2499
2500 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2502 self.max_stmt_cache_size = size;
2503 }
2504
2505 pub fn stmt_cache_len(&self) -> usize {
2507 self.stmts.len()
2508 }
2509
2510 pub fn created_at(&self) -> std::time::Instant {
2512 self.created_at
2513 }
2514
2515 #[inline]
2523 fn send_pipeline(
2524 &mut self,
2525 sql: &str,
2526 sql_hash: u64,
2527 params: &[&(dyn Encode + Sync)],
2528 need_columns: bool,
2529 skip_bind_complete: bool,
2530 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
2531 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2532
2533 if params.len() > i16::MAX as usize {
2534 return Err(DriverError::Protocol(format!(
2535 "parameter count {} exceeds maximum {}",
2536 params.len(),
2537 i16::MAX
2538 )));
2539 }
2540
2541 self.write_buf.clear();
2542
2543 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash, sql) {
2544 self.query_counter += 1;
2546 info.last_used = self.query_counter;
2547
2548 let can_use_template = info
2549 .bind_template
2550 .as_ref()
2551 .is_some_and(|t| t.param_slots.len() == params.len());
2552
2553 let mut has_exec_sync = false;
2555
2556 if can_use_template {
2557 let tmpl = info
2561 .bind_template
2562 .as_ref()
2563 .expect("guarded by can_use_template");
2564 self.write_buf.extend_from_slice(&tmpl.bytes);
2565
2566 let mut template_ok = true;
2567 for (i, param) in params.iter().enumerate() {
2568 let (data_offset, old_len) = tmpl.param_slots[i];
2569 if param.is_null() {
2570 let len_offset = data_offset - 4;
2572 self.write_buf[len_offset..len_offset + 4]
2573 .copy_from_slice(&(-1i32).to_be_bytes());
2574 } else if old_len >= 0 {
2575 let end = data_offset + old_len as usize;
2576 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2577 template_ok = false;
2579 break;
2580 }
2581 } else {
2582 template_ok = false;
2585 break;
2586 }
2587 }
2588
2589 if template_ok {
2590 has_exec_sync = true; } else {
2592 self.write_buf.clear();
2593 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2594 info.bind_template = None;
2596 }
2597 } else {
2598 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2599 }
2600
2601 let cols = if need_columns {
2602 Some(info.columns.clone())
2603 } else {
2604 None
2605 };
2606
2607 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2611 info.bind_template = build_bind_template(&self.write_buf, params.len());
2612 }
2613
2614 if !has_exec_sync {
2615 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2616 }
2617 self.flush_write()?;
2618
2619 cols
2620 } else {
2621 let name = make_stmt_name(sql_hash);
2623 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2624 let param_oids: smallvec::SmallVec<[u32; 8]> =
2625 params.iter().map(|p| p.type_oid()).collect();
2626 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2627 proto::write_describe(&mut self.write_buf, b'S', name_s);
2628 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2629
2630 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2631 self.flush_write()?;
2632
2633 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2634 let columns = self.read_column_description()?;
2635 self.query_counter += 1;
2636 self.cache_stmt(
2637 sql_hash,
2638 StmtInfo {
2639 name,
2640 sql: sql.into(),
2641 columns: columns.clone(),
2642 last_used: self.query_counter,
2643 bind_template: None,
2644 },
2645 );
2646 if need_columns { Some(columns) } else { None }
2647 };
2648
2649 if !skip_bind_complete {
2650 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2651 }
2652
2653 Ok(columns)
2654 }
2655
2656 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2658 loop {
2659 let msg = self.read_one_message()?;
2660 match msg {
2661 BackendMessage::RowDescription { data } => {
2662 let cols = proto::parse_row_description(data)?;
2663 return Ok(cols.into());
2664 }
2665 BackendMessage::ParameterDescription { .. } => {}
2666 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2667 BackendMessage::NoticeResponse { .. } => {}
2668 BackendMessage::ErrorResponse { data } => {
2669 let fields = proto::parse_error_response(data);
2670 self.drain_to_ready()?;
2671 return Err(self.make_server_error(fields));
2672 }
2673 other => {
2674 return Err(DriverError::Protocol(format!(
2675 "expected RowDescription/NoData, got: {other:?}"
2676 )));
2677 }
2678 }
2679 }
2680 }
2681
2682 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2685 if self.stmts.len() >= self.max_stmt_cache_size
2686 && !self.stmts.contains_key(&sql_hash, &info.sql)
2687 {
2688 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2689 proto::write_close(&mut self.write_buf, b'S', evicted.name_str());
2690 }
2691 }
2692 self.stmts.insert(sql_hash, info);
2693 }
2694
2695 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2696 if self.pending_notifications.len() < 1024 {
2697 self.pending_notifications.push(Notification {
2698 pid,
2699 channel: channel.to_owned(),
2700 payload: payload.to_owned(),
2701 });
2702 }
2703 }
2704
2705 fn shrink_buffers(&mut self) {
2706 if self.query_counter & 63 != 0 {
2710 return;
2711 }
2712 if self.read_buf.capacity() > 64 * 1024 {
2713 self.read_buf.clear();
2714 self.read_buf.shrink_to(8192);
2715 }
2716 if self.write_buf.capacity() > 16 * 1024 {
2717 self.write_buf.clear();
2718 self.write_buf.shrink_to(8192);
2719 }
2720 }
2721
2722 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2723 if &*fields.code == "26000" {
2724 self.stmts.remove(&sql_hash);
2725 true
2726 } else {
2727 false
2728 }
2729 }
2730
2731 #[cold]
2732 #[inline(never)]
2733 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2734 DriverError::Server {
2735 code: fields.code,
2736 message: fields.message.into_boxed_str(),
2737 detail: fields.detail.map(String::into_boxed_str),
2738 hint: fields.hint.map(String::into_boxed_str),
2739 position: fields.position,
2740 }
2741 }
2742
2743 #[cold]
2749 #[inline(never)]
2750 fn handle_non_datarow_query(
2751 &mut self,
2752 msg_type: u8,
2753 payload_start: usize,
2754 payload_end: usize,
2755 sql_hash: u64,
2756 affected_rows: &mut u64,
2757 ) -> Result<(), DriverError> {
2758 match msg_type {
2759 b'2' | b'I' => {} b'C' => {
2761 *affected_rows =
2762 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2763 }
2764 b'E' => {
2765 let fields =
2766 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2767 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2768 self.drain_to_ready()?;
2769 return Err(self.make_server_error(fields));
2770 }
2771 b'A' => {
2772 let msg = proto::parse_backend_message(
2773 msg_type,
2774 &self.stream_buf[payload_start..payload_end],
2775 )?;
2776 if let BackendMessage::NotificationResponse {
2777 pid,
2778 channel,
2779 payload,
2780 } = msg
2781 {
2782 let ch = channel.to_owned();
2783 let pl = payload.to_owned();
2784 self.buffer_notification(pid, &ch, &pl);
2785 }
2786 }
2787 _ => {} }
2789 Ok(())
2790 }
2791
2792 #[cold]
2795 #[inline(never)]
2796 fn handle_non_datarow_execute(
2797 &mut self,
2798 msg_type: u8,
2799 payload_start: usize,
2800 payload_end: usize,
2801 sql_hash: u64,
2802 ) -> Result<(), DriverError> {
2803 match msg_type {
2804 b'2' | b'C' | b'I' => {} b'E' => {
2806 let fields =
2807 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2808 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2809 self.drain_to_ready()?;
2810 return Err(self.make_server_error(fields));
2811 }
2812 b'A' => {
2813 let msg = proto::parse_backend_message(
2814 msg_type,
2815 &self.stream_buf[payload_start..payload_end],
2816 )?;
2817 if let BackendMessage::NotificationResponse {
2818 pid,
2819 channel,
2820 payload,
2821 } = msg
2822 {
2823 let ch = channel.to_owned();
2824 let pl = payload.to_owned();
2825 self.buffer_notification(pid, &ch, &pl);
2826 }
2827 }
2828 _ => {} }
2830 Ok(())
2831 }
2832
2833 #[inline]
2835 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2836 loop {
2837 let (msg_type, _payload_len) = self.read_message_buffered()?;
2838 if msg_type == b'A' {
2839 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2840 if let BackendMessage::NotificationResponse {
2841 pid,
2842 channel,
2843 payload,
2844 } = msg
2845 {
2846 let pid_owned = pid;
2847 let channel_owned = channel.to_owned();
2848 let payload_owned = payload.to_owned();
2849 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2850 continue;
2851 }
2852 }
2853 return proto::parse_backend_message(msg_type, &self.read_buf);
2854 }
2855 }
2856
2857 fn expect_message(
2858 &mut self,
2859 pred: impl Fn(&BackendMessage<'_>) -> bool,
2860 ) -> Result<(), DriverError> {
2861 loop {
2862 let msg = self.read_one_message()?;
2863 if pred(&msg) {
2864 return Ok(());
2865 }
2866 match msg {
2867 BackendMessage::ErrorResponse { data } => {
2868 let fields = proto::parse_error_response(data);
2869 self.drain_to_ready()?;
2870 return Err(self.make_server_error(fields));
2871 }
2872 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2873 other => {
2874 return Err(DriverError::Protocol(format!(
2875 "unexpected message while waiting for expected type: {other:?}"
2876 )));
2877 }
2878 }
2879 }
2880 }
2881
2882 fn expect_ready(&mut self) -> Result<(), DriverError> {
2883 loop {
2884 let msg = self.read_one_message()?;
2885 match msg {
2886 BackendMessage::ReadyForQuery { status } => {
2887 self.tx_status = status;
2888 return Ok(());
2889 }
2890 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2891 BackendMessage::ErrorResponse { data } => {
2892 let fields = proto::parse_error_response(data);
2893 self.drain_to_ready()?;
2894 return Err(self.make_server_error(fields));
2895 }
2896 _ => {}
2897 }
2898 }
2899 }
2900
2901 #[inline]
2902 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2903 loop {
2904 let msg = self.read_one_message()?;
2905 if let BackendMessage::ReadyForQuery { status } = msg {
2906 self.tx_status = status;
2907 return Ok(());
2908 }
2909 }
2910 }
2911
2912 #[inline]
2916 fn flush_write(&mut self) -> Result<(), DriverError> {
2917 self.stream
2918 .write_all(&self.write_buf)
2919 .map_err(DriverError::Io)
2920 }
2921
2922 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2926 let mut header = [0u8; 5];
2927 sync_buffered_read_exact(
2928 &mut self.stream,
2929 &mut self.stream_buf,
2930 &mut self.stream_buf_pos,
2931 &mut self.stream_buf_end,
2932 &mut header,
2933 )?;
2934
2935 let msg_type = header[0];
2936 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2937
2938 if len < 4 {
2939 return Err(DriverError::Protocol(format!(
2940 "invalid message length {len} for type '{}'",
2941 msg_type as char
2942 )));
2943 }
2944
2945 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2946 if len > MAX_MESSAGE_LEN {
2947 return Err(DriverError::Protocol(format!(
2948 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2949 msg_type as char
2950 )));
2951 }
2952
2953 let payload_len = (len - 4) as usize;
2954 self.read_buf.clear();
2955 self.read_buf.resize(payload_len, 0);
2956 if payload_len > 0 {
2957 sync_buffered_read_exact(
2958 &mut self.stream,
2959 &mut self.stream_buf,
2960 &mut self.stream_buf_pos,
2961 &mut self.stream_buf_end,
2962 &mut self.read_buf[..payload_len],
2963 )?;
2964 }
2965
2966 Ok((msg_type, payload_len))
2967 }
2968
2969 #[inline]
2971 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
2972 let remaining = self.stream_buf_end - self.stream_buf_pos;
2973 if remaining > 0 && self.stream_buf_pos > 0 {
2974 self.stream_buf
2975 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2976 }
2977 self.stream_buf_pos = 0;
2978 self.stream_buf_end = remaining;
2979
2980 let n = self
2981 .stream
2982 .read(&mut self.stream_buf[remaining..])
2983 .map_err(DriverError::Io)?;
2984 if n == 0 {
2985 return Err(DriverError::Io(std::io::Error::new(
2986 std::io::ErrorKind::UnexpectedEof,
2987 "connection closed",
2988 )));
2989 }
2990 self.stream_buf_end = remaining + n;
2991 Ok(())
2992 }
2993}
2994
2995fn sync_buffered_read_exact(
2998 stream: &mut Stream,
2999 buf: &mut [u8],
3000 pos: &mut usize,
3001 end: &mut usize,
3002 out: &mut [u8],
3003) -> Result<(), DriverError> {
3004 let mut filled = 0;
3005 while filled < out.len() {
3006 let avail = *end - *pos;
3007 if avail > 0 {
3008 let take = avail.min(out.len() - filled);
3009 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
3010 *pos += take;
3011 filled += take;
3012 } else {
3013 *pos = 0;
3014 let n = stream.read(buf).map_err(DriverError::Io)?;
3015 if n == 0 {
3016 return Err(DriverError::Io(std::io::Error::new(
3017 std::io::ErrorKind::UnexpectedEof,
3018 "connection closed",
3019 )));
3020 }
3021 *end = n;
3022 }
3023 }
3024 Ok(())
3025}
3026
3027#[inline(always)]
3037pub(crate) fn parse_data_row_into_buf(
3038 data: &[u8],
3039 buf: &mut Vec<u8>,
3040 out: &mut Vec<(usize, i32)>,
3041) -> Result<(), DriverError> {
3042 if data.len() < 2 {
3043 return Err(DriverError::Protocol("DataRow too short".into()));
3044 }
3045
3046 let num_cols = i16::from_be_bytes([data[0], data[1]]);
3047 if num_cols < 0 {
3048 return Err(DriverError::Protocol(
3049 "DataRow: negative column count".into(),
3050 ));
3051 }
3052 let num_cols = num_cols as usize;
3053
3054 let col_data = &data[2..];
3062 let base = buf.len();
3063 buf.extend_from_slice(col_data);
3064
3065 let mut pos: usize = 0;
3067 for _ in 0..num_cols {
3068 if pos + 4 > col_data.len() {
3069 return Err(DriverError::Protocol("DataRow truncated".into()));
3070 }
3071
3072 let col_len = i32::from_be_bytes([
3073 col_data[pos],
3074 col_data[pos + 1],
3075 col_data[pos + 2],
3076 col_data[pos + 3],
3077 ]);
3078 pos += 4;
3079
3080 if col_len < 0 {
3081 out.push((0, -1));
3082 } else {
3083 let len = col_len as usize;
3084 if pos + len > col_data.len() {
3085 return Err(DriverError::Protocol(
3086 "DataRow column data truncated".into(),
3087 ));
3088 }
3089 out.push((base + pos, col_len));
3091 pos += len;
3092 }
3093 }
3094
3095 Ok(())
3096}
3097
3098fn parse_data_row_flat(
3102 data: &[u8],
3103 arena: &mut Arena,
3104 out: &mut Vec<(usize, i32)>,
3105) -> Result<(), DriverError> {
3106 if data.len() < 2 {
3107 return Err(DriverError::Protocol("DataRow too short".into()));
3108 }
3109
3110 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3111 if num_cols_raw < 0 {
3112 return Err(DriverError::Protocol(
3113 "DataRow: negative column count".into(),
3114 ));
3115 }
3116 let num_cols = num_cols_raw as usize;
3117 out.reserve(num_cols);
3118
3119 let col_data = &data[2..];
3122 let base = arena.alloc_copy(col_data);
3123
3124 let mut pos: usize = 0;
3126 for _ in 0..num_cols {
3127 if pos + 4 > col_data.len() {
3128 return Err(DriverError::Protocol("DataRow truncated".into()));
3129 }
3130
3131 let col_len = i32::from_be_bytes([
3132 col_data[pos],
3133 col_data[pos + 1],
3134 col_data[pos + 2],
3135 col_data[pos + 3],
3136 ]);
3137 pos += 4;
3138
3139 if col_len < 0 {
3140 out.push((0, -1));
3141 } else {
3142 let len = col_len as usize;
3143 if pos + len > col_data.len() {
3144 return Err(DriverError::Protocol(
3145 "DataRow column data truncated".into(),
3146 ));
3147 }
3148 out.push((base + pos, col_len));
3150 pos += len;
3151 }
3152 }
3153
3154 Ok(())
3155}
3156
3157#[cfg(test)]
3158#[allow(clippy::approx_constant)]
3159mod tests {
3160 use super::*;
3161 use crate::types::hash_sql;
3162
3163 #[test]
3164 fn sync_config_tcp_no_longer_rejected() {
3165 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3168 let result = Connection::connect(&config);
3169 assert!(result.is_err());
3170 let err = result.unwrap_err().to_string();
3171 assert!(
3174 !err.contains("Unix domain socket"),
3175 "error should NOT mention UDS requirement: {err}"
3176 );
3177 }
3178
3179 #[test]
3180 fn sync_data_row_parsing() {
3181 let mut arena = Arena::new();
3182 let mut out = Vec::new();
3183
3184 let mut data = Vec::new();
3185 data.extend_from_slice(&2i16.to_be_bytes());
3186 data.extend_from_slice(&4i32.to_be_bytes());
3187 data.extend_from_slice(&42i32.to_be_bytes());
3188 data.extend_from_slice(&(-1i32).to_be_bytes());
3189
3190 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3191 assert_eq!(out.len(), 2);
3192 assert_eq!(out[0].1, 4);
3193 assert_eq!(out[1].1, -1);
3194 }
3195
3196 #[test]
3197 fn sync_data_row_empty() {
3198 let mut arena = Arena::new();
3199 let mut out = Vec::new();
3200 let data = 0i16.to_be_bytes();
3201 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3202 assert_eq!(out.len(), 0);
3203 }
3204
3205 #[test]
3206 fn sync_data_row_too_short() {
3207 let mut arena = Arena::new();
3208 let mut out = Vec::new();
3209 let data = vec![0u8];
3210 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3211 }
3212
3213 #[test]
3214 fn sync_data_row_negative_col_count() {
3215 let mut arena = Arena::new();
3216 let mut out = Vec::new();
3217 let data = (-1i16).to_be_bytes();
3218 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3219 }
3220
3221 #[test]
3222 fn sync_data_row_truncated() {
3223 let mut arena = Arena::new();
3224 let mut out = Vec::new();
3225 let mut data = Vec::new();
3226 data.extend_from_slice(&2i16.to_be_bytes());
3227 data.extend_from_slice(&4i32.to_be_bytes());
3228 data.extend_from_slice(&42i32.to_be_bytes());
3229 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3231 }
3232
3233 #[test]
3234 fn sync_data_row_col_data_truncated() {
3235 let mut arena = Arena::new();
3236 let mut out = Vec::new();
3237 let mut data = Vec::new();
3238 data.extend_from_slice(&1i16.to_be_bytes());
3239 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
3242 }
3243
3244 #[test]
3247 fn sync_connect_tcp_unreachable_port() {
3248 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3251 let result = Connection::connect(&config);
3252 assert!(result.is_err());
3253 let err = result.unwrap_err().to_string();
3254 assert!(
3255 !err.contains("Unix domain socket"),
3256 "error should NOT mention UDS: {err}"
3257 );
3258 }
3259
3260 #[test]
3261 fn sync_connect_ip_address_attempts_tcp() {
3262 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3265 let result = Connection::connect(&config);
3266 assert!(result.is_err());
3267 }
3268
3269 #[test]
3272 fn sync_data_row_all_null() {
3273 let mut arena = Arena::new();
3274 let mut out = Vec::new();
3275 let mut data = Vec::new();
3276 data.extend_from_slice(&3i16.to_be_bytes());
3277 data.extend_from_slice(&(-1i32).to_be_bytes());
3278 data.extend_from_slice(&(-1i32).to_be_bytes());
3279 data.extend_from_slice(&(-1i32).to_be_bytes());
3280 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3281 assert_eq!(out.len(), 3);
3282 for (_, len) in &out {
3283 assert_eq!(*len, -1);
3284 }
3285 }
3286
3287 #[test]
3288 fn sync_data_row_long_text() {
3289 let mut arena = Arena::new();
3290 let mut out = Vec::new();
3291 let long_text = "a".repeat(2048);
3292 let text_bytes = long_text.as_bytes();
3293 let mut data = Vec::new();
3294 data.extend_from_slice(&1i16.to_be_bytes());
3295 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3296 data.extend_from_slice(text_bytes);
3297 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3298 assert_eq!(out.len(), 1);
3299 assert_eq!(out[0].1, text_bytes.len() as i32);
3300 let stored = arena.get(out[0].0, out[0].1 as usize);
3301 assert_eq!(stored, text_bytes);
3302 }
3303
3304 #[test]
3305 fn sync_data_row_empty_text() {
3306 let mut arena = Arena::new();
3307 let mut out = Vec::new();
3308 let mut data = Vec::new();
3309 data.extend_from_slice(&1i16.to_be_bytes());
3310 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3312 assert_eq!(out.len(), 1);
3313 assert_eq!(out[0].1, 0); }
3315
3316 #[test]
3317 fn sync_data_row_17_columns_exceeds_smallvec() {
3318 let mut arena = Arena::new();
3319 let mut out = Vec::new();
3320 let mut data = Vec::new();
3321 let num_cols: i16 = 20;
3322 data.extend_from_slice(&num_cols.to_be_bytes());
3323 for i in 0..num_cols {
3324 let val = (i as i32).to_be_bytes();
3325 data.extend_from_slice(&4i32.to_be_bytes());
3326 data.extend_from_slice(&val);
3327 }
3328 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3329 assert_eq!(out.len(), 20);
3330 for (idx, (offset, len)) in out.iter().enumerate() {
3331 assert_eq!(*len, 4);
3332 let stored = arena.get(*offset, 4);
3333 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3334 assert_eq!(val, idx as i32);
3335 }
3336 }
3337
3338 #[test]
3339 fn sync_data_row_mixed_null_and_data() {
3340 let mut arena = Arena::new();
3341 let mut out = Vec::new();
3342 let mut data = Vec::new();
3343 data.extend_from_slice(&5i16.to_be_bytes());
3344 data.extend_from_slice(&(-1i32).to_be_bytes());
3346 data.extend_from_slice(&4i32.to_be_bytes());
3348 data.extend_from_slice(&42i32.to_be_bytes());
3349 data.extend_from_slice(&(-1i32).to_be_bytes());
3351 data.extend_from_slice(&(-1i32).to_be_bytes());
3353 data.extend_from_slice(&5i32.to_be_bytes());
3355 data.extend_from_slice(b"hello");
3356
3357 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3358 assert_eq!(out.len(), 5);
3359 assert_eq!(out[0].1, -1);
3360 assert_eq!(out[1].1, 4);
3361 assert_eq!(out[2].1, -1);
3362 assert_eq!(out[3].1, -1);
3363 assert_eq!(out[4].1, 5);
3364 let stored = arena.get(out[4].0, 5);
3365 assert_eq!(stored, b"hello");
3366 }
3367
3368 #[test]
3371 #[ignore] fn sync_connect_uds_if_pg_available() {
3373 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3374 let result = Connection::connect(&config);
3375 if let Ok(conn) = result {
3377 assert!(conn.pid() != 0, "pid should be nonzero");
3378 assert!(conn.is_idle(), "should start idle");
3379 assert!(!conn.is_in_transaction(), "should not be in tx");
3380 assert!(
3381 !conn.is_in_failed_transaction(),
3382 "should not be in failed tx"
3383 );
3384 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3385 let _ = conn.close();
3386 }
3387 }
3388
3389 #[test]
3390 #[ignore] fn sync_simple_query_if_pg_available() {
3392 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3393 let mut conn = Connection::connect(&config).unwrap();
3394 conn.simple_query("SELECT 1").unwrap();
3395 assert!(conn.is_idle());
3396 let _ = conn.close();
3397 }
3398
3399 #[test]
3400 #[ignore] fn sync_query_with_params_if_pg_available() {
3402 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3403 let mut conn = Connection::connect(&config).unwrap();
3404 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3405 let hash = hash_sql(sql);
3406 let a: i32 = 10;
3407 let b: i32 = 20;
3408 let result = conn
3409 .query(
3410 sql,
3411 hash,
3412 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3413 )
3414 .unwrap();
3415 assert_eq!(result.len(), 1);
3416 let _ = conn.close();
3417 }
3418
3419 #[test]
3420 #[ignore] fn sync_execute_if_pg_available() {
3422 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3423 let mut conn = Connection::connect(&config).unwrap();
3424 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3425 .unwrap();
3426 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3427 let hash = hash_sql(sql);
3428 let val: i32 = 42;
3429 let affected = conn
3430 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3431 .unwrap();
3432 assert_eq!(affected, 1);
3433 let _ = conn.close();
3434 }
3435
3436 #[test]
3437 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3439 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3440 let mut conn = Connection::connect(&config).unwrap();
3441 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3442 .unwrap();
3443 let sql = "SELECT id FROM _sync_fe0";
3444 let hash = hash_sql(sql);
3445 let mut count = 0u32;
3446 conn.for_each(sql, hash, &[], |_row| {
3447 count += 1;
3448 Ok(())
3449 })
3450 .unwrap();
3451 assert_eq!(count, 0);
3452 let _ = conn.close();
3453 }
3454
3455 #[test]
3456 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3458 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3459 let mut conn = Connection::connect(&config).unwrap();
3460 let sql = "SELECT generate_series(1, 5)";
3461 let hash = hash_sql(sql);
3462 let mut count = 0u32;
3463 conn.for_each(sql, hash, &[], |_row| {
3464 count += 1;
3465 Ok(())
3466 })
3467 .unwrap();
3468 assert_eq!(count, 5);
3469 let _ = conn.close();
3470 }
3471
3472 #[test]
3473 #[ignore] fn sync_prepare_only_if_pg_available() {
3475 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3476 let mut conn = Connection::connect(&config).unwrap();
3477 let sql = "SELECT 1";
3478 let hash = hash_sql(sql);
3479 conn.prepare_only(sql, hash).unwrap();
3480 assert_eq!(conn.stmt_cache_len(), 1);
3481 conn.prepare_only(sql, hash).unwrap();
3483 assert_eq!(conn.stmt_cache_len(), 1);
3484 let _ = conn.close();
3485 }
3486
3487 #[test]
3488 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3490 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3491 let mut conn = Connection::connect(&config).unwrap();
3492 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3493 assert!(!rows.is_empty());
3494 let _ = conn.close();
3495 }
3496
3497 #[test]
3498 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3500 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3501 let mut conn = Connection::connect(&config).unwrap();
3502 let sql1 = "SELECT 1";
3503 let hash1 = hash_sql(sql1);
3504 conn.query(sql1, hash1, &[]).unwrap();
3505 assert_eq!(conn.stmt_cache_len(), 1);
3506 conn.query(sql1, hash1, &[]).unwrap();
3508 assert_eq!(conn.stmt_cache_len(), 1);
3509 let sql2 = "SELECT 2";
3511 let hash2 = hash_sql(sql2);
3512 conn.query(sql2, hash2, &[]).unwrap();
3513 assert_eq!(conn.stmt_cache_len(), 2);
3514 let _ = conn.close();
3515 }
3516
3517 #[test]
3518 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3520 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3521 let mut conn = Connection::connect(&config).unwrap();
3522 let sql = "SELECTTTT INVALID GARBAGE";
3523 let hash = hash_sql(sql);
3524 let result = conn.query(sql, hash, &[]);
3525 assert!(result.is_err());
3526 assert!(conn.is_idle());
3528 let _ = conn.close();
3529 }
3530
3531 #[test]
3532 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3534 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3535 let mut conn = Connection::connect(&config).unwrap();
3536 assert!(conn.is_idle());
3537 assert!(!conn.is_in_transaction());
3538 conn.simple_query("BEGIN").unwrap();
3539 assert!(conn.is_in_transaction());
3540 assert!(!conn.is_idle());
3541 conn.simple_query("COMMIT").unwrap();
3542 assert!(conn.is_idle());
3543 assert!(!conn.is_in_transaction());
3544 let _ = conn.close();
3545 }
3546
3547 #[test]
3548 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3550 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3551 let mut conn = Connection::connect(&config).unwrap();
3552 conn.set_max_stmt_cache_size(3);
3553 for i in 0..5 {
3554 let sql = format!("SELECT {}", i);
3555 let hash = hash_sql(&sql);
3556 conn.query(&sql, hash, &[]).unwrap();
3557 }
3558 assert!(
3560 conn.stmt_cache_len() <= 3,
3561 "cache should be capped at 3, got {}",
3562 conn.stmt_cache_len()
3563 );
3564 let _ = conn.close();
3565 }
3566
3567 #[test]
3568 #[ignore] fn sync_for_each_raw_if_pg_available() {
3570 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3571 let mut conn = Connection::connect(&config).unwrap();
3572 let sql = "SELECT generate_series(1, 3)";
3573 let hash = hash_sql(sql);
3574 let mut raw_count = 0u32;
3575 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3576 raw_count += 1;
3577 Ok(())
3578 })
3579 .unwrap();
3580 assert_eq!(raw_count, 3);
3581 let _ = conn.close();
3582 }
3583
3584 #[test]
3585 #[ignore] fn sync_query_null_params_if_pg_available() {
3587 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3588 let mut conn = Connection::connect(&config).unwrap();
3589 let sql = "SELECT $1::int4 IS NULL AS is_null";
3590 let hash = hash_sql(sql);
3591 let val: Option<i32> = None;
3592 let _result = conn
3593 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
3594 .unwrap();
3595 let _ = conn.close();
3596 }
3597
3598 #[test]
3599 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3601 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3602 let mut conn = Connection::connect(&config).unwrap();
3603 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3604 let hash = hash_sql(sql);
3605 let p1: i32 = 42;
3606 let p2: i64 = 9999999;
3607 let p3: &str = "hello";
3608 let p4: bool = true;
3609 let p5: f64 = 3.14;
3610 let result = conn
3611 .query(
3612 sql,
3613 hash,
3614 &[
3615 &p1 as &(dyn Encode + Sync),
3616 &p2 as &(dyn Encode + Sync),
3617 &p3 as &(dyn Encode + Sync),
3618 &p4 as &(dyn Encode + Sync),
3619 &p5 as &(dyn Encode + Sync),
3620 ],
3621 )
3622 .unwrap();
3623 assert_eq!(result.len(), 1);
3624 let _ = conn.close();
3625 }
3626
3627 #[test]
3630 fn sync_shrink_threshold_values() {
3631 let shrink = 64 * 1024usize;
3640 let initial = 8192usize;
3641 assert!(
3642 shrink > initial,
3643 "shrink threshold must exceed initial size"
3644 );
3645 }
3646
3647 #[test]
3650 fn sync_connection_debug_format() {
3651 let fmt_str = format!(
3655 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3656 0, 'I', 0
3657 );
3658 assert!(fmt_str.contains("Connection"));
3659 assert!(fmt_str.contains("pid"));
3660 assert!(fmt_str.contains("tx_status"));
3661 }
3662
3663 #[test]
3666 fn sync_connect_sslmode_require_without_tls_feature() {
3667 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3671 config.ssl = SslMode::Require;
3672 let result = Connection::connect(&config);
3673 assert!(result.is_err());
3674 }
3679
3680 #[test]
3681 fn sync_connect_sslmode_disable_attempts_tcp() {
3682 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3683 config.ssl = SslMode::Disable;
3684 let result = Connection::connect(&config);
3685 assert!(result.is_err());
3686 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
3688 }
3689
3690 #[test]
3691 fn sync_connect_sslmode_prefer_attempts_tcp() {
3692 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3693 config.ssl = SslMode::Prefer;
3694 let result = Connection::connect(&config);
3695 assert!(result.is_err());
3696 }
3697
3698 #[test]
3701 #[ignore] fn sync_streaming_basic_if_pg_available() {
3703 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3704 let mut conn = Connection::connect(&config).unwrap();
3705 assert!(!conn.is_streaming());
3706
3707 let sql = "SELECT generate_series(1, 10)";
3708 let hash = hash_sql(sql);
3709
3710 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
3711 assert!(!cols.is_empty());
3712 assert!(conn.is_streaming());
3713
3714 let mut arena = Arena::new();
3715 let mut offsets = Vec::new();
3716 let mut total_rows = 0;
3717
3718 loop {
3720 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
3721 total_rows += offsets.len();
3722 if !has_more {
3723 break;
3724 }
3725 conn.streaming_send_execute(3).unwrap();
3726 }
3727
3728 assert_eq!(total_rows, 10);
3729 assert!(!conn.is_streaming());
3730 let _ = conn.close();
3731 }
3732
3733 #[test]
3736 #[ignore] fn sync_prepare_describe_if_pg_available() {
3738 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3739 let mut conn = Connection::connect(&config).unwrap();
3740
3741 let result = conn
3742 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
3743 .unwrap();
3744 assert_eq!(result.columns.len(), 1);
3745 assert_eq!(&*result.columns[0].name, "sum");
3746 assert_eq!(result.param_oids.len(), 2);
3747 let _ = conn.close();
3748 }
3749
3750 #[test]
3753 #[ignore] fn sync_wait_for_notification_if_pg_available() {
3755 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3756 let mut conn = Connection::connect(&config).unwrap();
3757
3758 conn.simple_query("LISTEN test_chan").unwrap();
3759 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
3760
3761 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
3763 .unwrap();
3764
3765 let (channel, payload) = conn.wait_for_notification().unwrap();
3766 assert_eq!(channel, "test_chan");
3767 assert_eq!(payload, "hello");
3768 let _ = conn.close();
3769 }
3770
3771 #[test]
3774 #[ignore] fn sync_cancel_if_pg_available() {
3776 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3777 let conn = Connection::connect(&config).unwrap();
3778 let result = conn.cancel();
3781 let _ = result;
3783 let _ = conn.close();
3784 }
3785
3786 #[test]
3789 #[ignore] fn sync_server_params_if_pg_available() {
3791 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3792 let conn = Connection::connect(&config).unwrap();
3793 let params = conn.server_params();
3794 assert!(
3795 !params.is_empty(),
3796 "server should send parameters during startup"
3797 );
3798 assert!(
3800 conn.parameter("server_encoding").is_some(),
3801 "server_encoding should be present"
3802 );
3803 let _ = conn.close();
3804 }
3805
3806 #[test]
3809 #[ignore] fn sync_set_read_timeout_if_pg_available() {
3811 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3812 let conn = Connection::connect(&config).unwrap();
3813 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
3815 .unwrap();
3816 conn.set_read_timeout(None).unwrap();
3817 let _ = conn.close();
3818 }
3819}