use crate::{
error::Error,
modules::inner::ClientInner,
protocol::{command::Command, connection, connection::Connection},
runtime::RefCount,
types::config::Server,
};
use futures::future::join_all;
use std::{
collections::{HashMap, VecDeque},
fmt,
fmt::Formatter,
};
#[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))]
use crate::types::config::TlsHostMapping;
#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
#[async_trait]
pub trait ReplicaFilter: Send + Sync + 'static {
#[allow(unused_variables)]
async fn filter(&self, primary: &Server, replica: &Server) -> bool {
true
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
#[derive(Clone)]
pub struct ReplicaConfig {
pub lazy_connections: bool,
pub filter: Option<RefCount<dyn ReplicaFilter>>,
pub ignore_reconnection_errors: bool,
pub primary_fallback: bool,
}
impl fmt::Debug for ReplicaConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ReplicaConfig")
.field("lazy_connections", &self.lazy_connections)
.field("ignore_reconnection_errors", &self.ignore_reconnection_errors)
.field("primary_fallback", &self.primary_fallback)
.finish()
}
}
impl PartialEq for ReplicaConfig {
fn eq(&self, other: &Self) -> bool {
self.lazy_connections == other.lazy_connections
&& self.ignore_reconnection_errors == other.ignore_reconnection_errors
&& self.primary_fallback == other.primary_fallback
}
}
impl Eq for ReplicaConfig {}
impl Default for ReplicaConfig {
fn default() -> Self {
ReplicaConfig {
lazy_connections: true,
filter: None,
ignore_reconnection_errors: true,
primary_fallback: true,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
pub struct ReplicaRouter {
counter: usize,
servers: Vec<Server>,
}
impl ReplicaRouter {
pub fn next(&mut self) -> Option<&Server> {
self.counter = (self.counter + 1) % self.servers.len();
self.servers.get(self.counter)
}
pub fn add(&mut self, server: Server) {
if !self.servers.contains(&server) {
self.servers.push(server);
}
}
pub fn remove(&mut self, server: &Server) {
self.servers = self.servers.drain(..).filter(|_server| server != _server).collect();
}
pub fn len(&self) -> usize {
self.servers.len()
}
pub fn iter(&self) -> impl Iterator<Item = &Server> {
self.servers.iter()
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct ReplicaSet {
servers: HashMap<Server, ReplicaRouter>,
}
impl ReplicaSet {
pub fn new() -> ReplicaSet {
ReplicaSet {
servers: HashMap::new(),
}
}
pub fn add(&mut self, primary: Server, replica: Server) {
self.servers.entry(primary).or_default().add(replica);
}
pub fn remove(&mut self, primary: &Server, replica: &Server) {
let should_remove = if let Some(router) = self.servers.get_mut(primary) {
router.remove(replica);
router.len() == 0
} else {
false
};
if should_remove {
self.servers.remove(primary);
}
}
pub fn remove_replica(&mut self, replica: &Server) {
self.servers = self
.servers
.drain()
.filter_map(|(primary, mut routing)| {
routing.remove(replica);
if routing.len() > 0 {
Some((primary, routing))
} else {
None
}
})
.collect();
}
pub fn next_replica(&mut self, primary: &Server) -> Option<&Server> {
self.servers.get_mut(primary).and_then(|router| router.next())
}
pub fn replicas(&self, primary: &Server) -> impl Iterator<Item = &Server> {
self
.servers
.get(primary)
.map(|router| router.iter())
.into_iter()
.flatten()
}
pub fn to_map(&self) -> HashMap<Server, Server> {
let mut out = HashMap::with_capacity(self.servers.len());
for (primary, replicas) in self.servers.iter() {
for replica in replicas.iter() {
out.insert(replica.clone(), primary.clone());
}
}
out
}
pub fn clear(&mut self) {
self.servers.clear();
}
}
#[cfg(feature = "replicas")]
pub struct Replicas {
pub connections: HashMap<Server, Connection>,
pub routing: ReplicaSet,
pub buffer: VecDeque<Command>,
}
#[cfg(feature = "replicas")]
#[allow(dead_code)]
impl Replicas {
pub fn new() -> Replicas {
Replicas {
connections: HashMap::new(),
routing: ReplicaSet::new(),
buffer: VecDeque::new(),
}
}
pub async fn sync_connections(&mut self, inner: &RefCount<ClientInner>) -> Result<(), Error> {
for (_, mut writer) in self.connections.drain() {
let commands = writer.close().await;
self.buffer.extend(commands);
}
for (replica, primary) in self.routing.to_map() {
self.add_connection(inner, primary, replica, false).await?;
}
Ok(())
}
pub async fn clear_connections(&mut self, inner: &RefCount<ClientInner>) -> Result<(), Error> {
self.routing.clear();
self.sync_connections(inner).await
}
pub fn clear_routing(&mut self) {
self.routing.clear();
}
pub async fn add_connection(
&mut self,
inner: &RefCount<ClientInner>,
primary: Server,
replica: Server,
force: bool,
) -> Result<(), Error> {
_debug!(
inner,
"Adding replica connection {} (replica) -> {} (primary)",
replica,
primary
);
if !inner.connection.replica.lazy_connections || force {
let mut transport = connection::create(inner, &replica, None).await?;
transport.setup(inner, None).await?;
if inner.config.server.is_clustered() {
transport.readonly(inner, None).await?;
};
if let Some(id) = transport.id {
inner
.backchannel
.connection_ids
.lock()
.insert(transport.server.clone(), id);
}
self.connections.insert(replica.clone(), transport.into_pipelined(true));
}
self.routing.add(primary, replica);
Ok(())
}
pub async fn drop_writer(&mut self, inner: &RefCount<ClientInner>, replica: &Server) {
if let Some(mut writer) = self.connections.remove(replica) {
self.buffer.extend(writer.close().await);
inner.backchannel.connection_ids.lock().remove(replica);
}
}
pub fn remove_replica(&mut self, replica: &Server) {
self.routing.remove_replica(replica);
}
pub async fn remove_connection(
&mut self,
inner: &RefCount<ClientInner>,
primary: &Server,
replica: &Server,
keep_routable: bool,
) -> Result<(), Error> {
_debug!(
inner,
"Removing replica connection {} (replica) -> {} (primary)",
replica,
primary
);
self.drop_writer(inner, replica).await;
if !keep_routable {
self.routing.remove(primary, replica);
}
Ok(())
}
pub async fn flush(&mut self) -> Result<(), Error> {
for (_, writer) in self.connections.iter_mut() {
writer.flush().await?;
}
Ok(())
}
pub async fn has_replica_connection(&mut self, primary: &Server) -> bool {
for replica in self.routing.replicas(primary) {
if let Some(replica) = self.connections.get_mut(replica) {
if replica.peek_reader_errors().await.is_some() {
continue;
} else {
return true;
}
} else {
continue;
}
}
false
}
pub fn routing_table(&self) -> HashMap<Server, Server> {
self.routing.to_map()
}
pub async fn drop_broken_connections(&mut self) {
let mut new_writers = HashMap::with_capacity(self.connections.len());
for (server, mut writer) in self.connections.drain() {
if writer.peek_reader_errors().await.is_some() {
self.buffer.extend(writer.close().await);
self.routing.remove_replica(&server);
} else {
new_writers.insert(server, writer);
}
}
self.connections = new_writers;
}
pub async fn active_connections(&mut self) -> Vec<Server> {
join_all(self.connections.iter_mut().map(|(server, conn)| async move {
if conn.peek_reader_errors().await.is_some() {
None
} else {
Some(server.clone())
}
}))
.await
.into_iter()
.flatten()
.collect()
}
pub fn take_retry_buffer(&mut self) -> VecDeque<Command> {
self.buffer.drain(..).collect()
}
pub async fn drain(&mut self, inner: &RefCount<ClientInner>) -> Result<(), Error> {
let _ = join_all(self.connections.iter_mut().map(|(_, conn)| conn.drain(inner)))
.await
.into_iter()
.collect::<Result<Vec<()>, Error>>()?;
Ok(())
}
}
#[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))]
pub fn map_replica_tls_names(inner: &RefCount<ClientInner>, primary: &Server, replica: &mut Server) {
let policy = match inner.config.tls {
Some(ref config) => &config.hostnames,
None => {
_trace!(inner, "Skip modifying TLS hostname for replicas.");
return;
},
};
if *policy == TlsHostMapping::None {
_trace!(inner, "Skip modifying TLS hostnames for replicas.");
return;
}
replica.set_tls_server_name(policy, &primary.host);
}
#[cfg(not(any(feature = "enable-native-tls", feature = "enable-rustls")))]
pub fn map_replica_tls_names(_: &RefCount<ClientInner>, _: &Server, _: &mut Server) {}