1#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9 Startup { user: String, database: String },
11 PasswordMessage(String),
12 Query(String),
13 Parse {
15 name: String,
16 query: String,
17 param_types: Vec<u32>,
18 },
19 Bind {
21 portal: String,
22 statement: String,
23 params: Vec<Option<Vec<u8>>>,
24 },
25 Execute { portal: String, max_rows: i32 },
27 Sync,
28 Terminate,
29 SASLInitialResponse { mechanism: String, data: Vec<u8> },
31 SASLResponse(Vec<u8>),
33}
34
35#[derive(Debug, Clone)]
37pub enum BackendMessage {
38 AuthenticationOk,
40 AuthenticationMD5Password([u8; 4]),
41 AuthenticationSASL(Vec<String>),
42 AuthenticationSASLContinue(Vec<u8>),
43 AuthenticationSASLFinal(Vec<u8>),
44 ParameterStatus {
46 name: String,
47 value: String,
48 },
49 BackendKeyData {
51 process_id: i32,
52 secret_key: i32,
53 },
54 ReadyForQuery(TransactionStatus),
55 RowDescription(Vec<FieldDescription>),
56 DataRow(Vec<Option<Vec<u8>>>),
57 CommandComplete(String),
58 ErrorResponse(ErrorFields),
59 ParseComplete,
60 BindComplete,
61 NoData,
62 CopyInResponse {
64 format: u8,
65 column_formats: Vec<u8>,
66 },
67 CopyOutResponse {
69 format: u8,
70 column_formats: Vec<u8>,
71 },
72 CopyData(Vec<u8>),
73 CopyDone,
74 NotificationResponse {
76 process_id: i32,
77 channel: String,
78 payload: String,
79 },
80 EmptyQueryResponse,
81 NoticeResponse(ErrorFields),
83}
84
85#[derive(Debug, Clone, Copy)]
87pub enum TransactionStatus {
88 Idle, InBlock, Failed, }
92
93#[derive(Debug, Clone)]
95pub struct FieldDescription {
96 pub name: String,
97 pub table_oid: u32,
98 pub column_attr: i16,
99 pub type_oid: u32,
100 pub type_size: i16,
101 pub type_modifier: i32,
102 pub format: i16,
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct ErrorFields {
108 pub severity: String,
109 pub code: String,
110 pub message: String,
111 pub detail: Option<String>,
112 pub hint: Option<String>,
113}
114
115impl FrontendMessage {
116 pub fn encode(&self) -> Vec<u8> {
118 match self {
119 FrontendMessage::Startup { user, database } => {
120 let mut buf = Vec::new();
121 buf.extend_from_slice(&196608i32.to_be_bytes());
123 buf.extend_from_slice(b"user\0");
125 buf.extend_from_slice(user.as_bytes());
126 buf.push(0);
127 buf.extend_from_slice(b"database\0");
128 buf.extend_from_slice(database.as_bytes());
129 buf.push(0);
130 buf.push(0); let len = (buf.len() + 4) as i32;
134 let mut result = len.to_be_bytes().to_vec();
135 result.extend(buf);
136 result
137 }
138 FrontendMessage::Query(sql) => {
139 let mut buf = Vec::new();
140 buf.push(b'Q');
141 let content = format!("{}\0", sql);
142 let len = (content.len() + 4) as i32;
143 buf.extend_from_slice(&len.to_be_bytes());
144 buf.extend_from_slice(content.as_bytes());
145 buf
146 }
147 FrontendMessage::Terminate => {
148 vec![b'X', 0, 0, 0, 4]
149 }
150 FrontendMessage::SASLInitialResponse { mechanism, data } => {
151 let mut buf = Vec::new();
152 buf.push(b'p'); let mut content = Vec::new();
155 content.extend_from_slice(mechanism.as_bytes());
156 content.push(0); content.extend_from_slice(&(data.len() as i32).to_be_bytes());
158 content.extend_from_slice(data);
159
160 let len = (content.len() + 4) as i32;
161 buf.extend_from_slice(&len.to_be_bytes());
162 buf.extend_from_slice(&content);
163 buf
164 }
165 FrontendMessage::SASLResponse(data) => {
166 let mut buf = Vec::new();
167 buf.push(b'p');
168
169 let len = (data.len() + 4) as i32;
170 buf.extend_from_slice(&len.to_be_bytes());
171 buf.extend_from_slice(data);
172 buf
173 }
174 FrontendMessage::PasswordMessage(password) => {
175 let mut buf = Vec::new();
176 buf.push(b'p');
177 let content = format!("{}\0", password);
178 let len = (content.len() + 4) as i32;
179 buf.extend_from_slice(&len.to_be_bytes());
180 buf.extend_from_slice(content.as_bytes());
181 buf
182 }
183 FrontendMessage::Parse { name, query, param_types } => {
184 let mut buf = Vec::new();
185 buf.push(b'P');
186
187 let mut content = Vec::new();
188 content.extend_from_slice(name.as_bytes());
189 content.push(0);
190 content.extend_from_slice(query.as_bytes());
191 content.push(0);
192 content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
193 for oid in param_types {
194 content.extend_from_slice(&oid.to_be_bytes());
195 }
196
197 let len = (content.len() + 4) as i32;
198 buf.extend_from_slice(&len.to_be_bytes());
199 buf.extend_from_slice(&content);
200 buf
201 }
202 FrontendMessage::Bind { portal, statement, params } => {
203 let mut buf = Vec::new();
204 buf.push(b'B');
205
206 let mut content = Vec::new();
207 content.extend_from_slice(portal.as_bytes());
208 content.push(0);
209 content.extend_from_slice(statement.as_bytes());
210 content.push(0);
211 content.extend_from_slice(&0i16.to_be_bytes());
213 content.extend_from_slice(&(params.len() as i16).to_be_bytes());
215 for param in params {
216 match param {
217 Some(data) => {
218 content.extend_from_slice(&(data.len() as i32).to_be_bytes());
219 content.extend_from_slice(data);
220 }
221 None => content.extend_from_slice(&(-1i32).to_be_bytes()),
222 }
223 }
224 content.extend_from_slice(&0i16.to_be_bytes());
226
227 let len = (content.len() + 4) as i32;
228 buf.extend_from_slice(&len.to_be_bytes());
229 buf.extend_from_slice(&content);
230 buf
231 }
232 FrontendMessage::Execute { portal, max_rows } => {
233 let mut buf = Vec::new();
234 buf.push(b'E');
235
236 let mut content = Vec::new();
237 content.extend_from_slice(portal.as_bytes());
238 content.push(0);
239 content.extend_from_slice(&max_rows.to_be_bytes());
240
241 let len = (content.len() + 4) as i32;
242 buf.extend_from_slice(&len.to_be_bytes());
243 buf.extend_from_slice(&content);
244 buf
245 }
246 FrontendMessage::Sync => {
247 vec![b'S', 0, 0, 0, 4]
248 }
249 }
250 }
251}
252
253impl BackendMessage {
254 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
256 if buf.len() < 5 {
257 return Err("Buffer too short".to_string());
258 }
259
260 let msg_type = buf[0];
261 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
262
263 if buf.len() < len + 1 {
264 return Err("Incomplete message".to_string());
265 }
266
267 let payload = &buf[5..len + 1];
268
269 let message = match msg_type {
270 b'R' => Self::decode_auth(payload)?,
271 b'S' => Self::decode_parameter_status(payload)?,
272 b'K' => Self::decode_backend_key(payload)?,
273 b'Z' => Self::decode_ready_for_query(payload)?,
274 b'T' => Self::decode_row_description(payload)?,
275 b'D' => Self::decode_data_row(payload)?,
276 b'C' => Self::decode_command_complete(payload)?,
277 b'E' => Self::decode_error_response(payload)?,
278 b'1' => BackendMessage::ParseComplete,
279 b'2' => BackendMessage::BindComplete,
280 b'n' => BackendMessage::NoData,
281 b'G' => Self::decode_copy_in_response(payload)?,
282 b'H' => Self::decode_copy_out_response(payload)?,
283 b'd' => BackendMessage::CopyData(payload.to_vec()),
284 b'c' => BackendMessage::CopyDone,
285 b'A' => Self::decode_notification_response(payload)?,
286 b'I' => BackendMessage::EmptyQueryResponse,
287 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
288 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
289 };
290
291 Ok((message, len + 1))
292 }
293
294 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
295 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
296 match auth_type {
297 0 => Ok(BackendMessage::AuthenticationOk),
298 5 => {
299 let salt: [u8; 4] = payload[4..8].try_into().unwrap();
300 Ok(BackendMessage::AuthenticationMD5Password(salt))
301 }
302 10 => {
303 let mut mechanisms = Vec::new();
305 let mut pos = 4;
306 while pos < payload.len() && payload[pos] != 0 {
307 let end = payload[pos..]
308 .iter()
309 .position(|&b| b == 0)
310 .map(|p| pos + p)
311 .unwrap_or(payload.len());
312 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
313 pos = end + 1;
314 }
315 Ok(BackendMessage::AuthenticationSASL(mechanisms))
316 }
317 11 => {
318 Ok(BackendMessage::AuthenticationSASLContinue(
320 payload[4..].to_vec(),
321 ))
322 }
323 12 => {
324 Ok(BackendMessage::AuthenticationSASLFinal(
326 payload[4..].to_vec(),
327 ))
328 }
329 _ => Err(format!("Unknown auth type: {}", auth_type)),
330 }
331 }
332
333 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
334 let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
335 let empty: &[u8] = b"";
336 Ok(BackendMessage::ParameterStatus {
337 name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
338 value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
339 })
340 }
341
342 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
343 Ok(BackendMessage::BackendKeyData {
344 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
345 secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
346 })
347 }
348
349 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
350 let status = match payload[0] {
351 b'I' => TransactionStatus::Idle,
352 b'T' => TransactionStatus::InBlock,
353 b'E' => TransactionStatus::Failed,
354 _ => return Err("Unknown transaction status".to_string()),
355 };
356 Ok(BackendMessage::ReadyForQuery(status))
357 }
358
359 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
360 if payload.len() < 2 {
361 return Err("RowDescription payload too short".to_string());
362 }
363
364 let field_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
365 let mut fields = Vec::with_capacity(field_count);
366 let mut pos = 2;
367
368 for _ in 0..field_count {
369 let name_end = payload[pos..]
371 .iter()
372 .position(|&b| b == 0)
373 .ok_or("Missing null terminator in field name")?;
374 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
375 pos += name_end + 1; if pos + 18 > payload.len() {
379 return Err("RowDescription field truncated".to_string());
380 }
381
382 let table_oid = u32::from_be_bytes([
383 payload[pos],
384 payload[pos + 1],
385 payload[pos + 2],
386 payload[pos + 3],
387 ]);
388 pos += 4;
389
390 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
391 pos += 2;
392
393 let type_oid = u32::from_be_bytes([
394 payload[pos],
395 payload[pos + 1],
396 payload[pos + 2],
397 payload[pos + 3],
398 ]);
399 pos += 4;
400
401 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
402 pos += 2;
403
404 let type_modifier = i32::from_be_bytes([
405 payload[pos],
406 payload[pos + 1],
407 payload[pos + 2],
408 payload[pos + 3],
409 ]);
410 pos += 4;
411
412 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
413 pos += 2;
414
415 fields.push(FieldDescription {
416 name,
417 table_oid,
418 column_attr,
419 type_oid,
420 type_size,
421 type_modifier,
422 format,
423 });
424 }
425
426 Ok(BackendMessage::RowDescription(fields))
427 }
428
429 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
430 if payload.len() < 2 {
431 return Err("DataRow payload too short".to_string());
432 }
433
434 let column_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
435 let mut columns = Vec::with_capacity(column_count);
436 let mut pos = 2;
437
438 for _ in 0..column_count {
439 if pos + 4 > payload.len() {
440 return Err("DataRow truncated".to_string());
441 }
442
443 let len = i32::from_be_bytes([
444 payload[pos],
445 payload[pos + 1],
446 payload[pos + 2],
447 payload[pos + 3],
448 ]);
449 pos += 4;
450
451 if len == -1 {
452 columns.push(None);
454 } else {
455 let len = len as usize;
456 if pos + len > payload.len() {
457 return Err("DataRow column data truncated".to_string());
458 }
459 let data = payload[pos..pos + len].to_vec();
460 pos += len;
461 columns.push(Some(data));
462 }
463 }
464
465 Ok(BackendMessage::DataRow(columns))
466 }
467
468 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
469 let tag = String::from_utf8_lossy(payload)
470 .trim_end_matches('\0')
471 .to_string();
472 Ok(BackendMessage::CommandComplete(tag))
473 }
474
475 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
476 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
477 payload,
478 )?))
479 }
480
481 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
482 let mut fields = ErrorFields::default();
483 let mut i = 0;
484 while i < payload.len() && payload[i] != 0 {
485 let field_type = payload[i];
486 i += 1;
487 let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
488 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
489 i = end + 1;
490
491 match field_type {
492 b'S' => fields.severity = value,
493 b'C' => fields.code = value,
494 b'M' => fields.message = value,
495 b'D' => fields.detail = Some(value),
496 b'H' => fields.hint = Some(value),
497 _ => {}
498 }
499 }
500 Ok(fields)
501 }
502
503 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
504 if payload.is_empty() {
505 return Err("Empty CopyInResponse payload".to_string());
506 }
507 let format = payload[0];
508 let num_columns = if payload.len() >= 3 {
509 i16::from_be_bytes([payload[1], payload[2]]) as usize
510 } else {
511 0
512 };
513 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
514 payload[3..].iter().take(num_columns).copied().collect()
515 } else {
516 vec![]
517 };
518 Ok(BackendMessage::CopyInResponse {
519 format,
520 column_formats,
521 })
522 }
523
524 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
525 if payload.is_empty() {
526 return Err("Empty CopyOutResponse payload".to_string());
527 }
528 let format = payload[0];
529 let num_columns = if payload.len() >= 3 {
530 i16::from_be_bytes([payload[1], payload[2]]) as usize
531 } else {
532 0
533 };
534 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
535 payload[3..].iter().take(num_columns).copied().collect()
536 } else {
537 vec![]
538 };
539 Ok(BackendMessage::CopyOutResponse {
540 format,
541 column_formats,
542 })
543 }
544
545 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
546 if payload.len() < 4 {
547 return Err("NotificationResponse too short".to_string());
548 }
549 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
550
551 let mut i = 4;
553 let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
554 let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
555 i = channel_end + 1;
556
557 let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
559 let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
560
561 Ok(BackendMessage::NotificationResponse {
562 process_id,
563 channel,
564 payload: notification_payload,
565 })
566 }
567}