gel_pg_protocol/protocol.rs
1use gel_protogen::prelude::*;
2
3message_group!(
4 /// The `Backend` message group contains messages sent from the backend to the frontend.
5 Backend: Message = [
6 AuthenticationOk,
7 AuthenticationKerberosV5,
8 AuthenticationCleartextPassword,
9 AuthenticationMD5Password,
10 AuthenticationGSS,
11 AuthenticationGSSContinue,
12 AuthenticationSSPI,
13 AuthenticationSASL,
14 AuthenticationSASLContinue,
15 AuthenticationSASLFinal,
16 BackendKeyData,
17 BindComplete,
18 CloseComplete,
19 CommandComplete,
20 CopyData,
21 CopyDone,
22 CopyInResponse,
23 CopyOutResponse,
24 CopyBothResponse,
25 DataRow,
26 EmptyQueryResponse,
27 ErrorResponse,
28 FunctionCallResponse,
29 NegotiateProtocolVersion,
30 NoData,
31 NoticeResponse,
32 NotificationResponse,
33 ParameterDescription,
34 ParameterStatus,
35 ParseComplete,
36 PortalSuspended,
37 ReadyForQuery,
38 RowDescription
39 ]
40);
41
42message_group!(
43 /// The `Frontend` message group contains messages sent from the frontend to the backend.
44 Frontend: Message = [
45 Bind,
46 Close,
47 CopyData,
48 CopyDone,
49 CopyFail,
50 Describe,
51 Execute,
52 Flush,
53 FunctionCall,
54 GSSResponse,
55 Parse,
56 PasswordMessage,
57 Query,
58 SASLInitialResponse,
59 SASLResponse,
60 Sync,
61 Terminate
62 ]
63);
64
65message_group!(
66 /// The `Initial` message group contains messages that are sent before the
67 /// normal message flow.
68 Initial: InitialMessage = [
69 CancelRequest,
70 GSSENCRequest,
71 SSLRequest,
72 StartupMessage
73 ]
74);
75
76protocol!(
77
78/// A generic base for all Postgres mtype/mlen-style messages.
79struct Message<'a> {
80 /// Identifies the message.
81 mtype: u8,
82 /// Length of message contents in bytes, including self.
83 mlen: len,
84 /// Message contents.
85 data: Rest<'a>,
86}
87
88/// A generic base for all initial Postgres messages.
89struct InitialMessage<'a> {
90 /// Length of message contents in bytes, including self.
91 mlen: len,
92 /// The identifier for this initial message.
93 protocol_version: i32,
94 /// Message contents.
95 data: Rest<'a>
96}
97
98/// The `AuthenticationMessage` struct is a base for all Postgres authentication messages.
99struct AuthenticationMessage<'a>: Message {
100 /// Identifies the message as an authentication request.
101 mtype: u8 = 'R',
102 /// Length of message contents in bytes, including self.
103 mlen: len,
104 /// Specifies that the authentication was successful.
105 status: i32,
106}
107
108/// The `AuthenticationOk` struct represents a message indicating successful authentication.
109struct AuthenticationOk<'a>: Message {
110 /// Identifies the message as an authentication request.
111 mtype: u8 = 'R',
112 /// Length of message contents in bytes, including self.
113 mlen: len = 8,
114 /// Specifies that the authentication was successful.
115 status: i32 = 0,
116}
117
118/// The `AuthenticationKerberosV5` struct represents a message indicating that Kerberos V5 authentication is required.
119struct AuthenticationKerberosV5<'a>: Message {
120 /// Identifies the message as an authentication request.
121 mtype: u8 = 'R',
122 /// Length of message contents in bytes, including self.
123 mlen: len = 8,
124 /// Specifies that Kerberos V5 authentication is required.
125 status: i32 = 2,
126}
127
128/// The `AuthenticationCleartextPassword` struct represents a message indicating that a cleartext password is required for authentication.
129struct AuthenticationCleartextPassword<'a>: Message {
130 /// Identifies the message as an authentication request.
131 mtype: u8 = 'R',
132 /// Length of message contents in bytes, including self.
133 mlen: len = 8,
134 /// Specifies that a clear-text password is required.
135 status: i32 = 3,
136}
137
138/// The `AuthenticationMD5Password` struct represents a message indicating that an MD5-encrypted password is required for authentication.
139struct AuthenticationMD5Password<'a>: Message {
140 /// Identifies the message as an authentication request.
141 mtype: u8 = 'R',
142 /// Length of message contents in bytes, including self.
143 mlen: len = 12,
144 /// Specifies that an MD5-encrypted password is required.
145 status: i32 = 5,
146 /// The salt to use when encrypting the password.
147 salt: [u8; 4],
148}
149
150/// The `AuthenticationSCMCredential` struct represents a message indicating that an SCM credential is required for authentication.
151struct AuthenticationSCMCredential<'a>: Message {
152 /// Identifies the message as an authentication request.
153 mtype: u8 = 'R',
154 /// Length of message contents in bytes, including self.
155 mlen: len = 6,
156 /// Any data byte, which is ignored.
157 byte: u8 = 0,
158}
159
160/// The `AuthenticationGSS` struct represents a message indicating that GSSAPI authentication is required.
161struct AuthenticationGSS<'a>: Message {
162 /// Identifies the message as an authentication request.
163 mtype: u8 = 'R',
164 /// Length of message contents in bytes, including self.
165 mlen: len = 8,
166 /// Specifies that GSSAPI authentication is required.
167 status: i32 = 7,
168}
169
170/// The `AuthenticationGSSContinue` struct represents a message indicating the continuation of GSSAPI authentication.
171struct AuthenticationGSSContinue<'a>: Message {
172 /// Identifies the message as an authentication request.
173 mtype: u8 = 'R',
174 /// Length of message contents in bytes, including self.
175 mlen: len,
176 /// Specifies that this message contains GSSAPI or SSPI data.
177 status: i32 = 8,
178 /// GSSAPI or SSPI authentication data.
179 data: Rest<'a>,
180}
181
182/// The `AuthenticationSSPI` struct represents a message indicating that SSPI authentication is required.
183struct AuthenticationSSPI<'a>: Message {
184 /// Identifies the message as an authentication request.
185 mtype: u8 = 'R',
186 /// Length of message contents in bytes, including self.
187 mlen: len = 8,
188 /// Specifies that SSPI authentication is required.
189 status: i32 = 9,
190}
191
192/// The `AuthenticationSASL` struct represents a message indicating that SASL authentication is required.
193struct AuthenticationSASL<'a>: Message {
194 /// Identifies the message as an authentication request.
195 mtype: u8 = 'R',
196 /// Length of message contents in bytes, including self.
197 mlen: len,
198 /// Specifies that SASL authentication is required.
199 status: i32 = 10,
200 /// List of SASL authentication mechanisms, terminated by a zero byte.
201 mechanisms: ZTArray<'a, ZTString<'a>>,
202}
203
204/// The `AuthenticationSASLContinue` struct represents a message containing a SASL challenge during the authentication process.
205struct AuthenticationSASLContinue<'a>: Message {
206 /// Identifies the message as an authentication request.
207 mtype: u8 = 'R',
208 /// Length of message contents in bytes, including self.
209 mlen: len,
210 /// Specifies that this message contains a SASL challenge.
211 status: i32 = 11,
212 /// SASL data, specific to the SASL mechanism being used.
213 data: Rest<'a>,
214}
215
216/// The `AuthenticationSASLFinal` struct represents a message indicating the completion of SASL authentication.
217struct AuthenticationSASLFinal<'a>: Message {
218 /// Identifies the message as an authentication request.
219 mtype: u8 = 'R',
220 /// Length of message contents in bytes, including self.
221 mlen: len,
222 /// Specifies that SASL authentication has completed.
223 status: i32 = 12,
224 /// SASL outcome "additional data", specific to the SASL mechanism being used.
225 data: Rest<'a>,
226}
227
228/// The `BackendKeyData` struct represents a message containing the process ID and secret key for this backend.
229struct BackendKeyData<'a>: Message {
230 /// Identifies the message as cancellation key data.
231 mtype: u8 = 'K',
232 /// Length of message contents in bytes, including self.
233 mlen: len = 12,
234 /// The process ID of this backend.
235 pid: i32,
236 /// The secret key of this backend.
237 key: i32,
238}
239
240/// The `Bind` struct represents a message to bind a named portal to a prepared statement.
241struct Bind<'a>: Message {
242 /// Identifies the message as a Bind command.
243 mtype: u8 = 'B',
244 /// Length of message contents in bytes, including self.
245 mlen: len,
246 /// The name of the destination portal.
247 portal: ZTString<'a>,
248 /// The name of the source prepared statement.
249 statement: ZTString<'a>,
250 /// The parameter format codes.
251 format_codes: Array<'a, i16, i16>,
252 /// Array of parameter values and their lengths.
253 values: Array<'a, i16, Encoded<'a>>,
254 /// The result-column format codes.
255 result_format_codes: Array<'a, i16, i16>,
256}
257
258/// The `BindComplete` struct represents a message indicating that a Bind operation was successful.
259struct BindComplete<'a>: Message {
260 /// Identifies the message as a Bind-complete indicator.
261 mtype: u8 = '2',
262 /// Length of message contents in bytes, including self.
263 mlen: len = 4,
264}
265
266/// The `CancelRequest` struct represents a message to request the cancellation of a query.
267struct CancelRequest<'a>: InitialMessage {
268 /// Length of message contents in bytes, including self.
269 mlen: len = 16,
270 /// The cancel request code.
271 code: i32 = 80877102,
272 /// The process ID of the target backend.
273 pid: i32,
274 /// The secret key for the target backend.
275 key: i32,
276}
277
278/// The `Close` struct represents a message to close a prepared statement or portal.
279struct Close<'a>: Message {
280 /// Identifies the message as a Close command.
281 mtype: u8 = 'C',
282 /// Length of message contents in bytes, including self.
283 mlen: len,
284 /// 'S' to close a prepared statement; 'P' to close a portal.
285 ctype: CloseType,
286 /// The name of the prepared statement or portal to close.
287 name: ZTString<'a>,
288}
289
290/// The `CloseComplete` struct represents a message indicating that a Close operation was successful.
291struct CloseComplete<'a>: Message {
292 /// Identifies the message as a Close-complete indicator.
293 mtype: u8 = '3',
294 /// Length of message contents in bytes, including self.
295 mlen: len = 4,
296}
297
298/// The `CommandComplete` struct represents a message indicating the successful completion of a command.
299struct CommandComplete<'a>: Message {
300 /// Identifies the message as a command-completed response.
301 mtype: u8 = 'C',
302 /// Length of message contents in bytes, including self.
303 mlen: len,
304 /// The command tag.
305 tag: ZTString<'a>,
306}
307
308/// The `CopyData` struct represents a message containing data for a copy operation.
309struct CopyData<'a>: Message {
310 /// Identifies the message as COPY data.
311 mtype: u8 = 'd',
312 /// Length of message contents in bytes, including self.
313 mlen: len,
314 /// Data that forms part of a COPY data stream.
315 data: Rest<'a>,
316}
317
318/// The `CopyDone` struct represents a message indicating that a copy operation is complete.
319struct CopyDone<'a>: Message {
320 /// Identifies the message as a COPY-complete indicator.
321 mtype: u8 = 'c',
322 /// Length of message contents in bytes, including self.
323 mlen: len = 4,
324}
325
326/// The `CopyFail` struct represents a message indicating that a copy operation has failed.
327struct CopyFail<'a>: Message {
328 /// Identifies the message as a COPY-failure indicator.
329 mtype: u8 = 'f',
330 /// Length of message contents in bytes, including self.
331 mlen: len,
332 /// An error message to report as the cause of failure.
333 error_msg: ZTString<'a>,
334}
335
336/// The `CopyInResponse` struct represents a message indicating that the server is ready to receive data for a copy-in operation.
337struct CopyInResponse<'a>: Message {
338 /// Identifies the message as a Start Copy In response.
339 mtype: u8 = 'G',
340 /// Length of message contents in bytes, including self.
341 mlen: len,
342 /// The format of the data.
343 format: CopyFormat,
344 /// The format codes for each column.
345 format_codes: Array<'a, i16, i16>,
346}
347
348/// The `CopyOutResponse` struct represents a message indicating that the server is ready to send data for a copy-out operation.
349struct CopyOutResponse<'a>: Message {
350 /// Identifies the message as a Start Copy Out response.
351 mtype: u8 = 'H',
352 /// Length of message contents in bytes, including self.
353 mlen: len,
354 /// The format of the data.
355 format: CopyFormat,
356 /// The format codes for each column.
357 format_codes: Array<'a, i16, i16>,
358}
359
360/// The `CopyBothResponse` is used only for Streaming Replication.
361struct CopyBothResponse<'a>: Message {
362 /// Identifies the message as a Start Copy Both response.
363 mtype: u8 = 'W',
364 /// Length of message contents in bytes, including self.
365 mlen: len,
366 /// The format of the data.
367 format: CopyFormat,
368 /// The format codes for each column.
369 format_codes: Array<'a, i16, i16>,
370}
371
372/// The `DataRow` struct represents a message containing a row of data.
373struct DataRow<'a>: Message {
374 /// Identifies the message as a data row.
375 mtype: u8 = 'D',
376 /// Length of message contents in bytes, including self.
377 mlen: len,
378 /// Array of column values and their lengths.
379 values: Array<'a, i16, Encoded<'a>>,
380}
381
382/// The `Describe` struct represents a message to describe a prepared statement or portal.
383struct Describe<'a>: Message {
384 /// Identifies the message as a Describe command.
385 mtype: u8 = 'D',
386 /// Length of message contents in bytes, including self.
387 mlen: len,
388 /// 'S' to describe a prepared statement; 'P' to describe a portal.
389 dtype: DescribeType,
390 /// The name of the prepared statement or portal.
391 name: ZTString<'a>,
392}
393
394/// The `EmptyQueryResponse` struct represents a message indicating that an empty query string was recognized.
395struct EmptyQueryResponse<'a>: Message {
396 /// Identifies the message as a response to an empty query String<'a>.
397 mtype: u8 = 'I',
398 /// Length of message contents in bytes, including self.
399 mlen: len = 4,
400}
401
402/// The `ErrorResponse` struct represents a message indicating that an error has occurred.
403struct ErrorResponse<'a>: Message {
404 /// Identifies the message as an error.
405 mtype: u8 = 'E',
406 /// Length of message contents in bytes, including self.
407 mlen: len,
408 /// Array of error fields and their values.
409 fields: ZTArray<'a, ErrorField<'a>>,
410}
411
412/// The `ErrorField` struct represents a single error message within an `ErrorResponse`.
413struct ErrorField<'a> {
414 /// A code identifying the field type.
415 etype: u8,
416 /// The field value.
417 value: ZTString<'a>,
418}
419
420/// The `Execute` struct represents a message to execute a prepared statement or portal.
421struct Execute<'a>: Message {
422 /// Identifies the message as an Execute command.
423 mtype: u8 = 'E',
424 /// Length of message contents in bytes, including self.
425 mlen: len,
426 /// The name of the portal to execute.
427 portal: ZTString<'a>,
428 /// Maximum number of rows to return.
429 max_rows: i32,
430}
431
432/// The `Flush` struct represents a message to flush the backend's output buffer.
433struct Flush<'a>: Message {
434 /// Identifies the message as a Flush command.
435 mtype: u8 = 'H',
436 /// Length of message contents in bytes, including self.
437 mlen: len = 4,
438}
439
440/// The `FunctionCall` struct represents a message to call a function.
441struct FunctionCall<'a>: Message {
442 /// Identifies the message as a function call.
443 mtype: u8 = 'F',
444 /// Length of message contents in bytes, including self.
445 mlen: len,
446 /// OID of the function to execute.
447 function_id: i32,
448 /// The parameter format codes.
449 format_codes: Array<'a, i16, FormatCode>,
450 /// Array of args and their lengths.
451 args: Array<'a, i16, Encoded<'a>>,
452 /// The format code for the result.
453 result_format_code: FormatCode,
454}
455
456/// The `FunctionCallResponse` struct represents a message containing the result of a function call.
457struct FunctionCallResponse<'a>: Message {
458 /// Identifies the message as a function-call response.
459 mtype: u8 = 'V',
460 /// Length of message contents in bytes, including self.
461 mlen: len,
462 /// The function result value.
463 result: Encoded<'a>,
464}
465
466/// The `GSSENCRequest` struct represents a message requesting GSSAPI encryption.
467struct GSSENCRequest<'a>: InitialMessage {
468 /// Length of message contents in bytes, including self.
469 mlen: len = 8,
470 /// The GSSAPI Encryption request code.
471 gssenc_request_code: i32 = 80877104,
472}
473
474/// The `GSSResponse` struct represents a message containing a GSSAPI or SSPI response.
475struct GSSResponse<'a>: Message {
476 /// Identifies the message as a GSSAPI or SSPI response.
477 mtype: u8 = 'p',
478 /// Length of message contents in bytes, including self.
479 mlen: len,
480 /// GSSAPI or SSPI authentication data.
481 data: Rest<'a>,
482}
483
484/// The `NegotiateProtocolVersion` struct represents a message requesting protocol version negotiation.
485struct NegotiateProtocolVersion<'a>: Message {
486 /// Identifies the message as a protocol version negotiation request.
487 mtype: u8 = 'v',
488 /// Length of message contents in bytes, including self.
489 mlen: len,
490 /// Newest minor protocol version supported by the server.
491 minor_version: i32,
492 /// List of protocol options not recognized.
493 options: Array<'a, i32, ZTString<'a>>,
494}
495
496/// The `NoData` struct represents a message indicating that there is no data to return.
497struct NoData<'a>: Message {
498 /// Identifies the message as a No Data indicator.
499 mtype: u8 = 'n',
500 /// Length of message contents in bytes, including self.
501 mlen: len = 4,
502}
503
504/// The `NoticeResponse` struct represents a message containing a notice.
505struct NoticeResponse<'a>: Message {
506 /// Identifies the message as a notice.
507 mtype: u8 = 'N',
508 /// Length of message contents in bytes, including self.
509 mlen: len,
510 /// Array of notice fields and their values.
511 fields: ZTArray<'a, NoticeField<'a>>,
512}
513
514/// The `NoticeField` struct represents a single error message within an `NoticeResponse`.
515struct NoticeField<'a>: Message {
516 /// A code identifying the field type.
517 ntype: u8,
518 /// The field value.
519 value: ZTString<'a>,
520}
521
522/// The `NotificationResponse` struct represents a message containing a notification from the backend.
523struct NotificationResponse<'a>: Message {
524 /// Identifies the message as a notification.
525 mtype: u8 = 'A',
526 /// Length of message contents in bytes, including self.
527 mlen: len,
528 /// The process ID of the notifying backend.
529 pid: i32,
530 /// The name of the notification channel.
531 channel: ZTString<'a>,
532 /// The notification payload.
533 payload: ZTString<'a>,
534}
535
536/// The `ParameterDescription` struct represents a message describing the parameters needed by a prepared statement.
537struct ParameterDescription<'a>: Message {
538 /// Identifies the message as a parameter description.
539 mtype: u8 = 't',
540 /// Length of message contents in bytes, including self.
541 mlen: len,
542 /// OIDs of the parameter data types.
543 param_types: Array<'a, i16, i32>,
544}
545
546/// The `ParameterStatus` struct represents a message containing the current status of a parameter.
547struct ParameterStatus<'a>: Message {
548 /// Identifies the message as a runtime parameter status report.
549 mtype: u8 = 'S',
550 /// Length of message contents in bytes, including self.
551 mlen: len,
552 /// The name of the parameter.
553 name: ZTString<'a>,
554 /// The current value of the parameter.
555 value: ZTString<'a>,
556}
557
558/// The `Parse` struct represents a message to parse a query string.
559struct Parse<'a>: Message {
560 /// Identifies the message as a Parse command.
561 mtype: u8 = 'P',
562 /// Length of message contents in bytes, including self.
563 mlen: len,
564 /// The name of the destination prepared statement.
565 statement: ZTString<'a>,
566 /// The query string to be parsed.
567 query: ZTString<'a>,
568 /// OIDs of the parameter data types.
569 param_types: Array<'a, i16, i32>,
570}
571
572/// The `ParseComplete` struct represents a message indicating that a Parse operation was successful.
573struct ParseComplete<'a>: Message {
574 /// Identifies the message as a Parse-complete indicator.
575 mtype: u8 = '1',
576 /// Length of message contents in bytes, including self.
577 mlen: len = 4,
578}
579
580/// The `PasswordMessage` struct represents a message containing a password.
581struct PasswordMessage<'a>: Message {
582 /// Identifies the message as a password response.
583 mtype: u8 = 'p',
584 /// Length of message contents in bytes, including self.
585 mlen: len,
586 /// The password (encrypted or plaintext, depending on context).
587 password: ZTString<'a>,
588}
589
590/// The `PortalSuspended` struct represents a message indicating that a portal has been suspended.
591struct PortalSuspended<'a>: Message {
592 /// Identifies the message as a portal-suspended indicator.
593 mtype: u8 = 's',
594 /// Length of message contents in bytes, including self.
595 mlen: len = 4,
596}
597
598/// The `Query` struct represents a message to execute a simple query.
599struct Query<'a>: Message {
600 /// Identifies the message as a simple query command.
601 mtype: u8 = 'Q',
602 /// Length of message contents in bytes, including self.
603 mlen: len,
604 /// The query String<'a> to be executed.
605 query: ZTString<'a>,
606}
607
608/// The `ReadyForQuery` struct represents a message indicating that the backend is ready for a new query.
609struct ReadyForQuery<'a>: Message {
610 /// Identifies the message as a ready-for-query indicator.
611 mtype: u8 = 'Z',
612 /// Length of message contents in bytes, including self.
613 mlen: len = 5,
614 /// Current transaction status indicator.
615 status: u8,
616}
617
618/// The `RowDescription` struct represents a message describing the rows that will be returned by a query.
619struct RowDescription<'a>: Message {
620 /// Identifies the message as a row description.
621 mtype: u8 = 'T',
622 /// Length of message contents in bytes, including self.
623 mlen: len,
624 /// Array of field descriptions.
625 fields: Array<'a, i16, RowField<'a>>,
626}
627
628/// The `RowField` struct represents a row within the `RowDescription` message.
629struct RowField<'a> {
630 /// The field name
631 name: ZTString<'a>,
632 /// The table ID (OID) of the table the column is from, or 0 if not a column reference
633 table_oid: i32,
634 /// The attribute number of the column, or 0 if not a column reference
635 column_attr_number: i16,
636 /// The object ID of the field's data type
637 data_type_oid: i32,
638 /// The data type size (negative if variable size)
639 data_type_size: i16,
640 /// The type modifier
641 type_modifier: i32,
642 /// The format code being used for the field (0 for text, 1 for binary)
643 format_code: FormatCode,
644}
645
646/// The `SASLInitialResponse` struct represents a message containing a SASL initial response.
647struct SASLInitialResponse<'a>: Message {
648 /// Identifies the message as a SASL initial response.
649 mtype: u8 = 'p',
650 /// Length of message contents in bytes, including self.
651 mlen: len,
652 /// Name of the SASL authentication mechanism.
653 mechanism: ZTString<'a>,
654 /// SASL initial response data.
655 response: Array<'a, i32, u8>,
656}
657
658/// The `SASLResponse` struct represents a message containing a SASL response.
659struct SASLResponse<'a>: Message {
660 /// Identifies the message as a SASL response.
661 mtype: u8 = 'p',
662 /// Length of message contents in bytes, including self.
663 mlen: len,
664 /// SASL response data.
665 response: Rest<'a>,
666}
667
668/// The `SSLRequest` struct represents a message requesting SSL encryption.
669struct SSLRequest<'a>: InitialMessage {
670 /// Length of message contents in bytes, including self.
671 mlen: len = 8,
672 /// The SSL request code.
673 code: i32 = 80877103,
674}
675
676struct SSLResponse<'a> {
677 /// Specifies if SSL was accepted or rejected.
678 code: u8,
679}
680
681/// The `StartupMessage` struct represents a message to initiate a connection.
682struct StartupMessage<'a>: InitialMessage {
683 /// Length of message contents in bytes, including self.
684 mlen: len,
685 /// The protocol version number.
686 protocol: i32 = 196608,
687 /// List of parameter name-value pairs, terminated by a zero byte.
688 params: ZTArray<'a, StartupNameValue<'a>>,
689}
690
691/// The `StartupMessage` struct represents a name/value pair within the `StartupMessage` message.
692struct StartupNameValue<'a> {
693 /// The parameter name.
694 name: ZTString<'a>,
695 /// The parameter value.
696 value: ZTString<'a>,
697}
698
699/// The `Sync` struct represents a message to synchronize the frontend and backend.
700struct Sync<'a>: Message {
701 /// Identifies the message as a Sync command.
702 mtype: u8 = 'S',
703 /// Length of message contents in bytes, including self.
704 mlen: len = 4,
705}
706
707/// The `Terminate` struct represents a message to terminate a connection.
708struct Terminate<'a>: Message {
709 /// Identifies the message as a Terminate command.
710 mtype: u8 = 'X',
711 /// Length of message contents in bytes, including self.
712 mlen: len = 4,
713}
714
715#[repr(u8)]
716/// The type of object to close.
717enum CloseType {
718 #[default]
719 Portal = b'P',
720 Statement = b'S',
721}
722
723#[repr(u8)]
724/// The type of object to describe.
725enum DescribeType {
726 #[default]
727 Portal = b'P',
728 Statement = b'S',
729}
730
731#[repr(u8)]
732/// The data format for a copy operation.
733enum CopyFormat {
734 #[default]
735 Text = 0,
736 Binary = 1,
737}
738
739#[repr(u16)]
740/// The format code for an input or output value.
741enum FormatCode {
742 #[default]
743 Text = 0,
744 Binary = 1,
745}
746
747);
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752 use gel_protogen::prelude::{match_message, Encoded, StructBuffer, StructMeta};
753 use rand::Rng;
754
755 /// We want to ensure that no malformed messages will cause unexpected
756 /// panics, so we try all sorts of combinations of message mutation to
757 /// ensure we don't.
758 ///
759 /// This isn't a 100% foolproof test.
760 fn fuzz_test<S: StructMeta>(s: S) {
761 let buf = s.to_vec();
762 assert!(buf.len() > 4, "Buffer is unexpectedly too short: {buf:?}");
763
764 eprintln!("Fuzzing buffer: {buf:?}");
765
766 // Re-create, won't panic
767 fuzz_test_buf::<S>(&buf);
768
769 // Truncating at any given length won't panic
770 for i in 0..buf.len() {
771 let mut buf = s.to_vec();
772 buf.truncate(i);
773 fuzz_test_buf::<S>(&buf);
774 }
775
776 // Removing any particular value won't panic
777 for i in 0..buf.len() {
778 let mut buf = s.to_vec();
779 buf.remove(i);
780 fuzz_test_buf::<S>(&buf);
781 }
782
783 // Zeroing any particular value won't panic
784 for i in 0..buf.len() {
785 let mut buf = s.to_vec();
786 buf[i] = 0;
787 fuzz_test_buf::<S>(&buf);
788 }
789
790 // Corrupt each byte by incrementing (mod 256)
791 for i in 0..buf.len() {
792 let mut buf = s.to_vec();
793 buf[i] = buf[i].wrapping_add(1);
794 fuzz_test_buf::<S>(&buf);
795 }
796
797 // Corrupt each byte by decrementing (mod 256)
798 for i in 0..buf.len() {
799 let mut buf = s.to_vec();
800 buf[i] = buf[i].wrapping_sub(1);
801 fuzz_test_buf::<S>(&buf);
802 }
803
804 // Replace four-byte chunks at 1-byte offsets with "-2" in big-endian, one at a time
805 // This shakes out any negative length issues for i32 lengths
806 let negative_two_i32: i32 = -2;
807 let bytes_i32 = negative_two_i32.to_be_bytes();
808 for start_index in 0..buf.len().saturating_sub(3) {
809 if start_index + 4 <= buf.len() {
810 let mut buf = s.to_vec(); // Clean buffer for each iteration
811 buf[start_index..start_index + 4].copy_from_slice(&bytes_i32);
812 eprintln!("Replaced 4-byte chunk at offset {} with -2 (big-endian) in buffer of length {}", start_index, buf.len());
813 fuzz_test_buf::<S>(&buf);
814 }
815 }
816
817 // Replace two-byte chunks at 1-byte offsets with "-2" in big-endian, one at a time
818 // This shakes out any negative length issues for i16 lengths
819 let negative_two_i16: i16 = -2;
820 let bytes_i16 = negative_two_i16.to_be_bytes();
821 for start_index in 0..buf.len().saturating_sub(1) {
822 if start_index + 2 <= buf.len() {
823 let mut buf = s.to_vec(); // Clean buffer for each iteration
824 buf[start_index..start_index + 2].copy_from_slice(&bytes_i16);
825 eprintln!("Replaced 2-byte chunk at offset {} with -2 (big-endian) in buffer of length {}", start_index, buf.len());
826 fuzz_test_buf::<S>(&buf);
827 }
828 }
829
830 let run_count = if std::env::var("EXTENSIVE_FUZZ").is_ok() {
831 100000
832 } else {
833 10
834 };
835
836 // Insert a random byte at a random position
837 for i in 0..run_count {
838 let mut buf = s.to_vec();
839 let random_byte: u8 = rand::rng().random();
840 let random_position = rand::rng().random_range(0..=buf.len());
841 buf.insert(random_position, random_byte);
842 eprintln!(
843 "Test {}: Inserted byte 0x{:02X} at position {} in buffer of length {}",
844 i + 1,
845 random_byte,
846 random_position,
847 buf.len()
848 );
849 fuzz_test_buf::<S>(&buf);
850 }
851
852 // Corrupt random parts of the buffer. This is non-deterministic.
853 for i in 0..run_count {
854 let mut buf = s.to_vec();
855 let rand: [u8; 4] = rand::rng().random();
856 let n = rand::rng().random_range(0..buf.len() - 4);
857 let range = n..n + 4;
858 eprintln!(
859 "Test {}: Corrupting buffer of length {} at range {:?} with bytes {:?}",
860 i + 1,
861 buf.len(),
862 range,
863 rand
864 );
865 buf.get_mut(range).unwrap().copy_from_slice(&rand);
866 fuzz_test_buf::<S>(&buf);
867 }
868
869 // Corrupt 1..4 random bytes at random positions
870 for i in 0..run_count {
871 let mut buf = s.to_vec();
872 let num_bytes_to_corrupt = rand::rng().random_range(1..=4);
873 let mut positions = Vec::new();
874
875 for _ in 0..num_bytes_to_corrupt {
876 let random_position = rand::rng().random_range(0..buf.len());
877 if !positions.contains(&random_position) {
878 positions.push(random_position);
879 let random_byte: u8 = rand::rng().random();
880 buf[random_position] = random_byte;
881 }
882 }
883
884 eprintln!(
885 "Test {}: Corrupted {} byte(s) at position(s) {:?} in buffer of length {}",
886 i + 1,
887 positions.len(),
888 positions,
889 buf.len()
890 );
891 fuzz_test_buf::<S>(&buf);
892 }
893
894 // Attempt to parse randomly generated structs. This is non-deterministic.
895 for i in 0..run_count {
896 let buf: [u8; 16] = rand::rng().random();
897 eprintln!(
898 "Test {}: Attempting to parse random buffer: {:02X?}",
899 i + 1,
900 buf
901 );
902 fuzz_test_buf::<S>(&buf);
903 }
904 }
905
906 fn fuzz_test_buf<S: StructMeta>(buf: &[u8]) {
907 // Use std::fmt::Debug which will walk each field
908 if let Ok(m) = S::new(buf) {
909 let _ = format!("{m:?}");
910 }
911 }
912
913 #[test]
914 fn test_sasl_response() {
915 let buf = [b'p', 0, 0, 0, 5, 2];
916 assert!(SASLResponse::is_buffer(&buf));
917 let message = SASLResponse::new(&buf).unwrap();
918 assert_eq!(*message.mlen(), 5);
919 assert_eq!(message.response().len(), 1);
920 }
921
922 #[test]
923 fn test_sasl_response_measure() {
924 let measure = SASLResponseBuilder {
925 response: &[1, 2, 3, 4, 5],
926 };
927 assert_eq!(measure.measure(), 10)
928 }
929
930 #[test]
931 fn test_sasl_initial_response() {
932 let buf = [
933 b'p', 0, 0, 0, 0x36, // Mechanism
934 b'S', b'C', b'R', b'A', b'M', b'-', b'S', b'H', b'A', b'-', b'2', b'5', b'6', 0,
935 // Data
936 0, 0, 0, 32, b'n', b',', b',', b'n', b'=', b',', b'r', b'=', b'p', b'E', b'k', b'P',
937 b'L', b'Q', b'u', b'2', b'9', b'G', b'E', b'v', b'w', b'N', b'e', b'V', b'J', b't',
938 b'7', b'2', b'a', b'r', b'Q', b'I',
939 ];
940
941 assert!(SASLInitialResponse::is_buffer(&buf));
942 let message = SASLInitialResponse::new(&buf).unwrap();
943 assert_eq!(*message.mlen(), 0x36);
944 assert_eq!(message.mechanism(), "SCRAM-SHA-256");
945 assert_eq!(
946 message.response().as_ref(),
947 b"n,,n=,r=pEkPLQu29GEvwNeVJt72arQI"
948 );
949
950 fuzz_test(message);
951 }
952
953 #[test]
954 fn test_sasl_initial_response_builder() {
955 let buf = SASLInitialResponseBuilder {
956 mechanism: "SCRAM-SHA-256",
957 response: b"n,,n=,r=pEkPLQu29GEvwNeVJt72arQI",
958 }
959 .to_vec();
960
961 let message = SASLInitialResponse::new(&buf).unwrap();
962 assert_eq!(*message.mlen(), 0x36);
963 assert_eq!(message.mechanism(), "SCRAM-SHA-256");
964 assert_eq!(
965 message.response().as_ref(),
966 b"n,,n=,r=pEkPLQu29GEvwNeVJt72arQI"
967 );
968
969 fuzz_test(message);
970 }
971
972 #[test]
973 fn test_startup_message() {
974 let buf = [
975 0, 0, 0, 41, 0, 0x03, 0, 0, 0x75, 0x73, 0x65, 0x72, 0, 0x70, 0x6f, 0x73, 0x74, 0x67,
976 0x72, 0x65, 0x73, 0, 0x64, 0x61, 0x74, 0x61, 0x62, 0x61, 0x73, 0x65, 0, 0x70, 0x6f,
977 0x73, 0x74, 0x67, 0x72, 0x65, 0x73, 0, 0,
978 ];
979 let message = StartupMessage::new(&buf).unwrap();
980 assert_eq!(*message.mlen() as usize, buf.len());
981 assert_eq!(message.protocol(), 196608);
982 let arr = message.params();
983 let mut vals = vec![];
984 for entry in arr {
985 vals.push(entry.name().to_owned().unwrap());
986 vals.push(entry.value().to_owned().unwrap());
987 }
988 assert_eq!(vals, vec!["user", "postgres", "database", "postgres"]);
989
990 fuzz_test(message);
991 }
992
993 #[test]
994 fn test_row_description() {
995 let buf = [
996 b'T', 0, 0, 0, 48, // header
997 0, 2, // # of fields
998 b'f', b'1', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // field 1
999 b'f', b'2', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // field 2
1000 ];
1001 assert!(RowDescription::is_buffer(&buf));
1002 let message = RowDescription::new(&buf).unwrap();
1003 assert_eq!(*message.mlen() as usize, buf.len() - 1);
1004 assert_eq!(message.fields().len(), 2);
1005 let mut iter = message.fields().into_iter();
1006 let f1 = iter.next().unwrap();
1007 assert_eq!(f1.name(), "f1");
1008 let f2 = iter.next().unwrap();
1009 assert_eq!(f2.name(), "f2");
1010 assert_eq!(None, iter.next());
1011 fuzz_test(message);
1012 }
1013
1014 #[test]
1015 fn test_row_description_measure() {
1016 let measure = RowDescriptionBuilder {
1017 fields: &[
1018 RowFieldBuilder {
1019 name: "F1",
1020 table_oid: 0,
1021 column_attr_number: 0,
1022 data_type_oid: 0,
1023 data_type_size: 0,
1024 type_modifier: 0,
1025 format_code: FormatCode::Text,
1026 },
1027 RowFieldBuilder {
1028 name: "F2",
1029 table_oid: 0,
1030 column_attr_number: 0,
1031 data_type_oid: 0,
1032 data_type_size: 0,
1033 type_modifier: 0,
1034 format_code: FormatCode::Text,
1035 },
1036 ],
1037 };
1038 assert_eq!(49, measure.measure())
1039 }
1040
1041 #[test]
1042 fn test_row_description_builder() {
1043 let builder = RowDescriptionBuilder {
1044 fields: &[
1045 RowFieldBuilder {
1046 name: "F1",
1047 column_attr_number: 1,
1048 table_oid: 1,
1049 data_type_oid: 0,
1050 data_type_size: 0,
1051 type_modifier: 0,
1052 format_code: FormatCode::Text,
1053 },
1054 RowFieldBuilder {
1055 name: "F2",
1056 data_type_oid: 1234,
1057 format_code: FormatCode::Binary,
1058 table_oid: 2,
1059 column_attr_number: 2,
1060 data_type_size: 0,
1061 type_modifier: 0,
1062 },
1063 ],
1064 };
1065
1066 let vec = builder.to_vec();
1067 assert_eq!(49, vec.len());
1068
1069 // Read it back
1070 assert!(RowDescription::is_buffer(&vec));
1071 let message = RowDescription::new(&vec).unwrap();
1072 assert_eq!(message.fields().len(), 2);
1073 let mut iter = message.fields().into_iter();
1074 let f1 = iter.next().unwrap();
1075 assert_eq!(f1.name(), "F1");
1076 assert_eq!(f1.column_attr_number(), 1);
1077 let f2 = iter.next().unwrap();
1078 assert_eq!(f2.name(), "F2");
1079 assert_eq!(f2.data_type_oid(), 1234);
1080 assert_eq!(f2.format_code(), FormatCode::Binary);
1081 assert_eq!(None, iter.next());
1082
1083 fuzz_test(message);
1084 }
1085
1086 #[test]
1087 fn test_message_polymorphism_sync() {
1088 let sync = SyncBuilder::default();
1089 let buf = sync.to_vec();
1090 assert_eq!(buf.len(), 5);
1091 // Read it as a Message
1092 let message = Message::new(&buf).unwrap();
1093 assert_eq!(*message.mlen(), 4);
1094 assert_eq!(message.mtype(), b'S');
1095 assert_eq!(message.data(), &[]);
1096 // And also a Sync
1097 assert!(Sync::is_buffer(&buf));
1098 let message = Sync::new(&buf).unwrap();
1099 assert_eq!(*message.mlen(), 4);
1100 assert_eq!(message.mtype(), b'S');
1101
1102 fuzz_test(message);
1103 }
1104
1105 #[test]
1106 fn test_message_polymorphism_rest() {
1107 let auth = AuthenticationGSSContinueBuilder {
1108 data: &[1, 2, 3, 4, 5],
1109 };
1110 let buf = auth.to_vec();
1111 assert_eq!(14, buf.len());
1112 // Read it as a Message
1113 assert!(Message::is_buffer(&buf));
1114 let message = Message::new(&buf).unwrap();
1115 assert_eq!(*message.mlen(), 13);
1116 assert_eq!(message.mtype(), b'R');
1117 assert_eq!(message.data(), &[0, 0, 0, 8, 1, 2, 3, 4, 5]);
1118 // And also a AuthenticationGSSContinue
1119 assert!(AuthenticationGSSContinue::is_buffer(&buf));
1120 let message = AuthenticationGSSContinue::new(&buf).unwrap();
1121 assert_eq!(*message.mlen(), 13);
1122 assert_eq!(message.mtype(), b'R');
1123 assert_eq!(message.data(), &[1, 2, 3, 4, 5]);
1124
1125 fuzz_test(message);
1126 }
1127
1128 #[test]
1129 fn test_query_messages() {
1130 let data: Vec<u8> = vec![
1131 0x54, 0x00, 0x00, 0x00, 0x21, 0x00, 0x01, 0x3f, b'c', b'o', b'l', b'u', b'm', b'n',
1132 0x3f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04,
1133 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x01, 0x00,
1134 0x00, 0x00, 0x01, b'1', b'C', 0x00, 0x00, 0x00, 0x0d, b'S', b'E', b'L', b'E', b'C',
1135 b'T', b' ', b'1', 0x00, 0x5a, 0x00, 0x00, 0x00, 0x05, b'I',
1136 ];
1137
1138 let mut buffer = StructBuffer::<Message>::default();
1139 buffer.push(&data, |message| {
1140 match_message!(message, Backend {
1141 (RowDescription as row) => {
1142 assert_eq!(row.fields().len(), 1);
1143 let field = row.fields().into_iter().next().unwrap();
1144 assert_eq!(field.name(), "?column?");
1145 assert_eq!(field.data_type_oid(), 23);
1146 assert_eq!(field.format_code(), FormatCode::Text);
1147 eprintln!("{row:?}");
1148 fuzz_test(row);
1149 },
1150 (DataRow as row) => {
1151 assert_eq!(row.values().len(), 1);
1152 assert_eq!(row.values().into_iter().next().unwrap(), "1");
1153 eprintln!("{row:?}");
1154 fuzz_test(row);
1155 },
1156 (CommandComplete as complete) => {
1157 assert_eq!(complete.tag(), "SELECT 1");
1158 eprintln!("{complete:?}");
1159 },
1160 (ReadyForQuery as ready) => {
1161 assert_eq!(ready.status(), b'I');
1162 eprintln!("{ready:?}");
1163 },
1164 unknown => {
1165 panic!("Unknown message type: {unknown:?}");
1166 }
1167 });
1168 });
1169 }
1170
1171 #[test]
1172 fn test_encode_data_row() {
1173 DataRowBuilder {
1174 values: &[Encoded::Value(b"1")],
1175 }
1176 .to_vec();
1177 }
1178
1179 #[test]
1180 fn test_parse() {
1181 let buf = [
1182 b'P', // message type
1183 0, 0, 0, 25, // message length
1184 b'S', b't', b'm', b't', 0, // statement name
1185 b'S', b'E', b'L', b'E', b'C', b'T', b' ', b'$', b'1', 0, // query string
1186 0, 1, // number of parameter data types
1187 0, 0, 0, 23, // OID
1188 ];
1189
1190 assert!(Parse::is_buffer(&buf));
1191 let message = Parse::new(&buf).unwrap();
1192 assert_eq!(*message.mlen(), 25);
1193 assert_eq!(message.statement(), "Stmt");
1194 assert_eq!(message.query(), "SELECT $1");
1195 assert_eq!(message.param_types().len(), 1);
1196 assert_eq!(message.param_types().get(0).unwrap(), 23); // OID
1197
1198 fuzz_test(message);
1199 }
1200
1201 #[test]
1202 fn test_function_call() {
1203 let buf = FunctionCallBuilder {
1204 function_id: 100,
1205 format_codes: &[FormatCode::Text],
1206 args: &[Encoded::Value(b"123")],
1207 result_format_code: FormatCode::Text,
1208 }
1209 .to_vec();
1210
1211 assert!(FunctionCall::is_buffer(&buf));
1212 let message = FunctionCall::new(&buf).unwrap();
1213 assert_eq!(message.function_id(), 100);
1214 assert_eq!(message.format_codes().len(), 1);
1215 assert_eq!(
1216 message.format_codes().into_iter().next().unwrap(),
1217 FormatCode::Text
1218 );
1219 assert_eq!(message.args().len(), 1);
1220 assert_eq!(
1221 message.args().into_iter().next().unwrap(),
1222 Encoded::Value(b"123")
1223 );
1224 assert_eq!(message.result_format_code(), FormatCode::Text);
1225
1226 fuzz_test(message);
1227 }
1228
1229 #[test]
1230 fn test_datarow() {
1231 let buf = [
1232 0x44, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff,
1233 ];
1234 assert!(DataRow::is_buffer(&buf));
1235 let message = DataRow::new(&buf).unwrap();
1236 assert_eq!(message.values().len(), 1);
1237 assert_eq!(message.values().into_iter().next().unwrap(), Encoded::Null);
1238 }
1239}