use crate::config::Config;
use crate::error::PgsqlError;
use crate::format::FieldFormat;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use hmac::digest::FixedOutput;
use hmac::{Hmac, Mac};
use json::{array, object, JsonValue};
use log::{debug, warn};
use rand::RngExt;
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
#[derive(Clone, Debug)]
pub struct Packet {
config: Config,
version: Version,
params: BTreeMap<String, String>,
pub auth_mechanism: String,
pub client_nonce: String,
pub server_nonce: String,
pub server_salt: String,
pub md5_salt: Vec<u8>,
pub iterations: u32,
pub error_message: ErrorMessage,
pub success_message: SuccessMessage,
pub server_proof: String,
pub service_params: BTreeMap<String, String>,
process_id: u32,
secret_key: u32,
status: Status,
sql: String,
tag: String,
statement: u32,
transactions: usize,
}
impl Packet {
pub fn new(config: Config) -> Self {
Self {
config,
version: Version::V3,
params: BTreeMap::default(),
auth_mechanism: String::new(),
client_nonce: generate_nonce(),
server_nonce: String::new(),
server_salt: String::new(),
md5_salt: Vec::new(),
iterations: 0,
error_message: ErrorMessage::default(),
success_message: SuccessMessage::default(),
server_proof: String::new(),
service_params: BTreeMap::default(),
process_id: 0,
secret_key: 0,
status: Status::None,
sql: String::new(),
tag: String::new(),
statement: 0,
transactions: 0,
}
}
pub fn set_version(&mut self, version: f64) {
self.version = Version::into(version);
}
pub fn set_params(&mut self, key: &str, value: &str) {
self.params.insert(key.to_string(), value.to_string());
}
pub fn pack_first(&mut self) -> Vec<u8> {
let mut packet = vec![];
let mut message = vec![];
self.set_version(3.0);
self.set_params("client_encoding", "UTF8");
let username = self.config.username.clone();
let database = self.config.database.clone();
self.set_params("user", &username);
self.set_params("database", &database);
packet.extend_from_slice(&self.version.i32().to_be_bytes());
for (key, value) in &self.params {
packet.extend_from_slice(key.as_bytes());
packet.push(0);
packet.extend_from_slice(value.as_bytes());
packet.push(0);
}
packet.push(0);
message.extend_from_slice(
(u32::try_from(packet.len()).unwrap() + 4)
.to_be_bytes()
.as_slice(),
);
message.extend(packet);
message
}
pub fn pack_auth(&mut self) -> Vec<u8> {
let mut packet = vec![];
let mut message = vec![];
let auth_data = format!("n,,n=,r={}", self.client_nonce);
packet.extend_from_slice(self.auth_mechanism.as_bytes());
packet.push(0);
packet.extend_from_slice(&i32::try_from(auth_data.len()).unwrap().to_be_bytes()); packet.extend_from_slice(auth_data.as_bytes());
message.extend_from_slice(b"p");
message.extend_from_slice(
(u32::try_from(packet.len()).unwrap() + 4)
.to_be_bytes()
.as_slice(),
);
message.extend(packet);
message
}
pub fn pack_md5_password(&self) -> Vec<u8> {
let mut message = vec![];
let inner = format!(
"{:x}",
md5::compute(format!("{}{}", self.config.userpass, self.config.username).as_bytes())
);
let mut salted = inner.as_bytes().to_vec();
salted.extend_from_slice(&self.md5_salt);
let hash = format!("md5{:x}", md5::compute(&salted));
message.extend_from_slice(b"p");
message.extend_from_slice(&((hash.len() as u32 + 5).to_be_bytes()));
message.extend_from_slice(hash.as_bytes());
message.push(0);
message
}
pub fn pack_cleartext_password(&self) -> Vec<u8> {
let mut message = vec![];
let password = &self.config.userpass;
message.extend_from_slice(b"p");
message.extend_from_slice(&((password.len() as u32 + 5).to_be_bytes()));
message.extend_from_slice(password.as_bytes());
message.push(0);
message
}
pub fn pack_auth_verify(&mut self) -> Vec<u8> {
let mut packet = vec![];
let mut message = vec![];
let salted_password = compute_salted_password(
self.config.userpass.as_bytes(),
self.server_salt.as_str(),
self.iterations,
);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key.clone());
let server_msg = format!("s={},i={}", self.server_salt, self.iterations);
let auth_message = format!(
"n=,r={},r={}{},{},c=biws,r={}{}",
self.client_nonce,
self.client_nonce,
self.server_nonce,
server_msg,
self.client_nonce,
self.server_nonce
);
let client_proof = compute_client_proof(&client_key, &stored_key, &auth_message);
let data = format!(
"c=biws,r={}{},p={}",
self.client_nonce, self.server_nonce, client_proof
);
packet.extend_from_slice(data.as_bytes());
message.extend_from_slice(b"p");
message.extend_from_slice(&(i32::try_from(packet.len()).unwrap() + 4).to_be_bytes());
message.extend(packet);
message
}
pub fn pack_query(&mut self, sql: &str) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
self.sql = sql.to_string();
let mut message = vec![];
message.extend(Query::Parse(self.statement, sql.to_string(), vec![]).vec_u8());
message.extend(Query::Bind(0, self.statement, vec![]).vec_u8());
message.extend(Query::Describe(b'P', 0).vec_u8());
message.extend(Query::Execute(0, 0).vec_u8());
message.extend(Query::Sync.vec_u8());
self.statement += 1;
message
}
pub fn pack_execute(&mut self, sql: &str) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
self.sql = sql.to_string();
let mut message = vec![];
message.extend(Query::Parse(self.statement, sql.to_string(), vec![]).vec_u8());
message.extend(Query::Bind(0, self.statement, vec![]).vec_u8());
message.extend(Query::Describe(b'P', 0).vec_u8());
message.extend(Query::Execute(0, 0).vec_u8());
message.extend(Query::Sync.vec_u8());
self.statement += 1;
message
}
pub fn pack_query_params(&mut self, sql: &str, params: &[Option<&str>]) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
self.sql = sql.to_string();
let param_count = u16::try_from(params.len()).unwrap_or(u16::MAX);
let param_values: Vec<Option<String>> = params
.iter()
.take(usize::from(param_count))
.map(|param| param.map(str::to_string))
.collect();
let mut message = vec![];
message.extend(
Query::Parse(
self.statement,
sql.to_string(),
vec![0u32; param_values.len()],
)
.vec_u8(),
);
message.extend(Query::Bind(0, self.statement, param_values).vec_u8());
message.extend(Query::Describe(b'S', self.statement).vec_u8());
message.extend(Query::Execute(0, 0).vec_u8());
message.extend(Query::Sync.vec_u8());
self.statement += 1;
message
}
pub fn pack_execute_params(&mut self, sql: &str, params: &[Option<&str>]) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
self.sql = sql.to_string();
let param_count = u16::try_from(params.len()).unwrap_or(u16::MAX);
let param_values: Vec<Option<String>> = params
.iter()
.take(usize::from(param_count))
.map(|param| param.map(str::to_string))
.collect();
let mut message = vec![];
message.extend(
Query::Parse(
self.statement,
sql.to_string(),
vec![0u32; param_values.len()],
)
.vec_u8(),
);
message.extend(Query::Bind(0, self.statement, param_values).vec_u8());
message.extend(Query::Describe(b'S', self.statement).vec_u8());
message.extend(Query::Execute(0, 0).vec_u8());
message.extend(Query::Sync.vec_u8());
self.statement += 1;
message
}
pub fn unpack(
&mut self,
mut message: Vec<u8>,
_level: usize,
) -> Result<SuccessMessage, PgsqlError> {
loop {
if message.is_empty() {
return Ok(self.success_message.clone());
}
if message.len() < 5 {
return Err(PgsqlError::Protocol("Message too short".into()));
}
let message_type = message.remove(0);
let length = message.drain(..4).collect::<Vec<u8>>();
let length: [u8; 4] = length
.try_into()
.map_err(|_| PgsqlError::Protocol("消息长度字段不完整".into()))?;
let raw_len = u32::from_be_bytes(length) as usize;
if raw_len < 4 {
return Err(PgsqlError::Protocol("Invalid message length".into()));
}
let mut len = raw_len - 4;
match MessageType::form(message_type) {
MessageType::Authentication => {
let auth_type = u32::from_be_bytes(
message
.drain(..4)
.collect::<Vec<u8>>()
.as_slice()
.try_into()
.unwrap_or([0, 0, 0, 0]),
);
match AuthStatus::form(auth_type) {
AuthStatus::Md5Password => {
if message.len() < 4 {
return Err(PgsqlError::Protocol("MD5 salt too short".into()));
}
let salt = message.drain(..4).collect::<Vec<u8>>();
self.md5_salt = salt;
return Ok(self.success_message.clone());
}
AuthStatus::CleartextPassword => {
return Ok(self.success_message.clone());
}
AuthStatus::ScramInitialization => {
if message.len() < 2 {
return Err(PgsqlError::Protocol(
"SCRAM initialization message too short".into(),
));
}
let data = message.drain(..message.len() - 2).collect::<Vec<u8>>();
self.auth_mechanism = String::from_utf8_lossy(&data).into_owned();
return Ok(self.success_message.clone());
}
AuthStatus::ScramChallenge => {
let data = std::mem::take(&mut message);
let text = String::from_utf8_lossy(&data);
let list: Vec<&str> = text.split(',').collect();
if list.len() < 3 {
return Err(PgsqlError::Protocol(
"SCRAM challenge invalid format".into(),
));
}
let full_nonce = list[0].trim_start_matches("r=");
if !full_nonce.starts_with(&self.client_nonce) {
return Err(PgsqlError::Auth(
"SCRAM server nonce does not start with client nonce".into(),
));
}
self.server_nonce = full_nonce[self.client_nonce.len()..].to_string();
self.server_salt = list[1].trim_start_matches("s=").to_string();
self.iterations = list[2]
.trim_start_matches("i=")
.parse::<u32>()
.unwrap_or(4096);
return Ok(self.success_message.clone());
}
AuthStatus::ScramComplete => {
if len < 4 {
return Err(PgsqlError::Protocol(
"SCRAM complete message too short".into(),
));
}
len -= 4;
if message.len() < len {
return Err(PgsqlError::Protocol(
"SCRAM complete message too short".into(),
));
}
let data = message.drain(..len).collect::<Vec<u8>>();
let text = String::from_utf8_lossy(&data);
self.server_proof = text.trim_start_matches("v=").to_string();
continue;
}
AuthStatus::None => {
return Err(self.to_query_error());
}
AuthStatus::AuthenticationOk => {
continue;
}
}
}
MessageType::ErrorResponse => {
if message.len() < len {
return Err(PgsqlError::Protocol(
"ErrorResponse message too short".into(),
));
}
let mut err_buf = message.drain(..len).collect::<Vec<u8>>();
loop {
if err_buf.is_empty() {
break;
}
let field = err_buf.remove(0);
if field == 0 {
break;
}
let target = [0];
match err_buf
.windows(target.len())
.position(|window| window == target)
{
None => break,
Some(index) => {
let field_name = String::from_utf8_lossy(&[field]).to_string();
match field_name.as_str() {
"C" => {
let data = err_buf.drain(..=index).collect::<Vec<u8>>();
self.error_message.code = self.text(&data);
}
"M" => {
let data = err_buf.drain(..=index).collect::<Vec<u8>>();
self.error_message.message = self.text(&data);
self.error_message.sql = self.sql.clone();
}
"D" | "H" => {
let data = err_buf.drain(..=index).collect::<Vec<u8>>();
self.error_message.detail = self.text(&data);
}
"P" => {
let data = err_buf.drain(..index).collect::<Vec<u8>>();
err_buf.remove(0);
let position = String::from_utf8_lossy(&data).to_string();
self.error_message.position =
position.parse::<u16>().unwrap_or(0);
self.error_message.sql = self.sql.clone();
}
_ => {
let _ = err_buf.drain(..=index).collect::<Vec<u8>>();
}
}
}
}
}
if self.config.debug {
debug!("ErrorMessage {:?}", self.error_message);
}
return Err(self.to_query_error());
}
MessageType::Notice => {
if message.len() < len {
return Err(PgsqlError::Protocol("Notice message too short".into()));
}
let mut notice_buf = message.drain(..len).collect::<Vec<u8>>();
loop {
if notice_buf.is_empty() {
break;
}
let field = notice_buf.remove(0);
if field == 0 {
break;
}
let target = [0];
match notice_buf
.windows(target.len())
.position(|window| window == target)
{
None => break,
Some(index) => {
let field_name = String::from_utf8_lossy(&[field]).to_string();
match field_name.as_str() {
"C" => {
let data = notice_buf.drain(..=index).collect::<Vec<u8>>();
self.error_message.code = self.text(&data);
}
"M" => {
let data = notice_buf.drain(..=index).collect::<Vec<u8>>();
self.error_message.message = self.text(&data);
self.error_message.sql = self.sql.clone();
}
"D" | "H" => {
let data = notice_buf.drain(..=index).collect::<Vec<u8>>();
self.error_message.detail = self.text(&data);
}
"P" => {
let data = notice_buf.drain(..index).collect::<Vec<u8>>();
notice_buf.remove(0);
let position = String::from_utf8_lossy(&data).to_string();
self.error_message.position =
position.parse::<u16>().unwrap_or(0);
self.error_message.sql = self.sql.clone();
}
_ => {
let _ = notice_buf.drain(..=index).collect::<Vec<u8>>();
}
}
}
}
}
continue;
}
MessageType::None => {
return Err(PgsqlError::Protocol(format!(
"消息类型未知: {message_type}"
)));
}
MessageType::ParameterStatus => {
if message.len() < len {
return Err(PgsqlError::Protocol(
"ParameterStatus message too short".into(),
));
}
let len_sub = message.drain(..len).collect::<Vec<u8>>();
match len_sub.windows(1).position(|window| window == [0]) {
None => {}
Some(e) => {
if e + 1 < len_sub.len() {
let name = String::from_utf8_lossy(&len_sub[..e]).into_owned();
let end = len_sub.len().saturating_sub(1);
let value = if e + 1 < end {
String::from_utf8_lossy(&len_sub[e + 1..end]).into_owned()
} else {
String::new()
};
self.service_params.insert(name, value);
}
}
}
continue;
}
MessageType::BackendKeyData => {
if message.len() < len || len < 8 {
return Err(PgsqlError::Protocol(
"BackendKeyData message too short".into(),
));
}
let len_sub = message.drain(..len).collect::<Vec<u8>>();
self.process_id =
u32::from_be_bytes(len_sub[..4].try_into().unwrap_or([0, 0, 0, 0]));
self.secret_key =
u32::from_be_bytes(len_sub[4..8].try_into().unwrap_or([0, 0, 0, 0]));
continue;
}
MessageType::ReadyForQuery => {
if message.len() < len {
return Err(PgsqlError::Protocol(
"ReadyForQuery message too short".into(),
));
}
let len_sub = message.drain(..len).collect::<Vec<u8>>();
self.status = Status::form(len_sub.as_slice());
return Ok(self.success_message.clone());
}
MessageType::ParameterDescription => {
if message.len() < len {
return Err(PgsqlError::Protocol(
"ParameterDescription message too short".into(),
));
}
let payload: Vec<u8> = message.drain(..len).collect();
self.success_message.param_oids.clear();
if payload.len() >= 2 {
let num_params = u16::from_be_bytes([payload[0], payload[1]]);
let mut oids = Vec::with_capacity(usize::from(num_params));
let mut offset = 2;
for _ in 0..num_params {
if offset + 4 > payload.len() {
break;
}
let oid = u32::from_be_bytes([
payload[offset],
payload[offset + 1],
payload[offset + 2],
payload[offset + 3],
]);
oids.push(oid);
offset += 4;
}
self.success_message.param_oids = oids;
if self.config.debug {
debug!(
"ParameterDescription OIDs: {:?}",
self.success_message.param_oids
);
}
}
continue;
}
MessageType::RowDescription => {
if len < 2 || message.len() < 2 {
return Err(PgsqlError::Protocol(
"RowDescription message too short".into(),
));
}
let field_count = u16::from_be_bytes(
message
.drain(..2)
.collect::<Vec<u8>>()
.as_slice()
.try_into()
.unwrap_or([0, 0]),
);
len -= 2;
if self.config.debug {
debug!("field_count: {field_count:?}");
}
let mut fields = if field_count > 0 && message.len() >= len {
message.drain(..len).collect::<Vec<u8>>()
} else {
vec![]
};
for _ in 0..field_count {
let name = match fields.windows(1).position(|window| window == [0]) {
None => {
continue;
}
Some(e) => {
let name = fields.drain(..e).collect::<Vec<u8>>();
if !fields.is_empty() {
fields.remove(0);
}
String::from_utf8_lossy(&name).into_owned()
}
};
if fields.len() < 18 {
continue;
}
let _table_oid = u32::from_be_bytes(
fields
.drain(..4)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0, 0, 0]),
);
let column_index = u16::from_be_bytes(
fields
.drain(..2)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0]),
);
let type_oid = u32::from_be_bytes(
fields
.drain(..4)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0, 0, 0]),
);
let _column_length = i16::from_be_bytes(
fields
.drain(..2)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0]),
);
let _type_modifier = i32::from_be_bytes(
fields
.drain(..4)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0, 0, 0]),
);
let format = u16::from_be_bytes(
fields
.drain(..2)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0]),
);
let field =
Field::new(name, column_index, FieldFormat::from_u16(format, type_oid));
self.success_message.fields.push(field);
}
continue;
}
MessageType::ParseCompletion
| MessageType::BindCompletion
| MessageType::NoData
| MessageType::CloseComplete
| MessageType::CopyDone => {
if len > 0 && message.len() >= len {
let _ = message.drain(..len);
}
continue;
}
MessageType::PortalSuspended => {
if len > 0 && message.len() >= len {
let _ = message.drain(..len);
}
self.success_message.has_more = true;
continue;
}
MessageType::EmptyQueryResponse => {
return Ok(self.success_message.clone());
}
MessageType::NotificationResponse => {
if message.len() >= len {
let _ = message.drain(..len).collect::<Vec<u8>>();
}
continue;
}
MessageType::CopyInResponse | MessageType::CopyOutResponse => {
if message.len() >= len {
let _ = message.drain(..len).collect::<Vec<u8>>();
}
continue;
}
MessageType::CopyData => {
if message.len() >= len {
let _ = message.drain(..len).collect::<Vec<u8>>();
}
continue;
}
MessageType::DataRow => {
if len < 2 || message.len() < 2 {
return Err(PgsqlError::Protocol("DataRow message too short".into()));
}
let field_count = u16::from_be_bytes(
message
.drain(..2)
.collect::<Vec<u8>>()
.as_slice()
.try_into()
.unwrap_or([0, 0]),
);
len -= 2;
if self.config.debug {
debug!("field_count: {field_count:?}");
}
let mut fields = if field_count > 0 && message.len() >= len {
message.drain(..len).collect::<Vec<u8>>()
} else {
vec![]
};
let mut row = object! {};
for i in 0..field_count {
if fields.len() < 4 {
break;
}
let length = u32::from_be_bytes(
fields
.drain(..4)
.collect::<Vec<u8>>()
.try_into()
.unwrap_or([0, 0, 0, 0]),
);
let data = if length == 4_294_967_295 {
vec![]
} else if fields.len() >= length as usize {
fields.drain(..length as usize).collect::<Vec<u8>>()
} else {
vec![]
};
if let Some(field) = self.success_message.fields.get(i as usize) {
let field = field.clone();
row[field.name] = FieldFormat::from_value(&field.format_type, data);
}
}
let _ = self.success_message.rows.push(row.clone());
continue;
}
MessageType::CommandCompletion => {
if message.len() < len {
return Err(PgsqlError::Protocol(
"CommandCompletion message too short".into(),
));
}
let tag = message.drain(..len).collect::<Vec<u8>>();
let tag_end = if !tag.is_empty() && tag[tag.len() - 1] == 0 {
tag.len() - 1
} else {
tag.len()
};
self.tag = String::from_utf8_lossy(&tag[..tag_end]).into_owned();
self.success_message.tag = self.tag.clone();
match self.tag.as_str() {
v if v.starts_with("SELECT") => {
self.success_message.affect_count = v
.trim_start_matches("SELECT")
.trim()
.parse::<usize>()
.unwrap_or(0);
}
v if v.starts_with("UPDATE") => {
self.success_message.affect_count = v
.trim_start_matches("UPDATE")
.trim()
.parse::<usize>()
.unwrap_or(0);
}
v if v.starts_with("INSERT") => {
let t = v.trim_start_matches("INSERT").trim();
let tt: Vec<usize> = t
.split(' ')
.map(|x| x.parse::<usize>().unwrap_or(0))
.collect();
self.success_message.row_count = tt.first().copied().unwrap_or(0);
self.success_message.affect_count = tt.get(1).copied().unwrap_or(0);
}
v if v.starts_with("DELETE") => {
self.success_message.affect_count = v
.trim_start_matches("DELETE")
.trim()
.parse::<usize>()
.unwrap_or(0);
}
v if v.starts_with("BEGIN") => {
self.transactions += 1;
}
v if v.starts_with("START TRANSACTION") => {
self.transactions += 1;
}
v if v.starts_with("COMMIT") => {
self.transactions = self.transactions.saturating_sub(1);
}
v if v.starts_with("ROLLBACK") => {
self.transactions = self.transactions.saturating_sub(1);
}
_ => {
if self.config.debug {
debug!("tag: {:?}", self.tag);
}
}
}
self.success_message.transaction = self.transactions > 0;
continue;
}
}
}
}
pub fn text(&self, response: &[u8]) -> String {
if response.is_empty() {
return String::new();
}
let end = if response[response.len() - 1] == 0 {
response.len() - 1
} else {
response.len()
};
String::from_utf8_lossy(&response[..end]).into_owned()
}
fn to_query_error(&self) -> PgsqlError {
PgsqlError::Query {
code: self.error_message.code.clone(),
message: self.error_message.message.clone(),
detail: self.error_message.detail.clone(),
sql: self.error_message.sql.clone(),
position: self.error_message.position,
}
}
pub fn pack_terminate() -> Vec<u8> {
let mut buf = Vec::with_capacity(5);
buf.push(b'X'); buf.extend(&(4u32.to_be_bytes())); buf
}
pub fn pack_ssl_request() -> Vec<u8> {
let mut buf = Vec::with_capacity(8);
buf.extend(&8u32.to_be_bytes()); buf.extend(&80877103u32.to_be_bytes()); buf
}
pub fn pack_query_portal(&mut self, sql: &str, max_rows: u32) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
self.sql = sql.to_string();
let mut message = vec![];
message.extend(Query::Parse(self.statement, sql.to_string(), vec![]).vec_u8());
message.extend(Query::Bind(0, self.statement, vec![]).vec_u8());
message.extend(Query::Describe(b'P', 0).vec_u8());
message.extend(Query::Execute(0, max_rows).vec_u8());
message.extend(Query::Sync.vec_u8());
self.statement += 1;
message
}
pub fn pack_fetch_more(&mut self, max_rows: u32) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
let mut message = vec![];
message.extend(Query::Execute(0, max_rows).vec_u8());
message.extend(Query::Sync.vec_u8());
message
}
pub fn pack_close_portal(&mut self) -> Vec<u8> {
self.success_message = SuccessMessage::default();
self.error_message = ErrorMessage::default();
let mut message = vec![];
message.extend(Query::Close(b'P', 0).vec_u8());
message.extend(Query::Sync.vec_u8());
message
}
}
#[derive(Clone, Debug)]
enum Version {
V3,
V31,
V32,
}
impl Version {
pub fn i32(&mut self) -> i32 {
match self {
Version::V3 => 196_608,
Version::V31 => 196_609,
Version::V32 => 196_610,
}
}
pub fn into(version: f64) -> Self {
match version {
3.0 => Version::V3,
3.1 => Version::V31,
3.2 => Version::V32,
_ => Version::V3,
}
}
}
#[derive(Clone, Debug)]
enum MessageType {
ParseCompletion,
BindCompletion,
ParameterDescription,
NoData,
RowDescription,
Authentication,
ErrorResponse,
ParameterStatus,
BackendKeyData,
ReadyForQuery,
DataRow,
CommandCompletion,
Notice,
EmptyQueryResponse,
NotificationResponse,
CloseComplete,
PortalSuspended,
CopyInResponse,
CopyOutResponse,
CopyDone,
CopyData,
None,
}
impl MessageType {
pub fn form(msg: u8) -> Self {
match msg {
b'1' => MessageType::ParseCompletion,
b'2' => MessageType::BindCompletion,
b'3' => MessageType::CloseComplete,
b't' => MessageType::ParameterDescription,
b'n' => MessageType::NoData,
b'T' => MessageType::RowDescription,
b'R' => MessageType::Authentication,
b'E' => MessageType::ErrorResponse,
b'S' => MessageType::ParameterStatus,
b'K' => MessageType::BackendKeyData,
b'Z' => MessageType::ReadyForQuery,
b'D' => MessageType::DataRow,
b'C' => MessageType::CommandCompletion,
b'N' => MessageType::Notice,
b'I' => MessageType::EmptyQueryResponse,
b'A' => MessageType::NotificationResponse,
b's' => MessageType::PortalSuspended,
b'G' => MessageType::CopyInResponse,
b'H' => MessageType::CopyOutResponse,
b'c' => MessageType::CopyDone,
b'd' => MessageType::CopyData,
_ => {
warn!("未知msg: {} code: {}", msg, String::from_utf8_lossy(&[msg]));
MessageType::None
}
}
}
}
#[derive(Clone, Debug)]
pub enum AuthStatus {
Md5Password,
CleartextPassword,
ScramInitialization,
ScramChallenge,
ScramComplete,
AuthenticationOk,
None,
}
impl AuthStatus {
pub fn form(code: u32) -> Self {
match code {
0 => Self::AuthenticationOk,
3 => Self::CleartextPassword,
5 => Self::Md5Password,
10 => Self::ScramInitialization,
11 => Self::ScramChallenge,
12 => Self::ScramComplete,
_ => {
warn!("未知 AuthStatus: {code}");
Self::None
}
}
}
}
fn generate_nonce() -> String {
let mut rng = rand::rng();
(0..24)
.map(|_| {
let mut v = rng.random_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e;
}
v as char
})
.collect::<String>()
}
fn compute_salted_password(password: &[u8], server_salt: &str, iterations: u32) -> Vec<u8> {
let salt = STANDARD.decode(server_salt).unwrap_or_else(|_| vec![]);
let mut hmac =
Hmac::<Sha256>::new_from_slice(password).expect("HMAC is able to accept all key sizes");
hmac.update(&salt);
hmac.update(&[0, 0, 0, 1]);
let mut prev = hmac.finalize().into_bytes();
let mut hi = prev;
for _ in 1..iterations {
let mut hmac = Hmac::<Sha256>::new_from_slice(password).expect("already checked above");
hmac.update(&prev);
prev = hmac.finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
}
hi.as_slice().to_vec()
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take a key of any size");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
fn sha256(input: &[u8]) -> Vec<u8> {
let mut hash = Sha256::default();
hash.update(input);
hash.finalize_fixed().to_vec()
}
type HmacSha256 = Hmac<Sha256>;
fn compute_client_proof(client_key: &[u8], stored_key: &[u8], auth_message: &str) -> String {
let mut mac = HmacSha256::new_from_slice(stored_key).expect("HMAC can take key of any size");
mac.update(auth_message.as_bytes());
let client_signature = mac.finalize().into_bytes();
let client_proof: Vec<u8> = client_key
.iter()
.zip(client_signature.iter())
.map(|(&key, &sig)| key ^ sig)
.collect();
STANDARD.encode(client_proof)
}
#[derive(Clone, Debug, Default)]
pub struct ErrorMessage {
pub code: String,
pub message: String,
pub detail: String,
pub sql: String,
pub position: u16,
}
#[derive(Clone, Debug)]
pub struct SuccessMessage {
pub fields: Vec<Field>,
pub rows: JsonValue,
pub param_oids: Vec<u32>,
pub row_count: usize,
pub affect_count: usize,
pub transaction: bool,
pub tag: String,
pub has_more: bool,
}
impl Default for SuccessMessage {
fn default() -> Self {
Self {
fields: Vec::default(),
rows: array![],
param_oids: Vec::default(),
row_count: 0,
affect_count: 0,
transaction: false,
tag: String::new(),
has_more: false,
}
}
}
#[derive(Clone, Debug)]
pub struct Field {
name: String,
column_index: u16,
format_type: FieldFormat,
}
impl Field {
pub fn new(name: String, column_index: u16, format_type: FieldFormat) -> Self {
Self {
name,
column_index,
format_type,
}
}
pub fn json(&self) -> JsonValue {
object! {
name: self.name.clone(),
index: self.column_index,
}
}
}
#[derive(Clone, Debug)]
enum Status {
I,
T,
E,
S,
None,
}
impl Status {
pub fn form(code: &[u8]) -> Self {
match code {
b"I" => Status::I,
b"T" => Status::T,
b"E" => Status::E,
b"S" => Status::S,
_ => Status::None,
}
}
}
enum Query {
Parse(u32, String, Vec<u32>),
Bind(u8, u32, Vec<Option<String>>),
Describe(u8, u32),
Execute(u32, u32),
Sync,
Close(u8, u32),
}
impl Query {
pub fn vec_u8(&self) -> Vec<u8> {
match self {
Query::Parse(statement, query, param_oids) => {
let statement_text = if statement > &0 {
format!("s{statement}\0")
} else {
"\0".to_string()
};
let query_text = format!("{query}\0");
let num_params = u16::try_from(param_oids.len()).unwrap_or(u16::MAX);
let len =
4 + statement_text.len() + query_text.len() + 2 + (usize::from(num_params) * 4);
let mut data = vec![];
data.extend_from_slice(b"P");
data.extend_from_slice(&u32::try_from(len).unwrap().to_be_bytes());
data.extend_from_slice(statement_text.as_bytes());
data.extend_from_slice(query_text.as_bytes());
data.extend_from_slice(&num_params.to_be_bytes());
for oid in param_oids.iter().take(usize::from(num_params)) {
data.extend_from_slice(&oid.to_be_bytes());
}
data
}
Query::Bind(portal, statement, params) => {
let statement_text = if statement > &0 {
format!("s{statement}\0")
} else {
"\0".to_string()
};
let portal_text = if portal > &0 {
format!("{portal}\0")
} else {
"\0".to_string()
};
let mut payload = vec![];
payload.extend_from_slice(portal_text.as_bytes());
payload.extend_from_slice(statement_text.as_bytes());
if params.is_empty() {
payload.extend_from_slice(&0u16.to_be_bytes());
payload.extend_from_slice(&0u16.to_be_bytes());
} else {
payload.extend_from_slice(&1u16.to_be_bytes());
payload.extend_from_slice(&0u16.to_be_bytes());
let param_count = u16::try_from(params.len()).unwrap_or(u16::MAX);
payload.extend_from_slice(¶m_count.to_be_bytes());
for param in params.iter().take(usize::from(param_count)) {
match param {
Some(value) => {
let value_bytes = value.as_bytes();
let value_len =
i32::try_from(value_bytes.len()).unwrap_or(i32::MAX);
payload.extend_from_slice(&value_len.to_be_bytes());
payload.extend_from_slice(value_bytes);
}
None => payload.extend_from_slice(&(-1i32).to_be_bytes()),
}
}
}
payload.extend_from_slice(&0u16.to_be_bytes());
let mut data = vec![];
data.extend_from_slice(b"B");
data.extend_from_slice(
&u32::try_from(payload.len() + 4)
.unwrap_or(u32::MAX)
.to_be_bytes(),
);
data.extend_from_slice(&payload);
data
}
Query::Describe(describe_type, name) => {
let name_text = match (*describe_type, name) {
(b'S', n) if *n > 0 => format!("Ss{n}\0"),
(b'S', _) => "S\0".to_string(),
(_, n) if *n > 0 => format!("P{n}\0"),
_ => "P\0".to_string(),
};
let len = 4 + name_text.len();
let mut data = vec![];
data.extend_from_slice(b"D");
data.extend_from_slice(&u32::try_from(len).unwrap().to_be_bytes());
data.extend_from_slice(name_text.as_bytes());
data
}
Query::Execute(portal, returns) => {
let portal_text = if portal > &0 {
format!("{portal}\0")
} else {
"\0".to_string()
};
let len = 4 + portal_text.len() + returns.to_be_bytes().len();
let mut data = vec![];
data.extend_from_slice(b"E");
data.extend_from_slice(&u32::try_from(len).unwrap().to_be_bytes());
data.extend_from_slice(portal_text.as_bytes());
data.extend_from_slice(&returns.to_be_bytes());
data
}
Query::Sync => {
let mut data = vec![];
data.extend_from_slice(b"S");
data.extend_from_slice(&4u32.to_be_bytes());
data
}
Query::Close(close_type, name) => {
let name_text = if name > &0 {
format!("s{name}\0")
} else {
"\0".to_string()
};
let len = 4 + 1 + name_text.len();
let mut data = vec![];
data.extend_from_slice(b"C");
data.extend_from_slice(&u32::try_from(len).unwrap().to_be_bytes());
data.push(*close_type);
data.extend_from_slice(name_text.as_bytes());
data
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use json::object;
fn test_config(debug: bool) -> Config {
Config::new(&object! {
debug: debug,
username: "test_user",
userpass: "test_pass",
database: "test_db",
hostname: "localhost",
hostport: 5432,
charset: "utf8mb4",
pool_max: 5,
})
}
fn test_packet(debug: bool) -> Packet {
let mut packet = Packet::new(test_config(debug));
packet.client_nonce = "test_nonce".to_string();
packet
}
fn build_msg(type_byte: u8, payload: &[u8]) -> Vec<u8> {
let mut message = Vec::with_capacity(payload.len() + 5);
message.push(type_byte);
message.extend_from_slice(&(u32::try_from(payload.len()).unwrap() + 4).to_be_bytes());
message.extend_from_slice(payload);
message
}
fn build_msg_with_len(type_byte: u8, length_with_header: u32, payload: &[u8]) -> Vec<u8> {
let mut message = Vec::with_capacity(payload.len() + 5);
message.push(type_byte);
message.extend_from_slice(&length_with_header.to_be_bytes());
message.extend_from_slice(payload);
message
}
fn build_auth(auth_type: u32, payload: &[u8]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&auth_type.to_be_bytes());
body.extend_from_slice(payload);
build_msg(b'R', &body)
}
fn build_field(code: u8, value: &str) -> Vec<u8> {
let mut field = Vec::new();
field.push(code);
field.extend_from_slice(value.as_bytes());
field.push(0);
field
}
fn assert_protocol_error(result: Result<SuccessMessage, PgsqlError>, expected_substr: &str) {
match result {
Err(PgsqlError::Protocol(msg)) => {
assert!(
msg.contains(expected_substr),
"expected protocol error containing '{expected_substr}', got '{msg}'"
);
}
other => panic!("expected protocol error, got {other:?}"),
}
}
fn hex_to_bytes(hex: &str) -> Vec<u8> {
assert_eq!(hex.len() % 2, 0);
hex.as_bytes()
.chunks(2)
.map(|chunk| {
let s = std::str::from_utf8(chunk).unwrap();
u8::from_str_radix(s, 16).unwrap()
})
.collect()
}
#[test]
fn packet_new_defaults() {
let packet = Packet::new(test_config(false));
assert!(matches!(packet.version, Version::V3));
assert!(packet.params.is_empty());
assert!(packet.auth_mechanism.is_empty());
assert_eq!(packet.client_nonce.len(), 24);
assert!(!packet.client_nonce.contains(','));
assert!(packet.server_nonce.is_empty());
assert!(packet.server_salt.is_empty());
assert!(packet.md5_salt.is_empty());
assert_eq!(packet.iterations, 0);
assert!(packet.error_message.code.is_empty());
assert!(packet.success_message.fields.is_empty());
assert!(packet.success_message.param_oids.is_empty());
assert_eq!(packet.process_id, 0);
assert_eq!(packet.secret_key, 0);
assert!(matches!(packet.status, Status::None));
assert!(packet.sql.is_empty());
assert!(packet.tag.is_empty());
assert_eq!(packet.statement, 0);
assert_eq!(packet.transactions, 0);
}
#[test]
fn packet_set_version_and_params() {
let mut packet = test_packet(false);
packet.set_version(3.1);
assert_eq!(packet.version.i32(), 196_609);
packet.set_version(3.2);
assert_eq!(packet.version.i32(), 196_610);
packet.set_version(9.9);
assert_eq!(packet.version.i32(), 196_608);
packet.set_params("application_name", "br-pgsql");
assert_eq!(
packet.params.get("application_name").map(String::as_str),
Some("br-pgsql")
);
}
#[test]
fn pack_first_builds_startup_message() {
let mut packet = test_packet(false);
let message = packet.pack_first();
assert_eq!(
u32::from_be_bytes(message[..4].try_into().unwrap()) as usize,
message.len()
);
assert_eq!(
i32::from_be_bytes(message[4..8].try_into().unwrap()),
196_608
);
assert!(message.ends_with(&[0]));
assert!(message
.windows(b"client_encoding\0UTF8\0".len())
.any(|w| w == b"client_encoding\0UTF8\0"));
assert!(message
.windows(b"user\0test_user\0".len())
.any(|w| w == b"user\0test_user\0"));
assert!(message
.windows(b"database\0test_db\0".len())
.any(|w| w == b"database\0test_db\0"));
}
#[test]
fn pack_auth_builds_sasl_initial_response() {
let mut packet = test_packet(false);
packet.auth_mechanism = "SCRAM-SHA-256".to_string();
let message = packet.pack_auth();
let auth_data = "n,,n=,r=test_nonce";
let mut expected_payload = Vec::new();
expected_payload.extend_from_slice(b"SCRAM-SHA-256");
expected_payload.push(0);
expected_payload.extend_from_slice(&i32::try_from(auth_data.len()).unwrap().to_be_bytes());
expected_payload.extend_from_slice(auth_data.as_bytes());
assert_eq!(message[0], b'p');
assert_eq!(
u32::from_be_bytes(message[1..5].try_into().unwrap()) as usize,
expected_payload.len() + 4
);
assert_eq!(&message[5..], expected_payload.as_slice());
}
#[test]
fn pack_auth_verify_builds_scram_verify_message() {
let mut packet = test_packet(false);
packet.server_nonce = "server_nonce".to_string();
packet.server_salt = "c2FsdA==".to_string();
packet.iterations = 2;
let message = packet.pack_auth_verify();
assert_eq!(message[0], b'p');
assert_eq!(
u32::from_be_bytes(message[1..5].try_into().unwrap()) as usize,
message.len() - 1
);
let salted_password = compute_salted_password(
packet.config.userpass.as_bytes(),
packet.server_salt.as_str(),
2,
);
let client_key = hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key);
let auth_message =
"n=,r=test_nonce,r=test_nonceserver_nonce,s=c2FsdA==,i=2,c=biws,r=test_nonceserver_nonce";
let expected_proof = compute_client_proof(&client_key, &stored_key, auth_message);
let expected_payload = format!("c=biws,r=test_nonceserver_nonce,p={expected_proof}");
assert_eq!(&message[5..], expected_payload.as_bytes());
}
#[test]
fn pack_md5_password_builds_password_message() {
let mut packet = test_packet(false);
packet.md5_salt = vec![1, 2, 3, 4];
let inner = format!(
"{:x}",
md5::compute(
format!("{}{}", packet.config.userpass, packet.config.username).as_bytes()
)
);
let mut salted = inner.as_bytes().to_vec();
salted.extend_from_slice(&packet.md5_salt);
let hash = format!("md5{:x}", md5::compute(&salted));
let message = packet.pack_md5_password();
assert_eq!(message[0], b'p');
assert_eq!(
u32::from_be_bytes(message[1..5].try_into().unwrap()) as usize,
hash.len() + 5
);
assert_eq!(&message[5..message.len() - 1], hash.as_bytes());
assert_eq!(message.last().copied(), Some(0));
}
#[test]
fn pack_cleartext_password_builds_password_message() {
let packet = test_packet(false);
let message = packet.pack_cleartext_password();
assert_eq!(message[0], b'p');
assert_eq!(
u32::from_be_bytes(message[1..5].try_into().unwrap()) as usize,
packet.config.userpass.len() + 5
);
assert_eq!(
&message[5..message.len() - 1],
packet.config.userpass.as_bytes()
);
assert_eq!(message.last().copied(), Some(0));
}
#[test]
fn pack_query_builds_parse_bind_describe_execute_sync() {
let mut packet = test_packet(false);
packet.success_message.affect_count = 9;
packet.error_message.message = "old-error".to_string();
let sql = "SELECT 1";
let message = packet.pack_query(sql);
let mut expected = Vec::new();
expected.extend(Query::Parse(0, sql.to_string(), vec![]).vec_u8());
expected.extend(Query::Bind(0, 0, vec![]).vec_u8());
expected.extend(Query::Describe(b'P', 0).vec_u8());
expected.extend(Query::Execute(0, 0).vec_u8());
expected.extend(Query::Sync.vec_u8());
assert_eq!(message, expected);
assert_eq!(packet.sql, sql);
assert_eq!(packet.success_message.affect_count, 0);
assert!(packet.error_message.message.is_empty());
assert_eq!(packet.statement, 1);
}
#[test]
fn pack_execute_builds_parse_bind_describe_execute_sync() {
let mut packet = test_packet(false);
packet.success_message.affect_count = 9;
packet.error_message.message = "old-error".to_string();
let sql = "UPDATE t SET c = 1";
let message = packet.pack_execute(sql);
let mut expected = Vec::new();
expected.extend(Query::Parse(0, sql.to_string(), vec![]).vec_u8());
expected.extend(Query::Bind(0, 0, vec![]).vec_u8());
expected.extend(Query::Describe(b'P', 0).vec_u8());
expected.extend(Query::Execute(0, 0).vec_u8());
expected.extend(Query::Sync.vec_u8());
assert_eq!(message, expected);
assert_eq!(packet.sql, sql);
assert_eq!(packet.success_message.affect_count, 0);
assert!(packet.error_message.message.is_empty());
assert_eq!(packet.statement, 1);
}
#[test]
fn pack_query_params_builds_correct_message() {
let mut packet = test_packet(false);
packet.success_message.affect_count = 9;
packet.error_message.message = "old-error".to_string();
let sql = "SELECT $1";
let message = packet.pack_query_params(sql, &[Some("test")]);
let mut expected = Vec::new();
expected.extend(Query::Parse(0, sql.to_string(), vec![0u32]).vec_u8());
expected.extend(Query::Bind(0, 0, vec![Some("test".to_string())]).vec_u8());
expected.extend(Query::Describe(b'S', 0).vec_u8());
expected.extend(Query::Execute(0, 0).vec_u8());
expected.extend(Query::Sync.vec_u8());
assert_eq!(message, expected);
assert_eq!(packet.sql, sql);
assert_eq!(packet.success_message.affect_count, 0);
assert!(packet.error_message.message.is_empty());
assert_eq!(packet.statement, 1);
}
#[test]
fn pack_execute_params_builds_correct_message() {
let mut packet = test_packet(false);
packet.success_message.affect_count = 9;
packet.error_message.message = "old-error".to_string();
let sql = "UPDATE t SET c = $1";
let message = packet.pack_execute_params(sql, &[Some("11")]);
let mut expected = Vec::new();
expected.extend(Query::Parse(0, sql.to_string(), vec![0u32]).vec_u8());
expected.extend(Query::Bind(0, 0, vec![Some("11".to_string())]).vec_u8());
expected.extend(Query::Describe(b'S', 0).vec_u8());
expected.extend(Query::Execute(0, 0).vec_u8());
expected.extend(Query::Sync.vec_u8());
assert_eq!(message, expected);
assert_eq!(packet.sql, sql);
assert_eq!(packet.success_message.affect_count, 0);
assert!(packet.error_message.message.is_empty());
assert_eq!(packet.statement, 1);
}
#[test]
fn statement_counter_increments() {
let mut packet = test_packet(false);
assert_eq!(packet.statement, 0);
let _ = packet.pack_query("SELECT 1");
assert_eq!(packet.statement, 1);
let _ = packet.pack_execute("UPDATE t SET x = 1");
assert_eq!(packet.statement, 2);
let _ = packet.pack_query_params("SELECT $1", &[Some("1")]);
assert_eq!(packet.statement, 3);
let _ = packet.pack_execute_params("UPDATE t SET x = $1", &[Some("1")]);
assert_eq!(packet.statement, 4);
}
#[test]
fn pack_terminate_builds_expected_message() {
assert_eq!(Packet::pack_terminate(), vec![b'X', 0, 0, 0, 4]);
}
#[test]
fn text_handles_normal_empty_and_trailing_zero() {
let packet = test_packet(false);
assert_eq!(packet.text(b"hello"), "hello");
assert_eq!(packet.text(b""), "");
assert_eq!(packet.text(b"world\0"), "world");
}
#[test]
fn unpack_authentication_ok_continue() {
let mut packet = test_packet(false);
let result = packet.unpack(build_auth(0, &[]), 0);
assert!(result.is_ok());
}
#[test]
fn unpack_auth_cleartext_password() {
let mut packet = test_packet(false);
let result = packet.unpack(build_auth(3, &[]), 0);
assert!(result.is_ok());
}
#[test]
fn unpack_auth_md5_password_sets_salt() {
let mut packet = test_packet(false);
packet
.unpack(build_auth(5, &[0xaa, 0xbb, 0xcc, 0xdd]), 0)
.unwrap();
assert_eq!(packet.md5_salt, vec![0xaa, 0xbb, 0xcc, 0xdd]);
}
#[test]
fn unpack_auth_scram_initialization_sets_mechanism() {
let mut packet = test_packet(false);
packet
.unpack(build_auth(10, b"SCRAM-SHA-256\0\0"), 0)
.unwrap();
assert_eq!(packet.auth_mechanism, "SCRAM-SHA-256");
}
#[test]
fn unpack_auth_scram_challenge_parses_nonce_salt_iterations() {
let mut packet = test_packet(false);
packet
.unpack(build_auth(11, b"r=test_nonceSERVER,s=c2FsdA==,i=4096"), 0)
.unwrap();
assert_eq!(packet.server_nonce, "SERVER");
assert_eq!(packet.server_salt, "c2FsdA==");
assert_eq!(packet.iterations, 4096);
}
#[test]
fn unpack_auth_scram_complete_parses_server_proof() {
let mut packet = test_packet(false);
packet.unpack(build_auth(12, b"v=proof_value"), 0).unwrap();
assert_eq!(packet.server_proof, "proof_value");
}
#[test]
fn unpack_auth_unknown_type_returns_query_error() {
let mut packet = test_packet(false);
match packet.unpack(build_auth(999, &[]), 0) {
Err(PgsqlError::Query {
code,
message,
detail,
sql,
position,
}) => {
assert!(code.is_empty());
assert!(message.is_empty());
assert!(detail.is_empty());
assert!(sql.is_empty());
assert_eq!(position, 0);
}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_auth_scram_challenge_invalid_format() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_auth(11, b"r=test_nonce,s=c2FsdA=="), 0),
"SCRAM challenge invalid format",
);
}
#[test]
fn unpack_auth_scram_initialization_too_short() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_auth(10, b"x"), 0),
"SCRAM initialization message too short",
);
}
#[test]
fn unpack_auth_md5_salt_too_short() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_auth(5, &[1, 2, 3]), 0),
"MD5 salt too short",
);
}
#[test]
fn unpack_auth_scram_complete_too_short() {
let mut packet = test_packet(false);
let mut malformed = Vec::new();
malformed.push(b'R');
malformed.extend_from_slice(&20u32.to_be_bytes());
malformed.extend_from_slice(&12u32.to_be_bytes());
malformed.extend_from_slice(b"v=x");
assert_protocol_error(
packet.unpack(malformed, 0),
"SCRAM complete message too short",
);
}
#[test]
fn unpack_error_response_with_all_fields() {
let mut packet = test_packet(false);
packet.sql = "SELECT * FROM t".to_string();
let mut payload = Vec::new();
payload.extend(build_field(b'C', "23505"));
payload.extend(build_field(b'M', "duplicate key value"));
payload.extend(build_field(b'D', "detail text"));
payload.extend(build_field(b'H', "hint text"));
payload.extend(build_field(b'P', "12"));
payload.push(0);
match packet.unpack(build_msg(b'E', &payload), 0) {
Err(PgsqlError::Query {
code,
message,
detail,
sql,
position,
}) => {
assert_eq!(code, "23505");
assert_eq!(message, "duplicate key value");
assert_eq!(detail, "hint text");
assert_eq!(sql, "SELECT * FROM t");
assert_eq!(position, 12);
}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_error_response_skips_unknown_fields() {
let mut packet = test_packet(false);
packet.sql = "SELECT broken".to_string();
let mut payload = Vec::new();
payload.extend(build_field(b'X', "ignored"));
payload.extend(build_field(b'C', "42601"));
payload.extend(build_field(b'M', "syntax error"));
payload.extend(build_field(b'P', "2"));
payload.push(0);
match packet.unpack(build_msg(b'E', &payload), 0) {
Err(PgsqlError::Query {
code,
message,
position,
..
}) => {
assert_eq!(code, "42601");
assert_eq!(message, "syntax error");
assert_eq!(position, 2);
}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_error_response_in_debug_mode() {
let mut packet = test_packet(true);
packet.sql = "SELECT debug".to_string();
let mut payload = Vec::new();
payload.extend(build_field(b'C', "22000"));
payload.extend(build_field(b'M', "debug error"));
payload.push(0);
match packet.unpack(build_msg(b'E', &payload), 0) {
Err(PgsqlError::Query { code, message, .. }) => {
assert_eq!(code, "22000");
assert_eq!(message, "debug error");
}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_notice_with_fields_continues() {
let mut packet = test_packet(false);
packet.sql = "SELECT notice".to_string();
let mut payload = Vec::new();
payload.extend(build_field(b'C', "01000"));
payload.extend(build_field(b'M', "notice message"));
payload.extend(build_field(b'D', "notice detail"));
payload.extend(build_field(b'P', "99"));
payload.push(0);
let result = packet.unpack(build_msg(b'N', &payload), 0).unwrap();
assert!(result.fields.is_empty());
assert_eq!(packet.error_message.code, "01000");
assert_eq!(packet.error_message.message, "notice message");
assert_eq!(packet.error_message.detail, "notice detail");
assert_eq!(packet.error_message.position, 99);
}
#[test]
fn unpack_parameter_status_sets_service_params() {
let mut packet = test_packet(false);
packet
.unpack(build_msg(b'S', b"server_version\x0016\0"), 0)
.unwrap();
assert_eq!(
packet
.service_params
.get("server_version")
.map(String::as_str),
Some("16")
);
}
#[test]
fn unpack_backend_key_data_sets_process_and_secret() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend_from_slice(&1234u32.to_be_bytes());
payload.extend_from_slice(&5678u32.to_be_bytes());
packet.unpack(build_msg(b'K', &payload), 0).unwrap();
assert_eq!(packet.process_id, 1234);
assert_eq!(packet.secret_key, 5678);
}
#[test]
fn unpack_ready_for_query_status_values() {
let mut packet_i = test_packet(false);
packet_i.unpack(build_msg(b'Z', b"I"), 0).unwrap();
assert!(matches!(packet_i.status, Status::I));
let mut packet_t = test_packet(false);
packet_t.unpack(build_msg(b'Z', b"T"), 0).unwrap();
assert!(matches!(packet_t.status, Status::T));
let mut packet_e = test_packet(false);
packet_e.unpack(build_msg(b'Z', b"E"), 0).unwrap();
assert!(matches!(packet_e.status, Status::E));
}
#[test]
fn unpack_parameter_description_skips_payload() {
let mut packet = test_packet(false);
assert!(packet.unpack(build_msg(b't', &[0, 1, 0, 2]), 0).is_ok());
}
#[test]
fn unpack_simple_completion_message_types_continue() {
let mut packet = test_packet(false);
let mut message = Vec::new();
for t in [b'1', b'2', b'n', b'3', b's', b'c'] {
message.extend(build_msg(t, &[]));
}
assert!(packet.unpack(message, 0).is_ok());
}
#[test]
fn unpack_empty_query_response_returns_ok() {
let mut packet = test_packet(false);
assert!(packet.unpack(build_msg(b'I', &[]), 0).is_ok());
}
#[test]
fn unpack_notification_response_skips_payload() {
let mut packet = test_packet(false);
assert!(packet.unpack(build_msg(b'A', b"notification"), 0).is_ok());
}
#[test]
fn unpack_copy_messages_skip_payload() {
let mut packet = test_packet(false);
let mut message = Vec::new();
message.extend(build_msg(b'G', &[1, 2]));
message.extend(build_msg(b'H', &[3, 4]));
message.extend(build_msg(b'd', &[5, 6, 7]));
assert!(packet.unpack(message, 0).is_ok());
}
#[test]
fn unpack_unknown_message_type_returns_error() {
let mut packet = test_packet(false);
match packet.unpack(build_msg(b'?', b"abc"), 0) {
Err(PgsqlError::Protocol(msg)) => assert!(msg.contains("63")),
other => panic!("expected protocol error, got {other:?}"),
}
}
fn row_description_payload(
name: &str,
column_index: u16,
type_oid: u32,
format: u16,
) -> Vec<u8> {
let mut payload = Vec::new();
payload.extend_from_slice(&1u16.to_be_bytes());
payload.extend_from_slice(name.as_bytes());
payload.push(0);
payload.extend_from_slice(&0u32.to_be_bytes());
payload.extend_from_slice(&column_index.to_be_bytes());
payload.extend_from_slice(&type_oid.to_be_bytes());
payload.extend_from_slice(&0i16.to_be_bytes());
payload.extend_from_slice(&0i32.to_be_bytes());
payload.extend_from_slice(&format.to_be_bytes());
payload
}
#[test]
fn unpack_row_description_and_data_row() {
let mut packet = test_packet(false);
let row_desc = build_msg(b'T', &row_description_payload("id", 1, 23, 0));
let mut data_row_payload = Vec::new();
data_row_payload.extend_from_slice(&1u16.to_be_bytes());
data_row_payload.extend_from_slice(&1u32.to_be_bytes());
data_row_payload.extend_from_slice(b"7");
let data_row = build_msg(b'D', &data_row_payload);
let mut message = Vec::new();
message.extend(row_desc);
message.extend(data_row);
let result = packet.unpack(message, 0).unwrap();
assert_eq!(result.fields.len(), 1);
assert_eq!(result.fields[0].name, "id");
assert!(matches!(result.fields[0].format_type, FieldFormat::Int));
assert_eq!(result.rows.len(), 1);
assert_eq!(result.rows[0]["id"].as_i32(), Some(7));
}
#[test]
fn unpack_data_row_null_column() {
let mut packet = test_packet(false);
let row_desc = build_msg(b'T', &row_description_payload("id", 1, 23, 0));
let mut data_row_payload = Vec::new();
data_row_payload.extend_from_slice(&1u16.to_be_bytes());
data_row_payload.extend_from_slice(&0xFFFF_FFFFu32.to_be_bytes());
let data_row = build_msg(b'D', &data_row_payload);
let mut message = Vec::new();
message.extend(row_desc);
message.extend(data_row);
let result = packet.unpack(message, 0).unwrap();
assert_eq!(result.rows.len(), 1);
assert_eq!(result.rows[0]["id"], 0);
}
#[test]
fn unpack_empty_field_count_for_row_description_and_data_row() {
let mut packet = test_packet(false);
let mut message = Vec::new();
message.extend(build_msg(b'T', &0u16.to_be_bytes()));
message.extend(build_msg(b'D', &0u16.to_be_bytes()));
let result = packet.unpack(message, 0).unwrap();
assert!(result.fields.is_empty());
assert_eq!(result.rows.len(), 1);
assert!(result.rows[0].is_object());
}
#[test]
fn unpack_command_completion_select() {
let mut packet = test_packet(false);
let result = packet.unpack(build_msg(b'C', b"SELECT 5\0"), 0).unwrap();
assert_eq!(result.tag, "SELECT 5");
assert_eq!(result.affect_count, 5);
}
#[test]
fn unpack_command_completion_update() {
let mut packet = test_packet(false);
let result = packet.unpack(build_msg(b'C', b"UPDATE 3\0"), 0).unwrap();
assert_eq!(result.tag, "UPDATE 3");
assert_eq!(result.affect_count, 3);
}
#[test]
fn unpack_command_completion_insert() {
let mut packet = test_packet(false);
let result = packet.unpack(build_msg(b'C', b"INSERT 0 1\0"), 0).unwrap();
assert_eq!(result.tag, "INSERT 0 1");
assert_eq!(result.row_count, 0);
assert_eq!(result.affect_count, 1);
}
#[test]
fn unpack_command_completion_delete() {
let mut packet = test_packet(false);
let result = packet.unpack(build_msg(b'C', b"DELETE 2\0"), 0).unwrap();
assert_eq!(result.tag, "DELETE 2");
assert_eq!(result.affect_count, 2);
}
#[test]
fn unpack_command_completion_begin_and_commit() {
let mut packet = test_packet(false);
let begin = packet.unpack(build_msg(b'C', b"BEGIN\0"), 0).unwrap();
assert!(begin.transaction);
assert_eq!(packet.transactions, 1);
let commit = packet.unpack(build_msg(b'C', b"COMMIT\0"), 0).unwrap();
assert!(!commit.transaction);
assert_eq!(packet.transactions, 0);
}
#[test]
fn unpack_command_completion_start_and_rollback() {
let mut packet = test_packet(false);
let start = packet
.unpack(build_msg(b'C', b"START TRANSACTION\0"), 0)
.unwrap();
assert!(start.transaction);
assert_eq!(packet.transactions, 1);
let rollback = packet.unpack(build_msg(b'C', b"ROLLBACK\0"), 0).unwrap();
assert!(!rollback.transaction);
assert_eq!(packet.transactions, 0);
}
#[test]
fn unpack_command_completion_unknown_tag_in_debug_mode() {
let mut packet = test_packet(true);
let result = packet.unpack(build_msg(b'C', b"MERGE 1\0"), 0).unwrap();
assert_eq!(result.tag, "MERGE 1");
assert_eq!(result.affect_count, 0);
}
#[test]
fn unpack_empty_message_returns_ok() {
let mut packet = test_packet(false);
assert!(packet.unpack(vec![], 0).is_ok());
}
#[test]
fn unpack_message_too_short_returns_error() {
let mut packet = test_packet(false);
assert_protocol_error(packet.unpack(vec![b'R', 0, 0, 0], 0), "Message too short");
}
#[test]
fn unpack_parameter_status_too_short_returns_error() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_msg_with_len(b'S', 10, b"a\0"), 0),
"ParameterStatus message too short",
);
}
#[test]
fn unpack_backend_key_data_too_short_returns_error() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_msg(b'K', &[1, 2, 3, 4]), 0),
"BackendKeyData message too short",
);
}
#[test]
fn unpack_row_description_too_short_returns_error() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_msg(b'T', &[]), 0),
"RowDescription message too short",
);
}
#[test]
fn unpack_data_row_too_short_returns_error() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_msg(b'D', &[]), 0),
"DataRow message too short",
);
}
#[test]
fn unpack_command_completion_too_short_returns_error() {
let mut packet = test_packet(false);
assert_protocol_error(
packet.unpack(build_msg_with_len(b'C', 12, b"OK"), 0),
"CommandCompletion message too short",
);
}
#[test]
fn version_into_and_i32_cover_all_paths() {
let mut v3 = Version::into(3.0);
assert_eq!(v3.i32(), 196_608);
let mut v31 = Version::into(3.1);
assert_eq!(v31.i32(), 196_609);
let mut v32 = Version::into(3.2);
assert_eq!(v32.i32(), 196_610);
let mut fallback = Version::into(8.8);
assert_eq!(fallback.i32(), 196_608);
let mut direct_v31 = Version::V31;
assert_eq!(direct_v31.i32(), 196_609);
let mut direct_v32 = Version::V32;
assert_eq!(direct_v32.i32(), 196_610);
}
#[test]
fn message_type_form_for_all_values() {
assert!(matches!(
MessageType::form(b'1'),
MessageType::ParseCompletion
));
assert!(matches!(
MessageType::form(b'2'),
MessageType::BindCompletion
));
assert!(matches!(
MessageType::form(b'3'),
MessageType::CloseComplete
));
assert!(matches!(
MessageType::form(b't'),
MessageType::ParameterDescription
));
assert!(matches!(MessageType::form(b'n'), MessageType::NoData));
assert!(matches!(
MessageType::form(b'T'),
MessageType::RowDescription
));
assert!(matches!(
MessageType::form(b'R'),
MessageType::Authentication
));
assert!(matches!(
MessageType::form(b'E'),
MessageType::ErrorResponse
));
assert!(matches!(
MessageType::form(b'S'),
MessageType::ParameterStatus
));
assert!(matches!(
MessageType::form(b'K'),
MessageType::BackendKeyData
));
assert!(matches!(
MessageType::form(b'Z'),
MessageType::ReadyForQuery
));
assert!(matches!(MessageType::form(b'D'), MessageType::DataRow));
assert!(matches!(
MessageType::form(b'C'),
MessageType::CommandCompletion
));
assert!(matches!(MessageType::form(b'N'), MessageType::Notice));
assert!(matches!(
MessageType::form(b'I'),
MessageType::EmptyQueryResponse
));
assert!(matches!(
MessageType::form(b'A'),
MessageType::NotificationResponse
));
assert!(matches!(
MessageType::form(b's'),
MessageType::PortalSuspended
));
assert!(matches!(
MessageType::form(b'G'),
MessageType::CopyInResponse
));
assert!(matches!(
MessageType::form(b'H'),
MessageType::CopyOutResponse
));
assert!(matches!(MessageType::form(b'c'), MessageType::CopyDone));
assert!(matches!(MessageType::form(b'd'), MessageType::CopyData));
assert!(matches!(MessageType::form(b'?'), MessageType::None));
}
#[test]
fn auth_status_form_for_all_values() {
assert!(matches!(AuthStatus::form(0), AuthStatus::AuthenticationOk));
assert!(matches!(AuthStatus::form(3), AuthStatus::CleartextPassword));
assert!(matches!(AuthStatus::form(5), AuthStatus::Md5Password));
assert!(matches!(
AuthStatus::form(10),
AuthStatus::ScramInitialization
));
assert!(matches!(AuthStatus::form(11), AuthStatus::ScramChallenge));
assert!(matches!(AuthStatus::form(12), AuthStatus::ScramComplete));
assert!(matches!(AuthStatus::form(999), AuthStatus::None));
}
#[test]
fn status_form_for_all_values() {
assert!(matches!(Status::form(b"I"), Status::I));
assert!(matches!(Status::form(b"T"), Status::T));
assert!(matches!(Status::form(b"E"), Status::E));
assert!(matches!(Status::form(b"S"), Status::S));
assert!(matches!(Status::form(b"X"), Status::None));
}
#[test]
fn field_new_and_json() {
let field = Field::new("col".to_string(), 7, FieldFormat::Text);
assert_eq!(field.name, "col");
assert_eq!(field.column_index, 7);
assert!(matches!(field.format_type, FieldFormat::Text));
let value = field.json();
assert_eq!(value["name"].as_str(), Some("col"));
assert_eq!(value["index"].as_u32(), Some(7));
}
#[test]
fn query_vec_u8_parse_statement_variants() {
let parse_default = Query::Parse(0, "SELECT 1".to_string(), vec![]).vec_u8();
assert_eq!(parse_default[0], b'P');
assert_eq!(
u32::from_be_bytes(parse_default[1..5].try_into().unwrap()) as usize,
parse_default.len() - 1
);
assert!(parse_default[5..].starts_with(b"\0SELECT 1\0"));
let parse_named = Query::Parse(2, "SELECT 1".to_string(), vec![]).vec_u8();
assert_eq!(parse_named[0], b'P');
assert!(parse_named[5..].starts_with(b"s2\0SELECT 1\0"));
}
#[test]
fn parse_with_param_oids_produces_correct_wire_format() {
let no_params = Query::Parse(0, "SELECT 1".to_string(), vec![]).vec_u8();
assert_eq!(no_params[0], b'P');
let last_two = &no_params[no_params.len() - 2..];
assert_eq!(last_two, &[0, 0]);
let two_params = Query::Parse(0, "SELECT $1, $2".to_string(), vec![0, 0]).vec_u8();
assert_eq!(two_params[0], b'P');
assert!(two_params[5..].starts_with(b"\0SELECT $1, $2\0"));
let query_end = 5 + 1 + "SELECT $1, $2".len();
let num_params =
u16::from_be_bytes(two_params[query_end + 1..query_end + 3].try_into().unwrap());
assert_eq!(num_params, 2);
assert_eq!(two_params.len(), query_end + 3 + 8);
assert_eq!(
&two_params[two_params.len() - 8..],
&[0, 0, 0, 0, 0, 0, 0, 0]
);
}
#[test]
fn query_vec_u8_bind_portal_and_statement_variants() {
let bind_default = Query::Bind(0, 0, vec![]).vec_u8();
let mut expected_default = Vec::new();
expected_default.push(b'B');
expected_default.extend_from_slice(&12u32.to_be_bytes());
expected_default.extend_from_slice(b"\0\0");
expected_default.extend_from_slice(&0u16.to_be_bytes());
expected_default.extend_from_slice(&0u16.to_be_bytes());
expected_default.extend_from_slice(&0u16.to_be_bytes());
assert_eq!(bind_default, expected_default);
let bind_named = Query::Bind(3, 7, vec![Some("v".to_string())]).vec_u8();
let mut expected_named = Vec::new();
expected_named.push(b'B');
expected_named.extend_from_slice(&22u32.to_be_bytes());
expected_named.extend_from_slice(b"3\0s7\0");
expected_named.extend_from_slice(&1u16.to_be_bytes());
expected_named.extend_from_slice(&0u16.to_be_bytes());
expected_named.extend_from_slice(&1u16.to_be_bytes());
expected_named.extend_from_slice(&1i32.to_be_bytes());
expected_named.extend_from_slice(b"v");
expected_named.extend_from_slice(&0u16.to_be_bytes());
assert_eq!(bind_named, expected_named);
}
#[test]
fn bind_with_no_params_produces_correct_wire_format() {
let bind = Query::Bind(0, 0, vec![]).vec_u8();
let mut expected = Vec::new();
expected.push(b'B');
expected.extend_from_slice(&12u32.to_be_bytes());
expected.extend_from_slice(b"\0\0");
expected.extend_from_slice(&0u16.to_be_bytes());
expected.extend_from_slice(&0u16.to_be_bytes());
expected.extend_from_slice(&0u16.to_be_bytes());
assert_eq!(bind, expected);
}
#[test]
fn bind_with_text_params_produces_correct_wire_format() {
let bind = Query::Bind(
0,
0,
vec![Some("hello".to_string()), Some("42".to_string())],
)
.vec_u8();
let mut expected = Vec::new();
expected.push(b'B');
expected.extend_from_slice(&29u32.to_be_bytes());
expected.extend_from_slice(b"\0\0");
expected.extend_from_slice(&1u16.to_be_bytes());
expected.extend_from_slice(&0u16.to_be_bytes());
expected.extend_from_slice(&2u16.to_be_bytes());
expected.extend_from_slice(&5i32.to_be_bytes());
expected.extend_from_slice(b"hello");
expected.extend_from_slice(&2i32.to_be_bytes());
expected.extend_from_slice(b"42");
expected.extend_from_slice(&0u16.to_be_bytes());
assert_eq!(bind, expected);
}
#[test]
fn bind_with_null_param_produces_correct_wire_format() {
let bind = Query::Bind(0, 0, vec![None, Some("x".to_string())]).vec_u8();
let mut expected = Vec::new();
expected.push(b'B');
expected.extend_from_slice(&23u32.to_be_bytes());
expected.extend_from_slice(b"\0\0");
expected.extend_from_slice(&1u16.to_be_bytes());
expected.extend_from_slice(&0u16.to_be_bytes());
expected.extend_from_slice(&2u16.to_be_bytes());
expected.extend_from_slice(&(-1i32).to_be_bytes());
expected.extend_from_slice(&1i32.to_be_bytes());
expected.extend_from_slice(b"x");
expected.extend_from_slice(&0u16.to_be_bytes());
assert_eq!(bind, expected);
}
#[test]
fn query_vec_u8_describe_portal_variants() {
let describe_default = Query::Describe(b'P', 0).vec_u8();
assert_eq!(describe_default, vec![b'D', 0, 0, 0, 6, b'P', 0]);
let describe_named = Query::Describe(b'P', 9).vec_u8();
assert_eq!(describe_named, vec![b'D', 0, 0, 0, 7, b'P', b'9', 0]);
}
#[test]
fn query_vec_u8_describe_statement_variants() {
let describe_unnamed = Query::Describe(b'S', 0).vec_u8();
assert_eq!(describe_unnamed, vec![b'D', 0, 0, 0, 6, b'S', 0]);
let describe_named = Query::Describe(b'S', 3).vec_u8();
assert_eq!(describe_named, vec![b'D', 0, 0, 0, 8, b'S', b's', b'3', 0]);
}
#[test]
fn query_vec_u8_execute_portal_variants() {
let execute_default = Query::Execute(0, 5).vec_u8();
let mut expected_default = vec![b'E'];
expected_default.extend_from_slice(&9u32.to_be_bytes());
expected_default.push(0);
expected_default.extend_from_slice(&5u32.to_be_bytes());
assert_eq!(execute_default, expected_default);
let execute_named = Query::Execute(4, 5).vec_u8();
let mut expected_named = vec![b'E'];
expected_named.extend_from_slice(&10u32.to_be_bytes());
expected_named.extend_from_slice(b"4\0");
expected_named.extend_from_slice(&5u32.to_be_bytes());
assert_eq!(execute_named, expected_named);
}
#[test]
fn query_vec_u8_sync() {
assert_eq!(Query::Sync.vec_u8(), vec![b'S', 0, 0, 0, 4]);
}
#[test]
fn compute_salted_password_known_inputs() {
let salted_once = compute_salted_password(b"password", "c2FsdA==", 1);
assert_eq!(
STANDARD.encode(salted_once),
"Eg+2z/z4syxD5yJSVsT4N6hlSMkszDVICAWYfLcL4Xs="
);
let salted_twice = compute_salted_password(b"password", "c2FsdA==", 2);
assert_eq!(
STANDARD.encode(salted_twice),
"rk0Mla9rRtMtCt/5KPBt0CowP47zwlHf1uLYWpVHTEM="
);
}
#[test]
fn hmac_sha256_known_inputs() {
let digest = hmac_sha256(b"key", b"The quick brown fox jumps over the lazy dog");
assert_eq!(
STANDARD.encode(digest),
"97yD9DBThCSxMpjmqm+xQ+9NWaFJRhdZl0edvC0aPNg="
);
}
#[test]
fn sha256_known_inputs() {
let digest = sha256(b"abc");
assert_eq!(
STANDARD.encode(digest),
"ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0="
);
}
#[test]
fn compute_client_proof_known_inputs() {
let client_key =
hex_to_bytes("00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff");
let stored_key =
hex_to_bytes("4773d12e2371bb935b9a0f5439b4a1c3ad3f2414b86980f8418d1cfabdfadfef");
let auth_message = "n=,r=client,r=clientserver,s=c2FsdA==,i=2,c=biws,r=clientserver";
let proof = compute_client_proof(&client_key, &stored_key, auth_message);
assert_eq!(proof, "04Ibe7P688bLoyuT7EM9jOuaLOoqPlIaonAvVZOSWCk=");
}
#[test]
fn generate_nonce_length_and_charset() {
let nonce = generate_nonce();
assert_eq!(nonce.len(), 24);
assert!(!nonce.contains(','));
assert!(nonce.bytes().all(|v| (0x21..=0x7e).contains(&v)));
}
#[test]
fn to_query_error_maps_all_fields() {
let mut packet = test_packet(false);
packet.error_message = ErrorMessage {
code: "22023".to_string(),
message: "bad value".to_string(),
detail: "detail text".to_string(),
sql: "SELECT bad".to_string(),
position: 17,
};
match packet.to_query_error() {
PgsqlError::Query {
code,
message,
detail,
sql,
position,
} => {
assert_eq!(code, "22023");
assert_eq!(message, "bad value");
assert_eq!(detail, "detail text");
assert_eq!(sql, "SELECT bad");
assert_eq!(position, 17);
}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_invalid_message_length_less_than_4() {
let mut packet = test_packet(false);
let msg = build_msg_with_len(b'R', 3, &[]);
assert_protocol_error(packet.unpack(msg, 0), "Invalid message length");
}
#[test]
fn unpack_scram_nonce_validation_failure() {
let mut packet = test_packet(false);
packet.client_nonce = "myclientnonce".to_string();
let challenge = b"r=WRONG_nonce_prefix,s=c2FsdA==,i=4096";
match packet.unpack(build_auth(11, challenge), 0) {
Err(PgsqlError::Auth(msg)) => {
assert!(
msg.contains("SCRAM server nonce does not start with client nonce"),
"got: {msg}"
);
}
other => panic!("expected auth error, got {other:?}"),
}
}
#[test]
fn unpack_scram_complete_len_less_than_4() {
let mut packet = test_packet(false);
let mut msg = Vec::new();
msg.push(b'R');
msg.extend_from_slice(&7u32.to_be_bytes());
msg.extend_from_slice(&12u32.to_be_bytes());
assert_protocol_error(packet.unpack(msg, 0), "SCRAM complete message too short");
}
#[test]
fn unpack_error_response_too_short() {
let mut packet = test_packet(false);
let msg = build_msg_with_len(b'E', 20, &[0x01, 0x02]);
assert_protocol_error(packet.unpack(msg, 0), "ErrorResponse message too short");
}
#[test]
fn unpack_error_response_empty_buf_break() {
let mut packet = test_packet(false);
let payload: Vec<u8> = vec![];
match packet.unpack(build_msg(b'E', &payload), 0) {
Err(PgsqlError::Query { .. }) => {}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_error_response_no_null_terminator() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.push(b'C');
payload.extend_from_slice(b"no_null_here");
match packet.unpack(build_msg(b'E', &payload), 0) {
Err(PgsqlError::Query { .. }) => {}
other => panic!("expected query error, got {other:?}"),
}
}
#[test]
fn unpack_notice_too_short() {
let mut packet = test_packet(false);
let msg = build_msg_with_len(b'N', 20, &[0x01, 0x02]);
assert_protocol_error(packet.unpack(msg, 0), "Notice message too short");
}
#[test]
fn unpack_notice_empty_buf_break() {
let mut packet = test_packet(false);
let payload: Vec<u8> = vec![];
let result = packet.unpack(build_msg(b'N', &payload), 0);
assert!(result.is_ok());
}
#[test]
fn unpack_notice_no_null_terminator() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.push(b'M');
payload.extend_from_slice(b"no_null_here");
let result = packet.unpack(build_msg(b'N', &payload), 0);
assert!(result.is_ok());
}
#[test]
fn unpack_notice_unknown_field() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend(build_field(b'X', "ignored_notice"));
payload.push(0);
let result = packet.unpack(build_msg(b'N', &payload), 0);
assert!(result.is_ok());
}
#[test]
fn unpack_parameter_status_no_null_separator() {
let mut packet = test_packet(false);
let payload = b"nonullhere";
packet.unpack(build_msg(b'S', payload), 0).unwrap();
assert!(packet.service_params.is_empty());
}
#[test]
fn unpack_parameter_status_value_empty_when_equal() {
let mut packet = test_packet(false);
let payload = b"key\0\0";
packet.unpack(build_msg(b'S', payload), 0).unwrap();
assert_eq!(
packet.service_params.get("key").map(String::as_str),
Some("")
);
}
#[test]
fn unpack_parameter_status_null_at_end_only() {
let mut packet = test_packet(false);
let payload = b"k\0";
packet.unpack(build_msg(b'S', payload), 0).unwrap();
assert!(packet.service_params.is_empty());
}
#[test]
fn unpack_ready_for_query_too_short() {
let mut packet = test_packet(false);
let msg = build_msg_with_len(b'Z', 20, b"I");
assert_protocol_error(packet.unpack(msg, 0), "ReadyForQuery message too short");
}
#[test]
fn unpack_parameter_description_too_short() {
let mut packet = test_packet(false);
let msg = build_msg_with_len(b't', 20, &[0, 1]);
assert_protocol_error(
packet.unpack(msg, 0),
"ParameterDescription message too short",
);
}
#[test]
fn unpack_parameter_description_parses_oids() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend_from_slice(&2u16.to_be_bytes());
payload.extend_from_slice(&23u32.to_be_bytes());
payload.extend_from_slice(&25u32.to_be_bytes());
let mut msg = build_msg(b't', &payload);
msg.extend(build_msg(b'Z', b"I"));
let result = packet.unpack(msg, 0);
assert!(result.is_ok());
let success = result.unwrap();
assert_eq!(success.param_oids, vec![23, 25]);
}
#[test]
fn unpack_parameter_description_empty_params() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend_from_slice(&0u16.to_be_bytes());
let mut msg = build_msg(b't', &payload);
msg.extend(build_msg(b'Z', b"I"));
let result = packet.unpack(msg, 0);
assert!(result.is_ok());
let success = result.unwrap();
assert!(success.param_oids.is_empty());
}
#[test]
fn unpack_parameter_description_truncated_oids() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend_from_slice(&3u16.to_be_bytes());
payload.extend_from_slice(&23u32.to_be_bytes());
let mut msg = build_msg(b't', &payload);
msg.extend(build_msg(b'Z', b"I"));
let result = packet.unpack(msg, 0);
assert!(result.is_ok());
let success = result.unwrap();
assert_eq!(success.param_oids, vec![23]);
}
#[test]
fn unpack_parameter_description_debug_mode() {
let mut packet = test_packet(true);
let mut payload = Vec::new();
payload.extend_from_slice(&1u16.to_be_bytes());
payload.extend_from_slice(&23u32.to_be_bytes());
let mut msg = build_msg(b't', &payload);
msg.extend(build_msg(b'Z', b"I"));
let result = packet.unpack(msg, 0);
assert!(result.is_ok());
}
#[test]
fn unpack_parameter_description_payload_too_short_ignored() {
let mut packet = test_packet(false);
let mut msg = build_msg(b't', &[1]);
msg.extend(build_msg(b'Z', b"I"));
let result = packet.unpack(msg, 0);
assert!(result.is_ok());
let success = result.unwrap();
assert!(success.param_oids.is_empty());
}
#[test]
fn unpack_row_description_debug_mode() {
let mut packet = test_packet(true);
let payload = row_description_payload("col", 0, 25, 0);
let result = packet.unpack(build_msg(b'T', &payload), 0);
assert!(result.is_ok());
assert_eq!(result.unwrap().fields.len(), 1);
}
#[test]
fn unpack_row_description_field_name_no_null() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend_from_slice(&1u16.to_be_bytes());
payload.extend_from_slice(b"fieldwithoutnull");
let result = packet.unpack(build_msg(b'T', &payload), 0);
assert!(result.is_ok());
assert!(result.unwrap().fields.is_empty());
}
#[test]
fn unpack_row_description_field_too_short_after_name() {
let mut packet = test_packet(false);
let mut payload = Vec::new();
payload.extend_from_slice(&1u16.to_be_bytes());
payload.extend_from_slice(b"f\0");
payload.extend_from_slice(&[0u8; 10]);
let result = packet.unpack(build_msg(b'T', &payload), 0);
assert!(result.is_ok());
assert!(result.unwrap().fields.is_empty());
}
#[test]
fn unpack_parse_completion_with_payload() {
let mut packet = test_packet(false);
let msg = build_msg(b'1', &[0xAA, 0xBB]);
let result = packet.unpack(msg, 0);
assert!(result.is_ok());
}
#[test]
fn unpack_data_row_debug_mode() {
let mut packet = test_packet(true);
let row_desc = build_msg(b'T', &row_description_payload("v", 0, 25, 0));
let mut dr_payload = Vec::new();
dr_payload.extend_from_slice(&1u16.to_be_bytes());
dr_payload.extend_from_slice(&3u32.to_be_bytes());
dr_payload.extend_from_slice(b"abc");
let data_row = build_msg(b'D', &dr_payload);
let mut message = Vec::new();
message.extend(row_desc);
message.extend(data_row);
let result = packet.unpack(message, 0).unwrap();
assert_eq!(result.rows.len(), 1);
}
#[test]
fn unpack_data_row_fields_less_than_4_break() {
let mut packet = test_packet(false);
let row_desc = build_msg(b'T', &row_description_payload("x", 0, 25, 0));
let mut dr_payload = Vec::new();
dr_payload.extend_from_slice(&1u16.to_be_bytes());
dr_payload.extend_from_slice(&[0xAA, 0xBB]);
let data_row = build_msg(b'D', &dr_payload);
let mut message = Vec::new();
message.extend(row_desc);
message.extend(data_row);
let result = packet.unpack(message, 0).unwrap();
assert_eq!(result.rows.len(), 1);
}
#[test]
fn unpack_data_row_field_length_exceeds_remaining() {
let mut packet = test_packet(false);
let row_desc = build_msg(b'T', &row_description_payload("y", 0, 25, 0));
let mut dr_payload = Vec::new();
dr_payload.extend_from_slice(&1u16.to_be_bytes());
dr_payload.extend_from_slice(&100u32.to_be_bytes());
dr_payload.extend_from_slice(b"short");
let data_row = build_msg(b'D', &dr_payload);
let mut message = Vec::new();
message.extend(row_desc);
message.extend(data_row);
let result = packet.unpack(message, 0).unwrap();
assert_eq!(result.rows.len(), 1);
}
#[test]
fn unpack_command_completion_tag_without_trailing_null() {
let mut packet = test_packet(false);
let result = packet.unpack(build_msg(b'C', b"SELECT 3"), 0).unwrap();
assert_eq!(result.tag, "SELECT 3");
assert_eq!(result.affect_count, 3);
}
#[test]
fn bind_with_many_params_produces_correct_wire_format() {
let params: Vec<Option<String>> = (0..100).map(|i| Some(format!("v{i}"))).collect();
let data = Query::Bind(0, 0, params).vec_u8();
assert_eq!(data[0], b'B');
let portal_end = 5 + 1; let stmt_end = portal_end + 1; let fmt_count_pos = stmt_end;
let fmt_count =
i16::from_be_bytes(data[fmt_count_pos..fmt_count_pos + 2].try_into().unwrap());
assert_eq!(fmt_count, 1); let param_count_pos = fmt_count_pos + 2 + 2; let param_count = i16::from_be_bytes(
data[param_count_pos..param_count_pos + 2]
.try_into()
.unwrap(),
);
assert_eq!(param_count, 100);
}
#[test]
fn bind_with_long_param_value() {
let long_val = "x".repeat(100_000);
let params = vec![Some(long_val.clone())];
let data = Query::Bind(0, 0, params).vec_u8();
assert_eq!(data[0], b'B');
let payload_str = String::from_utf8_lossy(&data);
assert!(payload_str.contains(&long_val));
}
#[test]
fn bind_empty_string_vs_null() {
let empty_str = Query::Bind(0, 0, vec![Some(String::new())]).vec_u8();
let null_val = Query::Bind(0, 0, vec![None]).vec_u8();
assert_eq!(empty_str.len(), null_val.len());
let param_len_pos = 5 + 1 + 1 + 2 + 2 + 2; let empty_len = i32::from_be_bytes(
empty_str[param_len_pos..param_len_pos + 4]
.try_into()
.unwrap(),
);
let null_len = i32::from_be_bytes(
null_val[param_len_pos..param_len_pos + 4]
.try_into()
.unwrap(),
);
assert_eq!(empty_len, 0); assert_eq!(null_len, -1); }
}