1#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9 Startup { user: String, database: String },
11 PasswordMessage(String),
12 Query(String),
13 Parse {
15 name: String,
16 query: String,
17 param_types: Vec<u32>,
18 },
19 Bind {
21 portal: String,
22 statement: String,
23 params: Vec<Option<Vec<u8>>>,
24 },
25 Execute { portal: String, max_rows: i32 },
27 Sync,
28 Terminate,
29 SASLInitialResponse { mechanism: String, data: Vec<u8> },
31 SASLResponse(Vec<u8>),
33}
34
35#[derive(Debug, Clone)]
37pub enum BackendMessage {
38 AuthenticationOk,
40 AuthenticationMD5Password([u8; 4]),
41 AuthenticationSASL(Vec<String>),
42 AuthenticationSASLContinue(Vec<u8>),
43 AuthenticationSASLFinal(Vec<u8>),
44 ParameterStatus {
46 name: String,
47 value: String,
48 },
49 BackendKeyData {
51 process_id: i32,
52 secret_key: i32,
53 },
54 ReadyForQuery(TransactionStatus),
55 RowDescription(Vec<FieldDescription>),
56 DataRow(Vec<Option<Vec<u8>>>),
57 CommandComplete(String),
58 ErrorResponse(ErrorFields),
59 ParseComplete,
60 BindComplete,
61 NoData,
62 CopyInResponse {
64 format: u8,
65 column_formats: Vec<u8>,
66 },
67 CopyOutResponse {
69 format: u8,
70 column_formats: Vec<u8>,
71 },
72 CopyData(Vec<u8>),
73 CopyDone,
74 NotificationResponse {
76 process_id: i32,
77 channel: String,
78 payload: String,
79 },
80 EmptyQueryResponse,
81 NoticeResponse(ErrorFields),
83 ParameterDescription(Vec<u32>),
86}
87
88#[derive(Debug, Clone, Copy)]
90pub enum TransactionStatus {
91 Idle, InBlock, Failed, }
95
96#[derive(Debug, Clone)]
98pub struct FieldDescription {
99 pub name: String,
100 pub table_oid: u32,
101 pub column_attr: i16,
102 pub type_oid: u32,
103 pub type_size: i16,
104 pub type_modifier: i32,
105 pub format: i16,
106}
107
108#[derive(Debug, Clone, Default)]
110pub struct ErrorFields {
111 pub severity: String,
112 pub code: String,
113 pub message: String,
114 pub detail: Option<String>,
115 pub hint: Option<String>,
116}
117
118impl FrontendMessage {
119 pub fn encode(&self) -> Vec<u8> {
121 match self {
122 FrontendMessage::Startup { user, database } => {
123 let mut buf = Vec::new();
124 buf.extend_from_slice(&196608i32.to_be_bytes());
126 buf.extend_from_slice(b"user\0");
128 buf.extend_from_slice(user.as_bytes());
129 buf.push(0);
130 buf.extend_from_slice(b"database\0");
131 buf.extend_from_slice(database.as_bytes());
132 buf.push(0);
133 buf.push(0); let len = (buf.len() + 4) as i32;
137 let mut result = len.to_be_bytes().to_vec();
138 result.extend(buf);
139 result
140 }
141 FrontendMessage::Query(sql) => {
142 let mut buf = Vec::new();
143 buf.push(b'Q');
144 let content = format!("{}\0", sql);
145 let len = (content.len() + 4) as i32;
146 buf.extend_from_slice(&len.to_be_bytes());
147 buf.extend_from_slice(content.as_bytes());
148 buf
149 }
150 FrontendMessage::Terminate => {
151 vec![b'X', 0, 0, 0, 4]
152 }
153 FrontendMessage::SASLInitialResponse { mechanism, data } => {
154 let mut buf = Vec::new();
155 buf.push(b'p'); let mut content = Vec::new();
158 content.extend_from_slice(mechanism.as_bytes());
159 content.push(0); content.extend_from_slice(&(data.len() as i32).to_be_bytes());
161 content.extend_from_slice(data);
162
163 let len = (content.len() + 4) as i32;
164 buf.extend_from_slice(&len.to_be_bytes());
165 buf.extend_from_slice(&content);
166 buf
167 }
168 FrontendMessage::SASLResponse(data) => {
169 let mut buf = Vec::new();
170 buf.push(b'p');
171
172 let len = (data.len() + 4) as i32;
173 buf.extend_from_slice(&len.to_be_bytes());
174 buf.extend_from_slice(data);
175 buf
176 }
177 FrontendMessage::PasswordMessage(password) => {
178 let mut buf = Vec::new();
179 buf.push(b'p');
180 let content = format!("{}\0", password);
181 let len = (content.len() + 4) as i32;
182 buf.extend_from_slice(&len.to_be_bytes());
183 buf.extend_from_slice(content.as_bytes());
184 buf
185 }
186 FrontendMessage::Parse { name, query, param_types } => {
187 let mut buf = Vec::new();
188 buf.push(b'P');
189
190 let mut content = Vec::new();
191 content.extend_from_slice(name.as_bytes());
192 content.push(0);
193 content.extend_from_slice(query.as_bytes());
194 content.push(0);
195 content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
196 for oid in param_types {
197 content.extend_from_slice(&oid.to_be_bytes());
198 }
199
200 let len = (content.len() + 4) as i32;
201 buf.extend_from_slice(&len.to_be_bytes());
202 buf.extend_from_slice(&content);
203 buf
204 }
205 FrontendMessage::Bind { portal, statement, params } => {
206 let mut buf = Vec::new();
207 buf.push(b'B');
208
209 let mut content = Vec::new();
210 content.extend_from_slice(portal.as_bytes());
211 content.push(0);
212 content.extend_from_slice(statement.as_bytes());
213 content.push(0);
214 content.extend_from_slice(&0i16.to_be_bytes());
216 content.extend_from_slice(&(params.len() as i16).to_be_bytes());
218 for param in params {
219 match param {
220 Some(data) => {
221 content.extend_from_slice(&(data.len() as i32).to_be_bytes());
222 content.extend_from_slice(data);
223 }
224 None => content.extend_from_slice(&(-1i32).to_be_bytes()),
225 }
226 }
227 content.extend_from_slice(&0i16.to_be_bytes());
229
230 let len = (content.len() + 4) as i32;
231 buf.extend_from_slice(&len.to_be_bytes());
232 buf.extend_from_slice(&content);
233 buf
234 }
235 FrontendMessage::Execute { portal, max_rows } => {
236 let mut buf = Vec::new();
237 buf.push(b'E');
238
239 let mut content = Vec::new();
240 content.extend_from_slice(portal.as_bytes());
241 content.push(0);
242 content.extend_from_slice(&max_rows.to_be_bytes());
243
244 let len = (content.len() + 4) as i32;
245 buf.extend_from_slice(&len.to_be_bytes());
246 buf.extend_from_slice(&content);
247 buf
248 }
249 FrontendMessage::Sync => {
250 vec![b'S', 0, 0, 0, 4]
251 }
252 }
253 }
254}
255
256impl BackendMessage {
257 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
259 if buf.len() < 5 {
260 return Err("Buffer too short".to_string());
261 }
262
263 let msg_type = buf[0];
264 let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
265
266 if buf.len() < len + 1 {
267 return Err("Incomplete message".to_string());
268 }
269
270 let payload = &buf[5..len + 1];
271
272 let message = match msg_type {
273 b'R' => Self::decode_auth(payload)?,
274 b'S' => Self::decode_parameter_status(payload)?,
275 b'K' => Self::decode_backend_key(payload)?,
276 b'Z' => Self::decode_ready_for_query(payload)?,
277 b'T' => Self::decode_row_description(payload)?,
278 b'D' => Self::decode_data_row(payload)?,
279 b'C' => Self::decode_command_complete(payload)?,
280 b'E' => Self::decode_error_response(payload)?,
281 b'1' => BackendMessage::ParseComplete,
282 b'2' => BackendMessage::BindComplete,
283 b'n' => BackendMessage::NoData,
284 b't' => Self::decode_parameter_description(payload)?,
285 b'G' => Self::decode_copy_in_response(payload)?,
286 b'H' => Self::decode_copy_out_response(payload)?,
287 b'd' => BackendMessage::CopyData(payload.to_vec()),
288 b'c' => BackendMessage::CopyDone,
289 b'A' => Self::decode_notification_response(payload)?,
290 b'I' => BackendMessage::EmptyQueryResponse,
291 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
292 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
293 };
294
295 Ok((message, len + 1))
296 }
297
298 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
299 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
300 match auth_type {
301 0 => Ok(BackendMessage::AuthenticationOk),
302 5 => {
303 let salt: [u8; 4] = payload[4..8].try_into().unwrap();
304 Ok(BackendMessage::AuthenticationMD5Password(salt))
305 }
306 10 => {
307 let mut mechanisms = Vec::new();
309 let mut pos = 4;
310 while pos < payload.len() && payload[pos] != 0 {
311 let end = payload[pos..]
312 .iter()
313 .position(|&b| b == 0)
314 .map(|p| pos + p)
315 .unwrap_or(payload.len());
316 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
317 pos = end + 1;
318 }
319 Ok(BackendMessage::AuthenticationSASL(mechanisms))
320 }
321 11 => {
322 Ok(BackendMessage::AuthenticationSASLContinue(
324 payload[4..].to_vec(),
325 ))
326 }
327 12 => {
328 Ok(BackendMessage::AuthenticationSASLFinal(
330 payload[4..].to_vec(),
331 ))
332 }
333 _ => Err(format!("Unknown auth type: {}", auth_type)),
334 }
335 }
336
337 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
338 let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
339 let empty: &[u8] = b"";
340 Ok(BackendMessage::ParameterStatus {
341 name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
342 value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
343 })
344 }
345
346 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
347 Ok(BackendMessage::BackendKeyData {
348 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
349 secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
350 })
351 }
352
353 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
354 let status = match payload[0] {
355 b'I' => TransactionStatus::Idle,
356 b'T' => TransactionStatus::InBlock,
357 b'E' => TransactionStatus::Failed,
358 _ => return Err("Unknown transaction status".to_string()),
359 };
360 Ok(BackendMessage::ReadyForQuery(status))
361 }
362
363 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
364 if payload.len() < 2 {
365 return Err("RowDescription payload too short".to_string());
366 }
367
368 let field_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
369 let mut fields = Vec::with_capacity(field_count);
370 let mut pos = 2;
371
372 for _ in 0..field_count {
373 let name_end = payload[pos..]
375 .iter()
376 .position(|&b| b == 0)
377 .ok_or("Missing null terminator in field name")?;
378 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
379 pos += name_end + 1; if pos + 18 > payload.len() {
383 return Err("RowDescription field truncated".to_string());
384 }
385
386 let table_oid = u32::from_be_bytes([
387 payload[pos],
388 payload[pos + 1],
389 payload[pos + 2],
390 payload[pos + 3],
391 ]);
392 pos += 4;
393
394 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
395 pos += 2;
396
397 let type_oid = u32::from_be_bytes([
398 payload[pos],
399 payload[pos + 1],
400 payload[pos + 2],
401 payload[pos + 3],
402 ]);
403 pos += 4;
404
405 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
406 pos += 2;
407
408 let type_modifier = i32::from_be_bytes([
409 payload[pos],
410 payload[pos + 1],
411 payload[pos + 2],
412 payload[pos + 3],
413 ]);
414 pos += 4;
415
416 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
417 pos += 2;
418
419 fields.push(FieldDescription {
420 name,
421 table_oid,
422 column_attr,
423 type_oid,
424 type_size,
425 type_modifier,
426 format,
427 });
428 }
429
430 Ok(BackendMessage::RowDescription(fields))
431 }
432
433 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
434 if payload.len() < 2 {
435 return Err("DataRow payload too short".to_string());
436 }
437
438 let column_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
439 let mut columns = Vec::with_capacity(column_count);
440 let mut pos = 2;
441
442 for _ in 0..column_count {
443 if pos + 4 > payload.len() {
444 return Err("DataRow truncated".to_string());
445 }
446
447 let len = i32::from_be_bytes([
448 payload[pos],
449 payload[pos + 1],
450 payload[pos + 2],
451 payload[pos + 3],
452 ]);
453 pos += 4;
454
455 if len == -1 {
456 columns.push(None);
458 } else {
459 let len = len as usize;
460 if pos + len > payload.len() {
461 return Err("DataRow column data truncated".to_string());
462 }
463 let data = payload[pos..pos + len].to_vec();
464 pos += len;
465 columns.push(Some(data));
466 }
467 }
468
469 Ok(BackendMessage::DataRow(columns))
470 }
471
472 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
473 let tag = String::from_utf8_lossy(payload)
474 .trim_end_matches('\0')
475 .to_string();
476 Ok(BackendMessage::CommandComplete(tag))
477 }
478
479 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
480 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
481 payload,
482 )?))
483 }
484
485 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
486 let mut fields = ErrorFields::default();
487 let mut i = 0;
488 while i < payload.len() && payload[i] != 0 {
489 let field_type = payload[i];
490 i += 1;
491 let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
492 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
493 i = end + 1;
494
495 match field_type {
496 b'S' => fields.severity = value,
497 b'C' => fields.code = value,
498 b'M' => fields.message = value,
499 b'D' => fields.detail = Some(value),
500 b'H' => fields.hint = Some(value),
501 _ => {}
502 }
503 }
504 Ok(fields)
505 }
506
507 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
508 let count = if payload.len() >= 2 {
509 i16::from_be_bytes([payload[0], payload[1]]) as usize
510 } else {
511 0
512 };
513 let mut oids = Vec::with_capacity(count);
514 let mut pos = 2;
515 for _ in 0..count {
516 if pos + 4 <= payload.len() {
517 oids.push(u32::from_be_bytes([
518 payload[pos], payload[pos + 1], payload[pos + 2], payload[pos + 3],
519 ]));
520 pos += 4;
521 }
522 }
523 Ok(BackendMessage::ParameterDescription(oids))
524 }
525
526 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
527 if payload.is_empty() {
528 return Err("Empty CopyInResponse payload".to_string());
529 }
530 let format = payload[0];
531 let num_columns = if payload.len() >= 3 {
532 i16::from_be_bytes([payload[1], payload[2]]) as usize
533 } else {
534 0
535 };
536 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
537 payload[3..].iter().take(num_columns).copied().collect()
538 } else {
539 vec![]
540 };
541 Ok(BackendMessage::CopyInResponse {
542 format,
543 column_formats,
544 })
545 }
546
547 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
548 if payload.is_empty() {
549 return Err("Empty CopyOutResponse payload".to_string());
550 }
551 let format = payload[0];
552 let num_columns = if payload.len() >= 3 {
553 i16::from_be_bytes([payload[1], payload[2]]) as usize
554 } else {
555 0
556 };
557 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
558 payload[3..].iter().take(num_columns).copied().collect()
559 } else {
560 vec![]
561 };
562 Ok(BackendMessage::CopyOutResponse {
563 format,
564 column_formats,
565 })
566 }
567
568 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
569 if payload.len() < 4 {
570 return Err("NotificationResponse too short".to_string());
571 }
572 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
573
574 let mut i = 4;
576 let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
577 let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
578 i = channel_end + 1;
579
580 let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
582 let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
583
584 Ok(BackendMessage::NotificationResponse {
585 process_id,
586 channel,
587 payload: notification_payload,
588 })
589 }
590}