1use bytes::Bytes;
2
3use crate::error::{Error, Result};
4
5#[derive(Debug)]
7pub enum BackendMessage {
8 AuthenticationOk,
9 AuthenticationCleartextPassword,
10 AuthenticationMd5Password {
11 salt: [u8; 4],
12 },
13 AuthenticationSasl {
14 mechanisms: Vec<String>,
15 },
16 AuthenticationSaslContinue {
17 data: Vec<u8>,
18 },
19 AuthenticationSaslFinal {
20 data: Vec<u8>,
21 },
22
23 BackendKeyData {
24 process_id: i32,
25 secret_key: i32,
26 },
27
28 ParameterStatus {
29 name: String,
30 value: String,
31 },
32
33 ReadyForQuery {
34 transaction_status: TransactionStatus,
35 },
36
37 RowDescription {
38 fields: Vec<FieldDescription>,
39 },
40
41 DataRow {
42 columns: DataRowColumns,
43 },
44
45 CommandComplete {
46 tag: String,
47 },
48
49 EmptyQueryResponse,
50
51 ErrorResponse {
52 fields: ErrorFields,
53 },
54
55 NoticeResponse {
56 fields: ErrorFields,
57 },
58
59 ParseComplete,
60 BindComplete,
61 CloseComplete,
62 NoData,
63 PortalSuspended,
64
65 ParameterDescription {
66 oids: Vec<u32>,
67 },
68
69 CopyInResponse {
70 format: CopyFormat,
71 column_formats: Vec<i16>,
72 },
73 CopyOutResponse {
74 format: CopyFormat,
75 column_formats: Vec<i16>,
76 },
77 CopyData {
78 data: Bytes,
79 },
80 CopyDone,
81
82 NotificationResponse {
83 process_id: i32,
84 channel: String,
85 payload: String,
86 },
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum TransactionStatus {
92 Idle,
94 InTransaction,
96 Failed,
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum CopyFormat {
103 Text,
104 Binary,
105}
106
107#[derive(Debug, Clone)]
109pub struct FieldDescription {
110 pub name: String,
111 pub table_oid: u32,
112 pub column_id: i16,
113 pub type_oid: u32,
114 pub type_size: i16,
115 pub type_modifier: i32,
116 pub format: i16,
117}
118
119#[derive(Debug, Clone)]
121pub struct ErrorFields {
122 pub severity: String,
123 pub code: String,
124 pub message: String,
125 pub detail: Option<String>,
126 pub hint: Option<String>,
127 pub position: Option<u32>,
128 pub internal_position: Option<u32>,
129 pub internal_query: Option<String>,
130 pub where_: Option<String>,
131 pub schema: Option<String>,
132 pub table: Option<String>,
133 pub column: Option<String>,
134 pub data_type: Option<String>,
135 pub constraint: Option<String>,
136 pub file: Option<String>,
137 pub line: Option<u32>,
138 pub routine: Option<String>,
139}
140
141#[derive(Debug)]
145pub struct DataRowColumns {
146 buf: Bytes,
147 columns: Vec<(usize, i32)>,
149}
150
151impl DataRowColumns {
152 pub fn len(&self) -> usize {
154 self.columns.len()
155 }
156
157 pub fn is_empty(&self) -> bool {
158 self.columns.is_empty()
159 }
160
161 pub fn get(&self, idx: usize) -> Option<Bytes> {
163 let &(offset, len) = self.columns.get(idx)?;
164 if len < 0 {
165 None } else {
167 Some(self.buf.slice(offset..offset + len as usize))
168 }
169 }
170
171 pub fn is_null(&self, idx: usize) -> bool {
173 self.columns.get(idx).map_or(true, |&(_, len)| len < 0)
174 }
175}
176
177pub fn decode(msg_type: u8, body: Bytes) -> Result<BackendMessage> {
181 match msg_type {
182 b'R' => decode_auth(&body),
183 b'K' => decode_backend_key_data(&body),
184 b'S' => decode_parameter_status(&body),
185 b'Z' => decode_ready_for_query(&body),
186 b'T' => decode_row_description(&body),
187 b'D' => decode_data_row(body),
188 b'C' => decode_command_complete(&body),
189 b'I' => Ok(BackendMessage::EmptyQueryResponse),
190 b'E' => decode_error_response(&body),
191 b'N' => decode_notice_response(&body),
192 b'1' => Ok(BackendMessage::ParseComplete),
193 b'2' => Ok(BackendMessage::BindComplete),
194 b'3' => Ok(BackendMessage::CloseComplete),
195 b'n' => Ok(BackendMessage::NoData),
196 b's' => Ok(BackendMessage::PortalSuspended),
197 b't' => decode_parameter_description(&body),
198 b'G' => decode_copy_in_response(&body),
199 b'H' => decode_copy_out_response(&body),
200 b'd' => Ok(BackendMessage::CopyData { data: body }),
201 b'c' => Ok(BackendMessage::CopyDone),
202 b'A' => decode_notification(&body),
203 _ => Err(Error::protocol(format!(
204 "unknown message type: 0x{msg_type:02x}"
205 ))),
206 }
207}
208
209fn decode_auth(body: &[u8]) -> Result<BackendMessage> {
212 if body.len() < 4 {
213 return Err(Error::protocol("auth message too short"));
214 }
215 let auth_type = read_i32(body, 0);
216
217 match auth_type {
218 0 => Ok(BackendMessage::AuthenticationOk),
219 3 => Ok(BackendMessage::AuthenticationCleartextPassword),
220 5 => {
221 if body.len() < 8 {
222 return Err(Error::protocol("MD5 auth message too short"));
223 }
224 let mut salt = [0u8; 4];
225 salt.copy_from_slice(&body[4..8]);
226 Ok(BackendMessage::AuthenticationMd5Password { salt })
227 }
228 10 => {
229 let mut mechanisms = Vec::new();
231 let mut pos = 4;
232 loop {
233 if pos >= body.len() {
234 break;
235 }
236 let s = read_cstr(body, &mut pos)?;
237 if s.is_empty() {
238 break;
239 }
240 mechanisms.push(s);
241 }
242 Ok(BackendMessage::AuthenticationSasl { mechanisms })
243 }
244 11 => Ok(BackendMessage::AuthenticationSaslContinue {
245 data: body[4..].to_vec(),
246 }),
247 12 => Ok(BackendMessage::AuthenticationSaslFinal {
248 data: body[4..].to_vec(),
249 }),
250 _ => Err(Error::protocol(format!(
251 "unsupported auth type: {auth_type}"
252 ))),
253 }
254}
255
256fn decode_backend_key_data(body: &[u8]) -> Result<BackendMessage> {
257 if body.len() < 8 {
258 return Err(Error::protocol("BackendKeyData too short"));
259 }
260 Ok(BackendMessage::BackendKeyData {
261 process_id: read_i32(body, 0),
262 secret_key: read_i32(body, 4),
263 })
264}
265
266fn decode_parameter_status(body: &[u8]) -> Result<BackendMessage> {
267 let mut pos = 0;
268 let name = read_cstr(body, &mut pos)?;
269 let value = read_cstr(body, &mut pos)?;
270 Ok(BackendMessage::ParameterStatus { name, value })
271}
272
273fn decode_ready_for_query(body: &[u8]) -> Result<BackendMessage> {
274 if body.is_empty() {
275 return Err(Error::protocol("ReadyForQuery empty"));
276 }
277 let status = match body[0] {
278 b'I' => TransactionStatus::Idle,
279 b'T' => TransactionStatus::InTransaction,
280 b'E' => TransactionStatus::Failed,
281 s => return Err(Error::protocol(format!("unknown transaction status: {s}"))),
282 };
283 Ok(BackendMessage::ReadyForQuery {
284 transaction_status: status,
285 })
286}
287
288fn decode_row_description(body: &[u8]) -> Result<BackendMessage> {
289 if body.len() < 2 {
290 return Err(Error::protocol("RowDescription too short"));
291 }
292 let field_count = read_i16(body, 0) as usize;
293 let mut fields = Vec::with_capacity(field_count);
294 let mut pos = 2;
295
296 for _ in 0..field_count {
297 let name = read_cstr(body, &mut pos)?;
298
299 if pos + 18 > body.len() {
300 return Err(Error::protocol("RowDescription field truncated"));
301 }
302
303 let table_oid = read_u32(body, pos);
304 let column_id = read_i16(body, pos + 4);
305 let type_oid = read_u32(body, pos + 6);
306 let type_size = read_i16(body, pos + 10);
307 let type_modifier = read_i32(body, pos + 12);
308 let format = read_i16(body, pos + 16);
309 pos += 18;
310
311 fields.push(FieldDescription {
312 name,
313 table_oid,
314 column_id,
315 type_oid,
316 type_size,
317 type_modifier,
318 format,
319 });
320 }
321
322 Ok(BackendMessage::RowDescription { fields })
323}
324
325fn decode_data_row(body: Bytes) -> Result<BackendMessage> {
326 if body.len() < 2 {
327 return Err(Error::protocol("DataRow too short"));
328 }
329 let col_count = read_i16(&body, 0) as usize;
330 let mut columns = Vec::with_capacity(col_count);
331 let mut pos = 2;
332
333 for _ in 0..col_count {
334 if pos + 4 > body.len() {
335 return Err(Error::protocol("DataRow column truncated"));
336 }
337 let len = read_i32(&body, pos);
338 pos += 4;
339
340 if len < 0 {
341 columns.push((0, -1)); } else {
343 let len_usize = len as usize;
344 if pos + len_usize > body.len() {
345 return Err(Error::protocol("DataRow column data truncated"));
346 }
347 columns.push((pos, len));
348 pos += len_usize;
349 }
350 }
351
352 Ok(BackendMessage::DataRow {
353 columns: DataRowColumns { buf: body, columns },
354 })
355}
356
357fn decode_command_complete(body: &[u8]) -> Result<BackendMessage> {
358 let mut pos = 0;
359 let tag = read_cstr(body, &mut pos)?;
360 Ok(BackendMessage::CommandComplete { tag })
361}
362
363fn decode_error_notice_fields(body: &[u8]) -> Result<ErrorFields> {
364 let mut severity = String::new();
365 let mut code = String::new();
366 let mut message = String::new();
367 let mut detail = None;
368 let mut hint = None;
369 let mut position = None;
370 let mut internal_position = None;
371 let mut internal_query = None;
372 let mut where_ = None;
373 let mut schema = None;
374 let mut table = None;
375 let mut column = None;
376 let mut data_type = None;
377 let mut constraint = None;
378 let mut file = None;
379 let mut line = None;
380 let mut routine = None;
381
382 let mut pos = 0;
383 loop {
384 if pos >= body.len() {
385 break;
386 }
387 let field_type = body[pos];
388 pos += 1;
389 if field_type == 0 {
390 break;
391 }
392 let value = read_cstr(body, &mut pos)?;
393
394 match field_type {
395 b'S' => severity = value,
396 b'C' => code = value,
397 b'M' => message = value,
398 b'D' => detail = Some(value),
399 b'H' => hint = Some(value),
400 b'P' => position = value.parse().ok(),
401 b'p' => internal_position = value.parse().ok(),
402 b'q' => internal_query = Some(value),
403 b'W' => where_ = Some(value),
404 b's' => schema = Some(value),
405 b't' => table = Some(value),
406 b'c' => column = Some(value),
407 b'd' => data_type = Some(value),
408 b'n' => constraint = Some(value),
409 b'F' => file = Some(value),
410 b'L' => line = value.parse().ok(),
411 b'R' => routine = Some(value),
412 _ => {} }
414 }
415
416 Ok(ErrorFields {
417 severity,
418 code,
419 message,
420 detail,
421 hint,
422 position,
423 internal_position,
424 internal_query,
425 where_,
426 schema,
427 table,
428 column,
429 data_type,
430 constraint,
431 file,
432 line,
433 routine,
434 })
435}
436
437fn decode_error_response(body: &[u8]) -> Result<BackendMessage> {
438 let fields = decode_error_notice_fields(body)?;
439 Ok(BackendMessage::ErrorResponse { fields })
440}
441
442fn decode_notice_response(body: &[u8]) -> Result<BackendMessage> {
443 let fields = decode_error_notice_fields(body)?;
444 Ok(BackendMessage::NoticeResponse { fields })
445}
446
447fn decode_parameter_description(body: &[u8]) -> Result<BackendMessage> {
448 if body.len() < 2 {
449 return Err(Error::protocol("ParameterDescription too short"));
450 }
451 let count = read_i16(body, 0) as usize;
452 let mut oids = Vec::with_capacity(count);
453 let mut pos = 2;
454
455 for _ in 0..count {
456 if pos + 4 > body.len() {
457 return Err(Error::protocol("ParameterDescription truncated"));
458 }
459 oids.push(read_u32(body, pos));
460 pos += 4;
461 }
462
463 Ok(BackendMessage::ParameterDescription { oids })
464}
465
466fn decode_copy_response(body: &[u8]) -> Result<(CopyFormat, Vec<i16>)> {
467 if body.len() < 3 {
468 return Err(Error::protocol("CopyResponse too short"));
469 }
470 let format = match body[0] {
471 0 => CopyFormat::Text,
472 1 => CopyFormat::Binary,
473 f => return Err(Error::protocol(format!("unknown copy format: {f}"))),
474 };
475 let col_count = read_i16(body, 1) as usize;
476 let mut column_formats = Vec::with_capacity(col_count);
477 let mut pos = 3;
478
479 for _ in 0..col_count {
480 if pos + 2 > body.len() {
481 return Err(Error::protocol("CopyResponse column formats truncated"));
482 }
483 column_formats.push(read_i16(body, pos));
484 pos += 2;
485 }
486
487 Ok((format, column_formats))
488}
489
490fn decode_copy_in_response(body: &[u8]) -> Result<BackendMessage> {
491 let (format, column_formats) = decode_copy_response(body)?;
492 Ok(BackendMessage::CopyInResponse {
493 format,
494 column_formats,
495 })
496}
497
498fn decode_copy_out_response(body: &[u8]) -> Result<BackendMessage> {
499 let (format, column_formats) = decode_copy_response(body)?;
500 Ok(BackendMessage::CopyOutResponse {
501 format,
502 column_formats,
503 })
504}
505
506fn decode_notification(body: &[u8]) -> Result<BackendMessage> {
507 if body.len() < 4 {
508 return Err(Error::protocol("NotificationResponse too short"));
509 }
510 let process_id = read_i32(body, 0);
511 let mut pos = 4;
512 let channel = read_cstr(body, &mut pos)?;
513 let payload = read_cstr(body, &mut pos)?;
514
515 Ok(BackendMessage::NotificationResponse {
516 process_id,
517 channel,
518 payload,
519 })
520}
521
522fn read_i32(buf: &[u8], offset: usize) -> i32 {
525 i32::from_be_bytes([
526 buf[offset],
527 buf[offset + 1],
528 buf[offset + 2],
529 buf[offset + 3],
530 ])
531}
532
533fn read_u32(buf: &[u8], offset: usize) -> u32 {
534 u32::from_be_bytes([
535 buf[offset],
536 buf[offset + 1],
537 buf[offset + 2],
538 buf[offset + 3],
539 ])
540}
541
542fn read_i16(buf: &[u8], offset: usize) -> i16 {
543 i16::from_be_bytes([buf[offset], buf[offset + 1]])
544}
545
546fn read_cstr(buf: &[u8], pos: &mut usize) -> Result<String> {
548 let start = *pos;
549 let null_pos = buf[start..]
550 .iter()
551 .position(|&b| b == 0)
552 .ok_or_else(|| Error::protocol("missing null terminator"))?;
553
554 let s = std::str::from_utf8(&buf[start..start + null_pos])
555 .map_err(|e| Error::protocol(format!("invalid UTF-8 in message: {e}")))?
556 .to_string();
557
558 *pos = start + null_pos + 1;
559 Ok(s)
560}