1use std::io;
21
22use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
23
24pub const PG_PROTOCOL_V3: u32 = 3 << 16;
26
27pub const PG_SSL_REQUEST: u32 = 80877103;
31pub const PG_GSSENC_REQUEST: u32 = 80877104;
32pub const PG_CANCEL_REQUEST: u32 = 80877102;
33
34#[derive(Debug)]
37pub enum PgWireError {
38 Io(io::Error),
39 Protocol(String),
40 Eof,
42}
43
44impl From<io::Error> for PgWireError {
45 fn from(err: io::Error) -> Self {
46 if err.kind() == io::ErrorKind::UnexpectedEof {
47 PgWireError::Eof
48 } else {
49 PgWireError::Io(err)
50 }
51 }
52}
53
54impl std::fmt::Display for PgWireError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 PgWireError::Io(e) => write!(f, "pg wire io: {e}"),
58 PgWireError::Protocol(m) => write!(f, "pg wire protocol: {m}"),
59 PgWireError::Eof => write!(f, "pg wire eof"),
60 }
61 }
62}
63
64impl std::error::Error for PgWireError {}
65
66#[derive(Debug, Clone)]
68pub enum FrontendMessage {
69 Startup(StartupParams),
71 SslRequest,
73 GssEncRequest,
75 Query(String),
77 PasswordMessage(Vec<u8>),
79 Terminate,
81 Flush,
83 Sync,
85 Unknown { tag: u8, payload: Vec<u8> },
88}
89
90#[derive(Debug, Clone, Default)]
91pub struct StartupParams {
92 pub params: Vec<(String, String)>,
94}
95
96impl StartupParams {
97 pub fn get(&self, key: &str) -> Option<&str> {
98 self.params
99 .iter()
100 .find(|(k, _)| k == key)
101 .map(|(_, v)| v.as_str())
102 }
103}
104
105#[derive(Debug, Clone)]
107pub enum BackendMessage {
108 AuthenticationOk,
110 ParameterStatus { name: String, value: String },
112 BackendKeyData { pid: u32, key: u32 },
114 ReadyForQuery(TransactionStatus),
116 RowDescription(Vec<ColumnDescriptor>),
118 DataRow(Vec<Option<Vec<u8>>>),
120 CommandComplete(String),
122 ErrorResponse {
124 severity: String,
125 code: String,
126 message: String,
127 },
128 NoticeResponse { message: String },
130 EmptyQueryResponse,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TransactionStatus {
136 Idle,
138 InTransaction,
140 Failed,
142}
143
144impl TransactionStatus {
145 pub fn as_byte(self) -> u8 {
146 match self {
147 TransactionStatus::Idle => b'I',
148 TransactionStatus::InTransaction => b'T',
149 TransactionStatus::Failed => b'E',
150 }
151 }
152}
153
154#[derive(Debug, Clone)]
155pub struct ColumnDescriptor {
156 pub name: String,
157 pub table_oid: u32,
159 pub column_attr: i16,
161 pub type_oid: u32,
163 pub type_size: i16,
165 pub type_mod: i32,
167 pub format: i16,
169}
170
171pub async fn read_startup<R: AsyncRead + Unpin>(
179 stream: &mut R,
180) -> Result<FrontendMessage, PgWireError> {
181 let mut len_buf = [0u8; 4];
182 stream.read_exact(&mut len_buf).await?;
183 let len = u32::from_be_bytes(len_buf);
184 if !(8..=65536).contains(&len) {
185 return Err(PgWireError::Protocol(format!(
186 "startup length {len} out of range"
187 )));
188 }
189 let body_len = (len as usize) - 4;
190 let mut body = vec![0u8; body_len];
191 stream.read_exact(&mut body).await?;
192 if body_len < 4 {
193 return Err(PgWireError::Protocol("startup payload too short".into()));
194 }
195 let version = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
196
197 match version {
198 PG_SSL_REQUEST => Ok(FrontendMessage::SslRequest),
199 PG_GSSENC_REQUEST => Ok(FrontendMessage::GssEncRequest),
200 PG_PROTOCOL_V3 => {
201 let mut params: Vec<(String, String)> = Vec::new();
204 let mut pos = 4usize;
205 while pos < body_len {
206 if body[pos] == 0 {
207 break;
208 }
209 let key = read_cstring(&body, &mut pos)?;
210 if pos >= body_len {
211 return Err(PgWireError::Protocol(
212 "startup parameter missing value".into(),
213 ));
214 }
215 let value = read_cstring(&body, &mut pos)?;
216 params.push((key, value));
217 }
218 Ok(FrontendMessage::Startup(StartupParams { params }))
219 }
220 PG_CANCEL_REQUEST => Ok(FrontendMessage::Unknown {
223 tag: b'K',
224 payload: body,
225 }),
226 _ => Err(PgWireError::Protocol(format!(
227 "unsupported protocol version {version}"
228 ))),
229 }
230}
231
232pub async fn read_frame<R: AsyncRead + Unpin>(
234 stream: &mut R,
235) -> Result<FrontendMessage, PgWireError> {
236 let mut tag_buf = [0u8; 1];
237 match stream.read_exact(&mut tag_buf).await {
238 Ok(_) => {}
239 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Err(PgWireError::Eof),
240 Err(e) => return Err(PgWireError::Io(e)),
241 }
242 let tag = tag_buf[0];
243
244 let mut len_buf = [0u8; 4];
245 stream.read_exact(&mut len_buf).await?;
246 let len = u32::from_be_bytes(len_buf);
247 if !(4..=1_048_576).contains(&len) {
248 return Err(PgWireError::Protocol(format!(
249 "frame length {len} out of bounds"
250 )));
251 }
252 let payload_len = (len as usize) - 4;
253 let mut payload = vec![0u8; payload_len];
254 stream.read_exact(&mut payload).await?;
255
256 Ok(match tag {
257 b'Q' => {
258 let mut pos = 0;
260 let query = read_cstring(&payload, &mut pos)?;
261 FrontendMessage::Query(query)
262 }
263 b'p' => FrontendMessage::PasswordMessage(payload),
264 b'X' => FrontendMessage::Terminate,
265 b'H' => FrontendMessage::Flush,
266 b'S' => FrontendMessage::Sync,
267 other => FrontendMessage::Unknown {
268 tag: other,
269 payload,
270 },
271 })
272}
273
274pub async fn write_raw_byte<W: AsyncWrite + Unpin>(
281 stream: &mut W,
282 byte: u8,
283) -> Result<(), PgWireError> {
284 stream.write_all(&[byte]).await?;
285 Ok(())
286}
287
288pub async fn write_frame<W: AsyncWrite + Unpin>(
290 stream: &mut W,
291 msg: &BackendMessage,
292) -> Result<(), PgWireError> {
293 let (tag, payload) = encode_backend(msg);
294 let length = (payload.len() + 4) as u32;
296 stream.write_all(&[tag]).await?;
297 stream.write_all(&length.to_be_bytes()).await?;
298 stream.write_all(&payload).await?;
299 Ok(())
300}
301
302fn sanitize_cstring_bytes(input: &[u8]) -> Vec<u8> {
318 if !input.contains(&0) {
319 return input.to_vec();
320 }
321 let mut out = Vec::with_capacity(input.len() + 8);
322 for &b in input {
323 if b == 0 {
324 out.extend_from_slice(&[0xEF, 0xBF, 0xBD]);
326 } else {
327 out.push(b);
328 }
329 }
330 out
331}
332
333#[inline]
334fn push_cstring(buf: &mut Vec<u8>, value: &str) {
335 buf.extend_from_slice(&sanitize_cstring_bytes(value.as_bytes()));
336 buf.push(0);
337}
338
339fn encode_backend(msg: &BackendMessage) -> (u8, Vec<u8>) {
340 match msg {
341 BackendMessage::AuthenticationOk => {
342 (b'R', vec![0, 0, 0, 0])
344 }
345 BackendMessage::ParameterStatus { name, value } => {
346 let mut buf = Vec::with_capacity(name.len() + value.len() + 2);
347 push_cstring(&mut buf, name);
349 push_cstring(&mut buf, value);
350 (b'S', buf)
351 }
352 BackendMessage::BackendKeyData { pid, key } => {
353 let mut buf = Vec::with_capacity(8);
354 buf.extend_from_slice(&pid.to_be_bytes());
355 buf.extend_from_slice(&key.to_be_bytes());
356 (b'K', buf)
357 }
358 BackendMessage::ReadyForQuery(status) => (b'Z', vec![status.as_byte()]),
359 BackendMessage::RowDescription(cols) => {
360 let mut buf = Vec::new();
361 buf.extend_from_slice(&(cols.len() as i16).to_be_bytes());
362 for col in cols {
363 push_cstring(&mut buf, &col.name);
365 buf.extend_from_slice(&col.table_oid.to_be_bytes());
366 buf.extend_from_slice(&col.column_attr.to_be_bytes());
367 buf.extend_from_slice(&col.type_oid.to_be_bytes());
368 buf.extend_from_slice(&col.type_size.to_be_bytes());
369 buf.extend_from_slice(&col.type_mod.to_be_bytes());
370 buf.extend_from_slice(&col.format.to_be_bytes());
371 }
372 (b'T', buf)
373 }
374 BackendMessage::DataRow(fields) => {
375 let mut buf = Vec::new();
376 buf.extend_from_slice(&(fields.len() as i16).to_be_bytes());
377 for field in fields {
378 match field {
379 None => {
380 buf.extend_from_slice(&(-1i32).to_be_bytes());
382 }
383 Some(bytes) => {
384 buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
388 buf.extend_from_slice(bytes);
389 }
390 }
391 }
392 (b'D', buf)
393 }
394 BackendMessage::CommandComplete(tag) => {
395 let mut buf = Vec::with_capacity(tag.len() + 1);
396 push_cstring(&mut buf, tag);
399 (b'C', buf)
400 }
401 BackendMessage::ErrorResponse {
402 severity,
403 code,
404 message,
405 } => {
406 let mut buf = Vec::new();
407 buf.push(b'S');
409 push_cstring(&mut buf, severity);
410 buf.push(b'V');
412 push_cstring(&mut buf, severity);
413 buf.push(b'C');
415 push_cstring(&mut buf, code);
416 buf.push(b'M');
418 push_cstring(&mut buf, message);
419 buf.push(0);
421 (b'E', buf)
422 }
423 BackendMessage::NoticeResponse { message } => {
424 let mut buf = Vec::new();
425 buf.push(b'S');
426 buf.extend_from_slice(b"NOTICE");
427 buf.push(0);
428 buf.push(b'M');
429 push_cstring(&mut buf, message);
431 buf.push(0);
432 (b'N', buf)
433 }
434 BackendMessage::EmptyQueryResponse => (b'I', Vec::new()),
435 }
436}
437
438fn read_cstring(buf: &[u8], pos: &mut usize) -> Result<String, PgWireError> {
445 let start = *pos;
446 while *pos < buf.len() && buf[*pos] != 0 {
447 *pos += 1;
448 }
449 if *pos >= buf.len() {
450 return Err(PgWireError::Protocol("cstring missing terminator".into()));
451 }
452 let s = std::str::from_utf8(&buf[start..*pos])
453 .map_err(|e| PgWireError::Protocol(format!("invalid utf8: {e}")))?
454 .to_string();
455 *pos += 1; Ok(s)
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[tokio::test]
464 async fn parse_startup_v3() {
465 let mut payload: Vec<u8> = Vec::new();
467 payload.extend_from_slice(&PG_PROTOCOL_V3.to_be_bytes());
468 payload.extend_from_slice(b"user\0alice\0");
469 payload.push(0);
470 let len = (4 + payload.len()) as u32;
471 let mut frame = Vec::new();
472 frame.extend_from_slice(&len.to_be_bytes());
473 frame.extend_from_slice(&payload);
474
475 let mut cursor = std::io::Cursor::new(frame);
476 let msg = read_startup(&mut cursor).await.unwrap();
477 match msg {
478 FrontendMessage::Startup(params) => {
479 assert_eq!(params.get("user"), Some("alice"));
480 }
481 other => panic!("expected Startup, got {:?}", other),
482 }
483 }
484
485 #[tokio::test]
486 async fn parse_ssl_request() {
487 let mut frame: Vec<u8> = Vec::new();
488 frame.extend_from_slice(&8u32.to_be_bytes());
489 frame.extend_from_slice(&PG_SSL_REQUEST.to_be_bytes());
490 let mut cursor = std::io::Cursor::new(frame);
491 assert!(matches!(
492 read_startup(&mut cursor).await.unwrap(),
493 FrontendMessage::SslRequest
494 ));
495 }
496
497 #[tokio::test]
498 async fn parse_query_frame() {
499 let query = "SELECT 1\0";
500 let mut frame = Vec::new();
501 frame.push(b'Q');
502 let len = (4 + query.len()) as u32;
503 frame.extend_from_slice(&len.to_be_bytes());
504 frame.extend_from_slice(query.as_bytes());
505 let mut cursor = std::io::Cursor::new(frame);
506 match read_frame(&mut cursor).await.unwrap() {
507 FrontendMessage::Query(s) => assert_eq!(s, "SELECT 1"),
508 other => panic!("expected Query, got {:?}", other),
509 }
510 }
511
512 #[tokio::test]
513 async fn emit_ready_for_query() {
514 let mut out: Vec<u8> = Vec::new();
515 write_frame(
516 &mut out,
517 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
518 )
519 .await
520 .unwrap();
521 assert_eq!(out, vec![b'Z', 0, 0, 0, 5, b'I']);
522 }
523
524 #[tokio::test]
525 async fn emit_row_description_and_data_row() {
526 let mut out: Vec<u8> = Vec::new();
527 write_frame(
528 &mut out,
529 &BackendMessage::RowDescription(vec![ColumnDescriptor {
530 name: "id".to_string(),
531 table_oid: 0,
532 column_attr: 0,
533 type_oid: 23,
534 type_size: 4,
535 type_mod: -1,
536 format: 0,
537 }]),
538 )
539 .await
540 .unwrap();
541 assert_eq!(out[0], b'T');
542
543 let mut data: Vec<u8> = Vec::new();
544 write_frame(
545 &mut data,
546 &BackendMessage::DataRow(vec![Some(b"42".to_vec()), None]),
547 )
548 .await
549 .unwrap();
550 assert_eq!(data[0], b'D');
551 }
552
553 fn count_nul(buf: &[u8]) -> usize {
561 buf.iter().filter(|&&b| b == 0).count()
562 }
563
564 #[tokio::test]
565 async fn pg3_nul_error_response_message_field_sanitized() {
566 let mut out: Vec<u8> = Vec::new();
567 write_frame(
568 &mut out,
569 &BackendMessage::ErrorResponse {
570 severity: "ERROR".to_string(),
571 code: "42000".to_string(),
572 message: "smuggled\0M\x00injection".to_string(),
573 },
574 )
575 .await
576 .unwrap();
577 assert_eq!(out[0], b'E');
578 let body = &out[5..];
582 assert_eq!(
583 count_nul(body),
584 5,
585 "expected 5 NULs (4 field + 1 list-end), got {} :: body={:?}",
586 count_nul(body),
587 body
588 );
589 assert!(
591 body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]),
592 "expected U+FFFD substitution in body"
593 );
594 }
595
596 #[tokio::test]
597 async fn pg3_nul_notice_response_sanitized() {
598 let mut out: Vec<u8> = Vec::new();
599 write_frame(
600 &mut out,
601 &BackendMessage::NoticeResponse {
602 message: "evil\0field".to_string(),
603 },
604 )
605 .await
606 .unwrap();
607 assert_eq!(out[0], b'N');
608 let body = &out[5..];
609 assert_eq!(count_nul(body), 3);
611 assert!(body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]));
612 }
613
614 #[tokio::test]
615 async fn pg3_nul_command_complete_sanitized() {
616 let mut out: Vec<u8> = Vec::new();
617 write_frame(
618 &mut out,
619 &BackendMessage::CommandComplete("SELECT\0;DROP".to_string()),
620 )
621 .await
622 .unwrap();
623 assert_eq!(out[0], b'C');
624 let body = &out[5..];
625 assert_eq!(count_nul(body), 1);
627 }
628
629 #[tokio::test]
630 async fn pg3_nul_row_description_column_name_sanitized() {
631 let mut out: Vec<u8> = Vec::new();
632 write_frame(
633 &mut out,
634 &BackendMessage::RowDescription(vec![ColumnDescriptor {
635 name: "evil\0col".to_string(),
636 table_oid: 0,
637 column_attr: 0,
638 type_oid: 23,
639 type_size: 4,
640 type_mod: -1,
641 format: 0,
642 }]),
643 )
644 .await
645 .unwrap();
646 assert_eq!(out[0], b'T');
647 let body = &out[5..];
650 let name_region = &body[2..];
653 let first_nul = name_region.iter().position(|&b| b == 0).unwrap();
654 assert!(
655 name_region[..first_nul]
656 .windows(3)
657 .any(|w| w == [0xEF, 0xBF, 0xBD]),
658 "U+FFFD missing from sanitized column name"
659 );
660 }
661
662 #[test]
663 fn sanitize_cstring_fastpath_no_nul() {
664 let s = "no nuls here";
665 let out = sanitize_cstring_bytes(s.as_bytes());
666 assert_eq!(out, s.as_bytes());
667 }
668
669 #[test]
670 fn sanitize_cstring_substitutes_nul_with_replacement_codepoint() {
671 let s = b"a\0b\0c";
672 let out = sanitize_cstring_bytes(s);
673 assert_eq!(out.len(), 9);
675 assert!(!out.contains(&0));
676 assert_eq!(&out[1..4], &[0xEF, 0xBF, 0xBD]);
677 assert_eq!(&out[5..8], &[0xEF, 0xBF, 0xBD]);
678 }
679}