use crate::{
error::{Error, ErrorKind},
modules::inner::ClientInner,
protocol::{command::Command, connection, connection::ExclusiveConnection, types::Server},
router::connections::Connections,
runtime::{AsyncRwLock, RefCount},
utils,
};
use parking_lot::Mutex;
use redis_protocol::resp3::types::BytesFrame as Resp3Frame;
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
};
async fn check_and_create_transport(
backchannel: &Backchannel,
inner: &RefCount<ClientInner>,
server: &Server,
) -> Result<bool, Error> {
let mut transport = backchannel.transport.write().await;
if let Some(ref mut transport) = transport.deref_mut() {
if &transport.server == server && transport.ping(inner).await.is_ok() {
_debug!(inner, "Using existing backchannel connection to {}", server);
return Ok(false);
}
}
*transport.deref_mut() = None;
let mut _transport = connection::create(inner, server, None).await?;
_transport.setup(inner, None).await?;
*transport.deref_mut() = Some(_transport);
Ok(true)
}
pub struct Backchannel {
pub transport: AsyncRwLock<Option<ExclusiveConnection>>,
pub blocked: Mutex<Option<Server>>,
pub connection_ids: Mutex<HashMap<Server, i64>>,
}
impl Default for Backchannel {
fn default() -> Self {
Backchannel {
transport: AsyncRwLock::new(None),
blocked: Mutex::new(None),
connection_ids: Mutex::new(HashMap::new()),
}
}
}
impl Backchannel {
pub async fn check_and_disconnect(&self, inner: &RefCount<ClientInner>, server: Option<&Server>) {
let should_close = self
.current_server()
.await
.map(|current| server.map(|server| *server == current).unwrap_or(true))
.unwrap_or(false);
if should_close {
if let Some(ref mut transport) = self.transport.write().await.take() {
let _ = transport.disconnect(inner).await;
}
}
}
pub fn check_and_unblock(&self, server: &Server) {
let mut guard = self.blocked.lock();
let matches = if let Some(blocked) = guard.as_ref() {
blocked == server
} else {
false
};
if matches {
*guard = None;
}
}
pub async fn clear_router_state(&self, inner: &RefCount<ClientInner>) {
self.connection_ids.lock().clear();
self.blocked.lock().take();
if let Some(ref mut transport) = self.transport.write().await.take() {
let _ = transport.disconnect(inner).await;
}
}
pub fn update_connection_ids(&self, connections: &Connections) {
let mut guard = self.connection_ids.lock();
*guard.deref_mut() = connections.connection_ids();
}
pub fn remove_connection_id(&self, server: &Server) {
self.connection_ids.lock().get(server);
}
pub fn connection_id(&self, server: &Server) -> Option<i64> {
self.connection_ids.lock().get(server).cloned()
}
pub fn set_blocked(&self, server: &Server) {
self.blocked.lock().replace(server.clone());
}
pub fn set_unblocked(&self) {
self.blocked.lock().take();
}
pub fn check_and_set_unblocked(&self, server: &Server) {
let mut guard = self.blocked.lock();
if guard.as_ref().map(|b| b == server).unwrap_or(false) {
guard.take();
}
}
pub fn is_blocked(&self) -> bool {
self.blocked.lock().is_some()
}
pub async fn has_blocked_transport(&self) -> bool {
if let Some(server) = self.blocked_server() {
match self.transport.read().await.deref() {
Some(ref transport) => transport.server == server,
None => false,
}
} else {
false
}
}
pub fn blocked_server(&self) -> Option<Server> {
self.blocked.lock().clone()
}
pub async fn current_server(&self) -> Option<Server> {
self.transport.read().await.as_ref().map(|t| t.server.clone())
}
pub async fn any_server(&self) -> Option<Server> {
self
.current_server()
.await
.or(self.blocked_server())
.or_else(|| self.connection_ids.lock().keys().next().cloned())
}
pub async fn current_server_is_blocked(&self) -> bool {
self
.current_server()
.await
.and_then(|server| self.blocked_server().map(|blocked| server == blocked))
.unwrap_or(false)
}
pub async fn request_response(
&self,
inner: &RefCount<ClientInner>,
server: &Server,
command: Command,
) -> Result<Resp3Frame, Error> {
let _ = check_and_create_transport(self, inner, server).await?;
if let Some(ref mut transport) = self.transport.write().await.deref_mut() {
_debug!(
inner,
"Sending {} ({}) on backchannel to {}",
command.kind.to_str_debug(),
command.debug_id(),
server
);
utils::timeout(
transport.request_response(command, inner.is_resp3()),
inner.connection_timeout(),
)
.await
} else {
Err(Error::new(
ErrorKind::Unknown,
"Failed to create backchannel connection.",
))
}
}
pub async fn find_server(
&self,
inner: &RefCount<ClientInner>,
command: &Command,
use_blocked: bool,
) -> Result<Server, Error> {
if use_blocked {
if let Some(server) = self.blocked.lock().deref() {
Ok(server.clone())
} else {
Err(Error::new(ErrorKind::Unknown, "No connections are blocked."))
}
} else if inner.config.server.is_clustered() {
if command.kind.use_random_cluster_node() {
self
.any_server()
.await
.ok_or_else(|| Error::new(ErrorKind::Unknown, "Failed to find backchannel server."))
} else {
inner.with_cluster_state(|state| {
let slot = match command.cluster_hash() {
Some(slot) => slot,
None => return Err(Error::new(ErrorKind::Cluster, "Failed to find cluster hash slot.")),
};
state
.get_server(slot)
.cloned()
.ok_or_else(|| Error::new(ErrorKind::Cluster, "Failed to find cluster owner."))
})
}
} else {
self
.any_server()
.await
.ok_or_else(|| Error::new(ErrorKind::Unknown, "Failed to find backchannel server."))
}
}
}