use std::collections::BTreeMap;
use std::ops::{ControlFlow, Deref, DerefMut};
use std::str::FromStr;
use futures_channel::mpsc::UnboundedSender;
use futures_util::SinkExt;
use log::Level;
use sqlx_core::bytes::Buf;
use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error;
use crate::message::{
BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
ParameterStatus, ReceivedMessage,
};
use crate::net::{self, BufferedSocket, Socket};
use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
pub struct PgStream {
inner: BufferedSocket<Box<dyn Socket>>,
pub(crate) notifications: Option<UnboundedSender<Notification>>,
pub(crate) parameter_statuses: BTreeMap<String, String>,
pub(crate) server_version_num: Option<u32>,
}
impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket_result = match options.fetch_socket() {
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
};
let socket = socket_result?;
Ok(Self {
inner: BufferedSocket::new(socket),
notifications: None,
parameter_statuses: BTreeMap::default(),
server_version_num: None,
})
}
#[inline(always)]
pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
self.write(EncodeMessage(message))
}
pub(crate) async fn send<T>(&mut self, message: T) -> Result<(), Error>
where
T: FrontendMessage,
{
self.write_msg(message)?;
self.flush().await?;
Ok(())
}
pub(crate) async fn recv_expect<B: BackendMessage>(&mut self) -> Result<B, Error> {
self.recv().await?.decode()
}
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
self.inner
.try_read(|buf| {
let Some(mut header) = buf.get(..5) else {
return Ok(ControlFlow::Continue(5));
};
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let message_len = header.get_u32() as usize;
let expected_len = message_len
.checked_add(1)
.ok_or_else(|| {
err_protocol!("message_len + 1 overflows usize: {message_len}")
})?;
if buf.len() < expected_len {
return Ok(ControlFlow::Continue(expected_len));
}
buf.advance(1);
let mut contents = buf.split_to(message_len).freeze();
contents.advance(4);
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
})
.await
}
pub(crate) async fn recv(&mut self) -> Result<ReceivedMessage, Error> {
loop {
let message = self.recv_unchecked().await?;
match message.format {
BackendMessageFormat::ErrorResponse => {
return Err(message.decode::<PgDatabaseError>()?.into());
}
BackendMessageFormat::NotificationResponse => {
if let Some(buffer) = &mut self.notifications {
let notification: Notification = message.decode()?;
let _ = buffer.send(notification).await;
continue;
}
}
BackendMessageFormat::ParameterStatus => {
let ParameterStatus { name, value } = message.decode()?;
match name.as_str() {
"server_version" => {
self.server_version_num = parse_server_version(&value);
}
_ => {
self.parameter_statuses.insert(name, value);
}
}
continue;
}
BackendMessageFormat::NoticeResponse => {
let notice: Notice = message.decode()?;
let (log_level, tracing_level) = match notice.severity() {
PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
(Level::Error, tracing::Level::ERROR)
}
PgSeverity::Warning => (Level::Warn, tracing::Level::WARN),
PgSeverity::Notice => (Level::Info, tracing::Level::INFO),
PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG),
PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE),
};
let log_is_enabled = log::log_enabled!(
target: "sqlx::postgres::notice",
log_level
) || sqlx_core::private_tracing_dynamic_enabled!(
target: "sqlx::postgres::notice",
tracing_level
);
if log_is_enabled {
sqlx_core::private_tracing_dynamic_event!(
target: "sqlx::postgres::notice",
tracing_level,
message = notice.message()
);
}
continue;
}
_ => {}
}
return Ok(message);
}
}
}
impl Deref for PgStream {
type Target = BufferedSocket<Box<dyn Socket>>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for PgStream {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
fn parse_server_version(s: &str) -> Option<u32> {
let mut parts = Vec::<u32>::with_capacity(3);
let mut from = 0;
let mut chs = s.char_indices().peekable();
while let Some((i, ch)) = chs.next() {
match ch {
'.' => {
if let Ok(num) = u32::from_str(&s[from..i]) {
parts.push(num);
from = i + 1;
} else {
break;
}
}
_ if ch.is_ascii_digit() => {
if chs.peek().is_none() {
if let Ok(num) = u32::from_str(&s[from..]) {
parts.push(num);
}
break;
}
}
_ => {
if let Ok(num) = u32::from_str(&s[from..i]) {
parts.push(num);
}
break;
}
};
}
let version_num = match parts.as_slice() {
[major, minor, rev] => (100 * major + minor) * 100 + rev,
[major, minor] if *major >= 10 => 100 * 100 * major + minor,
[major, minor] => (100 * major + minor) * 100,
[major] => 100 * 100 * major,
_ => return None,
};
Some(version_num)
}
#[cfg(test)]
mod tests {
use super::parse_server_version;
#[test]
fn test_parse_server_version_num() {
assert_eq!(parse_server_version("9.6.1"), Some(90601));
assert_eq!(parse_server_version("10.1"), Some(100001));
assert_eq!(parse_server_version("9.6devel"), Some(90600));
assert_eq!(parse_server_version("10devel"), Some(100000));
assert_eq!(parse_server_version("13devel87"), Some(130000));
assert_eq!(parse_server_version("unknown"), None);
}
}