fern_protocol_postgresql/codec/
frontend.rs

1// SPDX-FileCopyrightText:  Copyright © 2022 The Fern Authors <team@fernproxy.io>
2// SPDX-License-Identifier: Apache-2.0
3
4//! [`Decoder`]/[`Encoder`] traits implementations
5//! for PostgreSQL frontend Messages.
6//!
7//! [`Decoder`]: https://docs.rs/tokio-util/*/tokio_util/codec/trait.Decoder.html
8//! [`Encoder`]: https://docs.rs/tokio-util/*/tokio_util/codec/trait.Encoder.html
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use std::io;
12use tokio_util::codec::{Decoder, Encoder};
13
14use super::{PostgresMessage, SQLMessage};
15use crate::codec::constants::*;
16use crate::codec::utils::*;
17
18const BYTES_STARTUP_MESSAGE_HEADER: usize = 8;
19const MESSAGE_ID_SSL_REQUEST: i32 = 80877103;
20const MESSAGE_ID_STARTUP_MESSAGE: i32 = 196608;
21
22// const MESSAGE_ID_BIND: u8 = b'B'; //TODO(ppiotr3k): write tests
23const MESSAGE_ID_EXECUTE: u8 = b'E';
24const MESSAGE_ID_FLUSH: u8 = b'H';
25const MESSAGE_ID_QUERY: u8 = b'Q';
26const MESSAGE_ID_SASL: u8 = b'p';
27const MESSAGE_ID_SYNC: u8 = b'S';
28const MESSAGE_ID_TERMINATE: u8 = b'X';
29
30// TODO(ppiotr3k): implement following messages
31// const MESSAGE_ID_CANCEL_REQUEST: u8 = b''; // ! no id; maybe MSB will do //TODO(ppiotr3k): write tests
32// const MESSAGE_ID_CLOSE: u8 = b'C'; //TODO(ppiotr3k): write tests
33// const MESSAGE_ID_COPY_DATA: u8 = b'd'; //TODO(ppiotr3k): write tests
34// const MESSAGE_ID_COPY_DONE: u8 = b'c'; //TODO(ppiotr3k): write tests
35// const MESSAGE_ID_COPY_FAIL: u8 = b'f'; //TODO(ppiotr3k): write tests
36// const MESSAGE_ID_DESCRIBE: u8 = b'D'; //TODO(ppiotr3k): write tests
37// const MESSAGE_ID_FUNCTION_CALL: u8 = b'F'; //TODO(ppiotr3k): write tests
38// const MESSAGE_ID_GSSENC_REQUEST: u8 = b''; // ! no id //TODO(ppiotr3k): write tests
39// const MESSAGE_ID_GSS_RESPONSE: u8 = b'p'; // ! shared id //TODO(ppiotr3k): write tests
40// const MESSAGE_ID_PARSE: u8 = b'P'; //TODO(ppiotr3k): write tests
41// const MESSAGE_ID_PASSWORD_MESSAGE: u8 = b'p'; // ! shared id //TODO(ppiotr3k): write tests
42
43///TODO(ppiotr3k): write description
44//TODO(ppiotr3k): investigate if `Clone` is avoidable; currently only used in tests
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub enum Message {
47    NotImplemented(Bytes),
48
49    //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
50    Canary(u8),
51
52    Bind {
53        portal: Bytes,
54        stmt_name: Bytes,
55        parameters: Vec<BindParameter>,
56        results_formats: Vec<u16>,
57    },
58    Execute {
59        portal: Bytes,
60        max_rows: u32,
61    },
62    Flush(),
63    Query(Bytes),
64    SASLInitialResponse {
65        mecanism: Bytes,
66        response: Bytes,
67    },
68    SASLResponse(Bytes),
69    SSLRequest(),
70    StartupMessage {
71        frame_length: usize,
72        parameters: Vec<Parameter>,
73    },
74    Sync(),
75    Terminate(),
76
77    //TODO(ppiotr3k): implement following messages
78    CancelRequest(Bytes),
79    Close(Bytes),
80    CopyData(Bytes),
81    CopyDone(Bytes),
82    CopyFail(Bytes),
83    Describe(Bytes),
84    FunctionCall(Bytes),
85    GSSENCRequest(Bytes),
86    GSSResponse(Bytes),
87    Parse(Bytes),
88    PasswordMessage(Bytes),
89}
90
91impl PostgresMessage for Message {}
92impl SQLMessage for Message {}
93
94///TODO(ppiotr3k): write description
95//TODO(ppiotr3k): internal fields encapsulation
96#[derive(Clone, Debug, Eq, PartialEq)]
97pub struct BindParameter {
98    pub format: u16,
99    pub value: Bytes,
100}
101
102///TODO(ppiotr3k): write description
103//TODO(ppiotr3k): internal fields encapsulation
104#[derive(Clone, Debug, Eq, PartialEq)]
105pub struct Parameter {
106    pub name: Bytes,
107    pub value: Bytes,
108}
109
110///TODO(ppiotr3k): write description
111#[derive(Debug, Clone)]
112enum DecodeState {
113    Startup,
114    Head,
115    Message(usize),
116}
117
118///TODO(ppiotr3k): write description
119#[derive(Debug, Clone)]
120pub struct Codec {
121    /// Read state management / optimization.
122    state: DecodeState,
123}
124
125impl Codec {
126    ///TODO(ppiotr3k): write function description
127    #[must_use]
128    pub const fn new() -> Self {
129        Self {
130            state: DecodeState::Startup,
131        }
132    }
133
134    /// Transitions decoder from `Startup` to next state.
135    pub fn startup_complete(&mut self) {
136        self.state = DecodeState::Head;
137    }
138
139    ///TODO(ppiotr3k): write function description
140    fn decode_header(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
141        if src.len() < BYTES_MESSAGE_HEADER {
142            // Incomplete header, await for more data.
143            log::trace!(
144                "not enough header data ({} bytes), awaiting more ({} bytes)",
145                src.len(),
146                BYTES_MESSAGE_HEADER,
147            );
148            return Ok(None);
149        }
150
151        // Peek into data with a `Cursor` to avoid advancing underlying buffer.
152        let mut buf = io::Cursor::new(&mut *src);
153        buf.advance(BYTES_MESSAGE_ID);
154
155        // 'Message Length' field accounts for self, but not 'Message ID' field.
156        // Note: `usize` prevents from 'Message Length' `i32` value overflow.
157        let frame_length = (buf.get_u32() as usize) + BYTES_MESSAGE_ID;
158
159        // Strict "less than", as null-payload messages exist in protocol.
160        if frame_length < BYTES_MESSAGE_HEADER {
161            log::trace!("invalid frame: {:?}", buf);
162            let err = io::Error::new(
163                io::ErrorKind::InvalidInput,
164                "malformed packet - invalid message length",
165            );
166            log::error!("{}", err);
167            return Err(err);
168        }
169
170        Ok(Some(frame_length))
171    }
172
173    ///TODO(ppiotr3k): write function description
174    fn decode_message(&mut self, len: usize, src: &mut BytesMut) -> io::Result<Option<Message>> {
175        if src.len() < len {
176            // Incomplete message, await for more data.
177            log::trace!(
178                "not enough message data ({} bytes), awaiting more ({} bytes)",
179                src.len(),
180                len
181            );
182            return Ok(None);
183        }
184
185        // Full message, pop it out.
186        let mut frame = src.split_to(len);
187        //TODO(ppiotr3k): consider zero-cost `frame.freeze()` for lazy passing in `Pipe`
188
189        // Frames have at least `BYTES_MESSAGE_HEADER` bytes at this point.
190        let msg_id = frame.get_u8();
191        log::trace!("incoming msg id: '{}' ({})", msg_id as char, msg_id);
192        let msg_length = (frame.get_u32() as usize) - BYTES_MESSAGE_SIZE;
193        log::trace!("incoming msg length: {}", msg_length);
194
195        let msg = match msg_id {
196
197            // Canary
198            //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
199            b'B' /* 0x42 */ => {
200                frame.advance(msg_length);
201                Message::Canary(len as u8)
202            },
203            //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
204            b'!' /* 0x21 */ => {
205                return Err(io::Error::new(io::ErrorKind::InvalidData, "expected canary error"));
206            },
207
208            // Frontend
209            // MESSAGE_ID_BIND => {
210            //     //TODO(ppiotr3k): implement this message
211            // },
212            MESSAGE_ID_EXECUTE => {
213                let portal = get_cstr(&mut frame)?;
214                let max_rows = get_u32(&mut frame, "malformed packet - invalid execute data")?;
215                Message::Execute { portal, max_rows }
216            },
217            MESSAGE_ID_FLUSH => Message::Flush(),
218            MESSAGE_ID_QUERY => {
219                let query = get_cstr(&mut frame)?;
220                Message::Query(query)
221            },
222            MESSAGE_ID_SASL => {
223                // `SASLInitialResponse` holds a C-style null char terminated string,
224                // while `SASlResponse` holds bytes with no 0 byte at all in them.
225                // Therefore trying first to look for a `SASLInitialResponse`.
226                //TODO(ppiotr3k): rethink, as `get_cstr` writes errors to logs
227                // -> peeking at last frame byte and looking for a 0 maybe?
228                if let Ok(mecanism) = get_cstr(&mut frame) {
229                    const SASL_RESPONSE_SIZE_BYTES: usize = 4;
230                    let response = get_bytes(
231                        &mut frame,
232                        SASL_RESPONSE_SIZE_BYTES,
233                        "malformed packet - invalid SASL response data",
234                    )?;
235
236                    Message::SASLInitialResponse { mecanism, response }
237                } else {
238                    let response = frame.copy_to_bytes(frame.remaining());
239
240                    // SASLResponse `response` field cannot be empty.
241                    if response.is_empty() {
242                        let err = std::io::Error::new(
243                            std::io::ErrorKind::InvalidInput,
244                            "malformed packet - invalid SASL response data",
245                        );
246                        log::error!("{}", err);
247                        return Err(err);
248                    }
249
250                    Message::SASLResponse(response)
251                }
252            },
253            MESSAGE_ID_SYNC => Message::Sync(),
254            MESSAGE_ID_TERMINATE => Message::Terminate(),
255            _ => {
256                let bytes = frame.copy_to_bytes(msg_length);
257                Message::NotImplemented(bytes)
258            },
259        };
260
261        // At this point, all data should have been consumed from `frame`.
262        if !frame.is_empty() {
263            log::trace!("invalid frame: {:?}", frame);
264            let err = std::io::Error::new(
265                std::io::ErrorKind::InvalidInput,
266                "malformed packet - invalid message length",
267            );
268            log::error!("{}", err);
269            return Err(err);
270        }
271
272        log::debug!("decoded message frame: {:?}", msg);
273        Ok(Some(msg))
274    }
275
276    ///TODO(ppiotr3k): write function description
277    pub fn decode_startup_message(&mut self, src: &mut BytesMut) -> io::Result<Option<Message>> {
278        if src.len() < BYTES_STARTUP_MESSAGE_HEADER {
279            // Incomplete message, await for more data.
280            log::trace!(
281                "not enough header data ({} bytes), awaiting more ({} bytes)",
282                src.len(),
283                BYTES_STARTUP_MESSAGE_HEADER,
284            );
285            return Ok(None);
286        }
287
288        // Peek into data with a `Cursor` to avoid advancing underlying buffer.
289        let mut buf = io::Cursor::new(&mut *src);
290
291        // Note: `usize` prevents from 'Message Length' `i32` value overflow.
292        let frame_length = buf.get_u32() as usize;
293        if src.len() < frame_length {
294            // Incomplete message, await for more data.
295            log::trace!(
296                "not enough message data ({} bytes), awaiting more ({} bytes)",
297                src.len(),
298                frame_length,
299            );
300            return Ok(None);
301        }
302
303        // Full message, pop it out.
304        let mut frame = src.split_to(frame_length);
305        log::trace!("decoded frame length: {}", frame_length);
306        //TODO(ppiotr3k): consider zero-cost `frame.freeze()` for lazy passing in `Pipe`
307
308        frame.advance(4); // `Message Length`
309
310        let msg_id = frame.get_i32();
311        log::trace!("msg id: {}", msg_id);
312        let msg = match msg_id {
313            MESSAGE_ID_STARTUP_MESSAGE => {
314                let mut parameters = Vec::new();
315                let mut user_param_exists = false;
316
317                // At least one parameter and name/value pair terminator are expected.
318                while frame.remaining() > 2 {
319                    let parameter_name = get_cstr(&mut frame)?;
320
321                    // Note: `user` is the sole required parameter, others are optional.
322                    if parameter_name == "user" {
323                        user_param_exists = true;
324                    }
325
326                    let parameter = Parameter {
327                        name: parameter_name,
328                        value: get_cstr(&mut frame)?,
329                    };
330                    log::trace!("decoded parameter: {:?}", parameter);
331                    parameters.push(parameter);
332                }
333
334                // At this point, only name/value pair terminator should remain,
335                // and a parameter named `user` should have been found.
336                if frame.remaining() < 1 || !user_param_exists {
337                    let err = std::io::Error::new(
338                        std::io::ErrorKind::InvalidInput,
339                        "malformed packet - missing parameter fields",
340                    );
341                    log::error!("{}", err);
342                    return Err(err);
343                }
344                frame.advance(1); // name/value pair terminator
345
346                Message::StartupMessage {
347                    frame_length,
348                    parameters,
349                }
350            }
351            MESSAGE_ID_SSL_REQUEST => Message::SSLRequest(),
352            _ => {
353                // If neither a recognized `StartupMessage` nor `SSLRequest`,
354                // consider as `StartupMessage` with unsupported protocol version.
355                let err = std::io::Error::new(
356                    std::io::ErrorKind::InvalidInput,
357                    "malformed packet - invalid protocol version",
358                );
359                log::error!("{}", err);
360                return Err(err);
361            }
362        };
363        log::debug!("decoded message frame: {:?}", msg);
364        Ok(Some(msg))
365    }
366
367    ///TODO(ppiotr3k): write function description
368    //TODO(ppiotr3k): get size from Message struct
369    // -> pre-requisite: enum variants are considered as types in Rust
370    fn encode_header(&mut self, msg_id: u8, msg_size: usize, dst: &mut BytesMut) {
371        dst.reserve(BYTES_MESSAGE_HEADER + msg_size);
372        dst.put_u8(msg_id);
373        dst.put_u32((BYTES_MESSAGE_SIZE + msg_size) as u32);
374    }
375}
376
377impl Decoder for Codec {
378    type Item = Message;
379    type Error = io::Error;
380
381    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Self::Item>> {
382        log::trace!("decoder state: {:?}", self.state);
383        let msg_length = match self.state {
384            // During startup sequence, frontend can send an `SSLRequest` message rather
385            // than a `StartupMessage`. `Startup` state handles this initial edge case.
386            // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.12
387            DecodeState::Startup => match self.decode_startup_message(src)? {
388                None => return Ok(None),
389                Some(Message::SSLRequest()) => return Ok(Some(Message::SSLRequest())),
390                Some(Message::StartupMessage {
391                    frame_length,
392                    parameters,
393                }) => {
394                    self.startup_complete();
395                    return Ok(Some(Message::StartupMessage {
396                        frame_length,
397                        parameters,
398                    }));
399                }
400                Some(other) => {
401                    let err = io::Error::new(
402                        io::ErrorKind::InvalidData,
403                        //TODO(ppiotr3k): rewrite without debug symbols
404                        format!("unexpected message during startup: {:?}", other),
405                    );
406                    log::error!("{}", err);
407                    return Err(err);
408                }
409            },
410
411            DecodeState::Head => match self.decode_header(src)? {
412                // Incomplete header, await for more data.
413                None => return Ok(None),
414                // Header available, try getting full message.
415                Some(length) => {
416                    self.state = DecodeState::Message(length);
417
418                    // Ensure enough space is available to read incoming payload.
419                    // Note: acceptable over-allocation by content of `BYTES_MESSAGE_SIZE`.
420                    src.reserve(length);
421                    log::trace!("stream buffer capacity: {} bytes", src.capacity());
422
423                    length
424                }
425            },
426
427            DecodeState::Message(length) => length,
428        };
429        log::trace!("decoded frame length: {} bytes", msg_length);
430
431        match self.decode_message(msg_length, src)? {
432            // Incomplete message, await for more data.
433            None => Ok(None),
434            // Full message, pop it out, move on to parsing a new one.
435            Some(msg) => {
436                self.state = DecodeState::Head;
437
438                // Ensure enough space is available to read next header.
439                src.reserve(BYTES_MESSAGE_HEADER);
440                log::trace!("stream buffer capacity: {} bytes", src.capacity());
441
442                Ok(Some(msg))
443            }
444        }
445    }
446}
447
448impl Encoder<Message> for Codec {
449    type Error = io::Error;
450
451    fn encode(&mut self, msg: Message, dst: &mut BytesMut) -> Result<(), io::Error> {
452        //TODO(ppiotr3k): rationalize capacity reservation with `dst.reserve(msg.len())`
453        // -> pre-requisite: enum variants are considered as types in Rust
454        match msg {
455            Message::Execute { portal, max_rows } => {
456                self.encode_header(MESSAGE_ID_EXECUTE, portal.len() + 1 + 4, dst);
457                put_cstr(&portal, dst);
458                dst.put_i32(max_rows as i32);
459            }
460            Message::Flush() => {
461                self.encode_header(MESSAGE_ID_FLUSH, 0, dst);
462            }
463            Message::Query(query) => {
464                self.encode_header(MESSAGE_ID_QUERY, query.len() + 1, dst);
465                put_cstr(&query, dst);
466            }
467            Message::SASLInitialResponse { mecanism, response } => {
468                self.encode_header(
469                    MESSAGE_ID_SASL,
470                    mecanism.len() + 1 + 4 + response.len(),
471                    dst,
472                );
473                put_cstr(&mecanism, dst);
474                put_bytes(&response, dst);
475            }
476            Message::SASLResponse(response) => {
477                self.encode_header(MESSAGE_ID_SASL, response.len(), dst);
478                dst.put(response);
479            }
480            Message::StartupMessage {
481                frame_length,
482                parameters,
483            } => {
484                dst.reserve(frame_length);
485                dst.put_i32(frame_length as i32);
486                dst.put_i32(196608);
487                for parameter in &parameters {
488                    put_cstr(&parameter.name, dst);
489                    put_cstr(&parameter.value, dst);
490                }
491                dst.put_u8(0); // name/value pair terminator
492            }
493            Message::SSLRequest() => {
494                dst.reserve(8);
495                dst.put_i32(8);
496                dst.put_i32(80877103);
497            }
498            Message::Sync() => {
499                self.encode_header(MESSAGE_ID_SYNC, 0, dst);
500            }
501            Message::Terminate() => {
502                self.encode_header(MESSAGE_ID_TERMINATE, 0, dst);
503            }
504            other => {
505                unimplemented!("not implemented: {:?}", other)
506            }
507        }
508
509        // Message has been written to `Sink`, nothing left to do.
510        // Note: if bytes remain in frame, encoding tests need a review.
511        Ok(())
512    }
513}
514
515impl Default for Codec {
516    fn default() -> Self {
517        Self::new()
518    }
519}
520
521#[cfg(test)]
522mod decode_tests {
523
524    use bytes::{Bytes, BytesMut};
525    use test_log::test;
526
527    use super::{Codec, Message, Parameter};
528
529    /// Helper function to ease writing decoding tests for startup sequence.
530    fn assert_decode_startup_message(data: &[u8], expected: &[Message], remaining: usize) {
531        let buf = &mut BytesMut::from(data);
532        let mut decoded = Vec::new();
533
534        let mut codec = Codec::new();
535        while let Ok(Some(msg)) = codec.decode_startup_message(buf) {
536            decoded.push(msg);
537        }
538
539        assert_eq!(remaining, buf.len(), "remaining bytes in read buffer");
540        assert_eq!(expected.len(), decoded.len(), "decoded messages");
541        assert_eq!(expected, decoded, "decoded messages");
542    }
543
544    #[test]
545    #[rustfmt::skip]
546    fn valid_startup_message() {
547        let data = [
548            0, 0, 0, 78,                                                                  // total length: 78
549            0, 3, 0, 0,                                                                   // protocol version: 3.0
550            117, 115, 101, 114, 0,                                                        // cstr: "user\0"
551            114, 111, 111, 116, 0,                                                        // cstr: "root\0"
552            100, 97, 116, 97, 98, 97, 115, 101, 0,                                        // cstr: "database\0"
553            116, 101, 115, 116, 100, 98, 0,                                               // cstr: "testdb\0"
554            97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
555            112, 115, 113, 108, 0,                                                        // cstr: "psql\0"
556            99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0,    // cstr: "client_encoding\0"
557            85, 84, 70, 56, 0,                                                            // cstr: "UTF8\0"
558            0,                                                                            // name/value pair terminator
559        ];
560
561        let expected = vec![
562            Message::StartupMessage {
563                frame_length: 78,
564                parameters: vec![
565                    Parameter {
566                        name: Bytes::from_static(b"user"),
567                        value: Bytes::from_static(b"root"),
568                    },
569                    Parameter {
570                        name: Bytes::from_static(b"database"),
571                        value: Bytes::from_static(b"testdb"),
572                    },
573                    Parameter {
574                        name: Bytes::from_static(b"application_name"),
575                        value: Bytes::from_static(b"psql"),
576                    },
577                    Parameter {
578                        name: Bytes::from_static(b"client_encoding"),
579                        value: Bytes::from_static(b"UTF8"),
580                    },
581                ]},
582        ];
583        let remaining = 0;
584
585        assert_decode_startup_message(&data[..], &expected, remaining);
586    }
587
588    #[test]
589    #[rustfmt::skip]
590    fn invalid_startup_message_wrong_protocol_version() {
591        let data = [
592            0, 0, 0, 78,                                                                  // total length: 78
593            0, 2, 0, 0,                                                                   // wrong protocol version: 2.0
594            117, 115, 101, 114, 0,                                                        // cstr: "user\0"
595            114, 111, 111, 116, 0,                                                        // cstr: "root\0"
596            100, 97, 116, 97, 98, 97, 115, 101, 0,                                        // cstr: "database\0"
597            116, 101, 115, 116, 100, 98, 0,                                               // cstr: "testdb\0"
598            97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
599            112, 115, 113, 108, 0,                                                        // cstr: "psql\0"
600            99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0,    // cstr: "client_encoding\0"
601            85, 84, 70, 56, 0,                                                            // cstr: "UTF8\0"
602            0,                                                                            // name/value pair terminator
603        ];
604
605        let expected = vec![];
606        let remaining = 0;
607
608        assert_decode_startup_message(&data[..], &expected, remaining);
609    }
610
611    #[test]
612    #[rustfmt::skip]
613    fn invalid_startup_message_missing_required_user() {
614        let data = [
615            0, 0, 0, 68,                                                                  // total length: 68
616            0, 3, 0, 0,                                                                   // protocol version: 3.0
617            100, 97, 116, 97, 98, 97, 115, 101, 0,                                        // cstr: "database\0"
618            116, 101, 115, 116, 100, 98, 0,                                               // cstr: "testdb\0"
619            97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
620            112, 115, 113, 108, 0,                                                        // cstr: "psql\0"
621            99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0,    // cstr: "client_encoding\0"
622            85, 84, 70, 56, 0,                                                            // cstr: "UTF8\0"
623            0,                                                                            // name/value pair terminator
624        ];
625
626        let expected = vec![];
627        let remaining = 0;
628
629        assert_decode_startup_message(&data[..], &expected, remaining);
630    }
631
632    #[test]
633    #[rustfmt::skip]
634    fn invalid_startup_message_empty_parameters_list() {
635        let data = [
636            0, 0, 0, 9, // total length: 9
637            0, 3, 0, 0, // protocol version: 3.0
638            0,          // name/value pair terminator
639        ];
640
641        let expected = vec![];
642        let remaining = 0;
643
644        assert_decode_startup_message(&data[..], &expected, remaining);
645    }
646
647    #[test]
648    #[rustfmt::skip]
649    fn invalid_startup_message_missing_parameters_data() {
650        let data = [
651            0, 0, 0, 8, // total length: 8
652            0, 3, 0, 0, // protocol version: 3.0
653                        // missing parameters data
654        ];
655
656        let expected = vec![];
657        let remaining = 0;
658
659        assert_decode_startup_message(&data[..], &expected, remaining);
660    }
661
662    #[test]
663    #[rustfmt::skip]
664    fn invalid_startup_message_missing_parameters_list_terminator() {
665        let data = [
666            0, 0, 0, 77,                                                                  // total length: 77
667            0, 3, 0, 0,                                                                   // protocol version: 3.0
668            117, 115, 101, 114, 0,                                                        // cstr: "user\0"
669            114, 111, 111, 116, 0,                                                        // cstr: "root\0"
670            100, 97, 116, 97, 98, 97, 115, 101, 0,                                        // cstr: "database\0"
671            116, 101, 115, 116, 100, 98, 0,                                               // cstr: "testdb\0"
672            97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
673            112, 115, 113, 108, 0,                                                        // cstr: "psql\0"
674            99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0,    // cstr: "client_encoding\0"
675            85, 84, 70, 56, 0,                                                            // cstr: "UTF8\0"
676                                                                                          // missing name/value pair terminator
677        ];
678
679        let expected = vec![];
680        let remaining = 0;
681
682        assert_decode_startup_message(&data[..], &expected, remaining);
683    }
684
685    #[test]
686    #[rustfmt::skip]
687    fn invalid_startup_message_missing_parameter_field() {
688        let data = [
689            0, 0, 0, 28,                           // total length: 28
690            0, 3, 0, 0,                            // protocol version: 3.0
691            117, 115, 101, 114, 0,                 // cstr: "user\0"
692            114, 111, 111, 116, 0,                 // cstr: "root\0"
693            100, 97, 116, 97, 98, 97, 115, 101, 0, // cstr: "database\0"
694            0,                                     // missing value field || missing name/value pair terminator
695        ];
696
697        let expected = vec![];
698        let remaining = 0;
699
700        assert_decode_startup_message(&data[..], &expected, remaining);
701    }
702}