pgwire_replication/protocol/
messages.rs1use bytes::Buf;
2
3use crate::error::{PgWireError, Result};
4
5#[derive(Debug, Clone, Default, PartialEq, Eq)]
7pub struct ErrorFields {
8 pub severity: Option<String>,
9 pub code: Option<String>, pub message: Option<String>,
11 pub detail: Option<String>,
12 pub hint: Option<String>,
13 pub position: Option<String>,
14 pub where_: Option<String>,
15 pub schema: Option<String>,
16 pub table: Option<String>,
17 pub column: Option<String>,
18 pub data_type: Option<String>,
19 pub constraint: Option<String>,
20 pub file: Option<String>,
21 pub line: Option<String>,
22 pub routine: Option<String>,
23}
24
25impl ErrorFields {
26 pub fn parse(payload: &[u8]) -> Self {
28 let mut fields = ErrorFields::default();
29 let mut b = payload;
30
31 while !b.is_empty() {
32 let code = b[0];
33 b = &b[1..];
34 if code == 0 {
35 break;
36 }
37 if let Some(pos) = b.iter().position(|&x| x == 0) {
38 let s = String::from_utf8_lossy(&b[..pos]).to_string();
39 match code {
40 b'S' => fields.severity = Some(s),
41 b'C' => fields.code = Some(s),
42 b'M' => fields.message = Some(s),
43 b'D' => fields.detail = Some(s),
44 b'H' => fields.hint = Some(s),
45 b'P' => fields.position = Some(s),
46 b'W' => fields.where_ = Some(s),
47 b's' => fields.schema = Some(s),
48 b't' => fields.table = Some(s),
49 b'c' => fields.column = Some(s),
50 b'd' => fields.data_type = Some(s),
51 b'n' => fields.constraint = Some(s),
52 b'F' => fields.file = Some(s),
53 b'L' => fields.line = Some(s),
54 b'R' => fields.routine = Some(s),
55 _ => {} }
57 b = &b[pos + 1..];
58 } else {
59 break;
60 }
61 }
62
63 fields
64 }
65
66 pub fn to_error_string(&self) -> String {
68 match (&self.message, &self.code) {
69 (Some(m), Some(c)) => format!("{m} (SQLSTATE {c})"),
70 (Some(m), None) => m.clone(),
71 (None, Some(c)) => format!("error (SQLSTATE {c})"),
72 (None, None) => "unknown server error".to_string(),
73 }
74 }
75}
76
77pub fn parse_error_response(payload: &[u8]) -> String {
81 ErrorFields::parse(payload).to_error_string()
82}
83
84pub fn parse_auth_request(payload: &[u8]) -> Result<(i32, &[u8])> {
95 if payload.len() < 4 {
96 return Err(PgWireError::Protocol("auth request too short".into()));
97 }
98 let mut b = payload;
99 let code = b.get_i32();
100 Ok((code, b))
101}
102
103pub mod auth {
105 pub const OK: i32 = 0;
106 pub const CLEARTEXT_PASSWORD: i32 = 3;
107 pub const MD5_PASSWORD: i32 = 5;
108 pub const SASL: i32 = 10;
109 pub const SASL_CONTINUE: i32 = 11;
110 pub const SASL_FINAL: i32 = 12;
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn parse_error_response_extracts_message_and_code() {
119 let payload = [
121 b'M', b'h', b'e', b'l', b'l', b'o', 0, b'C', b'1', b'2', b'3', b'4', b'5', 0, 0,
122 ];
123 let s = parse_error_response(&payload);
124 assert!(s.contains("hello"));
125 assert!(s.contains("SQLSTATE 12345"));
126 }
127
128 #[test]
129 fn parse_error_response_handles_message_only() {
130 let payload = [b'M', b't', b'e', b's', b't', 0, 0];
131 let s = parse_error_response(&payload);
132 assert_eq!(s, "test");
133 }
134
135 #[test]
136 fn parse_error_response_handles_code_only() {
137 let payload = [b'C', b'4', b'2', b'0', b'0', b'0', 0, 0];
138 let s = parse_error_response(&payload);
139 assert_eq!(s, "error (SQLSTATE 42000)");
140 }
141
142 #[test]
143 fn parse_error_response_handles_empty() {
144 let payload = [0];
145 let s = parse_error_response(&payload);
146 assert_eq!(s, "unknown server error");
147 }
148
149 #[test]
150 fn parse_error_response_handles_truly_empty() {
151 let payload: &[u8] = &[];
152 let s = parse_error_response(payload);
153 assert_eq!(s, "unknown server error");
154 }
155
156 #[test]
157 fn error_fields_parses_all_standard_fields() {
158 let mut payload = Vec::new();
159 payload.extend_from_slice(b"SERROR\0");
161 payload.extend_from_slice(b"C42P01\0");
162 payload.extend_from_slice(b"Mrelation \"foo\" does not exist\0");
163 payload.extend_from_slice(b"Dsome detail\0");
164 payload.extend_from_slice(b"Htry this\0");
165 payload.extend_from_slice(b"sschema_name\0");
166 payload.extend_from_slice(b"ttable_name\0");
167 payload.extend_from_slice(b"Fparse_relation.c\0");
168 payload.extend_from_slice(b"L1234\0");
169 payload.extend_from_slice(b"Rsome_routine\0");
170 payload.push(0); let fields = ErrorFields::parse(&payload);
173
174 assert_eq!(fields.severity.as_deref(), Some("ERROR"));
175 assert_eq!(fields.code.as_deref(), Some("42P01"));
176 assert_eq!(
177 fields.message.as_deref(),
178 Some("relation \"foo\" does not exist")
179 );
180 assert_eq!(fields.detail.as_deref(), Some("some detail"));
181 assert_eq!(fields.hint.as_deref(), Some("try this"));
182 assert_eq!(fields.schema.as_deref(), Some("schema_name"));
183 assert_eq!(fields.table.as_deref(), Some("table_name"));
184 assert_eq!(fields.file.as_deref(), Some("parse_relation.c"));
185 assert_eq!(fields.line.as_deref(), Some("1234"));
186 assert_eq!(fields.routine.as_deref(), Some("some_routine"));
187 }
188
189 #[test]
190 fn error_fields_handles_truncated_payload() {
191 let payload = [b'M', b'h', b'e', b'l', b'l', b'o'];
193 let fields = ErrorFields::parse(&payload);
194 assert!(fields.message.is_none());
196 }
197
198 #[test]
199 fn error_fields_ignores_unknown_field_codes() {
200 let payload = [b'X', b'u', b'n', b'k', 0, b'M', b'o', b'k', 0, 0];
201 let fields = ErrorFields::parse(&payload);
202 assert_eq!(fields.message.as_deref(), Some("ok"));
204 }
205
206 #[test]
207 fn parse_auth_request_ok() {
208 let payload = [0, 0, 0, 0]; let (code, rest) = parse_auth_request(&payload).unwrap();
210 assert_eq!(code, auth::OK);
211 assert!(rest.is_empty());
212 }
213
214 #[test]
215 fn parse_auth_request_md5_with_salt() {
216 let mut payload = Vec::new();
217 payload.extend_from_slice(&5i32.to_be_bytes()); payload.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]); let (code, salt) = parse_auth_request(&payload).unwrap();
221 assert_eq!(code, auth::MD5_PASSWORD);
222 assert_eq!(salt, &[0xDE, 0xAD, 0xBE, 0xEF]);
223 }
224
225 #[test]
226 fn parse_auth_request_sasl_with_mechanisms() {
227 let mut payload = Vec::new();
228 payload.extend_from_slice(&10i32.to_be_bytes()); payload.extend_from_slice(b"SCRAM-SHA-256\0");
230 payload.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
231 payload.push(0); let (code, mechanisms) = parse_auth_request(&payload).unwrap();
234 assert_eq!(code, auth::SASL);
235 assert!(mechanisms.starts_with(b"SCRAM-SHA-256"));
236 }
237
238 #[test]
239 fn parse_auth_request_rejects_short_payload() {
240 let payload = [0, 0, 0]; let err = parse_auth_request(&payload).unwrap_err();
242 assert!(err.to_string().contains("too short"));
243 }
244
245 #[test]
246 fn auth_constants_have_correct_values() {
247 assert_eq!(auth::OK, 0);
248 assert_eq!(auth::CLEARTEXT_PASSWORD, 3);
249 assert_eq!(auth::MD5_PASSWORD, 5);
250 assert_eq!(auth::SASL, 10);
251 assert_eq!(auth::SASL_CONTINUE, 11);
252 assert_eq!(auth::SASL_FINAL, 12);
253 }
254}