1use crate::raft::NodeId;
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
12pub struct Region(String);
13
14impl Region {
15 pub fn new(name: impl Into<String>) -> Self {
17 Self(name.into())
18 }
19
20 pub fn name(&self) -> &str {
22 &self.0
23 }
24}
25
26impl From<&str> for Region {
27 fn from(s: &str) -> Self {
28 Self(s.to_string())
29 }
30}
31
32impl std::fmt::Display for Region {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "{}", self.0)
35 }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct DatacenterId(String);
41
42impl DatacenterId {
43 pub fn new(id: impl Into<String>) -> Self {
45 Self(id.into())
46 }
47
48 pub fn as_str(&self) -> &str {
50 &self.0
51 }
52}
53
54impl From<&str> for DatacenterId {
55 fn from(s: &str) -> Self {
56 Self(s.to_string())
57 }
58}
59
60impl std::fmt::Display for DatacenterId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Datacenter {
69 pub id: DatacenterId,
71 pub region: Region,
73 pub nodes: HashSet<NodeId>,
75 pub priority: i32,
77}
78
79impl Datacenter {
80 pub fn new(id: DatacenterId, region: Region) -> Self {
82 Self {
83 id,
84 region,
85 nodes: HashSet::new(),
86 priority: 0,
87 }
88 }
89
90 pub fn add_node(&mut self, node_id: NodeId) {
92 self.nodes.insert(node_id);
93 }
94
95 pub fn remove_node(&mut self, node_id: &NodeId) -> bool {
97 self.nodes.remove(node_id)
98 }
99
100 pub fn has_node(&self, node_id: &NodeId) -> bool {
102 self.nodes.contains(node_id)
103 }
104}
105
106pub struct MultiDatacenterCoordinator {
108 datacenters: HashMap<DatacenterId, Datacenter>,
110 node_to_dc: HashMap<NodeId, DatacenterId>,
112 latencies: HashMap<(DatacenterId, DatacenterId), u64>,
114}
115
116impl MultiDatacenterCoordinator {
117 pub fn new() -> Self {
119 Self {
120 datacenters: HashMap::new(),
121 node_to_dc: HashMap::new(),
122 latencies: HashMap::new(),
123 }
124 }
125
126 pub fn add_datacenter(&mut self, dc: Datacenter) {
128 self.datacenters.insert(dc.id.clone(), dc);
129 }
130
131 pub fn register_node(&mut self, node_id: NodeId, dc_id: DatacenterId) -> Result<(), String> {
133 let dc = self
134 .datacenters
135 .get_mut(&dc_id)
136 .ok_or_else(|| format!("Datacenter {dc_id} not found"))?;
137
138 dc.add_node(node_id);
139 self.node_to_dc.insert(node_id, dc_id);
140 Ok(())
141 }
142
143 pub fn unregister_node(&mut self, node_id: &NodeId) {
145 if let Some(dc_id) = self.node_to_dc.remove(node_id) {
146 if let Some(dc) = self.datacenters.get_mut(&dc_id) {
147 dc.remove_node(node_id);
148 }
149 }
150 }
151
152 pub fn get_node_datacenter(&self, node_id: &NodeId) -> Option<&Datacenter> {
154 self.node_to_dc
155 .get(node_id)
156 .and_then(|dc_id| self.datacenters.get(dc_id))
157 }
158
159 pub fn record_latency(&mut self, from: DatacenterId, to: DatacenterId, latency_ms: u64) {
161 self.latencies
162 .insert((from.clone(), to.clone()), latency_ms);
163 self.latencies.insert((to, from), latency_ms);
165 }
166
167 pub fn get_latency(&self, from: &DatacenterId, to: &DatacenterId) -> Option<u64> {
169 self.latencies.get(&(from.clone(), to.clone())).copied()
170 }
171
172 pub fn datacenters(&self) -> &HashMap<DatacenterId, Datacenter> {
174 &self.datacenters
175 }
176
177 pub fn total_nodes(&self) -> usize {
179 self.node_to_dc.len()
180 }
181}
182
183impl Default for MultiDatacenterCoordinator {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189#[derive(Debug, Clone)]
191pub enum ReplicationPolicy {
192 AllDatacenters,
194 Regions(Vec<Region>),
196 NClosest(usize),
198 Custom(Vec<DatacenterId>),
200}
201
202impl ReplicationPolicy {
203 pub fn select_datacenters(
205 &self,
206 coordinator: &MultiDatacenterCoordinator,
207 source_dc: &DatacenterId,
208 ) -> Vec<DatacenterId> {
209 match self {
210 ReplicationPolicy::AllDatacenters => coordinator.datacenters.keys().cloned().collect(),
211 ReplicationPolicy::Regions(regions) => coordinator
212 .datacenters
213 .values()
214 .filter(|dc| regions.contains(&dc.region))
215 .map(|dc| dc.id.clone())
216 .collect(),
217 ReplicationPolicy::NClosest(n) => {
218 let mut dcs: Vec<_> = coordinator
219 .datacenters
220 .keys()
221 .filter(|dc_id| *dc_id != source_dc)
222 .cloned()
223 .collect();
224
225 dcs.sort_by_key(|dc_id| {
227 coordinator
228 .get_latency(source_dc, dc_id)
229 .unwrap_or(u64::MAX)
230 });
231
232 dcs.into_iter().take(*n).collect()
233 }
234 ReplicationPolicy::Custom(dcs) => dcs.clone(),
235 }
236 }
237}
238
239pub struct LatencyAwareSelector {
241 coordinator: Arc<MultiDatacenterCoordinator>,
242 local_preference: bool,
244 max_latency_ms: Option<u64>,
246}
247
248impl LatencyAwareSelector {
249 pub fn new(coordinator: Arc<MultiDatacenterCoordinator>) -> Self {
251 Self {
252 coordinator,
253 local_preference: true,
254 max_latency_ms: None,
255 }
256 }
257
258 pub fn with_local_preference(mut self, enabled: bool) -> Self {
260 self.local_preference = enabled;
261 self
262 }
263
264 pub fn with_max_latency(mut self, latency_ms: u64) -> Self {
266 self.max_latency_ms = Some(latency_ms);
267 self
268 }
269
270 pub fn select_read_nodes(
272 &self,
273 available_nodes: &[NodeId],
274 local_node: &NodeId,
275 ) -> Vec<NodeId> {
276 let local_dc = self.coordinator.get_node_datacenter(local_node);
277
278 let mut candidates: Vec<_> = available_nodes
279 .iter()
280 .filter_map(|node_id| {
281 let node_dc = self.coordinator.get_node_datacenter(node_id)?;
282
283 let latency = if let Some(local) = local_dc {
285 self.coordinator
286 .get_latency(&local.id, &node_dc.id)
287 .unwrap_or(0)
288 } else {
289 0
290 };
291
292 if let Some(max_lat) = self.max_latency_ms {
294 if latency > max_lat {
295 return None;
296 }
297 }
298
299 Some((node_id, node_dc, latency))
300 })
301 .collect();
302
303 candidates.sort_by(|(_, dc1, lat1), (_, dc2, lat2)| {
305 if let (true, Some(local)) = (self.local_preference, local_dc) {
306 match (dc1.id == local.id, dc2.id == local.id) {
307 (true, false) => std::cmp::Ordering::Less,
308 (false, true) => std::cmp::Ordering::Greater,
309 _ => lat1.cmp(lat2),
310 }
311 } else {
312 lat1.cmp(lat2)
313 }
314 });
315
316 candidates
317 .into_iter()
318 .map(|(node_id, _, _)| *node_id)
319 .collect()
320 }
321}
322
323#[derive(Debug, Clone, Default)]
325pub struct CrossDcStats {
326 pub cross_dc_requests: u64,
328 pub local_requests: u64,
330 pub avg_cross_dc_latency_ms: f64,
332}
333
334impl CrossDcStats {
335 pub fn new() -> Self {
337 Self::default()
338 }
339
340 pub fn record_cross_dc(&mut self, latency_ms: u64) {
342 let total_latency = self.avg_cross_dc_latency_ms * self.cross_dc_requests as f64;
343 self.cross_dc_requests += 1;
344 self.avg_cross_dc_latency_ms =
345 (total_latency + latency_ms as f64) / self.cross_dc_requests as f64;
346 }
347
348 pub fn record_local(&mut self) {
350 self.local_requests += 1;
351 }
352
353 pub fn total_requests(&self) -> u64 {
355 self.cross_dc_requests + self.local_requests
356 }
357
358 pub fn cross_dc_percentage(&self) -> f64 {
360 let total = self.total_requests();
361 if total == 0 {
362 0.0
363 } else {
364 (self.cross_dc_requests as f64 / total as f64) * 100.0
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_datacenter_creation() {
375 let dc = Datacenter::new(DatacenterId::new("us-east-1"), Region::new("us-east"));
376
377 assert_eq!(dc.id.as_str(), "us-east-1");
378 assert_eq!(dc.region.name(), "us-east");
379 assert_eq!(dc.nodes.len(), 0);
380 }
381
382 #[test]
383 fn test_datacenter_nodes() {
384 let mut dc = Datacenter::new(DatacenterId::new("us-west-2"), Region::new("us-west"));
385
386 let node1 = NodeId(1);
387 let node2 = NodeId(2);
388
389 dc.add_node(node1);
390 dc.add_node(node2);
391 assert_eq!(dc.nodes.len(), 2);
392 assert!(dc.has_node(&node1));
393
394 assert!(dc.remove_node(&node1));
395 assert_eq!(dc.nodes.len(), 1);
396 assert!(!dc.has_node(&node1));
397 }
398
399 #[test]
400 fn test_multi_dc_coordinator() {
401 let mut coord = MultiDatacenterCoordinator::new();
402
403 let dc1 = Datacenter::new(DatacenterId::new("us-east-1"), Region::new("us-east"));
404 let dc2 = Datacenter::new(DatacenterId::new("us-west-2"), Region::new("us-west"));
405
406 coord.add_datacenter(dc1);
407 coord.add_datacenter(dc2);
408
409 let node1 = NodeId(1);
410 let node2 = NodeId(2);
411
412 coord
413 .register_node(node1, DatacenterId::new("us-east-1"))
414 .unwrap();
415 coord
416 .register_node(node2, DatacenterId::new("us-west-2"))
417 .unwrap();
418
419 assert_eq!(coord.total_nodes(), 2);
420
421 let dc = coord.get_node_datacenter(&node1).unwrap();
422 assert_eq!(dc.id.as_str(), "us-east-1");
423 }
424
425 #[test]
426 fn test_latency_tracking() {
427 let mut coord = MultiDatacenterCoordinator::new();
428
429 let dc1_id = DatacenterId::new("us-east-1");
430 let dc2_id = DatacenterId::new("us-west-2");
431
432 coord.record_latency(dc1_id.clone(), dc2_id.clone(), 50);
433
434 assert_eq!(coord.get_latency(&dc1_id, &dc2_id), Some(50));
435 assert_eq!(coord.get_latency(&dc2_id, &dc1_id), Some(50));
437 }
438
439 #[test]
440 fn test_replication_policy_all() {
441 let mut coord = MultiDatacenterCoordinator::new();
442
443 coord.add_datacenter(Datacenter::new(DatacenterId::new("dc1"), Region::new("r1")));
444 coord.add_datacenter(Datacenter::new(DatacenterId::new("dc2"), Region::new("r2")));
445
446 let policy = ReplicationPolicy::AllDatacenters;
447 let dcs = policy.select_datacenters(&coord, &DatacenterId::new("dc1"));
448
449 assert_eq!(dcs.len(), 2);
450 }
451
452 #[test]
453 fn test_replication_policy_regions() {
454 let mut coord = MultiDatacenterCoordinator::new();
455
456 coord.add_datacenter(Datacenter::new(
457 DatacenterId::new("us-east-1"),
458 Region::new("us-east"),
459 ));
460 coord.add_datacenter(Datacenter::new(
461 DatacenterId::new("us-west-2"),
462 Region::new("us-west"),
463 ));
464 coord.add_datacenter(Datacenter::new(
465 DatacenterId::new("eu-west-1"),
466 Region::new("eu-west"),
467 ));
468
469 let policy =
470 ReplicationPolicy::Regions(vec![Region::new("us-east"), Region::new("us-west")]);
471 let dcs = policy.select_datacenters(&coord, &DatacenterId::new("us-east-1"));
472
473 assert_eq!(dcs.len(), 2);
474 }
475
476 #[test]
477 fn test_latency_aware_selector() {
478 let mut coord = MultiDatacenterCoordinator::new();
479
480 let dc1_id = DatacenterId::new("dc1");
481 let dc2_id = DatacenterId::new("dc2");
482
483 coord.add_datacenter(Datacenter::new(dc1_id.clone(), Region::new("r1")));
484 coord.add_datacenter(Datacenter::new(dc2_id.clone(), Region::new("r2")));
485
486 let node1 = NodeId(1);
487 let node2 = NodeId(2);
488
489 coord.register_node(node1, dc1_id.clone()).unwrap();
490 coord.register_node(node2, dc2_id.clone()).unwrap();
491
492 coord.record_latency(dc1_id.clone(), dc2_id.clone(), 100);
493
494 let coord = Arc::new(coord);
495 let selector = LatencyAwareSelector::new(coord);
496
497 let nodes = selector.select_read_nodes(&[node1, node2], &node1);
498
499 assert_eq!(nodes[0], node1);
501 }
502
503 #[test]
504 fn test_cross_dc_stats() {
505 let mut stats = CrossDcStats::new();
506
507 stats.record_local();
508 stats.record_local();
509 stats.record_cross_dc(50);
510 stats.record_cross_dc(100);
511
512 assert_eq!(stats.local_requests, 2);
513 assert_eq!(stats.cross_dc_requests, 2);
514 assert_eq!(stats.total_requests(), 4);
515 assert_eq!(stats.cross_dc_percentage(), 50.0);
516 assert_eq!(stats.avg_cross_dc_latency_ms, 75.0);
517 }
518}