use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::TryInto;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use crate::connection::{Connect, Connection};
use crate::executor::Executor;
use crate::postgres::protocol::{
Authentication, AuthenticationMd5, AuthenticationSasl, BackendKeyData, Message,
PasswordMessage, StartupMessage, StatementId, Terminate,
};
use crate::postgres::row::Statement;
use crate::postgres::stream::PgStream;
use crate::postgres::type_info::SharedStr;
use crate::postgres::{sasl, tls};
use crate::url::Url;
pub struct PgConnection {
pub(super) stream: PgStream,
pub(super) next_statement_id: u32,
pub(super) is_ready: bool,
pub(super) cache_statement_id: HashMap<Box<str>, StatementId>,
pub(super) cache_statement: HashMap<StatementId, Arc<Statement>>,
pub(super) cache_type_oid: HashMap<SharedStr, u32>,
pub(super) cache_type_name: HashMap<u32, SharedStr>,
pub(super) current_row_values: Vec<Option<(u32, u32)>>,
#[allow(dead_code)]
process_id: u32,
#[allow(dead_code)]
secret_key: u32,
}
async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyData> {
let username = url
.username()
.or_else(|| std::env::var("USER").map(Cow::Owned).ok())
.unwrap_or(Cow::Borrowed("postgres"));
let database = url.database().unwrap_or(&username);
let params = &[
("user", username.as_ref()),
("database", database),
("DateStyle", "ISO, MDY"),
("TimeZone", "UTC"),
("extra_float_digits", "3"),
("client_encoding", "UTF-8"),
];
stream.write(StartupMessage { params });
stream.flush().await?;
let mut key_data = BackendKeyData {
process_id: 0,
secret_key: 0,
};
loop {
match stream.receive().await? {
Message::Authentication => match Authentication::read(stream.buffer())? {
Authentication::Ok => {
}
Authentication::CleartextPassword => {
stream.write(PasswordMessage::ClearText(
&url.password().unwrap_or_default(),
));
stream.flush().await?;
}
Authentication::Md5Password => {
let data = AuthenticationMd5::read(&stream.buffer()[4..])?;
stream.write(PasswordMessage::Md5 {
password: &url.password().unwrap_or_default(),
user: username.as_ref(),
salt: data.salt,
});
stream.flush().await?;
}
Authentication::Sasl => {
let data = AuthenticationSasl::read(&stream.buffer()[4..])?;
let mut has_sasl: bool = false;
let mut has_sasl_plus: bool = false;
for mechanism in &*data.mechanisms {
match &**mechanism {
"SCRAM-SHA-256" => {
has_sasl = true;
}
"SCRAM-SHA-256-PLUS" => {
has_sasl_plus = true;
}
_ => {
log::info!("unsupported auth mechanism: {}", mechanism);
}
}
}
if has_sasl || has_sasl_plus {
sasl::authenticate(
stream,
username.as_ref(),
&url.password().unwrap_or_default(),
)
.await?;
} else {
return Err(protocol_err!(
"unsupported SASL auth mechanisms: {:?}",
data.mechanisms
)
.into());
}
}
auth => {
return Err(
protocol_err!("requested unsupported authentication: {:?}", auth).into(),
);
}
},
Message::BackendKeyData => {
key_data = BackendKeyData::read(stream.buffer())?;
}
Message::ParameterStatus => {
}
Message::ReadyForQuery => {
break;
}
type_ => {
return Err(protocol_err!("unexpected message: {:?}", type_).into());
}
}
}
Ok(key_data)
}
async fn terminate(mut stream: PgStream) -> crate::Result<()> {
stream.write(Terminate);
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
impl PgConnection {
pub(super) async fn new(url: std::result::Result<Url, url::ParseError>) -> crate::Result<Self> {
let url = url?;
let mut stream = PgStream::new(&url).await?;
tls::request_if_needed(&mut stream, &url).await?;
let key_data = startup(&mut stream, &url).await?;
Ok(Self {
stream,
current_row_values: Vec::with_capacity(10),
next_statement_id: 1,
is_ready: true,
cache_type_oid: HashMap::new(),
cache_type_name: HashMap::new(),
cache_statement_id: HashMap::with_capacity(10),
cache_statement: HashMap::with_capacity(10),
process_id: key_data.process_id,
secret_key: key_data.secret_key,
})
}
}
impl Connect for PgConnection {
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<PgConnection>>
where
T: TryInto<Url, Error = url::ParseError>,
Self: Sized,
{
Box::pin(PgConnection::new(url.try_into()))
}
}
impl Connection for PgConnection {
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(terminate(self.stream))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(Executor::execute(self, "SELECT 1").map_ok(|_| ()))
}
}