1use std::sync::Arc;
15
16use rapidhash::quality::RapidHasher;
17
18struct StmtCache {
29 entries: Vec<(u64, StmtInfo)>,
30}
31
32impl Default for StmtCache {
33 fn default() -> Self {
34 Self {
35 entries: Vec::with_capacity(16),
36 }
37 }
38}
39
40impl StmtCache {
41 #[inline]
42 fn get_mut(&mut self, hash: &u64) -> Option<&mut StmtInfo> {
43 self.entries
44 .iter_mut()
45 .find(|(h, _)| h == hash)
46 .map(|(_, info)| info)
47 }
48
49 #[inline]
50 fn get(&self, hash: &u64) -> Option<&StmtInfo> {
51 self.entries
52 .iter()
53 .find(|(h, _)| h == hash)
54 .map(|(_, info)| info)
55 }
56
57 #[inline]
58 fn contains_key(&self, hash: &u64) -> bool {
59 self.entries.iter().any(|(h, _)| h == hash)
60 }
61
62 #[inline]
63 fn insert(&mut self, hash: u64, info: StmtInfo) {
64 if let Some(entry) = self.entries.iter_mut().find(|(h, _)| *h == hash) {
65 entry.1 = info;
66 } else {
67 self.entries.push((hash, info));
68 }
69 }
70
71 #[inline]
72 fn remove(&mut self, hash: &u64) -> Option<StmtInfo> {
73 if let Some(pos) = self.entries.iter().position(|(h, _)| h == hash) {
74 Some(self.entries.swap_remove(pos).1)
75 } else {
76 None
77 }
78 }
79
80 #[inline]
81 fn len(&self) -> usize {
82 self.entries.len()
83 }
84
85 fn evict_lru(&mut self) -> Option<(u64, StmtInfo)> {
87 if self.entries.is_empty() {
88 return None;
89 }
90 let min_idx = self
91 .entries
92 .iter()
93 .enumerate()
94 .min_by_key(|(_, (_, info))| info.last_used)
95 .map(|(i, _)| i)?;
96 Some(self.entries.swap_remove(min_idx))
97 }
98}
99
100use tokio::io::{AsyncRead, AsyncWriteExt};
101use tokio::net::TcpStream;
102
103use crate::DriverError;
104use crate::arena::Arena;
105use crate::auth;
106use crate::codec::Encode;
107use crate::proto::{self, BackendMessage};
108
109#[cfg(feature = "tls")]
110use crate::tls;
111
112enum Stream {
116 Plain(TcpStream),
117 #[cfg(feature = "tls")]
118 Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
119 #[cfg(unix)]
120 Unix(tokio::net::UnixStream),
121}
122
123impl Stream {
124 async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
125 match self {
126 Stream::Plain(s) => s.write_all(buf).await,
127 #[cfg(feature = "tls")]
128 Stream::Tls(s) => s.write_all(buf).await,
129 #[cfg(unix)]
130 Stream::Unix(s) => s.write_all(buf).await,
131 }
132 }
133
134 async fn flush(&mut self) -> std::io::Result<()> {
135 match self {
136 Stream::Plain(s) => s.flush().await,
137 #[cfg(feature = "tls")]
138 Stream::Tls(s) => s.flush().await,
139 #[cfg(unix)]
140 Stream::Unix(s) => s.flush().await,
141 }
142 }
143}
144
145struct StreamReader<'a>(&'a mut Stream);
147
148impl AsyncRead for StreamReader<'_> {
149 fn poll_read(
150 mut self: std::pin::Pin<&mut Self>,
151 cx: &mut std::task::Context<'_>,
152 buf: &mut tokio::io::ReadBuf<'_>,
153 ) -> std::task::Poll<std::io::Result<()>> {
154 match &mut *self.0 {
155 Stream::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
156 #[cfg(feature = "tls")]
157 Stream::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_read(cx, buf),
158 #[cfg(unix)]
159 Stream::Unix(s) => std::pin::Pin::new(s).poll_read(cx, buf),
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
173pub struct Config {
174 pub host: String,
175 pub port: u16,
176 pub user: String,
177 pub password: String,
178 pub database: String,
179 pub ssl: SslMode,
180 pub statement_timeout_secs: u32,
185}
186
187impl Drop for Config {
189 fn drop(&mut self) {
190 use zeroize::Zeroize;
191 self.password.zeroize();
192 }
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum SslMode {
198 Disable,
200 Prefer,
202 Require,
204}
205
206impl Config {
207 pub fn from_url(url: &str) -> Result<Self, DriverError> {
221 let url = url
222 .strip_prefix("postgres://")
223 .or_else(|| url.strip_prefix("postgresql://"))
224 .ok_or_else(|| DriverError::Protocol("URL must start with postgres://".into()))?;
225
226 let (userinfo, rest) = url
228 .split_once('@')
229 .ok_or_else(|| DriverError::Protocol("missing @ in connection URL".into()))?;
230
231 let (user, password) = userinfo.split_once(':').unwrap_or((userinfo, ""));
232
233 let (hostport, rest) = rest.split_once('/').unwrap_or((rest, ""));
235 let (database, params) = rest.split_once('?').unwrap_or((rest, ""));
236
237 let (host, port) = if let Some((h, p)) = hostport.split_once(':') {
238 let port = p
239 .parse::<u16>()
240 .map_err(|_| DriverError::Protocol(format!("invalid port: {p}")))?;
241 (h.to_owned(), port)
242 } else {
243 (hostport.to_owned(), 5432)
244 };
245
246 let mut ssl = SslMode::Prefer;
247 let mut statement_timeout_secs: u32 = 30;
248 let mut host_override: Option<String> = None;
249 for param in params.split('&') {
250 if param.is_empty() {
251 continue;
252 }
253 if let Some(val) = param.strip_prefix("sslmode=") {
254 ssl = match val {
256 "disable" => SslMode::Disable,
257 "prefer" => SslMode::Prefer,
258 "require" => SslMode::Require,
259 _ => {
260 return Err(DriverError::Protocol(format!(
261 "unknown sslmode: '{val}' (expected: disable, prefer, require)"
262 )));
263 }
264 };
265 } else if let Some(val) = param.strip_prefix("statement_timeout=") {
266 statement_timeout_secs = val.parse::<u32>().unwrap_or(30);
267 } else if let Some(val) = param.strip_prefix("host=") {
268 host_override = Some(url_decode(val)?);
269 }
270 }
271
272 let final_host = if let Some(h) = host_override {
275 h
276 } else {
277 url_decode(&host)?
278 };
279
280 let config = Config {
281 host: final_host,
282 port,
283 user: url_decode(user)?,
284 password: url_decode(password)?,
285 database: if database.is_empty() {
286 url_decode(user)?
287 } else {
288 url_decode(database)?
289 },
290 ssl,
291 statement_timeout_secs,
292 };
293 config.validate()?;
294 Ok(config)
295 }
296
297 pub fn validate(&self) -> Result<(), DriverError> {
302 if self.host.is_empty() {
303 return Err(DriverError::Protocol("host cannot be empty".into()));
304 }
305 if self.user.is_empty() {
306 return Err(DriverError::Protocol("user cannot be empty".into()));
307 }
308 if self.database.is_empty() {
309 return Err(DriverError::Protocol("database cannot be empty".into()));
310 }
311 Ok(())
312 }
313
314 pub fn host_is_uds(&self) -> bool {
319 self.host.starts_with('/')
320 }
321
322 pub fn uds_path(&self) -> String {
326 format!("{}/.s.PGSQL.{}", self.host, self.port)
327 }
328}
329
330fn url_decode(s: &str) -> Result<String, DriverError> {
336 let mut bytes = Vec::with_capacity(s.len());
337 let input = s.as_bytes();
338 let mut i = 0;
339 while i < input.len() {
340 if input[i] == b'%' {
341 if i + 2 >= input.len() {
342 return Err(DriverError::Protocol(format!(
343 "malformed percent-encoding in URL: '{s}'"
344 )));
345 }
346 let hi = hex_val(input[i + 1]).ok_or_else(|| {
347 DriverError::Protocol(format!(
348 "invalid hex digit '{}' in URL: '{s}'",
349 input[i + 1] as char
350 ))
351 })?;
352 let lo = hex_val(input[i + 2]).ok_or_else(|| {
353 DriverError::Protocol(format!(
354 "invalid hex digit '{}' in URL: '{s}'",
355 input[i + 2] as char
356 ))
357 })?;
358 bytes.push(hi * 16 + lo);
359 i += 3;
360 } else {
361 bytes.push(input[i]);
362 i += 1;
363 }
364 }
365 String::from_utf8(bytes)
366 .map_err(|_| DriverError::Protocol(format!("invalid UTF-8 in URL: '{s}'")))
367}
368
369fn hex_val(b: u8) -> Option<u8> {
370 match b {
371 b'0'..=b'9' => Some(b - b'0'),
372 b'a'..=b'f' => Some(b - b'a' + 10),
373 b'A'..=b'F' => Some(b - b'A' + 10),
374 _ => None,
375 }
376}
377
378enum StartupAction {
380 AuthOk,
381 AuthCleartext,
382 AuthMd5([u8; 4]),
383 AuthSasl(Vec<u8>),
384 ParameterStatus(Box<str>, Box<str>),
385 BackendKeyData(i32, i32),
386 ReadyForQuery(u8),
387 Error(String),
388 Notice,
389}
390
391#[inline]
399fn make_stmt_name(hash: u64) -> Box<str> {
400 const HEX: &[u8; 16] = b"0123456789abcdef";
401 let mut buf = [0u8; 18]; buf[0] = b's';
403 buf[1] = b'_';
404 let bytes = hash.to_be_bytes();
405 for (i, &b) in bytes.iter().enumerate() {
406 buf[2 + i * 2] = HEX[(b >> 4) as usize];
407 buf[2 + i * 2 + 1] = HEX[(b & 0x0f) as usize];
408 }
409 let s = std::str::from_utf8(&buf).expect("BUG: stmt name buffer contains only ASCII hex");
412 s.into()
413}
414
415struct StmtInfo {
424 name: Box<str>,
426 columns: Arc<[ColumnDesc]>,
428 last_used: u64,
431 bind_template: Option<BindTemplate>,
441}
442
443struct BindTemplate {
450 bytes: Vec<u8>,
452 bind_end: usize,
455 param_slots: Vec<(usize, i32)>,
459}
460
461#[derive(Debug, Clone)]
463pub struct ColumnDesc {
464 pub name: Box<str>,
466 pub type_oid: u32,
468 pub type_size: i16,
470 pub table_oid: u32,
472 pub column_id: i16,
474}
475
476#[derive(Debug, Clone)]
479pub struct PrepareResult {
480 pub columns: Vec<ColumnDesc>,
482 pub param_oids: Vec<u32>,
484}
485
486pub type SimpleRow = Vec<Option<String>>;
491
492#[derive(Debug, Clone)]
500pub struct Notification {
501 pub pid: i32,
503 pub channel: String,
505 pub payload: String,
507}
508
509pub struct Connection {
514 stream: Stream,
515 read_buf: Vec<u8>,
517 stream_buf: Vec<u8>,
520 stream_buf_pos: usize,
522 stream_buf_end: usize,
524 write_buf: Vec<u8>,
525 stmts: StmtCache,
526 params: Vec<(Box<str>, Box<str>)>,
527 pid: i32,
528 secret: i32,
529 tx_status: u8,
530 last_used: std::time::Instant,
534 streaming_active: bool,
538 created_at: std::time::Instant,
540 pending_notifications: Vec<Notification>,
543 max_stmt_cache_size: usize,
547 query_counter: u64,
550}
551
552impl Connection {
553 pub async fn connect(config: &Config) -> Result<Self, DriverError> {
559 #[cfg(unix)]
563 if config.host_is_uds() {
564 let path = config.uds_path();
565 let unix = tokio::net::UnixStream::connect(&path)
566 .await
567 .map_err(DriverError::Io)?;
568 let stream = Stream::Unix(unix);
569 return Self::finish_connect(stream, config).await;
570 }
571
572 let addr = format!("{}:{}", config.host, config.port);
573 let tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
574
575 tcp.set_nodelay(true).map_err(DriverError::Io)?;
577
578 Self::set_keepalive(&tcp)?;
581
582 let stream = match config.ssl {
583 SslMode::Disable => Stream::Plain(tcp),
584 #[cfg(feature = "tls")]
585 SslMode::Prefer | SslMode::Require => {
586 match tls::try_upgrade(tcp, &config.host, config.ssl == SslMode::Require).await {
587 Ok(tls_stream) => Stream::Tls(Box::new(tls_stream)),
588 Err(e) if config.ssl == SslMode::Require => return Err(e),
589 Err(_) => {
590 let tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
592 tcp.set_nodelay(true).map_err(DriverError::Io)?;
593 Self::set_keepalive(&tcp)?;
594 Stream::Plain(tcp)
595 }
596 }
597 }
598 #[cfg(not(feature = "tls"))]
599 SslMode::Require => {
600 return Err(DriverError::Protocol(
601 "TLS required but bsql-driver-postgres compiled without 'tls' feature".into(),
602 ));
603 }
604 #[cfg(not(feature = "tls"))]
605 SslMode::Prefer => Stream::Plain(tcp),
606 };
607
608 Self::finish_connect(stream, config).await
609 }
610
611 async fn finish_connect(stream: Stream, config: &Config) -> Result<Self, DriverError> {
615 let mut conn = Self {
616 stream,
617 read_buf: Vec::with_capacity(8192),
618
619 stream_buf: vec![0u8; 65536],
620 stream_buf_pos: 0,
621 stream_buf_end: 0,
622 write_buf: Vec::with_capacity(4096),
623 stmts: StmtCache::default(),
624 params: Vec::new(),
625 pid: 0,
626 secret: 0,
627 tx_status: b'I',
628 last_used: std::time::Instant::now(),
629 streaming_active: false,
630 created_at: std::time::Instant::now(),
631 pending_notifications: Vec::new(),
632 max_stmt_cache_size: 256,
633 query_counter: 0,
634 };
635
636 conn.startup(config).await?;
637
638 conn.validate_server_params()?;
640
641 if config.statement_timeout_secs > 0 {
642 conn.simple_query(&format!(
643 "SET statement_timeout = '{}s'",
644 config.statement_timeout_secs
645 ))
646 .await?;
647 }
648
649 Ok(conn)
650 }
651
652 async fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
658 self.write_buf.clear();
660 proto::write_startup(&mut self.write_buf, &config.user, &config.database);
661 self.flush_write().await?;
662
663 loop {
665 let action = self.read_startup_action().await?;
666 match action {
667 StartupAction::AuthOk => {}
668 StartupAction::AuthCleartext => {
669 self.write_buf.clear();
670 let mut pw = config.password.as_bytes().to_vec();
671 pw.push(0);
672 proto::write_password(&mut self.write_buf, &pw);
673 self.flush_write().await?;
674 }
675 StartupAction::AuthMd5(salt) => {
676 self.write_buf.clear();
677 let hash = auth::md5_password(&config.user, &config.password, &salt);
678 proto::write_password(&mut self.write_buf, &hash);
679 self.flush_write().await?;
680 }
681 StartupAction::AuthSasl(mechanisms_data) => {
682 self.handle_scram(config, &mechanisms_data).await?;
683 }
684 StartupAction::ParameterStatus(name, value) => {
685 if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
687 entry.1 = value;
688 } else {
689 self.params.push((name, value));
690 }
691 }
692 StartupAction::BackendKeyData(pid, secret) => {
693 self.pid = pid;
694 self.secret = secret;
695 }
696 StartupAction::ReadyForQuery(status) => {
697 self.tx_status = status;
698 return Ok(());
699 }
700 StartupAction::Error(msg) => {
701 return Err(DriverError::Auth(msg));
702 }
703 StartupAction::Notice => {}
704 }
705 }
706 }
707
708 async fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
713 let (msg_type, _) = self.read_message_buffered().await?;
714 self.read_startup_message_from_type(msg_type)
715 }
716
717 fn read_startup_message_from_type(&self, msg_type: u8) -> Result<StartupAction, DriverError> {
718 let payload = &self.read_buf;
719 let msg = proto::parse_backend_message(msg_type, payload)?;
720 match msg {
721 BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
722 BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
723 BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
724 BackendMessage::AuthSasl { mechanisms } => {
725 Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
726 }
727 BackendMessage::ParameterStatus { name, value } => {
728 Ok(StartupAction::ParameterStatus(name.into(), value.into()))
729 }
730 BackendMessage::BackendKeyData { pid, secret } => {
731 Ok(StartupAction::BackendKeyData(pid, secret))
732 }
733 BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
734 BackendMessage::ErrorResponse { data } => {
735 let fields = proto::parse_error_response(data);
736 Ok(StartupAction::Error(fields.to_string()))
737 }
738 BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
739 other => Err(DriverError::Protocol(format!(
740 "unexpected message during startup: {other:?}"
741 ))),
742 }
743 }
744
745 async fn handle_scram(
747 &mut self,
748 config: &Config,
749 mechanisms_data: &[u8],
750 ) -> Result<(), DriverError> {
751 let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
752 if !mechs.contains(&"SCRAM-SHA-256") {
753 return Err(DriverError::Auth(format!(
754 "server requires unsupported SASL mechanism(s): {mechs:?}"
755 )));
756 }
757
758 let mut scram = auth::ScramClient::new(&config.user, &config.password)?;
759
760 let client_first = scram.client_first_message();
762 self.write_buf.clear();
763 proto::write_sasl_initial(&mut self.write_buf, "SCRAM-SHA-256", &client_first);
764 self.flush_write().await?;
765
766 let (msg_type, _) = self.read_message_buffered().await?;
768 let server_first = {
769 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
770 match msg {
771 BackendMessage::AuthSaslContinue { data } => data.to_vec(),
772 BackendMessage::ErrorResponse { data } => {
773 let fields = proto::parse_error_response(data);
774 return Err(DriverError::Auth(fields.to_string()));
775 }
776 other => {
777 return Err(DriverError::Protocol(format!(
778 "expected AuthSaslContinue, got: {other:?}"
779 )));
780 }
781 }
782 };
783
784 scram.process_server_first(&server_first)?;
785
786 let client_final = scram.client_final_message()?;
788 self.write_buf.clear();
789 proto::write_sasl_response(&mut self.write_buf, &client_final);
790 self.flush_write().await?;
791
792 let (msg_type, _) = self.read_message_buffered().await?;
794 {
795 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
796 match msg {
797 BackendMessage::AuthSaslFinal { data } => {
798 let data_owned = data.to_vec();
800 scram.verify_server_final(&data_owned)?;
801 }
802 BackendMessage::ErrorResponse { data } => {
803 let fields = proto::parse_error_response(data);
804 return Err(DriverError::Auth(fields.to_string()));
805 }
806 other => {
807 return Err(DriverError::Protocol(format!(
808 "expected AuthSaslFinal, got: {other:?}"
809 )));
810 }
811 }
812 }
813
814 let (msg_type, _) = self.read_message_buffered().await?;
816 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
817 match msg {
818 BackendMessage::AuthOk => Ok(()),
819 BackendMessage::ErrorResponse { data } => {
820 let fields = proto::parse_error_response(data);
821 Err(DriverError::Auth(fields.to_string()))
822 }
823 other => Err(DriverError::Protocol(format!(
824 "expected AuthOk after SCRAM, got: {other:?}"
825 ))),
826 }
827 }
828
829 pub async fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
836 if self.stmts.contains_key(&sql_hash) {
837 return Ok(());
838 }
839 let name = make_stmt_name(sql_hash);
840 self.write_buf.clear();
841 proto::write_parse(&mut self.write_buf, &name, sql, &[]);
842 proto::write_describe(&mut self.write_buf, b'S', &name);
843 proto::write_sync(&mut self.write_buf);
844 self.flush_write().await?;
845
846 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
848 .await?;
849
850 let columns = self.read_column_description().await?;
852
853 self.expect_ready().await?;
855
856 self.query_counter += 1;
858 self.cache_stmt(
859 sql_hash,
860 StmtInfo {
861 name,
862 columns,
863 last_used: self.query_counter,
864 bind_template: None,
865 },
866 );
867 Ok(())
868 }
869
870 pub async fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
883 self.write_buf.clear();
884 proto::write_parse(&mut self.write_buf, "", sql, &[]);
887 proto::write_describe(&mut self.write_buf, b'S', "");
888 proto::write_sync(&mut self.write_buf);
889 self.flush_write().await?;
890
891 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
893 .await?;
894
895 let mut param_oids: Vec<u32> = Vec::new();
897 let columns;
898 loop {
899 let msg = self.read_one_message().await?;
900 match msg {
901 BackendMessage::ParameterDescription { data } => {
902 param_oids = proto::parse_parameter_description(data)?;
903 }
904 BackendMessage::RowDescription { data } => {
905 columns = proto::parse_row_description(data)?;
906 break;
907 }
908 BackendMessage::NoData => {
909 columns = Vec::new();
910 break;
911 }
912 BackendMessage::NoticeResponse { .. } => {}
913 BackendMessage::ErrorResponse { data } => {
914 let fields = proto::parse_error_response(data);
915 self.drain_to_ready().await?;
916 return Err(self.make_server_error(fields));
917 }
918 other => {
919 return Err(DriverError::Protocol(format!(
920 "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
921 )));
922 }
923 }
924 }
925
926 self.expect_ready().await?;
928
929 Ok(PrepareResult {
930 columns,
931 param_oids,
932 })
933 }
934
935 pub async fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
945 self.write_buf.clear();
946 proto::write_simple_query(&mut self.write_buf, sql);
947 self.flush_write().await?;
948
949 let mut rows: Vec<SimpleRow> = Vec::new();
950 loop {
951 let msg = self.read_one_message().await?;
952 match msg {
953 BackendMessage::ReadyForQuery { status } => {
954 self.tx_status = status;
955 return Ok(rows);
956 }
957 BackendMessage::DataRow { data } => {
958 rows.push(proto::parse_simple_data_row(data)?);
959 }
960 BackendMessage::RowDescription { .. }
961 | BackendMessage::CommandComplete { .. }
962 | BackendMessage::EmptyQuery
963 | BackendMessage::NoticeResponse { .. } => {}
964 BackendMessage::ErrorResponse { data } => {
965 let fields = proto::parse_error_response(data);
966 self.drain_to_ready().await?;
967 return Err(self.make_server_error(fields));
968 }
969 BackendMessage::ParameterStatus { .. } => {}
970 other => {
971 return Err(DriverError::Protocol(format!(
972 "unexpected message during simple_query_rows: {other:?}"
973 )));
974 }
975 }
976 }
977 }
978
979 pub async fn query_streaming_start(
992 &mut self,
993 sql: &str,
994 sql_hash: u64,
995 params: &[&(dyn Encode + Sync)],
996 chunk_size: i32,
997 ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
998 self.write_buf.clear();
999
1000 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
1002 self.query_counter += 1;
1004 info.last_used = self.query_counter;
1005
1006 let can_use_template = info
1007 .bind_template
1008 .as_ref()
1009 .is_some_and(|t| t.param_slots.len() == params.len());
1010
1011 if can_use_template {
1012 let tmpl = info.bind_template.as_ref().unwrap();
1013 self.write_buf
1016 .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
1017
1018 let mut template_ok = true;
1019 for (i, param) in params.iter().enumerate() {
1020 let (data_offset, old_len) = tmpl.param_slots[i];
1021 if param.is_null() {
1022 let len_offset = data_offset - 4;
1023 self.write_buf[len_offset..len_offset + 4]
1024 .copy_from_slice(&(-1i32).to_be_bytes());
1025 } else if old_len >= 0 {
1026 let end = data_offset + old_len as usize;
1027 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1028 template_ok = false;
1029 break;
1030 }
1031 } else {
1032 template_ok = false;
1033 break;
1034 }
1035 }
1036
1037 if !template_ok {
1038 self.write_buf.clear();
1039 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1040 info.bind_template = None;
1041 }
1042 } else {
1043 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1044 }
1045
1046 let cols = info.columns.clone();
1047
1048 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1049 info.bind_template = build_bind_template(&self.write_buf, params.len());
1050 }
1051
1052 proto::write_execute(&mut self.write_buf, "", chunk_size);
1053 proto::write_flush(&mut self.write_buf);
1055 self.flush_write().await?;
1056
1057 cols
1058 } else {
1059 let name = make_stmt_name(sql_hash);
1061 let param_oids: smallvec::SmallVec<[u32; 8]> =
1062 params.iter().map(|p| p.type_oid()).collect();
1063 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1064 proto::write_describe(&mut self.write_buf, b'S', &name);
1065 proto::write_bind_params(&mut self.write_buf, "", &name, params);
1066
1067 proto::write_execute(&mut self.write_buf, "", chunk_size);
1068 proto::write_flush(&mut self.write_buf);
1069 self.flush_write().await?;
1070
1071 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1072 .await?;
1073 let columns = self.read_column_description().await?;
1074 self.query_counter += 1;
1075 self.cache_stmt(
1076 sql_hash,
1077 StmtInfo {
1078 name,
1079 columns: columns.clone(),
1080 last_used: self.query_counter,
1081 bind_template: None,
1082 },
1083 );
1084 columns
1085 };
1086
1087 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1089 .await?;
1090
1091 self.streaming_active = true;
1092
1093 Ok((columns, false))
1094 }
1095
1096 pub async fn streaming_next_chunk(
1104 &mut self,
1105 arena: &mut Arena,
1106 all_col_offsets: &mut Vec<(usize, i32)>,
1107 ) -> Result<bool, DriverError> {
1108 all_col_offsets.clear();
1109
1110 loop {
1111 let msg = self.read_one_message().await?;
1112 match msg {
1113 BackendMessage::DataRow { data } => {
1114 parse_data_row_flat(data, arena, all_col_offsets)?;
1115 }
1116 BackendMessage::PortalSuspended => {
1117 return Ok(true);
1121 }
1122 BackendMessage::CommandComplete { .. } => {
1123 self.write_buf.clear();
1126 proto::write_sync(&mut self.write_buf);
1127 self.flush_write().await?;
1128 self.expect_ready().await?;
1129 self.shrink_buffers();
1130
1131 self.streaming_active = false;
1132 return Ok(false);
1133 }
1134 BackendMessage::EmptyQuery => {
1135 self.write_buf.clear();
1136 proto::write_sync(&mut self.write_buf);
1137 self.flush_write().await?;
1138 self.expect_ready().await?;
1139
1140 self.streaming_active = false;
1141 return Ok(false);
1142 }
1143 BackendMessage::ErrorResponse { data } => {
1144 let fields = proto::parse_error_response(data);
1145 self.write_buf.clear();
1147 proto::write_sync(&mut self.write_buf);
1148 self.flush_write().await?;
1149 self.drain_to_ready().await?;
1150
1151 self.streaming_active = false;
1152 return Err(self.make_server_error(fields));
1153 }
1154 BackendMessage::NoticeResponse { .. } => {}
1155 other => {
1156 return Err(DriverError::Protocol(format!(
1157 "unexpected message during streaming: {other:?}"
1158 )));
1159 }
1160 }
1161 }
1162 }
1163
1164 pub async fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
1172 self.write_buf.clear();
1173 proto::write_execute(&mut self.write_buf, "", chunk_size);
1174 proto::write_flush(&mut self.write_buf);
1175 self.flush_write().await
1176 }
1177
1178 async fn send_pipeline(
1190 &mut self,
1191 sql: &str,
1192 sql_hash: u64,
1193 params: &[&(dyn Encode + Sync)],
1194 need_columns: bool,
1195 skip_bind_complete: bool,
1196 ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
1197 debug_assert_eq!(
1198 hash_sql(sql),
1199 sql_hash,
1200 "sql_hash mismatch: caller-provided hash does not match hash_sql(sql)"
1201 );
1202
1203 if params.len() > i16::MAX as usize {
1204 return Err(DriverError::Protocol(format!(
1205 "parameter count {} exceeds maximum {} for PG wire protocol",
1206 params.len(),
1207 i16::MAX
1208 )));
1209 }
1210
1211 self.write_buf.clear();
1212
1213 let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
1215 self.query_counter += 1;
1217 info.last_used = self.query_counter;
1218
1219 let can_use_template = info
1220 .bind_template
1221 .as_ref()
1222 .is_some_and(|t| t.param_slots.len() == params.len());
1223
1224 let mut has_exec_sync = false;
1226
1227 if can_use_template {
1228 let tmpl = info.bind_template.as_ref().unwrap();
1231 self.write_buf.extend_from_slice(&tmpl.bytes);
1232
1233 let mut template_ok = true;
1234 for (i, param) in params.iter().enumerate() {
1235 let (data_offset, old_len) = tmpl.param_slots[i];
1236 if param.is_null() {
1237 let len_offset = data_offset - 4;
1238 self.write_buf[len_offset..len_offset + 4]
1239 .copy_from_slice(&(-1i32).to_be_bytes());
1240 } else if old_len >= 0 {
1241 let end = data_offset + old_len as usize;
1242 if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1243 template_ok = false;
1244 break;
1245 }
1246 } else {
1247 template_ok = false;
1249 break;
1250 }
1251 }
1252
1253 if template_ok {
1254 has_exec_sync = true; } else {
1256 self.write_buf.clear();
1257 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1258 info.bind_template = None;
1259 }
1260 } else {
1261 proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1262 }
1263
1264 let cols = if need_columns {
1267 Some(info.columns.clone())
1268 } else {
1269 None
1270 };
1271
1272 if info.bind_template.is_none() && !self.write_buf.is_empty() {
1275 info.bind_template = build_bind_template(&self.write_buf, params.len());
1276 }
1277
1278 if !has_exec_sync {
1279 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1280 }
1281 self.flush_write().await?;
1282
1283 cols
1284 } else {
1285 let name = make_stmt_name(sql_hash);
1287 let param_oids: smallvec::SmallVec<[u32; 8]> =
1288 params.iter().map(|p| p.type_oid()).collect();
1289 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1290 proto::write_describe(&mut self.write_buf, b'S', &name);
1291 proto::write_bind_params(&mut self.write_buf, "", &name, params);
1292
1293 self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1294 self.flush_write().await?;
1295
1296 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1297 .await?;
1298 let columns = self.read_column_description().await?;
1299 self.query_counter += 1;
1300 self.cache_stmt(
1301 sql_hash,
1302 StmtInfo {
1303 name,
1304 columns: columns.clone(),
1305 last_used: self.query_counter,
1306 bind_template: None,
1307 },
1308 );
1309 if need_columns { Some(columns) } else { None }
1310 };
1311
1312 if !skip_bind_complete {
1314 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1315 .await?;
1316 }
1317
1318 Ok(columns)
1319 }
1320
1321 pub async fn query(
1326 &mut self,
1327 sql: &str,
1328 sql_hash: u64,
1329 params: &[&(dyn Encode + Sync)],
1330 arena: &mut Arena,
1331 ) -> Result<QueryResult, DriverError> {
1332 let columns = self
1333 .send_pipeline(sql, sql_hash, params, true, false)
1334 .await?
1335 .expect("send_pipeline(need_columns=true) must return Some");
1336
1337 let num_cols = columns.len();
1344 let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
1347 let mut affected_rows: u64 = 0;
1348
1349 loop {
1350 let msg = self.read_one_message().await?;
1351 match msg {
1352 BackendMessage::DataRow { data } => {
1353 parse_data_row_flat(data, arena, &mut all_col_offsets)?;
1354 }
1355 BackendMessage::CommandComplete { tag } => {
1356 affected_rows = proto::parse_command_tag(tag);
1357 break;
1358 }
1359 BackendMessage::EmptyQuery => {
1360 break;
1361 }
1362 BackendMessage::NoticeResponse { .. } => {
1363 }
1365 BackendMessage::ErrorResponse { data } => {
1366 let fields = proto::parse_error_response(data);
1367
1368 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1369 self.drain_to_ready().await?;
1370 return Err(self.make_server_error(fields));
1371 }
1372 other => {
1373 return Err(DriverError::Protocol(format!(
1374 "unexpected message during query: {other:?}"
1375 )));
1376 }
1377 }
1378 }
1379
1380 self.expect_ready().await?;
1382 self.shrink_buffers();
1383
1384 Ok(QueryResult {
1385 all_col_offsets,
1386 num_cols,
1387 columns,
1388 affected_rows,
1389 })
1390 }
1391
1392 async fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
1395 loop {
1396 let msg = self.read_one_message().await?;
1397 match msg {
1398 BackendMessage::RowDescription { data } => {
1399 let cols = proto::parse_row_description(data)?;
1400 return Ok(cols.into());
1401 }
1402 BackendMessage::ParameterDescription { .. } => {
1403 }
1405 BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
1406 BackendMessage::NoticeResponse { .. } => {}
1407 BackendMessage::ErrorResponse { data } => {
1408 let fields = proto::parse_error_response(data);
1409 self.drain_to_ready().await?;
1410 return Err(self.make_server_error(fields));
1411 }
1412 other => {
1413 return Err(DriverError::Protocol(format!(
1414 "expected RowDescription/NoData after Parse, got: {other:?}"
1415 )));
1416 }
1417 }
1418 }
1419 }
1420
1421 pub async fn execute(
1426 &mut self,
1427 sql: &str,
1428 sql_hash: u64,
1429 params: &[&(dyn Encode + Sync)],
1430 ) -> Result<u64, DriverError> {
1431 let _ = self
1432 .send_pipeline(sql, sql_hash, params, false, false)
1433 .await?;
1434
1435 let mut affected_rows: u64 = 0;
1437 loop {
1438 let msg = self.read_one_message().await?;
1439 match msg {
1440 BackendMessage::DataRow { .. } => {
1441 }
1443 BackendMessage::CommandComplete { tag } => {
1444 affected_rows = proto::parse_command_tag(tag);
1445 break;
1446 }
1447 BackendMessage::EmptyQuery => break,
1448 BackendMessage::NoticeResponse { .. } => {}
1449 BackendMessage::ErrorResponse { data } => {
1450 let fields = proto::parse_error_response(data);
1451
1452 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1453 self.drain_to_ready().await?;
1454 return Err(self.make_server_error(fields));
1455 }
1456 other => {
1457 return Err(DriverError::Protocol(format!(
1458 "unexpected message during execute: {other:?}"
1459 )));
1460 }
1461 }
1462 }
1463
1464 self.expect_ready().await?;
1465 self.shrink_buffers();
1466 Ok(affected_rows)
1467 }
1468
1469 pub async fn execute_pipeline(
1485 &mut self,
1486 sql: &str,
1487 sql_hash: u64,
1488 param_sets: &[&[&(dyn Encode + Sync)]],
1489 ) -> Result<Vec<u64>, DriverError> {
1490 if param_sets.is_empty() {
1491 return Ok(Vec::new());
1492 }
1493
1494 debug_assert_eq!(
1495 hash_sql(sql),
1496 sql_hash,
1497 "sql_hash mismatch: caller-provided hash does not match hash_sql(sql)"
1498 );
1499
1500 self.write_buf.clear();
1501
1502 if !self.stmts.contains_key(&sql_hash) {
1505 let name = make_stmt_name(sql_hash);
1506 let first_params = param_sets[0];
1507 if first_params.len() > i16::MAX as usize {
1508 return Err(DriverError::Protocol(format!(
1509 "parameter count {} exceeds maximum {}",
1510 first_params.len(),
1511 i16::MAX
1512 )));
1513 }
1514 let param_oids: smallvec::SmallVec<[u32; 8]> =
1515 first_params.iter().map(|p| p.type_oid()).collect();
1516 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1517 proto::write_describe(&mut self.write_buf, b'S', &name);
1518 proto::write_sync(&mut self.write_buf);
1519 self.flush_write().await?;
1520
1521 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1522 .await?;
1523 let columns = self.read_column_description().await?;
1524 self.expect_ready().await?;
1525
1526 self.query_counter += 1;
1527 self.cache_stmt(
1528 sql_hash,
1529 StmtInfo {
1530 name,
1531 columns,
1532 last_used: self.query_counter,
1533 bind_template: None,
1534 },
1535 );
1536
1537 self.write_buf.clear();
1538 }
1539
1540 let stmt_name = self
1542 .stmts
1543 .get(&sql_hash)
1544 .expect("BUG: stmt just cached but not found")
1545 .name
1546 .clone();
1547 let count = param_sets.len();
1548
1549 for params in param_sets {
1550 if params.len() > i16::MAX as usize {
1551 return Err(DriverError::Protocol(format!(
1552 "parameter count {} exceeds maximum {}",
1553 params.len(),
1554 i16::MAX
1555 )));
1556 }
1557 proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1558 self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1559 }
1560
1561 self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1563 self.flush_write().await?;
1564
1565 let mut results = Vec::with_capacity(count);
1567 for _ in 0..count {
1568 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1569 .await?;
1570
1571 let mut affected_rows: u64 = 0;
1573 loop {
1574 let msg = self.read_one_message().await?;
1575 match msg {
1576 BackendMessage::DataRow { .. } => {}
1577 BackendMessage::CommandComplete { tag } => {
1578 affected_rows = proto::parse_command_tag(tag);
1579 break;
1580 }
1581 BackendMessage::EmptyQuery => break,
1582 BackendMessage::NoticeResponse { .. } => {}
1583 BackendMessage::ErrorResponse { data } => {
1584 let fields = proto::parse_error_response(data);
1585 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1586 self.drain_to_ready().await?;
1587 return Err(self.make_server_error(fields));
1588 }
1589 other => {
1590 return Err(DriverError::Protocol(format!(
1591 "unexpected message during execute_pipeline: {other:?}"
1592 )));
1593 }
1594 }
1595 }
1596 results.push(affected_rows);
1597 }
1598
1599 self.expect_ready().await?;
1600 self.shrink_buffers();
1601 Ok(results)
1602 }
1603
1604 pub(crate) async fn ensure_stmt_prepared(
1613 &mut self,
1614 sql: &str,
1615 sql_hash: u64,
1616 params: &[&(dyn Encode + Sync)],
1617 ) -> Result<Box<str>, DriverError> {
1618 if let Some(info) = self.stmts.get(&sql_hash) {
1619 return Ok(info.name.clone());
1620 }
1621
1622 let name = make_stmt_name(sql_hash);
1624 if params.len() > i16::MAX as usize {
1625 return Err(DriverError::Protocol(format!(
1626 "parameter count {} exceeds maximum {}",
1627 params.len(),
1628 i16::MAX
1629 )));
1630 }
1631 let param_oids: smallvec::SmallVec<[u32; 8]> =
1632 params.iter().map(|p| p.type_oid()).collect();
1633
1634 self.write_buf.clear();
1635 proto::write_parse(&mut self.write_buf, &name, sql, ¶m_oids);
1636 proto::write_describe(&mut self.write_buf, b'S', &name);
1637 proto::write_sync(&mut self.write_buf);
1638 self.flush_write().await?;
1639
1640 self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1641 .await?;
1642 let columns = self.read_column_description().await?;
1643 self.expect_ready().await?;
1644
1645 self.query_counter += 1;
1646 let stmt_name = name.clone();
1647 self.cache_stmt(
1648 sql_hash,
1649 StmtInfo {
1650 name,
1651 columns,
1652 last_used: self.query_counter,
1653 bind_template: None,
1654 },
1655 );
1656
1657 Ok(stmt_name)
1658 }
1659
1660 pub(crate) fn write_deferred_bind_execute(
1666 &self,
1667 sql_hash: u64,
1668 params: &[&(dyn Encode + Sync)],
1669 buf: &mut Vec<u8>,
1670 ) {
1671 let stmt_name = &self
1672 .stmts
1673 .get(&sql_hash)
1674 .expect("BUG: stmt just cached but not found")
1675 .name;
1676 proto::write_bind_params(buf, "", stmt_name, params);
1677 buf.extend_from_slice(proto::EXECUTE_ONLY);
1678 }
1679
1680 pub(crate) async fn flush_deferred_pipeline(
1686 &mut self,
1687 buf: &mut Vec<u8>,
1688 count: usize,
1689 ) -> Result<Vec<u64>, DriverError> {
1690 if count == 0 {
1691 buf.clear();
1692 return Ok(Vec::new());
1693 }
1694
1695 buf.extend_from_slice(proto::SYNC_ONLY);
1696
1697 self.stream.write_all(buf).await.map_err(DriverError::Io)?;
1699 self.stream.flush().await.map_err(DriverError::Io)?;
1700 buf.clear();
1701
1702 let mut results = Vec::with_capacity(count);
1704 for _ in 0..count {
1705 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1706 .await?;
1707
1708 let mut affected_rows: u64 = 0;
1709 loop {
1710 let msg = self.read_one_message().await?;
1711 match msg {
1712 BackendMessage::DataRow { .. } => {}
1713 BackendMessage::CommandComplete { tag } => {
1714 affected_rows = proto::parse_command_tag(tag);
1715 break;
1716 }
1717 BackendMessage::EmptyQuery => break,
1718 BackendMessage::NoticeResponse { .. } => {}
1719 BackendMessage::ErrorResponse { data } => {
1720 let fields = proto::parse_error_response(data);
1721 self.drain_to_ready().await?;
1722 return Err(self.make_server_error(fields));
1723 }
1724 other => {
1725 return Err(DriverError::Protocol(format!(
1726 "unexpected message during flush_deferred_pipeline: {other:?}"
1727 )));
1728 }
1729 }
1730 }
1731 results.push(affected_rows);
1732 }
1733
1734 self.expect_ready().await?;
1735 self.shrink_buffers();
1736 Ok(results)
1737 }
1738
1739 pub async fn for_each<F>(
1748 &mut self,
1749 sql: &str,
1750 sql_hash: u64,
1751 params: &[&(dyn Encode + Sync)],
1752 mut f: F,
1753 ) -> Result<(), DriverError>
1754 where
1755 F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1756 {
1757 let _ = self
1758 .send_pipeline(sql, sql_hash, params, false, false)
1759 .await?;
1760
1761 loop {
1762 let msg = self.read_one_message().await?;
1763 match msg {
1764 BackendMessage::DataRow { data } => {
1765 let row = PgDataRow::new(data)?;
1766 f(row)?;
1767 }
1768 BackendMessage::CommandComplete { .. } => break,
1769 BackendMessage::EmptyQuery => break,
1770 BackendMessage::NoticeResponse { .. } => {}
1771 BackendMessage::ErrorResponse { data } => {
1772 let fields = proto::parse_error_response(data);
1773 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1774 self.drain_to_ready().await?;
1775 return Err(self.make_server_error(fields));
1776 }
1777 other => {
1778 return Err(DriverError::Protocol(format!(
1779 "unexpected message during for_each: {other:?}"
1780 )));
1781 }
1782 }
1783 }
1784
1785 self.expect_ready().await?;
1786 self.shrink_buffers();
1787 Ok(())
1788 }
1789
1790 pub async fn for_each_raw<F>(
1804 &mut self,
1805 sql: &str,
1806 sql_hash: u64,
1807 params: &[&(dyn Encode + Sync)],
1808 mut f: F,
1809 ) -> Result<(), DriverError>
1810 where
1811 F: FnMut(&[u8]) -> Result<(), DriverError>,
1812 {
1813 let _ = self
1814 .send_pipeline(sql, sql_hash, params, false, true)
1815 .await?;
1816
1817 loop {
1821 let avail = self.stream_buf_end - self.stream_buf_pos;
1822 if avail >= 5 {
1823 let bc_type = self.stream_buf[self.stream_buf_pos];
1824 match bc_type {
1825 b'2' => {
1826 self.stream_buf_pos += 5;
1828 break;
1829 }
1830 b'E' => {
1831 let msg = self.read_one_message().await?;
1833 if let BackendMessage::ErrorResponse { data } = msg {
1834 let fields = proto::parse_error_response(data);
1835 self.drain_to_ready().await?;
1836 return Err(self.make_server_error(fields));
1837 }
1838 }
1839 b'N' | b'S' => {
1840 let raw_len = i32::from_be_bytes([
1843 self.stream_buf[self.stream_buf_pos + 1],
1844 self.stream_buf[self.stream_buf_pos + 2],
1845 self.stream_buf[self.stream_buf_pos + 3],
1846 self.stream_buf[self.stream_buf_pos + 4],
1847 ]);
1848 let total = 1 + raw_len as usize;
1849 if avail >= total {
1850 self.stream_buf_pos += total;
1851 continue;
1852 }
1853 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1855 .await?;
1856 break;
1857 }
1858 _ => {
1859 self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1862 .await?;
1863 break;
1864 }
1865 }
1866 } else {
1867 let remaining = self.stream_buf_end - self.stream_buf_pos;
1869 if remaining > 0 && self.stream_buf_pos > 0 {
1870 self.stream_buf
1871 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1872 }
1873 self.stream_buf_pos = 0;
1874 self.stream_buf_end = remaining;
1875
1876 let n = {
1877 let mut reader = StreamReader(&mut self.stream);
1878 use tokio::io::AsyncReadExt;
1879 reader
1880 .read(&mut self.stream_buf[remaining..])
1881 .await
1882 .map_err(DriverError::Io)?
1883 };
1884 if n == 0 {
1885 return Err(DriverError::Io(std::io::Error::new(
1886 std::io::ErrorKind::UnexpectedEof,
1887 "connection closed",
1888 )));
1889 }
1890 self.stream_buf_end = remaining + n;
1891 }
1892 }
1893
1894 'outer: loop {
1896 loop {
1898 let avail = self.stream_buf_end - self.stream_buf_pos;
1899 if avail < 5 {
1900 break; }
1902
1903 let msg_type = self.stream_buf[self.stream_buf_pos];
1904 let raw_len = i32::from_be_bytes([
1905 self.stream_buf[self.stream_buf_pos + 1],
1906 self.stream_buf[self.stream_buf_pos + 2],
1907 self.stream_buf[self.stream_buf_pos + 3],
1908 self.stream_buf[self.stream_buf_pos + 4],
1909 ]);
1910
1911 if raw_len < 4 {
1912 return Err(DriverError::Protocol(format!(
1913 "invalid message length {raw_len} for type '{}'",
1914 msg_type as char
1915 )));
1916 }
1917
1918 let payload_len = (raw_len - 4) as usize;
1919 let total_msg_len = 5 + payload_len; if avail < total_msg_len {
1922 if total_msg_len > self.stream_buf.len() {
1924 let msg = self.read_one_message().await?;
1927 match msg {
1928 BackendMessage::DataRow { data } => {
1929 f(data)?;
1930 continue;
1931 }
1932 BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1933 break 'outer;
1934 }
1935 BackendMessage::ErrorResponse { data } => {
1936 let fields = proto::parse_error_response(data);
1937 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1938 self.drain_to_ready().await?;
1939 return Err(self.make_server_error(fields));
1940 }
1941 BackendMessage::NoticeResponse { .. } => continue,
1942 other => {
1943 return Err(DriverError::Protocol(format!(
1944 "unexpected message during for_each_raw: {other:?}"
1945 )));
1946 }
1947 }
1948 }
1949 break;
1951 }
1952
1953 let payload_start = self.stream_buf_pos + 5;
1955 let payload_end = payload_start + payload_len;
1956
1957 if msg_type == b'D' {
1960 f(&self.stream_buf[payload_start..payload_end])?;
1964 } else if msg_type == b'C' || msg_type == b'I' {
1965 self.stream_buf_pos += total_msg_len;
1967 break 'outer;
1968 } else {
1969 self.handle_non_datarow_async(msg_type, payload_start, payload_end, sql_hash)
1970 .await?;
1971 }
1972
1973 self.stream_buf_pos += total_msg_len;
1974 }
1975
1976 let remaining = self.stream_buf_end - self.stream_buf_pos;
1978 if remaining > 0 && self.stream_buf_pos > 0 {
1979 self.stream_buf
1980 .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1981 }
1982 self.stream_buf_pos = 0;
1983 self.stream_buf_end = remaining;
1984
1985 let n = {
1987 let mut reader = StreamReader(&mut self.stream);
1988 use tokio::io::AsyncReadExt;
1989 reader
1990 .read(&mut self.stream_buf[remaining..])
1991 .await
1992 .map_err(DriverError::Io)?
1993 };
1994 if n == 0 {
1995 return Err(DriverError::Io(std::io::Error::new(
1996 std::io::ErrorKind::UnexpectedEof,
1997 "connection closed",
1998 )));
1999 }
2000 self.stream_buf_end = remaining + n;
2001 }
2002
2003 self.expect_ready().await?;
2005 self.shrink_buffers();
2006 Ok(())
2007 }
2008
2009 pub async fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
2013 self.write_buf.clear();
2014 proto::write_simple_query(&mut self.write_buf, sql);
2015 self.flush_write().await?;
2016
2017 loop {
2019 let msg = self.read_one_message().await?;
2020 match msg {
2021 BackendMessage::ReadyForQuery { status } => {
2022 self.tx_status = status;
2023 return Ok(());
2024 }
2025 BackendMessage::CommandComplete { .. }
2026 | BackendMessage::RowDescription { .. }
2027 | BackendMessage::DataRow { .. }
2028 | BackendMessage::EmptyQuery
2029 | BackendMessage::NoticeResponse { .. } => {}
2030 BackendMessage::ErrorResponse { data } => {
2031 let fields = proto::parse_error_response(data);
2032 self.drain_to_ready().await?;
2033 return Err(self.make_server_error(fields));
2034 }
2035
2036 BackendMessage::ParameterStatus { .. } => {}
2038
2039 BackendMessage::AuthOk
2042 | BackendMessage::AuthSaslFinal { .. }
2043 | BackendMessage::AuthSaslContinue { .. }
2044 | BackendMessage::AuthSasl { .. }
2045 | BackendMessage::AuthMd5 { .. }
2046 | BackendMessage::AuthCleartext
2047 | BackendMessage::BackendKeyData { .. } => {}
2048
2049 other => {
2050 return Err(DriverError::Protocol(format!(
2051 "unexpected message during simple_query: {other:?}"
2052 )));
2053 }
2054 }
2055 }
2056 }
2057
2058 pub async fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2067 loop {
2068 let (msg_type, _payload_len) = self.read_message_buffered().await?;
2069 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2070 match msg {
2071 BackendMessage::NotificationResponse {
2072 channel, payload, ..
2073 } => {
2074 return Ok((channel.to_owned(), payload.to_owned()));
2075 }
2076 BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2077 continue;
2078 }
2079 _ => continue,
2080 }
2081 }
2082 }
2083
2084 pub async fn close(mut self) -> Result<(), DriverError> {
2086 self.write_buf.clear();
2087 proto::write_terminate(&mut self.write_buf);
2088 let _ = self.flush_write().await;
2090 Ok(())
2091 }
2092
2093 pub fn is_idle(&self) -> bool {
2095 self.tx_status == b'I'
2096 }
2097
2098 pub fn is_in_transaction(&self) -> bool {
2100 self.tx_status == b'T'
2101 }
2102
2103 pub fn is_in_failed_transaction(&self) -> bool {
2105 self.tx_status == b'E'
2106 }
2107
2108 pub fn touch(&mut self) {
2111 self.last_used = std::time::Instant::now();
2112 }
2113
2114 pub fn idle_duration(&self) -> std::time::Duration {
2116 self.last_used.elapsed()
2117 }
2118
2119 pub fn parameter(&self, name: &str) -> Option<&str> {
2121 self.params
2122 .iter()
2123 .find(|(k, _)| &**k == name)
2124 .map(|(_, v)| &**v)
2125 }
2126
2127 pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2129 &self.params
2130 }
2131
2132 fn validate_server_params(&self) -> Result<(), DriverError> {
2141 if let Some(encoding) = self.parameter("server_encoding") {
2143 let normalized = encoding.to_uppercase();
2144 if normalized != "UTF8" && normalized != "UTF-8" {
2145 return Err(DriverError::Protocol(format!(
2146 "server_encoding is '{encoding}', but bsql requires UTF-8. \
2147 Set server encoding to UTF-8 in postgresql.conf or \
2148 use CREATE DATABASE ... ENCODING 'UTF8'."
2149 )));
2150 }
2151 }
2152
2153 if let Some(encoding) = self.parameter("client_encoding") {
2155 let normalized = encoding.to_uppercase();
2156 if normalized != "UTF8" && normalized != "UTF-8" {
2157 return Err(DriverError::Protocol(format!(
2158 "client_encoding is '{encoding}', but bsql requires UTF-8. \
2159 Check your connection or database configuration."
2160 )));
2161 }
2162 }
2163
2164 if let Some(idt) = self.parameter("integer_datetimes") {
2166 if idt != "on" {
2167 return Err(DriverError::Protocol(format!(
2168 "integer_datetimes is '{idt}', but bsql requires 'on'. \
2169 Our timestamp codec assumes integer-format timestamps \
2170 (microseconds since 2000-01-01). Float-format timestamps \
2171 would produce incorrect decode results."
2172 )));
2173 }
2174 }
2175
2176 Ok(())
2177 }
2178
2179 pub fn pid(&self) -> i32 {
2181 self.pid
2182 }
2183
2184 pub fn secret_key(&self) -> i32 {
2186 self.secret
2187 }
2188
2189 pub async fn cancel(&self, config: &Config) -> Result<(), DriverError> {
2197 let addr = format!("{}:{}", config.host, config.port);
2198 let mut tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
2199 let mut buf = Vec::with_capacity(16);
2200 proto::write_cancel_request(&mut buf, self.pid, self.secret);
2201 tcp.write_all(&buf).await.map_err(DriverError::Io)?;
2202 tcp.flush().await.map_err(DriverError::Io)?;
2203 drop(tcp);
2205 Ok(())
2206 }
2207
2208 pub fn is_streaming(&self) -> bool {
2210 self.streaming_active
2211 }
2212
2213 pub fn drain_notifications(&mut self) -> Vec<Notification> {
2219 std::mem::take(&mut self.pending_notifications)
2220 }
2221
2222 pub fn pending_notification_count(&self) -> usize {
2224 self.pending_notifications.len()
2225 }
2226
2227 pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2233 self.max_stmt_cache_size = size;
2234 }
2235
2236 pub fn stmt_cache_len(&self) -> usize {
2238 self.stmts.len()
2239 }
2240
2241 fn set_keepalive(tcp: &TcpStream) -> Result<(), DriverError> {
2243 let sock = socket2::SockRef::from(tcp);
2244 let ka = socket2::TcpKeepalive::new()
2245 .with_time(std::time::Duration::from_secs(60))
2246 .with_interval(std::time::Duration::from_secs(15));
2247 sock.set_tcp_keepalive(&ka).map_err(DriverError::Io)?;
2248 Ok(())
2249 }
2250
2251 pub fn created_at(&self) -> std::time::Instant {
2253 self.created_at
2254 }
2255
2256 fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2266 if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
2268 if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2269 proto::write_close(&mut self.write_buf, b'S', &evicted.name);
2272 }
2273 }
2274 self.stmts.insert(sql_hash, info);
2275 }
2276
2277 fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2279 if self.pending_notifications.len() < 1024 {
2281 self.pending_notifications.push(Notification {
2282 pid,
2283 channel: channel.to_owned(),
2284 payload: payload.to_owned(),
2285 });
2286 }
2287 }
2288
2289 fn shrink_buffers(&mut self) {
2294 if self.query_counter & 63 != 0 {
2298 return;
2299 }
2300 if self.read_buf.capacity() > 64 * 1024 {
2301 self.read_buf.clear();
2302 self.read_buf.shrink_to(8192);
2303 }
2304 if self.write_buf.capacity() > 16 * 1024 {
2305 self.write_buf.clear();
2306 self.write_buf.shrink_to(8192);
2307 }
2308 }
2309
2310 async fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2316 loop {
2317 let (msg_type, _payload_len) = self.read_message_buffered().await?;
2318 if msg_type == b'A' {
2321 let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2322 if let BackendMessage::NotificationResponse {
2323 pid,
2324 channel,
2325 payload,
2326 } = msg
2327 {
2328 let pid_owned = pid;
2330 let channel_owned = channel.to_owned();
2331 let payload_owned = payload.to_owned();
2332 self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2333 continue; }
2335 }
2336 return proto::parse_backend_message(msg_type, &self.read_buf);
2337 }
2338 }
2339
2340 async fn expect_message(
2346 &mut self,
2347 pred: impl Fn(&BackendMessage<'_>) -> bool,
2348 ) -> Result<(), DriverError> {
2349 loop {
2350 let msg = self.read_one_message().await?;
2351 if pred(&msg) {
2352 return Ok(());
2353 }
2354 match msg {
2355 BackendMessage::ErrorResponse { data } => {
2356 let fields = proto::parse_error_response(data);
2357 self.drain_to_ready().await?;
2358 return Err(self.make_server_error(fields));
2359 }
2360 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {
2361 }
2364 other => {
2365 return Err(DriverError::Protocol(format!(
2366 "unexpected message while waiting for expected type: {other:?}"
2367 )));
2368 }
2369 }
2370 }
2371 }
2372
2373 async fn expect_ready(&mut self) -> Result<(), DriverError> {
2375 loop {
2376 let msg = self.read_one_message().await?;
2377 match msg {
2378 BackendMessage::ReadyForQuery { status } => {
2379 self.tx_status = status;
2380 return Ok(());
2381 }
2382 BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2383 BackendMessage::ErrorResponse { data } => {
2384 let fields = proto::parse_error_response(data);
2385 self.drain_to_ready().await?;
2387 return Err(self.make_server_error(fields));
2388 }
2389 _ => {}
2390 }
2391 }
2392 }
2393
2394 async fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2397 loop {
2398 let msg = self.read_one_message().await?;
2399 if let BackendMessage::ReadyForQuery { status } = msg {
2400 self.tx_status = status;
2401 return Ok(());
2402 }
2403 }
2404 }
2405
2406 fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2409 if &*fields.code == "26000" {
2410 self.stmts.remove(&sql_hash);
2411 true
2412 } else {
2413 false
2414 }
2415 }
2416
2417 #[cold]
2419 #[inline(never)]
2420 fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2421 DriverError::Server {
2422 code: fields.code,
2423 message: fields.message.into_boxed_str(),
2424 detail: fields.detail.map(String::into_boxed_str),
2425 hint: fields.hint.map(String::into_boxed_str),
2426 position: fields.position,
2427 }
2428 }
2429
2430 #[cold]
2435 async fn handle_non_datarow_async(
2436 &mut self,
2437 msg_type: u8,
2438 payload_start: usize,
2439 payload_end: usize,
2440 sql_hash: u64,
2441 ) -> Result<(), DriverError> {
2442 match msg_type {
2443 b'E' => {
2444 let fields =
2445 proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2446 self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2447 self.drain_to_ready().await?;
2448 return Err(self.make_server_error(fields));
2449 }
2450 b'A' => {
2451 let msg = proto::parse_backend_message(
2452 msg_type,
2453 &self.stream_buf[payload_start..payload_end],
2454 )?;
2455 if let BackendMessage::NotificationResponse {
2456 pid,
2457 channel,
2458 payload,
2459 } = msg
2460 {
2461 let ch = channel.to_owned();
2462 let pl = payload.to_owned();
2463 self.buffer_notification(pid, &ch, &pl);
2464 }
2465 }
2466 _ => {} }
2468 Ok(())
2469 }
2470
2471 async fn flush_write(&mut self) -> Result<(), DriverError> {
2478 self.stream
2479 .write_all(&self.write_buf)
2480 .await
2481 .map_err(DriverError::Io)?;
2482 self.stream.flush().await.map_err(DriverError::Io)?;
2483 Ok(())
2484 }
2485
2486 async fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2490 let mut header = [0u8; 5];
2492 buffered_read_exact(
2493 &mut self.stream,
2494 &mut self.stream_buf,
2495 &mut self.stream_buf_pos,
2496 &mut self.stream_buf_end,
2497 &mut header,
2498 )
2499 .await?;
2500
2501 let msg_type = header[0];
2502 let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2503
2504 if len < 4 {
2505 return Err(DriverError::Protocol(format!(
2506 "invalid message length {len} for type '{}'",
2507 msg_type as char
2508 )));
2509 }
2510
2511 const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2512 if len > MAX_MESSAGE_LEN {
2513 return Err(DriverError::Protocol(format!(
2514 "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2515 msg_type as char
2516 )));
2517 }
2518
2519 let payload_len = (len - 4) as usize;
2520
2521 self.read_buf.clear();
2526 self.read_buf.resize(payload_len, 0);
2527 if payload_len > 0 {
2528 buffered_read_exact(
2529 &mut self.stream,
2530 &mut self.stream_buf,
2531 &mut self.stream_buf_pos,
2532 &mut self.stream_buf_end,
2533 &mut self.read_buf[..payload_len],
2534 )
2535 .await?;
2536 }
2537
2538 Ok((msg_type, payload_len))
2539 }
2540}
2541
2542async fn buffered_read_exact(
2547 stream: &mut Stream,
2548 buf: &mut [u8],
2549 pos: &mut usize,
2550 end: &mut usize,
2551 out: &mut [u8],
2552) -> Result<(), DriverError> {
2553 let mut filled = 0;
2554 while filled < out.len() {
2555 let avail = *end - *pos;
2556 if avail > 0 {
2557 let take = avail.min(out.len() - filled);
2558 out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
2559 *pos += take;
2560 filled += take;
2561 } else {
2562 *pos = 0;
2564 let n = {
2565 let mut reader = StreamReader(stream);
2566 use tokio::io::AsyncReadExt;
2567 reader.read(buf).await.map_err(DriverError::Io)?
2568 };
2569 if n == 0 {
2570 return Err(DriverError::Io(std::io::Error::new(
2571 std::io::ErrorKind::UnexpectedEof,
2572 "connection closed",
2573 )));
2574 }
2575 *end = n;
2576 }
2577 }
2578 Ok(())
2579}
2580
2581fn build_bind_template(write_buf: &[u8], param_count: usize) -> Option<BindTemplate> {
2589 if write_buf.is_empty() || write_buf[0] != b'B' {
2590 return None;
2591 }
2592 if write_buf.len() < 5 {
2593 return None;
2594 }
2595
2596 let mut pos = 5; while pos < write_buf.len() && write_buf[pos] != 0 {
2600 pos += 1;
2601 }
2602 pos += 1;
2603
2604 while pos < write_buf.len() && write_buf[pos] != 0 {
2606 pos += 1;
2607 }
2608 pos += 1;
2609
2610 if pos + 2 > write_buf.len() {
2612 return None;
2613 }
2614 let num_fmt_codes = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]);
2615 pos += 2;
2616 pos += num_fmt_codes.max(0) as usize * 2;
2617
2618 if pos + 2 > write_buf.len() {
2620 return None;
2621 }
2622 let wire_param_count = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]) as usize;
2623 pos += 2;
2624
2625 if wire_param_count != param_count {
2626 return None;
2627 }
2628
2629 let mut param_slots = Vec::with_capacity(param_count);
2630 for _ in 0..param_count {
2631 if pos + 4 > write_buf.len() {
2632 return None;
2633 }
2634 let data_len = i32::from_be_bytes([
2635 write_buf[pos],
2636 write_buf[pos + 1],
2637 write_buf[pos + 2],
2638 write_buf[pos + 3],
2639 ]);
2640 pos += 4;
2641
2642 if data_len < 0 {
2643 param_slots.push((pos, -1));
2644 } else {
2645 param_slots.push((pos, data_len));
2646 pos += data_len as usize;
2647 }
2648 }
2649
2650 let bind_end = write_buf.len();
2652 let mut bytes = Vec::with_capacity(bind_end + proto::EXECUTE_SYNC.len());
2653 bytes.extend_from_slice(write_buf);
2654 bytes.extend_from_slice(proto::EXECUTE_SYNC);
2655
2656 Some(BindTemplate {
2657 bytes,
2658 bind_end,
2659 param_slots,
2660 })
2661}
2662
2663pub struct QueryResult {
2686 all_col_offsets: Vec<(usize, i32)>,
2689 num_cols: usize,
2691 columns: Arc<[ColumnDesc]>,
2692 affected_rows: u64,
2693}
2694
2695impl QueryResult {
2696 pub fn from_parts(
2700 all_col_offsets: Vec<(usize, i32)>,
2701 num_cols: usize,
2702 columns: Arc<[ColumnDesc]>,
2703 affected_rows: u64,
2704 ) -> Self {
2705 Self {
2706 all_col_offsets,
2707 num_cols,
2708 columns,
2709 affected_rows,
2710 }
2711 }
2712
2713 pub fn len(&self) -> usize {
2715 if self.num_cols == 0 {
2716 return 0;
2717 }
2718 self.all_col_offsets.len() / self.num_cols
2719 }
2720
2721 pub fn is_empty(&self) -> bool {
2723 self.all_col_offsets.is_empty()
2724 }
2725
2726 pub fn affected_rows(&self) -> u64 {
2728 self.affected_rows
2729 }
2730
2731 pub fn columns(&self) -> &[ColumnDesc] {
2733 &self.columns
2734 }
2735
2736 pub fn row<'a>(&'a self, idx: usize, arena: &'a Arena) -> Row<'a> {
2738 let start = idx * self.num_cols;
2739 let end = start + self.num_cols;
2740 Row {
2741 arena,
2742 col_offsets: &self.all_col_offsets[start..end],
2743 columns: &self.columns,
2744 }
2745 }
2746
2747 pub fn take_col_offsets(&mut self) -> Vec<(usize, i32)> {
2752 std::mem::take(&mut self.all_col_offsets)
2753 }
2754
2755 pub fn rows<'a>(&'a self, arena: &'a Arena) -> impl Iterator<Item = Row<'a>> {
2757 let num_cols = self.num_cols;
2758 let columns = &self.columns;
2759 self.all_col_offsets
2760 .chunks(num_cols.max(1))
2763 .map(move |chunk| Row {
2764 arena,
2765 col_offsets: chunk,
2766 columns,
2767 })
2768 }
2769}
2770
2771pub struct Row<'a> {
2780 arena: &'a Arena,
2781 col_offsets: &'a [(usize, i32)],
2782 columns: &'a [ColumnDesc],
2783}
2784
2785impl<'a> Row<'a> {
2786 pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
2788 let (offset, len) = self.col_offsets[idx];
2789 if len < 0 {
2790 None
2791 } else {
2792 Some(self.arena.get(offset, len as usize))
2793 }
2794 }
2795
2796 pub fn is_null(&self, idx: usize) -> bool {
2798 self.col_offsets[idx].1 < 0
2799 }
2800
2801 pub fn column_count(&self) -> usize {
2803 self.col_offsets.len()
2804 }
2805
2806 pub fn get_bool(&self, idx: usize) -> Option<bool> {
2808 self.get_raw(idx)
2809 .and_then(|data| crate::codec::decode_bool(data).ok())
2810 }
2811
2812 pub fn get_i16(&self, idx: usize) -> Option<i16> {
2814 self.get_raw(idx)
2815 .and_then(|data| crate::codec::decode_i16(data).ok())
2816 }
2817
2818 pub fn get_i32(&self, idx: usize) -> Option<i32> {
2820 self.get_raw(idx)
2821 .and_then(|data| crate::codec::decode_i32(data).ok())
2822 }
2823
2824 pub fn get_i64(&self, idx: usize) -> Option<i64> {
2826 self.get_raw(idx)
2827 .and_then(|data| crate::codec::decode_i64(data).ok())
2828 }
2829
2830 pub fn get_f32(&self, idx: usize) -> Option<f32> {
2832 self.get_raw(idx)
2833 .and_then(|data| crate::codec::decode_f32(data).ok())
2834 }
2835
2836 pub fn get_f64(&self, idx: usize) -> Option<f64> {
2838 self.get_raw(idx)
2839 .and_then(|data| crate::codec::decode_f64(data).ok())
2840 }
2841
2842 pub fn get_str(&self, idx: usize) -> Option<&'a str> {
2844 self.get_raw(idx)
2845 .and_then(|data| crate::codec::decode_str(data).ok())
2846 }
2847
2848 pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
2850 self.get_raw(idx)
2851 }
2852
2853 pub fn column_name(&self, idx: usize) -> &str {
2855 &self.columns[idx].name
2856 }
2857
2858 pub fn column_type_oid(&self, idx: usize) -> u32 {
2860 self.columns[idx].type_oid
2861 }
2862}
2863
2864pub struct PgDataRow<'a> {
2875 data: &'a [u8],
2876 offsets: smallvec::SmallVec<[(usize, i32); 16]>,
2879}
2880
2881impl<'a> PgDataRow<'a> {
2882 pub fn new(data: &'a [u8]) -> Result<Self, DriverError> {
2887 if data.len() < 2 {
2888 return Err(DriverError::Protocol("DataRow too short".into()));
2889 }
2890 let num_cols = i16::from_be_bytes([data[0], data[1]]);
2891 if num_cols < 0 {
2892 return Err(DriverError::Protocol(
2893 "DataRow: negative column count".into(),
2894 ));
2895 }
2896 let num_cols = num_cols as usize;
2897 let mut offsets = smallvec::SmallVec::<[(usize, i32); 16]>::with_capacity(num_cols);
2898 let mut pos = 2usize;
2899 for _ in 0..num_cols {
2900 if pos + 4 > data.len() {
2901 return Err(DriverError::Protocol("DataRow truncated".into()));
2902 }
2903 let col_len =
2904 i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
2905 pos += 4;
2906 offsets.push((pos, col_len));
2907 if col_len > 0 {
2908 pos += col_len as usize;
2909 }
2910 }
2911 Ok(Self { data, offsets })
2912 }
2913
2914 #[inline]
2916 pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
2917 let (offset, len) = self.offsets[idx];
2918 if len < 0 {
2919 None
2920 } else {
2921 Some(&self.data[offset..offset + len as usize])
2922 }
2923 }
2924
2925 #[inline]
2927 pub fn is_null(&self, idx: usize) -> bool {
2928 self.offsets[idx].1 < 0
2929 }
2930
2931 #[inline]
2933 pub fn column_count(&self) -> usize {
2934 self.offsets.len()
2935 }
2936
2937 #[inline]
2939 pub fn get_bool(&self, idx: usize) -> Option<bool> {
2940 self.get_raw(idx)
2941 .and_then(|data| crate::codec::decode_bool(data).ok())
2942 }
2943
2944 #[inline]
2946 pub fn get_i16(&self, idx: usize) -> Option<i16> {
2947 self.get_raw(idx)
2948 .and_then(|data| crate::codec::decode_i16(data).ok())
2949 }
2950
2951 #[inline]
2953 pub fn get_i32(&self, idx: usize) -> Option<i32> {
2954 self.get_raw(idx)
2955 .and_then(|data| crate::codec::decode_i32(data).ok())
2956 }
2957
2958 #[inline]
2960 pub fn get_i64(&self, idx: usize) -> Option<i64> {
2961 self.get_raw(idx)
2962 .and_then(|data| crate::codec::decode_i64(data).ok())
2963 }
2964
2965 #[inline]
2967 pub fn get_f32(&self, idx: usize) -> Option<f32> {
2968 self.get_raw(idx)
2969 .and_then(|data| crate::codec::decode_f32(data).ok())
2970 }
2971
2972 #[inline]
2974 pub fn get_f64(&self, idx: usize) -> Option<f64> {
2975 self.get_raw(idx)
2976 .and_then(|data| crate::codec::decode_f64(data).ok())
2977 }
2978
2979 #[inline]
2981 pub fn get_str(&self, idx: usize) -> Option<&'a str> {
2982 self.get_raw(idx)
2983 .and_then(|data| crate::codec::decode_str(data).ok())
2984 }
2985
2986 #[inline]
2988 pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
2989 self.get_raw(idx)
2990 }
2991}
2992
2993fn parse_data_row_flat(
3002 data: &[u8],
3003 arena: &mut Arena,
3004 out: &mut Vec<(usize, i32)>,
3005) -> Result<(), DriverError> {
3006 if data.len() < 2 {
3007 return Err(DriverError::Protocol("DataRow too short".into()));
3008 }
3009
3010 let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3011 if num_cols_raw < 0 {
3012 return Err(DriverError::Protocol(
3013 "DataRow: negative column count".into(),
3014 ));
3015 }
3016 let num_cols = num_cols_raw as usize;
3017 out.reserve(num_cols);
3018 let mut pos = 2;
3019
3020 for _ in 0..num_cols {
3021 if pos + 4 > data.len() {
3022 return Err(DriverError::Protocol("DataRow truncated".into()));
3023 }
3024
3025 let col_len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3026 pos += 4;
3027
3028 if col_len < 0 {
3029 out.push((0, -1));
3031 } else {
3032 let len = col_len as usize;
3033 if pos + len > data.len() {
3034 return Err(DriverError::Protocol(
3035 "DataRow column data truncated".into(),
3036 ));
3037 }
3038
3039 let offset = arena.alloc_copy(&data[pos..pos + len]);
3040 out.push((offset, col_len));
3041 pos += len;
3042 }
3043 }
3044
3045 Ok(())
3046}
3047
3048pub fn hash_sql(sql: &str) -> u64 {
3052 use std::hash::{Hash, Hasher};
3053 let mut hasher = RapidHasher::default();
3054 sql.hash(&mut hasher);
3055 hasher.finish()
3056}
3057
3058#[cfg(test)]
3059#[allow(clippy::approx_constant)]
3060mod tests {
3061 use super::*;
3062
3063 #[test]
3064 fn config_parse_full_url() {
3065 let cfg = Config::from_url("postgres://user:pass@localhost:5432/mydb").unwrap();
3066 assert_eq!(cfg.user, "user");
3067 assert_eq!(cfg.password, "pass");
3068 assert_eq!(cfg.host, "localhost");
3069 assert_eq!(cfg.port, 5432);
3070 assert_eq!(cfg.database, "mydb");
3071 }
3072
3073 #[test]
3074 fn config_parse_default_port() {
3075 let cfg = Config::from_url("postgres://user:pass@localhost/mydb").unwrap();
3076 assert_eq!(cfg.port, 5432);
3077 }
3078
3079 #[test]
3080 fn config_parse_no_password() {
3081 let cfg = Config::from_url("postgres://user@localhost/mydb").unwrap();
3082 assert_eq!(cfg.user, "user");
3083 assert_eq!(cfg.password, "");
3084 }
3085
3086 #[test]
3087 fn config_parse_empty_database() {
3088 let cfg = Config::from_url("postgres://user:pass@localhost").unwrap();
3089 assert_eq!(cfg.database, "user");
3091 }
3092
3093 #[test]
3094 fn config_parse_sslmode() {
3095 let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
3096 assert_eq!(cfg.ssl, SslMode::Require);
3097 }
3098
3099 #[test]
3100 fn config_parse_percent_encoding() {
3101 let cfg = Config::from_url("postgres://user%40domain:p%40ss@localhost/db").unwrap();
3102 assert_eq!(cfg.user, "user@domain");
3103 assert_eq!(cfg.password, "p@ss");
3104 }
3105
3106 #[test]
3107 fn config_rejects_bad_scheme() {
3108 let result = Config::from_url("mysql://user:pass@localhost/db");
3109 assert!(result.is_err());
3110 }
3111
3112 #[test]
3114 fn config_rejects_unknown_sslmode() {
3115 let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=requre");
3116 assert!(result.is_err(), "typo 'requre' should be rejected");
3117 let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=REQUIRE");
3118 assert!(result.is_err(), "uppercase should be rejected");
3119 let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=bogus");
3120 assert!(result.is_err(), "bogus value should be rejected");
3121 }
3122
3123 #[test]
3125 fn config_accepts_valid_sslmodes() {
3126 let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=disable").unwrap();
3127 assert_eq!(cfg.ssl, SslMode::Disable);
3128 let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=prefer").unwrap();
3129 assert_eq!(cfg.ssl, SslMode::Prefer);
3130 let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
3131 assert_eq!(cfg.ssl, SslMode::Require);
3132 }
3133
3134 #[test]
3136 fn stmt_cache_basic_ops() {
3137 let mut cache = StmtCache::default();
3138 assert_eq!(cache.len(), 0);
3139 assert!(!cache.contains_key(&42));
3140 assert!(cache.get(&42).is_none());
3141 assert!(cache.get_mut(&42).is_none());
3142 assert!(cache.remove(&42).is_none());
3143 }
3144
3145 #[test]
3147 fn stmt_name_format() {
3148 let name = make_stmt_name(0);
3149 assert_eq!(&*name, "s_0000000000000000");
3150 let name = make_stmt_name(0xDEADBEEF12345678);
3151 assert_eq!(&*name, "s_deadbeef12345678");
3152 let name = make_stmt_name(u64::MAX);
3153 assert_eq!(&*name, "s_ffffffffffffffff");
3154 }
3155
3156 #[test]
3157 fn hash_sql_deterministic() {
3158 let h1 = hash_sql("SELECT 1");
3159 let h2 = hash_sql("SELECT 1");
3160 assert_eq!(h1, h2);
3161 }
3162
3163 #[test]
3164 fn hash_sql_different_queries() {
3165 let h1 = hash_sql("SELECT 1");
3166 let h2 = hash_sql("SELECT 2");
3167 assert_ne!(h1, h2);
3168 }
3169
3170 #[test]
3171 fn data_row_parsing() {
3172 let mut arena = Arena::new();
3173 let mut out = Vec::new();
3174
3175 let mut data = Vec::new();
3177 data.extend_from_slice(&2i16.to_be_bytes()); data.extend_from_slice(&4i32.to_be_bytes()); data.extend_from_slice(&42i32.to_be_bytes()); data.extend_from_slice(&(-1i32).to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3187 assert_eq!(out.len(), 2);
3188
3189 assert_eq!(out[0].1, 4);
3191
3192 assert_eq!(out[1].1, -1);
3194 }
3195
3196 #[test]
3197 fn data_row_empty() {
3198 let mut arena = Arena::new();
3199 let mut out = Vec::new();
3200 let data = 0i16.to_be_bytes();
3201 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3202 assert_eq!(out.len(), 0);
3203 }
3204
3205 #[test]
3206 fn query_result_empty() {
3207 let result = QueryResult {
3208 all_col_offsets: vec![],
3209 num_cols: 0,
3210 columns: Arc::from(Vec::new()),
3211 affected_rows: 0,
3212 };
3213 assert!(result.is_empty());
3214 assert_eq!(result.len(), 0);
3215 }
3216
3217 #[test]
3218 fn url_decode_works() {
3219 assert_eq!(url_decode("hello%20world").unwrap(), "hello world");
3220 assert_eq!(url_decode("no%20escape").unwrap(), "no escape");
3221 assert_eq!(url_decode("plain").unwrap(), "plain");
3222 assert_eq!(url_decode("a%40b").unwrap(), "a@b");
3223 }
3224
3225 #[test]
3226 fn url_decode_malformed_percent_trailing() {
3227 let result = url_decode("abc%2");
3229 assert!(result.is_err(), "truncated %2 should error");
3230 }
3231
3232 #[test]
3233 fn url_decode_malformed_percent_no_digits() {
3234 let result = url_decode("abc%");
3236 assert!(result.is_err(), "bare % at end should error");
3237 }
3238
3239 #[test]
3240 fn url_decode_invalid_hex_digit() {
3241 let result = url_decode("abc%GG");
3243 assert!(result.is_err(), "%GG should error");
3244 }
3245
3246 #[test]
3247 fn url_decode_invalid_hex_second_digit() {
3248 let result = url_decode("abc%2Z");
3250 assert!(result.is_err(), "%2Z should error");
3251 }
3252
3253 #[test]
3255 fn url_decode_invalid_utf8_percent() {
3256 let result = url_decode("%80%81");
3258 assert!(result.is_err(), "invalid UTF-8 bytes should error");
3259 }
3260
3261 #[test]
3263 fn url_decode_percent_everywhere() {
3264 assert_eq!(url_decode("%41%42%43").unwrap(), "ABC");
3265 assert_eq!(url_decode("%61").unwrap(), "a");
3266 assert_eq!(url_decode("x%2Fy%2Fz").unwrap(), "x/y/z");
3267 }
3268
3269 #[test]
3271 fn url_decode_bare_percent_middle() {
3272 assert!(url_decode("a%b").is_err(), "bare % in middle should error");
3273 }
3274
3275 #[test]
3277 fn url_decode_multibyte_utf8() {
3278 let result = url_decode("caf%C3%A9").unwrap();
3279 assert_eq!(result, "caf\u{00e9}"); }
3281
3282 #[test]
3286 fn config_parse_postgresql_scheme() {
3287 let cfg = Config::from_url("postgresql://user:pass@localhost:5432/mydb").unwrap();
3288 assert_eq!(cfg.user, "user");
3289 assert_eq!(cfg.password, "pass");
3290 assert_eq!(cfg.host, "localhost");
3291 assert_eq!(cfg.port, 5432);
3292 assert_eq!(cfg.database, "mydb");
3293 }
3294
3295 #[test]
3297 fn config_parse_no_password_standalone() {
3298 let cfg = Config::from_url("postgres://admin@db.example.com/myapp").unwrap();
3299 assert_eq!(cfg.user, "admin");
3300 assert_eq!(cfg.password, "");
3301 assert_eq!(cfg.host, "db.example.com");
3302 assert_eq!(cfg.database, "myapp");
3303 }
3304
3305 #[test]
3307 fn config_empty_database_falls_back_to_user() {
3308 let cfg = Config::from_url("postgres://testuser:pass@localhost").unwrap();
3309 assert_eq!(cfg.database, "testuser");
3310 }
3311
3312 #[test]
3314 fn config_unknown_sslmode_error() {
3315 let result = Config::from_url("postgres://u:p@h/d?sslmode=verify-full");
3316 assert!(result.is_err());
3317 let err = result.unwrap_err().to_string();
3318 assert!(
3319 err.contains("unknown sslmode"),
3320 "should describe unknown sslmode: {err}"
3321 );
3322 }
3323
3324 #[test]
3326 fn config_multiple_query_params() {
3327 let cfg = Config::from_url(
3328 "postgres://user:pass@localhost/db?sslmode=disable&statement_timeout=60",
3329 )
3330 .unwrap();
3331 assert_eq!(cfg.ssl, SslMode::Disable);
3332 assert_eq!(cfg.statement_timeout_secs, 60);
3333 }
3334
3335 #[test]
3337 fn url_decode_invalid_percent_zz() {
3338 let result = url_decode("abc%ZZ");
3339 assert!(result.is_err(), "%ZZ should error");
3340 }
3341
3342 #[test]
3344 fn url_decode_truncated_percent_trailing() {
3345 let result = url_decode("abc%");
3346 assert!(result.is_err(), "trailing % should error");
3347 }
3348
3349 #[test]
3351 fn url_decode_invalid_utf8() {
3352 let result = url_decode("%80");
3354 assert!(result.is_err(), "invalid UTF-8 should error");
3355 }
3356
3357 #[cfg(not(feature = "tls"))]
3359 #[test]
3360 fn config_sslmode_require_without_tls_feature() {
3361 let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
3364 assert_eq!(cfg.ssl, SslMode::Require);
3365 }
3366
3367 #[test]
3369 fn stmt_name_format_verification() {
3370 let name = make_stmt_name(0xDEADBEEFCAFEBABE);
3371 assert!(name.starts_with("s_"), "must start with s_");
3372 assert_eq!(name.len(), 18, "s_ (2) + 16 hex = 18");
3373 assert!(
3374 name[2..].chars().all(|c| c.is_ascii_hexdigit()),
3375 "remaining chars must be hex: {}",
3376 &*name
3377 );
3378 }
3379
3380 #[test]
3382 fn stmt_name_zero() {
3383 let name = make_stmt_name(0);
3384 assert_eq!(&*name, "s_0000000000000000");
3385 }
3386
3387 #[test]
3389 fn stmt_name_max() {
3390 let name = make_stmt_name(u64::MAX);
3391 assert_eq!(&*name, "s_ffffffffffffffff");
3392 }
3393
3394 #[test]
3396 fn config_validate_empty_host() {
3397 let cfg = Config {
3398 host: String::new(),
3399 port: 5432,
3400 user: "user".into(),
3401 password: "pass".into(),
3402 database: "db".into(),
3403 ssl: SslMode::Disable,
3404 statement_timeout_secs: 30,
3405 };
3406 assert!(cfg.validate().is_err());
3407 }
3408
3409 #[test]
3411 fn config_validate_empty_user() {
3412 let cfg = Config {
3413 host: "localhost".into(),
3414 port: 5432,
3415 user: String::new(),
3416 password: "pass".into(),
3417 database: "db".into(),
3418 ssl: SslMode::Disable,
3419 statement_timeout_secs: 30,
3420 };
3421 assert!(cfg.validate().is_err());
3422 }
3423
3424 #[test]
3426 fn config_validate_empty_database() {
3427 let cfg = Config {
3428 host: "localhost".into(),
3429 port: 5432,
3430 user: "user".into(),
3431 password: "pass".into(),
3432 database: String::new(),
3433 ssl: SslMode::Disable,
3434 statement_timeout_secs: 30,
3435 };
3436 assert!(cfg.validate().is_err());
3437 }
3438
3439 #[test]
3441 fn config_missing_at_sign() {
3442 let result = Config::from_url("postgres://userpasslocalhost/db");
3443 assert!(result.is_err());
3444 }
3445
3446 #[test]
3448 fn config_custom_port() {
3449 let cfg = Config::from_url("postgres://user:pass@localhost:5433/db").unwrap();
3450 assert_eq!(cfg.port, 5433);
3451 }
3452
3453 #[test]
3455 fn config_invalid_port() {
3456 let result = Config::from_url("postgres://user:pass@localhost:notaport/db");
3457 assert!(result.is_err());
3458 }
3459
3460 #[test]
3463 fn notification_struct_fields() {
3464 let n = Notification {
3465 pid: 42,
3466 channel: "test_chan".to_owned(),
3467 payload: "hello".to_owned(),
3468 };
3469 assert_eq!(n.pid, 42);
3470 assert_eq!(n.channel, "test_chan");
3471 assert_eq!(n.payload, "hello");
3472 }
3473
3474 #[test]
3475 fn notification_clone() {
3476 let n = Notification {
3477 pid: 1,
3478 channel: "c".to_owned(),
3479 payload: "p".to_owned(),
3480 };
3481 let n2 = n.clone();
3482 assert_eq!(n2.pid, 1);
3483 assert_eq!(n2.channel, "c");
3484 }
3485
3486 #[test]
3487 fn notification_debug() {
3488 let n = Notification {
3489 pid: 1,
3490 channel: "c".to_owned(),
3491 payload: "p".to_owned(),
3492 };
3493 let dbg = format!("{n:?}");
3494 assert!(dbg.contains("Notification"));
3495 }
3496
3497 #[test]
3500 fn stmt_info_has_last_used_counter() {
3501 let info = StmtInfo {
3502 name: "s_test".into(),
3503 columns: Arc::from(Vec::new()),
3504 last_used: 42,
3505 bind_template: None,
3506 };
3507 assert_eq!(info.last_used, 42);
3509 }
3510
3511 fn make_data_row(columns: &[Option<&[u8]>]) -> Vec<u8> {
3516 let mut buf = Vec::new();
3517 buf.extend_from_slice(&(columns.len() as i16).to_be_bytes());
3518 for col in columns {
3519 match col {
3520 Some(data) => {
3521 buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
3522 buf.extend_from_slice(data);
3523 }
3524 None => {
3525 buf.extend_from_slice(&(-1i32).to_be_bytes());
3526 }
3527 }
3528 }
3529 buf
3530 }
3531
3532 #[test]
3533 fn pg_data_row_get_i32() {
3534 let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
3535 let row = PgDataRow::new(&data).unwrap();
3536 assert_eq!(row.get_i32(0), Some(42));
3537 assert_eq!(row.column_count(), 1);
3538 }
3539
3540 #[test]
3541 fn pg_data_row_get_i64() {
3542 let data = make_data_row(&[Some(&12345i64.to_be_bytes())]);
3543 let row = PgDataRow::new(&data).unwrap();
3544 assert_eq!(row.get_i64(0), Some(12345));
3545 }
3546
3547 #[test]
3548 fn pg_data_row_get_str() {
3549 let data = make_data_row(&[Some(b"hello")]);
3550 let row = PgDataRow::new(&data).unwrap();
3551 assert_eq!(row.get_str(0), Some("hello"));
3552 }
3553
3554 #[test]
3555 fn pg_data_row_get_bytes() {
3556 let data = make_data_row(&[Some(&[0xDE, 0xAD, 0xBE, 0xEF])]);
3557 let row = PgDataRow::new(&data).unwrap();
3558 assert_eq!(row.get_bytes(0), Some(&[0xDE, 0xAD, 0xBE, 0xEF][..]));
3559 }
3560
3561 #[test]
3562 fn pg_data_row_get_bool() {
3563 let data = make_data_row(&[Some(&[1u8])]);
3564 let row = PgDataRow::new(&data).unwrap();
3565 assert_eq!(row.get_bool(0), Some(true));
3566
3567 let data = make_data_row(&[Some(&[0u8])]);
3568 let row = PgDataRow::new(&data).unwrap();
3569 assert_eq!(row.get_bool(0), Some(false));
3570 }
3571
3572 #[test]
3573 fn pg_data_row_get_f64() {
3574 let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
3575 let row = PgDataRow::new(&data).unwrap();
3576 assert!((row.get_f64(0).unwrap() - 3.14).abs() < 1e-10);
3577 }
3578
3579 #[test]
3580 fn pg_data_row_null_column() {
3581 let data = make_data_row(&[None]);
3582 let row = PgDataRow::new(&data).unwrap();
3583 assert!(row.is_null(0));
3584 assert_eq!(row.get_i32(0), None);
3585 assert_eq!(row.get_str(0), None);
3586 }
3587
3588 #[test]
3589 fn pg_data_row_multiple_columns() {
3590 let data = make_data_row(&[
3591 Some(&42i32.to_be_bytes()),
3592 Some(b"alice"),
3593 Some(b"alice@example.com"),
3594 Some(&[1u8]),
3595 Some(&3.14f64.to_be_bytes()),
3596 ]);
3597 let row = PgDataRow::new(&data).unwrap();
3598 assert_eq!(row.column_count(), 5);
3599 assert_eq!(row.get_i32(0), Some(42));
3600 assert_eq!(row.get_str(1), Some("alice"));
3601 assert_eq!(row.get_str(2), Some("alice@example.com"));
3602 assert_eq!(row.get_bool(3), Some(true));
3603 assert!((row.get_f64(4).unwrap() - 3.14).abs() < 1e-10);
3604 }
3605
3606 #[test]
3607 fn pg_data_row_mixed_null() {
3608 let data = make_data_row(&[Some(&42i32.to_be_bytes()), None, Some(b"text")]);
3609 let row = PgDataRow::new(&data).unwrap();
3610 assert_eq!(row.get_i32(0), Some(42));
3611 assert!(row.is_null(1));
3612 assert_eq!(row.get_str(1), None);
3613 assert_eq!(row.get_str(2), Some("text"));
3614 }
3615
3616 #[test]
3617 fn pg_data_row_empty() {
3618 let data = make_data_row(&[]);
3619 let row = PgDataRow::new(&data).unwrap();
3620 assert_eq!(row.column_count(), 0);
3621 }
3622
3623 #[test]
3624 fn pg_data_row_too_short() {
3625 let data = vec![0u8]; assert!(PgDataRow::new(&data).is_err());
3627 }
3628
3629 #[test]
3630 fn pg_data_row_truncated() {
3631 let mut data = Vec::new();
3633 data.extend_from_slice(&2i16.to_be_bytes());
3634 data.extend_from_slice(&4i32.to_be_bytes());
3635 data.extend_from_slice(&42i32.to_be_bytes());
3636 assert!(PgDataRow::new(&data).is_err());
3638 }
3639
3640 #[test]
3641 fn pg_data_row_get_i16() {
3642 let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
3643 let row = PgDataRow::new(&data).unwrap();
3644 assert_eq!(row.get_i16(0), Some(7));
3645 }
3646
3647 #[test]
3648 fn pg_data_row_get_f32() {
3649 let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
3650 let row = PgDataRow::new(&data).unwrap();
3651 assert!((row.get_f32(0).unwrap() - 2.5).abs() < 1e-6);
3652 }
3653
3654 #[test]
3655 fn pg_data_row_get_raw_null() {
3656 let data = make_data_row(&[None]);
3657 let row = PgDataRow::new(&data).unwrap();
3658 assert_eq!(row.get_raw(0), None);
3659 }
3660
3661 #[test]
3662 fn pg_data_row_get_raw_data() {
3663 let data = make_data_row(&[Some(&[1, 2, 3])]);
3664 let row = PgDataRow::new(&data).unwrap();
3665 assert_eq!(row.get_raw(0), Some(&[1u8, 2, 3][..]));
3666 }
3667
3668 #[test]
3669 fn pg_data_row_stack_alloc_16_columns() {
3670 let cols: Vec<Option<&[u8]>> = (0..16).map(|_| Some(&[0u8][..])).collect();
3672 let data = make_data_row(&cols);
3673 let row = PgDataRow::new(&data).unwrap();
3674 assert_eq!(row.column_count(), 16);
3675 for i in 0..16 {
3677 assert_eq!(row.get_raw(i), Some(&[0u8][..]));
3678 }
3679 }
3680
3681 #[test]
3686 fn inline_sequential_decode_five_columns() {
3687 let data = make_data_row(&[
3688 Some(&42i32.to_be_bytes()),
3689 Some(b"alice"),
3690 Some(b"alice@example.com"),
3691 Some(&[1u8]),
3692 Some(&3.14f64.to_be_bytes()),
3693 ]);
3694
3695 let mut pos: usize = 2; let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3700 pos += 4;
3701 assert_eq!(len, 4);
3702 let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3703 pos += len as usize;
3704 assert_eq!(id, 42);
3705
3706 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3708 pos += 4;
3709 assert_eq!(len, 5);
3710 let name = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3711 pos += len as usize;
3712 assert_eq!(name, "alice");
3713
3714 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3716 pos += 4;
3717 let email = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3718 pos += len as usize;
3719 assert_eq!(email, "alice@example.com");
3720
3721 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3723 pos += 4;
3724 assert_eq!(len, 1);
3725 let active = data[pos] != 0;
3726 pos += len as usize;
3727 assert!(active);
3728
3729 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3731 pos += 4;
3732 assert_eq!(len, 8);
3733 let score = f64::from_be_bytes([
3734 data[pos],
3735 data[pos + 1],
3736 data[pos + 2],
3737 data[pos + 3],
3738 data[pos + 4],
3739 data[pos + 5],
3740 data[pos + 6],
3741 data[pos + 7],
3742 ]);
3743 pos += len as usize;
3744 assert!((score - 3.14).abs() < 1e-10);
3745 assert_eq!(pos, data.len());
3746 }
3747
3748 #[test]
3750 fn inline_sequential_decode_with_nulls() {
3751 let data = make_data_row(&[
3752 Some(&42i32.to_be_bytes()),
3753 None, Some(b"text"),
3755 ]);
3756
3757 let mut pos: usize = 2;
3758
3759 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3761 pos += 4;
3762 let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3763 pos += len as usize;
3764 assert_eq!(id, 42);
3765
3766 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3768 pos += 4;
3769 let name: Option<&str> = if len < 0 {
3770 None
3771 } else {
3772 let s = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3773 pos += len as usize;
3774 Some(s)
3775 };
3776 assert!(name.is_none());
3777
3778 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3780 pos += 4;
3781 let txt = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3782 pos += len as usize;
3783 assert_eq!(txt, "text");
3784 assert_eq!(pos, data.len());
3785 }
3786
3787 #[test]
3789 fn inline_sequential_decode_all_scalar_types() {
3790 let data = make_data_row(&[
3791 Some(&[1u8]), Some(&7i16.to_be_bytes()), Some(&42i32.to_be_bytes()), Some(&12345i64.to_be_bytes()), Some(&2.5f32.to_be_bytes()), Some(&3.14f64.to_be_bytes()), ]);
3798
3799 let mut pos: usize = 2;
3800
3801 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3803 pos += 4;
3804 let v_bool = data[pos] != 0;
3805 pos += len as usize;
3806 assert!(v_bool);
3807
3808 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3810 pos += 4;
3811 let v_i16 = i16::from_be_bytes([data[pos], data[pos + 1]]);
3812 pos += len as usize;
3813 assert_eq!(v_i16, 7);
3814
3815 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3817 pos += 4;
3818 let v_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3819 pos += len as usize;
3820 assert_eq!(v_i32, 42);
3821
3822 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3824 pos += 4;
3825 let v_i64 = i64::from_be_bytes([
3826 data[pos],
3827 data[pos + 1],
3828 data[pos + 2],
3829 data[pos + 3],
3830 data[pos + 4],
3831 data[pos + 5],
3832 data[pos + 6],
3833 data[pos + 7],
3834 ]);
3835 pos += len as usize;
3836 assert_eq!(v_i64, 12345);
3837
3838 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3840 pos += 4;
3841 let v_f32 = f32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3842 pos += len as usize;
3843 assert!((v_f32 - 2.5).abs() < 1e-6);
3844
3845 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3847 pos += 4;
3848 let v_f64 = f64::from_be_bytes([
3849 data[pos],
3850 data[pos + 1],
3851 data[pos + 2],
3852 data[pos + 3],
3853 data[pos + 4],
3854 data[pos + 5],
3855 data[pos + 6],
3856 data[pos + 7],
3857 ]);
3858 pos += len as usize;
3859 assert!((v_f64 - 3.14).abs() < 1e-10);
3860 assert_eq!(pos, data.len());
3861 }
3862
3863 #[test]
3865 fn pg_data_row_new_is_public() {
3866 let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
3867 let row = PgDataRow::new(&data).unwrap();
3869 assert_eq!(row.get_i32(0), Some(42));
3870 }
3871
3872 #[test]
3874 fn inline_decode_matches_pgdatarow() {
3875 let data = make_data_row(&[
3876 Some(&99i32.to_be_bytes()),
3877 Some(b"hello world"),
3878 None,
3879 Some(&[0u8]),
3880 Some(&1.23f64.to_be_bytes()),
3881 ]);
3882
3883 let row = PgDataRow::new(&data).unwrap();
3885 let dr_i32 = row.get_i32(0);
3886 let dr_str = row.get_str(1);
3887 let dr_null = row.get_str(2);
3888 let dr_bool = row.get_bool(3);
3889 let dr_f64 = row.get_f64(4);
3890
3891 let mut pos: usize = 2;
3893
3894 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3895 pos += 4;
3896 let in_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3897 pos += len as usize;
3898
3899 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3900 pos += 4;
3901 let in_str = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3902 pos += len as usize;
3903
3904 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3905 pos += 4;
3906 let in_null: Option<&str> = if len < 0 { None } else { unreachable!() };
3907
3908 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3909 pos += 4;
3910 let in_bool = data[pos] != 0;
3911 pos += len as usize;
3912
3913 let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3914 pos += 4;
3915 let in_f64 = f64::from_be_bytes([
3916 data[pos],
3917 data[pos + 1],
3918 data[pos + 2],
3919 data[pos + 3],
3920 data[pos + 4],
3921 data[pos + 5],
3922 data[pos + 6],
3923 data[pos + 7],
3924 ]);
3925 pos += len as usize;
3926
3927 assert_eq!(dr_i32, Some(in_i32));
3929 assert_eq!(dr_str, Some(in_str));
3930 assert_eq!(dr_null, in_null);
3931 assert_eq!(dr_bool, Some(in_bool));
3932 assert!((dr_f64.unwrap() - in_f64).abs() < 1e-15);
3933 assert_eq!(pos, data.len());
3934 }
3935
3936 #[test]
3939 fn config_host_is_uds_absolute_path() {
3940 let cfg = Config {
3941 host: "/tmp".into(),
3942 port: 5432,
3943 user: "user".into(),
3944 password: "".into(),
3945 database: "db".into(),
3946 ssl: SslMode::Disable,
3947 statement_timeout_secs: 30,
3948 };
3949 assert!(cfg.host_is_uds());
3950 assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
3951 }
3952
3953 #[test]
3954 fn config_host_is_uds_var_run() {
3955 let cfg = Config {
3956 host: "/var/run/postgresql".into(),
3957 port: 5433,
3958 user: "user".into(),
3959 password: "".into(),
3960 database: "db".into(),
3961 ssl: SslMode::Disable,
3962 statement_timeout_secs: 30,
3963 };
3964 assert!(cfg.host_is_uds());
3965 assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
3966 }
3967
3968 #[test]
3969 fn config_host_is_not_uds_for_hostname() {
3970 let cfg = Config {
3971 host: "localhost".into(),
3972 port: 5432,
3973 user: "user".into(),
3974 password: "".into(),
3975 database: "db".into(),
3976 ssl: SslMode::Disable,
3977 statement_timeout_secs: 30,
3978 };
3979 assert!(!cfg.host_is_uds());
3980 }
3981
3982 #[test]
3983 fn config_host_is_not_uds_for_ip() {
3984 let cfg = Config {
3985 host: "127.0.0.1".into(),
3986 port: 5432,
3987 user: "user".into(),
3988 password: "".into(),
3989 database: "db".into(),
3990 ssl: SslMode::Disable,
3991 statement_timeout_secs: 30,
3992 };
3993 assert!(!cfg.host_is_uds());
3994 }
3995
3996 #[test]
3997 fn config_parse_uds_host_query_param() {
3998 let cfg = Config::from_url("postgres://user@localhost/mydb?host=/tmp").unwrap();
3999 assert_eq!(cfg.host, "/tmp");
4000 assert!(cfg.host_is_uds());
4001 assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
4002 assert_eq!(cfg.database, "mydb");
4003 assert_eq!(cfg.user, "user");
4004 }
4005
4006 #[test]
4007 fn config_parse_uds_host_query_param_custom_port() {
4008 let cfg = Config::from_url("postgres://user@localhost:5433/mydb?host=/var/run/postgresql")
4009 .unwrap();
4010 assert_eq!(cfg.host, "/var/run/postgresql");
4011 assert_eq!(cfg.port, 5433);
4012 assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
4013 }
4014
4015 #[test]
4016 fn config_parse_uds_host_with_other_params() {
4017 let cfg = Config::from_url(
4018 "postgres://user@localhost/db?host=/tmp&sslmode=disable&statement_timeout=60",
4019 )
4020 .unwrap();
4021 assert_eq!(cfg.host, "/tmp");
4022 assert!(cfg.host_is_uds());
4023 assert_eq!(cfg.ssl, SslMode::Disable);
4024 assert_eq!(cfg.statement_timeout_secs, 60);
4025 }
4026
4027 #[test]
4028 fn config_parse_uds_host_percent_encoded() {
4029 let cfg = Config::from_url("postgres://user@localhost/db?host=%2Ftmp").unwrap();
4031 assert_eq!(cfg.host, "/tmp");
4032 assert!(cfg.host_is_uds());
4033 }
4034
4035 #[test]
4036 fn config_parse_tcp_host_not_overridden_without_param() {
4037 let cfg = Config::from_url("postgres://user@myserver/db").unwrap();
4039 assert_eq!(cfg.host, "myserver");
4040 assert!(!cfg.host_is_uds());
4041 }
4042
4043 #[test]
4044 fn config_parse_uds_host_overrides_url_hostname() {
4045 let cfg = Config::from_url("postgres://user@db.example.com/mydb?host=/var/run/postgresql")
4047 .unwrap();
4048 assert_eq!(cfg.host, "/var/run/postgresql");
4049 assert!(cfg.host_is_uds());
4050 }
4051
4052 #[test]
4053 fn config_parse_uds_empty_url_host() {
4054 let cfg = Config::from_url("postgres://user@/mydb?host=/tmp").unwrap();
4056 assert_eq!(cfg.host, "/tmp");
4057 assert!(cfg.host_is_uds());
4058 assert_eq!(cfg.database, "mydb");
4059 }
4060
4061 #[test]
4066 fn pg_data_row_all_null_columns() {
4067 let data = make_data_row(&[None, None, None, None, None]);
4068 let row = PgDataRow::new(&data).unwrap();
4069 assert_eq!(row.column_count(), 5);
4070 for i in 0..5 {
4071 assert!(row.is_null(i), "column {i} should be null");
4072 assert_eq!(row.get_raw(i), None);
4073 assert_eq!(row.get_i32(i), None);
4074 assert_eq!(row.get_i64(i), None);
4075 assert_eq!(row.get_str(i), None);
4076 assert_eq!(row.get_bool(i), None);
4077 assert_eq!(row.get_f64(i), None);
4078 }
4079 }
4080
4081 #[test]
4082 fn pg_data_row_very_long_text() {
4083 let long_text = "x".repeat(2048);
4084 let data = make_data_row(&[Some(long_text.as_bytes())]);
4085 let row = PgDataRow::new(&data).unwrap();
4086 assert_eq!(row.get_str(0), Some(long_text.as_str()));
4087 }
4088
4089 #[test]
4090 fn pg_data_row_empty_text() {
4091 let data = make_data_row(&[Some(b"")]);
4092 let row = PgDataRow::new(&data).unwrap();
4093 assert!(!row.is_null(0));
4094 assert_eq!(row.get_str(0), Some(""));
4095 assert_eq!(row.get_bytes(0), Some(&[][..]));
4096 }
4097
4098 #[test]
4099 fn pg_data_row_20_columns_exceeds_inline() {
4100 let col_data: Vec<[u8; 4]> = (0..20).map(|i: i32| i.to_be_bytes()).collect();
4101 let cols: Vec<Option<&[u8]>> = col_data.iter().map(|b| Some(b.as_slice())).collect();
4102 let data = make_data_row(&cols);
4103 let row = PgDataRow::new(&data).unwrap();
4104 assert_eq!(row.column_count(), 20);
4105 for i in 0..20 {
4106 assert_eq!(row.get_i32(i), Some(i as i32));
4107 }
4108 }
4109
4110 #[test]
4111 fn pg_data_row_is_null_each_position() {
4112 let data = make_data_row(&[Some(&1i32.to_be_bytes()), None, Some(&3i32.to_be_bytes())]);
4114 let row = PgDataRow::new(&data).unwrap();
4115 assert!(!row.is_null(0));
4116 assert!(row.is_null(1));
4117 assert!(!row.is_null(2));
4118 }
4119
4120 #[test]
4121 fn pg_data_row_negative_column_count() {
4122 let data = (-1i16).to_be_bytes();
4123 assert!(PgDataRow::new(&data).is_err());
4124 }
4125
4126 #[test]
4127 fn pg_data_row_get_str_invalid_utf8() {
4128 let invalid_utf8 = &[0xFF, 0xFE, 0x80];
4129 let data = make_data_row(&[Some(invalid_utf8)]);
4130 let row = PgDataRow::new(&data).unwrap();
4131 assert_eq!(row.get_str(0), None);
4133 assert_eq!(row.get_bytes(0), Some(&[0xFF, 0xFE, 0x80][..]));
4134 }
4135
4136 #[test]
4137 fn pg_data_row_get_i32_wrong_length() {
4138 let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
4140 let row = PgDataRow::new(&data).unwrap();
4141 assert_eq!(row.get_i32(0), None); assert_eq!(row.get_i16(0), Some(7)); }
4144
4145 #[test]
4146 fn pg_data_row_get_i64_wrong_length() {
4147 let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
4149 let row = PgDataRow::new(&data).unwrap();
4150 assert_eq!(row.get_i64(0), None);
4151 }
4152
4153 #[test]
4154 fn pg_data_row_get_f64_wrong_length() {
4155 let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
4156 let row = PgDataRow::new(&data).unwrap();
4157 assert_eq!(row.get_f64(0), None); }
4159
4160 #[test]
4161 fn pg_data_row_get_f32_wrong_length() {
4162 let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
4163 let row = PgDataRow::new(&data).unwrap();
4164 assert_eq!(row.get_f32(0), None); }
4166
4167 #[test]
4168 fn pg_data_row_get_bool_wrong_length() {
4169 let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
4171 let row = PgDataRow::new(&data).unwrap();
4172 assert_eq!(row.get_bool(0), None);
4173 }
4174
4175 #[test]
4176 fn pg_data_row_unicode_text() {
4177 let texts = [
4178 "\u{1F600}\u{1F4A9}\u{1F680}", "\u{4e16}\u{754c}", "\u{0645}\u{0631}\u{062D}", "\u{1F468}\u{200D}\u{1F469}", ];
4183 for text in &texts {
4184 let data = make_data_row(&[Some(text.as_bytes())]);
4185 let row = PgDataRow::new(&data).unwrap();
4186 assert_eq!(row.get_str(0), Some(*text));
4187 }
4188 }
4189
4190 #[test]
4191 fn pg_data_row_i32_boundary_values() {
4192 for &val in &[i32::MIN, -1, 0, 1, i32::MAX] {
4193 let data = make_data_row(&[Some(&val.to_be_bytes())]);
4194 let row = PgDataRow::new(&data).unwrap();
4195 assert_eq!(row.get_i32(0), Some(val), "failed for {val}");
4196 }
4197 }
4198
4199 #[test]
4200 fn pg_data_row_i64_boundary_values() {
4201 for &val in &[i64::MIN, -1, 0, 1, i64::MAX] {
4202 let data = make_data_row(&[Some(&val.to_be_bytes())]);
4203 let row = PgDataRow::new(&data).unwrap();
4204 assert_eq!(row.get_i64(0), Some(val), "failed for {val}");
4205 }
4206 }
4207
4208 #[test]
4209 fn pg_data_row_f64_special_values() {
4210 let data = make_data_row(&[Some(&f64::INFINITY.to_be_bytes())]);
4211 let row = PgDataRow::new(&data).unwrap();
4212 assert_eq!(row.get_f64(0), Some(f64::INFINITY));
4213
4214 let data = make_data_row(&[Some(&f64::NEG_INFINITY.to_be_bytes())]);
4215 let row = PgDataRow::new(&data).unwrap();
4216 assert_eq!(row.get_f64(0), Some(f64::NEG_INFINITY));
4217
4218 let data = make_data_row(&[Some(&f64::NAN.to_be_bytes())]);
4219 let row = PgDataRow::new(&data).unwrap();
4220 assert!(row.get_f64(0).unwrap().is_nan());
4221 }
4222
4223 #[test]
4224 fn pg_data_row_f32_special_values() {
4225 let data = make_data_row(&[Some(&f32::INFINITY.to_be_bytes())]);
4226 let row = PgDataRow::new(&data).unwrap();
4227 assert_eq!(row.get_f32(0), Some(f32::INFINITY));
4228
4229 let data = make_data_row(&[Some(&f32::NAN.to_be_bytes())]);
4230 let row = PgDataRow::new(&data).unwrap();
4231 assert!(row.get_f32(0).unwrap().is_nan());
4232 }
4233
4234 #[test]
4235 fn pg_data_row_i16_boundary_values() {
4236 for &val in &[i16::MIN, -1, 0, 1, i16::MAX] {
4237 let data = make_data_row(&[Some(&val.to_be_bytes())]);
4238 let row = PgDataRow::new(&data).unwrap();
4239 assert_eq!(row.get_i16(0), Some(val));
4240 }
4241 }
4242
4243 #[test]
4248 fn data_row_flat_all_null() {
4249 let mut arena = Arena::new();
4250 let mut out = Vec::new();
4251 let mut data = Vec::new();
4252 data.extend_from_slice(&4i16.to_be_bytes());
4253 for _ in 0..4 {
4254 data.extend_from_slice(&(-1i32).to_be_bytes());
4255 }
4256 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
4257 assert_eq!(out.len(), 4);
4258 for (_, len) in &out {
4259 assert_eq!(*len, -1);
4260 }
4261 }
4262
4263 #[test]
4264 fn data_row_flat_long_text() {
4265 let mut arena = Arena::new();
4266 let mut out = Vec::new();
4267 let long = vec![b'A'; 1024];
4268 let mut data = Vec::new();
4269 data.extend_from_slice(&1i16.to_be_bytes());
4270 data.extend_from_slice(&(long.len() as i32).to_be_bytes());
4271 data.extend_from_slice(&long);
4272 parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
4273 assert_eq!(out[0].1, 1024);
4274 let stored = arena.get(out[0].0, 1024);
4275 assert!(stored.iter().all(|&b| b == b'A'));
4276 }
4277
4278 #[test]
4279 fn data_row_flat_empty_text() {
4280 let mut arena = Arena::new();
4281 let mut out = Vec::new();
4282 let mut data = Vec::new();
4283 data.extend_from_slice(&1i16.to_be_bytes());
4284 data.extend_from_slice(&0i32.to_be_bytes()); parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
4286 assert_eq!(out[0].1, 0);
4287 }
4288
4289 #[test]
4294 fn query_result_from_parts() {
4295 let result = QueryResult::from_parts(vec![(0, 4), (0, -1)], 2, Arc::from(Vec::new()), 5);
4296 assert_eq!(result.len(), 1);
4297 assert_eq!(result.num_cols, 2);
4298 assert_eq!(result.affected_rows, 5);
4299 }
4300
4301 #[test]
4302 fn query_result_affected_rows() {
4303 let result = QueryResult {
4304 all_col_offsets: vec![],
4305 num_cols: 0,
4306 columns: Arc::from(Vec::new()),
4307 affected_rows: 42,
4308 };
4309 assert_eq!(result.affected_rows, 42);
4310 assert!(result.is_empty());
4311 }
4312
4313 #[test]
4318 fn driver_error_server_with_hint() {
4319 let e = DriverError::Server {
4320 code: "42601".into(),
4321 message: "syntax error".into(),
4322 detail: None,
4323 hint: Some("check your SQL".into()),
4324 position: Some(10),
4325 };
4326 let s = e.to_string();
4327 assert!(s.contains("HINT: check your SQL"));
4328 assert!(s.contains("(at position 10)"));
4329 }
4330
4331 #[test]
4332 fn driver_error_server_with_all_fields() {
4333 let e = DriverError::Server {
4334 code: "23505".into(),
4335 message: "unique violation".into(),
4336 detail: Some("Key (id)=(1) already exists.".into()),
4337 hint: Some("change the id".into()),
4338 position: Some(1),
4339 };
4340 let s = e.to_string();
4341 assert!(s.contains("23505"));
4342 assert!(s.contains("unique violation"));
4343 assert!(s.contains("Key (id)=(1) already exists."));
4344 assert!(s.contains("change the id"));
4345 assert!(s.contains("(at position 1)"));
4346 }
4347
4348 #[test]
4353 fn config_statement_timeout_default() {
4354 let cfg = Config::from_url("postgres://user:pass@localhost/db").unwrap();
4355 assert_eq!(cfg.statement_timeout_secs, 30);
4356 }
4357
4358 #[test]
4359 fn config_statement_timeout_custom() {
4360 let cfg =
4361 Config::from_url("postgres://user:pass@localhost/db?statement_timeout=120").unwrap();
4362 assert_eq!(cfg.statement_timeout_secs, 120);
4363 }
4364
4365 #[test]
4366 fn config_statement_timeout_zero() {
4367 let cfg =
4368 Config::from_url("postgres://user:pass@localhost/db?statement_timeout=0").unwrap();
4369 assert_eq!(cfg.statement_timeout_secs, 0);
4370 }
4371
4372 #[test]
4373 fn config_statement_timeout_invalid_falls_back() {
4374 let cfg =
4375 Config::from_url("postgres://user:pass@localhost/db?statement_timeout=notanumber")
4376 .unwrap();
4377 assert_eq!(cfg.statement_timeout_secs, 30); }
4379
4380 #[test]
4381 fn config_uds_path_format() {
4382 let cfg = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
4383 assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
4384 }
4385
4386 #[test]
4387 fn config_uds_path_custom_port() {
4388 let cfg = Config::from_url("postgres://user@localhost:5433/db?host=/tmp").unwrap();
4389 assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5433");
4390 }
4391
4392 #[test]
4397 fn url_decode_empty_string() {
4398 assert_eq!(url_decode("").unwrap(), "");
4399 }
4400
4401 #[test]
4402 fn url_decode_no_encoding() {
4403 assert_eq!(url_decode("hello").unwrap(), "hello");
4404 }
4405
4406 #[test]
4407 fn url_decode_all_ascii_hex() {
4408 assert_eq!(url_decode("%2F").unwrap(), "/");
4410 assert_eq!(url_decode("%2f").unwrap(), "/");
4411 }
4412
4413 #[test]
4418 fn hash_sql_empty() {
4419 let _h = hash_sql(""); }
4421
4422 #[test]
4423 fn hash_sql_whitespace_only() {
4424 let h = hash_sql(" ");
4425 assert_ne!(h, hash_sql(""));
4426 }
4427
4428 #[test]
4429 fn hash_sql_very_long() {
4430 let long_sql = "SELECT ".to_string() + &"x".repeat(10_000);
4431 let h = hash_sql(&long_sql);
4432 assert_eq!(h, hash_sql(&long_sql));
4433 }
4434
4435 #[test]
4436 fn hash_sql_unicode() {
4437 let h = hash_sql("SELECT '\u{1F600}'");
4438 assert_ne!(h, hash_sql("SELECT 'x'"));
4439 }
4440}