1use std::io::{Read, Write};
15use std::sync::Arc;
16
17use crate::DriverError;
18use crate::arena::Arena;
19use crate::auth;
20use crate::codec::Encode;
21use crate::proto::{self, BackendMessage};
22use crate::stmt_cache::{StmtCache, StmtInfo, build_bind_template, make_stmt_name};
23use crate::sync_io::Stream;
24use crate::types::{
25 ColumnDesc, Config, Notification, PgDataRow, PrepareResult, QueryResult, SimpleRow, SslMode,
26 StartupAction,
27};
28
29use std::cell::RefCell;
37
38thread_local! {
39 static RESP_BUF_POOL: RefCell<Vec<Vec<u8>>> = const { RefCell::new(Vec::new()) };
40}
41
42fn acquire_resp_buf() -> Vec<u8> {
43 RESP_BUF_POOL
44 .with(|pool| pool.borrow_mut().pop())
45 .unwrap_or_default()
46}
47
48pub fn release_resp_buf(buf: Vec<u8>) {
50 RESP_BUF_POOL.with(|pool| {
51 let mut pool = pool.borrow_mut();
52 if pool.len() < 4 {
53 pool.push(buf);
54 }
55 });
56}
57
58pub struct Connection {
87 stream: Stream,
88 read_buf: Vec<u8>,
89 stream_buf: Vec<u8>,
90 stream_buf_pos: usize,
91 stream_buf_end: usize,
92 write_buf: Vec<u8>,
93 stmts: StmtCache,
94 params: Vec<(Box<str>, Box<str>)>,
95 pid: i32,
96 secret: i32,
97 tx_status: u8,
98 last_used: std::time::Instant,
99 streaming_active: bool,
100 created_at: std::time::Instant,
101 pending_notifications: Vec<Notification>,
102 max_stmt_cache_size: usize,
103 query_counter: u64,
104 connect_config: Arc<Config>,
107}
108
109impl std::fmt::Debug for Connection {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct("Connection")
112 .field("pid", &self.pid)
113 .field("tx_status", &(self.tx_status as char))
114 .field("stmt_cache_len", &self.stmts.len())
115 .finish()
116 }
117}
118
119impl Connection {
120 pub fn connect(config: &Config) -> Result<Self, DriverError> {
132 Self::connect_arc(Arc::new(config.clone()))
133 }
134
135 pub fn connect_arc(config: Arc<Config>) -> Result<Self, DriverError> {
140 config.validate()?;
141
142 let stream = if config.host_is_uds() {
143 #[cfg(unix)]
145 {
146 let path = config.uds_path();
147 let unix =
148 std::os::unix::net::UnixStream::connect(&path).map_err(DriverError::Io)?;
149 Stream::Unix(unix)
150 }
151 #[cfg(not(unix))]
152 {
153 return Err(DriverError::Protocol(
154 "Unix domain sockets are not supported on this platform".into(),
155 ));
156 }
157 } else {
158 let addr = format!("{}:{}", config.host, config.port);
160 let tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
161
162 match config.ssl {
163 SslMode::Disable => {
164 tcp.set_nodelay(true).map_err(DriverError::Io)?;
165 let stream = Stream::Tcp(tcp);
166 stream.set_keepalive()?;
167 stream
168 }
169 SslMode::Prefer | SslMode::Require => {
170 #[cfg(feature = "tls")]
171 {
172 match crate::tls_sync::try_upgrade(
173 tcp,
174 &config.host,
175 config.ssl == SslMode::Require,
176 ) {
177 Ok(tls_stream) => {
178 let stream = Stream::Tls(Box::new(tls_stream));
179 stream.set_nodelay()?;
180 stream.set_keepalive()?;
181 stream
182 }
183 Err(e) => {
184 if config.ssl == SslMode::Require {
185 return Err(e);
186 }
187 let tcp =
189 std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
190 tcp.set_nodelay(true).map_err(DriverError::Io)?;
191 let stream = Stream::Tcp(tcp);
192 stream.set_keepalive()?;
193 stream
194 }
195 }
196 }
197 #[cfg(not(feature = "tls"))]
198 {
199 if config.ssl == SslMode::Require {
200 return Err(DriverError::Protocol(
201 "sslmode=require but bsql was compiled without the 'tls' feature"
202 .into(),
203 ));
204 }
205 tcp.set_nodelay(true).map_err(DriverError::Io)?;
206 let stream = Stream::Tcp(tcp);
207 stream.set_keepalive()?;
208 stream
209 }
210 }
211 }
212 };
213
214 let mut conn = Self {
215 stream,
216 read_buf: Vec::with_capacity(8192),
217 stream_buf: vec![0u8; 65536],
218 stream_buf_pos: 0,
219 stream_buf_end: 0,
220 write_buf: Vec::with_capacity(4096),
221 stmts: StmtCache::default(),
222 params: Vec::new(),
223 pid: 0,
224 secret: 0,
225 tx_status: b'I',
226 last_used: std::time::Instant::now(),
227 streaming_active: false,
228 created_at: std::time::Instant::now(),
229 pending_notifications: Vec::new(),
230 max_stmt_cache_size: 256,
231 query_counter: 0,
232 connect_config: config.clone(),
233 };
234
235 conn.startup(&config)?;
236 conn.validate_server_params()?;
237
238 if config.statement_timeout_secs > 0 {
239 conn.simple_query(&format!(
240 "SET statement_timeout = '{}s'",
241 config.statement_timeout_secs
242 ))?;
243 }
244
245 Ok(conn)
246 }
247
248 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
251 self.write_buf.clear();
252 proto::write_startup(&mut self.write_buf, &config.user, &config.database);
253 self.flush_write()?;
254
255 loop {
256 let action = self.read_startup_action()?;
257 match action {
258 StartupAction::AuthOk => {}
259 StartupAction::AuthCleartext => {
260 self.write_buf.clear();
261 let mut pw = config.password.as_bytes().to_vec();
262 pw.push(0);
263 proto::write_password(&mut self.write_buf, &pw);
264 self.flush_write()?;
265 }
266 StartupAction::AuthMd5(salt) => {
267 self.write_buf.clear();
268 let hash = auth::md5_password(&config.user, &config.password, &salt);
269 proto::write_password(&mut self.write_buf, &hash);
270 self.flush_write()?;
271 }
272 StartupAction::AuthSasl(mechanisms_data) => {
273 self.handle_scram(config, &mechanisms_data)?;
274 }
275 StartupAction::ParameterStatus(name, value) => {
276 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
277 entry.1 = value;
278 } else {
279 self.params.push((name, value));
280 }
281 }
282 StartupAction::BackendKeyData(pid, secret) => {
283 self.pid = pid;
284 self.secret = secret;
285 }
286 StartupAction::ReadyForQuery(status) => {
287 self.tx_status = status;
288 return Ok(());
289 }
290 StartupAction::Error(msg) => {
291 return Err(DriverError::Auth(msg));
292 }
293 StartupAction::Notice => {}
294 }
295 }
296 }
297
298 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
299 let (msg_type, _) = self.read_message_buffered()?;
300 let payload = &self.read_buf;
301 let msg = proto::parse_backend_message(msg_type, payload)?;
302 match msg {
303 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
304 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
305 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
306 BackendMessage::AuthSasl { mechanisms } => {
307 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
308 }
309 BackendMessage::ParameterStatus { name, value } => {
310 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
311 }
312 BackendMessage::BackendKeyData { pid, secret } => {
313 Ok(StartupAction::BackendKeyData(pid, secret))
314 }
315 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
316 BackendMessage::ErrorResponse { data } => {
317 let fields = proto::parse_error_response(data);
318 Ok(StartupAction::Error(fields.to_string()))
319 }
320 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
321 other => Err(DriverError::Protocol(format!(
322 "unexpected message during startup: {other:?}"
323 ))),
324 }
325 }
326
327 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
328 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
329 if !mechs.contains(&"SCRAM-SHA-256") {
330 return Err(DriverError::Auth(format!(
331 "server requires unsupported SASL mechanism(s): {mechs:?}"
332 )));
333 }
334
335 let mut scram = auth::ScramClient::new(&config.user, &config.password)?;
336
337 let client_first = scram.client_first_message();
339 self.write_buf.clear();
340 proto::write_sasl_initial(&mut self.write_buf, "SCRAM-SHA-256", &client_first);
341 self.flush_write()?;
342
343 let (msg_type, _) = self.read_message_buffered()?;
345 let server_first = {
346 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
347 match msg {
348 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
349 BackendMessage::ErrorResponse { data } => {
350 let fields = proto::parse_error_response(data);
351 return Err(DriverError::Auth(fields.to_string()));
352 }
353 other => {
354 return Err(DriverError::Protocol(format!(
355 "expected AuthSaslContinue, got: {other:?}"
356 )));
357 }
358 }
359 };
360
361 scram.process_server_first(&server_first)?;
362
363 let client_final = scram.client_final_message()?;
365 self.write_buf.clear();
366 proto::write_sasl_response(&mut self.write_buf, &client_final);
367 self.flush_write()?;
368
369 let (msg_type, _) = self.read_message_buffered()?;
371 {
372 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
373 match msg {
374 BackendMessage::AuthSaslFinal { data } => {
375 let data_owned = data.to_vec();
376 scram.verify_server_final(&data_owned)?;
377 }
378 BackendMessage::ErrorResponse { data } => {
379 let fields = proto::parse_error_response(data);
380 return Err(DriverError::Auth(fields.to_string()));
381 }
382 other => {
383 return Err(DriverError::Protocol(format!(
384 "expected AuthSaslFinal, got: {other:?}"
385 )));
386 }
387 }
388 }
389
390 let (msg_type, _) = self.read_message_buffered()?;
392 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
393 match msg {
394 BackendMessage::AuthOk => Ok(()),
395 BackendMessage::ErrorResponse { data } => {
396 let fields = proto::parse_error_response(data);
397 Err(DriverError::Auth(fields.to_string()))
398 }
399 other => Err(DriverError::Protocol(format!(
400 "expected AuthOk after SCRAM, got: {other:?}"
401 ))),
402 }
403 }
404
405 fn validate_server_params(&self) -> Result<(), DriverError> {
406 if let Some(encoding) = self.parameter("server_encoding") {
407 let normalized = encoding.to_uppercase();
408 if normalized != "UTF8" && normalized != "UTF-8" {
409 return Err(DriverError::Protocol(format!(
410 "server_encoding is '{encoding}', but bsql requires UTF-8."
411 )));
412 }
413 }
414 if let Some(encoding) = self.parameter("client_encoding") {
415 let normalized = encoding.to_uppercase();
416 if normalized != "UTF8" && normalized != "UTF-8" {
417 return Err(DriverError::Protocol(format!(
418 "client_encoding is '{encoding}', but bsql requires UTF-8."
419 )));
420 }
421 }
422 if let Some(idt) = self.parameter("integer_datetimes") {
423 if idt != "on" {
424 return Err(DriverError::Protocol(format!(
425 "integer_datetimes is '{idt}', but bsql requires 'on'."
426 )));
427 }
428 }
429 Ok(())
430 }
431
432 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
438 if self.stmts.contains_key(&sql_hash) {
439 return Ok(());
440 }
441 let name = make_stmt_name(sql_hash);
442 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
443 self.write_buf.clear();
444 proto::write_parse(&mut self.write_buf, name_s, sql, &[]);
445 proto::write_describe(&mut self.write_buf, b'S', name_s);
446 proto::write_sync(&mut self.write_buf);
447 self.flush_write()?;
448
449 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
450 let columns = self.read_column_description()?;
451 self.expect_ready()?;
452
453 self.query_counter += 1;
454 self.cache_stmt(
455 sql_hash,
456 StmtInfo {
457 name,
458 columns,
459 last_used: self.query_counter,
460 bind_template: None,
461 },
462 );
463 Ok(())
464 }
465
466 #[inline]
477 pub fn query(
478 &mut self,
479 sql: &str,
480 sql_hash: u64,
481 params: &[&(dyn Encode + Sync)],
482 ) -> Result<QueryResult, DriverError> {
483 let columns = self
484 .send_pipeline(sql, sql_hash, params, true, true)?
485 .expect("send_pipeline(need_columns=true) must return Some");
486
487 let num_cols = columns.len();
488 let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
489 let mut affected_rows: u64 = 0;
490
491 let mut resp_buf = acquire_resp_buf();
500 resp_buf.clear();
501
502 'outer: loop {
504 loop {
505 let avail = self.stream_buf_end - self.stream_buf_pos;
506 if avail < 5 {
507 break; }
509
510 let msg_type = self.stream_buf[self.stream_buf_pos];
511 let raw_len = i32::from_be_bytes([
512 self.stream_buf[self.stream_buf_pos + 1],
513 self.stream_buf[self.stream_buf_pos + 2],
514 self.stream_buf[self.stream_buf_pos + 3],
515 self.stream_buf[self.stream_buf_pos + 4],
516 ]);
517
518 if raw_len < 4 {
519 return Err(DriverError::Protocol(format!(
520 "invalid message length {raw_len} for type '{}'",
521 msg_type as char
522 )));
523 }
524
525 let payload_len = (raw_len - 4) as usize;
526 let total_msg_len = 5 + payload_len;
527
528 if avail < total_msg_len {
529 if total_msg_len > self.stream_buf.len() {
530 let msg = self.read_one_message()?;
532 match msg {
533 BackendMessage::BindComplete => continue,
534 BackendMessage::DataRow { data } => {
535 parse_data_row_into_buf(data, &mut resp_buf, &mut all_col_offsets)?;
536 continue;
537 }
538 BackendMessage::CommandComplete { tag } => {
539 affected_rows = proto::parse_command_tag(tag);
540 continue;
541 }
542 BackendMessage::EmptyQuery => continue,
543 BackendMessage::ReadyForQuery { status } => {
544 self.tx_status = status;
545 break 'outer;
546 }
547 BackendMessage::NoticeResponse { .. } => continue,
548 BackendMessage::ErrorResponse { data } => {
549 let fields = proto::parse_error_response(data);
550 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
551 self.drain_to_ready()?;
552 return Err(self.make_server_error(fields));
553 }
554 other => {
555 return Err(DriverError::Protocol(format!(
556 "unexpected message during query: {other:?}"
557 )));
558 }
559 }
560 }
561 break; }
563
564 let payload_start = self.stream_buf_pos + 5;
566 let payload_end = payload_start + payload_len;
567
568 if msg_type == b'D' {
569 parse_data_row_into_buf(
571 &self.stream_buf[payload_start..payload_end],
572 &mut resp_buf,
573 &mut all_col_offsets,
574 )?;
575 } else if msg_type == b'Z' {
576 if payload_len >= 1 {
577 self.tx_status = self.stream_buf[payload_start];
578 }
579 self.stream_buf_pos += total_msg_len;
580 break 'outer;
581 } else {
582 self.handle_non_datarow_query(
583 msg_type,
584 payload_start,
585 payload_end,
586 sql_hash,
587 &mut affected_rows,
588 )?;
589 }
590
591 self.stream_buf_pos += total_msg_len;
592 }
593
594 self.refill_stream_buf()?;
595 }
596
597 self.shrink_buffers();
598
599 Ok(QueryResult::from_parts_with_buf(
602 all_col_offsets,
603 num_cols,
604 columns,
605 affected_rows,
606 resp_buf,
607 ))
608 }
609
610 #[inline]
621 pub fn execute_monolithic(
622 &mut self,
623 sql: &str,
624 sql_hash: u64,
625 params: &[&(dyn Encode + Sync)],
626 ) -> Result<u64, DriverError> {
627 self.write_buf.clear();
629
630 let info = match self.stmts.get_mut(&sql_hash) {
632 Some(info) => {
633 self.query_counter += 1;
634 info.last_used = self.query_counter;
635 info
636 }
637 None => {
638 return self.execute_with_prepare(sql, sql_hash, params);
640 }
641 };
642
643 let can_use_template = info
645 .bind_template
646 .as_ref()
647 .is_some_and(|t| t.param_slots.len() == params.len());
648
649 let mut has_exec_sync = false;
650
651 if can_use_template {
652 let tmpl = info
654 .bind_template
655 .as_ref()
656 .expect("guarded by can_use_template");
657 self.write_buf.extend_from_slice(&tmpl.bytes);
658
659 let mut template_ok = true;
660 for (i, param) in params.iter().enumerate() {
661 let (data_offset, old_len) = tmpl.param_slots[i];
662 if param.is_null() {
663 let len_offset = data_offset - 4;
664 self.write_buf[len_offset..len_offset + 4]
665 .copy_from_slice(&(-1i32).to_be_bytes());
666 } else if old_len >= 0 {
667 let end = data_offset + old_len as usize;
668 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
669 template_ok = false;
670 break;
671 }
672 } else {
673 template_ok = false;
675 break;
676 }
677 }
678
679 if template_ok {
680 has_exec_sync = true;
681 } else {
682 self.write_buf.clear();
683 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
684 info.bind_template = None;
685 }
686 } else {
687 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
688 }
689
690 if info.bind_template.is_none() && !self.write_buf.is_empty() {
692 info.bind_template = build_bind_template(&self.write_buf, params.len());
693 }
694
695 if !has_exec_sync {
696 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
697 }
698
699 self.stream
701 .write_all(&self.write_buf)
702 .map_err(DriverError::Io)?;
703
704 let mut affected_rows: u64 = 0;
706
707 'outer: loop {
708 loop {
709 let avail = self.stream_buf_end - self.stream_buf_pos;
710 if avail < 5 {
711 break; }
713
714 let msg_type = self.stream_buf[self.stream_buf_pos];
715 let raw_len = i32::from_be_bytes([
716 self.stream_buf[self.stream_buf_pos + 1],
717 self.stream_buf[self.stream_buf_pos + 2],
718 self.stream_buf[self.stream_buf_pos + 3],
719 self.stream_buf[self.stream_buf_pos + 4],
720 ]);
721
722 if raw_len < 4 {
723 return Err(DriverError::Protocol(format!(
724 "invalid message length {raw_len} for type '{}'",
725 msg_type as char
726 )));
727 }
728
729 let payload_len = (raw_len - 4) as usize;
730 let total_msg_len = 5 + payload_len;
731
732 if avail < total_msg_len {
733 if total_msg_len > self.stream_buf.len() {
734 let msg = self.read_one_message()?;
735 match msg {
736 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
737 continue;
738 }
739 BackendMessage::CommandComplete { tag } => {
740 affected_rows = proto::parse_command_tag(tag);
741 continue;
742 }
743 BackendMessage::EmptyQuery => continue,
744 BackendMessage::ReadyForQuery { status } => {
745 self.tx_status = status;
746 break 'outer;
747 }
748 BackendMessage::NoticeResponse { .. } => continue,
749 BackendMessage::ErrorResponse { data } => {
750 let fields = proto::parse_error_response(data);
751 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
752 self.drain_to_ready()?;
753 return Err(self.make_server_error(fields));
754 }
755 other => {
756 return Err(DriverError::Protocol(format!(
757 "unexpected message during execute: {other:?}"
758 )));
759 }
760 }
761 }
762 break; }
764
765 let payload_start = self.stream_buf_pos + 5;
770 let payload_end = payload_start + payload_len;
771
772 if msg_type == b'2' {
773 self.stream_buf_pos += total_msg_len;
775 continue;
776 } else if msg_type == b'C' {
777 affected_rows = proto::parse_command_tag_bytes(
779 &self.stream_buf[payload_start..payload_end],
780 );
781 } else if msg_type == b'Z' {
782 if payload_len >= 1 {
784 self.tx_status = self.stream_buf[payload_start];
785 }
786 self.stream_buf_pos += total_msg_len;
787 break 'outer;
788 } else if msg_type == b'D' || msg_type == b'I' {
789 } else {
791 self.handle_non_datarow_execute(
792 msg_type,
793 payload_start,
794 payload_end,
795 sql_hash,
796 )?;
797 }
798
799 self.stream_buf_pos += total_msg_len;
800 }
801
802 let remaining = self.stream_buf_end - self.stream_buf_pos;
804 debug_assert!(
805 remaining == 0 || self.stream_buf_pos > 0,
806 "compact called with pos=0 and remaining data"
807 );
808 if remaining > 0 {
809 self.stream_buf
810 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
811 }
812 self.stream_buf_pos = 0;
813 self.stream_buf_end = remaining;
814 let n = self
815 .stream
816 .read(&mut self.stream_buf[remaining..])
817 .map_err(DriverError::Io)?;
818 if n == 0 {
819 return Err(DriverError::Io(std::io::Error::new(
820 std::io::ErrorKind::UnexpectedEof,
821 "connection closed",
822 )));
823 }
824 self.stream_buf_end = remaining + n;
825 }
826
827 if self.query_counter & 63 == 0 {
829 if self.read_buf.capacity() > 64 * 1024 {
830 self.read_buf.clear();
831 self.read_buf.shrink_to(8192);
832 }
833 if self.write_buf.capacity() > 16 * 1024 {
834 self.write_buf.clear();
835 self.write_buf.shrink_to(8192);
836 }
837 }
838
839 Ok(affected_rows)
840 }
841
842 #[cold]
844 #[inline(never)]
845 fn execute_with_prepare(
846 &mut self,
847 sql: &str,
848 sql_hash: u64,
849 params: &[&(dyn Encode + Sync)],
850 ) -> Result<u64, DriverError> {
851 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
852
853 if params.len() > i16::MAX as usize {
854 return Err(DriverError::Protocol(format!(
855 "parameter count {} exceeds maximum {}",
856 params.len(),
857 i16::MAX
858 )));
859 }
860
861 let name = make_stmt_name(sql_hash);
862 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
863 let param_oids: smallvec::SmallVec<[u32; 8]> =
864 params.iter().map(|p| p.type_oid()).collect();
865
866 self.write_buf.clear();
867 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
868 proto::write_describe(&mut self.write_buf, b'S', name_s);
869 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
870 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
871 self.stream
872 .write_all(&self.write_buf)
873 .map_err(DriverError::Io)?;
874
875 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
876 let columns = self.read_column_description()?;
877 self.query_counter += 1;
878 self.cache_stmt(
879 sql_hash,
880 StmtInfo {
881 name,
882 columns,
883 last_used: self.query_counter,
884 bind_template: None,
885 },
886 );
887
888 let mut affected_rows: u64 = 0;
890 'outer: loop {
891 loop {
892 let avail = self.stream_buf_end - self.stream_buf_pos;
893 if avail < 5 {
894 break;
895 }
896
897 let msg_type = self.stream_buf[self.stream_buf_pos];
898 let raw_len = i32::from_be_bytes([
899 self.stream_buf[self.stream_buf_pos + 1],
900 self.stream_buf[self.stream_buf_pos + 2],
901 self.stream_buf[self.stream_buf_pos + 3],
902 self.stream_buf[self.stream_buf_pos + 4],
903 ]);
904
905 if raw_len < 4 {
906 return Err(DriverError::Protocol(format!(
907 "invalid message length {raw_len} for type '{}'",
908 msg_type as char
909 )));
910 }
911
912 let payload_len = (raw_len - 4) as usize;
913 let total_msg_len = 5 + payload_len;
914
915 if avail < total_msg_len {
916 if total_msg_len > self.stream_buf.len() {
917 let msg = self.read_one_message()?;
918 match msg {
919 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
920 continue;
921 }
922 BackendMessage::CommandComplete { tag } => {
923 affected_rows = proto::parse_command_tag(tag);
924 continue;
925 }
926 BackendMessage::EmptyQuery => continue,
927 BackendMessage::ReadyForQuery { status } => {
928 self.tx_status = status;
929 break 'outer;
930 }
931 BackendMessage::NoticeResponse { .. } => continue,
932 BackendMessage::ErrorResponse { data } => {
933 let fields = proto::parse_error_response(data);
934 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
935 self.drain_to_ready()?;
936 return Err(self.make_server_error(fields));
937 }
938 other => {
939 return Err(DriverError::Protocol(format!(
940 "unexpected message during execute: {other:?}"
941 )));
942 }
943 }
944 }
945 break;
946 }
947
948 let payload_start = self.stream_buf_pos + 5;
949 let payload_end = payload_start + payload_len;
950
951 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
952 } else if msg_type == b'C' {
954 affected_rows = proto::parse_command_tag_bytes(
955 &self.stream_buf[payload_start..payload_end],
956 );
957 } else if msg_type == b'Z' {
958 if payload_len >= 1 {
959 self.tx_status = self.stream_buf[payload_start];
960 }
961 self.stream_buf_pos += total_msg_len;
962 break 'outer;
963 } else {
964 self.handle_non_datarow_execute(
965 msg_type,
966 payload_start,
967 payload_end,
968 sql_hash,
969 )?;
970 }
971
972 self.stream_buf_pos += total_msg_len;
973 }
974
975 self.refill_stream_buf()?;
976 }
977
978 Ok(affected_rows)
979 }
980
981 #[inline]
986 pub fn execute(
987 &mut self,
988 sql: &str,
989 sql_hash: u64,
990 params: &[&(dyn Encode + Sync)],
991 ) -> Result<u64, DriverError> {
992 self.execute_monolithic(sql, sql_hash, params)
993 }
994
995 pub fn execute_pipeline(
1007 &mut self,
1008 sql: &str,
1009 sql_hash: u64,
1010 param_sets: &[&[&(dyn Encode + Sync)]],
1011 ) -> Result<Vec<u64>, DriverError> {
1012 if param_sets.is_empty() {
1013 return Ok(Vec::new());
1014 }
1015
1016 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1017
1018 self.write_buf.clear();
1019
1020 if !self.stmts.contains_key(&sql_hash) {
1022 let name = make_stmt_name(sql_hash);
1023 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1024 let first_params = param_sets[0];
1025 if first_params.len() > i16::MAX as usize {
1026 return Err(DriverError::Protocol(format!(
1027 "parameter count {} exceeds maximum {}",
1028 first_params.len(),
1029 i16::MAX
1030 )));
1031 }
1032 let param_oids: smallvec::SmallVec<[u32; 8]> =
1033 first_params.iter().map(|p| p.type_oid()).collect();
1034 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1035 proto::write_describe(&mut self.write_buf, b'S', name_s);
1036 proto::write_sync(&mut self.write_buf);
1037 self.flush_write()?;
1038
1039 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1040 let columns = self.read_column_description()?;
1041 self.expect_ready()?;
1042
1043 self.query_counter += 1;
1044 self.cache_stmt(
1045 sql_hash,
1046 StmtInfo {
1047 name,
1048 columns,
1049 last_used: self.query_counter,
1050 bind_template: None,
1051 },
1052 );
1053
1054 self.write_buf.clear();
1055 }
1056
1057 let stmt_name = self
1059 .stmts
1060 .get(&sql_hash)
1061 .expect("BUG: stmt just cached but not found")
1062 .name_str()
1063 .to_owned();
1064 let count = param_sets.len();
1065
1066 for params in param_sets {
1067 if params.len() > i16::MAX as usize {
1068 return Err(DriverError::Protocol(format!(
1069 "parameter count {} exceeds maximum {}",
1070 params.len(),
1071 i16::MAX
1072 )));
1073 }
1074 proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1075 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1076 }
1077
1078 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1079 self.flush_write()?;
1080
1081 let mut results = Vec::with_capacity(count);
1083 for _ in 0..count {
1084 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1085
1086 let mut affected_rows: u64 = 0;
1087 loop {
1088 let msg = self.read_one_message()?;
1089 match msg {
1090 BackendMessage::DataRow { .. } => {}
1091 BackendMessage::CommandComplete { tag } => {
1092 affected_rows = proto::parse_command_tag(tag);
1093 break;
1094 }
1095 BackendMessage::EmptyQuery => break,
1096 BackendMessage::NoticeResponse { .. } => {}
1097 BackendMessage::ErrorResponse { data } => {
1098 let fields = proto::parse_error_response(data);
1099 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1100 self.drain_to_ready()?;
1101 return Err(self.make_server_error(fields));
1102 }
1103 other => {
1104 return Err(DriverError::Protocol(format!(
1105 "unexpected message during execute_pipeline: {other:?}"
1106 )));
1107 }
1108 }
1109 }
1110 results.push(affected_rows);
1111 }
1112
1113 self.expect_ready()?;
1114 self.shrink_buffers();
1115 Ok(results)
1116 }
1117
1118 pub(crate) fn ensure_stmt_prepared(
1124 &mut self,
1125 sql: &str,
1126 sql_hash: u64,
1127 params: &[&(dyn Encode + Sync)],
1128 ) -> Result<[u8; 18], DriverError> {
1129 if let Some(info) = self.stmts.get(&sql_hash) {
1130 return Ok(info.name);
1131 }
1132
1133 let name = make_stmt_name(sql_hash);
1134 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1135 if params.len() > i16::MAX as usize {
1136 return Err(DriverError::Protocol(format!(
1137 "parameter count {} exceeds maximum {}",
1138 params.len(),
1139 i16::MAX
1140 )));
1141 }
1142 let param_oids: smallvec::SmallVec<[u32; 8]> =
1143 params.iter().map(|p| p.type_oid()).collect();
1144
1145 self.write_buf.clear();
1146 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1147 proto::write_describe(&mut self.write_buf, b'S', name_s);
1148 proto::write_sync(&mut self.write_buf);
1149 self.flush_write()?;
1150
1151 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1152 let columns = self.read_column_description()?;
1153 self.expect_ready()?;
1154
1155 self.query_counter += 1;
1156 self.cache_stmt(
1157 sql_hash,
1158 StmtInfo {
1159 name,
1160 columns,
1161 last_used: self.query_counter,
1162 bind_template: None,
1163 },
1164 );
1165
1166 Ok(name)
1167 }
1168
1169 pub(crate) fn write_deferred_bind_execute(
1172 &self,
1173 sql_hash: u64,
1174 params: &[&(dyn Encode + Sync)],
1175 buf: &mut Vec<u8>,
1176 ) {
1177 let stmt_name = self
1178 .stmts
1179 .get(&sql_hash)
1180 .expect("BUG: stmt just cached but not found")
1181 .name_str();
1182 proto::write_bind_params(buf, "", stmt_name, params);
1183 buf.extend_from_slice(proto::EXECUTE_ONLY);
1184 }
1185
1186 pub(crate) fn flush_deferred_pipeline(
1191 &mut self,
1192 buf: &mut Vec<u8>,
1193 count: usize,
1194 ) -> Result<Vec<u64>, DriverError> {
1195 if count == 0 {
1196 buf.clear();
1197 return Ok(Vec::new());
1198 }
1199
1200 buf.extend_from_slice(proto::SYNC_ONLY);
1201
1202 self.stream.write_all(buf).map_err(DriverError::Io)?;
1203 buf.clear();
1204
1205 let mut results = Vec::with_capacity(count);
1206 for _ in 0..count {
1207 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1208
1209 let mut affected_rows: u64 = 0;
1210 loop {
1211 let msg = self.read_one_message()?;
1212 match msg {
1213 BackendMessage::DataRow { .. } => {}
1214 BackendMessage::CommandComplete { tag } => {
1215 affected_rows = proto::parse_command_tag(tag);
1216 break;
1217 }
1218 BackendMessage::EmptyQuery => break,
1219 BackendMessage::NoticeResponse { .. } => {}
1220 BackendMessage::ErrorResponse { data } => {
1221 let fields = proto::parse_error_response(data);
1222 self.drain_to_ready()?;
1223 return Err(self.make_server_error(fields));
1224 }
1225 other => {
1226 return Err(DriverError::Protocol(format!(
1227 "unexpected message during flush_deferred_pipeline: {other:?}"
1228 )));
1229 }
1230 }
1231 }
1232 results.push(affected_rows);
1233 }
1234
1235 self.expect_ready()?;
1236 self.shrink_buffers();
1237 Ok(results)
1238 }
1239
1240 pub fn for_each<F>(
1242 &mut self,
1243 sql: &str,
1244 sql_hash: u64,
1245 params: &[&(dyn Encode + Sync)],
1246 mut f: F,
1247 ) -> Result<(), DriverError>
1248 where
1249 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1250 {
1251 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1252
1253 'outer: loop {
1255 loop {
1256 let avail = self.stream_buf_end - self.stream_buf_pos;
1257 if avail < 5 {
1258 break; }
1260
1261 let msg_type = self.stream_buf[self.stream_buf_pos];
1262 let raw_len = i32::from_be_bytes([
1263 self.stream_buf[self.stream_buf_pos + 1],
1264 self.stream_buf[self.stream_buf_pos + 2],
1265 self.stream_buf[self.stream_buf_pos + 3],
1266 self.stream_buf[self.stream_buf_pos + 4],
1267 ]);
1268
1269 if raw_len < 4 {
1270 return Err(DriverError::Protocol(format!(
1271 "invalid message length {raw_len} for type '{}'",
1272 msg_type as char
1273 )));
1274 }
1275
1276 let payload_len = (raw_len - 4) as usize;
1277 let total_msg_len = 5 + payload_len;
1278
1279 if avail < total_msg_len {
1280 if total_msg_len > self.stream_buf.len() {
1281 let msg = self.read_one_message()?;
1283 match msg {
1284 BackendMessage::BindComplete => continue,
1285 BackendMessage::DataRow { data } => {
1286 let row = PgDataRow::new(data)?;
1287 f(row)?;
1288 continue;
1289 }
1290 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1291 continue;
1292 }
1293 BackendMessage::ReadyForQuery { status } => {
1294 self.tx_status = status;
1295 break 'outer;
1296 }
1297 BackendMessage::NoticeResponse { .. } => continue,
1298 BackendMessage::ErrorResponse { data } => {
1299 let fields = proto::parse_error_response(data);
1300 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1301 self.drain_to_ready()?;
1302 return Err(self.make_server_error(fields));
1303 }
1304 other => {
1305 return Err(DriverError::Protocol(format!(
1306 "unexpected message during for_each: {other:?}"
1307 )));
1308 }
1309 }
1310 }
1311 break; }
1313
1314 let payload_start = self.stream_buf_pos + 5;
1316 let payload_end = payload_start + payload_len;
1317
1318 if msg_type == b'D' {
1321 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1323 f(row)?;
1324 } else if msg_type == b'Z' {
1325 if payload_len >= 1 {
1327 self.tx_status = self.stream_buf[payload_start];
1328 }
1329 self.stream_buf_pos += total_msg_len;
1330 break 'outer;
1331 } else {
1332 self.handle_non_datarow_execute(
1333 msg_type,
1334 payload_start,
1335 payload_end,
1336 sql_hash,
1337 )?;
1338 }
1339
1340 self.stream_buf_pos += total_msg_len;
1341 }
1342
1343 self.refill_stream_buf()?;
1345 }
1346
1347 self.shrink_buffers();
1348 Ok(())
1349 }
1350
1351 #[inline]
1362 pub fn for_each_raw_monolithic<F>(
1363 &mut self,
1364 sql: &str,
1365 sql_hash: u64,
1366 params: &[&(dyn Encode + Sync)],
1367 mut f: F,
1368 ) -> Result<(), DriverError>
1369 where
1370 F: FnMut(&[u8]) -> Result<(), DriverError>,
1371 {
1372 self.write_buf.clear();
1374
1375 let info = match self.stmts.get_mut(&sql_hash) {
1377 Some(info) => {
1378 self.query_counter += 1;
1379 info.last_used = self.query_counter;
1380 info
1381 }
1382 None => {
1383 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1385 }
1386 };
1387
1388 let can_use_template = info
1390 .bind_template
1391 .as_ref()
1392 .is_some_and(|t| t.param_slots.len() == params.len());
1393
1394 let mut has_exec_sync = false;
1395
1396 if can_use_template {
1397 let tmpl = info
1399 .bind_template
1400 .as_ref()
1401 .expect("guarded by can_use_template");
1402 self.write_buf.extend_from_slice(&tmpl.bytes);
1403
1404 let mut template_ok = true;
1405 for (i, param) in params.iter().enumerate() {
1406 let (data_offset, old_len) = tmpl.param_slots[i];
1407 if param.is_null() {
1408 let len_offset = data_offset - 4;
1409 self.write_buf[len_offset..len_offset + 4]
1410 .copy_from_slice(&(-1i32).to_be_bytes());
1411 } else if old_len >= 0 {
1412 let end = data_offset + old_len as usize;
1413 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1414 template_ok = false;
1415 break;
1416 }
1417 } else {
1418 template_ok = false;
1419 break;
1420 }
1421 }
1422
1423 if template_ok {
1424 has_exec_sync = true;
1425 } else {
1426 self.write_buf.clear();
1427 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1428 info.bind_template = None;
1429 }
1430 } else {
1431 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
1432 }
1433
1434 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1436 info.bind_template = build_bind_template(&self.write_buf, params.len());
1437 }
1438
1439 if !has_exec_sync {
1440 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1441 }
1442
1443 self.stream
1445 .write_all(&self.write_buf)
1446 .map_err(DriverError::Io)?;
1447
1448 loop {
1452 let avail = self.stream_buf_end - self.stream_buf_pos;
1453 if avail >= 5 {
1454 let bc_type = self.stream_buf[self.stream_buf_pos];
1455 match bc_type {
1456 b'2' => {
1457 self.stream_buf_pos += 5;
1458 break;
1459 }
1460 b'E' => {
1461 let msg = self.read_one_message()?;
1462 if let BackendMessage::ErrorResponse { data } = msg {
1463 let fields = proto::parse_error_response(data);
1464 self.drain_to_ready()?;
1465 return Err(self.make_server_error(fields));
1466 }
1467 }
1468 b'N' | b'S' => {
1469 let raw_len = i32::from_be_bytes([
1470 self.stream_buf[self.stream_buf_pos + 1],
1471 self.stream_buf[self.stream_buf_pos + 2],
1472 self.stream_buf[self.stream_buf_pos + 3],
1473 self.stream_buf[self.stream_buf_pos + 4],
1474 ]);
1475 let total = 1 + raw_len as usize;
1476 if avail >= total {
1477 self.stream_buf_pos += total;
1478 continue;
1479 }
1480 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1481 break;
1482 }
1483 _ => {
1484 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1485 break;
1486 }
1487 }
1488 } else {
1489 let remaining = self.stream_buf_end - self.stream_buf_pos;
1491 if remaining > 0 && self.stream_buf_pos > 0 {
1492 self.stream_buf
1493 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1494 }
1495 self.stream_buf_pos = 0;
1496 self.stream_buf_end = remaining;
1497 let n = self
1498 .stream
1499 .read(&mut self.stream_buf[remaining..])
1500 .map_err(DriverError::Io)?;
1501 if n == 0 {
1502 return Err(DriverError::Io(std::io::Error::new(
1503 std::io::ErrorKind::UnexpectedEof,
1504 "connection closed",
1505 )));
1506 }
1507 self.stream_buf_end = remaining + n;
1508 }
1509 }
1510
1511 'outer: loop {
1513 loop {
1514 let avail = self.stream_buf_end - self.stream_buf_pos;
1515 if avail < 5 {
1516 break;
1517 }
1518
1519 let msg_type = self.stream_buf[self.stream_buf_pos];
1520 let raw_len = i32::from_be_bytes([
1521 self.stream_buf[self.stream_buf_pos + 1],
1522 self.stream_buf[self.stream_buf_pos + 2],
1523 self.stream_buf[self.stream_buf_pos + 3],
1524 self.stream_buf[self.stream_buf_pos + 4],
1525 ]);
1526
1527 if raw_len < 4 {
1528 return Err(DriverError::Protocol(format!(
1529 "invalid message length {raw_len} for type '{}'",
1530 msg_type as char
1531 )));
1532 }
1533
1534 let payload_len = (raw_len - 4) as usize;
1535 let total_msg_len = 5 + payload_len;
1536
1537 if avail < total_msg_len {
1538 if total_msg_len > self.stream_buf.len() {
1539 let msg = self.read_one_message()?;
1540 match msg {
1541 BackendMessage::DataRow { data } => {
1542 f(data)?;
1543 continue;
1544 }
1545 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1546 continue;
1547 }
1548 BackendMessage::ReadyForQuery { status } => {
1549 self.tx_status = status;
1550 break 'outer;
1551 }
1552 BackendMessage::ErrorResponse { data } => {
1553 let fields = proto::parse_error_response(data);
1554 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1555 self.drain_to_ready()?;
1556 return Err(self.make_server_error(fields));
1557 }
1558 BackendMessage::NoticeResponse { .. } => continue,
1559 other => {
1560 return Err(DriverError::Protocol(format!(
1561 "unexpected message during for_each_raw: {other:?}"
1562 )));
1563 }
1564 }
1565 }
1566 break; }
1568
1569 let payload_start = self.stream_buf_pos + 5;
1571 let payload_end = payload_start + payload_len;
1572
1573 if msg_type == b'D' {
1574 f(&self.stream_buf[payload_start..payload_end])?;
1575 } else if msg_type == b'Z' {
1576 if payload_len >= 1 {
1577 self.tx_status = self.stream_buf[payload_start];
1578 }
1579 self.stream_buf_pos += total_msg_len;
1580 break 'outer;
1581 } else {
1582 self.handle_non_datarow_execute(
1583 msg_type,
1584 payload_start,
1585 payload_end,
1586 sql_hash,
1587 )?;
1588 }
1589
1590 self.stream_buf_pos += total_msg_len;
1591 }
1592
1593 let remaining = self.stream_buf_end - self.stream_buf_pos;
1595 if remaining > 0 && self.stream_buf_pos > 0 {
1596 self.stream_buf
1597 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1598 }
1599 self.stream_buf_pos = 0;
1600 self.stream_buf_end = remaining;
1601 let n = self
1602 .stream
1603 .read(&mut self.stream_buf[remaining..])
1604 .map_err(DriverError::Io)?;
1605 if n == 0 {
1606 return Err(DriverError::Io(std::io::Error::new(
1607 std::io::ErrorKind::UnexpectedEof,
1608 "connection closed",
1609 )));
1610 }
1611 self.stream_buf_end = remaining + n;
1612 }
1613
1614 if self.query_counter & 63 == 0 {
1616 if self.read_buf.capacity() > 64 * 1024 {
1617 self.read_buf.clear();
1618 self.read_buf.shrink_to(8192);
1619 }
1620 if self.write_buf.capacity() > 16 * 1024 {
1621 self.write_buf.clear();
1622 self.write_buf.shrink_to(8192);
1623 }
1624 }
1625
1626 Ok(())
1627 }
1628
1629 #[cold]
1631 #[inline(never)]
1632 fn for_each_raw_with_prepare<F>(
1633 &mut self,
1634 sql: &str,
1635 sql_hash: u64,
1636 params: &[&(dyn Encode + Sync)],
1637 mut f: F,
1638 ) -> Result<(), DriverError>
1639 where
1640 F: FnMut(&[u8]) -> Result<(), DriverError>,
1641 {
1642 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
1643
1644 if params.len() > i16::MAX as usize {
1645 return Err(DriverError::Protocol(format!(
1646 "parameter count {} exceeds maximum {}",
1647 params.len(),
1648 i16::MAX
1649 )));
1650 }
1651
1652 let name = make_stmt_name(sql_hash);
1653 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
1654 let param_oids: smallvec::SmallVec<[u32; 8]> =
1655 params.iter().map(|p| p.type_oid()).collect();
1656
1657 self.write_buf.clear();
1658 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
1659 proto::write_describe(&mut self.write_buf, b'S', name_s);
1660 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
1661 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1662 self.stream
1663 .write_all(&self.write_buf)
1664 .map_err(DriverError::Io)?;
1665
1666 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1667 let columns = self.read_column_description()?;
1668 self.query_counter += 1;
1669 self.cache_stmt(
1670 sql_hash,
1671 StmtInfo {
1672 name,
1673 columns,
1674 last_used: self.query_counter,
1675 bind_template: None,
1676 },
1677 );
1678
1679 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1681
1682 'outer: loop {
1683 loop {
1684 let avail = self.stream_buf_end - self.stream_buf_pos;
1685 if avail < 5 {
1686 break;
1687 }
1688
1689 let msg_type = self.stream_buf[self.stream_buf_pos];
1690 let raw_len = i32::from_be_bytes([
1691 self.stream_buf[self.stream_buf_pos + 1],
1692 self.stream_buf[self.stream_buf_pos + 2],
1693 self.stream_buf[self.stream_buf_pos + 3],
1694 self.stream_buf[self.stream_buf_pos + 4],
1695 ]);
1696
1697 if raw_len < 4 {
1698 return Err(DriverError::Protocol(format!(
1699 "invalid message length {raw_len} for type '{}'",
1700 msg_type as char
1701 )));
1702 }
1703
1704 let payload_len = (raw_len - 4) as usize;
1705 let total_msg_len = 5 + payload_len;
1706
1707 if avail < total_msg_len {
1708 if total_msg_len > self.stream_buf.len() {
1709 let msg = self.read_one_message()?;
1710 match msg {
1711 BackendMessage::DataRow { data } => {
1712 f(data)?;
1713 continue;
1714 }
1715 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1716 continue;
1717 }
1718 BackendMessage::ReadyForQuery { status } => {
1719 self.tx_status = status;
1720 break 'outer;
1721 }
1722 BackendMessage::ErrorResponse { data } => {
1723 let fields = proto::parse_error_response(data);
1724 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1725 self.drain_to_ready()?;
1726 return Err(self.make_server_error(fields));
1727 }
1728 BackendMessage::NoticeResponse { .. } => continue,
1729 other => {
1730 return Err(DriverError::Protocol(format!(
1731 "unexpected message during for_each_raw: {other:?}"
1732 )));
1733 }
1734 }
1735 }
1736 break;
1737 }
1738
1739 let payload_start = self.stream_buf_pos + 5;
1740 let payload_end = payload_start + payload_len;
1741
1742 if msg_type == b'D' {
1743 f(&self.stream_buf[payload_start..payload_end])?;
1744 } else if msg_type == b'Z' {
1745 if payload_len >= 1 {
1746 self.tx_status = self.stream_buf[payload_start];
1747 }
1748 self.stream_buf_pos += total_msg_len;
1749 break 'outer;
1750 } else {
1751 self.handle_non_datarow_execute(
1752 msg_type,
1753 payload_start,
1754 payload_end,
1755 sql_hash,
1756 )?;
1757 }
1758
1759 self.stream_buf_pos += total_msg_len;
1760 }
1761
1762 self.refill_stream_buf()?;
1763 }
1764
1765 self.shrink_buffers();
1766 Ok(())
1767 }
1768
1769 #[inline]
1774 pub fn for_each_raw<F>(
1775 &mut self,
1776 sql: &str,
1777 sql_hash: u64,
1778 params: &[&(dyn Encode + Sync)],
1779 f: F,
1780 ) -> Result<(), DriverError>
1781 where
1782 F: FnMut(&[u8]) -> Result<(), DriverError>,
1783 {
1784 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1785 }
1786
1787 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1789 self.write_buf.clear();
1790 proto::write_simple_query(&mut self.write_buf, sql);
1791 self.flush_write()?;
1792
1793 loop {
1794 let msg = self.read_one_message()?;
1795 match msg {
1796 BackendMessage::ReadyForQuery { status } => {
1797 self.tx_status = status;
1798 return Ok(());
1799 }
1800 BackendMessage::CommandComplete { .. }
1801 | BackendMessage::RowDescription { .. }
1802 | BackendMessage::DataRow { .. }
1803 | BackendMessage::EmptyQuery
1804 | BackendMessage::NoticeResponse { .. }
1805 | BackendMessage::ParameterStatus { .. }
1806 | BackendMessage::AuthOk
1810 | BackendMessage::AuthSaslFinal { .. }
1811 | BackendMessage::BackendKeyData { .. } => {}
1812 BackendMessage::ErrorResponse { data } => {
1813 let fields = proto::parse_error_response(data);
1814 self.drain_to_ready()?;
1815 return Err(self.make_server_error(fields));
1816 }
1817 other => {
1818 return Err(DriverError::Protocol(format!(
1819 "unexpected message during simple_query: {other:?}"
1820 )));
1821 }
1822 }
1823 }
1824 }
1825
1826 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1828 self.write_buf.clear();
1829 proto::write_simple_query(&mut self.write_buf, sql);
1830 self.flush_write()?;
1831
1832 let mut rows: Vec<SimpleRow> = Vec::new();
1833 loop {
1834 let msg = self.read_one_message()?;
1835 match msg {
1836 BackendMessage::ReadyForQuery { status } => {
1837 self.tx_status = status;
1838 return Ok(rows);
1839 }
1840 BackendMessage::DataRow { data } => {
1841 rows.push(proto::parse_simple_data_row(data)?);
1842 }
1843 BackendMessage::RowDescription { .. }
1844 | BackendMessage::CommandComplete { .. }
1845 | BackendMessage::EmptyQuery
1846 | BackendMessage::NoticeResponse { .. }
1847 | BackendMessage::ParameterStatus { .. }
1848 | BackendMessage::AuthOk
1849 | BackendMessage::AuthSaslFinal { .. }
1850 | BackendMessage::BackendKeyData { .. } => {}
1851 BackendMessage::ErrorResponse { data } => {
1852 let fields = proto::parse_error_response(data);
1853 self.drain_to_ready()?;
1854 return Err(self.make_server_error(fields));
1855 }
1856 other => {
1857 return Err(DriverError::Protocol(format!(
1858 "unexpected message during simple_query_rows: {other:?}"
1859 )));
1860 }
1861 }
1862 }
1863 }
1864
1865 pub fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
1870 self.write_buf.clear();
1871 proto::write_parse(&mut self.write_buf, "", sql, &[]);
1874 proto::write_describe(&mut self.write_buf, b'S', "");
1875 proto::write_sync(&mut self.write_buf);
1876 self.flush_write()?;
1877
1878 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1880
1881 let mut param_oids: Vec<u32> = Vec::new();
1883 let columns;
1884 loop {
1885 let msg = self.read_one_message()?;
1886 match msg {
1887 BackendMessage::ParameterDescription { data } => {
1888 param_oids = proto::parse_parameter_description(data)?;
1889 }
1890 BackendMessage::RowDescription { data } => {
1891 columns = proto::parse_row_description(data)?;
1892 break;
1893 }
1894 BackendMessage::NoData => {
1895 columns = Vec::new();
1896 break;
1897 }
1898 BackendMessage::NoticeResponse { .. } => {}
1899 BackendMessage::ErrorResponse { data } => {
1900 let fields = proto::parse_error_response(data);
1901 self.drain_to_ready()?;
1902 return Err(self.make_server_error(fields));
1903 }
1904 other => {
1905 return Err(DriverError::Protocol(format!(
1906 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
1907 )));
1908 }
1909 }
1910 }
1911
1912 self.expect_ready()?;
1914
1915 Ok(PrepareResult {
1916 columns,
1917 param_oids,
1918 })
1919 }
1920
1921 pub fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
1930 loop {
1931 let (msg_type, _payload_len) = self.read_message_buffered()?;
1932 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
1933 match msg {
1934 BackendMessage::NotificationResponse {
1935 channel, payload, ..
1936 } => {
1937 return Ok((channel.to_owned(), payload.to_owned()));
1938 }
1939 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
1940 continue;
1941 }
1942 _ => continue,
1943 }
1944 }
1945 }
1946
1947 pub fn cancel(&self) -> Result<(), DriverError> {
1953 let addr = format!("{}:{}", self.connect_config.host, self.connect_config.port);
1954 let mut tcp = std::net::TcpStream::connect(&addr).map_err(DriverError::Io)?;
1955 let mut buf = Vec::with_capacity(16);
1956 proto::write_cancel_request(&mut buf, self.pid, self.secret);
1957 tcp.write_all(&buf).map_err(DriverError::Io)?;
1958 tcp.flush().map_err(DriverError::Io)?;
1959 drop(tcp);
1961 Ok(())
1962 }
1963
1964 pub fn set_read_timeout(
1969 &self,
1970 timeout: Option<std::time::Duration>,
1971 ) -> Result<(), DriverError> {
1972 self.stream
1973 .set_read_timeout(timeout)
1974 .map_err(DriverError::Io)
1975 }
1976
1977 pub fn query_streaming_start(
1991 &mut self,
1992 sql: &str,
1993 sql_hash: u64,
1994 params: &[&(dyn Encode + Sync)],
1995 chunk_size: i32,
1996 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
1997 self.write_buf.clear();
1998
1999 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
2000 self.query_counter += 1;
2002 info.last_used = self.query_counter;
2003
2004 let can_use_template = info
2005 .bind_template
2006 .as_ref()
2007 .is_some_and(|t| t.param_slots.len() == params.len());
2008
2009 if can_use_template {
2010 let tmpl = info
2012 .bind_template
2013 .as_ref()
2014 .expect("guarded by can_use_template");
2015 self.write_buf
2018 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
2019
2020 let mut template_ok = true;
2021 for (i, param) in params.iter().enumerate() {
2022 let (data_offset, old_len) = tmpl.param_slots[i];
2023 if param.is_null() {
2024 let len_offset = data_offset - 4;
2025 self.write_buf[len_offset..len_offset + 4]
2026 .copy_from_slice(&(-1i32).to_be_bytes());
2027 } else if old_len >= 0 {
2028 let end = data_offset + old_len as usize;
2029 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2030 template_ok = false;
2031 break;
2032 }
2033 } else {
2034 template_ok = false;
2035 break;
2036 }
2037 }
2038
2039 if !template_ok {
2040 self.write_buf.clear();
2041 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2042 info.bind_template = None;
2043 }
2044 } else {
2045 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2046 }
2047
2048 let cols = info.columns.clone();
2049
2050 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2051 info.bind_template = build_bind_template(&self.write_buf, params.len());
2052 }
2053
2054 proto::write_execute(&mut self.write_buf, "", chunk_size);
2055 proto::write_flush(&mut self.write_buf);
2057 self.flush_write()?;
2058
2059 cols
2060 } else {
2061 let name = make_stmt_name(sql_hash);
2063 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2064 let param_oids: smallvec::SmallVec<[u32; 8]> =
2065 params.iter().map(|p| p.type_oid()).collect();
2066 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2067 proto::write_describe(&mut self.write_buf, b'S', name_s);
2068 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2069
2070 proto::write_execute(&mut self.write_buf, "", chunk_size);
2071 proto::write_flush(&mut self.write_buf);
2072 self.flush_write()?;
2073
2074 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2075 let columns = self.read_column_description()?;
2076 self.query_counter += 1;
2077 self.cache_stmt(
2078 sql_hash,
2079 StmtInfo {
2080 name,
2081 columns: columns.clone(),
2082 last_used: self.query_counter,
2083 bind_template: None,
2084 },
2085 );
2086 columns
2087 };
2088
2089 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2091
2092 self.streaming_active = true;
2093
2094 Ok((columns, false))
2095 }
2096
2097 pub fn streaming_next_chunk(
2105 &mut self,
2106 arena: &mut Arena,
2107 all_col_offsets: &mut Vec<(usize, i32)>,
2108 ) -> Result<bool, DriverError> {
2109 all_col_offsets.clear();
2110
2111 loop {
2112 let msg = self.read_one_message()?;
2113 match msg {
2114 BackendMessage::DataRow { data } => {
2115 parse_data_row_flat(data, arena, all_col_offsets)?;
2116 }
2117 BackendMessage::PortalSuspended => {
2118 return Ok(true);
2122 }
2123 BackendMessage::CommandComplete { .. } => {
2124 self.write_buf.clear();
2127 proto::write_sync(&mut self.write_buf);
2128 self.flush_write()?;
2129 self.expect_ready()?;
2130 self.shrink_buffers();
2131
2132 self.streaming_active = false;
2133 return Ok(false);
2134 }
2135 BackendMessage::EmptyQuery => {
2136 self.write_buf.clear();
2137 proto::write_sync(&mut self.write_buf);
2138 self.flush_write()?;
2139 self.expect_ready()?;
2140
2141 self.streaming_active = false;
2142 return Ok(false);
2143 }
2144 BackendMessage::ErrorResponse { data } => {
2145 let fields = proto::parse_error_response(data);
2146 self.write_buf.clear();
2148 proto::write_sync(&mut self.write_buf);
2149 self.flush_write()?;
2150 self.drain_to_ready()?;
2151
2152 self.streaming_active = false;
2153 return Err(self.make_server_error(fields));
2154 }
2155 BackendMessage::NoticeResponse { .. } => {}
2156 other => {
2157 return Err(DriverError::Protocol(format!(
2158 "unexpected message during streaming: {other:?}"
2159 )));
2160 }
2161 }
2162 }
2163 }
2164
2165 pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
2173 self.write_buf.clear();
2174 proto::write_execute(&mut self.write_buf, "", chunk_size);
2175 proto::write_flush(&mut self.write_buf);
2176 self.flush_write()
2177 }
2178
2179 pub fn is_streaming(&self) -> bool {
2181 self.streaming_active
2182 }
2183
2184 pub fn close(mut self) -> Result<(), DriverError> {
2186 self.write_buf.clear();
2187 proto::write_terminate(&mut self.write_buf);
2188 let _ = self.flush_write();
2189 Ok(())
2190 }
2191
2192 pub fn is_idle(&self) -> bool {
2196 self.tx_status == b'I'
2197 }
2198
2199 pub fn is_in_transaction(&self) -> bool {
2201 self.tx_status == b'T'
2202 }
2203
2204 pub fn is_in_failed_transaction(&self) -> bool {
2206 self.tx_status == b'E'
2207 }
2208
2209 pub fn touch(&mut self) {
2211 self.last_used = std::time::Instant::now();
2212 }
2213
2214 pub fn idle_duration(&self) -> std::time::Duration {
2216 self.last_used.elapsed()
2217 }
2218
2219 pub fn query_counter(&self) -> u64 {
2221 self.query_counter
2222 }
2223
2224 pub fn parameter(&self, name: &str) -> Option<&str> {
2226 self.params
2227 .iter()
2228 .find(|(k, _)| &**k == name)
2229 .map(|(_, v)| &**v)
2230 }
2231
2232 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2234 &self.params
2235 }
2236
2237 pub fn pid(&self) -> i32 {
2239 self.pid
2240 }
2241
2242 pub fn secret_key(&self) -> i32 {
2244 self.secret
2245 }
2246
2247 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2249 std::mem::take(&mut self.pending_notifications)
2250 }
2251
2252 pub fn pending_notification_count(&self) -> usize {
2254 self.pending_notifications.len()
2255 }
2256
2257 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2259 self.max_stmt_cache_size = size;
2260 }
2261
2262 pub fn stmt_cache_len(&self) -> usize {
2264 self.stmts.len()
2265 }
2266
2267 pub fn created_at(&self) -> std::time::Instant {
2269 self.created_at
2270 }
2271
2272 #[inline]
2280 fn send_pipeline(
2281 &mut self,
2282 sql: &str,
2283 sql_hash: u64,
2284 params: &[&(dyn Encode + Sync)],
2285 need_columns: bool,
2286 skip_bind_complete: bool,
2287 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
2288 debug_assert_eq!(crate::types::hash_sql(sql), sql_hash, "sql_hash mismatch");
2289
2290 if params.len() > i16::MAX as usize {
2291 return Err(DriverError::Protocol(format!(
2292 "parameter count {} exceeds maximum {}",
2293 params.len(),
2294 i16::MAX
2295 )));
2296 }
2297
2298 self.write_buf.clear();
2299
2300 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
2301 self.query_counter += 1;
2303 info.last_used = self.query_counter;
2304
2305 let can_use_template = info
2306 .bind_template
2307 .as_ref()
2308 .is_some_and(|t| t.param_slots.len() == params.len());
2309
2310 let mut has_exec_sync = false;
2312
2313 if can_use_template {
2314 let tmpl = info
2318 .bind_template
2319 .as_ref()
2320 .expect("guarded by can_use_template");
2321 self.write_buf.extend_from_slice(&tmpl.bytes);
2322
2323 let mut template_ok = true;
2324 for (i, param) in params.iter().enumerate() {
2325 let (data_offset, old_len) = tmpl.param_slots[i];
2326 if param.is_null() {
2327 let len_offset = data_offset - 4;
2329 self.write_buf[len_offset..len_offset + 4]
2330 .copy_from_slice(&(-1i32).to_be_bytes());
2331 } else if old_len >= 0 {
2332 let end = data_offset + old_len as usize;
2333 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2334 template_ok = false;
2336 break;
2337 }
2338 } else {
2339 template_ok = false;
2342 break;
2343 }
2344 }
2345
2346 if template_ok {
2347 has_exec_sync = true; } else {
2349 self.write_buf.clear();
2350 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2351 info.bind_template = None;
2353 }
2354 } else {
2355 proto::write_bind_params(&mut self.write_buf, "", info.name_str(), params);
2356 }
2357
2358 let cols = if need_columns {
2359 Some(info.columns.clone())
2360 } else {
2361 None
2362 };
2363
2364 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2368 info.bind_template = build_bind_template(&self.write_buf, params.len());
2369 }
2370
2371 if !has_exec_sync {
2372 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2373 }
2374 self.flush_write()?;
2375
2376 cols
2377 } else {
2378 let name = make_stmt_name(sql_hash);
2380 let name_s: &str = std::str::from_utf8(&name).expect("ASCII");
2381 let param_oids: smallvec::SmallVec<[u32; 8]> =
2382 params.iter().map(|p| p.type_oid()).collect();
2383 proto::write_parse(&mut self.write_buf, name_s, sql, ¶m_oids);
2384 proto::write_describe(&mut self.write_buf, b'S', name_s);
2385 proto::write_bind_params(&mut self.write_buf, "", name_s, params);
2386
2387 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2388 self.flush_write()?;
2389
2390 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2391 let columns = self.read_column_description()?;
2392 self.query_counter += 1;
2393 self.cache_stmt(
2394 sql_hash,
2395 StmtInfo {
2396 name,
2397 columns: columns.clone(),
2398 last_used: self.query_counter,
2399 bind_template: None,
2400 },
2401 );
2402 if need_columns { Some(columns) } else { None }
2403 };
2404
2405 if !skip_bind_complete {
2406 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2407 }
2408
2409 Ok(columns)
2410 }
2411
2412 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2414 loop {
2415 let msg = self.read_one_message()?;
2416 match msg {
2417 BackendMessage::RowDescription { data } => {
2418 let cols = proto::parse_row_description(data)?;
2419 return Ok(cols.into());
2420 }
2421 BackendMessage::ParameterDescription { .. } => {}
2422 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2423 BackendMessage::NoticeResponse { .. } => {}
2424 BackendMessage::ErrorResponse { data } => {
2425 let fields = proto::parse_error_response(data);
2426 self.drain_to_ready()?;
2427 return Err(self.make_server_error(fields));
2428 }
2429 other => {
2430 return Err(DriverError::Protocol(format!(
2431 "expected RowDescription/NoData, got: {other:?}"
2432 )));
2433 }
2434 }
2435 }
2436 }
2437
2438 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2441 if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
2442 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2443 proto::write_close(&mut self.write_buf, b'S', evicted.name_str());
2444 }
2445 }
2446 self.stmts.insert(sql_hash, info);
2447 }
2448
2449 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2450 if self.pending_notifications.len() < 1024 {
2451 self.pending_notifications.push(Notification {
2452 pid,
2453 channel: channel.to_owned(),
2454 payload: payload.to_owned(),
2455 });
2456 }
2457 }
2458
2459 fn shrink_buffers(&mut self) {
2460 if self.query_counter & 63 != 0 {
2464 return;
2465 }
2466 if self.read_buf.capacity() > 64 * 1024 {
2467 self.read_buf.clear();
2468 self.read_buf.shrink_to(8192);
2469 }
2470 if self.write_buf.capacity() > 16 * 1024 {
2471 self.write_buf.clear();
2472 self.write_buf.shrink_to(8192);
2473 }
2474 }
2475
2476 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2477 if &*fields.code == "26000" {
2478 self.stmts.remove(&sql_hash);
2479 true
2480 } else {
2481 false
2482 }
2483 }
2484
2485 #[cold]
2486 #[inline(never)]
2487 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2488 DriverError::Server {
2489 code: fields.code,
2490 message: fields.message.into_boxed_str(),
2491 detail: fields.detail.map(String::into_boxed_str),
2492 hint: fields.hint.map(String::into_boxed_str),
2493 position: fields.position,
2494 }
2495 }
2496
2497 #[cold]
2503 #[inline(never)]
2504 fn handle_non_datarow_query(
2505 &mut self,
2506 msg_type: u8,
2507 payload_start: usize,
2508 payload_end: usize,
2509 sql_hash: u64,
2510 affected_rows: &mut u64,
2511 ) -> Result<(), DriverError> {
2512 match msg_type {
2513 b'2' | b'I' => {} b'C' => {
2515 *affected_rows =
2516 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2517 }
2518 b'E' => {
2519 let fields =
2520 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2521 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2522 self.drain_to_ready()?;
2523 return Err(self.make_server_error(fields));
2524 }
2525 b'A' => {
2526 let msg = proto::parse_backend_message(
2527 msg_type,
2528 &self.stream_buf[payload_start..payload_end],
2529 )?;
2530 if let BackendMessage::NotificationResponse {
2531 pid,
2532 channel,
2533 payload,
2534 } = msg
2535 {
2536 let ch = channel.to_owned();
2537 let pl = payload.to_owned();
2538 self.buffer_notification(pid, &ch, &pl);
2539 }
2540 }
2541 _ => {} }
2543 Ok(())
2544 }
2545
2546 #[cold]
2549 #[inline(never)]
2550 fn handle_non_datarow_execute(
2551 &mut self,
2552 msg_type: u8,
2553 payload_start: usize,
2554 payload_end: usize,
2555 sql_hash: u64,
2556 ) -> Result<(), DriverError> {
2557 match msg_type {
2558 b'2' | b'C' | b'I' => {} b'E' => {
2560 let fields =
2561 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2562 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2563 self.drain_to_ready()?;
2564 return Err(self.make_server_error(fields));
2565 }
2566 b'A' => {
2567 let msg = proto::parse_backend_message(
2568 msg_type,
2569 &self.stream_buf[payload_start..payload_end],
2570 )?;
2571 if let BackendMessage::NotificationResponse {
2572 pid,
2573 channel,
2574 payload,
2575 } = msg
2576 {
2577 let ch = channel.to_owned();
2578 let pl = payload.to_owned();
2579 self.buffer_notification(pid, &ch, &pl);
2580 }
2581 }
2582 _ => {} }
2584 Ok(())
2585 }
2586
2587 #[inline]
2589 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2590 loop {
2591 let (msg_type, _payload_len) = self.read_message_buffered()?;
2592 if msg_type == b'A' {
2593 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2594 if let BackendMessage::NotificationResponse {
2595 pid,
2596 channel,
2597 payload,
2598 } = msg
2599 {
2600 let pid_owned = pid;
2601 let channel_owned = channel.to_owned();
2602 let payload_owned = payload.to_owned();
2603 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2604 continue;
2605 }
2606 }
2607 return proto::parse_backend_message(msg_type, &self.read_buf);
2608 }
2609 }
2610
2611 fn expect_message(
2612 &mut self,
2613 pred: impl Fn(&BackendMessage<'_>) -> bool,
2614 ) -> Result<(), DriverError> {
2615 loop {
2616 let msg = self.read_one_message()?;
2617 if pred(&msg) {
2618 return Ok(());
2619 }
2620 match msg {
2621 BackendMessage::ErrorResponse { data } => {
2622 let fields = proto::parse_error_response(data);
2623 self.drain_to_ready()?;
2624 return Err(self.make_server_error(fields));
2625 }
2626 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2627 other => {
2628 return Err(DriverError::Protocol(format!(
2629 "unexpected message while waiting for expected type: {other:?}"
2630 )));
2631 }
2632 }
2633 }
2634 }
2635
2636 fn expect_ready(&mut self) -> Result<(), DriverError> {
2637 loop {
2638 let msg = self.read_one_message()?;
2639 match msg {
2640 BackendMessage::ReadyForQuery { status } => {
2641 self.tx_status = status;
2642 return Ok(());
2643 }
2644 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2645 BackendMessage::ErrorResponse { data } => {
2646 let fields = proto::parse_error_response(data);
2647 self.drain_to_ready()?;
2648 return Err(self.make_server_error(fields));
2649 }
2650 _ => {}
2651 }
2652 }
2653 }
2654
2655 #[inline]
2656 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2657 loop {
2658 let msg = self.read_one_message()?;
2659 if let BackendMessage::ReadyForQuery { status } = msg {
2660 self.tx_status = status;
2661 return Ok(());
2662 }
2663 }
2664 }
2665
2666 #[inline]
2670 fn flush_write(&mut self) -> Result<(), DriverError> {
2671 self.stream
2672 .write_all(&self.write_buf)
2673 .map_err(DriverError::Io)
2674 }
2675
2676 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2680 let mut header = [0u8; 5];
2681 sync_buffered_read_exact(
2682 &mut self.stream,
2683 &mut self.stream_buf,
2684 &mut self.stream_buf_pos,
2685 &mut self.stream_buf_end,
2686 &mut header,
2687 )?;
2688
2689 let msg_type = header[0];
2690 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2691
2692 if len < 4 {
2693 return Err(DriverError::Protocol(format!(
2694 "invalid message length {len} for type '{}'",
2695 msg_type as char
2696 )));
2697 }
2698
2699 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2700 if len > MAX_MESSAGE_LEN {
2701 return Err(DriverError::Protocol(format!(
2702 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2703 msg_type as char
2704 )));
2705 }
2706
2707 let payload_len = (len - 4) as usize;
2708 self.read_buf.clear();
2709 self.read_buf.resize(payload_len, 0);
2710 if payload_len > 0 {
2711 sync_buffered_read_exact(
2712 &mut self.stream,
2713 &mut self.stream_buf,
2714 &mut self.stream_buf_pos,
2715 &mut self.stream_buf_end,
2716 &mut self.read_buf[..payload_len],
2717 )?;
2718 }
2719
2720 Ok((msg_type, payload_len))
2721 }
2722
2723 #[inline]
2725 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
2726 let remaining = self.stream_buf_end - self.stream_buf_pos;
2727 if remaining > 0 && self.stream_buf_pos > 0 {
2728 self.stream_buf
2729 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2730 }
2731 self.stream_buf_pos = 0;
2732 self.stream_buf_end = remaining;
2733
2734 let n = self
2735 .stream
2736 .read(&mut self.stream_buf[remaining..])
2737 .map_err(DriverError::Io)?;
2738 if n == 0 {
2739 return Err(DriverError::Io(std::io::Error::new(
2740 std::io::ErrorKind::UnexpectedEof,
2741 "connection closed",
2742 )));
2743 }
2744 self.stream_buf_end = remaining + n;
2745 Ok(())
2746 }
2747}
2748
2749fn sync_buffered_read_exact(
2752 stream: &mut Stream,
2753 buf: &mut [u8],
2754 pos: &mut usize,
2755 end: &mut usize,
2756 out: &mut [u8],
2757) -> Result<(), DriverError> {
2758 let mut filled = 0;
2759 while filled < out.len() {
2760 let avail = *end - *pos;
2761 if avail > 0 {
2762 let take = avail.min(out.len() - filled);
2763 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
2764 *pos += take;
2765 filled += take;
2766 } else {
2767 *pos = 0;
2768 let n = stream.read(buf).map_err(DriverError::Io)?;
2769 if n == 0 {
2770 return Err(DriverError::Io(std::io::Error::new(
2771 std::io::ErrorKind::UnexpectedEof,
2772 "connection closed",
2773 )));
2774 }
2775 *end = n;
2776 }
2777 }
2778 Ok(())
2779}
2780
2781#[inline(always)]
2791fn parse_data_row_into_buf(
2792 data: &[u8],
2793 buf: &mut Vec<u8>,
2794 out: &mut Vec<(usize, i32)>,
2795) -> Result<(), DriverError> {
2796 if data.len() < 2 {
2797 return Err(DriverError::Protocol("DataRow too short".into()));
2798 }
2799
2800 let num_cols = i16::from_be_bytes([data[0], data[1]]);
2801 if num_cols < 0 {
2802 return Err(DriverError::Protocol(
2803 "DataRow: negative column count".into(),
2804 ));
2805 }
2806 let num_cols = num_cols as usize;
2807
2808 let col_data = &data[2..];
2811 let base = buf.len();
2812 buf.extend_from_slice(col_data);
2813
2814 let mut pos: usize = 0;
2816 for _ in 0..num_cols {
2817 if pos + 4 > col_data.len() {
2818 return Err(DriverError::Protocol("DataRow truncated".into()));
2819 }
2820
2821 let col_len = i32::from_be_bytes([
2822 col_data[pos],
2823 col_data[pos + 1],
2824 col_data[pos + 2],
2825 col_data[pos + 3],
2826 ]);
2827 pos += 4;
2828
2829 if col_len < 0 {
2830 out.push((0, -1));
2831 } else {
2832 let len = col_len as usize;
2833 if pos + len > col_data.len() {
2834 return Err(DriverError::Protocol(
2835 "DataRow column data truncated".into(),
2836 ));
2837 }
2838 out.push((base + pos, col_len));
2840 pos += len;
2841 }
2842 }
2843
2844 Ok(())
2845}
2846
2847fn parse_data_row_flat(
2851 data: &[u8],
2852 arena: &mut Arena,
2853 out: &mut Vec<(usize, i32)>,
2854) -> Result<(), DriverError> {
2855 if data.len() < 2 {
2856 return Err(DriverError::Protocol("DataRow too short".into()));
2857 }
2858
2859 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
2860 if num_cols_raw < 0 {
2861 return Err(DriverError::Protocol(
2862 "DataRow: negative column count".into(),
2863 ));
2864 }
2865 let num_cols = num_cols_raw as usize;
2866 out.reserve(num_cols);
2867
2868 let col_data = &data[2..];
2871 let base = arena.alloc_copy(col_data);
2872
2873 let mut pos: usize = 0;
2875 for _ in 0..num_cols {
2876 if pos + 4 > col_data.len() {
2877 return Err(DriverError::Protocol("DataRow truncated".into()));
2878 }
2879
2880 let col_len = i32::from_be_bytes([
2881 col_data[pos],
2882 col_data[pos + 1],
2883 col_data[pos + 2],
2884 col_data[pos + 3],
2885 ]);
2886 pos += 4;
2887
2888 if col_len < 0 {
2889 out.push((0, -1));
2890 } else {
2891 let len = col_len as usize;
2892 if pos + len > col_data.len() {
2893 return Err(DriverError::Protocol(
2894 "DataRow column data truncated".into(),
2895 ));
2896 }
2897 out.push((base + pos, col_len));
2899 pos += len;
2900 }
2901 }
2902
2903 Ok(())
2904}
2905
2906#[cfg(test)]
2907#[allow(clippy::approx_constant)]
2908mod tests {
2909 use super::*;
2910 use crate::types::hash_sql;
2911
2912 #[test]
2913 fn sync_config_tcp_no_longer_rejected() {
2914 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
2917 let result = Connection::connect(&config);
2918 assert!(result.is_err());
2919 let err = result.unwrap_err().to_string();
2920 assert!(
2923 !err.contains("Unix domain socket"),
2924 "error should NOT mention UDS requirement: {err}"
2925 );
2926 }
2927
2928 #[test]
2929 fn sync_data_row_parsing() {
2930 let mut arena = Arena::new();
2931 let mut out = Vec::new();
2932
2933 let mut data = Vec::new();
2934 data.extend_from_slice(&2i16.to_be_bytes());
2935 data.extend_from_slice(&4i32.to_be_bytes());
2936 data.extend_from_slice(&42i32.to_be_bytes());
2937 data.extend_from_slice(&(-1i32).to_be_bytes());
2938
2939 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2940 assert_eq!(out.len(), 2);
2941 assert_eq!(out[0].1, 4);
2942 assert_eq!(out[1].1, -1);
2943 }
2944
2945 #[test]
2946 fn sync_data_row_empty() {
2947 let mut arena = Arena::new();
2948 let mut out = Vec::new();
2949 let data = 0i16.to_be_bytes();
2950 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2951 assert_eq!(out.len(), 0);
2952 }
2953
2954 #[test]
2955 fn sync_data_row_too_short() {
2956 let mut arena = Arena::new();
2957 let mut out = Vec::new();
2958 let data = vec![0u8];
2959 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2960 }
2961
2962 #[test]
2963 fn sync_data_row_negative_col_count() {
2964 let mut arena = Arena::new();
2965 let mut out = Vec::new();
2966 let data = (-1i16).to_be_bytes();
2967 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2968 }
2969
2970 #[test]
2971 fn sync_data_row_truncated() {
2972 let mut arena = Arena::new();
2973 let mut out = Vec::new();
2974 let mut data = Vec::new();
2975 data.extend_from_slice(&2i16.to_be_bytes());
2976 data.extend_from_slice(&4i32.to_be_bytes());
2977 data.extend_from_slice(&42i32.to_be_bytes());
2978 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2980 }
2981
2982 #[test]
2983 fn sync_data_row_col_data_truncated() {
2984 let mut arena = Arena::new();
2985 let mut out = Vec::new();
2986 let mut data = Vec::new();
2987 data.extend_from_slice(&1i16.to_be_bytes());
2988 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2991 }
2992
2993 #[test]
2996 fn sync_connect_tcp_unreachable_port() {
2997 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3000 let result = Connection::connect(&config);
3001 assert!(result.is_err());
3002 let err = result.unwrap_err().to_string();
3003 assert!(
3004 !err.contains("Unix domain socket"),
3005 "error should NOT mention UDS: {err}"
3006 );
3007 }
3008
3009 #[test]
3010 fn sync_connect_ip_address_attempts_tcp() {
3011 let config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3014 let result = Connection::connect(&config);
3015 assert!(result.is_err());
3016 }
3017
3018 #[test]
3021 fn sync_data_row_all_null() {
3022 let mut arena = Arena::new();
3023 let mut out = Vec::new();
3024 let mut data = Vec::new();
3025 data.extend_from_slice(&3i16.to_be_bytes());
3026 data.extend_from_slice(&(-1i32).to_be_bytes());
3027 data.extend_from_slice(&(-1i32).to_be_bytes());
3028 data.extend_from_slice(&(-1i32).to_be_bytes());
3029 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3030 assert_eq!(out.len(), 3);
3031 for (_, len) in &out {
3032 assert_eq!(*len, -1);
3033 }
3034 }
3035
3036 #[test]
3037 fn sync_data_row_long_text() {
3038 let mut arena = Arena::new();
3039 let mut out = Vec::new();
3040 let long_text = "a".repeat(2048);
3041 let text_bytes = long_text.as_bytes();
3042 let mut data = Vec::new();
3043 data.extend_from_slice(&1i16.to_be_bytes());
3044 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
3045 data.extend_from_slice(text_bytes);
3046 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3047 assert_eq!(out.len(), 1);
3048 assert_eq!(out[0].1, text_bytes.len() as i32);
3049 let stored = arena.get(out[0].0, out[0].1 as usize);
3050 assert_eq!(stored, text_bytes);
3051 }
3052
3053 #[test]
3054 fn sync_data_row_empty_text() {
3055 let mut arena = Arena::new();
3056 let mut out = Vec::new();
3057 let mut data = Vec::new();
3058 data.extend_from_slice(&1i16.to_be_bytes());
3059 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3061 assert_eq!(out.len(), 1);
3062 assert_eq!(out[0].1, 0); }
3064
3065 #[test]
3066 fn sync_data_row_17_columns_exceeds_smallvec() {
3067 let mut arena = Arena::new();
3068 let mut out = Vec::new();
3069 let mut data = Vec::new();
3070 let num_cols: i16 = 20;
3071 data.extend_from_slice(&num_cols.to_be_bytes());
3072 for i in 0..num_cols {
3073 let val = (i as i32).to_be_bytes();
3074 data.extend_from_slice(&4i32.to_be_bytes());
3075 data.extend_from_slice(&val);
3076 }
3077 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3078 assert_eq!(out.len(), 20);
3079 for (idx, (offset, len)) in out.iter().enumerate() {
3080 assert_eq!(*len, 4);
3081 let stored = arena.get(*offset, 4);
3082 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
3083 assert_eq!(val, idx as i32);
3084 }
3085 }
3086
3087 #[test]
3088 fn sync_data_row_mixed_null_and_data() {
3089 let mut arena = Arena::new();
3090 let mut out = Vec::new();
3091 let mut data = Vec::new();
3092 data.extend_from_slice(&5i16.to_be_bytes());
3093 data.extend_from_slice(&(-1i32).to_be_bytes());
3095 data.extend_from_slice(&4i32.to_be_bytes());
3097 data.extend_from_slice(&42i32.to_be_bytes());
3098 data.extend_from_slice(&(-1i32).to_be_bytes());
3100 data.extend_from_slice(&(-1i32).to_be_bytes());
3102 data.extend_from_slice(&5i32.to_be_bytes());
3104 data.extend_from_slice(b"hello");
3105
3106 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3107 assert_eq!(out.len(), 5);
3108 assert_eq!(out[0].1, -1);
3109 assert_eq!(out[1].1, 4);
3110 assert_eq!(out[2].1, -1);
3111 assert_eq!(out[3].1, -1);
3112 assert_eq!(out[4].1, 5);
3113 let stored = arena.get(out[4].0, 5);
3114 assert_eq!(stored, b"hello");
3115 }
3116
3117 #[test]
3120 #[ignore] fn sync_connect_uds_if_pg_available() {
3122 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3123 let result = Connection::connect(&config);
3124 if let Ok(conn) = result {
3126 assert!(conn.pid() != 0, "pid should be nonzero");
3127 assert!(conn.is_idle(), "should start idle");
3128 assert!(!conn.is_in_transaction(), "should not be in tx");
3129 assert!(
3130 !conn.is_in_failed_transaction(),
3131 "should not be in failed tx"
3132 );
3133 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3134 let _ = conn.close();
3135 }
3136 }
3137
3138 #[test]
3139 #[ignore] fn sync_simple_query_if_pg_available() {
3141 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3142 let mut conn = Connection::connect(&config).unwrap();
3143 conn.simple_query("SELECT 1").unwrap();
3144 assert!(conn.is_idle());
3145 let _ = conn.close();
3146 }
3147
3148 #[test]
3149 #[ignore] fn sync_query_with_params_if_pg_available() {
3151 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3152 let mut conn = Connection::connect(&config).unwrap();
3153 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3154 let hash = hash_sql(sql);
3155 let a: i32 = 10;
3156 let b: i32 = 20;
3157 let result = conn
3158 .query(
3159 sql,
3160 hash,
3161 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3162 )
3163 .unwrap();
3164 assert_eq!(result.len(), 1);
3165 let _ = conn.close();
3166 }
3167
3168 #[test]
3169 #[ignore] fn sync_execute_if_pg_available() {
3171 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3172 let mut conn = Connection::connect(&config).unwrap();
3173 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3174 .unwrap();
3175 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3176 let hash = hash_sql(sql);
3177 let val: i32 = 42;
3178 let affected = conn
3179 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3180 .unwrap();
3181 assert_eq!(affected, 1);
3182 let _ = conn.close();
3183 }
3184
3185 #[test]
3186 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3188 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3189 let mut conn = Connection::connect(&config).unwrap();
3190 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3191 .unwrap();
3192 let sql = "SELECT id FROM _sync_fe0";
3193 let hash = hash_sql(sql);
3194 let mut count = 0u32;
3195 conn.for_each(sql, hash, &[], |_row| {
3196 count += 1;
3197 Ok(())
3198 })
3199 .unwrap();
3200 assert_eq!(count, 0);
3201 let _ = conn.close();
3202 }
3203
3204 #[test]
3205 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3207 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3208 let mut conn = Connection::connect(&config).unwrap();
3209 let sql = "SELECT generate_series(1, 5)";
3210 let hash = hash_sql(sql);
3211 let mut count = 0u32;
3212 conn.for_each(sql, hash, &[], |_row| {
3213 count += 1;
3214 Ok(())
3215 })
3216 .unwrap();
3217 assert_eq!(count, 5);
3218 let _ = conn.close();
3219 }
3220
3221 #[test]
3222 #[ignore] fn sync_prepare_only_if_pg_available() {
3224 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3225 let mut conn = Connection::connect(&config).unwrap();
3226 let sql = "SELECT 1";
3227 let hash = hash_sql(sql);
3228 conn.prepare_only(sql, hash).unwrap();
3229 assert_eq!(conn.stmt_cache_len(), 1);
3230 conn.prepare_only(sql, hash).unwrap();
3232 assert_eq!(conn.stmt_cache_len(), 1);
3233 let _ = conn.close();
3234 }
3235
3236 #[test]
3237 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3239 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3240 let mut conn = Connection::connect(&config).unwrap();
3241 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3242 assert!(!rows.is_empty());
3243 let _ = conn.close();
3244 }
3245
3246 #[test]
3247 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3249 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3250 let mut conn = Connection::connect(&config).unwrap();
3251 let sql1 = "SELECT 1";
3252 let hash1 = hash_sql(sql1);
3253 conn.query(sql1, hash1, &[]).unwrap();
3254 assert_eq!(conn.stmt_cache_len(), 1);
3255 conn.query(sql1, hash1, &[]).unwrap();
3257 assert_eq!(conn.stmt_cache_len(), 1);
3258 let sql2 = "SELECT 2";
3260 let hash2 = hash_sql(sql2);
3261 conn.query(sql2, hash2, &[]).unwrap();
3262 assert_eq!(conn.stmt_cache_len(), 2);
3263 let _ = conn.close();
3264 }
3265
3266 #[test]
3267 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3269 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3270 let mut conn = Connection::connect(&config).unwrap();
3271 let sql = "SELECTTTT INVALID GARBAGE";
3272 let hash = hash_sql(sql);
3273 let result = conn.query(sql, hash, &[]);
3274 assert!(result.is_err());
3275 assert!(conn.is_idle());
3277 let _ = conn.close();
3278 }
3279
3280 #[test]
3281 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3283 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3284 let mut conn = Connection::connect(&config).unwrap();
3285 assert!(conn.is_idle());
3286 assert!(!conn.is_in_transaction());
3287 conn.simple_query("BEGIN").unwrap();
3288 assert!(conn.is_in_transaction());
3289 assert!(!conn.is_idle());
3290 conn.simple_query("COMMIT").unwrap();
3291 assert!(conn.is_idle());
3292 assert!(!conn.is_in_transaction());
3293 let _ = conn.close();
3294 }
3295
3296 #[test]
3297 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3299 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3300 let mut conn = Connection::connect(&config).unwrap();
3301 conn.set_max_stmt_cache_size(3);
3302 for i in 0..5 {
3303 let sql = format!("SELECT {}", i);
3304 let hash = hash_sql(&sql);
3305 conn.query(&sql, hash, &[]).unwrap();
3306 }
3307 assert!(
3309 conn.stmt_cache_len() <= 3,
3310 "cache should be capped at 3, got {}",
3311 conn.stmt_cache_len()
3312 );
3313 let _ = conn.close();
3314 }
3315
3316 #[test]
3317 #[ignore] fn sync_for_each_raw_if_pg_available() {
3319 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3320 let mut conn = Connection::connect(&config).unwrap();
3321 let sql = "SELECT generate_series(1, 3)";
3322 let hash = hash_sql(sql);
3323 let mut raw_count = 0u32;
3324 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3325 raw_count += 1;
3326 Ok(())
3327 })
3328 .unwrap();
3329 assert_eq!(raw_count, 3);
3330 let _ = conn.close();
3331 }
3332
3333 #[test]
3334 #[ignore] fn sync_query_null_params_if_pg_available() {
3336 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3337 let mut conn = Connection::connect(&config).unwrap();
3338 let sql = "SELECT $1::int4 IS NULL AS is_null";
3339 let hash = hash_sql(sql);
3340 let val: Option<i32> = None;
3341 let _result = conn
3342 .query(sql, hash, &[&val as &(dyn Encode + Sync)])
3343 .unwrap();
3344 let _ = conn.close();
3345 }
3346
3347 #[test]
3348 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3350 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3351 let mut conn = Connection::connect(&config).unwrap();
3352 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3353 let hash = hash_sql(sql);
3354 let p1: i32 = 42;
3355 let p2: i64 = 9999999;
3356 let p3: &str = "hello";
3357 let p4: bool = true;
3358 let p5: f64 = 3.14;
3359 let result = conn
3360 .query(
3361 sql,
3362 hash,
3363 &[
3364 &p1 as &(dyn Encode + Sync),
3365 &p2 as &(dyn Encode + Sync),
3366 &p3 as &(dyn Encode + Sync),
3367 &p4 as &(dyn Encode + Sync),
3368 &p5 as &(dyn Encode + Sync),
3369 ],
3370 )
3371 .unwrap();
3372 assert_eq!(result.len(), 1);
3373 let _ = conn.close();
3374 }
3375
3376 #[test]
3379 fn sync_shrink_threshold_values() {
3380 let shrink = 64 * 1024usize;
3389 let initial = 8192usize;
3390 assert!(
3391 shrink > initial,
3392 "shrink threshold must exceed initial size"
3393 );
3394 }
3395
3396 #[test]
3399 fn sync_connection_debug_format() {
3400 let fmt_str = format!(
3404 "Connection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3405 0, 'I', 0
3406 );
3407 assert!(fmt_str.contains("Connection"));
3408 assert!(fmt_str.contains("pid"));
3409 assert!(fmt_str.contains("tx_status"));
3410 }
3411
3412 #[test]
3415 fn sync_connect_sslmode_require_without_tls_feature() {
3416 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3420 config.ssl = SslMode::Require;
3421 let result = Connection::connect(&config);
3422 assert!(result.is_err());
3423 }
3428
3429 #[test]
3430 fn sync_connect_sslmode_disable_attempts_tcp() {
3431 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3432 config.ssl = SslMode::Disable;
3433 let result = Connection::connect(&config);
3434 assert!(result.is_err());
3435 assert!(matches!(result.unwrap_err(), DriverError::Io(_)));
3437 }
3438
3439 #[test]
3440 fn sync_connect_sslmode_prefer_attempts_tcp() {
3441 let mut config = Config::from_url("postgres://user:pass@127.0.0.1:1/db").unwrap();
3442 config.ssl = SslMode::Prefer;
3443 let result = Connection::connect(&config);
3444 assert!(result.is_err());
3445 }
3446
3447 #[test]
3450 #[ignore] fn sync_streaming_basic_if_pg_available() {
3452 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3453 let mut conn = Connection::connect(&config).unwrap();
3454 assert!(!conn.is_streaming());
3455
3456 let sql = "SELECT generate_series(1, 10)";
3457 let hash = hash_sql(sql);
3458
3459 let (cols, _) = conn.query_streaming_start(sql, hash, &[], 3).unwrap();
3460 assert!(!cols.is_empty());
3461 assert!(conn.is_streaming());
3462
3463 let mut arena = Arena::new();
3464 let mut offsets = Vec::new();
3465 let mut total_rows = 0;
3466
3467 loop {
3469 let has_more = conn.streaming_next_chunk(&mut arena, &mut offsets).unwrap();
3470 total_rows += offsets.len();
3471 if !has_more {
3472 break;
3473 }
3474 conn.streaming_send_execute(3).unwrap();
3475 }
3476
3477 assert_eq!(total_rows, 10);
3478 assert!(!conn.is_streaming());
3479 let _ = conn.close();
3480 }
3481
3482 #[test]
3485 #[ignore] fn sync_prepare_describe_if_pg_available() {
3487 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3488 let mut conn = Connection::connect(&config).unwrap();
3489
3490 let result = conn
3491 .prepare_describe("SELECT $1::int4 + $2::int4 AS sum")
3492 .unwrap();
3493 assert_eq!(result.columns.len(), 1);
3494 assert_eq!(&*result.columns[0].name, "sum");
3495 assert_eq!(result.param_oids.len(), 2);
3496 let _ = conn.close();
3497 }
3498
3499 #[test]
3502 #[ignore] fn sync_wait_for_notification_if_pg_available() {
3504 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3505 let mut conn = Connection::connect(&config).unwrap();
3506
3507 conn.simple_query("LISTEN test_chan").unwrap();
3508 conn.simple_query("NOTIFY test_chan, 'hello'").unwrap();
3509
3510 conn.set_read_timeout(Some(std::time::Duration::from_secs(5)))
3512 .unwrap();
3513
3514 let (channel, payload) = conn.wait_for_notification().unwrap();
3515 assert_eq!(channel, "test_chan");
3516 assert_eq!(payload, "hello");
3517 let _ = conn.close();
3518 }
3519
3520 #[test]
3523 #[ignore] fn sync_cancel_if_pg_available() {
3525 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3526 let conn = Connection::connect(&config).unwrap();
3527 let result = conn.cancel();
3530 let _ = result;
3532 let _ = conn.close();
3533 }
3534
3535 #[test]
3538 #[ignore] fn sync_server_params_if_pg_available() {
3540 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3541 let conn = Connection::connect(&config).unwrap();
3542 let params = conn.server_params();
3543 assert!(
3544 !params.is_empty(),
3545 "server should send parameters during startup"
3546 );
3547 assert!(
3549 conn.parameter("server_encoding").is_some(),
3550 "server_encoding should be present"
3551 );
3552 let _ = conn.close();
3553 }
3554
3555 #[test]
3558 #[ignore] fn sync_set_read_timeout_if_pg_available() {
3560 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3561 let conn = Connection::connect(&config).unwrap();
3562 conn.set_read_timeout(Some(std::time::Duration::from_secs(10)))
3564 .unwrap();
3565 conn.set_read_timeout(None).unwrap();
3566 let _ = conn.close();
3567 }
3568}