#![allow(clippy::cast_possible_truncation)]
use super::messages::{
CANCEL_REQUEST_CODE, DescribeKind, FrontendMessage, SSL_REQUEST_CODE, frontend_type,
};
#[derive(Debug, Clone)]
pub struct MessageWriter {
buf: Vec<u8>,
}
impl Default for MessageWriter {
fn default() -> Self {
Self::new()
}
}
impl MessageWriter {
pub fn new() -> Self {
Self::with_capacity(1024)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
buf: Vec::with_capacity(capacity),
}
}
pub fn clear(&mut self) {
self.buf.clear();
}
pub fn as_bytes(&self) -> &[u8] {
&self.buf
}
pub fn take(&mut self) -> Vec<u8> {
std::mem::take(&mut self.buf)
}
pub fn write(&mut self, msg: &FrontendMessage) -> &[u8] {
self.buf.clear();
match msg {
FrontendMessage::Startup { version, params } => {
self.write_startup(*version, params);
}
FrontendMessage::PasswordMessage(password) => {
self.write_password(password);
}
FrontendMessage::SASLInitialResponse { mechanism, data } => {
self.write_sasl_initial(mechanism, data);
}
FrontendMessage::SASLResponse(data) => {
self.write_sasl_response(data);
}
FrontendMessage::Query(query) => {
self.write_query(query);
}
FrontendMessage::Parse {
name,
query,
param_types,
} => {
self.write_parse(name, query, param_types);
}
FrontendMessage::Bind {
portal,
statement,
param_formats,
params,
result_formats,
} => {
self.write_bind(portal, statement, param_formats, params, result_formats);
}
FrontendMessage::Describe { kind, name } => {
self.write_describe(*kind, name);
}
FrontendMessage::Execute { portal, max_rows } => {
self.write_execute(portal, *max_rows);
}
FrontendMessage::Close { kind, name } => {
self.write_close(*kind, name);
}
FrontendMessage::Sync => {
self.write_sync();
}
FrontendMessage::Flush => {
self.write_flush();
}
FrontendMessage::CopyData(data) => {
self.write_copy_data(data);
}
FrontendMessage::CopyDone => {
self.write_copy_done();
}
FrontendMessage::CopyFail(message) => {
self.write_copy_fail(message);
}
FrontendMessage::Terminate => {
self.write_terminate();
}
FrontendMessage::CancelRequest {
process_id,
secret_key,
} => {
self.write_cancel_request(*process_id, *secret_key);
}
FrontendMessage::SSLRequest => {
self.write_ssl_request();
}
}
&self.buf
}
fn write_startup(&mut self, version: i32, params: &[(String, String)]) {
let mut body_len = 4; for (key, value) in params {
body_len += key.len() + 1 + value.len() + 1;
}
body_len += 1;
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.extend_from_slice(&version.to_be_bytes());
for (key, value) in params {
self.buf.extend_from_slice(key.as_bytes());
self.buf.push(0);
self.buf.extend_from_slice(value.as_bytes());
self.buf.push(0);
}
self.buf.push(0);
}
fn write_password(&mut self, password: &str) {
self.write_simple_string_message(frontend_type::PASSWORD, password);
}
fn write_sasl_initial(&mut self, mechanism: &str, data: &[u8]) {
self.buf.push(frontend_type::PASSWORD);
let body_len = mechanism.len() + 1 + 4 + data.len();
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.extend_from_slice(mechanism.as_bytes());
self.buf.push(0);
if data.is_empty() {
self.buf.extend_from_slice(&(-1_i32).to_be_bytes());
} else {
let data_len = data.len() as i32;
self.buf.extend_from_slice(&data_len.to_be_bytes());
self.buf.extend_from_slice(data);
}
}
fn write_sasl_response(&mut self, data: &[u8]) {
self.buf.push(frontend_type::PASSWORD);
let len = (data.len() + 4) as i32;
self.buf.extend_from_slice(&len.to_be_bytes());
self.buf.extend_from_slice(data);
}
fn write_query(&mut self, query: &str) {
self.write_simple_string_message(frontend_type::QUERY, query);
}
fn write_parse(&mut self, name: &str, query: &str, param_types: &[u32]) {
self.buf.push(frontend_type::PARSE);
let body_len = name.len() + 1 + query.len() + 1 + 2 + (param_types.len() * 4);
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.extend_from_slice(name.as_bytes());
self.buf.push(0);
self.buf.extend_from_slice(query.as_bytes());
self.buf.push(0);
let num_params = param_types.len() as i16;
self.buf.extend_from_slice(&num_params.to_be_bytes());
for &oid in param_types {
self.buf.extend_from_slice(&oid.to_be_bytes());
}
}
fn write_bind(
&mut self,
portal: &str,
statement: &str,
param_formats: &[i16],
params: &[Option<Vec<u8>>],
result_formats: &[i16],
) {
self.buf.push(frontend_type::BIND);
let mut body_len = portal.len() + 1 + statement.len() + 1;
body_len += 2 + (param_formats.len() * 2); body_len += 2;
for param in params {
body_len += 4; if let Some(data) = param {
body_len += data.len();
}
}
body_len += 2 + (result_formats.len() * 2);
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.extend_from_slice(portal.as_bytes());
self.buf.push(0);
self.buf.extend_from_slice(statement.as_bytes());
self.buf.push(0);
let num_formats = param_formats.len() as i16;
self.buf.extend_from_slice(&num_formats.to_be_bytes());
for &fmt in param_formats {
self.buf.extend_from_slice(&fmt.to_be_bytes());
}
let num_params = params.len() as i16;
self.buf.extend_from_slice(&num_params.to_be_bytes());
for param in params {
match param {
Some(data) => {
let len = data.len() as i32;
self.buf.extend_from_slice(&len.to_be_bytes());
self.buf.extend_from_slice(data);
}
None => {
self.buf.extend_from_slice(&(-1_i32).to_be_bytes());
}
}
}
let num_result_formats = result_formats.len() as i16;
self.buf
.extend_from_slice(&num_result_formats.to_be_bytes());
for &fmt in result_formats {
self.buf.extend_from_slice(&fmt.to_be_bytes());
}
}
fn write_describe(&mut self, kind: DescribeKind, name: &str) {
self.buf.push(frontend_type::DESCRIBE);
let body_len = 1 + name.len() + 1;
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.push(kind.as_byte());
self.buf.extend_from_slice(name.as_bytes());
self.buf.push(0);
}
fn write_execute(&mut self, portal: &str, max_rows: i32) {
self.buf.push(frontend_type::EXECUTE);
let body_len = portal.len() + 1 + 4;
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.extend_from_slice(portal.as_bytes());
self.buf.push(0);
self.buf.extend_from_slice(&max_rows.to_be_bytes());
}
fn write_close(&mut self, kind: DescribeKind, name: &str) {
self.buf.push(frontend_type::CLOSE);
let body_len = 1 + name.len() + 1;
let total_len = (body_len + 4) as i32;
self.buf.extend_from_slice(&total_len.to_be_bytes());
self.buf.push(kind.as_byte());
self.buf.extend_from_slice(name.as_bytes());
self.buf.push(0);
}
fn write_sync(&mut self) {
self.write_empty_message(frontend_type::SYNC);
}
fn write_flush(&mut self) {
self.write_empty_message(frontend_type::FLUSH);
}
fn write_copy_data(&mut self, data: &[u8]) {
self.buf.push(frontend_type::COPY_DATA);
let len = (data.len() + 4) as i32;
self.buf.extend_from_slice(&len.to_be_bytes());
self.buf.extend_from_slice(data);
}
fn write_copy_done(&mut self) {
self.write_empty_message(frontend_type::COPY_DONE);
}
fn write_copy_fail(&mut self, message: &str) {
self.write_simple_string_message(frontend_type::COPY_FAIL, message);
}
fn write_terminate(&mut self) {
self.write_empty_message(frontend_type::TERMINATE);
}
fn write_cancel_request(&mut self, process_id: i32, secret_key: i32) {
self.buf.extend_from_slice(&16_i32.to_be_bytes());
self.buf
.extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
self.buf.extend_from_slice(&process_id.to_be_bytes());
self.buf.extend_from_slice(&secret_key.to_be_bytes());
}
fn write_ssl_request(&mut self) {
self.buf.extend_from_slice(&8_i32.to_be_bytes());
self.buf.extend_from_slice(&SSL_REQUEST_CODE.to_be_bytes());
}
fn write_empty_message(&mut self, type_byte: u8) {
self.buf.push(type_byte);
self.buf.extend_from_slice(&4_i32.to_be_bytes());
}
fn write_simple_string_message(&mut self, type_byte: u8, s: &str) {
self.buf.push(type_byte);
let len = (s.len() + 5) as i32; self.buf.extend_from_slice(&len.to_be_bytes());
self.buf.extend_from_slice(s.as_bytes());
self.buf.push(0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::PROTOCOL_VERSION;
#[test]
fn test_startup_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Startup {
version: PROTOCOL_VERSION,
params: vec![
("user".to_string(), "postgres".to_string()),
("database".to_string(), "test".to_string()),
],
};
let data = writer.write(&msg);
let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert!(len > 0);
let version = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
assert_eq!(version, PROTOCOL_VERSION);
assert!(data.ends_with(&[0]));
}
#[test]
fn test_query_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Query("SELECT 1".to_string());
let data = writer.write(&msg);
assert_eq!(data[0], b'Q');
let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
assert_eq!(len, 4 + 8 + 1);
assert_eq!(data[len], 0);
}
#[test]
fn test_sync_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Sync;
let data = writer.write(&msg);
assert_eq!(data, &[b'S', 0, 0, 0, 4]);
}
#[test]
fn test_flush_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Flush;
let data = writer.write(&msg);
assert_eq!(data, &[b'H', 0, 0, 0, 4]);
}
#[test]
fn test_terminate_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Terminate;
let data = writer.write(&msg);
assert_eq!(data, &[b'X', 0, 0, 0, 4]);
}
#[test]
fn test_parse_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Parse {
name: "stmt1".to_string(),
query: "SELECT $1".to_string(),
param_types: vec![23], };
let data = writer.write(&msg);
assert_eq!(data[0], b'P');
let name_start = 5;
let name_end = data[name_start..].iter().position(|&b| b == 0).unwrap() + name_start;
assert_eq!(&data[name_start..name_end], b"stmt1");
}
#[test]
fn test_describe_statement() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Describe {
kind: DescribeKind::Statement,
name: "stmt1".to_string(),
};
let data = writer.write(&msg);
assert_eq!(data[0], b'D');
assert_eq!(data[5], b'S'); }
#[test]
fn test_describe_portal() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Describe {
kind: DescribeKind::Portal,
name: "portal1".to_string(),
};
let data = writer.write(&msg);
assert_eq!(data[0], b'D');
assert_eq!(data[5], b'P'); }
#[test]
fn test_execute_message() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Execute {
portal: String::new(),
max_rows: 0,
};
let data = writer.write(&msg);
assert_eq!(data[0], b'E');
let max_rows_offset = 5 + 1; let max_rows = i32::from_be_bytes([
data[max_rows_offset],
data[max_rows_offset + 1],
data[max_rows_offset + 2],
data[max_rows_offset + 3],
]);
assert_eq!(max_rows, 0);
}
#[test]
fn test_cancel_request() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::CancelRequest {
process_id: 12345,
secret_key: 67890,
};
let data = writer.write(&msg);
let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(len, 16);
let code = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
assert_eq!(code, CANCEL_REQUEST_CODE);
let pid = i32::from_be_bytes([data[8], data[9], data[10], data[11]]);
assert_eq!(pid, 12345);
let key = i32::from_be_bytes([data[12], data[13], data[14], data[15]]);
assert_eq!(key, 67890);
}
#[test]
fn test_ssl_request() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::SSLRequest;
let data = writer.write(&msg);
let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
assert_eq!(len, 8);
let code = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
assert_eq!(code, SSL_REQUEST_CODE);
}
#[test]
fn test_bind_with_null_params() {
let mut writer = MessageWriter::new();
let msg = FrontendMessage::Bind {
portal: String::new(),
statement: "stmt1".to_string(),
param_formats: vec![0],
params: vec![None], result_formats: vec![],
};
let data = writer.write(&msg);
assert_eq!(data[0], b'B');
let null_indicator = (-1_i32).to_be_bytes();
assert!(data.windows(4).any(|w| w == null_indicator));
}
#[test]
fn test_copy_data() {
let mut writer = MessageWriter::new();
let payload = b"hello\nworld\n";
let msg = FrontendMessage::CopyData(payload.to_vec());
let data = writer.write(&msg);
assert_eq!(data[0], b'd');
let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
assert_eq!(len, (4 + payload.len()) as i32);
assert_eq!(&data[5..], payload);
}
#[test]
fn test_writer_reuse() {
let mut writer = MessageWriter::new();
writer.write(&FrontendMessage::Sync);
assert_eq!(writer.as_bytes(), &[b'S', 0, 0, 0, 4]);
writer.write(&FrontendMessage::Flush);
assert_eq!(writer.as_bytes(), &[b'H', 0, 0, 0, 4]);
}
}