1use std::{
6 any::Any,
7 collections::HashMap,
8 convert::TryFrom,
9 fmt::{self, Debug, Display, Formatter},
10 io::{Cursor, Error},
11 sync::Arc,
12};
13
14use async_trait::async_trait;
15
16use bytes::BufMut;
17use tokio::io::AsyncReadExt;
18
19#[cfg(feature = "with-chrono")]
20use crate::TimestampValue;
21use crate::{buffer, BindValue, FromProtocolValue, PgType, PgTypeId, ProtocolError};
22
23const DEFAULT_CAPACITY: usize = 64;
24
25#[derive(Debug, PartialEq, Clone)]
26pub struct StartupMessage {
27 pub major: u16,
28 pub minor: u16,
29 pub parameters: HashMap<String, String>,
30}
31
32impl StartupMessage {
33 async fn from(mut buffer: &mut Cursor<Vec<u8>>) -> Result<Self, Error> {
34 let major = buffer.read_u16().await?;
35 let minor = buffer.read_u16().await?;
36
37 let mut parameters = HashMap::new();
38
39 loop {
40 let name = buffer::read_string(&mut buffer).await?;
41 if name.is_empty() {
42 break;
43 }
44 let value = buffer::read_string(&mut buffer).await?;
45 parameters.insert(name, value);
46 }
47
48 Ok(Self {
49 major,
50 minor,
51 parameters,
52 })
53 }
54}
55
56#[derive(Debug, PartialEq)]
58pub struct CancelRequest {
59 pub process_id: u32,
61 pub secret: u32,
63}
64
65impl CancelRequest {
66 async fn from(buffer: &mut Cursor<Vec<u8>>) -> Result<Self, Error> {
67 Ok(Self {
68 process_id: buffer.read_u32().await?,
69 secret: buffer.read_u32().await?,
70 })
71 }
72}
73
74pub const VERSION_MAJOR_SPECIAL: i16 = 1234;
76pub const VERSION_MINOR_CANCEL: i16 = 5678;
78pub const VERSION_MINOR_SSL: i16 = 5679;
80pub const VERSION_MINOR_GSSENC: i16 = 5680;
82
83pub enum InitialMessage {
85 Startup(StartupMessage),
86 CancelRequest(CancelRequest),
87 SslRequest,
88 Gssenc,
89}
90
91impl InitialMessage {
92 pub async fn from(buffer: &mut Cursor<Vec<u8>>) -> Result<InitialMessage, ProtocolError> {
93 let major = buffer.read_i16().await?;
94 let minor = buffer.read_i16().await?;
95
96 match major {
97 VERSION_MAJOR_SPECIAL => match minor {
98 VERSION_MINOR_CANCEL => Ok(InitialMessage::CancelRequest(
99 CancelRequest::from(buffer).await?,
100 )),
101 VERSION_MINOR_SSL => Ok(InitialMessage::SslRequest),
102 VERSION_MINOR_GSSENC => Ok(InitialMessage::Gssenc),
103 _ => Err(ErrorResponse::error(
104 ErrorCode::ProtocolViolation,
105 format!(
106 r#"Unsupported special version in initial message with code "{}""#,
107 minor
108 ),
109 )
110 .into()),
111 },
112 _ => {
113 buffer.set_position(0);
114 Ok(InitialMessage::Startup(StartupMessage::from(buffer).await?))
115 }
116 }
117 }
118}
119
120impl Serialize for StartupMessage {
121 const CODE: u8 = 0x00;
122
123 fn serialize(&self) -> Option<Vec<u8>> {
124 let mut buffer = Vec::with_capacity(DEFAULT_CAPACITY);
125 buffer.put_u16(self.major);
126 buffer.put_u16(self.minor);
127
128 for (name, value) in &self.parameters {
129 buffer::write_string(&mut buffer, name);
130 buffer::write_string(&mut buffer, value);
131 }
132
133 buffer.push(0);
134
135 Some(buffer)
136 }
137}
138
139#[derive(Debug)]
140pub struct NoticeResponse {
141 pub severity: NoticeSeverity,
143 pub code: ErrorCode,
144 pub message: String,
145}
146
147impl NoticeResponse {
148 pub fn warning(code: ErrorCode, message: String) -> Self {
149 Self {
150 severity: NoticeSeverity::Warning,
151 code,
152 message,
153 }
154 }
155}
156
157impl Serialize for NoticeResponse {
158 const CODE: u8 = b'N';
159
160 fn serialize(&self) -> Option<Vec<u8>> {
161 let mut buffer = Vec::with_capacity(DEFAULT_CAPACITY);
162
163 let severity = self.severity.to_string();
164 buffer.push(b'S');
165 buffer::write_string(&mut buffer, &severity);
166
167 buffer.push(b'C');
168 buffer::write_string(&mut buffer, &self.code.to_string());
169
170 buffer.push(b'M');
171 buffer::write_string(&mut buffer, &self.message);
172 buffer.push(0);
173
174 Some(buffer)
175 }
176}
177
178#[derive(thiserror::Error, Debug)]
179pub struct ErrorResponse {
180 pub severity: ErrorSeverity,
182 pub code: ErrorCode,
183 pub message: String,
184}
185
186impl Display for ErrorResponse {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(f, "ErrorResponse")
189 }
190}
191
192impl ErrorResponse {
193 pub fn new(severity: ErrorSeverity, code: ErrorCode, message: String) -> Self {
194 Self {
195 severity,
196 code,
197 message,
198 }
199 }
200
201 pub fn error(code: ErrorCode, message: String) -> Self {
202 Self {
203 severity: ErrorSeverity::Error,
204 code,
205 message,
206 }
207 }
208
209 pub fn fatal(code: ErrorCode, message: String) -> Self {
210 Self {
211 severity: ErrorSeverity::Fatal,
212 code,
213 message,
214 }
215 }
216
217 pub fn query_canceled() -> Self {
218 Self {
219 severity: ErrorSeverity::Error,
220 code: ErrorCode::QueryCanceled,
221 message: "canceling statement due to user request".to_string(),
222 }
223 }
224
225 pub fn admin_shutdown() -> Self {
226 Self {
227 severity: ErrorSeverity::Fatal,
228 code: ErrorCode::AdminShutdown,
229 message: "terminating connection due to shutdown signal".to_string(),
230 }
231 }
232}
233
234impl Serialize for ErrorResponse {
235 const CODE: u8 = b'E';
236
237 fn serialize(&self) -> Option<Vec<u8>> {
238 let mut buffer = Vec::with_capacity(DEFAULT_CAPACITY);
239
240 let severity = self.severity.to_string();
241 buffer.push(b'S');
242 buffer::write_string(&mut buffer, &severity);
243 buffer.push(b'V');
244 buffer::write_string(&mut buffer, &severity);
245 buffer.push(b'C');
246 buffer::write_string(&mut buffer, &self.code.to_string());
247 buffer.push(b'M');
248 buffer::write_string(&mut buffer, &self.message);
249 buffer.push(0);
250
251 Some(buffer)
252 }
253}
254
255pub struct SSLResponse {}
256
257impl SSLResponse {
258 pub fn new() -> Self {
259 Self {}
260 }
261}
262
263impl Serialize for SSLResponse {
264 const CODE: u8 = b'N';
265
266 fn serialize(&self) -> Option<Vec<u8>> {
267 None
268 }
269}
270
271pub struct Authentication {
272 response: AuthenticationRequest,
273}
274
275impl Authentication {
276 pub fn new(response: AuthenticationRequest) -> Self {
277 Self { response }
278 }
279}
280
281impl Serialize for Authentication {
282 const CODE: u8 = b'R';
283
284 fn serialize(&self) -> Option<Vec<u8>> {
285 Some(self.response.to_bytes())
286 }
287}
288
289pub struct ReadyForQuery {
290 transaction_status: TransactionStatus,
291}
292
293impl ReadyForQuery {
294 pub fn new(transaction_status: TransactionStatus) -> Self {
295 Self { transaction_status }
296 }
297}
298
299impl Serialize for ReadyForQuery {
300 const CODE: u8 = b'Z';
301
302 fn serialize(&self) -> Option<Vec<u8>> {
303 Some(vec![self.transaction_status.to_byte()])
304 }
305}
306
307pub struct EmptyQuery {}
308
309impl EmptyQuery {
310 pub fn new() -> Self {
311 Self {}
312 }
313}
314
315impl Serialize for EmptyQuery {
316 const CODE: u8 = b'I';
317
318 fn serialize(&self) -> Option<Vec<u8>> {
319 Some(vec![])
320 }
321}
322
323pub struct BackendKeyData {
324 process_id: u32,
325 secret: u32,
326}
327
328impl BackendKeyData {
329 pub fn new(process_id: u32, secret: u32) -> Self {
330 Self { process_id, secret }
331 }
332}
333
334impl Serialize for BackendKeyData {
335 const CODE: u8 = b'K';
336
337 fn serialize(&self) -> Option<Vec<u8>> {
338 let mut buffer = Vec::with_capacity(4 + 4);
339 buffer.put_u32(self.process_id);
340 buffer.put_u32(self.secret);
341
342 Some(buffer)
343 }
344}
345
346#[derive(Debug, PartialEq)]
348pub struct PortalSuspended {}
349
350impl PortalSuspended {
351 pub fn new() -> Self {
352 Self {}
353 }
354}
355
356impl Serialize for PortalSuspended {
357 const CODE: u8 = b's';
358
359 fn serialize(&self) -> Option<Vec<u8>> {
360 Some(vec![])
361 }
362}
363
364pub struct ParameterStatus {
365 name: String,
366 value: String,
367}
368
369impl ParameterStatus {
370 pub fn new(name: String, value: String) -> Self {
371 Self { name, value }
372 }
373}
374
375impl Serialize for ParameterStatus {
376 const CODE: u8 = b'S';
377
378 fn serialize(&self) -> Option<Vec<u8>> {
379 let mut buffer = Vec::with_capacity(DEFAULT_CAPACITY);
380 buffer::write_string(&mut buffer, &self.name);
381 buffer::write_string(&mut buffer, &self.value);
382 Some(buffer)
383 }
384}
385
386pub struct BindComplete {}
388
389impl BindComplete {
390 pub fn new() -> Self {
391 Self {}
392 }
393}
394
395impl Serialize for BindComplete {
396 const CODE: u8 = b'2';
397
398 fn serialize(&self) -> Option<Vec<u8>> {
399 Some(vec![])
401 }
402}
403
404pub struct CloseComplete {}
406
407impl CloseComplete {
408 pub fn new() -> Self {
409 Self {}
410 }
411}
412
413impl Serialize for CloseComplete {
414 const CODE: u8 = b'3';
415
416 fn serialize(&self) -> Option<Vec<u8>> {
417 Some(vec![])
419 }
420}
421
422#[derive(Debug)]
424pub struct ParseComplete {}
425
426impl ParseComplete {
427 pub fn new() -> Self {
428 Self {}
429 }
430}
431
432impl Serialize for ParseComplete {
433 const CODE: u8 = b'1';
434
435 fn serialize(&self) -> Option<Vec<u8>> {
436 Some(vec![])
438 }
439}
440
441#[derive(Debug, PartialEq)]
442pub enum PortalCompletion {
443 Complete(CommandComplete),
444 Suspended(PortalSuspended),
445}
446
447#[derive(Debug, PartialEq)]
451pub enum CommandComplete {
452 Select(u32),
453 Fetch(u32),
454 Plain(String),
455}
456
457impl CommandComplete {
458 pub fn new_selection(is_select: bool, rows: u32) -> Self {
459 match is_select {
460 true => CommandComplete::Select(rows),
461 false => CommandComplete::Fetch(rows),
462 }
463 }
464}
465
466impl Serialize for CommandComplete {
467 const CODE: u8 = b'C';
468
469 fn serialize(&self) -> Option<Vec<u8>> {
470 let mut buffer = Vec::with_capacity(DEFAULT_CAPACITY);
471 match self {
472 CommandComplete::Select(rows) => {
473 buffer::write_string(&mut buffer, &format!("SELECT {}", rows))
474 }
475 CommandComplete::Fetch(rows) => {
476 buffer::write_string(&mut buffer, &format!("FETCH {}", rows))
477 }
478 CommandComplete::Plain(tag) => buffer::write_string(&mut buffer, tag),
479 }
480
481 Some(buffer)
482 }
483}
484
485pub struct NoData {}
486
487impl NoData {
488 pub fn new() -> Self {
489 Self {}
490 }
491}
492
493impl Serialize for NoData {
494 const CODE: u8 = b'n';
495
496 fn serialize(&self) -> Option<Vec<u8>> {
497 Some(vec![])
498 }
499}
500
501pub struct EmptyQueryResponse {}
502
503impl EmptyQueryResponse {
504 pub fn new() -> Self {
505 Self {}
506 }
507}
508
509impl Serialize for EmptyQueryResponse {
510 const CODE: u8 = b'I';
511
512 fn serialize(&self) -> Option<Vec<u8>> {
513 Some(vec![])
514 }
515}
516
517#[derive(Debug, Clone)]
518pub struct ParameterDescription {
519 pub parameters: Vec<PgTypeId>,
520}
521
522impl ParameterDescription {
523 pub fn new(parameters: Vec<PgTypeId>) -> Self {
524 Self { parameters }
525 }
526
527 pub fn get(&self, i: usize) -> Option<&PgTypeId> {
528 self.parameters.get(i)
529 }
530}
531
532impl Serialize for ParameterDescription {
533 const CODE: u8 = b't';
534
535 fn serialize(&self) -> Option<Vec<u8>> {
536 let mut buffer: Vec<u8> = Vec::with_capacity(6 * self.parameters.len());
537 let size = i16::try_from(self.parameters.len()).unwrap();
539 buffer.put_i16(size);
540
541 for parameter in &self.parameters {
542 buffer.put_i32((*parameter as u32) as i32);
543 }
544
545 Some(buffer)
546 }
547}
548
549#[derive(Debug, Clone)]
550pub struct RowDescription {
551 fields: Vec<RowDescriptionField>,
552}
553
554impl RowDescription {
555 pub fn new(fields: Vec<RowDescriptionField>) -> Self {
556 Self { fields }
557 }
558
559 pub fn len(&self) -> usize {
560 self.fields.len()
561 }
562
563 pub fn get_formats(&self) -> Vec<Format> {
567 self.fields.iter().map(|f| f.format).collect()
568 }
569}
570
571impl Serialize for RowDescription {
572 const CODE: u8 = b'T';
573
574 fn serialize(&self) -> Option<Vec<u8>> {
575 let size = u16::try_from(self.fields.len()).unwrap();
577 let mut buffer = Vec::with_capacity(DEFAULT_CAPACITY);
578 buffer.extend_from_slice(&size.to_be_bytes());
579
580 for field in self.fields.iter() {
581 buffer::write_string(&mut buffer, &field.name);
582 buffer.extend_from_slice(&field.table_oid.to_be_bytes());
583 buffer.extend_from_slice(&field.attribute_number.to_be_bytes());
584 buffer.extend_from_slice(&field.data_type_oid.to_be_bytes());
585 buffer.extend_from_slice(&field.data_type_size.to_be_bytes());
586 buffer.extend_from_slice(&field.type_modifier.to_be_bytes());
587 buffer.extend_from_slice(&(field.format as i16).to_be_bytes());
588 }
589
590 Some(buffer)
591 }
592}
593
594#[derive(Debug, Clone)]
595pub struct RowDescriptionField {
596 name: String,
597 table_oid: i32,
599 attribute_number: i16,
601 data_type_oid: i32,
603 data_type_size: i16,
605 type_modifier: i32,
608 format: Format,
610}
611
612impl RowDescriptionField {
613 pub fn new(name: String, typ: &PgType, format: Format) -> Self {
614 Self {
615 name,
616 table_oid: 0,
618 attribute_number: 0,
620 data_type_oid: typ.oid as i32,
621 data_type_size: typ.typlen,
622 type_modifier: -1,
623 format: if format == Format::Binary && typ.is_binary_supported() {
624 Format::Binary
625 } else {
626 Format::Text
627 },
628 }
629 }
630}
631
632#[derive(Debug, PartialEq)]
633pub struct PasswordMessage {
634 pub password: String,
635}
636
637#[async_trait]
638impl Deserialize for PasswordMessage {
639 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
640 where
641 Self: Sized,
642 {
643 Ok(Self {
644 password: buffer::read_string(&mut buffer).await?,
645 })
646 }
647}
648
649#[derive(Debug, PartialEq)]
655pub struct Parse {
656 pub name: String,
658 pub query: String,
660 pub param_types: Vec<u32>,
662}
663
664#[async_trait]
665impl Deserialize for Parse {
666 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
667 where
668 Self: Sized,
669 {
670 let name = buffer::read_string(&mut buffer).await?;
671 let query = buffer::read_string(&mut buffer).await?;
672
673 let total = buffer.read_i16().await?;
674 let mut param_types = Vec::with_capacity(total as usize);
675
676 for _ in 0..total {
677 param_types.push(buffer.read_u32().await?);
678 }
679
680 Ok(Self {
681 name,
682 query,
683 param_types,
684 })
685 }
686}
687
688#[derive(Debug, PartialEq)]
690pub struct Execute {
691 pub portal: String,
693 pub max_rows: i32,
695}
696
697#[async_trait]
698impl Deserialize for Execute {
699 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
700 where
701 Self: Sized,
702 {
703 let portal = buffer::read_string(&mut buffer).await?;
704 let max_rows = buffer.read_i32().await?;
705
706 Ok(Self { portal, max_rows })
707 }
708}
709
710#[derive(Debug, PartialEq)]
711pub enum CloseType {
712 Statement,
713 Portal,
714}
715
716#[derive(Debug, PartialEq)]
717pub struct Close {
718 pub typ: CloseType,
719 pub name: String,
721}
722
723#[async_trait]
724impl Deserialize for Close {
725 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
726 where
727 Self: Sized,
728 {
729 let typ = match buffer.read_u8().await? {
730 b'S' => CloseType::Statement,
731 b'P' => CloseType::Portal,
732 code => {
733 return Err(ErrorResponse::error(
734 ErrorCode::ProtocolViolation,
735 format!("Unknown close code: {}", code),
736 )
737 .into());
738 }
739 };
740
741 let name = buffer::read_string(&mut buffer).await?;
742
743 Ok(Self { typ, name })
744 }
745}
746
747#[derive(Debug, PartialEq)]
749pub struct Bind {
750 pub portal: String,
752 pub statement: String,
754 pub parameter_formats: Vec<Format>,
756 pub parameter_values: Vec<Option<Vec<u8>>>,
758 pub result_formats: Vec<Format>,
760}
761
762impl Bind {
763 pub fn to_bind_values(
764 &self,
765 description: &ParameterDescription,
766 ) -> Result<Vec<BindValue>, ProtocolError> {
767 let mut values = Vec::with_capacity(self.parameter_values.len());
768
769 for (idx, raw_value) in self.parameter_values.iter().enumerate() {
770 let param_tid = description.get(idx).ok_or::<ProtocolError>({
771 ErrorResponse::error(
772 ErrorCode::InternalError,
773 format!("Unknown type for parameter: {}", idx),
774 )
775 .into()
776 })?;
777
778 let param_format = match self.parameter_formats.len() {
779 0 => Format::Text,
780 1 => self.parameter_formats[0],
781 _ => self.parameter_formats[idx],
782 };
783
784 values.push(match raw_value {
785 None => BindValue::Null,
786 Some(raw_value) => match param_tid {
787 PgTypeId::TEXT => {
788 BindValue::String(String::from_protocol(raw_value, param_format)?)
789 }
790 PgTypeId::BOOL => {
791 BindValue::Bool(bool::from_protocol(raw_value, param_format)?)
792 }
793 PgTypeId::INT8 => {
794 BindValue::Int64(i64::from_protocol(raw_value, param_format)?)
795 }
796 PgTypeId::FLOAT8 => {
797 BindValue::Float64(f64::from_protocol(raw_value, param_format)?)
798 }
799 #[cfg(feature = "with-chrono")]
800 PgTypeId::TIMESTAMP => BindValue::Timestamp(TimestampValue::from_protocol(
801 raw_value,
802 param_format,
803 )?),
804 #[cfg(feature = "with-chrono")]
805 PgTypeId::DATE => {
806 BindValue::Date(chrono::NaiveDate::from_protocol(raw_value, param_format)?)
807 }
808 _ => {
809 return Err(ErrorResponse::error(
810 ErrorCode::FeatureNotSupported,
811 format!(
812 r#"Type "{:?}" is not supported for parameters decoding"#,
813 param_tid
814 ),
815 )
816 .into())
817 }
818 },
819 })
820 }
821
822 Ok(values)
823 }
824}
825
826#[async_trait]
827impl Deserialize for Bind {
828 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
829 where
830 Self: Sized,
831 {
832 let portal = buffer::read_string(&mut buffer).await?;
833 let statement = buffer::read_string(&mut buffer).await?;
834
835 let mut parameter_formats = Vec::new();
836 {
837 let total = buffer.read_i16().await?;
838 for _ in 0..total {
839 parameter_formats.push(buffer::read_format(&mut buffer).await?);
840 }
841 }
842
843 let mut parameter_values = Vec::new();
844 {
845 let total = buffer.read_i16().await?;
846 for _ in 0..total {
847 let len = buffer.read_i32().await?;
848 if len == -1 {
849 parameter_values.push(None);
850 } else {
851 let mut value = Vec::with_capacity(len as usize);
852 for _ in 0..len {
853 value.push(buffer.read_u8().await?);
854 }
855
856 parameter_values.push(Some(value));
857 }
858 }
859 }
860
861 let mut result_formats = Vec::new();
862 {
863 let total = buffer.read_i16().await?;
864
865 for _ in 0..total {
866 result_formats.push(buffer::read_format(&mut buffer).await?);
867 }
868 }
869
870 Ok(Self {
871 portal,
872 statement,
873 parameter_formats,
874 parameter_values,
875 result_formats,
876 })
877 }
878}
879
880#[derive(Debug, PartialEq)]
881pub enum DescribeType {
882 Statement,
883 Portal,
884}
885
886#[derive(Debug, PartialEq)]
888pub struct Describe {
889 pub typ: DescribeType,
890 pub name: String,
891}
892
893#[async_trait]
894impl Deserialize for Describe {
895 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
896 where
897 Self: Sized,
898 {
899 let typ = match buffer.read_u8().await? {
900 b'S' => DescribeType::Statement,
901 b'P' => DescribeType::Portal,
902 code => {
903 return Err(ErrorResponse::error(
904 ErrorCode::ProtocolViolation,
905 format!("Unknown describe code: {}", code),
906 )
907 .into());
908 }
909 };
910 let name = buffer::read_string(&mut buffer).await?;
911
912 Ok(Self { typ, name })
913 }
914}
915
916#[derive(Debug, PartialEq)]
917pub struct Query {
918 pub query: String,
919}
920
921#[async_trait]
922impl Deserialize for Query {
923 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
924 where
925 Self: Sized,
926 {
927 Ok(Self {
928 query: buffer::read_string(&mut buffer).await?,
929 })
930 }
931}
932
933#[derive(Debug, PartialEq, Clone, Copy)]
934#[repr(u8)]
935pub enum Format {
936 Text,
937 Binary,
938}
939
940pub trait FrontendMessageExtension: Send + Sync + Debug {
941 fn as_any(&self) -> &dyn Any;
942}
943
944#[derive(Debug)]
946pub enum FrontendMessage {
947 PasswordMessage(PasswordMessage),
948 Query(Query),
950 Flush,
952 Terminate,
954 Sync,
956 Parse(Parse),
958 Bind(Bind),
960 Describe(Describe),
962 Execute(Execute),
964 Close(Close),
966 Extension(Box<dyn FrontendMessageExtension>),
968}
969
970#[derive(Debug)]
972#[allow(dead_code)]
973pub enum ErrorCode {
974 SqlStatementNotYetComplete,
976 FeatureNotSupported,
978 ProtocolViolation,
980 InvalidAuthorizationSpecification,
982 InvalidPassword,
983 DataException,
985 ActiveSqlTransaction,
987 NoActiveSqlTransaction,
988 InvalidSqlStatement,
990 InvalidCursorName,
992 SyntaxErrorOrAccessRuleViolation,
994 DuplicateCursor,
995 SyntaxError,
996 TooManyConnections,
998 ConfigurationLimitExceeded,
999 ObjectNotInPrerequisiteState,
1001 QueryCanceled,
1003 AdminShutdown,
1004 SystemError,
1006 InternalError,
1008}
1009
1010impl Display for ErrorCode {
1011 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1012 let string = match self {
1013 Self::SqlStatementNotYetComplete => "03000",
1014 Self::FeatureNotSupported => "0A000",
1015 Self::ProtocolViolation => "08P01",
1016 Self::InvalidAuthorizationSpecification => "28000",
1017 Self::InvalidPassword => "28P01",
1018 Self::DataException => "22000",
1019 Self::ActiveSqlTransaction => "25001",
1020 Self::NoActiveSqlTransaction => "25P01",
1021 Self::InvalidSqlStatement => "26000",
1022 Self::InvalidCursorName => "34000",
1023 Self::SyntaxErrorOrAccessRuleViolation => "42000",
1024 Self::DuplicateCursor => "42P03",
1025 Self::SyntaxError => "42601",
1026 Self::TooManyConnections => "53300",
1027 Self::ConfigurationLimitExceeded => "53400",
1028 Self::ObjectNotInPrerequisiteState => "55000",
1029 Self::QueryCanceled => "57014",
1030 Self::AdminShutdown => "57P01",
1031 Self::SystemError => "58000",
1032 Self::InternalError => "XX000",
1033 };
1034 write!(f, "{}", string)
1035 }
1036}
1037
1038#[derive(Debug)]
1039pub enum NoticeSeverity {
1040 Warning,
1042 Notice,
1043 Debug,
1044 Info,
1045 Log,
1046}
1047
1048impl Display for NoticeSeverity {
1049 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1050 let string = match self {
1051 Self::Warning => "WARNING",
1052 Self::Notice => "NOTICE",
1053 Self::Debug => "DEBUG",
1054 Self::Info => "INFO",
1055 Self::Log => "LOG",
1056 };
1057 write!(f, "{}", string)
1058 }
1059}
1060
1061#[derive(Debug)]
1062pub enum ErrorSeverity {
1063 Error,
1065 Fatal,
1066 Panic,
1067}
1068
1069impl Display for ErrorSeverity {
1070 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1071 let string = match self {
1072 Self::Error => "ERROR",
1073 Self::Fatal => "FATAL",
1074 Self::Panic => "PANIC",
1075 };
1076 write!(f, "{}", string)
1077 }
1078}
1079
1080pub enum TransactionStatus {
1081 Idle,
1082 InTransactionBlock,
1083 }
1085
1086impl TransactionStatus {
1087 pub fn to_byte(&self) -> u8 {
1088 match self {
1089 Self::Idle => b'I',
1090 Self::InTransactionBlock => b'T',
1091 }
1093 }
1094}
1095
1096pub trait AuthenticationRequestExtension: Send + Sync {
1097 fn as_any(&self) -> &dyn Any;
1098
1099 fn to_code(&self) -> u32;
1100}
1101
1102#[derive(Clone)]
1103pub enum AuthenticationRequest {
1104 Ok,
1105 CleartextPassword,
1106 Extension(Arc<dyn AuthenticationRequestExtension>),
1107}
1108
1109impl AuthenticationRequest {
1110 pub fn to_bytes(&self) -> Vec<u8> {
1111 self.to_code().to_be_bytes().to_vec()
1112 }
1113
1114 pub fn to_code(&self) -> u32 {
1115 match self {
1116 Self::Ok => 0,
1117 Self::CleartextPassword => 3,
1118 Self::Extension(extension) => extension.to_code(),
1119 }
1120 }
1121}
1122
1123pub trait Serialize {
1124 const CODE: u8;
1125
1126 fn serialize(&self) -> Option<Vec<u8>>;
1127
1128 fn code(&self) -> u8 {
1129 Self::CODE
1130 }
1131}
1132
1133#[async_trait]
1134pub trait Deserialize {
1135 async fn deserialize(mut buffer: Cursor<Vec<u8>>) -> Result<Self, ProtocolError>
1136 where
1137 Self: Sized;
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142 use super::*;
1143 use crate::{read_message, MessageTagParserDefaultImpl, ProtocolError};
1144
1145 use std::io::Cursor;
1146
1147 fn parse_hex_dump(input: String) -> Vec<u8> {
1148 let mut result: Vec<Vec<u8>> = vec![];
1149
1150 for line in input.trim().split("\n") {
1151 let splitted = line.trim().split(" ").collect::<Vec<&str>>();
1152 let row = splitted.first().unwrap().to_string().replace(" ", "");
1153
1154 let tmp = hex::decode(row).unwrap();
1155 result.push(tmp);
1156 }
1157
1158 result.concat()
1159 }
1160
1161 #[tokio::test]
1162 async fn test_startup_message_duplex() -> Result<(), ProtocolError> {
1163 let expected_message = {
1170 let mut parameters = HashMap::new();
1171 parameters.insert("database".to_string(), "test".to_string());
1172 parameters.insert("application_name".to_string(), "psql".to_string());
1173 parameters.insert("user".to_string(), "test".to_string());
1174 parameters.insert("client_encoding".to_string(), "UTF8".to_string());
1175
1176 StartupMessage {
1177 major: 3,
1178 minor: 0,
1179 parameters,
1180 }
1181 };
1182
1183 let mut cursor = Cursor::new(vec![]);
1185 buffer::write_message(
1186 &mut bytes::BytesMut::new(),
1187 &mut cursor,
1188 expected_message.clone(),
1189 )
1190 .await?;
1191
1192 let buffer = cursor.get_ref()[..].to_vec();
1194 let mut cursor = Cursor::new(buffer);
1195 cursor.read_u32().await?;
1197
1198 let actual_message = StartupMessage::from(&mut cursor).await?;
1199 assert_eq!(actual_message, expected_message);
1200
1201 Ok(())
1202 }
1203
1204 #[tokio::test]
1205 async fn test_frontend_message_parse_parse() -> Result<(), ProtocolError> {
1206 let buffer = parse_hex_dump(
1207 r#"
1208 50 00 00 00 77 6e 61 6d 65 64 2d 73 74 6d 74 00 P...wnamed-stmt.
1209 0a 20 20 20 20 20 20 53 45 4c 45 43 54 20 6e 75 . SELECT nu
1210 6d 2c 20 73 74 72 2c 20 62 6f 6f 6c 0a 20 20 20 m, str, bool.
1211 20 20 20 46 52 4f 4d 20 74 65 73 74 64 61 74 61 FROM testdata
1212 0a 20 20 20 20 20 20 57 48 45 52 45 20 6e 75 6d . WHERE num
1213 20 3d 20 24 31 20 41 4e 44 20 73 74 72 20 3d 20 = $1 AND str =
1214 24 32 20 41 4e 44 20 62 6f 6f 6c 20 3d 20 24 33 $2 AND bool = $3
1215 0a 20 20 20 20 00 00 00 . ...
1216 "#
1217 .to_string(),
1218 );
1219 let mut cursor = Cursor::new(buffer);
1220
1221 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1222 match message {
1223 FrontendMessage::Parse(parse) => {
1224 assert_eq!(
1225 parse,
1226 Parse {
1227 name: "named-stmt".to_string(),
1228 query: "\n SELECT num, str, bool\n FROM testdata\n WHERE num = $1 AND str = $2 AND bool = $3\n ".to_string(),
1229 param_types: vec![],
1230 },
1231 )
1232 }
1233 _ => panic!("Wrong message, must be Parse"),
1234 }
1235
1236 Ok(())
1237 }
1238
1239 #[tokio::test]
1240 async fn test_frontend_message_parse_bind_variant1() -> Result<(), ProtocolError> {
1241 let buffer = parse_hex_dump(
1242 r#"
1243 42 00 00 00 2d 00 6e 61 6d 65 64 2d 73 74 6d 74 B...-.named-stmt
1244 00 00 00 00 03 00 00 00 01 35 00 00 00 04 74 65 .........5....te
1245 73 74 00 00 00 04 74 72 75 65 00 01 00 00 st....true....
1246 "#
1247 .to_string(),
1248 );
1249 let mut cursor = Cursor::new(buffer);
1250
1251 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1252 match message {
1253 FrontendMessage::Bind(bind) => {
1254 assert_eq!(
1255 bind,
1256 Bind {
1257 portal: "".to_string(),
1258 statement: "named-stmt".to_string(),
1259 parameter_formats: vec![],
1260 parameter_values: vec![
1261 Some(vec![53]),
1262 Some(vec![116, 101, 115, 116]),
1263 Some(vec![116, 114, 117, 101]),
1264 ],
1265 result_formats: vec![Format::Text]
1266 },
1267 );
1268 }
1269 _ => panic!("Wrong message, must be Bind"),
1270 }
1271
1272 Ok(())
1273 }
1274
1275 #[tokio::test]
1276 async fn test_frontend_message_parse_bind_variant2() -> Result<(), ProtocolError> {
1277 let buffer = parse_hex_dump(
1278 r#"
1279 42 00 00 00 1a 00 73 30 00 00 01 00 01 00 01 00 B.....s0........
1280 00 00 04 74 65 73 74 00 01 00 01 ...test....
1281 "#
1282 .to_string(),
1283 );
1284 let mut cursor = Cursor::new(buffer);
1285
1286 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1287 match message {
1288 FrontendMessage::Bind(body) => {
1289 assert_eq!(
1290 body,
1291 Bind {
1292 portal: "".to_string(),
1293 statement: "s0".to_string(),
1294 parameter_formats: vec![Format::Binary],
1295 parameter_values: vec![Some(vec![116, 101, 115, 116])],
1296 result_formats: vec![Format::Binary]
1297 },
1298 );
1299
1300 assert_eq!(
1301 body.to_bind_values(&ParameterDescription::new(vec![PgTypeId::TEXT]))
1302 .unwrap(),
1303 vec![BindValue::String("test".to_string())]
1304 );
1305 }
1306 _ => panic!("Wrong message, must be Bind"),
1307 }
1308
1309 Ok(())
1310 }
1311
1312 #[tokio::test]
1313 async fn test_frontend_message_parse_bind_float64() -> Result<(), ProtocolError> {
1314 let buffer = parse_hex_dump(
1316 r#"
1317 42 00 00 00 1a 00 73 30 00 00 01 00 00 00 01 00 B.....s0........
1318 00 00 05 32 36 2e 31 31 00 00 00 00 ...26.11....
1319 "#
1320 .to_string(),
1321 );
1322 let mut cursor = Cursor::new(buffer);
1323
1324 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1325 match message {
1326 FrontendMessage::Bind(body) => {
1327 assert_eq!(
1328 body,
1329 Bind {
1330 portal: "".to_string(),
1331 statement: "s0".to_string(),
1332 parameter_formats: vec![Format::Text],
1333 parameter_values: vec![Some(vec![50, 54, 46, 49, 49])], result_formats: vec![]
1335 },
1336 );
1337
1338 assert_eq!(
1339 body.to_bind_values(&ParameterDescription::new(vec![PgTypeId::FLOAT8]))?,
1340 vec![BindValue::Float64(26.11)]
1341 );
1342 }
1343 _ => panic!("Wrong message, must be Bind"),
1344 }
1345
1346 let buffer = parse_hex_dump(
1348 r#"
1349 42 00 00 00 1e 00 73 30 00 00 01 00 01 00 01 00 B.....s0........
1350 00 00 08 40 3a 1c 28 f5 c2 8f 5c 00 00 00 00 ...@:.(....\...
1351 "#
1352 .to_string(),
1353 );
1354 let mut cursor = Cursor::new(buffer);
1355
1356 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1357 match message {
1358 FrontendMessage::Bind(body) => {
1359 assert_eq!(body.parameter_formats, vec![Format::Binary]);
1360 assert_eq!(
1361 body.to_bind_values(&ParameterDescription::new(vec![PgTypeId::FLOAT8]))?,
1362 vec![BindValue::Float64(26.11)]
1363 );
1364 }
1365 _ => panic!("Wrong message, must be Bind"),
1366 }
1367
1368 Ok(())
1369 }
1370
1371 #[cfg(feature = "with-chrono")]
1372 #[tokio::test]
1373 async fn test_frontend_message_parse_bind_date() -> Result<(), ProtocolError> {
1374 use chrono::NaiveDate;
1375
1376 let buffer = parse_hex_dump(
1378 r#"
1379 42 00 00 00 1e 00 73 30 00 00 01 00 00 00 01 00 B.....s0........
1380 00 00 0a 32 30 32 35 2d 30 38 2d 30 38 00 00 00 ...2025-08-08...
1381 00 .
1382 "#
1383 .to_string(),
1384 );
1385 let mut cursor = Cursor::new(buffer);
1386 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1387 match message {
1388 FrontendMessage::Bind(body) => {
1389 assert_eq!(
1390 body.to_bind_values(&ParameterDescription::new(vec![PgTypeId::DATE]))?,
1391 vec![BindValue::Date(
1392 NaiveDate::from_ymd_opt(2025, 8, 8).unwrap()
1393 )]
1394 );
1395 }
1396 _ => panic!("Wrong message, must be Bind"),
1397 }
1398
1399 let buffer = parse_hex_dump(
1401 r#"
1402 42 00 00 00 1a 00 73 30 00 00 01 00 01 00 01 00 B.....s0........
1403 00 00 04 00 00 24 87 00 00 00 00 .....$......
1404 "#
1405 .to_string(),
1406 );
1407 let mut cursor = Cursor::new(buffer);
1408 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1409 match message {
1410 FrontendMessage::Bind(body) => {
1411 assert_eq!(body.parameter_formats, vec![Format::Binary]);
1412 assert_eq!(
1413 body.to_bind_values(&ParameterDescription::new(vec![PgTypeId::DATE]))?,
1414 vec![BindValue::Date(
1415 NaiveDate::from_ymd_opt(2025, 8, 8).unwrap()
1416 )]
1417 );
1418 }
1419 _ => panic!("Wrong message, must be Bind"),
1420 }
1421
1422 Ok(())
1423 }
1424
1425 #[tokio::test]
1426 async fn test_frontend_message_parse_describe() -> Result<(), ProtocolError> {
1427 let buffer = parse_hex_dump(
1428 r#"
1429 44 00 00 00 08 53 73 30 00 D....Ss0.
1430 "#
1431 .to_string(),
1432 );
1433 let mut cursor = Cursor::new(buffer);
1434
1435 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1436 match message {
1437 FrontendMessage::Describe(desc) => {
1438 assert_eq!(
1439 desc,
1440 Describe {
1441 typ: DescribeType::Statement,
1442 name: "s0".to_string(),
1443 },
1444 )
1445 }
1446 _ => panic!("Wrong message, must be Describe"),
1447 }
1448
1449 Ok(())
1450 }
1451
1452 #[tokio::test]
1453 async fn test_frontend_message_parse_password_message() -> Result<(), ProtocolError> {
1454 let buffer = parse_hex_dump(
1455 r#"
1456 70 00 00 00 09 74 65 73 74 00 p....test.
1457 "#
1458 .to_string(),
1459 );
1460 let mut cursor = Cursor::new(buffer);
1461
1462 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1463 match message {
1464 FrontendMessage::PasswordMessage(body) => {
1465 assert_eq!(
1466 body,
1467 PasswordMessage {
1468 password: "test".to_string()
1469 },
1470 )
1471 }
1472 _ => panic!("Wrong message, must be Describe"),
1473 }
1474
1475 Ok(())
1476 }
1477
1478 #[tokio::test]
1479 async fn test_frontend_message_execute() -> Result<(), ProtocolError> {
1480 let buffer = parse_hex_dump(
1481 r#"
1482 45 00 00 00 09 00 00 00 00 00 E.........
1483 "#
1484 .to_string(),
1485 );
1486 let mut cursor = Cursor::new(buffer);
1487
1488 let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1489 match message {
1490 FrontendMessage::Execute(body) => {
1491 assert_eq!(
1492 body,
1493 Execute {
1494 portal: "".to_string(),
1495 max_rows: 0
1496 },
1497 )
1498 }
1499 _ => panic!("Wrong message, must be Describe"),
1500 }
1501
1502 Ok(())
1503 }
1504
1505 #[tokio::test]
1506 async fn test_frontend_message_parse_sequence_sync() -> Result<(), ProtocolError> {
1507 let buffer = parse_hex_dump(
1508 r#"
1509 53 00 00 00 04 S....
1510 53 00 00 00 04 S....
1511 "#
1512 .to_string(),
1513 );
1514 let mut cursor = Cursor::new(buffer);
1515
1516 read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1519 read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?;
1520
1521 Ok(())
1522 }
1523
1524 #[tokio::test]
1525 async fn test_frontend_message_write_complete_parse() -> Result<(), ProtocolError> {
1526 let mut cursor = Cursor::new(vec![]);
1527
1528 buffer::write_message(&mut bytes::BytesMut::new(), &mut cursor, ParseComplete {}).await?;
1529
1530 assert_eq!(cursor.get_ref()[0..], vec![49, 0, 0, 0, 4]);
1531
1532 Ok(())
1533 }
1534
1535 #[tokio::test]
1536 async fn test_frontend_message_write_row_description() -> Result<(), ProtocolError> {
1537 let mut cursor = Cursor::new(vec![]);
1538 let desc = RowDescription::new(vec![
1539 RowDescriptionField::new(
1540 "num".to_string(),
1541 PgType::get_by_tid(PgTypeId::INT8),
1542 Format::Text,
1543 ),
1544 RowDescriptionField::new(
1545 "str".to_string(),
1546 PgType::get_by_tid(PgTypeId::INT8),
1547 Format::Text,
1548 ),
1549 RowDescriptionField::new(
1550 "bool".to_string(),
1551 PgType::get_by_tid(PgTypeId::INT8),
1552 Format::Text,
1553 ),
1554 ]);
1555 buffer::write_message(&mut bytes::BytesMut::new(), &mut cursor, desc).await?;
1556
1557 assert_eq!(
1558 cursor.get_ref()[0..],
1559 vec![
1560 84, 0, 0, 0, 73, 0, 3, 110, 117, 109, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 0, 8, 255,
1561 255, 255, 255, 0, 0, 115, 116, 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 0, 8, 255,
1562 255, 255, 255, 0, 0, 98, 111, 111, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 0, 8,
1563 255, 255, 255, 255, 0, 0
1564 ]
1565 );
1566
1567 Ok(())
1568 }
1569}