use crate::{
Error, Future, Result, RetryReason, TcpStreamReader, TcpStreamWriter,
client::{Config, PreparedCommand},
commands::{
ClusterCommands, ConnectionCommands, HelloOptions, SentinelCommands, ServerCommands,
},
resp::{BufferDecoder, Command, CommandEncoder, RespResponse},
tcp_connect,
};
#[cfg(any(feature = "native-tls", feature = "rustls"))]
use crate::{TcpTlsStreamReader, TcpTlsStreamWriter, tcp_tls_connect};
use futures_util::{SinkExt, Stream, StreamExt, task::noop_waker_ref};
use log::{Level, debug, log_enabled, trace};
use serde::de::DeserializeOwned;
use std::{
future::IntoFuture,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio_util::codec::{FramedRead, FramedWrite};
pub(crate) enum Streams {
Tcp(
FramedRead<TcpStreamReader, BufferDecoder>,
FramedWrite<TcpStreamWriter, CommandEncoder>,
),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
TcpTls(
FramedRead<TcpTlsStreamReader, BufferDecoder>,
FramedWrite<TcpTlsStreamWriter, CommandEncoder>,
),
}
impl Streams {
pub async fn connect(host: &str, port: u16, config: &Config) -> Result<Self> {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
if let Some(tls_config) = &config.tls_config {
let (reader, writer) =
tcp_tls_connect(host, port, tls_config, config.connect_timeout).await?;
let framed_read = FramedRead::new(reader, BufferDecoder);
let framed_write = FramedWrite::new(writer, CommandEncoder);
Ok(Streams::TcpTls(framed_read, framed_write))
} else {
Self::connect_non_secure(host, port, config).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
Self::connect_non_secure(host, port, config).await
}
pub async fn connect_non_secure(host: &str, port: u16, config: &Config) -> Result<Self> {
let (reader, writer) = tcp_connect(host, port, config).await?;
let framed_read = FramedRead::new(reader, BufferDecoder);
let framed_write = FramedWrite::new(writer, CommandEncoder);
Ok(Streams::Tcp(framed_read, framed_write))
}
}
pub struct StandaloneConnection {
host: String,
port: u16,
config: Config,
streams: Streams,
version: String,
tag: Arc<str>,
}
impl StandaloneConnection {
pub async fn connect(host: &str, port: u16, config: &Config) -> Result<Self> {
let streams = Streams::connect(host, port, config).await?;
let mut connection = Self {
host: host.to_owned(),
port,
config: config.clone(),
streams,
version: String::new(),
tag: if config.connection_name.is_empty() {
format!("{host}:{port}").into()
} else {
format!("{}:{}:{}", config.connection_name, host, port).into()
},
};
connection.post_connect().await?;
Ok(connection)
}
async fn write(&mut self, command: &Command) -> Result<()> {
debug!("[{}] Sending command: {command}", self.tag);
match &mut self.streams {
Streams::Tcp(_, framed_write) => framed_write.send(command).await,
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Streams::TcpTls(_, framed_write) => framed_write.send(command).await,
}
}
pub async fn feed(&mut self, command: &Command, _retry_reasons: &[RetryReason]) -> Result<()> {
debug!("[{}] Sending command: {command}", self.tag);
#[cfg(debug_assertions)]
if command.try_decrement_kill_connection_on_write() {
let client_id = self.client_id().await?;
let mut config = self.config.clone();
"killer".clone_into(&mut config.connection_name);
let mut connection =
StandaloneConnection::connect(&self.host, self.port, &config).await?;
connection
.client_kill(crate::commands::ClientKillOptions::default().id(client_id))
.await?;
}
match &mut self.streams {
Streams::Tcp(_, framed_write) => framed_write.feed(command).await,
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Streams::TcpTls(_, framed_write) => framed_write.feed(command).await,
}
}
pub async fn flush(&mut self) -> Result<()> {
trace!("[{}] Flushing...", self.tag);
match &mut self.streams {
Streams::Tcp(_, framed_write) => framed_write.flush().await,
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Streams::TcpTls(_, framed_write) => framed_write.flush().await,
}
}
pub async fn read(&mut self) -> Option<Result<RespResponse>> {
if let Some(result) = match &mut self.streams {
Streams::Tcp(framed_read, _) => framed_read.next().await,
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Streams::TcpTls(framed_read, _) => framed_read.next().await,
} {
if log_enabled!(Level::Debug) {
match &result {
Ok(response) => debug!("[{}] Received response {response:?}", self.tag),
Err(err) => debug!("[{}] Received response {err:?}", self.tag),
}
}
Some(result)
} else {
debug!("[{}] Socked is closed", self.tag);
None
}
}
pub fn try_read(&mut self) -> Poll<Option<Result<RespResponse>>> {
let waker = noop_waker_ref();
let mut cx = Context::from_waker(waker);
let poll_result = match &mut self.streams {
Streams::Tcp(framed_read, _) => Pin::new(framed_read).poll_next(&mut cx),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Streams::TcpTls(framed_read, _) => Pin::new(framed_read).poll_next(&mut cx),
};
match poll_result {
Poll::Ready(Some(result)) => {
if log_enabled!(Level::Debug) {
match &result {
Ok(response) => {
debug!("[{}] (try_read) Received result {response:?}", self.tag)
}
Err(err) => debug!("[{}] (try_read) Received result {err:?}", self.tag),
}
}
Poll::Ready(Some(result))
}
Poll::Ready(None) => {
debug!("[{}] Socket is closed", self.tag);
Poll::Ready(None)
}
Poll::Pending => Poll::Pending, }
}
pub async fn reconnect(&mut self) -> Result<()> {
self.streams = Streams::connect(&self.host, self.port, &self.config).await?;
self.post_connect().await?;
Ok(())
}
async fn post_connect(&mut self) -> Result<()> {
let mut hello_options = HelloOptions::new(3);
let config_username = self.config.username.clone();
let config_password = self.config.password.clone();
let config_connection_name = self.config.connection_name.clone();
if let Some(password) = &config_password {
hello_options = hello_options.auth(
match &config_username {
Some(username) => username,
None => "default",
},
password,
);
}
if !config_connection_name.is_empty() {
hello_options = hello_options.set_name(&config_connection_name);
}
let hello_result = self.hello(hello_options).await?;
self.version = hello_result.version;
if self.config.database != 0 {
self.select(self.config.database).await?;
}
Ok(())
}
pub fn get_version(&self) -> &str {
&self.version
}
pub(crate) fn tag(&self) -> Arc<str> {
self.tag.clone()
}
}
impl<'a, R> IntoFuture for PreparedCommand<'a, &'a mut StandaloneConnection, R>
where
R: DeserializeOwned + Send + 'a,
{
type Output = Result<R>;
type IntoFuture = Future<'a, R>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
self.executor.write(&self.command).await?;
let response = self
.executor
.read()
.await
.ok_or_else(|| Error::DisconnectedByPeer)??;
response.to()
})
}
}
impl<'a> ClusterCommands<'a> for &'a mut StandaloneConnection {}
impl<'a> ConnectionCommands<'a> for &'a mut StandaloneConnection {}
impl<'a> SentinelCommands<'a> for &'a mut StandaloneConnection {}
impl<'a> ServerCommands<'a> for &'a mut StandaloneConnection {}