1use crate::{
2 error::Error,
3 modules::inner::ClientInner,
4 protocol::{command::Command, connection, connection::Connection},
5 runtime::RefCount,
6 types::config::Server,
7};
8use futures::future::join_all;
9use std::{
10 collections::{HashMap, VecDeque},
11 fmt,
12 fmt::Formatter,
13};
14
15#[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))]
16use crate::types::config::TlsHostMapping;
17
18#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
20#[async_trait]
21pub trait ReplicaFilter: Send + Sync + 'static {
22 #[allow(unused_variables)]
24 async fn filter(&self, primary: &Server, replica: &Server) -> bool {
25 true
26 }
27}
28
29#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
36#[derive(Clone)]
37pub struct ReplicaConfig {
38 pub lazy_connections: bool,
42 pub filter: Option<RefCount<dyn ReplicaFilter>>,
46 pub ignore_reconnection_errors: bool,
52 pub primary_fallback: bool,
56}
57
58impl fmt::Debug for ReplicaConfig {
59 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60 f.debug_struct("ReplicaConfig")
61 .field("lazy_connections", &self.lazy_connections)
62 .field("ignore_reconnection_errors", &self.ignore_reconnection_errors)
63 .field("primary_fallback", &self.primary_fallback)
64 .finish()
65 }
66}
67
68impl PartialEq for ReplicaConfig {
69 fn eq(&self, other: &Self) -> bool {
70 self.lazy_connections == other.lazy_connections
71 && self.ignore_reconnection_errors == other.ignore_reconnection_errors
72 && self.primary_fallback == other.primary_fallback
73 }
74}
75
76impl Eq for ReplicaConfig {}
77
78impl Default for ReplicaConfig {
79 fn default() -> Self {
80 ReplicaConfig {
81 lazy_connections: true,
82 filter: None,
83 ignore_reconnection_errors: true,
84 primary_fallback: true,
85 }
86 }
87}
88
89#[derive(Clone, Debug, PartialEq, Eq, Default)]
92#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
93pub struct ReplicaRouter {
94 counter: usize,
95 servers: Vec<Server>,
96}
97
98impl ReplicaRouter {
99 pub fn next(&mut self) -> Option<&Server> {
101 self.counter = (self.counter + 1) % self.servers.len();
102 self.servers.get(self.counter)
103 }
104
105 pub fn add(&mut self, server: Server) {
107 if !self.servers.contains(&server) {
108 self.servers.push(server);
109 }
110 }
111
112 pub fn remove(&mut self, server: &Server) {
114 self.servers = self.servers.drain(..).filter(|_server| server != _server).collect();
115 }
116
117 pub fn len(&self) -> usize {
119 self.servers.len()
120 }
121
122 pub fn iter(&self) -> impl Iterator<Item = &Server> {
124 self.servers.iter()
125 }
126}
127
128#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
130#[derive(Clone, Debug, Eq, PartialEq, Default)]
131pub struct ReplicaSet {
132 servers: HashMap<Server, ReplicaRouter>,
134}
135
136impl ReplicaSet {
137 pub fn new() -> ReplicaSet {
139 ReplicaSet {
140 servers: HashMap::new(),
141 }
142 }
143
144 pub fn add(&mut self, primary: Server, replica: Server) {
146 self.servers.entry(primary).or_default().add(replica);
147 }
148
149 pub fn remove(&mut self, primary: &Server, replica: &Server) {
151 let should_remove = if let Some(router) = self.servers.get_mut(primary) {
152 router.remove(replica);
153 router.len() == 0
154 } else {
155 false
156 };
157
158 if should_remove {
159 self.servers.remove(primary);
160 }
161 }
162
163 pub fn remove_replica(&mut self, replica: &Server) {
165 self.servers = self
166 .servers
167 .drain()
168 .filter_map(|(primary, mut routing)| {
169 routing.remove(replica);
170
171 if routing.len() > 0 {
172 Some((primary, routing))
173 } else {
174 None
175 }
176 })
177 .collect();
178 }
179
180 pub fn next_replica(&mut self, primary: &Server) -> Option<&Server> {
182 self.servers.get_mut(primary).and_then(|router| router.next())
183 }
184
185 pub fn replicas(&self, primary: &Server) -> impl Iterator<Item = &Server> {
187 self
188 .servers
189 .get(primary)
190 .map(|router| router.iter())
191 .into_iter()
192 .flatten()
193 }
194
195 pub fn to_map(&self) -> HashMap<Server, Server> {
197 let mut out = HashMap::with_capacity(self.servers.len());
198 for (primary, replicas) in self.servers.iter() {
199 for replica in replicas.iter() {
200 out.insert(replica.clone(), primary.clone());
201 }
202 }
203
204 out
205 }
206
207 pub fn clear(&mut self) {
209 self.servers.clear();
210 }
211}
212
213#[cfg(feature = "replicas")]
215pub struct Replicas {
216 pub connections: HashMap<Server, Connection>,
217 pub routing: ReplicaSet,
218 pub buffer: VecDeque<Command>,
219}
220
221#[cfg(feature = "replicas")]
222#[allow(dead_code)]
223impl Replicas {
224 pub fn new() -> Replicas {
225 Replicas {
226 connections: HashMap::new(),
227 routing: ReplicaSet::new(),
228 buffer: VecDeque::new(),
229 }
230 }
231
232 pub async fn sync_connections(&mut self, inner: &RefCount<ClientInner>) -> Result<(), Error> {
234 for (_, mut writer) in self.connections.drain() {
235 let commands = writer.close().await;
236 self.buffer.extend(commands);
237 }
238
239 for (replica, primary) in self.routing.to_map() {
240 self.add_connection(inner, primary, replica, false).await?;
241 }
242
243 Ok(())
244 }
245
246 pub async fn clear_connections(&mut self, inner: &RefCount<ClientInner>) -> Result<(), Error> {
248 self.routing.clear();
249 self.sync_connections(inner).await
250 }
251
252 pub fn clear_routing(&mut self) {
254 self.routing.clear();
255 }
256
257 pub async fn add_connection(
259 &mut self,
260 inner: &RefCount<ClientInner>,
261 primary: Server,
262 replica: Server,
263 force: bool,
264 ) -> Result<(), Error> {
265 _debug!(
266 inner,
267 "Adding replica connection {} (replica) -> {} (primary)",
268 replica,
269 primary
270 );
271
272 if !inner.connection.replica.lazy_connections || force {
273 let mut transport = connection::create(inner, &replica, None).await?;
274 transport.setup(inner, None).await?;
275
276 if inner.config.server.is_clustered() {
277 transport.readonly(inner, None).await?;
278 };
279
280 if let Some(id) = transport.id {
281 inner
282 .backchannel
283 .connection_ids
284 .lock()
285 .insert(transport.server.clone(), id);
286 }
287 self.connections.insert(replica.clone(), transport.into_pipelined(true));
288 }
289
290 self.routing.add(primary, replica);
291 Ok(())
292 }
293
294 pub async fn drop_writer(&mut self, inner: &RefCount<ClientInner>, replica: &Server) {
296 if let Some(mut writer) = self.connections.remove(replica) {
297 self.buffer.extend(writer.close().await);
298 inner.backchannel.connection_ids.lock().remove(replica);
299 }
300 }
301
302 pub fn remove_replica(&mut self, replica: &Server) {
304 self.routing.remove_replica(replica);
305 }
306
307 pub async fn remove_connection(
309 &mut self,
310 inner: &RefCount<ClientInner>,
311 primary: &Server,
312 replica: &Server,
313 keep_routable: bool,
314 ) -> Result<(), Error> {
315 _debug!(
316 inner,
317 "Removing replica connection {} (replica) -> {} (primary)",
318 replica,
319 primary
320 );
321 self.drop_writer(inner, replica).await;
322
323 if !keep_routable {
324 self.routing.remove(primary, replica);
325 }
326 Ok(())
327 }
328
329 pub async fn flush(&mut self) -> Result<(), Error> {
331 for (_, writer) in self.connections.iter_mut() {
332 writer.flush().await?;
333 }
334
335 Ok(())
336 }
337
338 pub async fn has_replica_connection(&mut self, primary: &Server) -> bool {
340 for replica in self.routing.replicas(primary) {
341 if let Some(replica) = self.connections.get_mut(replica) {
342 if replica.peek_reader_errors().await.is_some() {
343 continue;
344 } else {
345 return true;
346 }
347 } else {
348 continue;
349 }
350 }
351
352 false
353 }
354
355 pub fn routing_table(&self) -> HashMap<Server, Server> {
357 self.routing.to_map()
358 }
359
360 pub async fn drop_broken_connections(&mut self) {
362 let mut new_writers = HashMap::with_capacity(self.connections.len());
363 for (server, mut writer) in self.connections.drain() {
364 if writer.peek_reader_errors().await.is_some() {
365 self.buffer.extend(writer.close().await);
366 self.routing.remove_replica(&server);
367 } else {
368 new_writers.insert(server, writer);
369 }
370 }
371
372 self.connections = new_writers;
373 }
374
375 pub async fn active_connections(&mut self) -> Vec<Server> {
377 join_all(self.connections.iter_mut().map(|(server, conn)| async move {
378 if conn.peek_reader_errors().await.is_some() {
379 None
380 } else {
381 Some(server.clone())
382 }
383 }))
384 .await
385 .into_iter()
386 .flatten()
387 .collect()
388 }
389
390 pub fn take_retry_buffer(&mut self) -> VecDeque<Command> {
392 self.buffer.drain(..).collect()
393 }
394
395 pub async fn drain(&mut self, inner: &RefCount<ClientInner>) -> Result<(), Error> {
396 let _ = join_all(self.connections.iter_mut().map(|(_, conn)| conn.drain(inner)))
398 .await
399 .into_iter()
400 .collect::<Result<Vec<()>, Error>>()?;
401
402 Ok(())
403 }
404}
405
406#[cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))]
407pub fn map_replica_tls_names(inner: &RefCount<ClientInner>, primary: &Server, replica: &mut Server) {
408 let policy = match inner.config.tls {
409 Some(ref config) => &config.hostnames,
410 None => {
411 _trace!(inner, "Skip modifying TLS hostname for replicas.");
412 return;
413 },
414 };
415 if *policy == TlsHostMapping::None {
416 _trace!(inner, "Skip modifying TLS hostnames for replicas.");
417 return;
418 }
419
420 replica.set_tls_server_name(policy, &primary.host);
421}
422
423#[cfg(not(any(feature = "enable-native-tls", feature = "enable-rustls")))]
424pub fn map_replica_tls_names(_: &RefCount<ClientInner>, _: &Server, _: &mut Server) {}