fern_protocol_postgresql/codec/frontend.rs
1// SPDX-FileCopyrightText: Copyright © 2022 The Fern Authors <team@fernproxy.io>
2// SPDX-License-Identifier: Apache-2.0
3
4//! [`Decoder`]/[`Encoder`] traits implementations
5//! for PostgreSQL frontend Messages.
6//!
7//! [`Decoder`]: https://docs.rs/tokio-util/*/tokio_util/codec/trait.Decoder.html
8//! [`Encoder`]: https://docs.rs/tokio-util/*/tokio_util/codec/trait.Encoder.html
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use std::io;
12use tokio_util::codec::{Decoder, Encoder};
13
14use super::{PostgresMessage, SQLMessage};
15use crate::codec::constants::*;
16use crate::codec::utils::*;
17
18const BYTES_STARTUP_MESSAGE_HEADER: usize = 8;
19const MESSAGE_ID_SSL_REQUEST: i32 = 80877103;
20const MESSAGE_ID_STARTUP_MESSAGE: i32 = 196608;
21
22// const MESSAGE_ID_BIND: u8 = b'B'; //TODO(ppiotr3k): write tests
23const MESSAGE_ID_EXECUTE: u8 = b'E';
24const MESSAGE_ID_FLUSH: u8 = b'H';
25const MESSAGE_ID_QUERY: u8 = b'Q';
26const MESSAGE_ID_SASL: u8 = b'p';
27const MESSAGE_ID_SYNC: u8 = b'S';
28const MESSAGE_ID_TERMINATE: u8 = b'X';
29
30// TODO(ppiotr3k): implement following messages
31// const MESSAGE_ID_CANCEL_REQUEST: u8 = b''; // ! no id; maybe MSB will do //TODO(ppiotr3k): write tests
32// const MESSAGE_ID_CLOSE: u8 = b'C'; //TODO(ppiotr3k): write tests
33// const MESSAGE_ID_COPY_DATA: u8 = b'd'; //TODO(ppiotr3k): write tests
34// const MESSAGE_ID_COPY_DONE: u8 = b'c'; //TODO(ppiotr3k): write tests
35// const MESSAGE_ID_COPY_FAIL: u8 = b'f'; //TODO(ppiotr3k): write tests
36// const MESSAGE_ID_DESCRIBE: u8 = b'D'; //TODO(ppiotr3k): write tests
37// const MESSAGE_ID_FUNCTION_CALL: u8 = b'F'; //TODO(ppiotr3k): write tests
38// const MESSAGE_ID_GSSENC_REQUEST: u8 = b''; // ! no id //TODO(ppiotr3k): write tests
39// const MESSAGE_ID_GSS_RESPONSE: u8 = b'p'; // ! shared id //TODO(ppiotr3k): write tests
40// const MESSAGE_ID_PARSE: u8 = b'P'; //TODO(ppiotr3k): write tests
41// const MESSAGE_ID_PASSWORD_MESSAGE: u8 = b'p'; // ! shared id //TODO(ppiotr3k): write tests
42
43///TODO(ppiotr3k): write description
44//TODO(ppiotr3k): investigate if `Clone` is avoidable; currently only used in tests
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub enum Message {
47 NotImplemented(Bytes),
48
49 //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
50 Canary(u8),
51
52 Bind {
53 portal: Bytes,
54 stmt_name: Bytes,
55 parameters: Vec<BindParameter>,
56 results_formats: Vec<u16>,
57 },
58 Execute {
59 portal: Bytes,
60 max_rows: u32,
61 },
62 Flush(),
63 Query(Bytes),
64 SASLInitialResponse {
65 mecanism: Bytes,
66 response: Bytes,
67 },
68 SASLResponse(Bytes),
69 SSLRequest(),
70 StartupMessage {
71 frame_length: usize,
72 parameters: Vec<Parameter>,
73 },
74 Sync(),
75 Terminate(),
76
77 //TODO(ppiotr3k): implement following messages
78 CancelRequest(Bytes),
79 Close(Bytes),
80 CopyData(Bytes),
81 CopyDone(Bytes),
82 CopyFail(Bytes),
83 Describe(Bytes),
84 FunctionCall(Bytes),
85 GSSENCRequest(Bytes),
86 GSSResponse(Bytes),
87 Parse(Bytes),
88 PasswordMessage(Bytes),
89}
90
91impl PostgresMessage for Message {}
92impl SQLMessage for Message {}
93
94///TODO(ppiotr3k): write description
95//TODO(ppiotr3k): internal fields encapsulation
96#[derive(Clone, Debug, Eq, PartialEq)]
97pub struct BindParameter {
98 pub format: u16,
99 pub value: Bytes,
100}
101
102///TODO(ppiotr3k): write description
103//TODO(ppiotr3k): internal fields encapsulation
104#[derive(Clone, Debug, Eq, PartialEq)]
105pub struct Parameter {
106 pub name: Bytes,
107 pub value: Bytes,
108}
109
110///TODO(ppiotr3k): write description
111#[derive(Debug, Clone)]
112enum DecodeState {
113 Startup,
114 Head,
115 Message(usize),
116}
117
118///TODO(ppiotr3k): write description
119#[derive(Debug, Clone)]
120pub struct Codec {
121 /// Read state management / optimization.
122 state: DecodeState,
123}
124
125impl Codec {
126 ///TODO(ppiotr3k): write function description
127 #[must_use]
128 pub const fn new() -> Self {
129 Self {
130 state: DecodeState::Startup,
131 }
132 }
133
134 /// Transitions decoder from `Startup` to next state.
135 pub fn startup_complete(&mut self) {
136 self.state = DecodeState::Head;
137 }
138
139 ///TODO(ppiotr3k): write function description
140 fn decode_header(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
141 if src.len() < BYTES_MESSAGE_HEADER {
142 // Incomplete header, await for more data.
143 log::trace!(
144 "not enough header data ({} bytes), awaiting more ({} bytes)",
145 src.len(),
146 BYTES_MESSAGE_HEADER,
147 );
148 return Ok(None);
149 }
150
151 // Peek into data with a `Cursor` to avoid advancing underlying buffer.
152 let mut buf = io::Cursor::new(&mut *src);
153 buf.advance(BYTES_MESSAGE_ID);
154
155 // 'Message Length' field accounts for self, but not 'Message ID' field.
156 // Note: `usize` prevents from 'Message Length' `i32` value overflow.
157 let frame_length = (buf.get_u32() as usize) + BYTES_MESSAGE_ID;
158
159 // Strict "less than", as null-payload messages exist in protocol.
160 if frame_length < BYTES_MESSAGE_HEADER {
161 log::trace!("invalid frame: {:?}", buf);
162 let err = io::Error::new(
163 io::ErrorKind::InvalidInput,
164 "malformed packet - invalid message length",
165 );
166 log::error!("{}", err);
167 return Err(err);
168 }
169
170 Ok(Some(frame_length))
171 }
172
173 ///TODO(ppiotr3k): write function description
174 fn decode_message(&mut self, len: usize, src: &mut BytesMut) -> io::Result<Option<Message>> {
175 if src.len() < len {
176 // Incomplete message, await for more data.
177 log::trace!(
178 "not enough message data ({} bytes), awaiting more ({} bytes)",
179 src.len(),
180 len
181 );
182 return Ok(None);
183 }
184
185 // Full message, pop it out.
186 let mut frame = src.split_to(len);
187 //TODO(ppiotr3k): consider zero-cost `frame.freeze()` for lazy passing in `Pipe`
188
189 // Frames have at least `BYTES_MESSAGE_HEADER` bytes at this point.
190 let msg_id = frame.get_u8();
191 log::trace!("incoming msg id: '{}' ({})", msg_id as char, msg_id);
192 let msg_length = (frame.get_u32() as usize) - BYTES_MESSAGE_SIZE;
193 log::trace!("incoming msg length: {}", msg_length);
194
195 let msg = match msg_id {
196
197 // Canary
198 //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
199 b'B' /* 0x42 */ => {
200 frame.advance(msg_length);
201 Message::Canary(len as u8)
202 },
203 //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
204 b'!' /* 0x21 */ => {
205 return Err(io::Error::new(io::ErrorKind::InvalidData, "expected canary error"));
206 },
207
208 // Frontend
209 // MESSAGE_ID_BIND => {
210 // //TODO(ppiotr3k): implement this message
211 // },
212 MESSAGE_ID_EXECUTE => {
213 let portal = get_cstr(&mut frame)?;
214 let max_rows = get_u32(&mut frame, "malformed packet - invalid execute data")?;
215 Message::Execute { portal, max_rows }
216 },
217 MESSAGE_ID_FLUSH => Message::Flush(),
218 MESSAGE_ID_QUERY => {
219 let query = get_cstr(&mut frame)?;
220 Message::Query(query)
221 },
222 MESSAGE_ID_SASL => {
223 // `SASLInitialResponse` holds a C-style null char terminated string,
224 // while `SASlResponse` holds bytes with no 0 byte at all in them.
225 // Therefore trying first to look for a `SASLInitialResponse`.
226 //TODO(ppiotr3k): rethink, as `get_cstr` writes errors to logs
227 // -> peeking at last frame byte and looking for a 0 maybe?
228 if let Ok(mecanism) = get_cstr(&mut frame) {
229 const SASL_RESPONSE_SIZE_BYTES: usize = 4;
230 let response = get_bytes(
231 &mut frame,
232 SASL_RESPONSE_SIZE_BYTES,
233 "malformed packet - invalid SASL response data",
234 )?;
235
236 Message::SASLInitialResponse { mecanism, response }
237 } else {
238 let response = frame.copy_to_bytes(frame.remaining());
239
240 // SASLResponse `response` field cannot be empty.
241 if response.is_empty() {
242 let err = std::io::Error::new(
243 std::io::ErrorKind::InvalidInput,
244 "malformed packet - invalid SASL response data",
245 );
246 log::error!("{}", err);
247 return Err(err);
248 }
249
250 Message::SASLResponse(response)
251 }
252 },
253 MESSAGE_ID_SYNC => Message::Sync(),
254 MESSAGE_ID_TERMINATE => Message::Terminate(),
255 _ => {
256 let bytes = frame.copy_to_bytes(msg_length);
257 Message::NotImplemented(bytes)
258 },
259 };
260
261 // At this point, all data should have been consumed from `frame`.
262 if !frame.is_empty() {
263 log::trace!("invalid frame: {:?}", frame);
264 let err = std::io::Error::new(
265 std::io::ErrorKind::InvalidInput,
266 "malformed packet - invalid message length",
267 );
268 log::error!("{}", err);
269 return Err(err);
270 }
271
272 log::debug!("decoded message frame: {:?}", msg);
273 Ok(Some(msg))
274 }
275
276 ///TODO(ppiotr3k): write function description
277 pub fn decode_startup_message(&mut self, src: &mut BytesMut) -> io::Result<Option<Message>> {
278 if src.len() < BYTES_STARTUP_MESSAGE_HEADER {
279 // Incomplete message, await for more data.
280 log::trace!(
281 "not enough header data ({} bytes), awaiting more ({} bytes)",
282 src.len(),
283 BYTES_STARTUP_MESSAGE_HEADER,
284 );
285 return Ok(None);
286 }
287
288 // Peek into data with a `Cursor` to avoid advancing underlying buffer.
289 let mut buf = io::Cursor::new(&mut *src);
290
291 // Note: `usize` prevents from 'Message Length' `i32` value overflow.
292 let frame_length = buf.get_u32() as usize;
293 if src.len() < frame_length {
294 // Incomplete message, await for more data.
295 log::trace!(
296 "not enough message data ({} bytes), awaiting more ({} bytes)",
297 src.len(),
298 frame_length,
299 );
300 return Ok(None);
301 }
302
303 // Full message, pop it out.
304 let mut frame = src.split_to(frame_length);
305 log::trace!("decoded frame length: {}", frame_length);
306 //TODO(ppiotr3k): consider zero-cost `frame.freeze()` for lazy passing in `Pipe`
307
308 frame.advance(4); // `Message Length`
309
310 let msg_id = frame.get_i32();
311 log::trace!("msg id: {}", msg_id);
312 let msg = match msg_id {
313 MESSAGE_ID_STARTUP_MESSAGE => {
314 let mut parameters = Vec::new();
315 let mut user_param_exists = false;
316
317 // At least one parameter and name/value pair terminator are expected.
318 while frame.remaining() > 2 {
319 let parameter_name = get_cstr(&mut frame)?;
320
321 // Note: `user` is the sole required parameter, others are optional.
322 if parameter_name == "user" {
323 user_param_exists = true;
324 }
325
326 let parameter = Parameter {
327 name: parameter_name,
328 value: get_cstr(&mut frame)?,
329 };
330 log::trace!("decoded parameter: {:?}", parameter);
331 parameters.push(parameter);
332 }
333
334 // At this point, only name/value pair terminator should remain,
335 // and a parameter named `user` should have been found.
336 if frame.remaining() < 1 || !user_param_exists {
337 let err = std::io::Error::new(
338 std::io::ErrorKind::InvalidInput,
339 "malformed packet - missing parameter fields",
340 );
341 log::error!("{}", err);
342 return Err(err);
343 }
344 frame.advance(1); // name/value pair terminator
345
346 Message::StartupMessage {
347 frame_length,
348 parameters,
349 }
350 }
351 MESSAGE_ID_SSL_REQUEST => Message::SSLRequest(),
352 _ => {
353 // If neither a recognized `StartupMessage` nor `SSLRequest`,
354 // consider as `StartupMessage` with unsupported protocol version.
355 let err = std::io::Error::new(
356 std::io::ErrorKind::InvalidInput,
357 "malformed packet - invalid protocol version",
358 );
359 log::error!("{}", err);
360 return Err(err);
361 }
362 };
363 log::debug!("decoded message frame: {:?}", msg);
364 Ok(Some(msg))
365 }
366
367 ///TODO(ppiotr3k): write function description
368 //TODO(ppiotr3k): get size from Message struct
369 // -> pre-requisite: enum variants are considered as types in Rust
370 fn encode_header(&mut self, msg_id: u8, msg_size: usize, dst: &mut BytesMut) {
371 dst.reserve(BYTES_MESSAGE_HEADER + msg_size);
372 dst.put_u8(msg_id);
373 dst.put_u32((BYTES_MESSAGE_SIZE + msg_size) as u32);
374 }
375}
376
377impl Decoder for Codec {
378 type Item = Message;
379 type Error = io::Error;
380
381 fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Self::Item>> {
382 log::trace!("decoder state: {:?}", self.state);
383 let msg_length = match self.state {
384 // During startup sequence, frontend can send an `SSLRequest` message rather
385 // than a `StartupMessage`. `Startup` state handles this initial edge case.
386 // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.12
387 DecodeState::Startup => match self.decode_startup_message(src)? {
388 None => return Ok(None),
389 Some(Message::SSLRequest()) => return Ok(Some(Message::SSLRequest())),
390 Some(Message::StartupMessage {
391 frame_length,
392 parameters,
393 }) => {
394 self.startup_complete();
395 return Ok(Some(Message::StartupMessage {
396 frame_length,
397 parameters,
398 }));
399 }
400 Some(other) => {
401 let err = io::Error::new(
402 io::ErrorKind::InvalidData,
403 //TODO(ppiotr3k): rewrite without debug symbols
404 format!("unexpected message during startup: {:?}", other),
405 );
406 log::error!("{}", err);
407 return Err(err);
408 }
409 },
410
411 DecodeState::Head => match self.decode_header(src)? {
412 // Incomplete header, await for more data.
413 None => return Ok(None),
414 // Header available, try getting full message.
415 Some(length) => {
416 self.state = DecodeState::Message(length);
417
418 // Ensure enough space is available to read incoming payload.
419 // Note: acceptable over-allocation by content of `BYTES_MESSAGE_SIZE`.
420 src.reserve(length);
421 log::trace!("stream buffer capacity: {} bytes", src.capacity());
422
423 length
424 }
425 },
426
427 DecodeState::Message(length) => length,
428 };
429 log::trace!("decoded frame length: {} bytes", msg_length);
430
431 match self.decode_message(msg_length, src)? {
432 // Incomplete message, await for more data.
433 None => Ok(None),
434 // Full message, pop it out, move on to parsing a new one.
435 Some(msg) => {
436 self.state = DecodeState::Head;
437
438 // Ensure enough space is available to read next header.
439 src.reserve(BYTES_MESSAGE_HEADER);
440 log::trace!("stream buffer capacity: {} bytes", src.capacity());
441
442 Ok(Some(msg))
443 }
444 }
445 }
446}
447
448impl Encoder<Message> for Codec {
449 type Error = io::Error;
450
451 fn encode(&mut self, msg: Message, dst: &mut BytesMut) -> Result<(), io::Error> {
452 //TODO(ppiotr3k): rationalize capacity reservation with `dst.reserve(msg.len())`
453 // -> pre-requisite: enum variants are considered as types in Rust
454 match msg {
455 Message::Execute { portal, max_rows } => {
456 self.encode_header(MESSAGE_ID_EXECUTE, portal.len() + 1 + 4, dst);
457 put_cstr(&portal, dst);
458 dst.put_i32(max_rows as i32);
459 }
460 Message::Flush() => {
461 self.encode_header(MESSAGE_ID_FLUSH, 0, dst);
462 }
463 Message::Query(query) => {
464 self.encode_header(MESSAGE_ID_QUERY, query.len() + 1, dst);
465 put_cstr(&query, dst);
466 }
467 Message::SASLInitialResponse { mecanism, response } => {
468 self.encode_header(
469 MESSAGE_ID_SASL,
470 mecanism.len() + 1 + 4 + response.len(),
471 dst,
472 );
473 put_cstr(&mecanism, dst);
474 put_bytes(&response, dst);
475 }
476 Message::SASLResponse(response) => {
477 self.encode_header(MESSAGE_ID_SASL, response.len(), dst);
478 dst.put(response);
479 }
480 Message::StartupMessage {
481 frame_length,
482 parameters,
483 } => {
484 dst.reserve(frame_length);
485 dst.put_i32(frame_length as i32);
486 dst.put_i32(196608);
487 for parameter in ¶meters {
488 put_cstr(¶meter.name, dst);
489 put_cstr(¶meter.value, dst);
490 }
491 dst.put_u8(0); // name/value pair terminator
492 }
493 Message::SSLRequest() => {
494 dst.reserve(8);
495 dst.put_i32(8);
496 dst.put_i32(80877103);
497 }
498 Message::Sync() => {
499 self.encode_header(MESSAGE_ID_SYNC, 0, dst);
500 }
501 Message::Terminate() => {
502 self.encode_header(MESSAGE_ID_TERMINATE, 0, dst);
503 }
504 other => {
505 unimplemented!("not implemented: {:?}", other)
506 }
507 }
508
509 // Message has been written to `Sink`, nothing left to do.
510 // Note: if bytes remain in frame, encoding tests need a review.
511 Ok(())
512 }
513}
514
515impl Default for Codec {
516 fn default() -> Self {
517 Self::new()
518 }
519}
520
521#[cfg(test)]
522mod decode_tests {
523
524 use bytes::{Bytes, BytesMut};
525 use test_log::test;
526
527 use super::{Codec, Message, Parameter};
528
529 /// Helper function to ease writing decoding tests for startup sequence.
530 fn assert_decode_startup_message(data: &[u8], expected: &[Message], remaining: usize) {
531 let buf = &mut BytesMut::from(data);
532 let mut decoded = Vec::new();
533
534 let mut codec = Codec::new();
535 while let Ok(Some(msg)) = codec.decode_startup_message(buf) {
536 decoded.push(msg);
537 }
538
539 assert_eq!(remaining, buf.len(), "remaining bytes in read buffer");
540 assert_eq!(expected.len(), decoded.len(), "decoded messages");
541 assert_eq!(expected, decoded, "decoded messages");
542 }
543
544 #[test]
545 #[rustfmt::skip]
546 fn valid_startup_message() {
547 let data = [
548 0, 0, 0, 78, // total length: 78
549 0, 3, 0, 0, // protocol version: 3.0
550 117, 115, 101, 114, 0, // cstr: "user\0"
551 114, 111, 111, 116, 0, // cstr: "root\0"
552 100, 97, 116, 97, 98, 97, 115, 101, 0, // cstr: "database\0"
553 116, 101, 115, 116, 100, 98, 0, // cstr: "testdb\0"
554 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
555 112, 115, 113, 108, 0, // cstr: "psql\0"
556 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, // cstr: "client_encoding\0"
557 85, 84, 70, 56, 0, // cstr: "UTF8\0"
558 0, // name/value pair terminator
559 ];
560
561 let expected = vec![
562 Message::StartupMessage {
563 frame_length: 78,
564 parameters: vec![
565 Parameter {
566 name: Bytes::from_static(b"user"),
567 value: Bytes::from_static(b"root"),
568 },
569 Parameter {
570 name: Bytes::from_static(b"database"),
571 value: Bytes::from_static(b"testdb"),
572 },
573 Parameter {
574 name: Bytes::from_static(b"application_name"),
575 value: Bytes::from_static(b"psql"),
576 },
577 Parameter {
578 name: Bytes::from_static(b"client_encoding"),
579 value: Bytes::from_static(b"UTF8"),
580 },
581 ]},
582 ];
583 let remaining = 0;
584
585 assert_decode_startup_message(&data[..], &expected, remaining);
586 }
587
588 #[test]
589 #[rustfmt::skip]
590 fn invalid_startup_message_wrong_protocol_version() {
591 let data = [
592 0, 0, 0, 78, // total length: 78
593 0, 2, 0, 0, // wrong protocol version: 2.0
594 117, 115, 101, 114, 0, // cstr: "user\0"
595 114, 111, 111, 116, 0, // cstr: "root\0"
596 100, 97, 116, 97, 98, 97, 115, 101, 0, // cstr: "database\0"
597 116, 101, 115, 116, 100, 98, 0, // cstr: "testdb\0"
598 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
599 112, 115, 113, 108, 0, // cstr: "psql\0"
600 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, // cstr: "client_encoding\0"
601 85, 84, 70, 56, 0, // cstr: "UTF8\0"
602 0, // name/value pair terminator
603 ];
604
605 let expected = vec![];
606 let remaining = 0;
607
608 assert_decode_startup_message(&data[..], &expected, remaining);
609 }
610
611 #[test]
612 #[rustfmt::skip]
613 fn invalid_startup_message_missing_required_user() {
614 let data = [
615 0, 0, 0, 68, // total length: 68
616 0, 3, 0, 0, // protocol version: 3.0
617 100, 97, 116, 97, 98, 97, 115, 101, 0, // cstr: "database\0"
618 116, 101, 115, 116, 100, 98, 0, // cstr: "testdb\0"
619 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
620 112, 115, 113, 108, 0, // cstr: "psql\0"
621 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, // cstr: "client_encoding\0"
622 85, 84, 70, 56, 0, // cstr: "UTF8\0"
623 0, // name/value pair terminator
624 ];
625
626 let expected = vec![];
627 let remaining = 0;
628
629 assert_decode_startup_message(&data[..], &expected, remaining);
630 }
631
632 #[test]
633 #[rustfmt::skip]
634 fn invalid_startup_message_empty_parameters_list() {
635 let data = [
636 0, 0, 0, 9, // total length: 9
637 0, 3, 0, 0, // protocol version: 3.0
638 0, // name/value pair terminator
639 ];
640
641 let expected = vec![];
642 let remaining = 0;
643
644 assert_decode_startup_message(&data[..], &expected, remaining);
645 }
646
647 #[test]
648 #[rustfmt::skip]
649 fn invalid_startup_message_missing_parameters_data() {
650 let data = [
651 0, 0, 0, 8, // total length: 8
652 0, 3, 0, 0, // protocol version: 3.0
653 // missing parameters data
654 ];
655
656 let expected = vec![];
657 let remaining = 0;
658
659 assert_decode_startup_message(&data[..], &expected, remaining);
660 }
661
662 #[test]
663 #[rustfmt::skip]
664 fn invalid_startup_message_missing_parameters_list_terminator() {
665 let data = [
666 0, 0, 0, 77, // total length: 77
667 0, 3, 0, 0, // protocol version: 3.0
668 117, 115, 101, 114, 0, // cstr: "user\0"
669 114, 111, 111, 116, 0, // cstr: "root\0"
670 100, 97, 116, 97, 98, 97, 115, 101, 0, // cstr: "database\0"
671 116, 101, 115, 116, 100, 98, 0, // cstr: "testdb\0"
672 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, // cstr: "application_name\0"
673 112, 115, 113, 108, 0, // cstr: "psql\0"
674 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, // cstr: "client_encoding\0"
675 85, 84, 70, 56, 0, // cstr: "UTF8\0"
676 // missing name/value pair terminator
677 ];
678
679 let expected = vec![];
680 let remaining = 0;
681
682 assert_decode_startup_message(&data[..], &expected, remaining);
683 }
684
685 #[test]
686 #[rustfmt::skip]
687 fn invalid_startup_message_missing_parameter_field() {
688 let data = [
689 0, 0, 0, 28, // total length: 28
690 0, 3, 0, 0, // protocol version: 3.0
691 117, 115, 101, 114, 0, // cstr: "user\0"
692 114, 111, 111, 116, 0, // cstr: "root\0"
693 100, 97, 116, 97, 98, 97, 115, 101, 0, // cstr: "database\0"
694 0, // missing value field || missing name/value pair terminator
695 ];
696
697 let expected = vec![];
698 let remaining = 0;
699
700 assert_decode_startup_message(&data[..], &expected, remaining);
701 }
702}