1use std::io::{Read, Write};
26use std::os::unix::net::UnixStream;
27use std::sync::Arc;
28
29use crate::DriverError;
30use crate::arena::Arena;
31use crate::auth;
32use crate::codec::Encode;
33use crate::conn::{ColumnDesc, Config, Notification, PgDataRow, QueryResult, SimpleRow};
34use crate::proto::{self, BackendMessage};
35
36struct StmtCache {
47 entries: Vec<(u64, StmtInfo)>,
48}
49
50impl Default for StmtCache {
51 fn default() -> Self {
52 Self {
53 entries: Vec::with_capacity(16),
54 }
55 }
56}
57
58impl StmtCache {
59 #[inline]
60 fn get_mut(&mut self, hash: &u64) -> Option<&mut StmtInfo> {
61 self.entries
62 .iter_mut()
63 .find(|(h, _)| h == hash)
64 .map(|(_, info)| info)
65 }
66
67 #[inline]
68 fn get(&self, hash: &u64) -> Option<&StmtInfo> {
69 self.entries
70 .iter()
71 .find(|(h, _)| h == hash)
72 .map(|(_, info)| info)
73 }
74
75 #[inline]
76 fn contains_key(&self, hash: &u64) -> bool {
77 self.entries.iter().any(|(h, _)| h == hash)
78 }
79
80 #[inline]
81 fn insert(&mut self, hash: u64, info: StmtInfo) {
82 if let Some(entry) = self.entries.iter_mut().find(|(h, _)| *h == hash) {
83 entry.1 = info;
84 } else {
85 self.entries.push((hash, info));
86 }
87 }
88
89 #[inline]
90 fn remove(&mut self, hash: &u64) -> Option<StmtInfo> {
91 if let Some(pos) = self.entries.iter().position(|(h, _)| h == hash) {
92 Some(self.entries.swap_remove(pos).1)
93 } else {
94 None
95 }
96 }
97
98 #[inline]
99 fn len(&self) -> usize {
100 self.entries.len()
101 }
102
103 fn evict_lru(&mut self) -> Option<(u64, StmtInfo)> {
105 if self.entries.is_empty() {
106 return None;
107 }
108 let min_idx = self
109 .entries
110 .iter()
111 .enumerate()
112 .min_by_key(|(_, (_, info))| info.last_used)
113 .map(|(i, _)| i)?;
114 Some(self.entries.swap_remove(min_idx))
115 }
116}
117
118struct StmtInfo {
120 name: Box<str>,
122 columns: Arc<[ColumnDesc]>,
124 last_used: u64,
126 bind_template: Option<BindTemplate>,
135}
136
137struct BindTemplate {
144 bytes: Vec<u8>,
146 param_slots: Vec<(usize, i32)>,
150}
151
152#[inline]
154fn make_stmt_name(hash: u64) -> Box<str> {
155 const HEX: &[u8; 16] = b"0123456789abcdef";
156 let mut buf = [0u8; 18];
157 buf[0] = b's';
158 buf[1] = b'_';
159 let bytes = hash.to_be_bytes();
160 for (i, &b) in bytes.iter().enumerate() {
161 buf[2 + i * 2] = HEX[(b >> 4) as usize];
162 buf[2 + i * 2 + 1] = HEX[(b & 0x0f) as usize];
163 }
164 let s = std::str::from_utf8(&buf).expect("BUG: stmt name buffer contains only ASCII hex");
165 s.into()
166}
167
168enum StartupAction {
170 AuthOk,
171 AuthCleartext,
172 AuthMd5([u8; 4]),
173 AuthSasl(Vec<u8>),
174 ParameterStatus(Box<str>, Box<str>),
175 BackendKeyData(i32, i32),
176 ReadyForQuery(u8),
177 Error(String),
178 Notice,
179}
180
181pub struct SyncConnection {
208 stream: UnixStream,
209 read_buf: Vec<u8>,
210 stream_buf: Vec<u8>,
211 stream_buf_pos: usize,
212 stream_buf_end: usize,
213 write_buf: Vec<u8>,
214 stmts: StmtCache,
215 params: Vec<(Box<str>, Box<str>)>,
216 pid: i32,
217 secret: i32,
218 tx_status: u8,
219 last_used: std::time::Instant,
220 created_at: std::time::Instant,
221 pending_notifications: Vec<Notification>,
222 max_stmt_cache_size: usize,
223 query_counter: u64,
224}
225
226impl std::fmt::Debug for SyncConnection {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("SyncConnection")
229 .field("pid", &self.pid)
230 .field("tx_status", &(self.tx_status as char))
231 .field("stmt_cache_len", &self.stmts.len())
232 .finish()
233 }
234}
235
236impl SyncConnection {
237 pub fn connect(config: &Config) -> Result<Self, DriverError> {
247 if !config.host_is_uds() {
248 return Err(DriverError::Protocol(
249 "SyncConnection requires a Unix domain socket path (host starting with '/')".into(),
250 ));
251 }
252
253 let path = config.uds_path();
254 let stream = UnixStream::connect(&path).map_err(DriverError::Io)?;
255
256 let mut conn = Self {
257 stream,
258 read_buf: Vec::with_capacity(8192),
259 stream_buf: vec![0u8; 65536],
260 stream_buf_pos: 0,
261 stream_buf_end: 0,
262 write_buf: Vec::with_capacity(4096),
263 stmts: StmtCache::default(),
264 params: Vec::new(),
265 pid: 0,
266 secret: 0,
267 tx_status: b'I',
268 last_used: std::time::Instant::now(),
269 created_at: std::time::Instant::now(),
270 pending_notifications: Vec::new(),
271 max_stmt_cache_size: 256,
272 query_counter: 0,
273 };
274
275 conn.startup(config)?;
276 conn.validate_server_params()?;
277
278 if config.statement_timeout_secs > 0 {
279 conn.simple_query(&format!(
280 "SET statement_timeout = '{}s'",
281 config.statement_timeout_secs
282 ))?;
283 }
284
285 Ok(conn)
286 }
287
288 fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
291 self.write_buf.clear();
292 proto::write_startup(&mut self.write_buf, &config.user, &config.database);
293 self.flush_write()?;
294
295 loop {
296 let action = self.read_startup_action()?;
297 match action {
298 StartupAction::AuthOk => {}
299 StartupAction::AuthCleartext => {
300 self.write_buf.clear();
301 let mut pw = config.password.as_bytes().to_vec();
302 pw.push(0);
303 proto::write_password(&mut self.write_buf, &pw);
304 self.flush_write()?;
305 }
306 StartupAction::AuthMd5(salt) => {
307 self.write_buf.clear();
308 let hash = auth::md5_password(&config.user, &config.password, &salt);
309 proto::write_password(&mut self.write_buf, &hash);
310 self.flush_write()?;
311 }
312 StartupAction::AuthSasl(mechanisms_data) => {
313 self.handle_scram(config, &mechanisms_data)?;
314 }
315 StartupAction::ParameterStatus(name, value) => {
316 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
317 entry.1 = value;
318 } else {
319 self.params.push((name, value));
320 }
321 }
322 StartupAction::BackendKeyData(pid, secret) => {
323 self.pid = pid;
324 self.secret = secret;
325 }
326 StartupAction::ReadyForQuery(status) => {
327 self.tx_status = status;
328 return Ok(());
329 }
330 StartupAction::Error(msg) => {
331 return Err(DriverError::Auth(msg));
332 }
333 StartupAction::Notice => {}
334 }
335 }
336 }
337
338 fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
339 let (msg_type, _) = self.read_message_buffered()?;
340 let payload = &self.read_buf;
341 let msg = proto::parse_backend_message(msg_type, payload)?;
342 match msg {
343 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
344 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
345 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
346 BackendMessage::AuthSasl { mechanisms } => {
347 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
348 }
349 BackendMessage::ParameterStatus { name, value } => {
350 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
351 }
352 BackendMessage::BackendKeyData { pid, secret } => {
353 Ok(StartupAction::BackendKeyData(pid, secret))
354 }
355 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
356 BackendMessage::ErrorResponse { data } => {
357 let fields = proto::parse_error_response(data);
358 Ok(StartupAction::Error(fields.to_string()))
359 }
360 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
361 other => Err(DriverError::Protocol(format!(
362 "unexpected message during startup: {other:?}"
363 ))),
364 }
365 }
366
367 fn handle_scram(&mut self, config: &Config, mechanisms_data: &[u8]) -> Result<(), DriverError> {
368 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
369 if !mechs.contains(&"SCRAM-SHA-256") {
370 return Err(DriverError::Auth(format!(
371 "server requires unsupported SASL mechanism(s): {mechs:?}"
372 )));
373 }
374
375 let mut scram = auth::ScramClient::new(&config.user, &config.password)?;
376
377 let client_first = scram.client_first_message();
379 self.write_buf.clear();
380 proto::write_sasl_initial(&mut self.write_buf, "SCRAM-SHA-256", &client_first);
381 self.flush_write()?;
382
383 let (msg_type, _) = self.read_message_buffered()?;
385 let server_first = {
386 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
387 match msg {
388 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
389 BackendMessage::ErrorResponse { data } => {
390 let fields = proto::parse_error_response(data);
391 return Err(DriverError::Auth(fields.to_string()));
392 }
393 other => {
394 return Err(DriverError::Protocol(format!(
395 "expected AuthSaslContinue, got: {other:?}"
396 )));
397 }
398 }
399 };
400
401 scram.process_server_first(&server_first)?;
402
403 let client_final = scram.client_final_message()?;
405 self.write_buf.clear();
406 proto::write_sasl_response(&mut self.write_buf, &client_final);
407 self.flush_write()?;
408
409 let (msg_type, _) = self.read_message_buffered()?;
411 {
412 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
413 match msg {
414 BackendMessage::AuthSaslFinal { data } => {
415 let data_owned = data.to_vec();
416 scram.verify_server_final(&data_owned)?;
417 }
418 BackendMessage::ErrorResponse { data } => {
419 let fields = proto::parse_error_response(data);
420 return Err(DriverError::Auth(fields.to_string()));
421 }
422 other => {
423 return Err(DriverError::Protocol(format!(
424 "expected AuthSaslFinal, got: {other:?}"
425 )));
426 }
427 }
428 }
429
430 let (msg_type, _) = self.read_message_buffered()?;
432 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
433 match msg {
434 BackendMessage::AuthOk => Ok(()),
435 BackendMessage::ErrorResponse { data } => {
436 let fields = proto::parse_error_response(data);
437 Err(DriverError::Auth(fields.to_string()))
438 }
439 other => Err(DriverError::Protocol(format!(
440 "expected AuthOk after SCRAM, got: {other:?}"
441 ))),
442 }
443 }
444
445 fn validate_server_params(&self) -> Result<(), DriverError> {
446 if let Some(encoding) = self.parameter("server_encoding") {
447 let normalized = encoding.to_uppercase();
448 if normalized != "UTF8" && normalized != "UTF-8" {
449 return Err(DriverError::Protocol(format!(
450 "server_encoding is '{encoding}', but bsql requires UTF-8."
451 )));
452 }
453 }
454 if let Some(encoding) = self.parameter("client_encoding") {
455 let normalized = encoding.to_uppercase();
456 if normalized != "UTF8" && normalized != "UTF-8" {
457 return Err(DriverError::Protocol(format!(
458 "client_encoding is '{encoding}', but bsql requires UTF-8."
459 )));
460 }
461 }
462 if let Some(idt) = self.parameter("integer_datetimes") {
463 if idt != "on" {
464 return Err(DriverError::Protocol(format!(
465 "integer_datetimes is '{idt}', but bsql requires 'on'."
466 )));
467 }
468 }
469 Ok(())
470 }
471
472 pub fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
478 if self.stmts.contains_key(&sql_hash) {
479 return Ok(());
480 }
481 let name = make_stmt_name(sql_hash);
482 self.write_buf.clear();
483 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
484 proto::write_describe(&mut self.write_buf, b'S', &name);
485 proto::write_sync(&mut self.write_buf);
486 self.flush_write()?;
487
488 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
489 let columns = self.read_column_description()?;
490 self.expect_ready()?;
491
492 self.query_counter += 1;
493 self.cache_stmt(
494 sql_hash,
495 StmtInfo {
496 name,
497 columns,
498 last_used: self.query_counter,
499 bind_template: None,
500 },
501 );
502 Ok(())
503 }
504
505 #[inline]
512 pub fn query(
513 &mut self,
514 sql: &str,
515 sql_hash: u64,
516 params: &[&(dyn Encode + Sync)],
517 arena: &mut Arena,
518 ) -> Result<QueryResult, DriverError> {
519 let columns = self
520 .send_pipeline(sql, sql_hash, params, true, true)?
521 .expect("send_pipeline(need_columns=true) must return Some");
522
523 let num_cols = columns.len();
524 let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
525 let mut affected_rows: u64 = 0;
526
527 'outer: loop {
529 loop {
530 let avail = self.stream_buf_end - self.stream_buf_pos;
531 if avail < 5 {
532 break; }
534
535 let msg_type = self.stream_buf[self.stream_buf_pos];
536 let raw_len = i32::from_be_bytes([
537 self.stream_buf[self.stream_buf_pos + 1],
538 self.stream_buf[self.stream_buf_pos + 2],
539 self.stream_buf[self.stream_buf_pos + 3],
540 self.stream_buf[self.stream_buf_pos + 4],
541 ]);
542
543 if raw_len < 4 {
544 return Err(DriverError::Protocol(format!(
545 "invalid message length {raw_len} for type '{}'",
546 msg_type as char
547 )));
548 }
549
550 let payload_len = (raw_len - 4) as usize;
551 let total_msg_len = 5 + payload_len;
552
553 if avail < total_msg_len {
554 if total_msg_len > self.stream_buf.len() {
555 let msg = self.read_one_message()?;
557 match msg {
558 BackendMessage::BindComplete => continue,
559 BackendMessage::DataRow { data } => {
560 parse_data_row_flat(data, arena, &mut all_col_offsets)?;
561 continue;
562 }
563 BackendMessage::CommandComplete { tag } => {
564 affected_rows = proto::parse_command_tag(tag);
565 continue;
566 }
567 BackendMessage::EmptyQuery => continue,
568 BackendMessage::ReadyForQuery { status } => {
569 self.tx_status = status;
570 break 'outer;
571 }
572 BackendMessage::NoticeResponse { .. } => continue,
573 BackendMessage::ErrorResponse { data } => {
574 let fields = proto::parse_error_response(data);
575 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
576 self.drain_to_ready()?;
577 return Err(self.make_server_error(fields));
578 }
579 other => {
580 return Err(DriverError::Protocol(format!(
581 "unexpected message during query: {other:?}"
582 )));
583 }
584 }
585 }
586 break; }
588
589 let payload_start = self.stream_buf_pos + 5;
591 let payload_end = payload_start + payload_len;
592
593 if msg_type == b'D' {
597 parse_data_row_flat(
599 &self.stream_buf[payload_start..payload_end],
600 arena,
601 &mut all_col_offsets,
602 )?;
603 } else if msg_type == b'Z' {
604 if payload_len >= 1 {
606 self.tx_status = self.stream_buf[payload_start];
607 }
608 self.stream_buf_pos += total_msg_len;
609 break 'outer;
610 } else {
611 self.handle_non_datarow_query(
612 msg_type,
613 payload_start,
614 payload_end,
615 sql_hash,
616 &mut affected_rows,
617 )?;
618 }
619
620 self.stream_buf_pos += total_msg_len;
621 }
622
623 self.refill_stream_buf()?;
625 }
626
627 self.shrink_buffers();
628
629 Ok(QueryResult::from_parts(
630 all_col_offsets,
631 num_cols,
632 columns,
633 affected_rows,
634 ))
635 }
636
637 #[inline]
648 pub fn execute_monolithic(
649 &mut self,
650 sql: &str,
651 sql_hash: u64,
652 params: &[&(dyn Encode + Sync)],
653 ) -> Result<u64, DriverError> {
654 self.write_buf.clear();
656
657 let info = match self.stmts.get_mut(&sql_hash) {
659 Some(info) => {
660 self.query_counter += 1;
661 info.last_used = self.query_counter;
662 info
663 }
664 None => {
665 return self.execute_with_prepare(sql, sql_hash, params);
667 }
668 };
669
670 let can_use_template = info
672 .bind_template
673 .as_ref()
674 .is_some_and(|t| t.param_slots.len() == params.len());
675
676 let mut has_exec_sync = false;
677
678 if can_use_template {
679 let tmpl = info.bind_template.as_ref().unwrap();
680 self.write_buf.extend_from_slice(&tmpl.bytes);
681
682 let mut template_ok = true;
683 for (i, param) in params.iter().enumerate() {
684 let (data_offset, old_len) = tmpl.param_slots[i];
685 if param.is_null() {
686 let len_offset = data_offset - 4;
687 self.write_buf[len_offset..len_offset + 4]
688 .copy_from_slice(&(-1i32).to_be_bytes());
689 } else if old_len >= 0 {
690 let end = data_offset + old_len as usize;
691 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
692 template_ok = false;
693 break;
694 }
695 } else {
696 template_ok = false;
698 break;
699 }
700 }
701
702 if template_ok {
703 has_exec_sync = true;
704 } else {
705 self.write_buf.clear();
706 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
707 info.bind_template = None;
708 }
709 } else {
710 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
711 }
712
713 if info.bind_template.is_none() && !self.write_buf.is_empty() {
715 info.bind_template = build_bind_template(&self.write_buf, params.len());
716 }
717
718 if !has_exec_sync {
719 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
720 }
721
722 self.stream
724 .write_all(&self.write_buf)
725 .map_err(DriverError::Io)?;
726
727 let mut affected_rows: u64 = 0;
729
730 'outer: loop {
731 loop {
732 let avail = self.stream_buf_end - self.stream_buf_pos;
733 if avail < 5 {
734 break; }
736
737 let msg_type = self.stream_buf[self.stream_buf_pos];
738 let raw_len = i32::from_be_bytes([
739 self.stream_buf[self.stream_buf_pos + 1],
740 self.stream_buf[self.stream_buf_pos + 2],
741 self.stream_buf[self.stream_buf_pos + 3],
742 self.stream_buf[self.stream_buf_pos + 4],
743 ]);
744
745 if raw_len < 4 {
746 return Err(DriverError::Protocol(format!(
747 "invalid message length {raw_len} for type '{}'",
748 msg_type as char
749 )));
750 }
751
752 let payload_len = (raw_len - 4) as usize;
753 let total_msg_len = 5 + payload_len;
754
755 if avail < total_msg_len {
756 if total_msg_len > self.stream_buf.len() {
757 let msg = self.read_one_message()?;
758 match msg {
759 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
760 continue;
761 }
762 BackendMessage::CommandComplete { tag } => {
763 affected_rows = proto::parse_command_tag(tag);
764 continue;
765 }
766 BackendMessage::EmptyQuery => continue,
767 BackendMessage::ReadyForQuery { status } => {
768 self.tx_status = status;
769 break 'outer;
770 }
771 BackendMessage::NoticeResponse { .. } => continue,
772 BackendMessage::ErrorResponse { data } => {
773 let fields = proto::parse_error_response(data);
774 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
775 self.drain_to_ready()?;
776 return Err(self.make_server_error(fields));
777 }
778 other => {
779 return Err(DriverError::Protocol(format!(
780 "unexpected message during execute: {other:?}"
781 )));
782 }
783 }
784 }
785 break; }
787
788 let payload_start = self.stream_buf_pos + 5;
793 let payload_end = payload_start + payload_len;
794
795 if msg_type == b'2' {
796 self.stream_buf_pos += total_msg_len;
798 continue;
799 } else if msg_type == b'C' {
800 affected_rows = proto::parse_command_tag_bytes(
802 &self.stream_buf[payload_start..payload_end],
803 );
804 } else if msg_type == b'Z' {
805 if payload_len >= 1 {
807 self.tx_status = self.stream_buf[payload_start];
808 }
809 self.stream_buf_pos += total_msg_len;
810 break 'outer;
811 } else if msg_type == b'D' || msg_type == b'I' {
812 } else {
814 self.handle_non_datarow_execute(
815 msg_type,
816 payload_start,
817 payload_end,
818 sql_hash,
819 )?;
820 }
821
822 self.stream_buf_pos += total_msg_len;
823 }
824
825 let remaining = self.stream_buf_end - self.stream_buf_pos;
827 debug_assert!(
828 remaining == 0 || self.stream_buf_pos > 0,
829 "compact called with pos=0 and remaining data"
830 );
831 if remaining > 0 {
832 self.stream_buf
833 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
834 }
835 self.stream_buf_pos = 0;
836 self.stream_buf_end = remaining;
837 let n = self
838 .stream
839 .read(&mut self.stream_buf[remaining..])
840 .map_err(DriverError::Io)?;
841 if n == 0 {
842 return Err(DriverError::Io(std::io::Error::new(
843 std::io::ErrorKind::UnexpectedEof,
844 "connection closed",
845 )));
846 }
847 self.stream_buf_end = remaining + n;
848 }
849
850 if self.query_counter & 63 == 0 {
852 if self.read_buf.capacity() > 64 * 1024 {
853 self.read_buf.clear();
854 self.read_buf.shrink_to(8192);
855 }
856 if self.write_buf.capacity() > 16 * 1024 {
857 self.write_buf.clear();
858 self.write_buf.shrink_to(8192);
859 }
860 }
861
862 Ok(affected_rows)
863 }
864
865 #[cold]
867 #[inline(never)]
868 fn execute_with_prepare(
869 &mut self,
870 sql: &str,
871 sql_hash: u64,
872 params: &[&(dyn Encode + Sync)],
873 ) -> Result<u64, DriverError> {
874 debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
875
876 if params.len() > i16::MAX as usize {
877 return Err(DriverError::Protocol(format!(
878 "parameter count {} exceeds maximum {}",
879 params.len(),
880 i16::MAX
881 )));
882 }
883
884 let name = make_stmt_name(sql_hash);
885 let param_oids: smallvec::SmallVec<[u32; 8]> =
886 params.iter().map(|p| p.type_oid()).collect();
887
888 self.write_buf.clear();
889 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
890 proto::write_describe(&mut self.write_buf, b'S', &name);
891 proto::write_bind_params(&mut self.write_buf, "", &name, params);
892 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
893 self.stream
894 .write_all(&self.write_buf)
895 .map_err(DriverError::Io)?;
896
897 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
898 let columns = self.read_column_description()?;
899 self.query_counter += 1;
900 self.cache_stmt(
901 sql_hash,
902 StmtInfo {
903 name,
904 columns,
905 last_used: self.query_counter,
906 bind_template: None,
907 },
908 );
909
910 let mut affected_rows: u64 = 0;
912 'outer: loop {
913 loop {
914 let avail = self.stream_buf_end - self.stream_buf_pos;
915 if avail < 5 {
916 break;
917 }
918
919 let msg_type = self.stream_buf[self.stream_buf_pos];
920 let raw_len = i32::from_be_bytes([
921 self.stream_buf[self.stream_buf_pos + 1],
922 self.stream_buf[self.stream_buf_pos + 2],
923 self.stream_buf[self.stream_buf_pos + 3],
924 self.stream_buf[self.stream_buf_pos + 4],
925 ]);
926
927 if raw_len < 4 {
928 return Err(DriverError::Protocol(format!(
929 "invalid message length {raw_len} for type '{}'",
930 msg_type as char
931 )));
932 }
933
934 let payload_len = (raw_len - 4) as usize;
935 let total_msg_len = 5 + payload_len;
936
937 if avail < total_msg_len {
938 if total_msg_len > self.stream_buf.len() {
939 let msg = self.read_one_message()?;
940 match msg {
941 BackendMessage::BindComplete | BackendMessage::DataRow { .. } => {
942 continue;
943 }
944 BackendMessage::CommandComplete { tag } => {
945 affected_rows = proto::parse_command_tag(tag);
946 continue;
947 }
948 BackendMessage::EmptyQuery => continue,
949 BackendMessage::ReadyForQuery { status } => {
950 self.tx_status = status;
951 break 'outer;
952 }
953 BackendMessage::NoticeResponse { .. } => continue,
954 BackendMessage::ErrorResponse { data } => {
955 let fields = proto::parse_error_response(data);
956 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
957 self.drain_to_ready()?;
958 return Err(self.make_server_error(fields));
959 }
960 other => {
961 return Err(DriverError::Protocol(format!(
962 "unexpected message during execute: {other:?}"
963 )));
964 }
965 }
966 }
967 break;
968 }
969
970 let payload_start = self.stream_buf_pos + 5;
971 let payload_end = payload_start + payload_len;
972
973 if msg_type == b'2' || msg_type == b'D' || msg_type == b'I' {
974 } else if msg_type == b'C' {
976 affected_rows = proto::parse_command_tag_bytes(
977 &self.stream_buf[payload_start..payload_end],
978 );
979 } else if msg_type == b'Z' {
980 if payload_len >= 1 {
981 self.tx_status = self.stream_buf[payload_start];
982 }
983 self.stream_buf_pos += total_msg_len;
984 break 'outer;
985 } else {
986 self.handle_non_datarow_execute(
987 msg_type,
988 payload_start,
989 payload_end,
990 sql_hash,
991 )?;
992 }
993
994 self.stream_buf_pos += total_msg_len;
995 }
996
997 self.refill_stream_buf()?;
998 }
999
1000 Ok(affected_rows)
1001 }
1002
1003 #[inline]
1008 pub fn execute(
1009 &mut self,
1010 sql: &str,
1011 sql_hash: u64,
1012 params: &[&(dyn Encode + Sync)],
1013 ) -> Result<u64, DriverError> {
1014 self.execute_monolithic(sql, sql_hash, params)
1015 }
1016
1017 pub fn execute_pipeline(
1029 &mut self,
1030 sql: &str,
1031 sql_hash: u64,
1032 param_sets: &[&[&(dyn Encode + Sync)]],
1033 ) -> Result<Vec<u64>, DriverError> {
1034 if param_sets.is_empty() {
1035 return Ok(Vec::new());
1036 }
1037
1038 debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
1039
1040 self.write_buf.clear();
1041
1042 if !self.stmts.contains_key(&sql_hash) {
1044 let name = make_stmt_name(sql_hash);
1045 let first_params = param_sets[0];
1046 if first_params.len() > i16::MAX as usize {
1047 return Err(DriverError::Protocol(format!(
1048 "parameter count {} exceeds maximum {}",
1049 first_params.len(),
1050 i16::MAX
1051 )));
1052 }
1053 let param_oids: smallvec::SmallVec<[u32; 8]> =
1054 first_params.iter().map(|p| p.type_oid()).collect();
1055 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1056 proto::write_describe(&mut self.write_buf, b'S', &name);
1057 proto::write_sync(&mut self.write_buf);
1058 self.flush_write()?;
1059
1060 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1061 let columns = self.read_column_description()?;
1062 self.expect_ready()?;
1063
1064 self.query_counter += 1;
1065 self.cache_stmt(
1066 sql_hash,
1067 StmtInfo {
1068 name,
1069 columns,
1070 last_used: self.query_counter,
1071 bind_template: None,
1072 },
1073 );
1074
1075 self.write_buf.clear();
1076 }
1077
1078 let stmt_name = self
1080 .stmts
1081 .get(&sql_hash)
1082 .expect("BUG: stmt just cached but not found")
1083 .name
1084 .clone();
1085 let count = param_sets.len();
1086
1087 for params in param_sets {
1088 if params.len() > i16::MAX as usize {
1089 return Err(DriverError::Protocol(format!(
1090 "parameter count {} exceeds maximum {}",
1091 params.len(),
1092 i16::MAX
1093 )));
1094 }
1095 proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1096 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1097 }
1098
1099 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1100 self.flush_write()?;
1101
1102 let mut results = Vec::with_capacity(count);
1104 for _ in 0..count {
1105 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1106
1107 let mut affected_rows: u64 = 0;
1108 loop {
1109 let msg = self.read_one_message()?;
1110 match msg {
1111 BackendMessage::DataRow { .. } => {}
1112 BackendMessage::CommandComplete { tag } => {
1113 affected_rows = proto::parse_command_tag(tag);
1114 break;
1115 }
1116 BackendMessage::EmptyQuery => break,
1117 BackendMessage::NoticeResponse { .. } => {}
1118 BackendMessage::ErrorResponse { data } => {
1119 let fields = proto::parse_error_response(data);
1120 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1121 self.drain_to_ready()?;
1122 return Err(self.make_server_error(fields));
1123 }
1124 other => {
1125 return Err(DriverError::Protocol(format!(
1126 "unexpected message during execute_pipeline: {other:?}"
1127 )));
1128 }
1129 }
1130 }
1131 results.push(affected_rows);
1132 }
1133
1134 self.expect_ready()?;
1135 self.shrink_buffers();
1136 Ok(results)
1137 }
1138
1139 pub(crate) fn ensure_stmt_prepared(
1145 &mut self,
1146 sql: &str,
1147 sql_hash: u64,
1148 params: &[&(dyn Encode + Sync)],
1149 ) -> Result<Box<str>, DriverError> {
1150 if let Some(info) = self.stmts.get(&sql_hash) {
1151 return Ok(info.name.clone());
1152 }
1153
1154 let name = make_stmt_name(sql_hash);
1155 if params.len() > i16::MAX as usize {
1156 return Err(DriverError::Protocol(format!(
1157 "parameter count {} exceeds maximum {}",
1158 params.len(),
1159 i16::MAX
1160 )));
1161 }
1162 let param_oids: smallvec::SmallVec<[u32; 8]> =
1163 params.iter().map(|p| p.type_oid()).collect();
1164
1165 self.write_buf.clear();
1166 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1167 proto::write_describe(&mut self.write_buf, b'S', &name);
1168 proto::write_sync(&mut self.write_buf);
1169 self.flush_write()?;
1170
1171 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1172 let columns = self.read_column_description()?;
1173 self.expect_ready()?;
1174
1175 self.query_counter += 1;
1176 let stmt_name = name.clone();
1177 self.cache_stmt(
1178 sql_hash,
1179 StmtInfo {
1180 name,
1181 columns,
1182 last_used: self.query_counter,
1183 bind_template: None,
1184 },
1185 );
1186
1187 Ok(stmt_name)
1188 }
1189
1190 pub(crate) fn write_deferred_bind_execute(
1193 &self,
1194 sql_hash: u64,
1195 params: &[&(dyn Encode + Sync)],
1196 buf: &mut Vec<u8>,
1197 ) {
1198 let stmt_name = &self
1199 .stmts
1200 .get(&sql_hash)
1201 .expect("BUG: stmt just cached but not found")
1202 .name;
1203 proto::write_bind_params(buf, "", stmt_name, params);
1204 buf.extend_from_slice(proto::EXECUTE_ONLY);
1205 }
1206
1207 pub(crate) fn flush_deferred_pipeline(
1212 &mut self,
1213 buf: &mut Vec<u8>,
1214 count: usize,
1215 ) -> Result<Vec<u64>, DriverError> {
1216 if count == 0 {
1217 buf.clear();
1218 return Ok(Vec::new());
1219 }
1220
1221 buf.extend_from_slice(proto::SYNC_ONLY);
1222
1223 self.stream.write_all(buf).map_err(DriverError::Io)?;
1224 buf.clear();
1225
1226 let mut results = Vec::with_capacity(count);
1227 for _ in 0..count {
1228 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1229
1230 let mut affected_rows: u64 = 0;
1231 loop {
1232 let msg = self.read_one_message()?;
1233 match msg {
1234 BackendMessage::DataRow { .. } => {}
1235 BackendMessage::CommandComplete { tag } => {
1236 affected_rows = proto::parse_command_tag(tag);
1237 break;
1238 }
1239 BackendMessage::EmptyQuery => break,
1240 BackendMessage::NoticeResponse { .. } => {}
1241 BackendMessage::ErrorResponse { data } => {
1242 let fields = proto::parse_error_response(data);
1243 self.drain_to_ready()?;
1244 return Err(self.make_server_error(fields));
1245 }
1246 other => {
1247 return Err(DriverError::Protocol(format!(
1248 "unexpected message during flush_deferred_pipeline: {other:?}"
1249 )));
1250 }
1251 }
1252 }
1253 results.push(affected_rows);
1254 }
1255
1256 self.expect_ready()?;
1257 self.shrink_buffers();
1258 Ok(results)
1259 }
1260
1261 pub fn for_each<F>(
1263 &mut self,
1264 sql: &str,
1265 sql_hash: u64,
1266 params: &[&(dyn Encode + Sync)],
1267 mut f: F,
1268 ) -> Result<(), DriverError>
1269 where
1270 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1271 {
1272 let _ = self.send_pipeline(sql, sql_hash, params, false, true)?;
1273
1274 'outer: loop {
1276 loop {
1277 let avail = self.stream_buf_end - self.stream_buf_pos;
1278 if avail < 5 {
1279 break; }
1281
1282 let msg_type = self.stream_buf[self.stream_buf_pos];
1283 let raw_len = i32::from_be_bytes([
1284 self.stream_buf[self.stream_buf_pos + 1],
1285 self.stream_buf[self.stream_buf_pos + 2],
1286 self.stream_buf[self.stream_buf_pos + 3],
1287 self.stream_buf[self.stream_buf_pos + 4],
1288 ]);
1289
1290 if raw_len < 4 {
1291 return Err(DriverError::Protocol(format!(
1292 "invalid message length {raw_len} for type '{}'",
1293 msg_type as char
1294 )));
1295 }
1296
1297 let payload_len = (raw_len - 4) as usize;
1298 let total_msg_len = 5 + payload_len;
1299
1300 if avail < total_msg_len {
1301 if total_msg_len > self.stream_buf.len() {
1302 let msg = self.read_one_message()?;
1304 match msg {
1305 BackendMessage::BindComplete => continue,
1306 BackendMessage::DataRow { data } => {
1307 let row = PgDataRow::new(data)?;
1308 f(row)?;
1309 continue;
1310 }
1311 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1312 continue;
1313 }
1314 BackendMessage::ReadyForQuery { status } => {
1315 self.tx_status = status;
1316 break 'outer;
1317 }
1318 BackendMessage::NoticeResponse { .. } => continue,
1319 BackendMessage::ErrorResponse { data } => {
1320 let fields = proto::parse_error_response(data);
1321 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1322 self.drain_to_ready()?;
1323 return Err(self.make_server_error(fields));
1324 }
1325 other => {
1326 return Err(DriverError::Protocol(format!(
1327 "unexpected message during for_each: {other:?}"
1328 )));
1329 }
1330 }
1331 }
1332 break; }
1334
1335 let payload_start = self.stream_buf_pos + 5;
1337 let payload_end = payload_start + payload_len;
1338
1339 if msg_type == b'D' {
1342 let row = PgDataRow::new(&self.stream_buf[payload_start..payload_end])?;
1344 f(row)?;
1345 } else if msg_type == b'Z' {
1346 if payload_len >= 1 {
1348 self.tx_status = self.stream_buf[payload_start];
1349 }
1350 self.stream_buf_pos += total_msg_len;
1351 break 'outer;
1352 } else {
1353 self.handle_non_datarow_execute(
1354 msg_type,
1355 payload_start,
1356 payload_end,
1357 sql_hash,
1358 )?;
1359 }
1360
1361 self.stream_buf_pos += total_msg_len;
1362 }
1363
1364 self.refill_stream_buf()?;
1366 }
1367
1368 self.shrink_buffers();
1369 Ok(())
1370 }
1371
1372 #[inline]
1383 pub fn for_each_raw_monolithic<F>(
1384 &mut self,
1385 sql: &str,
1386 sql_hash: u64,
1387 params: &[&(dyn Encode + Sync)],
1388 mut f: F,
1389 ) -> Result<(), DriverError>
1390 where
1391 F: FnMut(&[u8]) -> Result<(), DriverError>,
1392 {
1393 self.write_buf.clear();
1395
1396 let info = match self.stmts.get_mut(&sql_hash) {
1398 Some(info) => {
1399 self.query_counter += 1;
1400 info.last_used = self.query_counter;
1401 info
1402 }
1403 None => {
1404 return self.for_each_raw_with_prepare(sql, sql_hash, params, f);
1406 }
1407 };
1408
1409 let can_use_template = info
1411 .bind_template
1412 .as_ref()
1413 .is_some_and(|t| t.param_slots.len() == params.len());
1414
1415 let mut has_exec_sync = false;
1416
1417 if can_use_template {
1418 let tmpl = info.bind_template.as_ref().unwrap();
1419 self.write_buf.extend_from_slice(&tmpl.bytes);
1420
1421 let mut template_ok = true;
1422 for (i, param) in params.iter().enumerate() {
1423 let (data_offset, old_len) = tmpl.param_slots[i];
1424 if param.is_null() {
1425 let len_offset = data_offset - 4;
1426 self.write_buf[len_offset..len_offset + 4]
1427 .copy_from_slice(&(-1i32).to_be_bytes());
1428 } else if old_len >= 0 {
1429 let end = data_offset + old_len as usize;
1430 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1431 template_ok = false;
1432 break;
1433 }
1434 } else {
1435 template_ok = false;
1436 break;
1437 }
1438 }
1439
1440 if template_ok {
1441 has_exec_sync = true;
1442 } else {
1443 self.write_buf.clear();
1444 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1445 info.bind_template = None;
1446 }
1447 } else {
1448 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1449 }
1450
1451 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1453 info.bind_template = build_bind_template(&self.write_buf, params.len());
1454 }
1455
1456 if !has_exec_sync {
1457 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1458 }
1459
1460 self.stream
1462 .write_all(&self.write_buf)
1463 .map_err(DriverError::Io)?;
1464
1465 loop {
1469 let avail = self.stream_buf_end - self.stream_buf_pos;
1470 if avail >= 5 {
1471 let bc_type = self.stream_buf[self.stream_buf_pos];
1472 match bc_type {
1473 b'2' => {
1474 self.stream_buf_pos += 5;
1475 break;
1476 }
1477 b'E' => {
1478 let msg = self.read_one_message()?;
1479 if let BackendMessage::ErrorResponse { data } = msg {
1480 let fields = proto::parse_error_response(data);
1481 self.drain_to_ready()?;
1482 return Err(self.make_server_error(fields));
1483 }
1484 }
1485 b'N' | b'S' => {
1486 let raw_len = i32::from_be_bytes([
1487 self.stream_buf[self.stream_buf_pos + 1],
1488 self.stream_buf[self.stream_buf_pos + 2],
1489 self.stream_buf[self.stream_buf_pos + 3],
1490 self.stream_buf[self.stream_buf_pos + 4],
1491 ]);
1492 let total = 1 + raw_len as usize;
1493 if avail >= total {
1494 self.stream_buf_pos += total;
1495 continue;
1496 }
1497 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1498 break;
1499 }
1500 _ => {
1501 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1502 break;
1503 }
1504 }
1505 } else {
1506 let remaining = self.stream_buf_end - self.stream_buf_pos;
1508 if remaining > 0 && self.stream_buf_pos > 0 {
1509 self.stream_buf
1510 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1511 }
1512 self.stream_buf_pos = 0;
1513 self.stream_buf_end = remaining;
1514 let n = self
1515 .stream
1516 .read(&mut self.stream_buf[remaining..])
1517 .map_err(DriverError::Io)?;
1518 if n == 0 {
1519 return Err(DriverError::Io(std::io::Error::new(
1520 std::io::ErrorKind::UnexpectedEof,
1521 "connection closed",
1522 )));
1523 }
1524 self.stream_buf_end = remaining + n;
1525 }
1526 }
1527
1528 'outer: loop {
1530 loop {
1531 let avail = self.stream_buf_end - self.stream_buf_pos;
1532 if avail < 5 {
1533 break;
1534 }
1535
1536 let msg_type = self.stream_buf[self.stream_buf_pos];
1537 let raw_len = i32::from_be_bytes([
1538 self.stream_buf[self.stream_buf_pos + 1],
1539 self.stream_buf[self.stream_buf_pos + 2],
1540 self.stream_buf[self.stream_buf_pos + 3],
1541 self.stream_buf[self.stream_buf_pos + 4],
1542 ]);
1543
1544 if raw_len < 4 {
1545 return Err(DriverError::Protocol(format!(
1546 "invalid message length {raw_len} for type '{}'",
1547 msg_type as char
1548 )));
1549 }
1550
1551 let payload_len = (raw_len - 4) as usize;
1552 let total_msg_len = 5 + payload_len;
1553
1554 if avail < total_msg_len {
1555 if total_msg_len > self.stream_buf.len() {
1556 let msg = self.read_one_message()?;
1557 match msg {
1558 BackendMessage::DataRow { data } => {
1559 f(data)?;
1560 continue;
1561 }
1562 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1563 continue;
1564 }
1565 BackendMessage::ReadyForQuery { status } => {
1566 self.tx_status = status;
1567 break 'outer;
1568 }
1569 BackendMessage::ErrorResponse { data } => {
1570 let fields = proto::parse_error_response(data);
1571 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1572 self.drain_to_ready()?;
1573 return Err(self.make_server_error(fields));
1574 }
1575 BackendMessage::NoticeResponse { .. } => continue,
1576 other => {
1577 return Err(DriverError::Protocol(format!(
1578 "unexpected message during for_each_raw: {other:?}"
1579 )));
1580 }
1581 }
1582 }
1583 break; }
1585
1586 let payload_start = self.stream_buf_pos + 5;
1588 let payload_end = payload_start + payload_len;
1589
1590 if msg_type == b'D' {
1591 f(&self.stream_buf[payload_start..payload_end])?;
1592 } else if msg_type == b'Z' {
1593 if payload_len >= 1 {
1594 self.tx_status = self.stream_buf[payload_start];
1595 }
1596 self.stream_buf_pos += total_msg_len;
1597 break 'outer;
1598 } else {
1599 self.handle_non_datarow_execute(
1600 msg_type,
1601 payload_start,
1602 payload_end,
1603 sql_hash,
1604 )?;
1605 }
1606
1607 self.stream_buf_pos += total_msg_len;
1608 }
1609
1610 let remaining = self.stream_buf_end - self.stream_buf_pos;
1612 if remaining > 0 && self.stream_buf_pos > 0 {
1613 self.stream_buf
1614 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1615 }
1616 self.stream_buf_pos = 0;
1617 self.stream_buf_end = remaining;
1618 let n = self
1619 .stream
1620 .read(&mut self.stream_buf[remaining..])
1621 .map_err(DriverError::Io)?;
1622 if n == 0 {
1623 return Err(DriverError::Io(std::io::Error::new(
1624 std::io::ErrorKind::UnexpectedEof,
1625 "connection closed",
1626 )));
1627 }
1628 self.stream_buf_end = remaining + n;
1629 }
1630
1631 if self.query_counter & 63 == 0 {
1633 if self.read_buf.capacity() > 64 * 1024 {
1634 self.read_buf.clear();
1635 self.read_buf.shrink_to(8192);
1636 }
1637 if self.write_buf.capacity() > 16 * 1024 {
1638 self.write_buf.clear();
1639 self.write_buf.shrink_to(8192);
1640 }
1641 }
1642
1643 Ok(())
1644 }
1645
1646 #[cold]
1648 #[inline(never)]
1649 fn for_each_raw_with_prepare<F>(
1650 &mut self,
1651 sql: &str,
1652 sql_hash: u64,
1653 params: &[&(dyn Encode + Sync)],
1654 mut f: F,
1655 ) -> Result<(), DriverError>
1656 where
1657 F: FnMut(&[u8]) -> Result<(), DriverError>,
1658 {
1659 debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
1660
1661 if params.len() > i16::MAX as usize {
1662 return Err(DriverError::Protocol(format!(
1663 "parameter count {} exceeds maximum {}",
1664 params.len(),
1665 i16::MAX
1666 )));
1667 }
1668
1669 let name = make_stmt_name(sql_hash);
1670 let param_oids: smallvec::SmallVec<[u32; 8]> =
1671 params.iter().map(|p| p.type_oid()).collect();
1672
1673 self.write_buf.clear();
1674 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1675 proto::write_describe(&mut self.write_buf, b'S', &name);
1676 proto::write_bind_params(&mut self.write_buf, "", &name, params);
1677 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1678 self.stream
1679 .write_all(&self.write_buf)
1680 .map_err(DriverError::Io)?;
1681
1682 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
1683 let columns = self.read_column_description()?;
1684 self.query_counter += 1;
1685 self.cache_stmt(
1686 sql_hash,
1687 StmtInfo {
1688 name,
1689 columns,
1690 last_used: self.query_counter,
1691 bind_template: None,
1692 },
1693 );
1694
1695 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
1697
1698 'outer: loop {
1699 loop {
1700 let avail = self.stream_buf_end - self.stream_buf_pos;
1701 if avail < 5 {
1702 break;
1703 }
1704
1705 let msg_type = self.stream_buf[self.stream_buf_pos];
1706 let raw_len = i32::from_be_bytes([
1707 self.stream_buf[self.stream_buf_pos + 1],
1708 self.stream_buf[self.stream_buf_pos + 2],
1709 self.stream_buf[self.stream_buf_pos + 3],
1710 self.stream_buf[self.stream_buf_pos + 4],
1711 ]);
1712
1713 if raw_len < 4 {
1714 return Err(DriverError::Protocol(format!(
1715 "invalid message length {raw_len} for type '{}'",
1716 msg_type as char
1717 )));
1718 }
1719
1720 let payload_len = (raw_len - 4) as usize;
1721 let total_msg_len = 5 + payload_len;
1722
1723 if avail < total_msg_len {
1724 if total_msg_len > self.stream_buf.len() {
1725 let msg = self.read_one_message()?;
1726 match msg {
1727 BackendMessage::DataRow { data } => {
1728 f(data)?;
1729 continue;
1730 }
1731 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1732 continue;
1733 }
1734 BackendMessage::ReadyForQuery { status } => {
1735 self.tx_status = status;
1736 break 'outer;
1737 }
1738 BackendMessage::ErrorResponse { data } => {
1739 let fields = proto::parse_error_response(data);
1740 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1741 self.drain_to_ready()?;
1742 return Err(self.make_server_error(fields));
1743 }
1744 BackendMessage::NoticeResponse { .. } => continue,
1745 other => {
1746 return Err(DriverError::Protocol(format!(
1747 "unexpected message during for_each_raw: {other:?}"
1748 )));
1749 }
1750 }
1751 }
1752 break;
1753 }
1754
1755 let payload_start = self.stream_buf_pos + 5;
1756 let payload_end = payload_start + payload_len;
1757
1758 if msg_type == b'D' {
1759 f(&self.stream_buf[payload_start..payload_end])?;
1760 } else if msg_type == b'Z' {
1761 if payload_len >= 1 {
1762 self.tx_status = self.stream_buf[payload_start];
1763 }
1764 self.stream_buf_pos += total_msg_len;
1765 break 'outer;
1766 } else {
1767 self.handle_non_datarow_execute(
1768 msg_type,
1769 payload_start,
1770 payload_end,
1771 sql_hash,
1772 )?;
1773 }
1774
1775 self.stream_buf_pos += total_msg_len;
1776 }
1777
1778 self.refill_stream_buf()?;
1779 }
1780
1781 self.shrink_buffers();
1782 Ok(())
1783 }
1784
1785 #[inline]
1790 pub fn for_each_raw<F>(
1791 &mut self,
1792 sql: &str,
1793 sql_hash: u64,
1794 params: &[&(dyn Encode + Sync)],
1795 f: F,
1796 ) -> Result<(), DriverError>
1797 where
1798 F: FnMut(&[u8]) -> Result<(), DriverError>,
1799 {
1800 self.for_each_raw_monolithic(sql, sql_hash, params, f)
1801 }
1802
1803 pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1805 self.write_buf.clear();
1806 proto::write_simple_query(&mut self.write_buf, sql);
1807 self.flush_write()?;
1808
1809 loop {
1810 let msg = self.read_one_message()?;
1811 match msg {
1812 BackendMessage::ReadyForQuery { status } => {
1813 self.tx_status = status;
1814 return Ok(());
1815 }
1816 BackendMessage::CommandComplete { .. }
1817 | BackendMessage::RowDescription { .. }
1818 | BackendMessage::DataRow { .. }
1819 | BackendMessage::EmptyQuery
1820 | BackendMessage::NoticeResponse { .. }
1821 | BackendMessage::ParameterStatus { .. } => {}
1822 BackendMessage::ErrorResponse { data } => {
1823 let fields = proto::parse_error_response(data);
1824 self.drain_to_ready()?;
1825 return Err(self.make_server_error(fields));
1826 }
1827 other => {
1828 return Err(DriverError::Protocol(format!(
1829 "unexpected message during simple_query: {other:?}"
1830 )));
1831 }
1832 }
1833 }
1834 }
1835
1836 pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
1838 self.write_buf.clear();
1839 proto::write_simple_query(&mut self.write_buf, sql);
1840 self.flush_write()?;
1841
1842 let mut rows: Vec<SimpleRow> = Vec::new();
1843 loop {
1844 let msg = self.read_one_message()?;
1845 match msg {
1846 BackendMessage::ReadyForQuery { status } => {
1847 self.tx_status = status;
1848 return Ok(rows);
1849 }
1850 BackendMessage::DataRow { data } => {
1851 rows.push(proto::parse_simple_data_row(data)?);
1852 }
1853 BackendMessage::RowDescription { .. }
1854 | BackendMessage::CommandComplete { .. }
1855 | BackendMessage::EmptyQuery
1856 | BackendMessage::NoticeResponse { .. }
1857 | BackendMessage::ParameterStatus { .. } => {}
1858 BackendMessage::ErrorResponse { data } => {
1859 let fields = proto::parse_error_response(data);
1860 self.drain_to_ready()?;
1861 return Err(self.make_server_error(fields));
1862 }
1863 other => {
1864 return Err(DriverError::Protocol(format!(
1865 "unexpected message during simple_query_rows: {other:?}"
1866 )));
1867 }
1868 }
1869 }
1870 }
1871
1872 pub fn close(mut self) -> Result<(), DriverError> {
1874 self.write_buf.clear();
1875 proto::write_terminate(&mut self.write_buf);
1876 let _ = self.flush_write();
1877 Ok(())
1878 }
1879
1880 pub fn is_idle(&self) -> bool {
1884 self.tx_status == b'I'
1885 }
1886
1887 pub fn is_in_transaction(&self) -> bool {
1889 self.tx_status == b'T'
1890 }
1891
1892 pub fn is_in_failed_transaction(&self) -> bool {
1894 self.tx_status == b'E'
1895 }
1896
1897 pub fn touch(&mut self) {
1899 self.last_used = std::time::Instant::now();
1900 }
1901
1902 pub fn idle_duration(&self) -> std::time::Duration {
1904 self.last_used.elapsed()
1905 }
1906
1907 pub fn parameter(&self, name: &str) -> Option<&str> {
1909 self.params
1910 .iter()
1911 .find(|(k, _)| &**k == name)
1912 .map(|(_, v)| &**v)
1913 }
1914
1915 pub fn pid(&self) -> i32 {
1917 self.pid
1918 }
1919
1920 pub fn secret_key(&self) -> i32 {
1922 self.secret
1923 }
1924
1925 pub fn drain_notifications(&mut self) -> Vec<Notification> {
1927 std::mem::take(&mut self.pending_notifications)
1928 }
1929
1930 pub fn pending_notification_count(&self) -> usize {
1932 self.pending_notifications.len()
1933 }
1934
1935 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
1937 self.max_stmt_cache_size = size;
1938 }
1939
1940 pub fn stmt_cache_len(&self) -> usize {
1942 self.stmts.len()
1943 }
1944
1945 pub fn created_at(&self) -> std::time::Instant {
1947 self.created_at
1948 }
1949
1950 #[inline]
1958 fn send_pipeline(
1959 &mut self,
1960 sql: &str,
1961 sql_hash: u64,
1962 params: &[&(dyn Encode + Sync)],
1963 need_columns: bool,
1964 skip_bind_complete: bool,
1965 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
1966 debug_assert_eq!(crate::conn::hash_sql(sql), sql_hash, "sql_hash mismatch");
1967
1968 if params.len() > i16::MAX as usize {
1969 return Err(DriverError::Protocol(format!(
1970 "parameter count {} exceeds maximum {}",
1971 params.len(),
1972 i16::MAX
1973 )));
1974 }
1975
1976 self.write_buf.clear();
1977
1978 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
1979 self.query_counter += 1;
1981 info.last_used = self.query_counter;
1982
1983 let can_use_template = info
1984 .bind_template
1985 .as_ref()
1986 .is_some_and(|t| t.param_slots.len() == params.len());
1987
1988 let mut has_exec_sync = false;
1990
1991 if can_use_template {
1992 let tmpl = info.bind_template.as_ref().unwrap();
1995 self.write_buf.extend_from_slice(&tmpl.bytes);
1996
1997 let mut template_ok = true;
1998 for (i, param) in params.iter().enumerate() {
1999 let (data_offset, old_len) = tmpl.param_slots[i];
2000 if param.is_null() {
2001 let len_offset = data_offset - 4;
2003 self.write_buf[len_offset..len_offset + 4]
2004 .copy_from_slice(&(-1i32).to_be_bytes());
2005 } else if old_len >= 0 {
2006 let end = data_offset + old_len as usize;
2007 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
2008 template_ok = false;
2010 break;
2011 }
2012 } else {
2013 template_ok = false;
2016 break;
2017 }
2018 }
2019
2020 if template_ok {
2021 has_exec_sync = true; } else {
2023 self.write_buf.clear();
2024 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
2025 info.bind_template = None;
2027 }
2028 } else {
2029 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
2030 }
2031
2032 let cols = if need_columns {
2033 Some(info.columns.clone())
2034 } else {
2035 None
2036 };
2037
2038 if info.bind_template.is_none() && !self.write_buf.is_empty() {
2042 info.bind_template = build_bind_template(&self.write_buf, params.len());
2043 }
2044
2045 if !has_exec_sync {
2046 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2047 }
2048 self.flush_write()?;
2049
2050 cols
2051 } else {
2052 let name = make_stmt_name(sql_hash);
2054 let param_oids: smallvec::SmallVec<[u32; 8]> =
2055 params.iter().map(|p| p.type_oid()).collect();
2056 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
2057 proto::write_describe(&mut self.write_buf, b'S', &name);
2058 proto::write_bind_params(&mut self.write_buf, "", &name, params);
2059
2060 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
2061 self.flush_write()?;
2062
2063 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))?;
2064 let columns = self.read_column_description()?;
2065 self.query_counter += 1;
2066 self.cache_stmt(
2067 sql_hash,
2068 StmtInfo {
2069 name,
2070 columns: columns.clone(),
2071 last_used: self.query_counter,
2072 bind_template: None,
2073 },
2074 );
2075 if need_columns { Some(columns) } else { None }
2076 };
2077
2078 if !skip_bind_complete {
2079 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))?;
2080 }
2081
2082 Ok(columns)
2083 }
2084
2085 fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
2087 loop {
2088 let msg = self.read_one_message()?;
2089 match msg {
2090 BackendMessage::RowDescription { data } => {
2091 let cols = proto::parse_row_description(data)?;
2092 return Ok(cols.into());
2093 }
2094 BackendMessage::ParameterDescription { .. } => {}
2095 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
2096 BackendMessage::NoticeResponse { .. } => {}
2097 BackendMessage::ErrorResponse { data } => {
2098 let fields = proto::parse_error_response(data);
2099 self.drain_to_ready()?;
2100 return Err(self.make_server_error(fields));
2101 }
2102 other => {
2103 return Err(DriverError::Protocol(format!(
2104 "expected RowDescription/NoData, got: {other:?}"
2105 )));
2106 }
2107 }
2108 }
2109 }
2110
2111 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2114 if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
2115 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2116 proto::write_close(&mut self.write_buf, b'S', &evicted.name);
2117 }
2118 }
2119 self.stmts.insert(sql_hash, info);
2120 }
2121
2122 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2123 if self.pending_notifications.len() < 1024 {
2124 self.pending_notifications.push(Notification {
2125 pid,
2126 channel: channel.to_owned(),
2127 payload: payload.to_owned(),
2128 });
2129 }
2130 }
2131
2132 fn shrink_buffers(&mut self) {
2133 if self.query_counter & 63 != 0 {
2137 return;
2138 }
2139 if self.read_buf.capacity() > 64 * 1024 {
2140 self.read_buf.clear();
2141 self.read_buf.shrink_to(8192);
2142 }
2143 if self.write_buf.capacity() > 16 * 1024 {
2144 self.write_buf.clear();
2145 self.write_buf.shrink_to(8192);
2146 }
2147 }
2148
2149 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2150 if &*fields.code == "26000" {
2151 self.stmts.remove(&sql_hash);
2152 true
2153 } else {
2154 false
2155 }
2156 }
2157
2158 #[cold]
2159 #[inline(never)]
2160 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2161 DriverError::Server {
2162 code: fields.code,
2163 message: fields.message.into_boxed_str(),
2164 detail: fields.detail.map(String::into_boxed_str),
2165 hint: fields.hint.map(String::into_boxed_str),
2166 position: fields.position,
2167 }
2168 }
2169
2170 #[cold]
2176 #[inline(never)]
2177 fn handle_non_datarow_query(
2178 &mut self,
2179 msg_type: u8,
2180 payload_start: usize,
2181 payload_end: usize,
2182 sql_hash: u64,
2183 affected_rows: &mut u64,
2184 ) -> Result<(), DriverError> {
2185 match msg_type {
2186 b'2' | b'I' => {} b'C' => {
2188 *affected_rows =
2189 proto::parse_command_tag_bytes(&self.stream_buf[payload_start..payload_end]);
2190 }
2191 b'E' => {
2192 let fields =
2193 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2194 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2195 self.drain_to_ready()?;
2196 return Err(self.make_server_error(fields));
2197 }
2198 b'A' => {
2199 let msg = proto::parse_backend_message(
2200 msg_type,
2201 &self.stream_buf[payload_start..payload_end],
2202 )?;
2203 if let BackendMessage::NotificationResponse {
2204 pid,
2205 channel,
2206 payload,
2207 } = msg
2208 {
2209 let ch = channel.to_owned();
2210 let pl = payload.to_owned();
2211 self.buffer_notification(pid, &ch, &pl);
2212 }
2213 }
2214 _ => {} }
2216 Ok(())
2217 }
2218
2219 #[cold]
2222 #[inline(never)]
2223 fn handle_non_datarow_execute(
2224 &mut self,
2225 msg_type: u8,
2226 payload_start: usize,
2227 payload_end: usize,
2228 sql_hash: u64,
2229 ) -> Result<(), DriverError> {
2230 match msg_type {
2231 b'2' | b'C' | b'I' => {} b'E' => {
2233 let fields =
2234 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2235 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2236 self.drain_to_ready()?;
2237 return Err(self.make_server_error(fields));
2238 }
2239 b'A' => {
2240 let msg = proto::parse_backend_message(
2241 msg_type,
2242 &self.stream_buf[payload_start..payload_end],
2243 )?;
2244 if let BackendMessage::NotificationResponse {
2245 pid,
2246 channel,
2247 payload,
2248 } = msg
2249 {
2250 let ch = channel.to_owned();
2251 let pl = payload.to_owned();
2252 self.buffer_notification(pid, &ch, &pl);
2253 }
2254 }
2255 _ => {} }
2257 Ok(())
2258 }
2259
2260 #[inline]
2262 fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2263 loop {
2264 let (msg_type, _payload_len) = self.read_message_buffered()?;
2265 if msg_type == b'A' {
2266 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2267 if let BackendMessage::NotificationResponse {
2268 pid,
2269 channel,
2270 payload,
2271 } = msg
2272 {
2273 let pid_owned = pid;
2274 let channel_owned = channel.to_owned();
2275 let payload_owned = payload.to_owned();
2276 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2277 continue;
2278 }
2279 }
2280 return proto::parse_backend_message(msg_type, &self.read_buf);
2281 }
2282 }
2283
2284 fn expect_message(
2285 &mut self,
2286 pred: impl Fn(&BackendMessage<'_>) -> bool,
2287 ) -> Result<(), DriverError> {
2288 loop {
2289 let msg = self.read_one_message()?;
2290 if pred(&msg) {
2291 return Ok(());
2292 }
2293 match msg {
2294 BackendMessage::ErrorResponse { data } => {
2295 let fields = proto::parse_error_response(data);
2296 self.drain_to_ready()?;
2297 return Err(self.make_server_error(fields));
2298 }
2299 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2300 other => {
2301 return Err(DriverError::Protocol(format!(
2302 "unexpected message while waiting for expected type: {other:?}"
2303 )));
2304 }
2305 }
2306 }
2307 }
2308
2309 fn expect_ready(&mut self) -> Result<(), DriverError> {
2310 loop {
2311 let msg = self.read_one_message()?;
2312 match msg {
2313 BackendMessage::ReadyForQuery { status } => {
2314 self.tx_status = status;
2315 return Ok(());
2316 }
2317 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2318 BackendMessage::ErrorResponse { data } => {
2319 let fields = proto::parse_error_response(data);
2320 self.drain_to_ready()?;
2321 return Err(self.make_server_error(fields));
2322 }
2323 _ => {}
2324 }
2325 }
2326 }
2327
2328 #[inline]
2329 fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2330 loop {
2331 let msg = self.read_one_message()?;
2332 if let BackendMessage::ReadyForQuery { status } = msg {
2333 self.tx_status = status;
2334 return Ok(());
2335 }
2336 }
2337 }
2338
2339 #[inline]
2343 fn flush_write(&mut self) -> Result<(), DriverError> {
2344 self.stream
2345 .write_all(&self.write_buf)
2346 .map_err(DriverError::Io)
2347 }
2348
2349 fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2353 let mut header = [0u8; 5];
2354 sync_buffered_read_exact(
2355 &mut self.stream,
2356 &mut self.stream_buf,
2357 &mut self.stream_buf_pos,
2358 &mut self.stream_buf_end,
2359 &mut header,
2360 )?;
2361
2362 let msg_type = header[0];
2363 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2364
2365 if len < 4 {
2366 return Err(DriverError::Protocol(format!(
2367 "invalid message length {len} for type '{}'",
2368 msg_type as char
2369 )));
2370 }
2371
2372 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2373 if len > MAX_MESSAGE_LEN {
2374 return Err(DriverError::Protocol(format!(
2375 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2376 msg_type as char
2377 )));
2378 }
2379
2380 let payload_len = (len - 4) as usize;
2381 self.read_buf.clear();
2382 self.read_buf.resize(payload_len, 0);
2383 if payload_len > 0 {
2384 sync_buffered_read_exact(
2385 &mut self.stream,
2386 &mut self.stream_buf,
2387 &mut self.stream_buf_pos,
2388 &mut self.stream_buf_end,
2389 &mut self.read_buf[..payload_len],
2390 )?;
2391 }
2392
2393 Ok((msg_type, payload_len))
2394 }
2395
2396 #[inline]
2398 fn refill_stream_buf(&mut self) -> Result<(), DriverError> {
2399 let remaining = self.stream_buf_end - self.stream_buf_pos;
2400 debug_assert!(
2404 remaining == 0 || self.stream_buf_pos > 0,
2405 "refill_stream_buf called with pos=0 and remaining data"
2406 );
2407 if remaining > 0 {
2408 self.stream_buf
2409 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
2410 }
2411 self.stream_buf_pos = 0;
2412 self.stream_buf_end = remaining;
2413
2414 let n = self
2415 .stream
2416 .read(&mut self.stream_buf[remaining..])
2417 .map_err(DriverError::Io)?;
2418 if n == 0 {
2419 return Err(DriverError::Io(std::io::Error::new(
2420 std::io::ErrorKind::UnexpectedEof,
2421 "connection closed",
2422 )));
2423 }
2424 self.stream_buf_end = remaining + n;
2425 Ok(())
2426 }
2427}
2428
2429fn sync_buffered_read_exact(
2432 stream: &mut UnixStream,
2433 buf: &mut [u8],
2434 pos: &mut usize,
2435 end: &mut usize,
2436 out: &mut [u8],
2437) -> Result<(), DriverError> {
2438 let mut filled = 0;
2439 while filled < out.len() {
2440 let avail = *end - *pos;
2441 if avail > 0 {
2442 let take = avail.min(out.len() - filled);
2443 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
2444 *pos += take;
2445 filled += take;
2446 } else {
2447 *pos = 0;
2448 let n = stream.read(buf).map_err(DriverError::Io)?;
2449 if n == 0 {
2450 return Err(DriverError::Io(std::io::Error::new(
2451 std::io::ErrorKind::UnexpectedEof,
2452 "connection closed",
2453 )));
2454 }
2455 *end = n;
2456 }
2457 }
2458 Ok(())
2459}
2460
2461fn parse_data_row_flat(
2465 data: &[u8],
2466 arena: &mut Arena,
2467 out: &mut Vec<(usize, i32)>,
2468) -> Result<(), DriverError> {
2469 if data.len() < 2 {
2470 return Err(DriverError::Protocol("DataRow too short".into()));
2471 }
2472
2473 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
2474 if num_cols_raw < 0 {
2475 return Err(DriverError::Protocol(
2476 "DataRow: negative column count".into(),
2477 ));
2478 }
2479 let num_cols = num_cols_raw as usize;
2480 out.reserve(num_cols);
2481 let mut pos = 2;
2482
2483 for _ in 0..num_cols {
2484 if pos + 4 > data.len() {
2485 return Err(DriverError::Protocol("DataRow truncated".into()));
2486 }
2487
2488 let col_len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
2489 pos += 4;
2490
2491 if col_len < 0 {
2492 out.push((0, -1));
2493 } else {
2494 let len = col_len as usize;
2495 if pos + len > data.len() {
2496 return Err(DriverError::Protocol(
2497 "DataRow column data truncated".into(),
2498 ));
2499 }
2500 let offset = arena.alloc_copy(&data[pos..pos + len]);
2501 out.push((offset, col_len));
2502 pos += len;
2503 }
2504 }
2505
2506 Ok(())
2507}
2508
2509fn build_bind_template(write_buf: &[u8], param_count: usize) -> Option<BindTemplate> {
2518 if write_buf.is_empty() || write_buf[0] != b'B' {
2520 return None;
2521 }
2522
2523 if write_buf.len() < 5 {
2524 return None;
2525 }
2526
2527 let mut pos = 5;
2529
2530 while pos < write_buf.len() && write_buf[pos] != 0 {
2532 pos += 1;
2533 }
2534 pos += 1; while pos < write_buf.len() && write_buf[pos] != 0 {
2538 pos += 1;
2539 }
2540 pos += 1; if pos + 2 > write_buf.len() {
2544 return None;
2545 }
2546 let num_fmt_codes = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]);
2547 pos += 2;
2548 pos += num_fmt_codes.max(0) as usize * 2; if pos + 2 > write_buf.len() {
2552 return None;
2553 }
2554 let wire_param_count = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]) as usize;
2555 pos += 2;
2556
2557 if wire_param_count != param_count {
2558 return None;
2559 }
2560
2561 let mut param_slots = Vec::with_capacity(param_count);
2562 for _ in 0..param_count {
2563 if pos + 4 > write_buf.len() {
2564 return None;
2565 }
2566 let data_len = i32::from_be_bytes([
2567 write_buf[pos],
2568 write_buf[pos + 1],
2569 write_buf[pos + 2],
2570 write_buf[pos + 3],
2571 ]);
2572 pos += 4;
2573
2574 if data_len < 0 {
2575 param_slots.push((pos, -1));
2577 } else {
2578 let data_offset = pos;
2579 param_slots.push((data_offset, data_len));
2580 pos += data_len as usize;
2581 }
2582 }
2583
2584 let mut bytes = Vec::with_capacity(write_buf.len() + proto::EXECUTE_SYNC.len());
2586 bytes.extend_from_slice(write_buf);
2587 bytes.extend_from_slice(proto::EXECUTE_SYNC);
2588
2589 Some(BindTemplate { bytes, param_slots })
2590}
2591
2592#[cfg(test)]
2593#[allow(clippy::approx_constant)]
2594mod tests {
2595 use super::*;
2596 use crate::conn::hash_sql;
2597
2598 #[test]
2599 fn sync_make_stmt_name() {
2600 let name = make_stmt_name(0);
2601 assert_eq!(&*name, "s_0000000000000000");
2602 let name = make_stmt_name(0xDEADBEEF12345678);
2603 assert_eq!(&*name, "s_deadbeef12345678");
2604 }
2605
2606 #[test]
2607 fn sync_stmt_cache_basic() {
2608 let cache = StmtCache::default();
2609 assert_eq!(cache.len(), 0);
2610 assert!(!cache.contains_key(&42));
2611 }
2612
2613 #[test]
2614 fn sync_config_rejects_tcp() {
2615 let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
2616 let result = SyncConnection::connect(&config);
2617 assert!(result.is_err());
2618 let err = result.unwrap_err().to_string();
2619 assert!(
2620 err.contains("Unix domain socket"),
2621 "error should mention UDS requirement: {err}"
2622 );
2623 }
2624
2625 #[test]
2626 fn sync_data_row_parsing() {
2627 let mut arena = Arena::new();
2628 let mut out = Vec::new();
2629
2630 let mut data = Vec::new();
2631 data.extend_from_slice(&2i16.to_be_bytes());
2632 data.extend_from_slice(&4i32.to_be_bytes());
2633 data.extend_from_slice(&42i32.to_be_bytes());
2634 data.extend_from_slice(&(-1i32).to_be_bytes());
2635
2636 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2637 assert_eq!(out.len(), 2);
2638 assert_eq!(out[0].1, 4);
2639 assert_eq!(out[1].1, -1);
2640 }
2641
2642 #[test]
2643 fn sync_data_row_empty() {
2644 let mut arena = Arena::new();
2645 let mut out = Vec::new();
2646 let data = 0i16.to_be_bytes();
2647 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2648 assert_eq!(out.len(), 0);
2649 }
2650
2651 #[test]
2652 fn sync_data_row_too_short() {
2653 let mut arena = Arena::new();
2654 let mut out = Vec::new();
2655 let data = vec![0u8];
2656 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2657 }
2658
2659 #[test]
2660 fn sync_data_row_negative_col_count() {
2661 let mut arena = Arena::new();
2662 let mut out = Vec::new();
2663 let data = (-1i16).to_be_bytes();
2664 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2665 }
2666
2667 #[test]
2668 fn sync_data_row_truncated() {
2669 let mut arena = Arena::new();
2670 let mut out = Vec::new();
2671 let mut data = Vec::new();
2672 data.extend_from_slice(&2i16.to_be_bytes());
2673 data.extend_from_slice(&4i32.to_be_bytes());
2674 data.extend_from_slice(&42i32.to_be_bytes());
2675 assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2677 }
2678
2679 #[test]
2680 fn sync_data_row_col_data_truncated() {
2681 let mut arena = Arena::new();
2682 let mut out = Vec::new();
2683 let mut data = Vec::new();
2684 data.extend_from_slice(&1i16.to_be_bytes());
2685 data.extend_from_slice(&100i32.to_be_bytes()); data.push(0); assert!(parse_data_row_flat(&data, &mut arena, &mut out).is_err());
2688 }
2689
2690 #[test]
2691 fn build_bind_template_basic() {
2692 let mut buf = Vec::new();
2693 let val: i32 = 42;
2694 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
2695
2696 let tmpl = build_bind_template(&buf, 1);
2697 assert!(tmpl.is_some());
2698 let tmpl = tmpl.unwrap();
2699 assert_eq!(tmpl.param_slots.len(), 1);
2700 assert_eq!(tmpl.param_slots[0].1, 4);
2702 }
2703
2704 #[test]
2705 fn build_bind_template_null_param() {
2706 let mut buf = Vec::new();
2707 let val: Option<i32> = None;
2708 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
2709
2710 let tmpl = build_bind_template(&buf, 1);
2711 assert!(tmpl.is_some());
2712 let tmpl = tmpl.unwrap();
2713 assert_eq!(tmpl.param_slots.len(), 1);
2714 assert_eq!(tmpl.param_slots[0].1, -1); }
2716
2717 #[test]
2718 fn build_bind_template_multiple_params() {
2719 let mut buf = Vec::new();
2720 let id: i32 = 1;
2721 let name: &str = "alice";
2722 proto::write_bind_params(
2723 &mut buf,
2724 "",
2725 "s_test",
2726 &[&id as &(dyn Encode + Sync), &name as &(dyn Encode + Sync)],
2727 );
2728
2729 let tmpl = build_bind_template(&buf, 2);
2730 assert!(tmpl.is_some());
2731 let tmpl = tmpl.unwrap();
2732 assert_eq!(tmpl.param_slots.len(), 2);
2733 assert_eq!(tmpl.param_slots[0].1, 4); assert_eq!(tmpl.param_slots[1].1, 5); }
2736
2737 #[test]
2738 fn build_bind_template_empty_buf() {
2739 let tmpl = build_bind_template(&[], 0);
2740 assert!(tmpl.is_none());
2741 }
2742
2743 #[test]
2744 fn build_bind_template_wrong_type() {
2745 let tmpl = build_bind_template(&[b'E', 0, 0, 0, 4], 0);
2746 assert!(tmpl.is_none());
2747 }
2748
2749 #[test]
2750 fn build_bind_template_param_count_mismatch() {
2751 let mut buf = Vec::new();
2752 let val: i32 = 42;
2753 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
2754
2755 let tmpl = build_bind_template(&buf, 2);
2757 assert!(tmpl.is_none());
2758 }
2759
2760 #[test]
2761 fn hash_sql_consistency() {
2762 let h = hash_sql("SELECT 1");
2764 assert_eq!(h, hash_sql("SELECT 1"));
2765 assert_ne!(h, hash_sql("SELECT 2"));
2766 }
2767
2768 #[test]
2771 fn sync_connect_tcp_fails_with_uds_message() {
2772 let config = Config::from_url("postgres://user:pass@localhost:5432/db").unwrap();
2773 let result = SyncConnection::connect(&config);
2774 assert!(result.is_err());
2775 let err = result.unwrap_err().to_string();
2776 assert!(
2777 err.contains("Unix domain socket"),
2778 "error should mention UDS: {err}"
2779 );
2780 }
2781
2782 #[test]
2783 fn sync_connect_ip_address_fails() {
2784 let config = Config::from_url("postgres://user:pass@127.0.0.1:5432/db").unwrap();
2785 let result = SyncConnection::connect(&config);
2786 assert!(result.is_err());
2787 }
2788
2789 #[test]
2792 fn sync_make_stmt_name_max() {
2793 let name = make_stmt_name(u64::MAX);
2794 assert_eq!(&*name, "s_ffffffffffffffff");
2795 }
2796
2797 #[test]
2798 fn sync_make_stmt_name_one() {
2799 let name = make_stmt_name(1);
2800 assert_eq!(&*name, "s_0000000000000001");
2801 }
2802
2803 #[test]
2804 fn sync_make_stmt_name_powers_of_two() {
2805 let name = make_stmt_name(256);
2806 assert_eq!(&*name, "s_0000000000000100");
2807 }
2808
2809 #[test]
2810 fn sync_make_stmt_name_format_always_18_chars() {
2811 for val in [0u64, 1, 0xFF, 0xFFFF, 0xFFFF_FFFF, u64::MAX] {
2812 let name = make_stmt_name(val);
2813 assert_eq!(name.len(), 18, "name len for {val:x}");
2814 assert!(name.starts_with("s_"));
2815 assert!(name[2..].chars().all(|c| c.is_ascii_hexdigit()));
2816 }
2817 }
2818
2819 #[test]
2822 fn sync_stmt_cache_insert_get_remove() {
2823 let mut cache = StmtCache::default();
2824 let info = StmtInfo {
2825 name: "s_test".into(),
2826 columns: Arc::from(Vec::new()),
2827 last_used: 1,
2828 bind_template: None,
2829 };
2830 cache.insert(42, info);
2831 assert_eq!(cache.len(), 1);
2832 assert!(cache.contains_key(&42));
2833 assert!(cache.get(&42).is_some());
2834 assert!(cache.get_mut(&42).is_some());
2835
2836 let removed = cache.remove(&42);
2837 assert!(removed.is_some());
2838 assert_eq!(cache.len(), 0);
2839 assert!(!cache.contains_key(&42));
2840 }
2841
2842 #[test]
2843 fn sync_stmt_cache_evict_lru() {
2844 let mut cache = StmtCache::default();
2845 for i in 0..3u64 {
2846 cache.insert(
2847 i,
2848 StmtInfo {
2849 name: format!("s_{i}").into(),
2850 columns: Arc::from(Vec::new()),
2851 last_used: i + 1,
2852 bind_template: None,
2853 },
2854 );
2855 }
2856 assert_eq!(cache.len(), 3);
2857 let evicted = cache.evict_lru().unwrap();
2858 assert_eq!(evicted.0, 0); assert_eq!(cache.len(), 2);
2860 }
2861
2862 #[test]
2863 fn sync_stmt_cache_insert_overwrite() {
2864 let mut cache = StmtCache::default();
2865 let info1 = StmtInfo {
2866 name: "s_a".into(),
2867 columns: Arc::from(Vec::new()),
2868 last_used: 1,
2869 bind_template: None,
2870 };
2871 let info2 = StmtInfo {
2872 name: "s_b".into(),
2873 columns: Arc::from(Vec::new()),
2874 last_used: 2,
2875 bind_template: None,
2876 };
2877 cache.insert(42, info1);
2878 cache.insert(42, info2);
2879 assert_eq!(cache.len(), 1);
2880 assert_eq!(&*cache.get(&42).unwrap().name, "s_b");
2881 }
2882
2883 #[test]
2886 fn sync_data_row_all_null() {
2887 let mut arena = Arena::new();
2888 let mut out = Vec::new();
2889 let mut data = Vec::new();
2890 data.extend_from_slice(&3i16.to_be_bytes());
2891 data.extend_from_slice(&(-1i32).to_be_bytes());
2892 data.extend_from_slice(&(-1i32).to_be_bytes());
2893 data.extend_from_slice(&(-1i32).to_be_bytes());
2894 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2895 assert_eq!(out.len(), 3);
2896 for (_, len) in &out {
2897 assert_eq!(*len, -1);
2898 }
2899 }
2900
2901 #[test]
2902 fn sync_data_row_long_text() {
2903 let mut arena = Arena::new();
2904 let mut out = Vec::new();
2905 let long_text = "a".repeat(2048);
2906 let text_bytes = long_text.as_bytes();
2907 let mut data = Vec::new();
2908 data.extend_from_slice(&1i16.to_be_bytes());
2909 data.extend_from_slice(&(text_bytes.len() as i32).to_be_bytes());
2910 data.extend_from_slice(text_bytes);
2911 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2912 assert_eq!(out.len(), 1);
2913 assert_eq!(out[0].1, text_bytes.len() as i32);
2914 let stored = arena.get(out[0].0, out[0].1 as usize);
2915 assert_eq!(stored, text_bytes);
2916 }
2917
2918 #[test]
2919 fn sync_data_row_empty_text() {
2920 let mut arena = Arena::new();
2921 let mut out = Vec::new();
2922 let mut data = Vec::new();
2923 data.extend_from_slice(&1i16.to_be_bytes());
2924 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2926 assert_eq!(out.len(), 1);
2927 assert_eq!(out[0].1, 0); }
2929
2930 #[test]
2931 fn sync_data_row_17_columns_exceeds_smallvec() {
2932 let mut arena = Arena::new();
2933 let mut out = Vec::new();
2934 let mut data = Vec::new();
2935 let num_cols: i16 = 20;
2936 data.extend_from_slice(&num_cols.to_be_bytes());
2937 for i in 0..num_cols {
2938 let val = (i as i32).to_be_bytes();
2939 data.extend_from_slice(&4i32.to_be_bytes());
2940 data.extend_from_slice(&val);
2941 }
2942 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2943 assert_eq!(out.len(), 20);
2944 for (idx, (offset, len)) in out.iter().enumerate() {
2945 assert_eq!(*len, 4);
2946 let stored = arena.get(*offset, 4);
2947 let val = i32::from_be_bytes([stored[0], stored[1], stored[2], stored[3]]);
2948 assert_eq!(val, idx as i32);
2949 }
2950 }
2951
2952 #[test]
2953 fn sync_data_row_mixed_null_and_data() {
2954 let mut arena = Arena::new();
2955 let mut out = Vec::new();
2956 let mut data = Vec::new();
2957 data.extend_from_slice(&5i16.to_be_bytes());
2958 data.extend_from_slice(&(-1i32).to_be_bytes());
2960 data.extend_from_slice(&4i32.to_be_bytes());
2962 data.extend_from_slice(&42i32.to_be_bytes());
2963 data.extend_from_slice(&(-1i32).to_be_bytes());
2965 data.extend_from_slice(&(-1i32).to_be_bytes());
2967 data.extend_from_slice(&5i32.to_be_bytes());
2969 data.extend_from_slice(b"hello");
2970
2971 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
2972 assert_eq!(out.len(), 5);
2973 assert_eq!(out[0].1, -1);
2974 assert_eq!(out[1].1, 4);
2975 assert_eq!(out[2].1, -1);
2976 assert_eq!(out[3].1, -1);
2977 assert_eq!(out[4].1, 5);
2978 let stored = arena.get(out[4].0, 5);
2979 assert_eq!(stored, b"hello");
2980 }
2981
2982 #[test]
2985 fn build_bind_template_too_short_buf() {
2986 let tmpl = build_bind_template(&[b'B', 0, 0], 0);
2987 assert!(tmpl.is_none());
2988 }
2989
2990 #[test]
2991 fn build_bind_template_zero_params() {
2992 let mut buf = Vec::new();
2993 proto::write_bind_params(&mut buf, "", "s_test", &[]);
2994 let tmpl = build_bind_template(&buf, 0);
2995 assert!(tmpl.is_some());
2996 let tmpl = tmpl.unwrap();
2997 assert_eq!(tmpl.param_slots.len(), 0);
2998 }
2999
3000 #[test]
3001 fn build_bind_template_bool_param() {
3002 let mut buf = Vec::new();
3003 let val = true;
3004 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
3005 let tmpl = build_bind_template(&buf, 1);
3006 assert!(tmpl.is_some());
3007 let tmpl = tmpl.unwrap();
3008 assert_eq!(tmpl.param_slots.len(), 1);
3009 assert_eq!(tmpl.param_slots[0].1, 1); }
3011
3012 #[test]
3013 fn build_bind_template_i64_param() {
3014 let mut buf = Vec::new();
3015 let val: i64 = 123456789;
3016 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
3017 let tmpl = build_bind_template(&buf, 1);
3018 assert!(tmpl.is_some());
3019 let tmpl = tmpl.unwrap();
3020 assert_eq!(tmpl.param_slots[0].1, 8); }
3022
3023 #[test]
3024 fn build_bind_template_f64_param() {
3025 let mut buf = Vec::new();
3026 let val: f64 = 3.14;
3027 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
3028 let tmpl = build_bind_template(&buf, 1);
3029 assert!(tmpl.is_some());
3030 let tmpl = tmpl.unwrap();
3031 assert_eq!(tmpl.param_slots[0].1, 8); }
3033
3034 #[test]
3035 fn build_bind_template_str_param() {
3036 let mut buf = Vec::new();
3037 let val: &str = "hello world";
3038 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
3039 let tmpl = build_bind_template(&buf, 1);
3040 assert!(tmpl.is_some());
3041 let tmpl = tmpl.unwrap();
3042 assert_eq!(tmpl.param_slots[0].1, 11); }
3044
3045 #[test]
3046 fn build_bind_template_mixed_params_with_null() {
3047 let mut buf = Vec::new();
3048 let id: i32 = 1;
3049 let name: Option<i32> = None;
3050 let score: f64 = 9.9;
3051 proto::write_bind_params(
3052 &mut buf,
3053 "",
3054 "s_test",
3055 &[
3056 &id as &(dyn Encode + Sync),
3057 &name as &(dyn Encode + Sync),
3058 &score as &(dyn Encode + Sync),
3059 ],
3060 );
3061 let tmpl = build_bind_template(&buf, 3);
3062 assert!(tmpl.is_some());
3063 let tmpl = tmpl.unwrap();
3064 assert_eq!(tmpl.param_slots.len(), 3);
3065 assert_eq!(tmpl.param_slots[0].1, 4); assert_eq!(tmpl.param_slots[1].1, -1); assert_eq!(tmpl.param_slots[2].1, 8); }
3069
3070 #[test]
3071 fn build_bind_template_preserves_bytes() {
3072 let mut buf = Vec::new();
3073 let val: i32 = 42;
3074 proto::write_bind_params(&mut buf, "", "s_test", &[&val as &(dyn Encode + Sync)]);
3075 let bind_len = buf.len();
3076 let tmpl = build_bind_template(&buf, 1).unwrap();
3077 assert_eq!(
3079 &tmpl.bytes[..bind_len],
3080 &buf[..],
3081 "template must start with original Bind message"
3082 );
3083 assert_eq!(
3084 &tmpl.bytes[bind_len..],
3085 proto::EXECUTE_SYNC,
3086 "template must end with EXECUTE_SYNC"
3087 );
3088 }
3089
3090 #[test]
3093 #[ignore] fn sync_connect_uds_if_pg_available() {
3095 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3096 let result = SyncConnection::connect(&config);
3097 if let Ok(conn) = result {
3099 assert!(conn.pid() != 0, "pid should be nonzero");
3100 assert!(conn.is_idle(), "should start idle");
3101 assert!(!conn.is_in_transaction(), "should not be in tx");
3102 assert!(
3103 !conn.is_in_failed_transaction(),
3104 "should not be in failed tx"
3105 );
3106 assert_eq!(conn.stmt_cache_len(), 0, "cache should be empty");
3107 let _ = conn.close();
3108 }
3109 }
3110
3111 #[test]
3112 #[ignore] fn sync_simple_query_if_pg_available() {
3114 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3115 let mut conn = SyncConnection::connect(&config).unwrap();
3116 conn.simple_query("SELECT 1").unwrap();
3117 assert!(conn.is_idle());
3118 let _ = conn.close();
3119 }
3120
3121 #[test]
3122 #[ignore] fn sync_query_with_params_if_pg_available() {
3124 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3125 let mut conn = SyncConnection::connect(&config).unwrap();
3126 let mut arena = Arena::new();
3127 let sql = "SELECT $1::int4 + $2::int4 AS sum";
3128 let hash = hash_sql(sql);
3129 let a: i32 = 10;
3130 let b: i32 = 20;
3131 let result = conn
3132 .query(
3133 sql,
3134 hash,
3135 &[&a as &(dyn Encode + Sync), &b as &(dyn Encode + Sync)],
3136 &mut arena,
3137 )
3138 .unwrap();
3139 assert_eq!(result.len(), 1);
3140 let _ = conn.close();
3141 }
3142
3143 #[test]
3144 #[ignore] fn sync_execute_if_pg_available() {
3146 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3147 let mut conn = SyncConnection::connect(&config).unwrap();
3148 conn.simple_query("CREATE TEMP TABLE _sync_test (id int)")
3149 .unwrap();
3150 let sql = "INSERT INTO _sync_test VALUES ($1::int4)";
3151 let hash = hash_sql(sql);
3152 let val: i32 = 42;
3153 let affected = conn
3154 .execute(sql, hash, &[&val as &(dyn Encode + Sync)])
3155 .unwrap();
3156 assert_eq!(affected, 1);
3157 let _ = conn.close();
3158 }
3159
3160 #[test]
3161 #[ignore] fn sync_for_each_zero_rows_if_pg_available() {
3163 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3164 let mut conn = SyncConnection::connect(&config).unwrap();
3165 conn.simple_query("CREATE TEMP TABLE _sync_fe0 (id int)")
3166 .unwrap();
3167 let sql = "SELECT id FROM _sync_fe0";
3168 let hash = hash_sql(sql);
3169 let mut count = 0u32;
3170 conn.for_each(sql, hash, &[], |_row| {
3171 count += 1;
3172 Ok(())
3173 })
3174 .unwrap();
3175 assert_eq!(count, 0);
3176 let _ = conn.close();
3177 }
3178
3179 #[test]
3180 #[ignore] fn sync_for_each_multiple_rows_if_pg_available() {
3182 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3183 let mut conn = SyncConnection::connect(&config).unwrap();
3184 let sql = "SELECT generate_series(1, 5)";
3185 let hash = hash_sql(sql);
3186 let mut count = 0u32;
3187 conn.for_each(sql, hash, &[], |_row| {
3188 count += 1;
3189 Ok(())
3190 })
3191 .unwrap();
3192 assert_eq!(count, 5);
3193 let _ = conn.close();
3194 }
3195
3196 #[test]
3197 #[ignore] fn sync_prepare_only_if_pg_available() {
3199 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3200 let mut conn = SyncConnection::connect(&config).unwrap();
3201 let sql = "SELECT 1";
3202 let hash = hash_sql(sql);
3203 conn.prepare_only(sql, hash).unwrap();
3204 assert_eq!(conn.stmt_cache_len(), 1);
3205 conn.prepare_only(sql, hash).unwrap();
3207 assert_eq!(conn.stmt_cache_len(), 1);
3208 let _ = conn.close();
3209 }
3210
3211 #[test]
3212 #[ignore] fn sync_simple_query_rows_if_pg_available() {
3214 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3215 let mut conn = SyncConnection::connect(&config).unwrap();
3216 let rows = conn.simple_query_rows("SELECT 42 AS n").unwrap();
3217 assert!(!rows.is_empty());
3218 let _ = conn.close();
3219 }
3220
3221 #[test]
3222 #[ignore] fn sync_stmt_cache_hit_miss_if_pg_available() {
3224 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3225 let mut conn = SyncConnection::connect(&config).unwrap();
3226 let mut arena = Arena::new();
3227 let sql1 = "SELECT 1";
3228 let hash1 = hash_sql(sql1);
3229 conn.query(sql1, hash1, &[], &mut arena).unwrap();
3230 assert_eq!(conn.stmt_cache_len(), 1);
3231 arena.reset();
3233 conn.query(sql1, hash1, &[], &mut arena).unwrap();
3234 assert_eq!(conn.stmt_cache_len(), 1);
3235 let sql2 = "SELECT 2";
3237 let hash2 = hash_sql(sql2);
3238 arena.reset();
3239 conn.query(sql2, hash2, &[], &mut arena).unwrap();
3240 assert_eq!(conn.stmt_cache_len(), 2);
3241 let _ = conn.close();
3242 }
3243
3244 #[test]
3245 #[ignore] fn sync_invalid_sql_error_if_pg_available() {
3247 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3248 let mut conn = SyncConnection::connect(&config).unwrap();
3249 let mut arena = Arena::new();
3250 let sql = "SELECTTTT INVALID GARBAGE";
3251 let hash = hash_sql(sql);
3252 let result = conn.query(sql, hash, &[], &mut arena);
3253 assert!(result.is_err());
3254 assert!(conn.is_idle());
3256 let _ = conn.close();
3257 }
3258
3259 #[test]
3260 #[ignore] fn sync_tx_state_transitions_if_pg_available() {
3262 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3263 let mut conn = SyncConnection::connect(&config).unwrap();
3264 assert!(conn.is_idle());
3265 assert!(!conn.is_in_transaction());
3266 conn.simple_query("BEGIN").unwrap();
3267 assert!(conn.is_in_transaction());
3268 assert!(!conn.is_idle());
3269 conn.simple_query("COMMIT").unwrap();
3270 assert!(conn.is_idle());
3271 assert!(!conn.is_in_transaction());
3272 let _ = conn.close();
3273 }
3274
3275 #[test]
3276 #[ignore] fn sync_lru_cache_eviction_if_pg_available() {
3278 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3279 let mut conn = SyncConnection::connect(&config).unwrap();
3280 conn.set_max_stmt_cache_size(3);
3281 let mut arena = Arena::new();
3282 for i in 0..5 {
3283 let sql = format!("SELECT {}", i);
3284 let hash = hash_sql(&sql);
3285 arena.reset();
3286 conn.query(&sql, hash, &[], &mut arena).unwrap();
3287 }
3288 assert!(
3290 conn.stmt_cache_len() <= 3,
3291 "cache should be capped at 3, got {}",
3292 conn.stmt_cache_len()
3293 );
3294 let _ = conn.close();
3295 }
3296
3297 #[test]
3298 #[ignore] fn sync_for_each_raw_if_pg_available() {
3300 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3301 let mut conn = SyncConnection::connect(&config).unwrap();
3302 let sql = "SELECT generate_series(1, 3)";
3303 let hash = hash_sql(sql);
3304 let mut raw_count = 0u32;
3305 conn.for_each_raw(sql, hash, &[], |_raw_data| {
3306 raw_count += 1;
3307 Ok(())
3308 })
3309 .unwrap();
3310 assert_eq!(raw_count, 3);
3311 let _ = conn.close();
3312 }
3313
3314 #[test]
3315 #[ignore] fn sync_query_null_params_if_pg_available() {
3317 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3318 let mut conn = SyncConnection::connect(&config).unwrap();
3319 let mut arena = Arena::new();
3320 let sql = "SELECT $1::int4 IS NULL AS is_null";
3321 let hash = hash_sql(sql);
3322 let val: Option<i32> = None;
3323 let _result = conn
3324 .query(sql, hash, &[&val as &(dyn Encode + Sync)], &mut arena)
3325 .unwrap();
3326 let _ = conn.close();
3327 }
3328
3329 #[test]
3330 #[ignore] fn sync_query_various_param_types_if_pg_available() {
3332 let config = Config::from_url("postgres://postgres@localhost/postgres?host=/tmp").unwrap();
3333 let mut conn = SyncConnection::connect(&config).unwrap();
3334 let mut arena = Arena::new();
3335 let sql = "SELECT $1::int4, $2::int8, $3::text, $4::bool, $5::float8";
3336 let hash = hash_sql(sql);
3337 let p1: i32 = 42;
3338 let p2: i64 = 9999999;
3339 let p3: &str = "hello";
3340 let p4: bool = true;
3341 let p5: f64 = 3.14;
3342 let result = conn
3343 .query(
3344 sql,
3345 hash,
3346 &[
3347 &p1 as &(dyn Encode + Sync),
3348 &p2 as &(dyn Encode + Sync),
3349 &p3 as &(dyn Encode + Sync),
3350 &p4 as &(dyn Encode + Sync),
3351 &p5 as &(dyn Encode + Sync),
3352 ],
3353 &mut arena,
3354 )
3355 .unwrap();
3356 assert_eq!(result.len(), 1);
3357 let _ = conn.close();
3358 }
3359
3360 #[test]
3363 fn sync_shrink_threshold_values() {
3364 let shrink = 64 * 1024usize;
3373 let initial = 8192usize;
3374 assert!(
3375 shrink > initial,
3376 "shrink threshold must exceed initial size"
3377 );
3378 }
3379
3380 #[test]
3383 fn sync_connection_debug_format() {
3384 let fmt_str = format!(
3388 "SyncConnection {{ pid: {}, tx_status: '{}', stmt_cache_len: {} }}",
3389 0, 'I', 0
3390 );
3391 assert!(fmt_str.contains("SyncConnection"));
3392 assert!(fmt_str.contains("pid"));
3393 assert!(fmt_str.contains("tx_status"));
3394 }
3395}