1use anyhow::{anyhow, Result};
16use bytes::{BufMut, BytesMut};
17use std::collections::HashMap;
18
19pub const PROTOCOL_VERSION: u32 = 0x00030000; #[derive(Debug, Clone)]
22pub enum FrontendMessage {
23 StartupMessage(StartupMessage),
24 PasswordMessage(String),
25 Query(String),
26 Parse {
27 name: String,
28 query: String,
29 param_types: Vec<u32>,
30 },
31 Bind {
32 portal: String,
33 statement: String,
34 formats: Vec<i16>,
35 values: Vec<Option<Vec<u8>>>,
36 result_formats: Vec<i16>,
37 },
38 Execute {
39 portal: String,
40 max_rows: i32,
41 },
42 Sync,
43 Terminate,
44 CopyData(Vec<u8>),
45 CopyDone,
46 CopyFail(String),
47 SASLInitialResponse {
49 mechanism: String,
50 data: Vec<u8>,
51 },
52 SASLResponse(Vec<u8>),
53 StandbyStatusUpdate {
55 write_lsn: u64,
56 flush_lsn: u64,
57 apply_lsn: u64,
58 timestamp: i64,
59 reply: u8,
60 },
61}
62
63#[derive(Debug, Clone)]
64pub struct StartupMessage {
65 pub parameters: HashMap<String, String>,
66}
67
68impl StartupMessage {
69 pub fn new_replication(database: &str, user: &str) -> Self {
70 let mut parameters = HashMap::new();
71 parameters.insert("user".to_string(), user.to_string());
72 parameters.insert("database".to_string(), database.to_string());
73 parameters.insert("replication".to_string(), "database".to_string());
74 Self { parameters }
75 }
76}
77
78impl FrontendMessage {
79 pub fn encode(&self, buf: &mut BytesMut) -> Result<()> {
80 match self {
81 FrontendMessage::StartupMessage(msg) => {
82 let mut msg_buf = BytesMut::new();
83 msg_buf.put_u32(PROTOCOL_VERSION);
84
85 for (key, value) in &msg.parameters {
86 msg_buf.put_slice(key.as_bytes());
87 msg_buf.put_u8(0);
88 msg_buf.put_slice(value.as_bytes());
89 msg_buf.put_u8(0);
90 }
91 msg_buf.put_u8(0); buf.put_u32((msg_buf.len() + 4) as u32);
95 buf.put_slice(&msg_buf);
96 }
97
98 FrontendMessage::PasswordMessage(password) => {
99 buf.put_u8(b'p');
100 buf.put_u32((4 + password.len() + 1) as u32);
101 buf.put_slice(password.as_bytes());
102 buf.put_u8(0);
103 }
104
105 FrontendMessage::Query(query) => {
106 buf.put_u8(b'Q');
107 buf.put_u32((4 + query.len() + 1) as u32);
108 buf.put_slice(query.as_bytes());
109 buf.put_u8(0);
110 }
111
112 FrontendMessage::Terminate => {
113 buf.put_u8(b'X');
114 buf.put_u32(4);
115 }
116
117 FrontendMessage::CopyData(data) => {
118 buf.put_u8(b'd');
119 buf.put_u32((4 + data.len()) as u32);
120 buf.put_slice(data);
121 }
122
123 FrontendMessage::CopyDone => {
124 buf.put_u8(b'c');
125 buf.put_u32(4);
126 }
127
128 FrontendMessage::CopyFail(msg) => {
129 buf.put_u8(b'f');
130 buf.put_u32((4 + msg.len() + 1) as u32);
131 buf.put_slice(msg.as_bytes());
132 buf.put_u8(0);
133 }
134
135 FrontendMessage::SASLInitialResponse { mechanism, data } => {
136 buf.put_u8(b'p');
137 let mut msg_buf = BytesMut::new();
138 msg_buf.put_slice(mechanism.as_bytes());
139 msg_buf.put_u8(0);
140 msg_buf.put_u32(data.len() as u32);
141 msg_buf.put_slice(data);
142 buf.put_u32((4 + msg_buf.len()) as u32);
143 buf.put_slice(&msg_buf);
144 }
145
146 FrontendMessage::SASLResponse(data) => {
147 buf.put_u8(b'p');
148 buf.put_u32((4 + data.len()) as u32);
149 buf.put_slice(data);
150 }
151
152 FrontendMessage::StandbyStatusUpdate {
153 write_lsn,
154 flush_lsn,
155 apply_lsn,
156 timestamp,
157 reply,
158 } => {
159 let mut data = BytesMut::new();
160 data.put_u8(b'r'); data.put_u64(*write_lsn);
162 data.put_u64(*flush_lsn);
163 data.put_u64(*apply_lsn);
164 data.put_i64(*timestamp);
165 data.put_u8(*reply);
166
167 buf.put_u8(b'd'); buf.put_u32((4 + data.len()) as u32);
169 buf.put_slice(&data);
170 }
171
172 _ => return Err(anyhow!("Unsupported message type for encoding")),
173 }
174
175 Ok(())
176 }
177}
178
179#[allow(clippy::large_enum_variant)]
180#[derive(Debug)]
181pub enum BackendMessage {
182 Authentication(AuthenticationMessage),
183 BackendKeyData {
184 process_id: i32,
185 secret_key: i32,
186 },
187 BindComplete,
188 CloseComplete,
189 CommandComplete(String),
190 CopyBothResponse,
191 CopyData(Vec<u8>),
192 CopyDone,
193 CopyInResponse,
194 CopyOutResponse,
195 DataRow(Vec<Option<Vec<u8>>>),
196 EmptyQueryResponse,
197 ErrorResponse(ErrorResponse),
198 NoData,
199 NoticeResponse(NoticeResponse),
200 NotificationResponse,
201 ParameterDescription,
202 ParameterStatus {
203 name: String,
204 value: String,
205 },
206 ParseComplete,
207 PortalSuspended,
208 ReadyForQuery(TransactionStatus),
209 RowDescription(Vec<FieldDescription>),
210 PrimaryKeepaliveMessage {
212 wal_end: u64,
213 timestamp: i64,
214 reply: u8,
215 },
216}
217
218#[derive(Debug)]
219pub enum AuthenticationMessage {
220 Ok,
221 KerberosV5,
222 CleartextPassword,
223 MD5Password([u8; 4]),
224 SCMCredential,
225 GSS,
226 GSSContinue(Vec<u8>),
227 SSPI,
228 SASL(Vec<String>),
229 SASLContinue(Vec<u8>),
230 SASLFinal(Vec<u8>),
231}
232
233#[derive(Debug)]
234pub struct ErrorResponse {
235 pub severity: String,
236 pub code: String,
237 pub message: String,
238 pub detail: Option<String>,
239 pub hint: Option<String>,
240 pub position: Option<i32>,
241 pub internal_position: Option<i32>,
242 pub internal_query: Option<String>,
243 pub where_: Option<String>,
244 pub schema: Option<String>,
245 pub table: Option<String>,
246 pub column: Option<String>,
247 pub datatype: Option<String>,
248 pub constraint: Option<String>,
249 pub file: Option<String>,
250 pub line: Option<i32>,
251 pub routine: Option<String>,
252}
253
254#[derive(Debug)]
255pub struct NoticeResponse {
256 pub severity: String,
257 pub code: String,
258 pub message: String,
259 pub detail: Option<String>,
260 pub hint: Option<String>,
261}
262
263#[derive(Debug, Clone, Copy)]
264pub enum TransactionStatus {
265 Idle,
266 Transaction,
267 Failed,
268}
269
270#[derive(Debug)]
271pub struct FieldDescription {
272 pub name: String,
273 pub table_oid: u32,
274 pub column_id: i16,
275 pub type_oid: u32,
276 pub type_size: i16,
277 pub type_modifier: i32,
278 pub format: i16,
279}
280
281pub fn parse_backend_message(msg_type: u8, body: &[u8]) -> Result<BackendMessage> {
282 match msg_type {
283 b'R' => parse_authentication(body),
284 b'K' => parse_backend_key_data(body),
285 b'Z' => parse_ready_for_query(body),
286 b'S' => parse_parameter_status(body),
287 b'E' => parse_error_response(body),
288 b'N' => parse_notice_response(body),
289 b'C' => parse_command_complete(body),
290 b'T' => parse_row_description(body),
291 b'D' => parse_data_row(body),
292 b'W' => parse_copy_both_response(body),
293 b'd' => Ok(BackendMessage::CopyData(body.to_vec())),
294 b'c' => Ok(BackendMessage::CopyDone),
295 b'1' => Ok(BackendMessage::ParseComplete),
296 b'2' => Ok(BackendMessage::BindComplete),
297 b'3' => Ok(BackendMessage::CloseComplete),
298 b'n' => Ok(BackendMessage::NoData),
299 b'I' => Ok(BackendMessage::EmptyQueryResponse),
300 b's' => Ok(BackendMessage::PortalSuspended),
301 _ => Err(anyhow!(
302 "Unknown backend message type: {}",
303 msg_type as char
304 )),
305 }
306}
307
308fn parse_authentication(body: &[u8]) -> Result<BackendMessage> {
309 if body.len() < 4 {
310 return Err(anyhow!("Authentication message too short"));
311 }
312
313 let auth_type = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
314 let auth = match auth_type {
315 0 => AuthenticationMessage::Ok,
316 3 => AuthenticationMessage::CleartextPassword,
317 5 => {
318 if body.len() < 8 {
319 return Err(anyhow!("MD5 authentication message too short"));
320 }
321 let mut salt = [0u8; 4];
322 salt.copy_from_slice(&body[4..8]);
323 AuthenticationMessage::MD5Password(salt)
324 }
325 10 => {
326 if body.len() > 4 {
327 let mechanisms = parse_sasl_mechanisms(&body[4..])?;
328 AuthenticationMessage::SASL(mechanisms)
329 } else {
330 AuthenticationMessage::SASL(vec![])
331 }
332 }
333 11 => {
334 if body.len() > 4 {
335 AuthenticationMessage::SASLContinue(body[4..].to_vec())
336 } else {
337 AuthenticationMessage::SASLContinue(vec![])
338 }
339 }
340 12 => {
341 if body.len() > 4 {
342 AuthenticationMessage::SASLFinal(body[4..].to_vec())
343 } else {
344 AuthenticationMessage::SASLFinal(vec![])
345 }
346 }
347 _ => return Err(anyhow!("Unsupported authentication type: {auth_type}")),
348 };
349
350 Ok(BackendMessage::Authentication(auth))
351}
352
353fn parse_sasl_mechanisms(body: &[u8]) -> Result<Vec<String>> {
354 let mut mechanisms = Vec::new();
355 let mut pos = 0;
356
357 while pos < body.len() {
358 let end = body[pos..]
359 .iter()
360 .position(|&b| b == 0)
361 .ok_or_else(|| anyhow!("Unterminated SASL mechanism"))?;
362
363 if end == 0 {
364 break; }
366
367 mechanisms.push(String::from_utf8_lossy(&body[pos..pos + end]).to_string());
368 pos += end + 1;
369 }
370
371 Ok(mechanisms)
372}
373
374fn parse_backend_key_data(body: &[u8]) -> Result<BackendMessage> {
375 if body.len() != 8 {
376 return Err(anyhow!("BackendKeyData message wrong size"));
377 }
378
379 let process_id = i32::from_be_bytes([body[0], body[1], body[2], body[3]]);
380 let secret_key = i32::from_be_bytes([body[4], body[5], body[6], body[7]]);
381
382 Ok(BackendMessage::BackendKeyData {
383 process_id,
384 secret_key,
385 })
386}
387
388fn parse_ready_for_query(body: &[u8]) -> Result<BackendMessage> {
389 if body.len() != 1 {
390 return Err(anyhow!("ReadyForQuery message wrong size"));
391 }
392
393 let status = match body[0] {
394 b'I' => TransactionStatus::Idle,
395 b'T' => TransactionStatus::Transaction,
396 b'E' => TransactionStatus::Failed,
397 _ => return Err(anyhow!("Unknown transaction status: {}", body[0])),
398 };
399
400 Ok(BackendMessage::ReadyForQuery(status))
401}
402
403fn parse_parameter_status(body: &[u8]) -> Result<BackendMessage> {
404 let name_end = body
405 .iter()
406 .position(|&b| b == 0)
407 .ok_or_else(|| anyhow!("Unterminated parameter name"))?;
408
409 let name = String::from_utf8_lossy(&body[..name_end]).to_string();
410
411 let value_start = name_end + 1;
412 let value_end = body[value_start..]
413 .iter()
414 .position(|&b| b == 0)
415 .ok_or_else(|| anyhow!("Unterminated parameter value"))?;
416
417 let value = String::from_utf8_lossy(&body[value_start..value_start + value_end]).to_string();
418
419 Ok(BackendMessage::ParameterStatus { name, value })
420}
421
422fn parse_error_response(body: &[u8]) -> Result<BackendMessage> {
423 let fields = parse_notice_fields(body)?;
424 Ok(BackendMessage::ErrorResponse(ErrorResponse {
425 severity: fields.get("S").cloned().unwrap_or_default(),
426 code: fields.get("C").cloned().unwrap_or_default(),
427 message: fields.get("M").cloned().unwrap_or_default(),
428 detail: fields.get("D").cloned(),
429 hint: fields.get("H").cloned(),
430 position: fields.get("P").and_then(|s| s.parse().ok()),
431 internal_position: fields.get("p").and_then(|s| s.parse().ok()),
432 internal_query: fields.get("q").cloned(),
433 where_: fields.get("W").cloned(),
434 schema: fields.get("s").cloned(),
435 table: fields.get("t").cloned(),
436 column: fields.get("c").cloned(),
437 datatype: fields.get("d").cloned(),
438 constraint: fields.get("n").cloned(),
439 file: fields.get("F").cloned(),
440 line: fields.get("L").and_then(|s| s.parse().ok()),
441 routine: fields.get("R").cloned(),
442 }))
443}
444
445fn parse_notice_response(body: &[u8]) -> Result<BackendMessage> {
446 let fields = parse_notice_fields(body)?;
447 Ok(BackendMessage::NoticeResponse(NoticeResponse {
448 severity: fields.get("S").cloned().unwrap_or_default(),
449 code: fields.get("C").cloned().unwrap_or_default(),
450 message: fields.get("M").cloned().unwrap_or_default(),
451 detail: fields.get("D").cloned(),
452 hint: fields.get("H").cloned(),
453 }))
454}
455
456fn parse_notice_fields(body: &[u8]) -> Result<HashMap<String, String>> {
457 let mut fields = HashMap::new();
458 let mut pos = 0;
459
460 while pos < body.len() && body[pos] != 0 {
461 let field_type = body[pos] as char;
462 pos += 1;
463
464 let end = body[pos..]
465 .iter()
466 .position(|&b| b == 0)
467 .ok_or_else(|| anyhow!("Unterminated field value"))?;
468
469 let value = String::from_utf8_lossy(&body[pos..pos + end]).to_string();
470 fields.insert(field_type.to_string(), value);
471
472 pos += end + 1;
473 }
474
475 Ok(fields)
476}
477
478fn parse_command_complete(body: &[u8]) -> Result<BackendMessage> {
479 let end = body
480 .iter()
481 .position(|&b| b == 0)
482 .ok_or_else(|| anyhow!("Unterminated command tag"))?;
483
484 let tag = String::from_utf8_lossy(&body[..end]).to_string();
485 Ok(BackendMessage::CommandComplete(tag))
486}
487
488fn parse_row_description(body: &[u8]) -> Result<BackendMessage> {
489 let mut pos = 0;
490 let field_count = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
491 pos += 2;
492
493 let mut fields = Vec::with_capacity(field_count);
494
495 for _ in 0..field_count {
496 let name_end = body[pos..]
497 .iter()
498 .position(|&b| b == 0)
499 .ok_or_else(|| anyhow!("Unterminated field name"))?;
500
501 let name = String::from_utf8_lossy(&body[pos..pos + name_end]).to_string();
502 pos += name_end + 1;
503
504 if pos + 18 > body.len() {
505 return Err(anyhow!("Row description truncated"));
506 }
507
508 let table_oid =
509 u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
510 pos += 4;
511
512 let column_id = i16::from_be_bytes([body[pos], body[pos + 1]]);
513 pos += 2;
514
515 let type_oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
516 pos += 4;
517
518 let type_size = i16::from_be_bytes([body[pos], body[pos + 1]]);
519 pos += 2;
520
521 let type_modifier =
522 i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
523 pos += 4;
524
525 let format = i16::from_be_bytes([body[pos], body[pos + 1]]);
526 pos += 2;
527
528 fields.push(FieldDescription {
529 name,
530 table_oid,
531 column_id,
532 type_oid,
533 type_size,
534 type_modifier,
535 format,
536 });
537 }
538
539 Ok(BackendMessage::RowDescription(fields))
540}
541
542fn parse_data_row(body: &[u8]) -> Result<BackendMessage> {
543 let mut pos = 0;
544 let column_count = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
545 pos += 2;
546
547 let mut columns = Vec::with_capacity(column_count);
548
549 for _ in 0..column_count {
550 if pos + 4 > body.len() {
551 return Err(anyhow!("Data row truncated"));
552 }
553
554 let length = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
555 pos += 4;
556
557 if length == -1 {
558 columns.push(None);
559 } else {
560 let length = length as usize;
561 if pos + length > body.len() {
562 return Err(anyhow!("Data row value truncated"));
563 }
564 columns.push(Some(body[pos..pos + length].to_vec()));
565 pos += length;
566 }
567 }
568
569 Ok(BackendMessage::DataRow(columns))
570}
571
572fn parse_copy_both_response(_body: &[u8]) -> Result<BackendMessage> {
573 Ok(BackendMessage::CopyBothResponse)
587}