use {
crate::{
error::{ClientResult, ConnectionSetupError, Error},
protocol::{
handshake::{ClientHandshake, ServerHandshake},
state_init::{DecodeState, MRespState, PipelineResult, RState},
Decoder,
},
query::Pipeline,
response::{FromResponse, Response},
Config, Query,
},
native_tls::Certificate,
std::ops::{Deref, DerefMut},
tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
},
tokio_native_tls::{TlsConnector, TlsStream},
};
#[derive(Debug)]
pub struct ConnectionAsync(TcpConnection<TcpStream>);
#[derive(Debug)]
pub struct ConnectionTlsAsync(TcpConnection<TlsStream<TcpStream>>);
impl Deref for ConnectionAsync {
type Target = TcpConnection<TcpStream>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for ConnectionAsync {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Deref for ConnectionTlsAsync {
type Target = TcpConnection<TlsStream<TcpStream>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for ConnectionTlsAsync {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Config {
pub async fn connect_async(&self) -> ClientResult<ConnectionAsync> {
TcpStream::connect((self.host(), self.port()))
.await
.map(TcpConnection::new)?
._handshake(self)
.await
.map(ConnectionAsync)
}
pub async fn connect_tls_async(&self, cert: &str) -> ClientResult<ConnectionTlsAsync> {
let stream = TcpStream::connect((self.host(), self.port())).await?;
let mut builder = native_tls::TlsConnector::builder();
builder
.add_root_certificate(Certificate::from_pem(cert.as_bytes()).map_err(|e| {
ConnectionSetupError::Other(format!("failed to parse certificate: {e}"))
})?)
.danger_accept_invalid_hostnames(true)
.build()
.map_err(|e| {
ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}"))
})?;
let connector = builder.build().map_err(|e| {
ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}"))
})?;
TlsConnector::from(connector)
.connect(self.host(), stream)
.await
.map(TcpConnection::new)
.map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?
._handshake(self)
.await
.map(ConnectionTlsAsync)
}
}
#[derive(Debug)]
pub struct TcpConnection<C: AsyncWriteExt + AsyncReadExt + Unpin> {
con: C,
buf: Vec<u8>,
}
impl<C: AsyncWriteExt + AsyncReadExt + Unpin> TcpConnection<C> {
fn new(con: C) -> Self {
Self {
con,
buf: Vec::with_capacity(crate::BUFSIZE),
}
}
async fn _handshake(mut self, cfg: &Config) -> ClientResult<Self> {
let handshake = ClientHandshake::new(cfg);
self.con.write_all(handshake.inner()).await?;
let mut resp = [0u8; 4];
self.con.read_exact(&mut resp).await?;
match ServerHandshake::parse(resp)? {
ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()),
ServerHandshake::Okay(_suggestion) => return Ok(self),
}
}
pub async fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult<Vec<Response>> {
self.buf.clear();
self.buf.push(b'P');
self.buf
.extend(itoa::Buffer::new().format(pipeline.buf().len()).as_bytes());
self.buf.push(b'\n');
self.con.write_all(&self.buf).await?;
self.con.write_all(pipeline.buf()).await?;
self.buf.clear();
let mut cursor = 0;
let mut state = MRespState::default();
loop {
let mut buf = [0u8; crate::BUFSIZE];
let n = self.con.read(&mut buf).await?;
if n == 0 {
return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into()));
}
self.buf.extend_from_slice(&buf[..n]);
let mut decoder = Decoder::new(&self.buf, cursor);
match decoder.validate_pipe(pipeline.query_count(), state) {
PipelineResult::Completed(r) => return Ok(r),
PipelineResult::Pending(_state) => {
cursor = decoder.position();
state = _state;
}
PipelineResult::Error(e) => return Err(e.into()),
}
}
}
pub async fn query(&mut self, q: &Query) -> ClientResult<Response> {
self.buf.clear();
q.write_packet(&mut self.buf).unwrap();
self.con.write_all(&self.buf).await?;
self.buf.clear();
let mut state = RState::default();
let mut cursor = 0;
let mut expected = Decoder::MIN_READBACK;
loop {
let mut buf = [0u8; crate::BUFSIZE];
let n = self.con.read(&mut buf).await?;
if n == 0 {
return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into()));
}
if n < expected {
continue;
}
self.buf.extend_from_slice(&buf[..n]);
let mut decoder = Decoder::new(&self.buf, cursor);
match decoder.validate_response(state) {
DecodeState::Completed(resp) => return Ok(resp),
DecodeState::ChangeState(_state) => {
expected = 1;
state = _state;
cursor = decoder.position();
}
DecodeState::Error(e) => return Err(Error::ProtocolError(e)),
}
}
}
pub async fn query_parse<T: FromResponse>(&mut self, q: &Query) -> ClientResult<T> {
self.query(q).await.and_then(FromResponse::from_response)
}
pub fn reset_buffer(&mut self) {
self.buf.shrink_to_fit()
}
}