1use bytes::{Buf, BufMut, Bytes, BytesMut};
11use std::io;
12use tokio_util::codec::{Decoder, Encoder};
13
14use super::{PostgresMessage, SQLMessage};
15use crate::codec::constants::*;
16use crate::codec::utils::*;
17
18const MESSAGE_ID_AUTHENTICATION: u8 = b'R';
19const MESSAGE_ID_BACKEND_KEY_DATA: u8 = b'K';
20const MESSAGE_ID_COMMAND_COMPLETE: u8 = b'C';
21const MESSAGE_ID_DATA_ROW: u8 = b'D';
22const MESSAGE_ID_EMPTY_QUERY_RESPONSE: u8 = b'I';
23const MESSAGE_ID_ERROR_RESPONSE: u8 = b'E'; const MESSAGE_ID_PARAMETER_STATUS: u8 = b'S';
25const MESSAGE_ID_READY_FOR_QUERY: u8 = b'Z';
26const MESSAGE_ID_ROW_DESCRIPTION: u8 = b'T';
27
28#[derive(Clone, Debug, Eq, PartialEq)]
55pub enum Message {
56 NotImplemented(Bytes),
57
58 Canary(u8),
60
61 AuthenticationOk(),
62 AuthenticationSASL(Bytes),
63 AuthenticationSASLContinue(Bytes),
64 AuthenticationSASLFinal(Bytes),
65 CommandComplete(Bytes),
66 BackendKeyData { process: u32, secret_key: u32 },
67 DataRow(Vec<Bytes>),
68 EmptyQueryResponse(),
69 ErrorResponse(Bytes),
70 ParameterStatus { parameter: Bytes, value: Bytes },
71 ReadyForQuery(u8),
72 RowDescription(Vec<RowDescription>),
73
74 AuthenticationKerberosV5(Bytes),
76 AuthenticationCleartextPassword(Bytes),
77 AuthenticationMD5Password(Bytes),
78 AuthenticationSCMCredential(Bytes),
79 AuthenticationGSS(Bytes),
80 AuthenticationGSSContinue(Bytes),
81 AuthenticationSSPI(Bytes),
82 BindComplete(Bytes),
83 CloseComplete(Bytes),
84 CopyData(Bytes),
85 CopyDone(Bytes),
86 CopyInResponse(Bytes),
87 CopyOutResponse(Bytes),
88 CopyBothResponse(Bytes),
89 FunctionCallResponse(Bytes),
90 NegotiateProtocolVersion(Bytes),
91 NoData(),
92 NoticeResponse(Bytes),
93 NotificationResponse(Bytes),
94 ParameterDescription(Bytes),
95 ParseComplete(),
96 PortalSuspended(),
97}
98
99#[derive(Clone, Debug, Eq, PartialEq)]
103pub struct RowDescription {
104 pub name: Bytes,
105 pub table_oid: u32,
106 pub column_attr: u16,
107 pub data_type_oid: u32,
108 pub data_type_size: i16,
109 pub type_modifier: i32,
110 pub format: u16,
111}
112
113#[derive(Debug, Clone)]
115enum DecodeState {
116 Head,
117 Message(usize),
118}
119
120#[derive(Debug, Clone)]
122pub struct Codec {
123 state: DecodeState,
125}
126
127impl Codec {
128 #[must_use]
130 pub const fn new() -> Self {
131 Self {
132 state: DecodeState::Head,
133 }
134 }
135
136 fn decode_header(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
138 if src.len() < BYTES_MESSAGE_HEADER {
139 log::trace!(
141 "not enough header data ({} bytes), awaiting more ({} bytes)",
142 src.len(),
143 BYTES_MESSAGE_HEADER,
144 );
145 return Ok(None);
146 }
147
148 let mut buf = io::Cursor::new(&mut *src);
149 buf.advance(BYTES_MESSAGE_ID);
150
151 let frame_length = (buf.get_u32() as usize) + BYTES_MESSAGE_ID;
154
155 if frame_length < BYTES_MESSAGE_HEADER {
157 log::trace!("invalid frame: {:?}", buf);
158 let err = std::io::Error::new(
159 std::io::ErrorKind::InvalidInput,
160 "malformed packet - invalid message length",
161 );
162 log::error!("{}", err);
163 return Err(err);
164 }
165
166 Ok(Some(frame_length))
167 }
168
169 fn decode_message(&mut self, len: usize, src: &mut BytesMut) -> io::Result<Option<Message>> {
171 if src.len() < len {
172 log::trace!(
174 "not enough message data ({} bytes), awaiting more ({} bytes)",
175 src.len(),
176 len
177 );
178 return Ok(None);
179 }
180
181 let mut frame = src.split_to(len);
183 let msg_id = frame.get_u8();
187 log::trace!("incoming msg id: '{}' ({})", msg_id as char, msg_id);
188 let msg_length = (frame.get_u32() as usize) - BYTES_MESSAGE_SIZE;
189 log::trace!("incoming msg length: {}", msg_length);
190
191 let msg = match msg_id {
192
193 b'B' => {
196 frame.advance(msg_length);
197 Message::Canary(len as u8)
198 },
199 b'!' => {
201 return Err(io::Error::new(io::ErrorKind::InvalidData, "expected canary error"));
202 },
203
204 MESSAGE_ID_AUTHENTICATION => {
206 let authn_case = get_u32(&mut frame, "malformed packet - invalid authentication data")?;
207 match authn_case {
208 0 => Message::AuthenticationOk(),
209 10 => {
210 let data = get_cstr(&mut frame)?;
211
212 if frame.is_empty() {
215 let err = std::io::Error::new(
216 std::io::ErrorKind::InvalidInput,
217 "malformed packet - invalid SASL mecanism data",
218 );
219 log::error!("{}", err);
220 return Err(err);
221 }
222 frame.advance(1); Message::AuthenticationSASL(data)
225 },
226 11 => {
227 let response = frame.copy_to_bytes(frame.remaining());
228
229 if response.is_empty() {
231 let err = std::io::Error::new(
232 std::io::ErrorKind::InvalidInput,
233 "malformed packet - invalid SASL response data",
234 );
235 log::error!("{}", err);
236 return Err(err);
237 }
238
239 Message::AuthenticationSASLContinue(response)
240 },
241 12 => {
242 let response = frame.copy_to_bytes(frame.remaining());
243
244 if response.is_empty() {
246 let err = std::io::Error::new(
247 std::io::ErrorKind::InvalidInput,
248 "malformed packet - invalid SASL response data",
249 );
250 log::error!("{}", err);
251 return Err(err);
252 }
253
254 Message::AuthenticationSASLFinal(response)
255 },
256 _ => {
257 let err = std::io::Error::new(
258 std::io::ErrorKind::InvalidInput,
259 "malformed packet - invalid SASL identifier",
260 );
261 log::error!("{}", err);
262 return Err(err);
263 }
264 }
265 },
266 MESSAGE_ID_BACKEND_KEY_DATA => {
267 let process = get_u32(&mut frame, "malformed packet - invalid key data")?;
268 let secret_key = get_u32(&mut frame, "malformed packet - invalid key data")?;
269 Message::BackendKeyData { process, secret_key }
270 },
271 MESSAGE_ID_COMMAND_COMPLETE => {
272 let command = get_cstr(&mut frame)?;
273 Message::CommandComplete(command)
274 },
275 MESSAGE_ID_DATA_ROW => {
276 let fields = self.get_data_row_fields(&mut frame)?;
277 Message::DataRow(fields)
278 },
279 MESSAGE_ID_ERROR_RESPONSE => {
280 let unparsed_fields = frame.copy_to_bytes(msg_length);
282 Message::ErrorResponse(unparsed_fields)
283 },
284 MESSAGE_ID_EMPTY_QUERY_RESPONSE => Message::EmptyQueryResponse(),
285 MESSAGE_ID_PARAMETER_STATUS => {
286 let parameter = get_cstr(&mut frame)?;
287 let value = get_cstr(&mut frame)?;
288 Message::ParameterStatus { parameter, value }
289 },
290 MESSAGE_ID_READY_FOR_QUERY => {
291 let status = get_u8(&mut frame, "malformed packet - missing status indicator")?;
292 match status {
293 b'I' | b'T'| b'E' => Message::ReadyForQuery(status),
294 _ => {
295 let err = std::io::Error::new(
296 std::io::ErrorKind::InvalidInput,
297 "malformed packet - invalid status indicator",
298 );
299 log::error!("{}", err);
300 return Err(err);
301 },
302 }
303 },
304 MESSAGE_ID_ROW_DESCRIPTION => {
305 let descriptions = self.get_row_descriptions(&mut frame)?;
306 Message::RowDescription(descriptions)
307 },
308 _ => {
309 let bytes = frame.copy_to_bytes(msg_length);
310 unimplemented!("msg_id: {} ({:?})", msg_id, bytes);
311 },
312 };
313
314 if !frame.is_empty() {
316 log::trace!("invalid frame: {:?}", frame);
317 let err = std::io::Error::new(
318 std::io::ErrorKind::InvalidInput,
319 "malformed packet - invalid message length",
320 );
321 log::error!("{}", err);
322 return Err(err);
323 }
324
325 log::debug!("decoded message frame: {:?}", msg);
326 Ok(Some(msg))
327 }
328
329 fn get_row_descriptions(&mut self, buf: &mut BytesMut) -> io::Result<Vec<RowDescription>> {
331 let mut columns = get_u16(buf, "malformed packet - invalid data size")?;
332 log::trace!("decoded number of description columns: {}", columns);
333
334 let mut decoded = Vec::new();
335
336 const BYTES_ROW_DESCRIPTION_COMMON_LENGTH: usize = 18;
337 while columns > 0 {
338 let column_name = get_cstr(buf)?;
339
340 if buf.remaining() < BYTES_ROW_DESCRIPTION_COMMON_LENGTH {
341 let err = std::io::Error::new(
342 std::io::ErrorKind::InvalidInput,
343 "malformed packet - invalid row description structure",
344 );
345 log::error!("{}", err);
346 return Err(err);
347 }
348
349 let description = RowDescription {
350 name: column_name,
351 table_oid: get_u32(buf, "malformed packet - invalid data size")?,
352 column_attr: get_u16(buf, "malformed packet - invalid data size")?,
353 data_type_oid: get_u32(buf, "malformed packet - invalid data size")?,
354 data_type_size: get_i16(buf, "malformed packet - invalid data size")?,
355 type_modifier: get_i32(buf, "malformed packet - invalid data size")?,
356 format: get_u16(buf, "malformed packet - invalid data size")?,
357 };
358
359 log::trace!("decoded row description: {:?}", description);
360 decoded.push(description);
361 columns -= 1;
362 }
363
364 Ok(decoded)
365 }
366
367 fn get_data_row_fields(&mut self, buf: &mut BytesMut) -> io::Result<Vec<Bytes>> {
369 let mut fields = buf.get_u16();
370 log::trace!("decoded number of row fields: {}", fields);
371
372 let mut decoded = Vec::new();
373
374 const BYTES_DATA_ROW_FIELD_LENGTH: usize = 4;
375 while fields > 0 {
376 let value = get_bytes(
377 buf,
378 BYTES_DATA_ROW_FIELD_LENGTH,
379 "malformed packet - invalid field size",
380 )?;
381
382 log::trace!("decoded field: {:?}", value);
383 decoded.push(value);
384 fields -= 1;
385 }
386
387 Ok(decoded)
388 }
389
390 fn encode_header(&mut self, msg_id: u8, msg_size: usize, dst: &mut BytesMut) {
394 dst.reserve(BYTES_MESSAGE_HEADER + msg_size);
395 dst.put_u8(msg_id);
396 dst.put_u32((BYTES_MESSAGE_SIZE + msg_size) as u32);
397 }
398}
399
400impl PostgresMessage for Message {}
401impl SQLMessage for Message {}
402
403impl Decoder for Codec {
404 type Item = Message;
405 type Error = io::Error;
406
407 fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Self::Item>> {
408 let msg_length = match self.state {
409 DecodeState::Head => match self.decode_header(src)? {
410 None => return Ok(None),
412 Some(length) => {
414 self.state = DecodeState::Message(length);
415
416 src.reserve(length);
419 log::trace!("stream buffer capacity: {} bytes", src.capacity());
420
421 length
422 }
423 },
424 DecodeState::Message(length) => length,
425 };
426 log::trace!("decoded frame length: {} bytes", msg_length);
427
428 match self.decode_message(msg_length, src)? {
429 None => Ok(None),
431 Some(msg) => {
433 self.state = DecodeState::Head;
434
435 src.reserve(BYTES_MESSAGE_HEADER);
437 log::trace!("stream buffer capacity: {} bytes", src.capacity());
438
439 Ok(Some(msg))
440 }
441 }
442 }
443}
444
445impl Encoder<Message> for Codec {
446 type Error = io::Error;
447
448 fn encode(&mut self, msg: Message, dst: &mut BytesMut) -> Result<(), io::Error> {
449 match msg {
452 Message::AuthenticationOk() => {
453 self.encode_header(MESSAGE_ID_AUTHENTICATION, 4, dst);
454 dst.put_i32(0);
455 }
456 Message::AuthenticationSASL(data) => {
457 self.encode_header(MESSAGE_ID_AUTHENTICATION, 4 + data.len() + 1 + 1, dst);
458 dst.put_i32(10);
459 put_cstr(&data, dst);
460 dst.put_u8(0); }
462 Message::AuthenticationSASLContinue(response) => {
463 self.encode_header(MESSAGE_ID_AUTHENTICATION, 4 + response.len(), dst);
464 dst.put_i32(11);
465 dst.put(response);
466 }
467 Message::AuthenticationSASLFinal(response) => {
468 self.encode_header(MESSAGE_ID_AUTHENTICATION, 4 + response.len(), dst);
469 dst.put_i32(12);
470 dst.put(response);
471 }
472 Message::BackendKeyData {
473 process,
474 secret_key,
475 } => {
476 self.encode_header(MESSAGE_ID_BACKEND_KEY_DATA, 4 + 4, dst);
477 dst.put_i32(process as i32);
478 dst.put_i32(secret_key as i32);
479 }
480 Message::CommandComplete(command) => {
481 self.encode_header(MESSAGE_ID_COMMAND_COMPLETE, command.len() + 1, dst);
482 put_cstr(&command, dst);
483 }
484 Message::DataRow(fields) => {
485 let mut msg_size = 2;
486 for field in fields.iter() {
487 msg_size += field.len() + 4;
488 }
489
490 self.encode_header(MESSAGE_ID_DATA_ROW, msg_size, dst);
491 dst.put_u16(fields.len() as u16);
492
493 for field in fields.iter() {
494 put_bytes(field, dst)
495 }
496 }
497 Message::EmptyQueryResponse() => {
498 self.encode_header(MESSAGE_ID_EMPTY_QUERY_RESPONSE, 0, dst);
499 }
500 Message::ErrorResponse(unparsed_fields) => {
501 self.encode_header(MESSAGE_ID_ERROR_RESPONSE, unparsed_fields.len(), dst);
502 dst.put(unparsed_fields);
503 }
504 Message::ParameterStatus { parameter, value } => {
505 self.encode_header(
506 MESSAGE_ID_PARAMETER_STATUS,
507 parameter.len() + 1 + value.len() + 1,
508 dst,
509 );
510 put_cstr(¶meter, dst);
511 put_cstr(&value, dst);
512 }
513 Message::ReadyForQuery(status) => {
514 self.encode_header(MESSAGE_ID_READY_FOR_QUERY, 1, dst);
515 dst.put_u8(status);
516 }
517 Message::RowDescription(descriptions) => {
518 let mut msg_size = 2;
519 for column in descriptions.iter() {
520 msg_size += column.name.len() + 1 + 4 + 2 + 4 + 2 + 4 + 2;
521 }
522
523 self.encode_header(MESSAGE_ID_ROW_DESCRIPTION, msg_size, dst);
524 dst.put_u16(descriptions.len() as u16);
525
526 for column in descriptions.iter() {
527 put_cstr(&column.name, dst);
528 dst.put_u32(column.table_oid);
529 dst.put_u16(column.column_attr);
530 dst.put_u32(column.data_type_oid);
531 dst.put_i16(column.data_type_size);
532 dst.put_i32(column.type_modifier);
533 dst.put_u16(column.format);
534 }
535 }
536 other => {
537 unimplemented!("msg: {:?}", other)
538 }
539 }
540
541 Ok(())
544 }
545}
546
547impl Default for Codec {
548 fn default() -> Self {
549 Self::new()
550 }
551}