pgwire_replication/protocol/
messages.rs

1use bytes::Buf;
2
3use crate::error::{PgWireError, Result};
4
5/// Parsed PostgreSQL error/notice response fields
6#[derive(Debug, Clone, Default, PartialEq, Eq)]
7pub struct ErrorFields {
8    pub severity: Option<String>,
9    pub code: Option<String>, // SQLSTATE
10    pub message: Option<String>,
11    pub detail: Option<String>,
12    pub hint: Option<String>,
13    pub position: Option<String>,
14    pub where_: Option<String>,
15    pub schema: Option<String>,
16    pub table: Option<String>,
17    pub column: Option<String>,
18    pub data_type: Option<String>,
19    pub constraint: Option<String>,
20    pub file: Option<String>,
21    pub line: Option<String>,
22    pub routine: Option<String>,
23}
24
25impl ErrorFields {
26    /// Parse error fields from payload bytes
27    pub fn parse(payload: &[u8]) -> Self {
28        let mut fields = ErrorFields::default();
29        let mut b = payload;
30
31        while !b.is_empty() {
32            let code = b[0];
33            b = &b[1..];
34            if code == 0 {
35                break;
36            }
37            if let Some(pos) = b.iter().position(|&x| x == 0) {
38                let s = String::from_utf8_lossy(&b[..pos]).to_string();
39                match code {
40                    b'S' => fields.severity = Some(s),
41                    b'C' => fields.code = Some(s),
42                    b'M' => fields.message = Some(s),
43                    b'D' => fields.detail = Some(s),
44                    b'H' => fields.hint = Some(s),
45                    b'P' => fields.position = Some(s),
46                    b'W' => fields.where_ = Some(s),
47                    b's' => fields.schema = Some(s),
48                    b't' => fields.table = Some(s),
49                    b'c' => fields.column = Some(s),
50                    b'd' => fields.data_type = Some(s),
51                    b'n' => fields.constraint = Some(s),
52                    b'F' => fields.file = Some(s),
53                    b'L' => fields.line = Some(s),
54                    b'R' => fields.routine = Some(s),
55                    _ => {} // ignore unknown fields
56                }
57                b = &b[pos + 1..];
58            } else {
59                break;
60            }
61        }
62
63        fields
64    }
65
66    /// Format as a human-readable error string
67    pub fn to_error_string(&self) -> String {
68        match (&self.message, &self.code) {
69            (Some(m), Some(c)) => format!("{m} (SQLSTATE {c})"),
70            (Some(m), None) => m.clone(),
71            (None, Some(c)) => format!("error (SQLSTATE {c})"),
72            (None, None) => "unknown server error".to_string(),
73        }
74    }
75}
76
77/// Parse an ErrorResponse payload into a human-readable string.
78///
79/// For more detailed error information, use `ErrorFields::parse()` instead.
80pub fn parse_error_response(payload: &[u8]) -> String {
81    ErrorFields::parse(payload).to_error_string()
82}
83
84/// Parse an AuthenticationRequest payload.
85///
86/// Returns (auth_type, remaining_data).
87/// Auth types:
88/// - 0 = AuthenticationOk
89/// - 3 = AuthenticationCleartextPassword
90/// - 5 = AuthenticationMD5Password (data contains 4-byte salt)
91/// - 10 = AuthenticationSASL (data contains mechanism names)
92/// - 11 = AuthenticationSASLContinue
93/// - 12 = AuthenticationSASLFinal
94pub fn parse_auth_request(payload: &[u8]) -> Result<(i32, &[u8])> {
95    if payload.len() < 4 {
96        return Err(PgWireError::Protocol("auth request too short".into()));
97    }
98    let mut b = payload;
99    let code = b.get_i32();
100    Ok((code, b))
101}
102
103/// Authentication type constants
104pub mod auth {
105    pub const OK: i32 = 0;
106    pub const CLEARTEXT_PASSWORD: i32 = 3;
107    pub const MD5_PASSWORD: i32 = 5;
108    pub const SASL: i32 = 10;
109    pub const SASL_CONTINUE: i32 = 11;
110    pub const SASL_FINAL: i32 = 12;
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn parse_error_response_extracts_message_and_code() {
119        // 'M' "hello" \0 'C' "12345" \0 \0
120        let payload = [
121            b'M', b'h', b'e', b'l', b'l', b'o', 0, b'C', b'1', b'2', b'3', b'4', b'5', 0, 0,
122        ];
123        let s = parse_error_response(&payload);
124        assert!(s.contains("hello"));
125        assert!(s.contains("SQLSTATE 12345"));
126    }
127
128    #[test]
129    fn parse_error_response_handles_message_only() {
130        let payload = [b'M', b't', b'e', b's', b't', 0, 0];
131        let s = parse_error_response(&payload);
132        assert_eq!(s, "test");
133    }
134
135    #[test]
136    fn parse_error_response_handles_code_only() {
137        let payload = [b'C', b'4', b'2', b'0', b'0', b'0', 0, 0];
138        let s = parse_error_response(&payload);
139        assert_eq!(s, "error (SQLSTATE 42000)");
140    }
141
142    #[test]
143    fn parse_error_response_handles_empty() {
144        let payload = [0];
145        let s = parse_error_response(&payload);
146        assert_eq!(s, "unknown server error");
147    }
148
149    #[test]
150    fn parse_error_response_handles_truly_empty() {
151        let payload: &[u8] = &[];
152        let s = parse_error_response(payload);
153        assert_eq!(s, "unknown server error");
154    }
155
156    #[test]
157    fn error_fields_parses_all_standard_fields() {
158        let mut payload = Vec::new();
159        // Build a realistic error response
160        payload.extend_from_slice(b"SERROR\0");
161        payload.extend_from_slice(b"C42P01\0");
162        payload.extend_from_slice(b"Mrelation \"foo\" does not exist\0");
163        payload.extend_from_slice(b"Dsome detail\0");
164        payload.extend_from_slice(b"Htry this\0");
165        payload.extend_from_slice(b"sschema_name\0");
166        payload.extend_from_slice(b"ttable_name\0");
167        payload.extend_from_slice(b"Fparse_relation.c\0");
168        payload.extend_from_slice(b"L1234\0");
169        payload.extend_from_slice(b"Rsome_routine\0");
170        payload.push(0); // terminator
171
172        let fields = ErrorFields::parse(&payload);
173
174        assert_eq!(fields.severity.as_deref(), Some("ERROR"));
175        assert_eq!(fields.code.as_deref(), Some("42P01"));
176        assert_eq!(
177            fields.message.as_deref(),
178            Some("relation \"foo\" does not exist")
179        );
180        assert_eq!(fields.detail.as_deref(), Some("some detail"));
181        assert_eq!(fields.hint.as_deref(), Some("try this"));
182        assert_eq!(fields.schema.as_deref(), Some("schema_name"));
183        assert_eq!(fields.table.as_deref(), Some("table_name"));
184        assert_eq!(fields.file.as_deref(), Some("parse_relation.c"));
185        assert_eq!(fields.line.as_deref(), Some("1234"));
186        assert_eq!(fields.routine.as_deref(), Some("some_routine"));
187    }
188
189    #[test]
190    fn error_fields_handles_truncated_payload() {
191        // Missing null terminator for value
192        let payload = [b'M', b'h', b'e', b'l', b'l', b'o'];
193        let fields = ErrorFields::parse(&payload);
194        // Should not panic, just skip incomplete field
195        assert!(fields.message.is_none());
196    }
197
198    #[test]
199    fn error_fields_ignores_unknown_field_codes() {
200        let payload = [b'X', b'u', b'n', b'k', 0, b'M', b'o', b'k', 0, 0];
201        let fields = ErrorFields::parse(&payload);
202        // Unknown 'X' field ignored, 'M' field parsed
203        assert_eq!(fields.message.as_deref(), Some("ok"));
204    }
205
206    #[test]
207    fn parse_auth_request_ok() {
208        let payload = [0, 0, 0, 0]; // auth type 0 = OK
209        let (code, rest) = parse_auth_request(&payload).unwrap();
210        assert_eq!(code, auth::OK);
211        assert!(rest.is_empty());
212    }
213
214    #[test]
215    fn parse_auth_request_md5_with_salt() {
216        let mut payload = Vec::new();
217        payload.extend_from_slice(&5i32.to_be_bytes()); // MD5
218        payload.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]); // salt
219
220        let (code, salt) = parse_auth_request(&payload).unwrap();
221        assert_eq!(code, auth::MD5_PASSWORD);
222        assert_eq!(salt, &[0xDE, 0xAD, 0xBE, 0xEF]);
223    }
224
225    #[test]
226    fn parse_auth_request_sasl_with_mechanisms() {
227        let mut payload = Vec::new();
228        payload.extend_from_slice(&10i32.to_be_bytes()); // SASL
229        payload.extend_from_slice(b"SCRAM-SHA-256\0");
230        payload.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
231        payload.push(0); // terminator
232
233        let (code, mechanisms) = parse_auth_request(&payload).unwrap();
234        assert_eq!(code, auth::SASL);
235        assert!(mechanisms.starts_with(b"SCRAM-SHA-256"));
236    }
237
238    #[test]
239    fn parse_auth_request_rejects_short_payload() {
240        let payload = [0, 0, 0]; // only 3 bytes
241        let err = parse_auth_request(&payload).unwrap_err();
242        assert!(err.to_string().contains("too short"));
243    }
244
245    #[test]
246    fn auth_constants_have_correct_values() {
247        assert_eq!(auth::OK, 0);
248        assert_eq!(auth::CLEARTEXT_PASSWORD, 3);
249        assert_eq!(auth::MD5_PASSWORD, 5);
250        assert_eq!(auth::SASL, 10);
251        assert_eq!(auth::SASL_CONTINUE, 11);
252        assert_eq!(auth::SASL_FINAL, 12);
253    }
254}