1use std::collections::BTreeMap;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4
5use super::codec;
6use super::DecodeContext;
7use super::Message;
8use super::ProtocolVersion;
9use crate::error::{PgWireError, PgWireResult};
10
11pub(crate) const MINIMUM_STARTUP_MESSAGE_LEN: usize = 8;
12pub(crate) const MAXIMUM_STARTUP_MESSAGE_LEN: usize = super::SMALL_PACKET_SIZE_LIMIT;
16
17#[non_exhaustive]
19#[derive(PartialEq, Eq, Debug, new)]
20pub struct Startup {
21 #[new(value = "3")]
22 pub protocol_number_major: u16,
23 #[new(value = "0")]
24 pub protocol_number_minor: u16,
25 #[new(default)]
26 pub parameters: BTreeMap<String, String>,
27}
28
29impl Default for Startup {
30 fn default() -> Startup {
31 Startup::new()
32 }
33}
34
35impl Startup {
36 pub const PROTOCOL_VERSION_3_0: i32 = 196608;
37 pub const PROTOCOL_VERSION_3_2: i32 = 196610;
38
39 pub const PG_PROTOCOL_EARLIEST: u16 = 3;
40 pub const PG_PROTOCOL_LATEST: u16 = 3;
41}
42
43impl Message for Startup {
44 fn message_length(&self) -> usize {
45 let param_length = self
46 .parameters
47 .iter()
48 .map(|(k, v)| k.len() + v.len() + 2)
49 .sum::<usize>();
50 9 + param_length
52 }
53
54 #[inline]
55 fn max_message_length() -> usize {
56 MAXIMUM_STARTUP_MESSAGE_LEN
57 }
58
59 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
60 buf.put_u16(self.protocol_number_major);
62 buf.put_u16(self.protocol_number_minor);
63
64 for (k, v) in self.parameters.iter() {
66 codec::put_cstring(buf, k);
67 codec::put_cstring(buf, v);
68 }
69 codec::put_cstring(buf, "");
71
72 Ok(())
73 }
74
75 fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
76 codec::decode_packet(buf, 0, Self::max_message_length(), |buf, full_len| {
77 Self::decode_body(buf, full_len, ctx)
78 })
79 }
80
81 fn decode_body(buf: &mut BytesMut, msg_len: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
82 if msg_len <= MINIMUM_STARTUP_MESSAGE_LEN {
87 return Err(PgWireError::InvalidStartupMessage);
88 }
89
90 let protocol_number_major = buf.get_u16();
92
93 if !(Self::PG_PROTOCOL_EARLIEST..=Self::PG_PROTOCOL_LATEST).contains(&protocol_number_major)
95 {
96 return Err(PgWireError::InvalidStartupMessage);
97 }
98
99 let protocol_number_minor = buf.get_u16();
100
101 let mut parameters = BTreeMap::new();
103 while let Some(key) = codec::get_cstring(buf) {
104 let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
105 parameters.insert(key, value);
106 }
107
108 Ok(Startup {
109 protocol_number_major,
110 protocol_number_minor,
111 parameters,
112 })
113 }
114}
115
116#[non_exhaustive]
118#[derive(PartialEq, Eq, Debug)]
119pub enum Authentication {
120 Ok, CleartextPassword, KerberosV5, MD5Password(Vec<u8>), SASL(Vec<String>), SASLContinue(Bytes), SASLFinal(Bytes), }
136
137pub const MESSAGE_TYPE_BYTE_AUTHENTICATION: u8 = b'R';
138
139impl Message for Authentication {
140 #[inline]
141 fn message_type() -> Option<u8> {
142 Some(MESSAGE_TYPE_BYTE_AUTHENTICATION)
143 }
144
145 #[inline]
146 fn max_message_length() -> usize {
147 super::SMALL_BACKEND_PACKET_SIZE_LIMIT
148 }
149
150 #[inline]
151 fn message_length(&self) -> usize {
152 match self {
153 Authentication::Ok | Authentication::CleartextPassword | Authentication::KerberosV5 => {
154 8
155 }
156 Authentication::MD5Password(_) => 12,
157 Authentication::SASL(methods) => {
158 8 + methods.iter().map(|v| v.len() + 1).sum::<usize>() + 1
159 }
160 Authentication::SASLContinue(data) => 8 + data.len(),
161 Authentication::SASLFinal(data) => 8 + data.len(),
162 }
163 }
164
165 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
166 match self {
167 Authentication::Ok => buf.put_i32(0),
168 Authentication::CleartextPassword => buf.put_i32(3),
169 Authentication::KerberosV5 => buf.put_i32(2),
170 Authentication::MD5Password(salt) => {
171 buf.put_i32(5);
172 buf.put_slice(salt.as_ref());
173 }
174 Authentication::SASL(methods) => {
175 buf.put_i32(10);
176 for method in methods {
177 codec::put_cstring(buf, method);
178 }
179 buf.put_u8(b'\0');
180 }
181 Authentication::SASLContinue(data) => {
182 buf.put_i32(11);
183 buf.put_slice(data.as_ref());
184 }
185 Authentication::SASLFinal(data) => {
186 buf.put_i32(12);
187 buf.put_slice(data.as_ref());
188 }
189 }
190 Ok(())
191 }
192
193 fn decode_body(buf: &mut BytesMut, msg_len: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
194 let code = buf.get_i32();
195 let msg = match code {
196 0 => Authentication::Ok,
197
198 2 => Authentication::KerberosV5,
199 3 => Authentication::CleartextPassword,
200 5 => {
201 let mut salt_vec = vec![0; 4];
202 buf.copy_to_slice(&mut salt_vec);
203 Authentication::MD5Password(salt_vec)
204 }
205 10 => {
206 let mut methods = Vec::new();
207 while let Some(method) = codec::get_cstring(buf) {
208 methods.push(method);
209 }
210 Authentication::SASL(methods)
211 }
212 11 => Authentication::SASLContinue(buf.split_to(msg_len - 8).freeze()),
213 12 => Authentication::SASLFinal(buf.split_to(msg_len - 8).freeze()),
214 _ => {
215 return Err(PgWireError::InvalidAuthenticationMessageCode(code));
216 }
217 };
218
219 Ok(msg)
220 }
221}
222
223pub const MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY: u8 = b'p';
224
225#[non_exhaustive]
240#[derive(Debug)]
241pub enum PasswordMessageFamily {
242 Raw(BytesMut),
244 Password(Password),
246 SASLInitialResponse(SASLInitialResponse),
248 SASLResponse(SASLResponse),
250}
251
252impl Message for PasswordMessageFamily {
253 fn message_type() -> Option<u8> {
254 Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
255 }
256
257 fn message_length(&self) -> usize {
258 match self {
259 PasswordMessageFamily::Raw(body) => body.len() + 4,
260 PasswordMessageFamily::Password(inner) => inner.message_length(),
261 PasswordMessageFamily::SASLInitialResponse(inner) => inner.message_length(),
262 PasswordMessageFamily::SASLResponse(inner) => inner.message_length(),
263 }
264 }
265
266 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
267 match self {
268 PasswordMessageFamily::Raw(body) => {
269 buf.put_slice(body.as_ref());
270 Ok(())
271 }
272 PasswordMessageFamily::Password(inner) => inner.encode_body(buf),
273 PasswordMessageFamily::SASLInitialResponse(inner) => inner.encode_body(buf),
274 PasswordMessageFamily::SASLResponse(inner) => inner.encode_body(buf),
275 }
276 }
277
278 fn decode_body(
279 buf: &mut BytesMut,
280 full_len: usize,
281 _ctx: &DecodeContext,
282 ) -> PgWireResult<Self> {
283 let body = buf.split_to(full_len - 4);
284 Ok(PasswordMessageFamily::Raw(body))
285 }
286}
287
288impl PasswordMessageFamily {
289 pub fn into_password(self) -> PgWireResult<Password> {
295 if let PasswordMessageFamily::Raw(mut body) = self {
296 let len = body.len() + 4;
297 Password::decode_body(&mut body, len, &DecodeContext::default())
298 } else {
299 unreachable!(
300 "Do not coerce password message when it has a concrete type {:?}",
301 self
302 )
303 }
304 }
305
306 pub fn into_sasl_initial_response(self) -> PgWireResult<SASLInitialResponse> {
312 if let PasswordMessageFamily::Raw(mut body) = self {
313 let len = body.len() + 4;
314 SASLInitialResponse::decode_body(&mut body, len, &DecodeContext::default())
315 } else {
316 unreachable!(
317 "Do not coerce password message when it has a concrete type {:?}",
318 self
319 )
320 }
321 }
322
323 pub fn into_sasl_response(self) -> PgWireResult<SASLResponse> {
329 if let PasswordMessageFamily::Raw(mut body) = self {
330 let len = body.len() + 4;
331 SASLResponse::decode_body(&mut body, len, &DecodeContext::default())
332 } else {
333 unreachable!(
334 "Do not coerce password message when it has a concrete type {:?}",
335 self
336 )
337 }
338 }
339}
340
341#[non_exhaustive]
343#[derive(PartialEq, Eq, Debug, new)]
344pub struct Password {
345 pub password: String,
346}
347
348impl Message for Password {
349 #[inline]
350 fn message_type() -> Option<u8> {
351 Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
352 }
353
354 fn message_length(&self) -> usize {
355 5 + self.password.len()
356 }
357
358 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
359 codec::put_cstring(buf, &self.password);
360
361 Ok(())
362 }
363
364 fn decode_body(buf: &mut BytesMut, _: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
365 let pass = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
366
367 Ok(Password::new(pass))
368 }
369}
370
371#[non_exhaustive]
373#[derive(PartialEq, Eq, Debug, new)]
374pub struct ParameterStatus {
375 pub name: String,
376 pub value: String,
377}
378
379pub const MESSAGE_TYPE_BYTE_PARAMETER_STATUS: u8 = b'S';
380
381impl Message for ParameterStatus {
382 #[inline]
383 fn message_type() -> Option<u8> {
384 Some(MESSAGE_TYPE_BYTE_PARAMETER_STATUS)
385 }
386
387 #[inline]
388 fn max_message_length() -> usize {
389 super::SMALL_BACKEND_PACKET_SIZE_LIMIT
390 }
391
392 fn message_length(&self) -> usize {
393 4 + 2 + self.name.len() + self.value.len()
394 }
395
396 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
397 codec::put_cstring(buf, &self.name);
398 codec::put_cstring(buf, &self.value);
399
400 Ok(())
401 }
402
403 fn decode_body(buf: &mut BytesMut, _: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
404 let name = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
405 let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
406
407 Ok(ParameterStatus::new(name, value))
408 }
409}
410
411#[derive(Debug, PartialEq, Eq, Clone, Hash)]
417pub enum SecretKey {
418 I32(i32),
419 Bytes(Bytes),
420}
421
422impl Default for SecretKey {
423 fn default() -> Self {
424 SecretKey::I32(0)
425 }
426}
427
428impl SecretKey {
429 pub fn as_i32(&self) -> Option<i32> {
433 match self {
434 Self::I32(v) => Some(*v),
435 Self::Bytes(v) => {
436 if v.len() == 4 {
437 Some((&v[..]).get_i32())
438 } else {
439 None
440 }
441 }
442 }
443 }
444
445 fn validate(&self) -> PgWireResult<()> {
446 match self {
447 SecretKey::I32(_) => Ok(()),
448 SecretKey::Bytes(key_bytes) => {
449 let len = key_bytes.len();
450 Self::validate_bytes_len(len)
451 }
452 }
453 }
454
455 fn validate_bytes_len(data_len: usize) -> PgWireResult<()> {
456 if !(4..=256).contains(&data_len) {
457 return Err(PgWireError::InvalidSecretKey);
458 }
459 Ok(())
460 }
461
462 pub(crate) fn len(&self) -> usize {
463 match self {
464 SecretKey::I32(_) => 4,
465 SecretKey::Bytes(key_bytes) => key_bytes.len(),
466 }
467 }
468
469 pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
470 match self {
471 SecretKey::I32(key) => buf.put_i32(*key),
472 SecretKey::Bytes(key) => {
473 self.validate()?;
474 buf.put_slice(key)
475 }
476 }
477 Ok(())
478 }
479
480 pub fn decode(buf: &mut BytesMut, data_len: usize, ctx: &DecodeContext) -> PgWireResult<Self> {
481 Self::validate_bytes_len(data_len)?;
482
483 match ctx.protocol_version {
484 ProtocolVersion::PROTOCOL3_2 => Ok(SecretKey::Bytes(buf.split_to(data_len).freeze())),
485 ProtocolVersion::PROTOCOL3_0 => Ok(SecretKey::I32(buf.get_i32())),
486 }
487 }
488}
489
490#[non_exhaustive]
493#[derive(PartialEq, Eq, Debug, new)]
494pub struct BackendKeyData {
495 pub pid: i32,
496 pub secret_key: SecretKey,
497}
498
499pub const MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA: u8 = b'K';
500
501impl Message for BackendKeyData {
502 #[inline]
503 fn message_type() -> Option<u8> {
504 Some(MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA)
505 }
506
507 #[inline]
508 fn max_message_length() -> usize {
509 super::SMALL_BACKEND_PACKET_SIZE_LIMIT
510 }
511
512 #[inline]
513 fn message_length(&self) -> usize {
514 8 + self.secret_key.len()
515 }
516
517 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
518 buf.put_i32(self.pid);
519 self.secret_key.encode(buf)?;
520
521 Ok(())
522 }
523
524 fn decode_body(buf: &mut BytesMut, msg_len: usize, ctx: &DecodeContext) -> PgWireResult<Self> {
525 let pid = buf.get_i32();
526 let secret_key = SecretKey::decode(buf, msg_len - 8, ctx)?;
528
529 Ok(BackendKeyData { pid, secret_key })
530 }
531}
532
533#[non_exhaustive]
540#[derive(PartialEq, Eq, Debug, new)]
541pub struct SslRequest;
542
543impl SslRequest {
544 pub const BODY_MAGIC_NUMBER: i32 = 80877103;
545 pub const BODY_SIZE: usize = MINIMUM_STARTUP_MESSAGE_LEN;
546
547 pub fn is_ssl_request_packet(buf: &[u8]) -> bool {
548 if buf.remaining() >= Self::BODY_SIZE {
549 let magic_code = (&buf[4..8]).get_i32();
550 magic_code == Self::BODY_MAGIC_NUMBER
551 } else {
552 false
553 }
554 }
555}
556
557impl Message for SslRequest {
558 #[inline]
559 fn message_type() -> Option<u8> {
560 None
561 }
562
563 #[inline]
564 fn message_length(&self) -> usize {
565 Self::BODY_SIZE
566 }
567
568 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
569 buf.put_i32(Self::BODY_MAGIC_NUMBER);
570 Ok(())
571 }
572
573 fn decode_body(
574 _buf: &mut BytesMut,
575 _full_len: usize,
576 _ctx: &DecodeContext,
577 ) -> PgWireResult<Self> {
578 unreachable!();
579 }
580
581 fn decode(buf: &mut BytesMut, _ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
586 if buf.remaining() >= Self::BODY_SIZE {
587 if Self::is_ssl_request_packet(buf) {
588 buf.advance(8);
589 Ok(Some(SslRequest))
590 } else {
591 Err(PgWireError::InvalidSSLRequestMessage)
592 }
593 } else {
594 Ok(None)
595 }
596 }
597}
598
599#[non_exhaustive]
605#[derive(PartialEq, Eq, Debug, new)]
606pub struct GssEncRequest;
607
608impl GssEncRequest {
609 pub const BODY_MAGIC_NUMBER: i32 = 80877104;
610 pub const BODY_SIZE: usize = 8;
611
612 pub fn is_gss_enc_request_packet(buf: &[u8]) -> bool {
613 if buf.remaining() >= Self::BODY_SIZE {
614 let magic_code = (&buf[4..8]).get_i32();
615 magic_code == Self::BODY_MAGIC_NUMBER
616 } else {
617 false
618 }
619 }
620}
621
622impl Message for GssEncRequest {
623 #[inline]
624 fn message_type() -> Option<u8> {
625 None
626 }
627
628 #[inline]
629 fn message_length(&self) -> usize {
630 Self::BODY_SIZE
631 }
632
633 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
634 buf.put_i32(Self::BODY_MAGIC_NUMBER);
635 Ok(())
636 }
637
638 fn decode_body(
639 _buf: &mut BytesMut,
640 _full_len: usize,
641 _ctx: &DecodeContext,
642 ) -> PgWireResult<Self> {
643 unreachable!();
644 }
645
646 fn decode(buf: &mut BytesMut, _ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
651 if buf.remaining() >= Self::BODY_SIZE {
652 if Self::is_gss_enc_request_packet(buf) {
653 buf.advance(8);
654 Ok(Some(GssEncRequest))
655 } else {
656 Err(PgWireError::InvalidGssEncRequestMessage)
657 }
658 } else {
659 Ok(None)
660 }
661 }
662}
663
664#[non_exhaustive]
665#[derive(PartialEq, Eq, Debug, new)]
666pub struct SASLInitialResponse {
667 pub auth_method: String,
668 pub data: Option<Bytes>,
669}
670
671impl Message for SASLInitialResponse {
672 #[inline]
673 fn message_type() -> Option<u8> {
674 Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
675 }
676
677 #[inline]
678 fn message_length(&self) -> usize {
679 4 + self.auth_method.len() + 1 + 4 + self.data.as_ref().map(|b| b.len()).unwrap_or(0)
680 }
681
682 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
683 codec::put_cstring(buf, &self.auth_method);
684 if let Some(ref data) = self.data {
685 buf.put_i32(data.len() as i32);
686 buf.put_slice(data.as_ref());
687 } else {
688 buf.put_i32(-1);
689 }
690 Ok(())
691 }
692
693 fn decode_body(
694 buf: &mut BytesMut,
695 _full_len: usize,
696 _ctx: &DecodeContext,
697 ) -> PgWireResult<Self> {
698 let auth_method = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
699 let data_len = buf.get_i32();
700 let data = if data_len == -1 {
701 None
702 } else {
703 Some(buf.split_to(data_len as usize).freeze())
704 };
705
706 Ok(SASLInitialResponse { auth_method, data })
707 }
708}
709
710#[non_exhaustive]
711#[derive(PartialEq, Eq, Debug, new)]
712pub struct SASLResponse {
713 pub data: Bytes,
714}
715
716impl Message for SASLResponse {
717 #[inline]
718 fn message_type() -> Option<u8> {
719 Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
720 }
721
722 #[inline]
723 fn message_length(&self) -> usize {
724 4 + self.data.len()
725 }
726
727 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
728 buf.put_slice(self.data.as_ref());
729 Ok(())
730 }
731
732 fn decode_body(
733 buf: &mut BytesMut,
734 full_len: usize,
735 _ctx: &DecodeContext,
736 ) -> PgWireResult<Self> {
737 let data = buf.split_to(full_len - 4).freeze();
738 Ok(SASLResponse { data })
739 }
740}
741
742#[non_exhaustive]
743#[derive(PartialEq, Eq, Debug, new)]
744pub struct NegotiateProtocolVersion {
745 pub newest_minor_protocol: i32,
746 pub unsupported_options: Vec<String>,
747}
748
749pub const MESSAGE_TYPE_BYTE_NEGOTIATE_PROTOCOL_VERSION: u8 = b'v';
750
751impl Message for NegotiateProtocolVersion {
752 #[inline]
753 fn message_type() -> Option<u8> {
754 Some(MESSAGE_TYPE_BYTE_NEGOTIATE_PROTOCOL_VERSION)
755 }
756
757 #[inline]
758 fn max_message_length() -> usize {
759 super::SMALL_BACKEND_PACKET_SIZE_LIMIT
760 }
761
762 #[inline]
763 fn message_length(&self) -> usize {
764 12 + self
765 .unsupported_options
766 .iter()
767 .map(|s| s.len() + 1)
768 .sum::<usize>()
769 }
770
771 fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
772 buf.put_i32(self.newest_minor_protocol);
773 buf.put_i32(self.unsupported_options.len() as i32);
774
775 for s in &self.unsupported_options {
776 codec::put_cstring(buf, s);
777 }
778
779 Ok(())
780 }
781
782 fn decode_body(
783 buf: &mut BytesMut,
784 _full_len: usize,
785 _ctx: &DecodeContext,
786 ) -> PgWireResult<Self> {
787 let version = buf.get_i32();
788 let option_count = buf.get_i32();
789 let mut options = Vec::with_capacity(option_count as usize);
790
791 for _ in 0..option_count {
792 options.push(codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()))
793 }
794
795 Ok(Self {
796 newest_minor_protocol: version,
797 unsupported_options: options,
798 })
799 }
800}