1use bytes::{Buf, BufMut, BytesMut};
9
10use crate::error::{PgWireError, PgWireResult};
11
12#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
13pub enum ProtocolVersion {
14 PROTOCOL3_0,
15 #[default]
16 PROTOCOL3_2,
17}
18
19impl ProtocolVersion {
20 pub fn version_number(&self) -> (u16, u16) {
21 match &self {
22 Self::PROTOCOL3_0 => (3, 0),
23 Self::PROTOCOL3_2 => (3, 2),
24 }
25 }
26
27 pub fn from_version_number(major: u16, minor: u16) -> Option<Self> {
31 match (major, minor) {
32 (3, 0) => Some(Self::PROTOCOL3_0),
33 (3, 2) => Some(Self::PROTOCOL3_2),
34 _ => None,
35 }
36 }
37}
38
39pub(crate) const LARGE_PACKET_SIZE_LIMIT: usize = 0x3fffffff - 1;
41pub(crate) const SMALL_PACKET_SIZE_LIMIT: usize = 10000;
42pub(crate) const SMALL_BACKEND_PACKET_SIZE_LIMIT: usize = 30000;
45pub(crate) const LONG_BACKEND_PACKET_SIZE_LIMIT: usize = i32::MAX as usize;
46
47#[non_exhaustive]
48#[derive(Default, Debug, PartialEq, Eq, new)]
49pub struct DecodeContext {
50 pub protocol_version: ProtocolVersion,
51 #[new(value = "true")]
52 pub awaiting_ssl: bool,
53 #[new(value = "true")]
54 pub awaiting_startup: bool,
55}
56
57pub trait Message: Sized {
59 #[inline]
62 fn message_type() -> Option<u8> {
63 None
64 }
65
66 fn message_length(&self) -> usize;
68
69 #[inline]
73 fn max_message_length() -> usize {
74 SMALL_PACKET_SIZE_LIMIT
75 }
76
77 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()>;
79
80 fn decode_body(buf: &mut BytesMut, full_len: usize, _ctx: &DecodeContext)
82 -> PgWireResult<Self>;
83
84 fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
89 if let Some(mt) = Self::message_type() {
90 buf.put_u8(mt);
91 }
92
93 let len = self.message_length();
94 if len > Self::max_message_length() {
95 return Err(PgWireError::MessageTooLarge(
96 len,
97 Self::max_message_length(),
98 ));
99 }
100
101 buf.put_i32(len as i32);
102 self.encode_body(buf)
103 }
104
105 fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
111 let offset = Self::message_type().is_some().into();
112
113 codec::decode_packet(buf, offset, Self::max_message_length(), |buf, full_len| {
114 Self::decode_body(buf, full_len, ctx)
115 })
116 }
117}
118
119pub mod cancel;
121mod codec;
122pub mod copy;
124pub mod data;
126pub mod extendedquery;
128pub mod response;
130pub mod simplequery;
132pub mod startup;
134pub mod terminate;
136
137#[derive(Debug)]
139pub enum PgWireFrontendMessage {
140 Startup(startup::Startup),
141 CancelRequest(cancel::CancelRequest),
142 SslRequest(startup::SslRequest),
143 GssEncRequest(startup::GssEncRequest),
144 PasswordMessageFamily(startup::PasswordMessageFamily),
145
146 Query(simplequery::Query),
147
148 Parse(extendedquery::Parse),
149 Close(extendedquery::Close),
150 Bind(extendedquery::Bind),
151 Describe(extendedquery::Describe),
152 Execute(extendedquery::Execute),
153 Flush(extendedquery::Flush),
154 Sync(extendedquery::Sync),
155
156 Terminate(terminate::Terminate),
157
158 CopyData(copy::CopyData),
159 CopyFail(copy::CopyFail),
160 CopyDone(copy::CopyDone),
161}
162
163impl PgWireFrontendMessage {
164 pub fn is_extended_query(&self) -> bool {
165 matches!(
166 self,
167 Self::Parse(_)
168 | Self::Bind(_)
169 | Self::Close(_)
170 | Self::Describe(_)
171 | Self::Execute(_)
172 | Self::Flush(_)
173 | Self::Sync(_)
174 )
175 }
176
177 pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
178 match self {
179 Self::Startup(msg) => msg.encode(buf),
180 Self::CancelRequest(msg) => msg.encode(buf),
181 Self::SslRequest(msg) => msg.encode(buf),
182 Self::GssEncRequest(msg) => msg.encode(buf),
183
184 Self::PasswordMessageFamily(msg) => msg.encode(buf),
185
186 Self::Query(msg) => msg.encode(buf),
187
188 Self::Parse(msg) => msg.encode(buf),
189 Self::Bind(msg) => msg.encode(buf),
190 Self::Close(msg) => msg.encode(buf),
191 Self::Describe(msg) => msg.encode(buf),
192 Self::Execute(msg) => msg.encode(buf),
193 Self::Flush(msg) => msg.encode(buf),
194 Self::Sync(msg) => msg.encode(buf),
195
196 Self::Terminate(msg) => msg.encode(buf),
197
198 Self::CopyData(msg) => msg.encode(buf),
199 Self::CopyFail(msg) => msg.encode(buf),
200 Self::CopyDone(msg) => msg.encode(buf),
201 }
202 }
203
204 pub fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
205 if ctx.awaiting_ssl {
206 if buf.remaining() >= 8 {
213 if cancel::CancelRequest::is_cancel_request_packet(buf) {
214 cancel::CancelRequest::decode(buf, ctx)
215 .map(|opt| opt.map(PgWireFrontendMessage::CancelRequest))
216 } else if startup::SslRequest::is_ssl_request_packet(buf) {
217 startup::SslRequest::decode(buf, ctx)
218 .map(|opt| opt.map(PgWireFrontendMessage::SslRequest))
219 } else if startup::GssEncRequest::is_gss_enc_request_packet(buf) {
220 startup::GssEncRequest::decode(buf, ctx)
221 .map(|opt| opt.map(PgWireFrontendMessage::GssEncRequest))
222 } else {
223 startup::Startup::decode(buf, ctx).map(|v| v.map(Self::Startup))
225 }
226 } else {
227 Ok(None)
228 }
229 } else if ctx.awaiting_startup {
230 if buf.remaining() >= 8 {
232 if cancel::CancelRequest::is_cancel_request_packet(buf) {
233 cancel::CancelRequest::decode(buf, ctx)
234 .map(|opt| opt.map(PgWireFrontendMessage::CancelRequest))
235 } else {
236 startup::Startup::decode(buf, ctx).map(|v| v.map(Self::Startup))
237 }
238 } else {
239 Ok(None)
240 }
241 } else if buf.remaining() > 1 {
242 let first_byte = buf[0];
243
244 match first_byte {
245 startup::MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY => {
248 startup::PasswordMessageFamily::decode(buf, ctx)
249 .map(|v| v.map(Self::PasswordMessageFamily))
250 }
251
252 simplequery::MESSAGE_TYPE_BYTE_QUERY => {
253 simplequery::Query::decode(buf, ctx).map(|v| v.map(Self::Query))
254 }
255
256 extendedquery::MESSAGE_TYPE_BYTE_PARSE => {
257 extendedquery::Parse::decode(buf, ctx).map(|v| v.map(Self::Parse))
258 }
259 extendedquery::MESSAGE_TYPE_BYTE_BIND => {
260 extendedquery::Bind::decode(buf, ctx).map(|v| v.map(Self::Bind))
261 }
262 extendedquery::MESSAGE_TYPE_BYTE_CLOSE => {
263 extendedquery::Close::decode(buf, ctx).map(|v| v.map(Self::Close))
264 }
265 extendedquery::MESSAGE_TYPE_BYTE_DESCRIBE => {
266 extendedquery::Describe::decode(buf, ctx).map(|v| v.map(Self::Describe))
267 }
268 extendedquery::MESSAGE_TYPE_BYTE_EXECUTE => {
269 extendedquery::Execute::decode(buf, ctx).map(|v| v.map(Self::Execute))
270 }
271 extendedquery::MESSAGE_TYPE_BYTE_FLUSH => {
272 extendedquery::Flush::decode(buf, ctx).map(|v| v.map(Self::Flush))
273 }
274 extendedquery::MESSAGE_TYPE_BYTE_SYNC => {
275 extendedquery::Sync::decode(buf, ctx).map(|v| v.map(Self::Sync))
276 }
277
278 terminate::MESSAGE_TYPE_BYTE_TERMINATE => {
279 terminate::Terminate::decode(buf, ctx).map(|v| v.map(Self::Terminate))
280 }
281
282 copy::MESSAGE_TYPE_BYTE_COPY_DATA => {
283 copy::CopyData::decode(buf, ctx).map(|v| v.map(Self::CopyData))
284 }
285 copy::MESSAGE_TYPE_BYTE_COPY_FAIL => {
286 copy::CopyFail::decode(buf, ctx).map(|v| v.map(Self::CopyFail))
287 }
288 copy::MESSAGE_TYPE_BYTE_COPY_DONE => {
289 copy::CopyDone::decode(buf, ctx).map(|v| v.map(Self::CopyDone))
290 }
291 _ => Err(PgWireError::InvalidMessageType(first_byte)),
292 }
293 } else {
294 Ok(None)
295 }
296 }
297}
298
299#[derive(Debug)]
301pub enum PgWireBackendMessage {
302 SslResponse(response::SslResponse),
304 GssEncResponse(response::GssEncResponse),
305 Authentication(startup::Authentication),
306 ParameterStatus(startup::ParameterStatus),
307 BackendKeyData(startup::BackendKeyData),
308 NegotiateProtocolVersion(startup::NegotiateProtocolVersion),
309
310 ParseComplete(extendedquery::ParseComplete),
312 CloseComplete(extendedquery::CloseComplete),
313 BindComplete(extendedquery::BindComplete),
314 PortalSuspended(extendedquery::PortalSuspended),
315
316 CommandComplete(response::CommandComplete),
318 EmptyQueryResponse(response::EmptyQueryResponse),
319 ReadyForQuery(response::ReadyForQuery),
320 ErrorResponse(response::ErrorResponse),
321 NoticeResponse(response::NoticeResponse),
322 NotificationResponse(response::NotificationResponse),
323
324 ParameterDescription(data::ParameterDescription),
326 RowDescription(data::RowDescription),
327 DataRow(data::DataRow),
328 NoData(data::NoData),
329
330 CopyData(copy::CopyData),
332 CopyFail(copy::CopyFail),
333 CopyDone(copy::CopyDone),
334 CopyInResponse(copy::CopyInResponse),
335 CopyOutResponse(copy::CopyOutResponse),
336 CopyBothResponse(copy::CopyBothResponse),
337}
338
339impl PgWireBackendMessage {
340 pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
341 match self {
342 Self::SslResponse(msg) => msg.encode(buf),
343 Self::GssEncResponse(msg) => msg.encode(buf),
344 Self::Authentication(msg) => msg.encode(buf),
345 Self::ParameterStatus(msg) => msg.encode(buf),
346 Self::BackendKeyData(msg) => msg.encode(buf),
347 Self::NegotiateProtocolVersion(msg) => msg.encode(buf),
348
349 Self::ParseComplete(msg) => msg.encode(buf),
350 Self::BindComplete(msg) => msg.encode(buf),
351 Self::CloseComplete(msg) => msg.encode(buf),
352 Self::PortalSuspended(msg) => msg.encode(buf),
353
354 Self::CommandComplete(msg) => msg.encode(buf),
355 Self::EmptyQueryResponse(msg) => msg.encode(buf),
356 Self::ReadyForQuery(msg) => msg.encode(buf),
357 Self::ErrorResponse(msg) => msg.encode(buf),
358 Self::NoticeResponse(msg) => msg.encode(buf),
359 Self::NotificationResponse(msg) => msg.encode(buf),
360
361 Self::ParameterDescription(msg) => msg.encode(buf),
362 Self::RowDescription(msg) => msg.encode(buf),
363 Self::DataRow(msg) => msg.encode(buf),
364 Self::NoData(msg) => msg.encode(buf),
365
366 Self::CopyData(msg) => msg.encode(buf),
367 Self::CopyFail(msg) => msg.encode(buf),
368 Self::CopyDone(msg) => msg.encode(buf),
369 Self::CopyInResponse(msg) => msg.encode(buf),
370 Self::CopyOutResponse(msg) => msg.encode(buf),
371 Self::CopyBothResponse(msg) => msg.encode(buf),
372 }
373 }
374
375 pub fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
376 if buf.remaining() > 1 {
377 let first_byte = buf[0];
378 match first_byte {
379 startup::MESSAGE_TYPE_BYTE_AUTHENTICATION => {
380 startup::Authentication::decode(buf, ctx).map(|v| v.map(Self::Authentication))
381 }
382 startup::MESSAGE_TYPE_BYTE_PARAMETER_STATUS => {
383 startup::ParameterStatus::decode(buf, ctx).map(|v| v.map(Self::ParameterStatus))
384 }
385 startup::MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA => {
386 startup::BackendKeyData::decode(buf, ctx).map(|v| v.map(Self::BackendKeyData))
387 }
388 startup::MESSAGE_TYPE_BYTE_NEGOTIATE_PROTOCOL_VERSION => {
389 startup::NegotiateProtocolVersion::decode(buf, ctx)
390 .map(|v| v.map(Self::NegotiateProtocolVersion))
391 }
392
393 extendedquery::MESSAGE_TYPE_BYTE_PARSE_COMPLETE => {
394 extendedquery::ParseComplete::decode(buf, ctx)
395 .map(|v| v.map(Self::ParseComplete))
396 }
397 extendedquery::MESSAGE_TYPE_BYTE_BIND_COMPLETE => {
398 extendedquery::BindComplete::decode(buf, ctx).map(|v| v.map(Self::BindComplete))
399 }
400 extendedquery::MESSAGE_TYPE_BYTE_CLOSE_COMPLETE => {
401 extendedquery::CloseComplete::decode(buf, ctx)
402 .map(|v| v.map(Self::CloseComplete))
403 }
404 extendedquery::MESSAGE_TYPE_BYTE_PORTAL_SUSPENDED => {
405 extendedquery::PortalSuspended::decode(buf, ctx)
406 .map(|v| v.map(PgWireBackendMessage::PortalSuspended))
407 }
408
409 response::MESSAGE_TYPE_BYTE_COMMAND_COMPLETE => {
410 response::CommandComplete::decode(buf, ctx)
411 .map(|v| v.map(Self::CommandComplete))
412 }
413 response::MESSAGE_TYPE_BYTE_EMPTY_QUERY_RESPONSE => {
414 response::EmptyQueryResponse::decode(buf, ctx)
415 .map(|v| v.map(Self::EmptyQueryResponse))
416 }
417 response::MESSAGE_TYPE_BYTE_READY_FOR_QUERY => {
418 response::ReadyForQuery::decode(buf, ctx).map(|v| v.map(Self::ReadyForQuery))
419 }
420 response::MESSAGE_TYPE_BYTE_ERROR_RESPONSE => {
421 response::ErrorResponse::decode(buf, ctx).map(|v| v.map(Self::ErrorResponse))
422 }
423 response::MESSAGE_TYPE_BYTE_NOTICE_RESPONSE => {
424 response::NoticeResponse::decode(buf, ctx).map(|v| v.map(Self::NoticeResponse))
425 }
426 response::MESSAGE_TYPE_BYTE_NOTIFICATION_RESPONSE => {
427 response::NotificationResponse::decode(buf, ctx)
428 .map(|v| v.map(Self::NotificationResponse))
429 }
430
431 data::MESSAGE_TYPE_BYTE_PARAMETER_DESCRITION => {
432 data::ParameterDescription::decode(buf, ctx)
433 .map(|v| v.map(Self::ParameterDescription))
434 }
435 data::MESSAGE_TYPE_BYTE_ROW_DESCRITION => {
436 data::RowDescription::decode(buf, ctx).map(|v| v.map(Self::RowDescription))
437 }
438 data::MESSAGE_TYPE_BYTE_DATA_ROW => {
439 data::DataRow::decode(buf, ctx).map(|v| v.map(Self::DataRow))
440 }
441 data::MESSAGE_TYPE_BYTE_NO_DATA => {
442 data::NoData::decode(buf, ctx).map(|v| v.map(Self::NoData))
443 }
444
445 copy::MESSAGE_TYPE_BYTE_COPY_DATA => {
446 copy::CopyData::decode(buf, ctx).map(|v| v.map(Self::CopyData))
447 }
448 copy::MESSAGE_TYPE_BYTE_COPY_FAIL => {
449 copy::CopyFail::decode(buf, ctx).map(|v| v.map(Self::CopyFail))
450 }
451 copy::MESSAGE_TYPE_BYTE_COPY_DONE => {
452 copy::CopyDone::decode(buf, ctx).map(|v| v.map(Self::CopyDone))
453 }
454 copy::MESSAGE_TYPE_BYTE_COPY_IN_RESPONSE => {
455 copy::CopyInResponse::decode(buf, ctx).map(|v| v.map(Self::CopyInResponse))
456 }
457 copy::MESSAGE_TYPE_BYTE_COPY_OUT_RESPONSE => {
458 copy::CopyOutResponse::decode(buf, ctx).map(|v| v.map(Self::CopyOutResponse))
459 }
460 copy::MESSAGE_TYPE_BYTE_COPY_BOTH_RESPONSE => {
461 copy::CopyBothResponse::decode(buf, ctx).map(|v| v.map(Self::CopyBothResponse))
462 }
463 _ => Err(PgWireError::InvalidMessageType(first_byte)),
464 }
465 } else {
466 Ok(None)
467 }
468 }
469}
470
471#[cfg(test)]
472mod test {
473 use crate::messages::DecodeContext;
474
475 use super::cancel::CancelRequest;
476 use super::copy::*;
477 use super::data::*;
478 use super::extendedquery::*;
479 use super::response::*;
480 use super::simplequery::*;
481 use super::startup::*;
482 use super::terminate::*;
483 use super::{Message, ProtocolVersion};
484 use bytes::{Buf, BufMut, Bytes, BytesMut};
485
486 macro_rules! roundtrip {
487 ($ins:ident, $st:ty, $ctx:expr) => {
488 let mut buffer = BytesMut::new();
489 $ins.encode(&mut buffer).expect("encode packet");
490
491 assert!(buffer.remaining() > 0);
492
493 let item2 = <$st>::decode(&mut buffer, $ctx)
494 .expect("decode packet")
495 .expect("packet is none");
496
497 assert_eq!(buffer.remaining(), 0);
498 assert_eq!($ins, item2);
499 };
500 }
501
502 #[test]
503 fn test_startup() {
504 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
505
506 let mut s = Startup::default();
507 s.parameters.insert("user".to_owned(), "tomcat".to_owned());
508 roundtrip!(s, Startup, &ctx);
509
510 ctx.awaiting_ssl = false;
511 roundtrip!(s, Startup, &ctx);
512
513 let mut s_too_large = Startup::default();
514 s_too_large
515 .parameters
516 .insert("user".to_owned(), "a".repeat(10000));
517 let mut buffer = BytesMut::new();
518 assert!(s_too_large.encode(&mut buffer).is_err());
519 }
520
521 #[test]
522 fn test_cancel_request() {
523 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_2);
524
525 let s = CancelRequest::new(100, SecretKey::Bytes(Bytes::from("server2008")));
526 roundtrip!(s, CancelRequest, &ctx);
527
528 ctx.protocol_version = ProtocolVersion::PROTOCOL3_0;
529 let s = CancelRequest::new(100, SecretKey::I32(1900));
530 roundtrip!(s, CancelRequest, &ctx);
531 }
532
533 #[test]
534 fn test_authentication() {
535 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
536 ctx.awaiting_ssl = false;
537 ctx.awaiting_startup = false;
538
539 let ss = vec![
540 Authentication::Ok,
541 Authentication::CleartextPassword,
542 Authentication::KerberosV5,
543 Authentication::SASLContinue(Bytes::from("hello")),
544 Authentication::SASLFinal(Bytes::from("world")),
545 ];
546 for s in ss {
547 roundtrip!(s, Authentication, &ctx);
548 }
549
550 let md5pass = Authentication::MD5Password(vec![b'p', b's', b't', b'g']);
551 roundtrip!(md5pass, Authentication, &ctx);
552 }
553
554 #[test]
555 fn test_password() {
556 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
557 ctx.awaiting_ssl = false;
558 ctx.awaiting_startup = false;
559
560 let s = Password::new("pgwire".to_owned());
561 roundtrip!(s, Password, &ctx);
562 }
563
564 #[test]
565 fn test_parameter_status() {
566 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
567 ctx.awaiting_ssl = false;
568 ctx.awaiting_startup = false;
569
570 let pps = ParameterStatus::new("cli".to_owned(), "psql".to_owned());
571 roundtrip!(pps, ParameterStatus, &ctx);
572 }
573
574 #[test]
575 fn test_query() {
576 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
577 ctx.awaiting_ssl = false;
578 ctx.awaiting_startup = false;
579
580 let query = Query::new("SELECT 1".to_owned());
581 roundtrip!(query, Query, &ctx);
582 }
583
584 #[test]
585 fn test_command_complete() {
586 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
587 ctx.awaiting_ssl = false;
588 ctx.awaiting_startup = false;
589
590 let cc = CommandComplete::new("DELETE 5".to_owned());
591 roundtrip!(cc, CommandComplete, &ctx);
592
593 let cc = CommandComplete::new("DELETE 5".repeat(10_000));
594 let mut buffer = BytesMut::new();
595 assert!(cc.encode(&mut buffer).is_err());
596 }
597
598 #[test]
599 fn test_ready_for_query() {
600 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
601 ctx.awaiting_ssl = false;
602 ctx.awaiting_startup = false;
603
604 let r4q = ReadyForQuery::new(TransactionStatus::Idle);
605 roundtrip!(r4q, ReadyForQuery, &ctx);
606 let r4q = ReadyForQuery::new(TransactionStatus::Transaction);
607 roundtrip!(r4q, ReadyForQuery, &ctx);
608 let r4q = ReadyForQuery::new(TransactionStatus::Error);
609 roundtrip!(r4q, ReadyForQuery, &ctx);
610 }
611
612 #[test]
613 fn test_error_response() {
614 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
615 ctx.awaiting_ssl = false;
616 ctx.awaiting_startup = false;
617
618 let mut error = ErrorResponse::default();
619 error.fields.push((b'R', "ERROR".to_owned()));
620 error.fields.push((b'K', "cli".to_owned()));
621
622 roundtrip!(error, ErrorResponse, &ctx);
623 }
624
625 #[test]
626 fn test_notice_response() {
627 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
628 ctx.awaiting_ssl = false;
629 ctx.awaiting_startup = false;
630
631 let mut error = NoticeResponse::default();
632 error.fields.push((b'R', "NOTICE".to_owned()));
633 error.fields.push((b'K', "cli".to_owned()));
634
635 roundtrip!(error, NoticeResponse, &ctx);
636 }
637
638 #[test]
639 fn test_row_description() {
640 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
641 ctx.awaiting_ssl = false;
642 ctx.awaiting_startup = false;
643
644 let mut row_description = RowDescription::default();
645
646 let mut f1 = FieldDescription::default();
647 f1.name = "id".into();
648 f1.table_id = 1001;
649 f1.column_id = 10001;
650 f1.type_id = 1083;
651 f1.type_size = 4;
652 f1.type_modifier = -1;
653 f1.format_code = FORMAT_CODE_TEXT;
654 row_description.fields.push(f1);
655
656 let mut f2 = FieldDescription::default();
657 f2.name = "name".into();
658 f2.table_id = 1001;
659 f2.column_id = 10001;
660 f2.type_id = 1099;
661 f2.type_size = -1;
662 f2.type_modifier = -1;
663 f2.format_code = FORMAT_CODE_TEXT;
664 row_description.fields.push(f2);
665
666 roundtrip!(row_description, RowDescription, &ctx);
667 }
668
669 #[test]
670 fn test_data_row() {
671 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
672 ctx.awaiting_ssl = false;
673 ctx.awaiting_startup = false;
674
675 let mut row0 = DataRow::default();
676 row0.data.put_i32(4);
677 row0.data.put_slice("data".as_bytes());
678 row0.data.put_i32(4);
679 row0.data.put_i32(1001);
680 row0.data.put_i32(-1);
681
682 roundtrip!(row0, DataRow, &ctx);
683 }
684
685 #[test]
686 fn test_terminate() {
687 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
688 ctx.awaiting_ssl = false;
689 ctx.awaiting_startup = false;
690
691 let terminate = Terminate::new();
692 roundtrip!(terminate, Terminate, &ctx);
693 }
694
695 #[test]
696 fn test_parse() {
697 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
698 ctx.awaiting_ssl = false;
699 ctx.awaiting_startup = false;
700
701 let parse = Parse::new(
702 Some("find-user-by-id".to_owned()),
703 "SELECT * FROM user WHERE id = ?".to_owned(),
704 vec![1],
705 );
706 roundtrip!(parse, Parse, &ctx);
707 }
708
709 #[test]
710 fn test_parse_65k() {
711 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
712 ctx.awaiting_ssl = false;
713 ctx.awaiting_startup = false;
714
715 let parse = Parse::new(
716 Some("many-params".to_owned()),
717 "it won't be parsed anyway".to_owned(),
718 vec![25; u16::MAX as usize],
719 );
720 roundtrip!(parse, Parse, &ctx);
721 }
722
723 #[test]
724 fn test_parse_complete() {
725 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
726 ctx.awaiting_ssl = false;
727 ctx.awaiting_startup = false;
728
729 let parse_complete = ParseComplete::new();
730 roundtrip!(parse_complete, ParseComplete, &ctx);
731 }
732
733 #[test]
734 fn test_close() {
735 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
736 ctx.awaiting_ssl = false;
737 ctx.awaiting_startup = false;
738
739 let close = Close::new(
740 TARGET_TYPE_BYTE_STATEMENT,
741 Some("find-user-by-id".to_owned()),
742 );
743 roundtrip!(close, Close, &ctx);
744 }
745
746 #[test]
747 fn test_bind() {
748 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
749 ctx.awaiting_ssl = false;
750 ctx.awaiting_startup = false;
751 let bind = Bind::new(
752 Some("find-user-by-id-0".to_owned()),
753 Some("find-user-by-id".to_owned()),
754 vec![0],
755 vec![Some(Bytes::from_static(b"1234"))],
756 vec![0],
757 );
758 roundtrip!(bind, Bind, &ctx);
759 }
760
761 #[test]
762 fn test_bind_65k() {
763 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
764 ctx.awaiting_ssl = false;
765 ctx.awaiting_startup = false;
766
767 let bind = Bind::new(
768 Some("lol".to_owned()),
769 Some("kek".to_owned()),
770 vec![0; u16::MAX as usize],
771 vec![Some(Bytes::from_static(b"1234")); u16::MAX as usize],
772 vec![0],
773 );
774 roundtrip!(bind, Bind, &ctx);
775 }
776
777 #[test]
778 fn test_execute() {
779 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
780 ctx.awaiting_ssl = false;
781 ctx.awaiting_startup = false;
782
783 let exec = Execute::new(Some("find-user-by-id-0".to_owned()), 100);
784 roundtrip!(exec, Execute, &ctx);
785 }
786
787 #[test]
788 fn test_sslrequest() {
789 let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
790
791 let sslreq = SslRequest::new();
792 roundtrip!(sslreq, SslRequest, &ctx);
793 }
794
795 #[test]
796 fn test_sslresponse() {
797 let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
798
799 let sslaccept = SslResponse::Accept;
800 roundtrip!(sslaccept, SslResponse, &ctx);
801 let sslrefuse = SslResponse::Refuse;
802 roundtrip!(sslrefuse, SslResponse, &ctx);
803 }
804
805 #[test]
806 fn test_gssencrequest() {
807 let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
808
809 let sslreq = GssEncRequest::new();
810 roundtrip!(sslreq, GssEncRequest, &ctx);
811 }
812
813 #[test]
814 fn test_gssencresponse() {
815 let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
816
817 let gssenc_accept = GssEncResponse::Accept;
818 roundtrip!(gssenc_accept, GssEncResponse, &ctx);
819 let gssenc_refuse = GssEncResponse::Refuse;
820 roundtrip!(gssenc_refuse, GssEncResponse, &ctx);
821 }
822
823 #[test]
824 fn test_saslresponse() {
825 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
826 ctx.awaiting_ssl = false;
827 ctx.awaiting_startup = false;
828
829 let saslinitialresp =
830 SASLInitialResponse::new("SCRAM-SHA-256".to_owned(), Some(Bytes::from_static(b"abc")));
831 roundtrip!(saslinitialresp, SASLInitialResponse, &ctx);
832
833 let saslresp = SASLResponse::new(Bytes::from_static(b"abc"));
834 roundtrip!(saslresp, SASLResponse, &ctx);
835 }
836
837 #[test]
838 fn test_parameter_description() {
839 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
840 ctx.awaiting_ssl = false;
841 ctx.awaiting_startup = false;
842
843 let param_desc = ParameterDescription::new(vec![100, 200]);
844 roundtrip!(param_desc, ParameterDescription, &ctx);
845 }
846
847 #[test]
848 fn test_password_family() {
849 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
850 ctx.awaiting_ssl = false;
851 ctx.awaiting_startup = false;
852
853 let password = Password::new("tomcat".to_owned());
854
855 let mut buffer = BytesMut::new();
856 password.encode(&mut buffer).unwrap();
857 assert!(buffer.remaining() > 0);
858
859 let item2 = PasswordMessageFamily::decode(&mut buffer, &ctx)
860 .unwrap()
861 .unwrap();
862 assert_eq!(buffer.remaining(), 0);
863 assert_eq!(password, item2.into_password().unwrap());
864
865 let saslinitialresp =
866 SASLInitialResponse::new("SCRAM-SHA-256".to_owned(), Some(Bytes::from_static(b"abc")));
867 let mut buffer = BytesMut::new();
868 saslinitialresp.encode(&mut buffer).unwrap();
869 assert!(buffer.remaining() > 0);
870
871 let item2 = PasswordMessageFamily::decode(&mut buffer, &ctx)
872 .unwrap()
873 .unwrap();
874 assert_eq!(buffer.remaining(), 0);
875 assert_eq!(saslinitialresp, item2.into_sasl_initial_response().unwrap());
876 }
877
878 #[test]
879 fn test_no_data() {
880 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
881 ctx.awaiting_ssl = false;
882 ctx.awaiting_startup = false;
883
884 let nodata = NoData::new();
885 roundtrip!(nodata, NoData, &ctx);
886 }
887
888 #[test]
889 fn test_copy_data() {
890 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
891 ctx.awaiting_ssl = false;
892 ctx.awaiting_startup = false;
893
894 let copydata = CopyData::new(Bytes::from_static("tomcat".as_bytes()));
895 roundtrip!(copydata, CopyData, &ctx);
896 }
897
898 #[test]
899 fn test_copy_done() {
900 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
901 ctx.awaiting_ssl = false;
902 ctx.awaiting_startup = false;
903
904 let copydone = CopyDone::new();
905 roundtrip!(copydone, CopyDone, &ctx);
906 }
907
908 #[test]
909 fn test_copy_fail() {
910 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
911 ctx.awaiting_ssl = false;
912 ctx.awaiting_startup = false;
913
914 let copyfail = CopyFail::new("copy failed".to_owned());
915 roundtrip!(copyfail, CopyFail, &ctx);
916 }
917
918 #[test]
919 fn test_copy_response() {
920 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
921 ctx.awaiting_ssl = false;
922 ctx.awaiting_startup = false;
923
924 let copyresponse = CopyInResponse::new(0, 3, vec![0, 0, 0]);
925 roundtrip!(copyresponse, CopyInResponse, &ctx);
926
927 let copyresponse = CopyOutResponse::new(0, 3, vec![0, 0, 0]);
928 roundtrip!(copyresponse, CopyOutResponse, &ctx);
929
930 let copyresponse = CopyBothResponse::new(0, 3, vec![0, 0, 0]);
931 roundtrip!(copyresponse, CopyBothResponse, &ctx);
932 }
933
934 #[test]
935 fn test_notification_response() {
936 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
937 ctx.awaiting_ssl = false;
938 ctx.awaiting_startup = false;
939
940 let notification_response =
941 NotificationResponse::new(10087, "channel".to_owned(), "payload".to_owned());
942 roundtrip!(notification_response, NotificationResponse, &ctx);
943 }
944
945 #[test]
946 fn test_negotiate_protocol_version() {
947 let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
948 ctx.awaiting_ssl = false;
949 ctx.awaiting_startup = false;
950
951 let negotiate_protocol_version =
952 NegotiateProtocolVersion::new(2, vec!["database".to_owned(), "user".to_owned()]);
953 roundtrip!(negotiate_protocol_version, NegotiateProtocolVersion, &ctx);
954 }
955}