1use std::io;
21
22use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
23
24pub const PG_PROTOCOL_V3: u32 = 3 << 16;
26
27pub const PG_SSL_REQUEST: u32 = 80877103;
31pub const PG_GSSENC_REQUEST: u32 = 80877104;
32pub const PG_CANCEL_REQUEST: u32 = 80877102;
33
34#[derive(Debug)]
37pub enum PgWireError {
38 Io(io::Error),
39 Protocol(String),
40 Eof,
42}
43
44impl From<io::Error> for PgWireError {
45 fn from(err: io::Error) -> Self {
46 if err.kind() == io::ErrorKind::UnexpectedEof {
47 PgWireError::Eof
48 } else {
49 PgWireError::Io(err)
50 }
51 }
52}
53
54impl std::fmt::Display for PgWireError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 PgWireError::Io(e) => write!(f, "pg wire io: {e}"),
58 PgWireError::Protocol(m) => write!(f, "pg wire protocol: {m}"),
59 PgWireError::Eof => write!(f, "pg wire eof"),
60 }
61 }
62}
63
64impl std::error::Error for PgWireError {}
65
66#[derive(Debug, Clone)]
68pub enum FrontendMessage {
69 Startup(StartupParams),
71 SslRequest,
73 GssEncRequest,
75 Query(String),
77 Parse(ParseMessage),
79 Bind(BindMessage),
81 Describe(DescribeMessage),
83 Execute(ExecuteMessage),
85 Close(CloseMessage),
87 PasswordMessage(Vec<u8>),
89 Terminate,
91 Flush,
93 Sync,
95 Unknown { tag: u8, payload: Vec<u8> },
98}
99
100#[derive(Debug, Clone)]
101pub struct ParseMessage {
102 pub statement: String,
103 pub query: String,
104 pub param_type_oids: Vec<u32>,
105}
106
107#[derive(Debug, Clone)]
108pub struct BindMessage {
109 pub portal: String,
110 pub statement: String,
111 pub param_format_codes: Vec<i16>,
112 pub params: Vec<Option<Vec<u8>>>,
113 pub result_format_codes: Vec<i16>,
114}
115
116#[derive(Debug, Clone)]
117pub struct DescribeMessage {
118 pub target: DescribeTarget,
119 pub name: String,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum DescribeTarget {
124 Statement,
125 Portal,
126}
127
128#[derive(Debug, Clone)]
129pub struct ExecuteMessage {
130 pub portal: String,
131 pub max_rows: u32,
132}
133
134#[derive(Debug, Clone)]
135pub struct CloseMessage {
136 pub target: DescribeTarget,
137 pub name: String,
138}
139
140#[derive(Debug, Clone, Default)]
141pub struct StartupParams {
142 pub params: Vec<(String, String)>,
144}
145
146impl StartupParams {
147 pub fn get(&self, key: &str) -> Option<&str> {
148 self.params
149 .iter()
150 .find(|(k, _)| k == key)
151 .map(|(_, v)| v.as_str())
152 }
153}
154
155#[derive(Debug, Clone)]
157pub enum BackendMessage {
158 AuthenticationOk,
160 ParameterStatus { name: String, value: String },
162 BackendKeyData { pid: u32, key: u32 },
164 ReadyForQuery(TransactionStatus),
166 RowDescription(Vec<ColumnDescriptor>),
168 DataRow(Vec<Option<Vec<u8>>>),
170 CommandComplete(String),
172 ParseComplete,
174 BindComplete,
176 CloseComplete,
178 ParameterDescription(Vec<u32>),
180 NoData,
182 ErrorResponse {
184 severity: String,
185 code: String,
186 message: String,
187 },
188 NoticeResponse { message: String },
190 EmptyQueryResponse,
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum TransactionStatus {
196 Idle,
198 InTransaction,
200 Failed,
202}
203
204impl TransactionStatus {
205 pub fn as_byte(self) -> u8 {
206 match self {
207 TransactionStatus::Idle => b'I',
208 TransactionStatus::InTransaction => b'T',
209 TransactionStatus::Failed => b'E',
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
215pub struct ColumnDescriptor {
216 pub name: String,
217 pub table_oid: u32,
219 pub column_attr: i16,
221 pub type_oid: u32,
223 pub type_size: i16,
225 pub type_mod: i32,
227 pub format: i16,
229}
230
231pub async fn read_startup<R: AsyncRead + Unpin>(
239 stream: &mut R,
240) -> Result<FrontendMessage, PgWireError> {
241 let mut len_buf = [0u8; 4];
242 stream.read_exact(&mut len_buf).await?;
243 let len = u32::from_be_bytes(len_buf);
244 if !(8..=65536).contains(&len) {
245 return Err(PgWireError::Protocol(format!(
246 "startup length {len} out of range"
247 )));
248 }
249 let body_len = (len as usize) - 4;
250 let mut body = vec![0u8; body_len];
251 stream.read_exact(&mut body).await?;
252 if body_len < 4 {
253 return Err(PgWireError::Protocol("startup payload too short".into()));
254 }
255 let version = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
256
257 match version {
258 PG_SSL_REQUEST => Ok(FrontendMessage::SslRequest),
259 PG_GSSENC_REQUEST => Ok(FrontendMessage::GssEncRequest),
260 PG_PROTOCOL_V3 => {
261 let mut params: Vec<(String, String)> = Vec::new();
264 let mut pos = 4usize;
265 while pos < body_len {
266 if body[pos] == 0 {
267 break;
268 }
269 let key = read_cstring(&body, &mut pos)?;
270 if pos >= body_len {
271 return Err(PgWireError::Protocol(
272 "startup parameter missing value".into(),
273 ));
274 }
275 let value = read_cstring(&body, &mut pos)?;
276 params.push((key, value));
277 }
278 Ok(FrontendMessage::Startup(StartupParams { params }))
279 }
280 PG_CANCEL_REQUEST => Ok(FrontendMessage::Unknown {
283 tag: b'K',
284 payload: body,
285 }),
286 _ => Err(PgWireError::Protocol(format!(
287 "unsupported protocol version {version}"
288 ))),
289 }
290}
291
292pub async fn read_frame<R: AsyncRead + Unpin>(
294 stream: &mut R,
295) -> Result<FrontendMessage, PgWireError> {
296 let mut tag_buf = [0u8; 1];
297 match stream.read_exact(&mut tag_buf).await {
298 Ok(_) => {}
299 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Err(PgWireError::Eof),
300 Err(e) => return Err(PgWireError::Io(e)),
301 }
302 let tag = tag_buf[0];
303
304 let mut len_buf = [0u8; 4];
305 stream.read_exact(&mut len_buf).await?;
306 let len = u32::from_be_bytes(len_buf);
307 if !(4..=1_048_576).contains(&len) {
308 return Err(PgWireError::Protocol(format!(
309 "frame length {len} out of bounds"
310 )));
311 }
312 let payload_len = (len as usize) - 4;
313 let mut payload = vec![0u8; payload_len];
314 stream.read_exact(&mut payload).await?;
315
316 Ok(match tag {
317 b'Q' => {
318 let mut pos = 0;
320 let query = read_cstring(&payload, &mut pos)?;
321 FrontendMessage::Query(query)
322 }
323 b'P' => FrontendMessage::Parse(parse_parse_message(&payload)?),
324 b'B' => FrontendMessage::Bind(parse_bind_message(&payload)?),
325 b'D' => FrontendMessage::Describe(parse_describe_message(&payload)?),
326 b'E' => FrontendMessage::Execute(parse_execute_message(&payload)?),
327 b'C' => FrontendMessage::Close(parse_close_message(&payload)?),
328 b'p' => FrontendMessage::PasswordMessage(payload),
329 b'X' => FrontendMessage::Terminate,
330 b'H' => FrontendMessage::Flush,
331 b'S' => FrontendMessage::Sync,
332 other => FrontendMessage::Unknown {
333 tag: other,
334 payload,
335 },
336 })
337}
338
339pub async fn write_raw_byte<W: AsyncWrite + Unpin>(
346 stream: &mut W,
347 byte: u8,
348) -> Result<(), PgWireError> {
349 stream.write_all(&[byte]).await?;
350 Ok(())
351}
352
353pub async fn write_frame<W: AsyncWrite + Unpin>(
355 stream: &mut W,
356 msg: &BackendMessage,
357) -> Result<(), PgWireError> {
358 let (tag, payload) = encode_backend(msg);
359 let length = (payload.len() + 4) as u32;
361 stream.write_all(&[tag]).await?;
362 stream.write_all(&length.to_be_bytes()).await?;
363 stream.write_all(&payload).await?;
364 Ok(())
365}
366
367fn sanitize_cstring_bytes(input: &[u8]) -> Vec<u8> {
383 if !input.contains(&0) {
384 return input.to_vec();
385 }
386 let mut out = Vec::with_capacity(input.len() + 8);
387 for &b in input {
388 if b == 0 {
389 out.extend_from_slice(&[0xEF, 0xBF, 0xBD]);
391 } else {
392 out.push(b);
393 }
394 }
395 out
396}
397
398#[inline]
399fn push_cstring(buf: &mut Vec<u8>, value: &str) {
400 buf.extend_from_slice(&sanitize_cstring_bytes(value.as_bytes()));
401 buf.push(0);
402}
403
404fn encode_backend(msg: &BackendMessage) -> (u8, Vec<u8>) {
405 match msg {
406 BackendMessage::AuthenticationOk => {
407 (b'R', vec![0, 0, 0, 0])
409 }
410 BackendMessage::ParameterStatus { name, value } => {
411 let mut buf = Vec::with_capacity(name.len() + value.len() + 2);
412 push_cstring(&mut buf, name);
414 push_cstring(&mut buf, value);
415 (b'S', buf)
416 }
417 BackendMessage::BackendKeyData { pid, key } => {
418 let mut buf = Vec::with_capacity(8);
419 buf.extend_from_slice(&pid.to_be_bytes());
420 buf.extend_from_slice(&key.to_be_bytes());
421 (b'K', buf)
422 }
423 BackendMessage::ReadyForQuery(status) => (b'Z', vec![status.as_byte()]),
424 BackendMessage::RowDescription(cols) => {
425 let mut buf = Vec::new();
426 buf.extend_from_slice(&(cols.len() as i16).to_be_bytes());
427 for col in cols {
428 push_cstring(&mut buf, &col.name);
430 buf.extend_from_slice(&col.table_oid.to_be_bytes());
431 buf.extend_from_slice(&col.column_attr.to_be_bytes());
432 buf.extend_from_slice(&col.type_oid.to_be_bytes());
433 buf.extend_from_slice(&col.type_size.to_be_bytes());
434 buf.extend_from_slice(&col.type_mod.to_be_bytes());
435 buf.extend_from_slice(&col.format.to_be_bytes());
436 }
437 (b'T', buf)
438 }
439 BackendMessage::DataRow(fields) => {
440 let mut buf = Vec::new();
441 buf.extend_from_slice(&(fields.len() as i16).to_be_bytes());
442 for field in fields {
443 match field {
444 None => {
445 buf.extend_from_slice(&(-1i32).to_be_bytes());
447 }
448 Some(bytes) => {
449 buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
453 buf.extend_from_slice(bytes);
454 }
455 }
456 }
457 (b'D', buf)
458 }
459 BackendMessage::CommandComplete(tag) => {
460 let mut buf = Vec::with_capacity(tag.len() + 1);
461 push_cstring(&mut buf, tag);
464 (b'C', buf)
465 }
466 BackendMessage::ParseComplete => (b'1', Vec::new()),
467 BackendMessage::BindComplete => (b'2', Vec::new()),
468 BackendMessage::CloseComplete => (b'3', Vec::new()),
469 BackendMessage::ParameterDescription(oids) => {
470 let mut buf = Vec::with_capacity(2 + oids.len() * 4);
471 buf.extend_from_slice(&(oids.len() as i16).to_be_bytes());
472 for oid in oids {
473 buf.extend_from_slice(&oid.to_be_bytes());
474 }
475 (b't', buf)
476 }
477 BackendMessage::NoData => (b'n', Vec::new()),
478 BackendMessage::ErrorResponse {
479 severity,
480 code,
481 message,
482 } => {
483 let mut buf = Vec::new();
484 buf.push(b'S');
486 push_cstring(&mut buf, severity);
487 buf.push(b'V');
489 push_cstring(&mut buf, severity);
490 buf.push(b'C');
492 push_cstring(&mut buf, code);
493 buf.push(b'M');
495 push_cstring(&mut buf, message);
496 buf.push(0);
498 (b'E', buf)
499 }
500 BackendMessage::NoticeResponse { message } => {
501 let mut buf = Vec::new();
502 buf.push(b'S');
503 buf.extend_from_slice(b"NOTICE");
504 buf.push(0);
505 buf.push(b'M');
506 push_cstring(&mut buf, message);
508 buf.push(0);
509 (b'N', buf)
510 }
511 BackendMessage::EmptyQueryResponse => (b'I', Vec::new()),
512 }
513}
514
515fn read_cstring(buf: &[u8], pos: &mut usize) -> Result<String, PgWireError> {
522 let start = *pos;
523 while *pos < buf.len() && buf[*pos] != 0 {
524 *pos += 1;
525 }
526 if *pos >= buf.len() {
527 return Err(PgWireError::Protocol("cstring missing terminator".into()));
528 }
529 let s = std::str::from_utf8(&buf[start..*pos])
530 .map_err(|e| PgWireError::Protocol(format!("invalid utf8: {e}")))?
531 .to_string();
532 *pos += 1; Ok(s)
534}
535
536fn parse_parse_message(payload: &[u8]) -> Result<ParseMessage, PgWireError> {
537 let mut pos = 0;
538 let statement = read_cstring(payload, &mut pos)?;
539 let query = read_cstring(payload, &mut pos)?;
540 let nparams = read_i16(payload, &mut pos, "Parse parameter count")?;
541 if nparams < 0 {
542 return Err(PgWireError::Protocol(
543 "negative Parse parameter count".into(),
544 ));
545 }
546 let mut param_type_oids = Vec::with_capacity(nparams as usize);
547 for _ in 0..nparams {
548 param_type_oids.push(read_u32(payload, &mut pos, "Parse parameter type OID")?);
549 }
550 ensure_consumed(payload, pos, "Parse")?;
551 Ok(ParseMessage {
552 statement,
553 query,
554 param_type_oids,
555 })
556}
557
558fn parse_bind_message(payload: &[u8]) -> Result<BindMessage, PgWireError> {
559 let mut pos = 0;
560 let portal = read_cstring(payload, &mut pos)?;
561 let statement = read_cstring(payload, &mut pos)?;
562
563 let nformats = read_i16(payload, &mut pos, "Bind format count")?;
564 if nformats < 0 {
565 return Err(PgWireError::Protocol("negative Bind format count".into()));
566 }
567 let mut param_format_codes = Vec::with_capacity(nformats as usize);
568 for _ in 0..nformats {
569 param_format_codes.push(read_i16(payload, &mut pos, "Bind format code")?);
570 }
571
572 let nparams = read_i16(payload, &mut pos, "Bind parameter count")?;
573 if nparams < 0 {
574 return Err(PgWireError::Protocol(
575 "negative Bind parameter count".into(),
576 ));
577 }
578 let mut params = Vec::with_capacity(nparams as usize);
579 for _ in 0..nparams {
580 let len = read_i32(payload, &mut pos, "Bind parameter length")?;
581 if len == -1 {
582 params.push(None);
583 } else if len < -1 {
584 return Err(PgWireError::Protocol(
585 "invalid Bind parameter length".into(),
586 ));
587 } else {
588 params.push(Some(
589 read_bytes(payload, &mut pos, len as usize, "Bind parameter")?.to_vec(),
590 ));
591 }
592 }
593
594 let nresult_formats = read_i16(payload, &mut pos, "Bind result format count")?;
595 if nresult_formats < 0 {
596 return Err(PgWireError::Protocol(
597 "negative Bind result format count".into(),
598 ));
599 }
600 let mut result_format_codes = Vec::with_capacity(nresult_formats as usize);
601 for _ in 0..nresult_formats {
602 result_format_codes.push(read_i16(payload, &mut pos, "Bind result format code")?);
603 }
604 ensure_consumed(payload, pos, "Bind")?;
605
606 Ok(BindMessage {
607 portal,
608 statement,
609 param_format_codes,
610 params,
611 result_format_codes,
612 })
613}
614
615fn parse_describe_message(payload: &[u8]) -> Result<DescribeMessage, PgWireError> {
616 let mut pos = 0;
617 let target = read_describe_target(payload, &mut pos, "Describe")?;
618 let name = read_cstring(payload, &mut pos)?;
619 ensure_consumed(payload, pos, "Describe")?;
620 Ok(DescribeMessage { target, name })
621}
622
623fn parse_execute_message(payload: &[u8]) -> Result<ExecuteMessage, PgWireError> {
624 let mut pos = 0;
625 let portal = read_cstring(payload, &mut pos)?;
626 let max_rows = read_u32(payload, &mut pos, "Execute max rows")?;
627 ensure_consumed(payload, pos, "Execute")?;
628 Ok(ExecuteMessage { portal, max_rows })
629}
630
631fn parse_close_message(payload: &[u8]) -> Result<CloseMessage, PgWireError> {
632 let mut pos = 0;
633 let target = read_describe_target(payload, &mut pos, "Close")?;
634 let name = read_cstring(payload, &mut pos)?;
635 ensure_consumed(payload, pos, "Close")?;
636 Ok(CloseMessage { target, name })
637}
638
639fn read_describe_target(
640 payload: &[u8],
641 pos: &mut usize,
642 frame: &'static str,
643) -> Result<DescribeTarget, PgWireError> {
644 let byte = *read_bytes(payload, pos, 1, frame)?
645 .first()
646 .expect("one target byte");
647 match byte {
648 b'S' => Ok(DescribeTarget::Statement),
649 b'P' => Ok(DescribeTarget::Portal),
650 other => Err(PgWireError::Protocol(format!(
651 "{frame} target must be 'S' or 'P', got 0x{other:02x}"
652 ))),
653 }
654}
655
656fn read_i16(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<i16, PgWireError> {
657 let bytes = read_bytes(payload, pos, 2, field)?;
658 Ok(i16::from_be_bytes([bytes[0], bytes[1]]))
659}
660
661fn read_i32(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<i32, PgWireError> {
662 let bytes = read_bytes(payload, pos, 4, field)?;
663 Ok(i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
664}
665
666fn read_u32(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<u32, PgWireError> {
667 let bytes = read_bytes(payload, pos, 4, field)?;
668 Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
669}
670
671fn read_bytes<'a>(
672 payload: &'a [u8],
673 pos: &mut usize,
674 len: usize,
675 field: &'static str,
676) -> Result<&'a [u8], PgWireError> {
677 let end = pos
678 .checked_add(len)
679 .ok_or_else(|| PgWireError::Protocol(format!("{field} length overflow")))?;
680 if end > payload.len() {
681 return Err(PgWireError::Protocol(format!("{field} truncated")));
682 }
683 let bytes = &payload[*pos..end];
684 *pos = end;
685 Ok(bytes)
686}
687
688fn ensure_consumed(payload: &[u8], pos: usize, frame: &'static str) -> Result<(), PgWireError> {
689 if pos == payload.len() {
690 Ok(())
691 } else {
692 Err(PgWireError::Protocol(format!(
693 "{frame} had {} trailing bytes",
694 payload.len() - pos
695 )))
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702
703 #[tokio::test]
704 async fn parse_startup_v3() {
705 let mut payload: Vec<u8> = Vec::new();
707 payload.extend_from_slice(&PG_PROTOCOL_V3.to_be_bytes());
708 payload.extend_from_slice(b"user\0alice\0");
709 payload.push(0);
710 let len = (4 + payload.len()) as u32;
711 let mut frame = Vec::new();
712 frame.extend_from_slice(&len.to_be_bytes());
713 frame.extend_from_slice(&payload);
714
715 let mut cursor = std::io::Cursor::new(frame);
716 let msg = read_startup(&mut cursor).await.unwrap();
717 match msg {
718 FrontendMessage::Startup(params) => {
719 assert_eq!(params.get("user"), Some("alice"));
720 }
721 other => panic!("expected Startup, got {:?}", other),
722 }
723 }
724
725 #[tokio::test]
726 async fn parse_ssl_request() {
727 let mut frame: Vec<u8> = Vec::new();
728 frame.extend_from_slice(&8u32.to_be_bytes());
729 frame.extend_from_slice(&PG_SSL_REQUEST.to_be_bytes());
730 let mut cursor = std::io::Cursor::new(frame);
731 assert!(matches!(
732 read_startup(&mut cursor).await.unwrap(),
733 FrontendMessage::SslRequest
734 ));
735 }
736
737 #[tokio::test]
738 async fn parse_query_frame() {
739 let query = "SELECT 1\0";
740 let mut frame = Vec::new();
741 frame.push(b'Q');
742 let len = (4 + query.len()) as u32;
743 frame.extend_from_slice(&len.to_be_bytes());
744 frame.extend_from_slice(query.as_bytes());
745 let mut cursor = std::io::Cursor::new(frame);
746 match read_frame(&mut cursor).await.unwrap() {
747 FrontendMessage::Query(s) => assert_eq!(s, "SELECT 1"),
748 other => panic!("expected Query, got {:?}", other),
749 }
750 }
751
752 #[tokio::test]
753 async fn parse_extended_query_frames() {
754 let mut parse_payload = Vec::new();
755 push_test_cstring(&mut parse_payload, "");
756 push_test_cstring(&mut parse_payload, "SELECT $1");
757 parse_payload.extend_from_slice(&1i16.to_be_bytes());
758 parse_payload.extend_from_slice(&23u32.to_be_bytes());
759 let mut frame = tagged_frame(b'P', parse_payload);
760 let mut cursor = std::io::Cursor::new(frame);
761 match read_frame(&mut cursor).await.unwrap() {
762 FrontendMessage::Parse(msg) => {
763 assert_eq!(msg.statement, "");
764 assert_eq!(msg.query, "SELECT $1");
765 assert_eq!(msg.param_type_oids, vec![23]);
766 }
767 other => panic!("expected Parse, got {other:?}"),
768 }
769
770 let mut bind_payload = Vec::new();
771 push_test_cstring(&mut bind_payload, "");
772 push_test_cstring(&mut bind_payload, "");
773 bind_payload.extend_from_slice(&1i16.to_be_bytes());
774 bind_payload.extend_from_slice(&0i16.to_be_bytes());
775 bind_payload.extend_from_slice(&1i16.to_be_bytes());
776 bind_payload.extend_from_slice(&2i32.to_be_bytes());
777 bind_payload.extend_from_slice(b"42");
778 bind_payload.extend_from_slice(&0i16.to_be_bytes());
779 frame = tagged_frame(b'B', bind_payload);
780 let mut cursor = std::io::Cursor::new(frame);
781 match read_frame(&mut cursor).await.unwrap() {
782 FrontendMessage::Bind(msg) => {
783 assert_eq!(msg.portal, "");
784 assert_eq!(msg.statement, "");
785 assert_eq!(msg.param_format_codes, vec![0]);
786 assert_eq!(msg.params, vec![Some(b"42".to_vec())]);
787 assert!(msg.result_format_codes.is_empty());
788 }
789 other => panic!("expected Bind, got {other:?}"),
790 }
791
792 let mut describe_payload = vec![b'P'];
793 push_test_cstring(&mut describe_payload, "");
794 let mut cursor = std::io::Cursor::new(tagged_frame(b'D', describe_payload));
795 assert!(matches!(
796 read_frame(&mut cursor).await.unwrap(),
797 FrontendMessage::Describe(DescribeMessage {
798 target: DescribeTarget::Portal,
799 ..
800 })
801 ));
802 }
803
804 #[tokio::test]
805 async fn emit_ready_for_query() {
806 let mut out: Vec<u8> = Vec::new();
807 write_frame(
808 &mut out,
809 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
810 )
811 .await
812 .unwrap();
813 assert_eq!(out, vec![b'Z', 0, 0, 0, 5, b'I']);
814 }
815
816 #[tokio::test]
817 async fn emit_row_description_and_data_row() {
818 let mut out: Vec<u8> = Vec::new();
819 write_frame(
820 &mut out,
821 &BackendMessage::RowDescription(vec![ColumnDescriptor {
822 name: "id".to_string(),
823 table_oid: 0,
824 column_attr: 0,
825 type_oid: 23,
826 type_size: 4,
827 type_mod: -1,
828 format: 0,
829 }]),
830 )
831 .await
832 .unwrap();
833 assert_eq!(out[0], b'T');
834
835 let mut data: Vec<u8> = Vec::new();
836 write_frame(
837 &mut data,
838 &BackendMessage::DataRow(vec![Some(b"42".to_vec()), None]),
839 )
840 .await
841 .unwrap();
842 assert_eq!(data[0], b'D');
843 }
844
845 #[tokio::test]
846 async fn emit_extended_completion_frames() {
847 let mut out = Vec::new();
848 write_frame(&mut out, &BackendMessage::ParseComplete)
849 .await
850 .unwrap();
851 write_frame(&mut out, &BackendMessage::BindComplete)
852 .await
853 .unwrap();
854 write_frame(
855 &mut out,
856 &BackendMessage::ParameterDescription(vec![23, 25]),
857 )
858 .await
859 .unwrap();
860 write_frame(&mut out, &BackendMessage::NoData)
861 .await
862 .unwrap();
863 write_frame(&mut out, &BackendMessage::CloseComplete)
864 .await
865 .unwrap();
866 assert_eq!(collect_tags(&out), vec![b'1', b'2', b't', b'n', b'3']);
867 }
868
869 fn count_nul(buf: &[u8]) -> usize {
877 buf.iter().filter(|&&b| b == 0).count()
878 }
879
880 #[tokio::test]
881 async fn pg3_nul_error_response_message_field_sanitized() {
882 let mut out: Vec<u8> = Vec::new();
883 write_frame(
884 &mut out,
885 &BackendMessage::ErrorResponse {
886 severity: "ERROR".to_string(),
887 code: "42000".to_string(),
888 message: "smuggled\0M\x00injection".to_string(),
889 },
890 )
891 .await
892 .unwrap();
893 assert_eq!(out[0], b'E');
894 let body = &out[5..];
898 assert_eq!(
899 count_nul(body),
900 5,
901 "expected 5 NULs (4 field + 1 list-end), got {} :: body={:?}",
902 count_nul(body),
903 body
904 );
905 assert!(
907 body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]),
908 "expected U+FFFD substitution in body"
909 );
910 }
911
912 #[tokio::test]
913 async fn pg3_nul_notice_response_sanitized() {
914 let mut out: Vec<u8> = Vec::new();
915 write_frame(
916 &mut out,
917 &BackendMessage::NoticeResponse {
918 message: "evil\0field".to_string(),
919 },
920 )
921 .await
922 .unwrap();
923 assert_eq!(out[0], b'N');
924 let body = &out[5..];
925 assert_eq!(count_nul(body), 3);
927 assert!(body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]));
928 }
929
930 #[tokio::test]
931 async fn pg3_nul_command_complete_sanitized() {
932 let mut out: Vec<u8> = Vec::new();
933 write_frame(
934 &mut out,
935 &BackendMessage::CommandComplete("SELECT\0;DROP".to_string()),
936 )
937 .await
938 .unwrap();
939 assert_eq!(out[0], b'C');
940 let body = &out[5..];
941 assert_eq!(count_nul(body), 1);
943 }
944
945 #[tokio::test]
946 async fn pg3_nul_row_description_column_name_sanitized() {
947 let mut out: Vec<u8> = Vec::new();
948 write_frame(
949 &mut out,
950 &BackendMessage::RowDescription(vec![ColumnDescriptor {
951 name: "evil\0col".to_string(),
952 table_oid: 0,
953 column_attr: 0,
954 type_oid: 23,
955 type_size: 4,
956 type_mod: -1,
957 format: 0,
958 }]),
959 )
960 .await
961 .unwrap();
962 assert_eq!(out[0], b'T');
963 let body = &out[5..];
966 let name_region = &body[2..];
969 let first_nul = name_region.iter().position(|&b| b == 0).unwrap();
970 assert!(
971 name_region[..first_nul]
972 .windows(3)
973 .any(|w| w == [0xEF, 0xBF, 0xBD]),
974 "U+FFFD missing from sanitized column name"
975 );
976 }
977
978 #[test]
979 fn sanitize_cstring_fastpath_no_nul() {
980 let s = "no nuls here";
981 let out = sanitize_cstring_bytes(s.as_bytes());
982 assert_eq!(out, s.as_bytes());
983 }
984
985 #[test]
986 fn sanitize_cstring_substitutes_nul_with_replacement_codepoint() {
987 let s = b"a\0b\0c";
988 let out = sanitize_cstring_bytes(s);
989 assert_eq!(out.len(), 9);
991 assert!(!out.contains(&0));
992 assert_eq!(&out[1..4], &[0xEF, 0xBF, 0xBD]);
993 assert_eq!(&out[5..8], &[0xEF, 0xBF, 0xBD]);
994 }
995
996 fn tagged_frame(tag: u8, payload: Vec<u8>) -> Vec<u8> {
997 let mut frame = vec![tag];
998 frame.extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
999 frame.extend_from_slice(&payload);
1000 frame
1001 }
1002
1003 fn push_test_cstring(out: &mut Vec<u8>, value: &str) {
1004 out.extend_from_slice(value.as_bytes());
1005 out.push(0);
1006 }
1007
1008 fn collect_tags(bytes: &[u8]) -> Vec<u8> {
1009 let mut tags = Vec::new();
1010 let mut pos = 0;
1011 while pos < bytes.len() {
1012 tags.push(bytes[pos]);
1013 let len = u32::from_be_bytes([
1014 bytes[pos + 1],
1015 bytes[pos + 2],
1016 bytes[pos + 3],
1017 bytes[pos + 4],
1018 ]) as usize;
1019 pos += 1 + len;
1020 }
1021 tags
1022 }
1023}