pgwire/messages/
mod.rs

1//! `messages` module contains postgresql wire protocol message definitions and
2//! codecs.
3//!
4//! `PgWireFrontendMessage` and `PgWireBackendMessage` are enums that define all
5//! types of supported messages. `Message` trait allows you to encode/decode
6//! them on a `BytesMut` buffer.
7
8use bytes::{Buf, BufMut, BytesMut};
9
10use crate::error::{PgWireError, PgWireResult};
11
12#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
13pub enum ProtocolVersion {
14    PROTOCOL3_0,
15    #[default]
16    PROTOCOL3_2,
17}
18
19impl ProtocolVersion {
20    pub fn version_number(&self) -> (u16, u16) {
21        match &self {
22            Self::PROTOCOL3_0 => (3, 0),
23            Self::PROTOCOL3_2 => (3, 2),
24        }
25    }
26
27    /// Get ProtocolVersion from (major, minor) version tuple
28    ///
29    /// Return none if protocol is not supported.
30    pub fn from_version_number(major: u16, minor: u16) -> Option<Self> {
31        match (major, minor) {
32            (3, 0) => Some(Self::PROTOCOL3_0),
33            (3, 2) => Some(Self::PROTOCOL3_2),
34            _ => None,
35        }
36    }
37}
38
39// 1 gigabyte - 1
40pub(crate) const LARGE_PACKET_SIZE_LIMIT: usize = 0x3fffffff - 1;
41pub(crate) const SMALL_PACKET_SIZE_LIMIT: usize = 10000;
42// libpq has a limit on backend message length, except those
43// VALID_LONG_MESSAGE_TYPE
44pub(crate) const SMALL_BACKEND_PACKET_SIZE_LIMIT: usize = 30000;
45pub(crate) const LONG_BACKEND_PACKET_SIZE_LIMIT: usize = i32::MAX as usize;
46
47#[non_exhaustive]
48#[derive(Default, Debug, PartialEq, Eq, new)]
49pub struct DecodeContext {
50    pub protocol_version: ProtocolVersion,
51    #[new(value = "true")]
52    pub awaiting_ssl: bool,
53    #[new(value = "true")]
54    pub awaiting_startup: bool,
55}
56
57/// Define how message encode and decoded.
58pub trait Message: Sized {
59    /// Return the type code of the message. In order to maintain backward
60    /// compatibility, `Startup` has no message type.
61    #[inline]
62    fn message_type() -> Option<u8> {
63        None
64    }
65
66    /// Return the length of the message, including the length integer itself.
67    fn message_length(&self) -> usize;
68
69    /// Return the max length of message in this type.
70    ///
71    /// This is to validate the length field in decode
72    #[inline]
73    fn max_message_length() -> usize {
74        SMALL_PACKET_SIZE_LIMIT
75    }
76
77    /// Encode body part of the message.
78    fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()>;
79
80    /// Decode body part of the message.
81    fn decode_body(buf: &mut BytesMut, full_len: usize, _ctx: &DecodeContext)
82        -> PgWireResult<Self>;
83
84    /// Default implementation for encoding message.
85    ///
86    /// Message type and length are encoded in this implementation and it calls
87    /// `encode_body` for remaining parts.
88    fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
89        if let Some(mt) = Self::message_type() {
90            buf.put_u8(mt);
91        }
92
93        let len = self.message_length();
94        if len > Self::max_message_length() {
95            return Err(PgWireError::MessageTooLarge(
96                len,
97                Self::max_message_length(),
98            ));
99        }
100
101        buf.put_i32(len as i32);
102        self.encode_body(buf)
103    }
104
105    /// Default implementation for decoding message.
106    ///
107    /// Message type and length are decoded in this implementation and it calls
108    /// `decode_body` for remaining parts. Return `None` if the packet is not
109    /// complete for parsing.
110    fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
111        let offset = Self::message_type().is_some().into();
112
113        codec::decode_packet(buf, offset, Self::max_message_length(), |buf, full_len| {
114            Self::decode_body(buf, full_len, ctx)
115        })
116    }
117}
118
119/// Cancel message
120pub mod cancel;
121mod codec;
122/// Copy messages
123pub mod copy;
124/// Data related messages
125pub mod data;
126/// Extended query messages, including request/response for parse, bind and etc.
127pub mod extendedquery;
128/// General response messages
129pub mod response;
130/// Simple query messages, including descriptions
131pub mod simplequery;
132/// Startup messages
133pub mod startup;
134/// Termination messages
135pub mod terminate;
136
137/// Messages sent from Frontend
138#[derive(Debug)]
139pub enum PgWireFrontendMessage {
140    Startup(startup::Startup),
141    CancelRequest(cancel::CancelRequest),
142    SslRequest(startup::SslRequest),
143    GssEncRequest(startup::GssEncRequest),
144    PasswordMessageFamily(startup::PasswordMessageFamily),
145
146    Query(simplequery::Query),
147
148    Parse(extendedquery::Parse),
149    Close(extendedquery::Close),
150    Bind(extendedquery::Bind),
151    Describe(extendedquery::Describe),
152    Execute(extendedquery::Execute),
153    Flush(extendedquery::Flush),
154    Sync(extendedquery::Sync),
155
156    Terminate(terminate::Terminate),
157
158    CopyData(copy::CopyData),
159    CopyFail(copy::CopyFail),
160    CopyDone(copy::CopyDone),
161}
162
163impl PgWireFrontendMessage {
164    pub fn is_extended_query(&self) -> bool {
165        matches!(
166            self,
167            Self::Parse(_)
168                | Self::Bind(_)
169                | Self::Close(_)
170                | Self::Describe(_)
171                | Self::Execute(_)
172                | Self::Flush(_)
173                | Self::Sync(_)
174        )
175    }
176
177    pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
178        match self {
179            Self::Startup(msg) => msg.encode(buf),
180            Self::CancelRequest(msg) => msg.encode(buf),
181            Self::SslRequest(msg) => msg.encode(buf),
182            Self::GssEncRequest(msg) => msg.encode(buf),
183
184            Self::PasswordMessageFamily(msg) => msg.encode(buf),
185
186            Self::Query(msg) => msg.encode(buf),
187
188            Self::Parse(msg) => msg.encode(buf),
189            Self::Bind(msg) => msg.encode(buf),
190            Self::Close(msg) => msg.encode(buf),
191            Self::Describe(msg) => msg.encode(buf),
192            Self::Execute(msg) => msg.encode(buf),
193            Self::Flush(msg) => msg.encode(buf),
194            Self::Sync(msg) => msg.encode(buf),
195
196            Self::Terminate(msg) => msg.encode(buf),
197
198            Self::CopyData(msg) => msg.encode(buf),
199            Self::CopyFail(msg) => msg.encode(buf),
200            Self::CopyDone(msg) => msg.encode(buf),
201        }
202    }
203
204    pub fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
205        if ctx.awaiting_ssl {
206            // Connection just estabilished, the incoming message can be:
207            // - SSLRequest
208            // - Startup
209            // - CancelRequest
210            // Try to read the magic number to tell whether it SSLRequest or
211            // Startup, all these messages should have at least 8 bytes
212            if buf.remaining() >= 8 {
213                if cancel::CancelRequest::is_cancel_request_packet(buf) {
214                    cancel::CancelRequest::decode(buf, ctx)
215                        .map(|opt| opt.map(PgWireFrontendMessage::CancelRequest))
216                } else if startup::SslRequest::is_ssl_request_packet(buf) {
217                    startup::SslRequest::decode(buf, ctx)
218                        .map(|opt| opt.map(PgWireFrontendMessage::SslRequest))
219                } else if startup::GssEncRequest::is_gss_enc_request_packet(buf) {
220                    startup::GssEncRequest::decode(buf, ctx)
221                        .map(|opt| opt.map(PgWireFrontendMessage::GssEncRequest))
222                } else {
223                    // startup
224                    startup::Startup::decode(buf, ctx).map(|v| v.map(Self::Startup))
225                }
226            } else {
227                Ok(None)
228            }
229        } else if ctx.awaiting_startup {
230            // we will check for cancel request again in case it's sent in ssl connection
231            if buf.remaining() >= 8 {
232                if cancel::CancelRequest::is_cancel_request_packet(buf) {
233                    cancel::CancelRequest::decode(buf, ctx)
234                        .map(|opt| opt.map(PgWireFrontendMessage::CancelRequest))
235                } else {
236                    startup::Startup::decode(buf, ctx).map(|v| v.map(Self::Startup))
237                }
238            } else {
239                Ok(None)
240            }
241        } else if buf.remaining() > 1 {
242            let first_byte = buf[0];
243
244            match first_byte {
245                // Password, SASLInitialResponse, SASLResponse can only be
246                // decoded under certain context
247                startup::MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY => {
248                    startup::PasswordMessageFamily::decode(buf, ctx)
249                        .map(|v| v.map(Self::PasswordMessageFamily))
250                }
251
252                simplequery::MESSAGE_TYPE_BYTE_QUERY => {
253                    simplequery::Query::decode(buf, ctx).map(|v| v.map(Self::Query))
254                }
255
256                extendedquery::MESSAGE_TYPE_BYTE_PARSE => {
257                    extendedquery::Parse::decode(buf, ctx).map(|v| v.map(Self::Parse))
258                }
259                extendedquery::MESSAGE_TYPE_BYTE_BIND => {
260                    extendedquery::Bind::decode(buf, ctx).map(|v| v.map(Self::Bind))
261                }
262                extendedquery::MESSAGE_TYPE_BYTE_CLOSE => {
263                    extendedquery::Close::decode(buf, ctx).map(|v| v.map(Self::Close))
264                }
265                extendedquery::MESSAGE_TYPE_BYTE_DESCRIBE => {
266                    extendedquery::Describe::decode(buf, ctx).map(|v| v.map(Self::Describe))
267                }
268                extendedquery::MESSAGE_TYPE_BYTE_EXECUTE => {
269                    extendedquery::Execute::decode(buf, ctx).map(|v| v.map(Self::Execute))
270                }
271                extendedquery::MESSAGE_TYPE_BYTE_FLUSH => {
272                    extendedquery::Flush::decode(buf, ctx).map(|v| v.map(Self::Flush))
273                }
274                extendedquery::MESSAGE_TYPE_BYTE_SYNC => {
275                    extendedquery::Sync::decode(buf, ctx).map(|v| v.map(Self::Sync))
276                }
277
278                terminate::MESSAGE_TYPE_BYTE_TERMINATE => {
279                    terminate::Terminate::decode(buf, ctx).map(|v| v.map(Self::Terminate))
280                }
281
282                copy::MESSAGE_TYPE_BYTE_COPY_DATA => {
283                    copy::CopyData::decode(buf, ctx).map(|v| v.map(Self::CopyData))
284                }
285                copy::MESSAGE_TYPE_BYTE_COPY_FAIL => {
286                    copy::CopyFail::decode(buf, ctx).map(|v| v.map(Self::CopyFail))
287                }
288                copy::MESSAGE_TYPE_BYTE_COPY_DONE => {
289                    copy::CopyDone::decode(buf, ctx).map(|v| v.map(Self::CopyDone))
290                }
291                _ => Err(PgWireError::InvalidMessageType(first_byte)),
292            }
293        } else {
294            Ok(None)
295        }
296    }
297}
298
299/// Messages sent from Backend
300#[derive(Debug)]
301pub enum PgWireBackendMessage {
302    // startup
303    SslResponse(response::SslResponse),
304    GssEncResponse(response::GssEncResponse),
305    Authentication(startup::Authentication),
306    ParameterStatus(startup::ParameterStatus),
307    BackendKeyData(startup::BackendKeyData),
308    NegotiateProtocolVersion(startup::NegotiateProtocolVersion),
309
310    // extended query
311    ParseComplete(extendedquery::ParseComplete),
312    CloseComplete(extendedquery::CloseComplete),
313    BindComplete(extendedquery::BindComplete),
314    PortalSuspended(extendedquery::PortalSuspended),
315
316    // command response
317    CommandComplete(response::CommandComplete),
318    EmptyQueryResponse(response::EmptyQueryResponse),
319    ReadyForQuery(response::ReadyForQuery),
320    ErrorResponse(response::ErrorResponse),
321    NoticeResponse(response::NoticeResponse),
322    NotificationResponse(response::NotificationResponse),
323
324    // data
325    ParameterDescription(data::ParameterDescription),
326    RowDescription(data::RowDescription),
327    DataRow(data::DataRow),
328    NoData(data::NoData),
329
330    // copy
331    CopyData(copy::CopyData),
332    CopyFail(copy::CopyFail),
333    CopyDone(copy::CopyDone),
334    CopyInResponse(copy::CopyInResponse),
335    CopyOutResponse(copy::CopyOutResponse),
336    CopyBothResponse(copy::CopyBothResponse),
337}
338
339impl PgWireBackendMessage {
340    pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
341        match self {
342            Self::SslResponse(msg) => msg.encode(buf),
343            Self::GssEncResponse(msg) => msg.encode(buf),
344            Self::Authentication(msg) => msg.encode(buf),
345            Self::ParameterStatus(msg) => msg.encode(buf),
346            Self::BackendKeyData(msg) => msg.encode(buf),
347            Self::NegotiateProtocolVersion(msg) => msg.encode(buf),
348
349            Self::ParseComplete(msg) => msg.encode(buf),
350            Self::BindComplete(msg) => msg.encode(buf),
351            Self::CloseComplete(msg) => msg.encode(buf),
352            Self::PortalSuspended(msg) => msg.encode(buf),
353
354            Self::CommandComplete(msg) => msg.encode(buf),
355            Self::EmptyQueryResponse(msg) => msg.encode(buf),
356            Self::ReadyForQuery(msg) => msg.encode(buf),
357            Self::ErrorResponse(msg) => msg.encode(buf),
358            Self::NoticeResponse(msg) => msg.encode(buf),
359            Self::NotificationResponse(msg) => msg.encode(buf),
360
361            Self::ParameterDescription(msg) => msg.encode(buf),
362            Self::RowDescription(msg) => msg.encode(buf),
363            Self::DataRow(msg) => msg.encode(buf),
364            Self::NoData(msg) => msg.encode(buf),
365
366            Self::CopyData(msg) => msg.encode(buf),
367            Self::CopyFail(msg) => msg.encode(buf),
368            Self::CopyDone(msg) => msg.encode(buf),
369            Self::CopyInResponse(msg) => msg.encode(buf),
370            Self::CopyOutResponse(msg) => msg.encode(buf),
371            Self::CopyBothResponse(msg) => msg.encode(buf),
372        }
373    }
374
375    pub fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
376        if buf.remaining() > 1 {
377            let first_byte = buf[0];
378            match first_byte {
379                startup::MESSAGE_TYPE_BYTE_AUTHENTICATION => {
380                    startup::Authentication::decode(buf, ctx).map(|v| v.map(Self::Authentication))
381                }
382                startup::MESSAGE_TYPE_BYTE_PARAMETER_STATUS => {
383                    startup::ParameterStatus::decode(buf, ctx).map(|v| v.map(Self::ParameterStatus))
384                }
385                startup::MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA => {
386                    startup::BackendKeyData::decode(buf, ctx).map(|v| v.map(Self::BackendKeyData))
387                }
388                startup::MESSAGE_TYPE_BYTE_NEGOTIATE_PROTOCOL_VERSION => {
389                    startup::NegotiateProtocolVersion::decode(buf, ctx)
390                        .map(|v| v.map(Self::NegotiateProtocolVersion))
391                }
392
393                extendedquery::MESSAGE_TYPE_BYTE_PARSE_COMPLETE => {
394                    extendedquery::ParseComplete::decode(buf, ctx)
395                        .map(|v| v.map(Self::ParseComplete))
396                }
397                extendedquery::MESSAGE_TYPE_BYTE_BIND_COMPLETE => {
398                    extendedquery::BindComplete::decode(buf, ctx).map(|v| v.map(Self::BindComplete))
399                }
400                extendedquery::MESSAGE_TYPE_BYTE_CLOSE_COMPLETE => {
401                    extendedquery::CloseComplete::decode(buf, ctx)
402                        .map(|v| v.map(Self::CloseComplete))
403                }
404                extendedquery::MESSAGE_TYPE_BYTE_PORTAL_SUSPENDED => {
405                    extendedquery::PortalSuspended::decode(buf, ctx)
406                        .map(|v| v.map(PgWireBackendMessage::PortalSuspended))
407                }
408
409                response::MESSAGE_TYPE_BYTE_COMMAND_COMPLETE => {
410                    response::CommandComplete::decode(buf, ctx)
411                        .map(|v| v.map(Self::CommandComplete))
412                }
413                response::MESSAGE_TYPE_BYTE_EMPTY_QUERY_RESPONSE => {
414                    response::EmptyQueryResponse::decode(buf, ctx)
415                        .map(|v| v.map(Self::EmptyQueryResponse))
416                }
417                response::MESSAGE_TYPE_BYTE_READY_FOR_QUERY => {
418                    response::ReadyForQuery::decode(buf, ctx).map(|v| v.map(Self::ReadyForQuery))
419                }
420                response::MESSAGE_TYPE_BYTE_ERROR_RESPONSE => {
421                    response::ErrorResponse::decode(buf, ctx).map(|v| v.map(Self::ErrorResponse))
422                }
423                response::MESSAGE_TYPE_BYTE_NOTICE_RESPONSE => {
424                    response::NoticeResponse::decode(buf, ctx).map(|v| v.map(Self::NoticeResponse))
425                }
426                response::MESSAGE_TYPE_BYTE_NOTIFICATION_RESPONSE => {
427                    response::NotificationResponse::decode(buf, ctx)
428                        .map(|v| v.map(Self::NotificationResponse))
429                }
430
431                data::MESSAGE_TYPE_BYTE_PARAMETER_DESCRITION => {
432                    data::ParameterDescription::decode(buf, ctx)
433                        .map(|v| v.map(Self::ParameterDescription))
434                }
435                data::MESSAGE_TYPE_BYTE_ROW_DESCRITION => {
436                    data::RowDescription::decode(buf, ctx).map(|v| v.map(Self::RowDescription))
437                }
438                data::MESSAGE_TYPE_BYTE_DATA_ROW => {
439                    data::DataRow::decode(buf, ctx).map(|v| v.map(Self::DataRow))
440                }
441                data::MESSAGE_TYPE_BYTE_NO_DATA => {
442                    data::NoData::decode(buf, ctx).map(|v| v.map(Self::NoData))
443                }
444
445                copy::MESSAGE_TYPE_BYTE_COPY_DATA => {
446                    copy::CopyData::decode(buf, ctx).map(|v| v.map(Self::CopyData))
447                }
448                copy::MESSAGE_TYPE_BYTE_COPY_FAIL => {
449                    copy::CopyFail::decode(buf, ctx).map(|v| v.map(Self::CopyFail))
450                }
451                copy::MESSAGE_TYPE_BYTE_COPY_DONE => {
452                    copy::CopyDone::decode(buf, ctx).map(|v| v.map(Self::CopyDone))
453                }
454                copy::MESSAGE_TYPE_BYTE_COPY_IN_RESPONSE => {
455                    copy::CopyInResponse::decode(buf, ctx).map(|v| v.map(Self::CopyInResponse))
456                }
457                copy::MESSAGE_TYPE_BYTE_COPY_OUT_RESPONSE => {
458                    copy::CopyOutResponse::decode(buf, ctx).map(|v| v.map(Self::CopyOutResponse))
459                }
460                copy::MESSAGE_TYPE_BYTE_COPY_BOTH_RESPONSE => {
461                    copy::CopyBothResponse::decode(buf, ctx).map(|v| v.map(Self::CopyBothResponse))
462                }
463                _ => Err(PgWireError::InvalidMessageType(first_byte)),
464            }
465        } else {
466            Ok(None)
467        }
468    }
469}
470
471#[cfg(test)]
472mod test {
473    use crate::messages::DecodeContext;
474
475    use super::cancel::CancelRequest;
476    use super::copy::*;
477    use super::data::*;
478    use super::extendedquery::*;
479    use super::response::*;
480    use super::simplequery::*;
481    use super::startup::*;
482    use super::terminate::*;
483    use super::{Message, ProtocolVersion};
484    use bytes::{Buf, BufMut, Bytes, BytesMut};
485
486    macro_rules! roundtrip {
487        ($ins:ident, $st:ty, $ctx:expr) => {
488            let mut buffer = BytesMut::new();
489            $ins.encode(&mut buffer).expect("encode packet");
490
491            assert!(buffer.remaining() > 0);
492
493            let item2 = <$st>::decode(&mut buffer, $ctx)
494                .expect("decode packet")
495                .expect("packet is none");
496
497            assert_eq!(buffer.remaining(), 0);
498            assert_eq!($ins, item2);
499        };
500    }
501
502    #[test]
503    fn test_startup() {
504        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
505
506        let mut s = Startup::default();
507        s.parameters.insert("user".to_owned(), "tomcat".to_owned());
508        roundtrip!(s, Startup, &ctx);
509
510        ctx.awaiting_ssl = false;
511        roundtrip!(s, Startup, &ctx);
512
513        let mut s_too_large = Startup::default();
514        s_too_large
515            .parameters
516            .insert("user".to_owned(), "a".repeat(10000));
517        let mut buffer = BytesMut::new();
518        assert!(s_too_large.encode(&mut buffer).is_err());
519    }
520
521    #[test]
522    fn test_cancel_request() {
523        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_2);
524
525        let s = CancelRequest::new(100, SecretKey::Bytes(Bytes::from("server2008")));
526        roundtrip!(s, CancelRequest, &ctx);
527
528        ctx.protocol_version = ProtocolVersion::PROTOCOL3_0;
529        let s = CancelRequest::new(100, SecretKey::I32(1900));
530        roundtrip!(s, CancelRequest, &ctx);
531    }
532
533    #[test]
534    fn test_authentication() {
535        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
536        ctx.awaiting_ssl = false;
537        ctx.awaiting_startup = false;
538
539        let ss = vec![
540            Authentication::Ok,
541            Authentication::CleartextPassword,
542            Authentication::KerberosV5,
543            Authentication::SASLContinue(Bytes::from("hello")),
544            Authentication::SASLFinal(Bytes::from("world")),
545        ];
546        for s in ss {
547            roundtrip!(s, Authentication, &ctx);
548        }
549
550        let md5pass = Authentication::MD5Password(vec![b'p', b's', b't', b'g']);
551        roundtrip!(md5pass, Authentication, &ctx);
552    }
553
554    #[test]
555    fn test_password() {
556        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
557        ctx.awaiting_ssl = false;
558        ctx.awaiting_startup = false;
559
560        let s = Password::new("pgwire".to_owned());
561        roundtrip!(s, Password, &ctx);
562    }
563
564    #[test]
565    fn test_parameter_status() {
566        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
567        ctx.awaiting_ssl = false;
568        ctx.awaiting_startup = false;
569
570        let pps = ParameterStatus::new("cli".to_owned(), "psql".to_owned());
571        roundtrip!(pps, ParameterStatus, &ctx);
572    }
573
574    #[test]
575    fn test_query() {
576        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
577        ctx.awaiting_ssl = false;
578        ctx.awaiting_startup = false;
579
580        let query = Query::new("SELECT 1".to_owned());
581        roundtrip!(query, Query, &ctx);
582    }
583
584    #[test]
585    fn test_command_complete() {
586        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
587        ctx.awaiting_ssl = false;
588        ctx.awaiting_startup = false;
589
590        let cc = CommandComplete::new("DELETE 5".to_owned());
591        roundtrip!(cc, CommandComplete, &ctx);
592
593        let cc = CommandComplete::new("DELETE 5".repeat(10_000));
594        let mut buffer = BytesMut::new();
595        assert!(cc.encode(&mut buffer).is_err());
596    }
597
598    #[test]
599    fn test_ready_for_query() {
600        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
601        ctx.awaiting_ssl = false;
602        ctx.awaiting_startup = false;
603
604        let r4q = ReadyForQuery::new(TransactionStatus::Idle);
605        roundtrip!(r4q, ReadyForQuery, &ctx);
606        let r4q = ReadyForQuery::new(TransactionStatus::Transaction);
607        roundtrip!(r4q, ReadyForQuery, &ctx);
608        let r4q = ReadyForQuery::new(TransactionStatus::Error);
609        roundtrip!(r4q, ReadyForQuery, &ctx);
610    }
611
612    #[test]
613    fn test_error_response() {
614        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
615        ctx.awaiting_ssl = false;
616        ctx.awaiting_startup = false;
617
618        let mut error = ErrorResponse::default();
619        error.fields.push((b'R', "ERROR".to_owned()));
620        error.fields.push((b'K', "cli".to_owned()));
621
622        roundtrip!(error, ErrorResponse, &ctx);
623    }
624
625    #[test]
626    fn test_notice_response() {
627        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
628        ctx.awaiting_ssl = false;
629        ctx.awaiting_startup = false;
630
631        let mut error = NoticeResponse::default();
632        error.fields.push((b'R', "NOTICE".to_owned()));
633        error.fields.push((b'K', "cli".to_owned()));
634
635        roundtrip!(error, NoticeResponse, &ctx);
636    }
637
638    #[test]
639    fn test_row_description() {
640        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
641        ctx.awaiting_ssl = false;
642        ctx.awaiting_startup = false;
643
644        let mut row_description = RowDescription::default();
645
646        let mut f1 = FieldDescription::default();
647        f1.name = "id".into();
648        f1.table_id = 1001;
649        f1.column_id = 10001;
650        f1.type_id = 1083;
651        f1.type_size = 4;
652        f1.type_modifier = -1;
653        f1.format_code = FORMAT_CODE_TEXT;
654        row_description.fields.push(f1);
655
656        let mut f2 = FieldDescription::default();
657        f2.name = "name".into();
658        f2.table_id = 1001;
659        f2.column_id = 10001;
660        f2.type_id = 1099;
661        f2.type_size = -1;
662        f2.type_modifier = -1;
663        f2.format_code = FORMAT_CODE_TEXT;
664        row_description.fields.push(f2);
665
666        roundtrip!(row_description, RowDescription, &ctx);
667    }
668
669    #[test]
670    fn test_data_row() {
671        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
672        ctx.awaiting_ssl = false;
673        ctx.awaiting_startup = false;
674
675        let mut row0 = DataRow::default();
676        row0.data.put_i32(4);
677        row0.data.put_slice("data".as_bytes());
678        row0.data.put_i32(4);
679        row0.data.put_i32(1001);
680        row0.data.put_i32(-1);
681
682        roundtrip!(row0, DataRow, &ctx);
683    }
684
685    #[test]
686    fn test_terminate() {
687        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
688        ctx.awaiting_ssl = false;
689        ctx.awaiting_startup = false;
690
691        let terminate = Terminate::new();
692        roundtrip!(terminate, Terminate, &ctx);
693    }
694
695    #[test]
696    fn test_parse() {
697        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
698        ctx.awaiting_ssl = false;
699        ctx.awaiting_startup = false;
700
701        let parse = Parse::new(
702            Some("find-user-by-id".to_owned()),
703            "SELECT * FROM user WHERE id = ?".to_owned(),
704            vec![1],
705        );
706        roundtrip!(parse, Parse, &ctx);
707    }
708
709    #[test]
710    fn test_parse_65k() {
711        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
712        ctx.awaiting_ssl = false;
713        ctx.awaiting_startup = false;
714
715        let parse = Parse::new(
716            Some("many-params".to_owned()),
717            "it won't be parsed anyway".to_owned(),
718            vec![25; u16::MAX as usize],
719        );
720        roundtrip!(parse, Parse, &ctx);
721    }
722
723    #[test]
724    fn test_parse_complete() {
725        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
726        ctx.awaiting_ssl = false;
727        ctx.awaiting_startup = false;
728
729        let parse_complete = ParseComplete::new();
730        roundtrip!(parse_complete, ParseComplete, &ctx);
731    }
732
733    #[test]
734    fn test_close() {
735        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
736        ctx.awaiting_ssl = false;
737        ctx.awaiting_startup = false;
738
739        let close = Close::new(
740            TARGET_TYPE_BYTE_STATEMENT,
741            Some("find-user-by-id".to_owned()),
742        );
743        roundtrip!(close, Close, &ctx);
744    }
745
746    #[test]
747    fn test_bind() {
748        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
749        ctx.awaiting_ssl = false;
750        ctx.awaiting_startup = false;
751        let bind = Bind::new(
752            Some("find-user-by-id-0".to_owned()),
753            Some("find-user-by-id".to_owned()),
754            vec![0],
755            vec![Some(Bytes::from_static(b"1234"))],
756            vec![0],
757        );
758        roundtrip!(bind, Bind, &ctx);
759    }
760
761    #[test]
762    fn test_bind_65k() {
763        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
764        ctx.awaiting_ssl = false;
765        ctx.awaiting_startup = false;
766
767        let bind = Bind::new(
768            Some("lol".to_owned()),
769            Some("kek".to_owned()),
770            vec![0; u16::MAX as usize],
771            vec![Some(Bytes::from_static(b"1234")); u16::MAX as usize],
772            vec![0],
773        );
774        roundtrip!(bind, Bind, &ctx);
775    }
776
777    #[test]
778    fn test_execute() {
779        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
780        ctx.awaiting_ssl = false;
781        ctx.awaiting_startup = false;
782
783        let exec = Execute::new(Some("find-user-by-id-0".to_owned()), 100);
784        roundtrip!(exec, Execute, &ctx);
785    }
786
787    #[test]
788    fn test_sslrequest() {
789        let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
790
791        let sslreq = SslRequest::new();
792        roundtrip!(sslreq, SslRequest, &ctx);
793    }
794
795    #[test]
796    fn test_sslresponse() {
797        let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
798
799        let sslaccept = SslResponse::Accept;
800        roundtrip!(sslaccept, SslResponse, &ctx);
801        let sslrefuse = SslResponse::Refuse;
802        roundtrip!(sslrefuse, SslResponse, &ctx);
803    }
804
805    #[test]
806    fn test_gssencrequest() {
807        let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
808
809        let sslreq = GssEncRequest::new();
810        roundtrip!(sslreq, GssEncRequest, &ctx);
811    }
812
813    #[test]
814    fn test_gssencresponse() {
815        let ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
816
817        let gssenc_accept = GssEncResponse::Accept;
818        roundtrip!(gssenc_accept, GssEncResponse, &ctx);
819        let gssenc_refuse = GssEncResponse::Refuse;
820        roundtrip!(gssenc_refuse, GssEncResponse, &ctx);
821    }
822
823    #[test]
824    fn test_saslresponse() {
825        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
826        ctx.awaiting_ssl = false;
827        ctx.awaiting_startup = false;
828
829        let saslinitialresp =
830            SASLInitialResponse::new("SCRAM-SHA-256".to_owned(), Some(Bytes::from_static(b"abc")));
831        roundtrip!(saslinitialresp, SASLInitialResponse, &ctx);
832
833        let saslresp = SASLResponse::new(Bytes::from_static(b"abc"));
834        roundtrip!(saslresp, SASLResponse, &ctx);
835    }
836
837    #[test]
838    fn test_parameter_description() {
839        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
840        ctx.awaiting_ssl = false;
841        ctx.awaiting_startup = false;
842
843        let param_desc = ParameterDescription::new(vec![100, 200]);
844        roundtrip!(param_desc, ParameterDescription, &ctx);
845    }
846
847    #[test]
848    fn test_password_family() {
849        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
850        ctx.awaiting_ssl = false;
851        ctx.awaiting_startup = false;
852
853        let password = Password::new("tomcat".to_owned());
854
855        let mut buffer = BytesMut::new();
856        password.encode(&mut buffer).unwrap();
857        assert!(buffer.remaining() > 0);
858
859        let item2 = PasswordMessageFamily::decode(&mut buffer, &ctx)
860            .unwrap()
861            .unwrap();
862        assert_eq!(buffer.remaining(), 0);
863        assert_eq!(password, item2.into_password().unwrap());
864
865        let saslinitialresp =
866            SASLInitialResponse::new("SCRAM-SHA-256".to_owned(), Some(Bytes::from_static(b"abc")));
867        let mut buffer = BytesMut::new();
868        saslinitialresp.encode(&mut buffer).unwrap();
869        assert!(buffer.remaining() > 0);
870
871        let item2 = PasswordMessageFamily::decode(&mut buffer, &ctx)
872            .unwrap()
873            .unwrap();
874        assert_eq!(buffer.remaining(), 0);
875        assert_eq!(saslinitialresp, item2.into_sasl_initial_response().unwrap());
876    }
877
878    #[test]
879    fn test_no_data() {
880        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
881        ctx.awaiting_ssl = false;
882        ctx.awaiting_startup = false;
883
884        let nodata = NoData::new();
885        roundtrip!(nodata, NoData, &ctx);
886    }
887
888    #[test]
889    fn test_copy_data() {
890        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
891        ctx.awaiting_ssl = false;
892        ctx.awaiting_startup = false;
893
894        let copydata = CopyData::new(Bytes::from_static("tomcat".as_bytes()));
895        roundtrip!(copydata, CopyData, &ctx);
896    }
897
898    #[test]
899    fn test_copy_done() {
900        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
901        ctx.awaiting_ssl = false;
902        ctx.awaiting_startup = false;
903
904        let copydone = CopyDone::new();
905        roundtrip!(copydone, CopyDone, &ctx);
906    }
907
908    #[test]
909    fn test_copy_fail() {
910        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
911        ctx.awaiting_ssl = false;
912        ctx.awaiting_startup = false;
913
914        let copyfail = CopyFail::new("copy failed".to_owned());
915        roundtrip!(copyfail, CopyFail, &ctx);
916    }
917
918    #[test]
919    fn test_copy_response() {
920        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
921        ctx.awaiting_ssl = false;
922        ctx.awaiting_startup = false;
923
924        let copyresponse = CopyInResponse::new(0, 3, vec![0, 0, 0]);
925        roundtrip!(copyresponse, CopyInResponse, &ctx);
926
927        let copyresponse = CopyOutResponse::new(0, 3, vec![0, 0, 0]);
928        roundtrip!(copyresponse, CopyOutResponse, &ctx);
929
930        let copyresponse = CopyBothResponse::new(0, 3, vec![0, 0, 0]);
931        roundtrip!(copyresponse, CopyBothResponse, &ctx);
932    }
933
934    #[test]
935    fn test_notification_response() {
936        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
937        ctx.awaiting_ssl = false;
938        ctx.awaiting_startup = false;
939
940        let notification_response =
941            NotificationResponse::new(10087, "channel".to_owned(), "payload".to_owned());
942        roundtrip!(notification_response, NotificationResponse, &ctx);
943    }
944
945    #[test]
946    fn test_negotiate_protocol_version() {
947        let mut ctx = DecodeContext::new(ProtocolVersion::PROTOCOL3_0);
948        ctx.awaiting_ssl = false;
949        ctx.awaiting_startup = false;
950
951        let negotiate_protocol_version =
952            NegotiateProtocolVersion::new(2, vec!["database".to_owned(), "user".to_owned()]);
953        roundtrip!(negotiate_protocol_version, NegotiateProtocolVersion, &ctx);
954    }
955}