fern_protocol_postgresql/codec/
backend.rs

1// SPDX-FileCopyrightText:  Copyright © 2022 The Fern Authors <team@fernproxy.io>
2// SPDX-License-Identifier: Apache-2.0
3
4//! [`Decoder`]/[`Encoder`] traits implementations
5//! for PostgreSQL backend Messages.
6//!
7//! [`Decoder`]: https://docs.rs/tokio-util/*/tokio_util/codec/trait.Decoder.html
8//! [`Encoder`]: https://docs.rs/tokio-util/*/tokio_util/codec/trait.Encoder.html
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use std::io;
12use tokio_util::codec::{Decoder, Encoder};
13
14use super::{PostgresMessage, SQLMessage};
15use crate::codec::constants::*;
16use crate::codec::utils::*;
17
18const MESSAGE_ID_AUTHENTICATION: u8 = b'R';
19const MESSAGE_ID_BACKEND_KEY_DATA: u8 = b'K';
20const MESSAGE_ID_COMMAND_COMPLETE: u8 = b'C';
21const MESSAGE_ID_DATA_ROW: u8 = b'D';
22const MESSAGE_ID_EMPTY_QUERY_RESPONSE: u8 = b'I';
23const MESSAGE_ID_ERROR_RESPONSE: u8 = b'E'; //TODO(ppiotr3k): write tests
24const MESSAGE_ID_PARAMETER_STATUS: u8 = b'S';
25const MESSAGE_ID_READY_FOR_QUERY: u8 = b'Z';
26const MESSAGE_ID_ROW_DESCRIPTION: u8 = b'T';
27
28//TODO(ppiotr3k): implement following messages
29// const MESSAGE_ID_AUTHENTICATION_KERBEROS_V5: u8 = b'R'; // 2
30// const MESSAGE_ID_AUTHENTICATION_CLEARTEXT_PASSWORD: u8 = b'R'; // 3
31// const MESSAGE_ID_AUTHENTICATION_MD5_PASSWORD: u8 = b'R'; // 5
32// const MESSAGE_ID_AUTHENTICATION_SCM_CREDENTIAL: u8 = b'R'; // 6
33// const MESSAGE_ID_AUTHENTICATION_GSS: u8 = b'R'; // 7
34// const MESSAGE_ID_AUTHENTICATION_GSS_CONTINUE: u8 = b'R'; // 8
35// const MESSAGE_ID_AUTHENTICATION_SSPI: u8 = b'R'; // 9
36// const MESSAGE_ID_BIND_COMPLETE: u8 = b'2';
37// const MESSAGE_ID_CLOSE_COMPLETE: u8 = b'3';
38// const MESSAGE_ID_COPY_DATA: u8 = b'd';
39// const MESSAGE_ID_COPY_DONE: u8 = b'c';
40// const MESSAGE_ID_COPY_IN_RESPONSE: u8 = b'G';
41// const MESSAGE_ID_COPY_OUT_RESPONSE: u8 = b'H';
42// const MESSAGE_ID_COPY_BOTH_RESPONSE: u8 = b'W';
43// const MESSAGE_ID_FUNCTION_CALL_RESPONSE: u8 = b'V';
44// const MESSAGE_ID_NEGOTIATE_PROTOCOL_VERSION: u8 = b'v';
45// const MESSAGE_ID_NO_DATA: u8 = b'n';
46// const MESSAGE_ID_NOTICE_RESPONSE: u8 = b'N';
47// const MESSAGE_ID_NOTIFICATION_RESPONSE: u8 = b'A';
48// const MESSAGE_ID_PARAMETER_DESCRIPTION: u8 = b'B';
49// const MESSAGE_ID_PARSE_COMPLETE: u8 = b'1';
50// const MESSAGE_ID_PORTAL_SUSPENDED: u8 = b's';
51
52///TODO(ppiotr3k): write description
53//TODO(ppiotr3k): investigate if `Clone` is avoidable; currently only used in tests
54#[derive(Clone, Debug, Eq, PartialEq)]
55pub enum Message {
56    NotImplemented(Bytes),
57
58    //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
59    Canary(u8),
60
61    AuthenticationOk(),
62    AuthenticationSASL(Bytes),
63    AuthenticationSASLContinue(Bytes),
64    AuthenticationSASLFinal(Bytes),
65    CommandComplete(Bytes),
66    BackendKeyData { process: u32, secret_key: u32 },
67    DataRow(Vec<Bytes>),
68    EmptyQueryResponse(),
69    ErrorResponse(Bytes),
70    ParameterStatus { parameter: Bytes, value: Bytes },
71    ReadyForQuery(u8),
72    RowDescription(Vec<RowDescription>),
73
74    //TODO(ppiotr3k): implement following messages
75    AuthenticationKerberosV5(Bytes),
76    AuthenticationCleartextPassword(Bytes),
77    AuthenticationMD5Password(Bytes),
78    AuthenticationSCMCredential(Bytes),
79    AuthenticationGSS(Bytes),
80    AuthenticationGSSContinue(Bytes),
81    AuthenticationSSPI(Bytes),
82    BindComplete(Bytes),
83    CloseComplete(Bytes),
84    CopyData(Bytes),
85    CopyDone(Bytes),
86    CopyInResponse(Bytes),
87    CopyOutResponse(Bytes),
88    CopyBothResponse(Bytes),
89    FunctionCallResponse(Bytes),
90    NegotiateProtocolVersion(Bytes),
91    NoData(),
92    NoticeResponse(Bytes),
93    NotificationResponse(Bytes),
94    ParameterDescription(Bytes),
95    ParseComplete(),
96    PortalSuspended(),
97}
98
99///TODO(ppiotr3k): write description
100//TODO(ppiotr3k): investigate if `Clone` is avoidable; currently only used in tests
101//TODO(ppiotr3k): internal fields encapsulation
102#[derive(Clone, Debug, Eq, PartialEq)]
103pub struct RowDescription {
104    pub name: Bytes,
105    pub table_oid: u32,
106    pub column_attr: u16,
107    pub data_type_oid: u32,
108    pub data_type_size: i16,
109    pub type_modifier: i32,
110    pub format: u16,
111}
112
113///TODO(ppiotr3k): write description
114#[derive(Debug, Clone)]
115enum DecodeState {
116    Head,
117    Message(usize),
118}
119
120///TODO(ppiotr3k): write description
121#[derive(Debug, Clone)]
122pub struct Codec {
123    /// Read state management / optimization
124    state: DecodeState,
125}
126
127impl Codec {
128    ///TODO(ppiotr3k): write function description
129    #[must_use]
130    pub const fn new() -> Self {
131        Self {
132            state: DecodeState::Head,
133        }
134    }
135
136    ///TODO(ppiotr3k): write function description
137    fn decode_header(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
138        if src.len() < BYTES_MESSAGE_HEADER {
139            // Incomplete header, await for more data.
140            log::trace!(
141                "not enough header data ({} bytes), awaiting more ({} bytes)",
142                src.len(),
143                BYTES_MESSAGE_HEADER,
144            );
145            return Ok(None);
146        }
147
148        let mut buf = io::Cursor::new(&mut *src);
149        buf.advance(BYTES_MESSAGE_ID);
150
151        // 'Message Length' field accounts for self, but not 'Message ID' field.
152        // Note: `usize` prevents from 'Message Length' `i32` value overflow.
153        let frame_length = (buf.get_u32() as usize) + BYTES_MESSAGE_ID;
154
155        // Strict "less than", as null-payload messages exist in protocol.
156        if frame_length < BYTES_MESSAGE_HEADER {
157            log::trace!("invalid frame: {:?}", buf);
158            let err = std::io::Error::new(
159                std::io::ErrorKind::InvalidInput,
160                "malformed packet - invalid message length",
161            );
162            log::error!("{}", err);
163            return Err(err);
164        }
165
166        Ok(Some(frame_length))
167    }
168
169    ///TODO(ppiotr3k): write function description
170    fn decode_message(&mut self, len: usize, src: &mut BytesMut) -> io::Result<Option<Message>> {
171        if src.len() < len {
172            // Incomplete message, await for more data.
173            log::trace!(
174                "not enough message data ({} bytes), awaiting more ({} bytes)",
175                src.len(),
176                len
177            );
178            return Ok(None);
179        }
180
181        // Full message, pop it out.
182        let mut frame = src.split_to(len);
183        //TODO(ppiotr3k): consider zero-cost `frame.freeze()` for lazy passing in `Pipe`
184
185        // Frames have at least `BYTES_MESSAGE_HEADER` bytes at this point.
186        let msg_id = frame.get_u8();
187        log::trace!("incoming msg id: '{}' ({})", msg_id as char, msg_id);
188        let msg_length = (frame.get_u32() as usize) - BYTES_MESSAGE_SIZE;
189        log::trace!("incoming msg length: {}", msg_length);
190
191        let msg = match msg_id {
192
193            // Canary
194            //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
195            b'B' /* 0x42 */ => {
196                frame.advance(msg_length);
197                Message::Canary(len as u8)
198            },
199            //#[cfg(test)] //TODO(ppiotr3k): fix enabling `Canary` only in tests
200            b'!' /* 0x21 */ => {
201                return Err(io::Error::new(io::ErrorKind::InvalidData, "expected canary error"));
202            },
203
204            // Backend
205            MESSAGE_ID_AUTHENTICATION => {
206                let authn_case = get_u32(&mut frame, "malformed packet - invalid authentication data")?;
207                match authn_case {
208                     0 /* AuthenticationOk */=> Message::AuthenticationOk(),
209                    10 /* AuthenticationSASL */ => {
210                        let data = get_cstr(&mut frame)?;
211
212                        // A zero byte is required as terminator after the last authn mechanism.
213                        //TODO(ppiotr3k): write a test where it is a different value than zero
214                        if frame.is_empty() {
215                            let err = std::io::Error::new(
216                                std::io::ErrorKind::InvalidInput,
217                                "malformed packet - invalid SASL mecanism data",
218                            );
219                            log::error!("{}", err);
220                            return Err(err);
221                        }
222                        frame.advance(1); // zero byte list terminator
223
224                        Message::AuthenticationSASL(data)
225                    },
226                    11 /* AuthenticationSASLContinue */ => {
227                        let response = frame.copy_to_bytes(frame.remaining());
228
229                        // AuthenticationSASLContinue `response` cannot be empty.
230                        if response.is_empty() {
231                            let err = std::io::Error::new(
232                                std::io::ErrorKind::InvalidInput,
233                                "malformed packet - invalid SASL response data",
234                            );
235                            log::error!("{}", err);
236                            return Err(err);
237                        }
238
239                        Message::AuthenticationSASLContinue(response)
240                    },
241                    12 /* AuthenticationSASLFinal */ => {
242                        let response = frame.copy_to_bytes(frame.remaining());
243
244                        // AuthenticationSASLFinal `response` cannot be empty.
245                        if response.is_empty() {
246                            let err = std::io::Error::new(
247                                std::io::ErrorKind::InvalidInput,
248                                "malformed packet - invalid SASL response data",
249                            );
250                            log::error!("{}", err);
251                            return Err(err);
252                        }
253
254                        Message::AuthenticationSASLFinal(response)
255                    },
256                    _ => {
257                        let err = std::io::Error::new(
258                            std::io::ErrorKind::InvalidInput,
259                            "malformed packet - invalid SASL identifier",
260                        );
261                        log::error!("{}", err);
262                        return Err(err);
263                    }
264                }
265            },
266            MESSAGE_ID_BACKEND_KEY_DATA => {
267                let process = get_u32(&mut frame, "malformed packet - invalid key data")?;
268                let secret_key = get_u32(&mut frame, "malformed packet - invalid key data")?;
269                Message::BackendKeyData { process, secret_key }
270            },
271            MESSAGE_ID_COMMAND_COMPLETE => {
272                let command = get_cstr(&mut frame)?;
273                Message::CommandComplete(command)
274            },
275            MESSAGE_ID_DATA_ROW => {
276                let fields = self.get_data_row_fields(&mut frame)?;
277                Message::DataRow(fields)
278            },
279            MESSAGE_ID_ERROR_RESPONSE => {
280                //TODO(ppiotr3k): identify if parsing those fields is of interest
281                let unparsed_fields = frame.copy_to_bytes(msg_length);
282                Message::ErrorResponse(unparsed_fields)
283            },
284            MESSAGE_ID_EMPTY_QUERY_RESPONSE => Message::EmptyQueryResponse(),
285            MESSAGE_ID_PARAMETER_STATUS => {
286                let parameter = get_cstr(&mut frame)?;
287                let value = get_cstr(&mut frame)?;
288                Message::ParameterStatus { parameter, value }
289            },
290            MESSAGE_ID_READY_FOR_QUERY => {
291                let status = get_u8(&mut frame, "malformed packet - missing status indicator")?;
292                match status {
293                    b'I' | b'T'| b'E' => Message::ReadyForQuery(status),
294                    _ => {
295                        let err = std::io::Error::new(
296                            std::io::ErrorKind::InvalidInput,
297                            "malformed packet - invalid status indicator",
298                        );
299                        log::error!("{}", err);
300                        return Err(err);
301                    },
302                }
303            },
304            MESSAGE_ID_ROW_DESCRIPTION => {
305                let descriptions = self.get_row_descriptions(&mut frame)?;
306                Message::RowDescription(descriptions)
307            },
308            _ => {
309                let bytes = frame.copy_to_bytes(msg_length);
310                unimplemented!("msg_id: {} ({:?})", msg_id, bytes);
311            },
312        };
313
314        // At this point, all data should have been consumed from `frame`.
315        if !frame.is_empty() {
316            log::trace!("invalid frame: {:?}", frame);
317            let err = std::io::Error::new(
318                std::io::ErrorKind::InvalidInput,
319                "malformed packet - invalid message length",
320            );
321            log::error!("{}", err);
322            return Err(err);
323        }
324
325        log::debug!("decoded message frame: {:?}", msg);
326        Ok(Some(msg))
327    }
328
329    ///TODO(ppiotr3k): write function description
330    fn get_row_descriptions(&mut self, buf: &mut BytesMut) -> io::Result<Vec<RowDescription>> {
331        let mut columns = get_u16(buf, "malformed packet - invalid data size")?;
332        log::trace!("decoded number of description columns: {}", columns);
333
334        let mut decoded = Vec::new();
335
336        const BYTES_ROW_DESCRIPTION_COMMON_LENGTH: usize = 18;
337        while columns > 0 {
338            let column_name = get_cstr(buf)?;
339
340            if buf.remaining() < BYTES_ROW_DESCRIPTION_COMMON_LENGTH {
341                let err = std::io::Error::new(
342                    std::io::ErrorKind::InvalidInput,
343                    "malformed packet - invalid row description structure",
344                );
345                log::error!("{}", err);
346                return Err(err);
347            }
348
349            let description = RowDescription {
350                name: column_name,
351                table_oid: get_u32(buf, "malformed packet - invalid data size")?,
352                column_attr: get_u16(buf, "malformed packet - invalid data size")?,
353                data_type_oid: get_u32(buf, "malformed packet - invalid data size")?,
354                data_type_size: get_i16(buf, "malformed packet - invalid data size")?,
355                type_modifier: get_i32(buf, "malformed packet - invalid data size")?,
356                format: get_u16(buf, "malformed packet - invalid data size")?,
357            };
358
359            log::trace!("decoded row description: {:?}", description);
360            decoded.push(description);
361            columns -= 1;
362        }
363
364        Ok(decoded)
365    }
366
367    ///TODO(ppiotr3k): write function description
368    fn get_data_row_fields(&mut self, buf: &mut BytesMut) -> io::Result<Vec<Bytes>> {
369        let mut fields = buf.get_u16();
370        log::trace!("decoded number of row fields: {}", fields);
371
372        let mut decoded = Vec::new();
373
374        const BYTES_DATA_ROW_FIELD_LENGTH: usize = 4;
375        while fields > 0 {
376            let value = get_bytes(
377                buf,
378                BYTES_DATA_ROW_FIELD_LENGTH,
379                "malformed packet - invalid field size",
380            )?;
381
382            log::trace!("decoded field: {:?}", value);
383            decoded.push(value);
384            fields -= 1;
385        }
386
387        Ok(decoded)
388    }
389
390    ///TODO(ppiotr3k): write function description
391    //TODO(ppiotr3k): get size from Message struct
392    // -> pre-requisite: enum variants are considered as types in Rust
393    fn encode_header(&mut self, msg_id: u8, msg_size: usize, dst: &mut BytesMut) {
394        dst.reserve(BYTES_MESSAGE_HEADER + msg_size);
395        dst.put_u8(msg_id);
396        dst.put_u32((BYTES_MESSAGE_SIZE + msg_size) as u32);
397    }
398}
399
400impl PostgresMessage for Message {}
401impl SQLMessage for Message {}
402
403impl Decoder for Codec {
404    type Item = Message;
405    type Error = io::Error;
406
407    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Self::Item>> {
408        let msg_length = match self.state {
409            DecodeState::Head => match self.decode_header(src)? {
410                // Incomplete header, await for more data.
411                None => return Ok(None),
412                // Header available, try getting full message.
413                Some(length) => {
414                    self.state = DecodeState::Message(length);
415
416                    // Ensure enough space is available to read incoming payload.
417                    // Note: acceptable over-allocation by content of `BYTES_MESSAGE_SIZE`.
418                    src.reserve(length);
419                    log::trace!("stream buffer capacity: {} bytes", src.capacity());
420
421                    length
422                }
423            },
424            DecodeState::Message(length) => length,
425        };
426        log::trace!("decoded frame length: {} bytes", msg_length);
427
428        match self.decode_message(msg_length, src)? {
429            // Incomplete message, await for more data.
430            None => Ok(None),
431            // Full message, pop it out, move on to parsing a new one.
432            Some(msg) => {
433                self.state = DecodeState::Head;
434
435                // Ensure enough space is available to read next header.
436                src.reserve(BYTES_MESSAGE_HEADER);
437                log::trace!("stream buffer capacity: {} bytes", src.capacity());
438
439                Ok(Some(msg))
440            }
441        }
442    }
443}
444
445impl Encoder<Message> for Codec {
446    type Error = io::Error;
447
448    fn encode(&mut self, msg: Message, dst: &mut BytesMut) -> Result<(), io::Error> {
449        //TODO(ppiotr3k): rationalize capacity reservation with `dst.reserve(msg.len())`
450        // -> pre-requisite: enum variants are considered as types in Rust
451        match msg {
452            Message::AuthenticationOk() => {
453                self.encode_header(MESSAGE_ID_AUTHENTICATION, 4, dst);
454                dst.put_i32(0);
455            }
456            Message::AuthenticationSASL(data) => {
457                self.encode_header(MESSAGE_ID_AUTHENTICATION, 4 + data.len() + 1 + 1, dst);
458                dst.put_i32(10);
459                put_cstr(&data, dst);
460                dst.put_u8(0); // zero byte list terminator
461            }
462            Message::AuthenticationSASLContinue(response) => {
463                self.encode_header(MESSAGE_ID_AUTHENTICATION, 4 + response.len(), dst);
464                dst.put_i32(11);
465                dst.put(response);
466            }
467            Message::AuthenticationSASLFinal(response) => {
468                self.encode_header(MESSAGE_ID_AUTHENTICATION, 4 + response.len(), dst);
469                dst.put_i32(12);
470                dst.put(response);
471            }
472            Message::BackendKeyData {
473                process,
474                secret_key,
475            } => {
476                self.encode_header(MESSAGE_ID_BACKEND_KEY_DATA, 4 + 4, dst);
477                dst.put_i32(process as i32);
478                dst.put_i32(secret_key as i32);
479            }
480            Message::CommandComplete(command) => {
481                self.encode_header(MESSAGE_ID_COMMAND_COMPLETE, command.len() + 1, dst);
482                put_cstr(&command, dst);
483            }
484            Message::DataRow(fields) => {
485                let mut msg_size = 2;
486                for field in fields.iter() {
487                    msg_size += field.len() + 4;
488                }
489
490                self.encode_header(MESSAGE_ID_DATA_ROW, msg_size, dst);
491                dst.put_u16(fields.len() as u16);
492
493                for field in fields.iter() {
494                    put_bytes(field, dst)
495                }
496            }
497            Message::EmptyQueryResponse() => {
498                self.encode_header(MESSAGE_ID_EMPTY_QUERY_RESPONSE, 0, dst);
499            }
500            Message::ErrorResponse(unparsed_fields) => {
501                self.encode_header(MESSAGE_ID_ERROR_RESPONSE, unparsed_fields.len(), dst);
502                dst.put(unparsed_fields);
503            }
504            Message::ParameterStatus { parameter, value } => {
505                self.encode_header(
506                    MESSAGE_ID_PARAMETER_STATUS,
507                    parameter.len() + 1 + value.len() + 1,
508                    dst,
509                );
510                put_cstr(&parameter, dst);
511                put_cstr(&value, dst);
512            }
513            Message::ReadyForQuery(status) => {
514                self.encode_header(MESSAGE_ID_READY_FOR_QUERY, 1, dst);
515                dst.put_u8(status);
516            }
517            Message::RowDescription(descriptions) => {
518                let mut msg_size = 2;
519                for column in descriptions.iter() {
520                    msg_size += column.name.len() + 1 + 4 + 2 + 4 + 2 + 4 + 2;
521                }
522
523                self.encode_header(MESSAGE_ID_ROW_DESCRIPTION, msg_size, dst);
524                dst.put_u16(descriptions.len() as u16);
525
526                for column in descriptions.iter() {
527                    put_cstr(&column.name, dst);
528                    dst.put_u32(column.table_oid);
529                    dst.put_u16(column.column_attr);
530                    dst.put_u32(column.data_type_oid);
531                    dst.put_i16(column.data_type_size);
532                    dst.put_i32(column.type_modifier);
533                    dst.put_u16(column.format);
534                }
535            }
536            other => {
537                unimplemented!("msg: {:?}", other)
538            }
539        }
540
541        // Message has been written to `Sink`, nothing left to do.
542        // Note: if bytes remain in frame, encoding tests need a review.
543        Ok(())
544    }
545}
546
547impl Default for Codec {
548    fn default() -> Self {
549        Self::new()
550    }
551}