1#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9 Startup { user: String, database: String },
11 PasswordMessage(String),
12 Query(String),
13 Parse {
15 name: String,
16 query: String,
17 param_types: Vec<u32>,
18 },
19 Bind {
21 portal: String,
22 statement: String,
23 params: Vec<Option<Vec<u8>>>,
24 },
25 Execute { portal: String, max_rows: i32 },
27 Sync,
28 Terminate,
29 SASLInitialResponse { mechanism: String, data: Vec<u8> },
31 SASLResponse(Vec<u8>),
33 CopyFail(String),
35 Close { is_portal: bool, name: String },
37}
38
39#[derive(Debug, Clone)]
41pub enum BackendMessage {
42 AuthenticationOk,
44 AuthenticationMD5Password([u8; 4]),
45 AuthenticationSASL(Vec<String>),
46 AuthenticationSASLContinue(Vec<u8>),
47 AuthenticationSASLFinal(Vec<u8>),
48 ParameterStatus {
50 name: String,
51 value: String,
52 },
53 BackendKeyData {
55 process_id: i32,
56 secret_key: i32,
57 },
58 ReadyForQuery(TransactionStatus),
59 RowDescription(Vec<FieldDescription>),
60 DataRow(Vec<Option<Vec<u8>>>),
61 CommandComplete(String),
62 ErrorResponse(ErrorFields),
63 ParseComplete,
64 BindComplete,
65 NoData,
66 CopyInResponse {
68 format: u8,
69 column_formats: Vec<u8>,
70 },
71 CopyOutResponse {
73 format: u8,
74 column_formats: Vec<u8>,
75 },
76 CopyData(Vec<u8>),
77 CopyDone,
78 NotificationResponse {
80 process_id: i32,
81 channel: String,
82 payload: String,
83 },
84 EmptyQueryResponse,
85 NoticeResponse(ErrorFields),
87 ParameterDescription(Vec<u32>),
90 CloseComplete,
92}
93
94#[derive(Debug, Clone, Copy)]
96pub enum TransactionStatus {
97 Idle, InBlock, Failed, }
101
102#[derive(Debug, Clone)]
104pub struct FieldDescription {
105 pub name: String,
106 pub table_oid: u32,
107 pub column_attr: i16,
108 pub type_oid: u32,
109 pub type_size: i16,
110 pub type_modifier: i32,
111 pub format: i16,
112}
113
114#[derive(Debug, Clone, Default)]
116pub struct ErrorFields {
117 pub severity: String,
118 pub code: String,
119 pub message: String,
120 pub detail: Option<String>,
121 pub hint: Option<String>,
122}
123
124impl FrontendMessage {
125 pub fn encode(&self) -> Vec<u8> {
127 match self {
128 FrontendMessage::Startup { user, database } => {
129 let mut buf = Vec::new();
130 buf.extend_from_slice(&196608i32.to_be_bytes());
132 buf.extend_from_slice(b"user\0");
134 buf.extend_from_slice(user.as_bytes());
135 buf.push(0);
136 buf.extend_from_slice(b"database\0");
137 buf.extend_from_slice(database.as_bytes());
138 buf.push(0);
139 buf.push(0); let len = (buf.len() + 4) as i32;
143 let mut result = len.to_be_bytes().to_vec();
144 result.extend(buf);
145 result
146 }
147 FrontendMessage::Query(sql) => {
148 let mut buf = Vec::new();
149 buf.push(b'Q');
150 let content = format!("{}\0", sql);
151 let len = (content.len() + 4) as i32;
152 buf.extend_from_slice(&len.to_be_bytes());
153 buf.extend_from_slice(content.as_bytes());
154 buf
155 }
156 FrontendMessage::Terminate => {
157 vec![b'X', 0, 0, 0, 4]
158 }
159 FrontendMessage::SASLInitialResponse { mechanism, data } => {
160 let mut buf = Vec::new();
161 buf.push(b'p'); let mut content = Vec::new();
164 content.extend_from_slice(mechanism.as_bytes());
165 content.push(0); content.extend_from_slice(&(data.len() as i32).to_be_bytes());
167 content.extend_from_slice(data);
168
169 let len = (content.len() + 4) as i32;
170 buf.extend_from_slice(&len.to_be_bytes());
171 buf.extend_from_slice(&content);
172 buf
173 }
174 FrontendMessage::SASLResponse(data) => {
175 let mut buf = Vec::new();
176 buf.push(b'p');
177
178 let len = (data.len() + 4) as i32;
179 buf.extend_from_slice(&len.to_be_bytes());
180 buf.extend_from_slice(data);
181 buf
182 }
183 FrontendMessage::PasswordMessage(password) => {
184 let mut buf = Vec::new();
185 buf.push(b'p');
186 let content = format!("{}\0", password);
187 let len = (content.len() + 4) as i32;
188 buf.extend_from_slice(&len.to_be_bytes());
189 buf.extend_from_slice(content.as_bytes());
190 buf
191 }
192 FrontendMessage::Parse { name, query, param_types } => {
193 let mut buf = Vec::new();
194 buf.push(b'P');
195
196 let mut content = Vec::new();
197 content.extend_from_slice(name.as_bytes());
198 content.push(0);
199 content.extend_from_slice(query.as_bytes());
200 content.push(0);
201 content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
202 for oid in param_types {
203 content.extend_from_slice(&oid.to_be_bytes());
204 }
205
206 let len = (content.len() + 4) as i32;
207 buf.extend_from_slice(&len.to_be_bytes());
208 buf.extend_from_slice(&content);
209 buf
210 }
211 FrontendMessage::Bind { portal, statement, params } => {
212 let mut buf = Vec::new();
213 buf.push(b'B');
214
215 let mut content = Vec::new();
216 content.extend_from_slice(portal.as_bytes());
217 content.push(0);
218 content.extend_from_slice(statement.as_bytes());
219 content.push(0);
220 content.extend_from_slice(&0i16.to_be_bytes());
222 content.extend_from_slice(&(params.len() as i16).to_be_bytes());
224 for param in params {
225 match param {
226 Some(data) => {
227 content.extend_from_slice(&(data.len() as i32).to_be_bytes());
228 content.extend_from_slice(data);
229 }
230 None => content.extend_from_slice(&(-1i32).to_be_bytes()),
231 }
232 }
233 content.extend_from_slice(&0i16.to_be_bytes());
235
236 let len = (content.len() + 4) as i32;
237 buf.extend_from_slice(&len.to_be_bytes());
238 buf.extend_from_slice(&content);
239 buf
240 }
241 FrontendMessage::Execute { portal, max_rows } => {
242 let mut buf = Vec::new();
243 buf.push(b'E');
244
245 let mut content = Vec::new();
246 content.extend_from_slice(portal.as_bytes());
247 content.push(0);
248 content.extend_from_slice(&max_rows.to_be_bytes());
249
250 let len = (content.len() + 4) as i32;
251 buf.extend_from_slice(&len.to_be_bytes());
252 buf.extend_from_slice(&content);
253 buf
254 }
255 FrontendMessage::Sync => {
256 vec![b'S', 0, 0, 0, 4]
257 }
258 FrontendMessage::CopyFail(msg) => {
259 let mut buf = Vec::new();
260 buf.push(b'f');
261 let content = format!("{}\0", msg);
262 let len = (content.len() + 4) as i32;
263 buf.extend_from_slice(&len.to_be_bytes());
264 buf.extend_from_slice(content.as_bytes());
265 buf
266 }
267 FrontendMessage::Close { is_portal, name } => {
268 let mut buf = Vec::new();
269 buf.push(b'C');
270 let type_byte = if *is_portal { b'P' } else { b'S' };
271 let mut content = vec![type_byte];
272 content.extend_from_slice(name.as_bytes());
273 content.push(0);
274 let len = (content.len() + 4) as i32;
275 buf.extend_from_slice(&len.to_be_bytes());
276 buf.extend_from_slice(&content);
277 buf
278 }
279 }
280 }
281}
282
283impl BackendMessage {
284 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
286 if buf.len() < 5 {
287 return Err("Buffer too short".to_string());
288 }
289
290 let msg_type = buf[0];
291 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
292
293 if buf.len() < len + 1 {
294 return Err("Incomplete message".to_string());
295 }
296
297 let payload = &buf[5..len + 1];
298
299 let message = match msg_type {
300 b'R' => Self::decode_auth(payload)?,
301 b'S' => Self::decode_parameter_status(payload)?,
302 b'K' => Self::decode_backend_key(payload)?,
303 b'Z' => Self::decode_ready_for_query(payload)?,
304 b'T' => Self::decode_row_description(payload)?,
305 b'D' => Self::decode_data_row(payload)?,
306 b'C' => Self::decode_command_complete(payload)?,
307 b'E' => Self::decode_error_response(payload)?,
308 b'1' => BackendMessage::ParseComplete,
309 b'2' => BackendMessage::BindComplete,
310 b'3' => BackendMessage::CloseComplete,
311 b'n' => BackendMessage::NoData,
312 b't' => Self::decode_parameter_description(payload)?,
313 b'G' => Self::decode_copy_in_response(payload)?,
314 b'H' => Self::decode_copy_out_response(payload)?,
315 b'd' => BackendMessage::CopyData(payload.to_vec()),
316 b'c' => BackendMessage::CopyDone,
317 b'A' => Self::decode_notification_response(payload)?,
318 b'I' => BackendMessage::EmptyQueryResponse,
319 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
320 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
321 };
322
323 Ok((message, len + 1))
324 }
325
326 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
327 if payload.len() < 4 {
328 return Err("Auth payload too short".to_string());
329 }
330 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
331 match auth_type {
332 0 => Ok(BackendMessage::AuthenticationOk),
333 5 => {
334 if payload.len() < 8 {
335 return Err("MD5 auth payload too short (need salt)".to_string());
336 }
337 let salt: [u8; 4] = payload[4..8].try_into().expect("salt length verified above");
338 Ok(BackendMessage::AuthenticationMD5Password(salt))
339 }
340 10 => {
341 let mut mechanisms = Vec::new();
343 let mut pos = 4;
344 while pos < payload.len() && payload[pos] != 0 {
345 let end = payload[pos..]
346 .iter()
347 .position(|&b| b == 0)
348 .map(|p| pos + p)
349 .unwrap_or(payload.len());
350 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
351 pos = end + 1;
352 }
353 Ok(BackendMessage::AuthenticationSASL(mechanisms))
354 }
355 11 => {
356 Ok(BackendMessage::AuthenticationSASLContinue(
358 payload[4..].to_vec(),
359 ))
360 }
361 12 => {
362 Ok(BackendMessage::AuthenticationSASLFinal(
364 payload[4..].to_vec(),
365 ))
366 }
367 _ => Err(format!("Unknown auth type: {}", auth_type)),
368 }
369 }
370
371 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
372 let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
373 let empty: &[u8] = b"";
374 Ok(BackendMessage::ParameterStatus {
375 name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
376 value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
377 })
378 }
379
380 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
381 if payload.len() < 8 {
382 return Err("BackendKeyData payload too short".to_string());
383 }
384 Ok(BackendMessage::BackendKeyData {
385 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
386 secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
387 })
388 }
389
390 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
391 if payload.is_empty() {
392 return Err("ReadyForQuery payload empty".to_string());
393 }
394 let status = match payload[0] {
395 b'I' => TransactionStatus::Idle,
396 b'T' => TransactionStatus::InBlock,
397 b'E' => TransactionStatus::Failed,
398 _ => return Err("Unknown transaction status".to_string()),
399 };
400 Ok(BackendMessage::ReadyForQuery(status))
401 }
402
403 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
404 if payload.len() < 2 {
405 return Err("RowDescription payload too short".to_string());
406 }
407
408 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
409 if raw_count < 0 {
410 return Err(format!("RowDescription invalid field count: {}", raw_count));
411 }
412 let field_count = raw_count as usize;
413 let mut fields = Vec::with_capacity(field_count);
414 let mut pos = 2;
415
416 for _ in 0..field_count {
417 let name_end = payload[pos..]
419 .iter()
420 .position(|&b| b == 0)
421 .ok_or("Missing null terminator in field name")?;
422 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
423 pos += name_end + 1; if pos + 18 > payload.len() {
427 return Err("RowDescription field truncated".to_string());
428 }
429
430 let table_oid = u32::from_be_bytes([
431 payload[pos],
432 payload[pos + 1],
433 payload[pos + 2],
434 payload[pos + 3],
435 ]);
436 pos += 4;
437
438 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
439 pos += 2;
440
441 let type_oid = u32::from_be_bytes([
442 payload[pos],
443 payload[pos + 1],
444 payload[pos + 2],
445 payload[pos + 3],
446 ]);
447 pos += 4;
448
449 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
450 pos += 2;
451
452 let type_modifier = i32::from_be_bytes([
453 payload[pos],
454 payload[pos + 1],
455 payload[pos + 2],
456 payload[pos + 3],
457 ]);
458 pos += 4;
459
460 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
461 pos += 2;
462
463 fields.push(FieldDescription {
464 name,
465 table_oid,
466 column_attr,
467 type_oid,
468 type_size,
469 type_modifier,
470 format,
471 });
472 }
473
474 Ok(BackendMessage::RowDescription(fields))
475 }
476
477 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
478 if payload.len() < 2 {
479 return Err("DataRow payload too short".to_string());
480 }
481
482 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
483 if raw_count < 0 {
484 return Err(format!("DataRow invalid column count: {}", raw_count));
485 }
486 let column_count = raw_count as usize;
487 if column_count > (payload.len() - 2) / 4 + 1 {
489 return Err(format!(
490 "DataRow claims {} columns but payload is only {} bytes",
491 column_count,
492 payload.len()
493 ));
494 }
495 let mut columns = Vec::with_capacity(column_count);
496 let mut pos = 2;
497
498 for _ in 0..column_count {
499 if pos + 4 > payload.len() {
500 return Err("DataRow truncated".to_string());
501 }
502
503 let len = i32::from_be_bytes([
504 payload[pos],
505 payload[pos + 1],
506 payload[pos + 2],
507 payload[pos + 3],
508 ]);
509 pos += 4;
510
511 if len == -1 {
512 columns.push(None);
514 } else {
515 let len = len as usize;
516 if pos + len > payload.len() {
517 return Err("DataRow column data truncated".to_string());
518 }
519 let data = payload[pos..pos + len].to_vec();
520 pos += len;
521 columns.push(Some(data));
522 }
523 }
524
525 Ok(BackendMessage::DataRow(columns))
526 }
527
528 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
529 let tag = String::from_utf8_lossy(payload)
530 .trim_end_matches('\0')
531 .to_string();
532 Ok(BackendMessage::CommandComplete(tag))
533 }
534
535 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
536 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
537 payload,
538 )?))
539 }
540
541 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
542 let mut fields = ErrorFields::default();
543 let mut i = 0;
544 while i < payload.len() && payload[i] != 0 {
545 let field_type = payload[i];
546 i += 1;
547 let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
548 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
549 i = end + 1;
550
551 match field_type {
552 b'S' => fields.severity = value,
553 b'C' => fields.code = value,
554 b'M' => fields.message = value,
555 b'D' => fields.detail = Some(value),
556 b'H' => fields.hint = Some(value),
557 _ => {}
558 }
559 }
560 Ok(fields)
561 }
562
563 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
564 let count = if payload.len() >= 2 {
565 i16::from_be_bytes([payload[0], payload[1]]) as usize
566 } else {
567 0
568 };
569 let mut oids = Vec::with_capacity(count);
570 let mut pos = 2;
571 for _ in 0..count {
572 if pos + 4 <= payload.len() {
573 oids.push(u32::from_be_bytes([
574 payload[pos], payload[pos + 1], payload[pos + 2], payload[pos + 3],
575 ]));
576 pos += 4;
577 }
578 }
579 Ok(BackendMessage::ParameterDescription(oids))
580 }
581
582 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
583 if payload.is_empty() {
584 return Err("Empty CopyInResponse payload".to_string());
585 }
586 let format = payload[0];
587 let num_columns = if payload.len() >= 3 {
588 i16::from_be_bytes([payload[1], payload[2]]) as usize
589 } else {
590 0
591 };
592 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
593 payload[3..].iter().take(num_columns).copied().collect()
594 } else {
595 vec![]
596 };
597 Ok(BackendMessage::CopyInResponse {
598 format,
599 column_formats,
600 })
601 }
602
603 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
604 if payload.is_empty() {
605 return Err("Empty CopyOutResponse payload".to_string());
606 }
607 let format = payload[0];
608 let num_columns = if payload.len() >= 3 {
609 i16::from_be_bytes([payload[1], payload[2]]) as usize
610 } else {
611 0
612 };
613 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
614 payload[3..].iter().take(num_columns).copied().collect()
615 } else {
616 vec![]
617 };
618 Ok(BackendMessage::CopyOutResponse {
619 format,
620 column_formats,
621 })
622 }
623
624 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
625 if payload.len() < 4 {
626 return Err("NotificationResponse too short".to_string());
627 }
628 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
629
630 let mut i = 4;
632 let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
633 let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
634 i = channel_end + 1;
635
636 let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
638 let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
639
640 Ok(BackendMessage::NotificationResponse {
641 process_id,
642 channel,
643 payload: notification_payload,
644 })
645 }
646}