Skip to main content

pglite_oxide/protocol/
messages.rs

1use std::fmt;
2
3use anyhow::Result;
4
5use crate::protocol::types::Mode;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum MessageName {
9    ParseComplete,
10    BindComplete,
11    CloseComplete,
12    NoData,
13    PortalSuspended,
14    ReplicationStart,
15    EmptyQuery,
16    CopyDone,
17    CopyData,
18    RowDescription,
19    ParameterDescription,
20    ParameterStatus,
21    BackendKeyData,
22    Notification,
23    ReadyForQuery,
24    CommandComplete,
25    DataRow,
26    CopyInResponse,
27    CopyOutResponse,
28    AuthenticationOk,
29    AuthenticationMD5Password,
30    AuthenticationCleartextPassword,
31    AuthenticationSasl,
32    AuthenticationSaslContinue,
33    AuthenticationSaslFinal,
34    Error,
35    Notice,
36}
37
38impl fmt::Display for MessageName {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        use MessageName::*;
41        let name = match self {
42            ParseComplete => "parseComplete",
43            BindComplete => "bindComplete",
44            CloseComplete => "closeComplete",
45            NoData => "noData",
46            PortalSuspended => "portalSuspended",
47            ReplicationStart => "replicationStart",
48            EmptyQuery => "emptyQuery",
49            CopyDone => "copyDone",
50            CopyData => "copyData",
51            RowDescription => "rowDescription",
52            ParameterDescription => "parameterDescription",
53            ParameterStatus => "parameterStatus",
54            BackendKeyData => "backendKeyData",
55            Notification => "notification",
56            ReadyForQuery => "readyForQuery",
57            CommandComplete => "commandComplete",
58            DataRow => "dataRow",
59            CopyInResponse => "copyInResponse",
60            CopyOutResponse => "copyOutResponse",
61            AuthenticationOk => "authenticationOk",
62            AuthenticationMD5Password => "authenticationMD5Password",
63            AuthenticationCleartextPassword => "authenticationCleartextPassword",
64            AuthenticationSasl => "authenticationSASL",
65            AuthenticationSaslContinue => "authenticationSASLContinue",
66            AuthenticationSaslFinal => "authenticationSASLFinal",
67            Error => "error",
68            Notice => "notice",
69        };
70        write!(f, "{name}")
71    }
72}
73
74#[derive(Debug, Clone)]
75pub enum BackendMessage {
76    ParseComplete { length: usize },
77    BindComplete { length: usize },
78    CloseComplete { length: usize },
79    NoData { length: usize },
80    PortalSuspended { length: usize },
81    ReplicationStart { length: usize },
82    EmptyQuery { length: usize },
83    CopyDone { length: usize },
84    ReadyForQuery(ReadyForQueryMessage),
85    CommandComplete(CommandCompleteMessage),
86    DataRow(DataRowMessage),
87    RowDescription(RowDescriptionMessage),
88    ParameterDescription(ParameterDescriptionMessage),
89    ParameterStatus(ParameterStatusMessage),
90    BackendKeyData(BackendKeyDataMessage),
91    Notification(NotificationResponseMessage),
92    CopyResponse(CopyResponse),
93    CopyData(CopyDataMessage),
94    Authentication(AuthenticationMessage),
95    Error(DatabaseError),
96    Notice(NoticeMessage),
97}
98
99impl BackendMessage {
100    pub fn name(&self) -> MessageName {
101        use BackendMessage::*;
102        match self {
103            ParseComplete { .. } => MessageName::ParseComplete,
104            BindComplete { .. } => MessageName::BindComplete,
105            CloseComplete { .. } => MessageName::CloseComplete,
106            NoData { .. } => MessageName::NoData,
107            PortalSuspended { .. } => MessageName::PortalSuspended,
108            ReplicationStart { .. } => MessageName::ReplicationStart,
109            EmptyQuery { .. } => MessageName::EmptyQuery,
110            CopyDone { .. } => MessageName::CopyDone,
111            ReadyForQuery(_) => MessageName::ReadyForQuery,
112            CommandComplete(_) => MessageName::CommandComplete,
113            DataRow(_) => MessageName::DataRow,
114            RowDescription(_) => MessageName::RowDescription,
115            ParameterDescription(_) => MessageName::ParameterDescription,
116            ParameterStatus(_) => MessageName::ParameterStatus,
117            BackendKeyData(_) => MessageName::BackendKeyData,
118            Notification(_) => MessageName::Notification,
119            CopyResponse(resp) => match resp.name {
120                MessageName::CopyInResponse => MessageName::CopyInResponse,
121                MessageName::CopyOutResponse => MessageName::CopyOutResponse,
122                _ => resp.name,
123            },
124            CopyData(_) => MessageName::CopyData,
125            Authentication(auth) => auth.name(),
126            Error(_) => MessageName::Error,
127            Notice(_) => MessageName::Notice,
128        }
129    }
130
131    pub fn length(&self) -> usize {
132        use BackendMessage::*;
133        match self {
134            ParseComplete { length }
135            | BindComplete { length }
136            | CloseComplete { length }
137            | NoData { length }
138            | PortalSuspended { length }
139            | ReplicationStart { length }
140            | EmptyQuery { length }
141            | CopyDone { length } => *length,
142            ReadyForQuery(msg) => msg.length,
143            CommandComplete(msg) => msg.length,
144            DataRow(msg) => msg.length,
145            RowDescription(msg) => msg.length,
146            ParameterDescription(msg) => msg.length,
147            ParameterStatus(msg) => msg.length,
148            BackendKeyData(msg) => msg.length,
149            Notification(msg) => msg.length,
150            CopyResponse(msg) => msg.length,
151            CopyData(msg) => msg.length,
152            Authentication(msg) => msg.length(),
153            Error(msg) => msg.length,
154            Notice(msg) => msg.length,
155        }
156    }
157}
158
159#[derive(Debug, Clone)]
160pub struct ReadyForQueryMessage {
161    pub length: usize,
162    pub status: u8,
163}
164
165#[derive(Debug, Clone)]
166pub struct CommandCompleteMessage {
167    pub length: usize,
168    pub text: String,
169}
170
171#[derive(Debug, Clone)]
172pub struct CopyDataMessage {
173    pub length: usize,
174    pub chunk: Vec<u8>,
175}
176
177#[derive(Debug, Clone)]
178pub struct CopyResponse {
179    pub length: usize,
180    pub name: MessageName,
181    pub binary: bool,
182    pub column_types: Vec<i16>,
183}
184
185#[derive(Debug, Clone)]
186pub struct Field {
187    pub name: String,
188    pub table_id: i32,
189    pub column_id: i16,
190    pub data_type_id: i32,
191    pub data_type_size: i16,
192    pub data_type_modifier: i32,
193    pub format: Mode,
194}
195
196#[derive(Debug, Clone)]
197pub struct RowDescriptionMessage {
198    pub length: usize,
199    pub fields: Vec<Field>,
200}
201
202#[derive(Debug, Clone)]
203pub struct ParameterDescriptionMessage {
204    pub length: usize,
205    pub data_type_ids: Vec<i32>,
206}
207
208#[derive(Debug, Clone)]
209pub struct ParameterStatusMessage {
210    pub length: usize,
211    pub parameter_name: String,
212    pub parameter_value: String,
213}
214
215#[derive(Debug, Clone)]
216pub struct BackendKeyDataMessage {
217    pub length: usize,
218    pub process_id: i32,
219    pub secret_key: i32,
220}
221
222#[derive(Debug, Clone)]
223pub struct NotificationResponseMessage {
224    pub length: usize,
225    pub process_id: i32,
226    pub channel: String,
227    pub payload: String,
228}
229
230#[derive(Debug, Clone)]
231pub struct CommandTag(pub String);
232
233#[derive(Debug, Clone)]
234pub struct DataRowMessage {
235    pub length: usize,
236    pub fields: Vec<Option<String>>,
237}
238
239pub trait NoticeOrErrorFields {
240    fn apply_fields(&mut self, fields: &std::collections::HashMap<String, String>);
241}
242
243#[derive(Debug, Clone)]
244pub struct NoticeMessage {
245    pub length: usize,
246    pub message: Option<String>,
247    pub severity: Option<String>,
248    pub code: Option<String>,
249    pub detail: Option<String>,
250    pub hint: Option<String>,
251    pub position: Option<String>,
252    pub internal_position: Option<String>,
253    pub internal_query: Option<String>,
254    pub r#where: Option<String>,
255    pub schema: Option<String>,
256    pub table: Option<String>,
257    pub column: Option<String>,
258    pub data_type: Option<String>,
259    pub constraint: Option<String>,
260    pub file: Option<String>,
261    pub line: Option<String>,
262    pub routine: Option<String>,
263}
264
265impl NoticeMessage {
266    pub fn new(length: usize, message: Option<String>) -> Self {
267        Self {
268            length,
269            message,
270            severity: None,
271            code: None,
272            detail: None,
273            hint: None,
274            position: None,
275            internal_position: None,
276            internal_query: None,
277            r#where: None,
278            schema: None,
279            table: None,
280            column: None,
281            data_type: None,
282            constraint: None,
283            file: None,
284            line: None,
285            routine: None,
286        }
287    }
288}
289
290impl NoticeOrErrorFields for NoticeMessage {
291    fn apply_fields(&mut self, fields: &std::collections::HashMap<String, String>) {
292        self.severity = fields.get("S").cloned();
293        self.code = fields.get("C").cloned();
294        self.detail = fields.get("D").cloned();
295        self.hint = fields.get("H").cloned();
296        self.position = fields.get("P").cloned();
297        self.internal_position = fields.get("p").cloned();
298        self.internal_query = fields.get("q").cloned();
299        self.r#where = fields.get("W").cloned();
300        self.schema = fields.get("s").cloned();
301        self.table = fields.get("t").cloned();
302        self.column = fields.get("c").cloned();
303        self.data_type = fields.get("d").cloned();
304        self.constraint = fields.get("n").cloned();
305        self.file = fields.get("F").cloned();
306        self.line = fields.get("L").cloned();
307        self.routine = fields.get("R").cloned();
308    }
309}
310
311#[derive(Debug, Clone)]
312pub struct DatabaseError {
313    pub length: usize,
314    pub message: String,
315    pub severity: Option<String>,
316    pub code: Option<String>,
317    pub detail: Option<String>,
318    pub hint: Option<String>,
319    pub position: Option<String>,
320    pub internal_position: Option<String>,
321    pub internal_query: Option<String>,
322    pub r#where: Option<String>,
323    pub schema: Option<String>,
324    pub table: Option<String>,
325    pub column: Option<String>,
326    pub data_type: Option<String>,
327    pub constraint: Option<String>,
328    pub file: Option<String>,
329    pub line: Option<String>,
330    pub routine: Option<String>,
331}
332
333impl DatabaseError {
334    pub fn new(length: usize, message: String) -> Self {
335        Self {
336            length,
337            message,
338            severity: None,
339            code: None,
340            detail: None,
341            hint: None,
342            position: None,
343            internal_position: None,
344            internal_query: None,
345            r#where: None,
346            schema: None,
347            table: None,
348            column: None,
349            data_type: None,
350            constraint: None,
351            file: None,
352            line: None,
353            routine: None,
354        }
355    }
356}
357
358impl std::fmt::Display for DatabaseError {
359    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360        write!(f, "{}", self.message)
361    }
362}
363
364impl std::error::Error for DatabaseError {}
365
366impl NoticeOrErrorFields for DatabaseError {
367    fn apply_fields(&mut self, fields: &std::collections::HashMap<String, String>) {
368        self.severity = fields.get("S").cloned();
369        self.code = fields.get("C").cloned();
370        self.detail = fields.get("D").cloned();
371        self.hint = fields.get("H").cloned();
372        self.position = fields.get("P").cloned();
373        self.internal_position = fields.get("p").cloned();
374        self.internal_query = fields.get("q").cloned();
375        self.r#where = fields.get("W").cloned();
376        self.schema = fields.get("s").cloned();
377        self.table = fields.get("t").cloned();
378        self.column = fields.get("c").cloned();
379        self.data_type = fields.get("d").cloned();
380        self.constraint = fields.get("n").cloned();
381        self.file = fields.get("F").cloned();
382        self.line = fields.get("L").cloned();
383        self.routine = fields.get("R").cloned();
384    }
385}
386
387#[derive(Debug, Clone)]
388pub struct AuthenticationOk {
389    pub length: usize,
390}
391
392#[derive(Debug, Clone)]
393pub struct AuthenticationCleartextPassword {
394    pub length: usize,
395}
396
397#[derive(Debug, Clone)]
398pub struct AuthenticationMD5Password {
399    pub length: usize,
400    pub salt: Vec<u8>,
401}
402
403#[derive(Debug, Clone)]
404pub struct AuthenticationSasl {
405    pub length: usize,
406    pub mechanisms: Vec<String>,
407}
408
409#[derive(Debug, Clone)]
410pub struct AuthenticationSaslContinue {
411    pub length: usize,
412    pub data: String,
413}
414
415#[derive(Debug, Clone)]
416pub struct AuthenticationSaslFinal {
417    pub length: usize,
418    pub data: String,
419}
420
421#[derive(Debug, Clone)]
422pub enum AuthenticationMessage {
423    Ok(AuthenticationOk),
424    Cleartext(AuthenticationCleartextPassword),
425    Md5(AuthenticationMD5Password),
426    Sasl(AuthenticationSasl),
427    SaslContinue(AuthenticationSaslContinue),
428    SaslFinal(AuthenticationSaslFinal),
429}
430
431impl AuthenticationMessage {
432    pub fn name(&self) -> MessageName {
433        use AuthenticationMessage::*;
434        match self {
435            Ok(_) => MessageName::AuthenticationOk,
436            Cleartext(_) => MessageName::AuthenticationCleartextPassword,
437            Md5(_) => MessageName::AuthenticationMD5Password,
438            Sasl(_) => MessageName::AuthenticationSasl,
439            SaslContinue(_) => MessageName::AuthenticationSaslContinue,
440            SaslFinal(_) => MessageName::AuthenticationSaslFinal,
441        }
442    }
443
444    pub fn length(&self) -> usize {
445        use AuthenticationMessage::*;
446        match self {
447            Ok(msg) => msg.length,
448            Cleartext(msg) => msg.length,
449            Md5(msg) => msg.length,
450            Sasl(msg) => msg.length,
451            SaslContinue(msg) => msg.length,
452            SaslFinal(msg) => msg.length,
453        }
454    }
455}
456
457pub fn collect_fields(
458    reader: &mut crate::protocol::buffer_reader::BufferReader<'_>,
459) -> Result<std::collections::HashMap<String, String>> {
460    use std::collections::HashMap;
461    let mut map = HashMap::new();
462    loop {
463        let field_type = reader.string(1)?;
464        if field_type == "\0" {
465            break;
466        }
467        let value = reader.cstring()?;
468        map.insert(field_type, value);
469    }
470    Ok(map)
471}