1use dashmap::DashMap;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use uuid::Uuid;
8
9#[derive(Clone, Debug, PartialEq, Eq, Hash)]
11pub struct Endpoint {
12 pub uri: String,
13}
14
15#[derive(Clone, Debug, PartialEq, Eq)]
17pub enum EndpointKind {
18 Tcp(std::net::SocketAddr),
20 Uds(std::path::PathBuf),
22 Other(String),
24}
25
26impl Endpoint {
27 pub fn from_uri<S: Into<String>>(s: S) -> Self {
28 Self { uri: s.into() }
29 }
30
31 pub fn uds(path: impl AsRef<std::path::Path>) -> Self {
32 Self {
33 uri: format!("unix://{}", path.as_ref().display()),
34 }
35 }
36
37 #[must_use]
38 pub fn http(host: &str, port: u16) -> Self {
39 Self {
40 uri: format!("http://{host}:{port}"),
41 }
42 }
43
44 #[must_use]
45 pub fn https(host: &str, port: u16) -> Self {
46 Self {
47 uri: format!("https://{host}:{port}"),
48 }
49 }
50
51 #[must_use]
53 pub fn kind(&self) -> EndpointKind {
54 if let Some(rest) = self.uri.strip_prefix("unix://") {
55 return EndpointKind::Uds(std::path::PathBuf::from(rest));
56 }
57 if let Some(rest) = self.uri.strip_prefix("http://")
58 && let Ok(addr) = rest.parse::<std::net::SocketAddr>()
59 {
60 return EndpointKind::Tcp(addr);
61 }
62 if let Some(rest) = self.uri.strip_prefix("https://")
63 && let Ok(addr) = rest.parse::<std::net::SocketAddr>()
64 {
65 return EndpointKind::Tcp(addr);
66 }
67 EndpointKind::Other(self.uri.clone())
68 }
69}
70
71#[derive(Clone, Copy, Debug, PartialEq, Eq)]
72pub enum InstanceState {
73 Registered,
74 Ready,
75 Healthy,
76 Quarantined,
77 Draining,
78}
79
80#[derive(Clone, Debug)]
82pub struct InstanceRuntimeState {
83 pub last_heartbeat: Instant,
84 pub state: InstanceState,
85}
86
87#[derive(Debug)]
89#[must_use]
90pub struct ModuleInstance {
91 pub module: String,
92 pub instance_id: Uuid,
93 pub control: Option<Endpoint>,
94 pub grpc_services: HashMap<String, Endpoint>,
95 pub version: Option<String>,
96 inner: Arc<parking_lot::RwLock<InstanceRuntimeState>>,
97}
98
99impl Clone for ModuleInstance {
100 fn clone(&self) -> Self {
101 Self {
102 module: self.module.clone(),
103 instance_id: self.instance_id,
104 control: self.control.clone(),
105 grpc_services: self.grpc_services.clone(),
106 version: self.version.clone(),
107 inner: Arc::clone(&self.inner),
108 }
109 }
110}
111
112impl ModuleInstance {
113 pub fn new(module: impl Into<String>, instance_id: Uuid) -> Self {
114 Self {
115 module: module.into(),
116 instance_id,
117 control: None,
118 grpc_services: HashMap::new(),
119 version: None,
120 inner: Arc::new(parking_lot::RwLock::new(InstanceRuntimeState {
121 last_heartbeat: Instant::now(),
122 state: InstanceState::Registered,
123 })),
124 }
125 }
126
127 pub fn with_control(mut self, ep: Endpoint) -> Self {
128 self.control = Some(ep);
129 self
130 }
131
132 pub fn with_version(mut self, v: impl Into<String>) -> Self {
133 self.version = Some(v.into());
134 self
135 }
136
137 pub fn with_grpc_service(mut self, name: impl Into<String>, ep: Endpoint) -> Self {
138 self.grpc_services.insert(name.into(), ep);
139 self
140 }
141
142 #[must_use]
144 pub fn state(&self) -> InstanceState {
145 self.inner.read().state
146 }
147
148 #[must_use]
150 pub fn last_heartbeat(&self) -> Instant {
151 self.inner.read().last_heartbeat
152 }
153}
154
155#[derive(Clone)]
158#[must_use]
159pub struct ModuleManager {
160 inner: DashMap<String, Vec<Arc<ModuleInstance>>>,
161 rr_counters: DashMap<String, usize>,
162 hb_ttl: Duration,
163 hb_grace: Duration,
164}
165
166impl std::fmt::Debug for ModuleManager {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 let modules: Vec<String> = self.inner.iter().map(|e| e.key().clone()).collect();
169 f.debug_struct("ModuleManager")
170 .field("instances_count", &self.inner.len())
171 .field("modules", &modules)
172 .field("heartbeat_ttl", &self.hb_ttl)
173 .field("heartbeat_grace", &self.hb_grace)
174 .finish_non_exhaustive()
175 }
176}
177
178impl ModuleManager {
179 pub fn new() -> Self {
180 Self {
181 inner: DashMap::new(),
182 rr_counters: DashMap::new(),
183 hb_ttl: Duration::from_secs(15),
184 hb_grace: Duration::from_secs(30),
185 }
186 }
187
188 pub fn with_heartbeat_policy(mut self, ttl: Duration, grace: Duration) -> Self {
189 self.hb_ttl = ttl;
190 self.hb_grace = grace;
191 self
192 }
193
194 pub fn register_instance(&self, instance: Arc<ModuleInstance>) {
196 let module = instance.module.clone();
197 let mut vec = self.inner.entry(module).or_default();
198 if let Some(pos) = vec
200 .iter()
201 .position(|i| i.instance_id == instance.instance_id)
202 {
203 vec[pos] = instance;
204 } else {
205 vec.push(instance);
206 }
207 }
208
209 pub fn mark_ready(&self, module: &str, instance_id: Uuid) {
211 if let Some(mut vec) = self.inner.get_mut(module)
212 && let Some(inst) = vec.iter_mut().find(|i| i.instance_id == instance_id)
213 {
214 let mut state = inst.inner.write();
215 state.state = InstanceState::Ready;
216 }
217 }
218
219 pub fn update_heartbeat(&self, module: &str, instance_id: Uuid, at: Instant) {
221 if let Some(mut vec) = self.inner.get_mut(module)
222 && let Some(inst) = vec.iter_mut().find(|i| i.instance_id == instance_id)
223 {
224 let mut state = inst.inner.write();
225 state.last_heartbeat = at;
226 if state.state == InstanceState::Registered {
228 state.state = InstanceState::Healthy;
229 }
230 }
231 }
232
233 pub fn mark_quarantined(&self, module: &str, instance_id: Uuid) {
235 if let Some(mut vec) = self.inner.get_mut(module)
236 && let Some(inst) = vec.iter_mut().find(|i| i.instance_id == instance_id)
237 {
238 inst.inner.write().state = InstanceState::Quarantined;
239 }
240 }
241
242 pub fn mark_draining(&self, module: &str, instance_id: Uuid) {
244 if let Some(mut vec) = self.inner.get_mut(module)
245 && let Some(inst) = vec.iter_mut().find(|i| i.instance_id == instance_id)
246 {
247 inst.inner.write().state = InstanceState::Draining;
248 }
249 }
250
251 pub fn deregister(&self, module: &str, instance_id: Uuid) {
253 let mut remove_module = false;
254 {
255 if let Some(mut vec) = self.inner.get_mut(module) {
256 let list = vec.value_mut();
257 list.retain(|inst| inst.instance_id != instance_id);
258 if list.is_empty() {
259 remove_module = true;
260 }
261 }
262 }
263
264 if remove_module {
265 self.inner.remove(module);
266 self.rr_counters.remove(module);
267 }
268 }
269
270 #[must_use]
272 pub fn instances_of(&self, module: &str) -> Vec<Arc<ModuleInstance>> {
273 self.inner
274 .get(module)
275 .map(|v| v.clone())
276 .unwrap_or_default()
277 }
278
279 #[must_use]
281 pub fn all_instances(&self) -> Vec<Arc<ModuleInstance>> {
282 self.inner
283 .iter()
284 .flat_map(|entry| entry.value().clone())
285 .collect()
286 }
287
288 pub fn evict_stale(&self, now: Instant) {
290 use InstanceState::{Draining, Quarantined};
291 let mut empty_modules = Vec::new();
292
293 for mut entry in self.inner.iter_mut() {
294 let module = entry.key().clone();
295 let vec = entry.value_mut();
296 vec.retain(|inst| {
297 let state = inst.inner.read();
298 let age = now.saturating_duration_since(state.last_heartbeat);
299
300 if age >= self.hb_ttl && !matches!(state.state, Quarantined | Draining) {
302 drop(state); inst.inner.write().state = Quarantined;
304 return true; }
306
307 if state.state == Quarantined && age >= self.hb_ttl + self.hb_grace {
309 return false; }
311
312 true
313 });
314
315 if vec.is_empty() {
316 empty_modules.push(module);
317 }
318 }
319
320 for module in empty_modules {
321 self.inner.remove(&module);
322 self.rr_counters.remove(&module);
323 }
324 }
325
326 #[must_use]
328 pub fn pick_instance_round_robin(&self, module: &str) -> Option<Arc<ModuleInstance>> {
329 let instances_entry = self.inner.get(module)?;
330 let instances = instances_entry.value();
331
332 if instances.is_empty() {
333 return None;
334 }
335
336 let healthy: Vec<_> = instances
338 .iter()
339 .filter(|inst| matches!(inst.state(), InstanceState::Healthy | InstanceState::Ready))
340 .cloned()
341 .collect();
342
343 let candidates: Vec<_> = if healthy.is_empty() {
344 instances.clone()
345 } else {
346 healthy
347 };
348
349 if candidates.is_empty() {
350 return None;
351 }
352
353 let len = candidates.len();
354 let mut counter = self.rr_counters.entry(module.to_owned()).or_insert(0);
355 let idx = *counter % len;
356 *counter = (*counter + 1) % len;
357
358 candidates.get(idx).cloned()
359 }
360
361 #[must_use]
364 pub fn pick_service_round_robin(
365 &self,
366 service_name: &str,
367 ) -> Option<(String, Arc<ModuleInstance>, Endpoint)> {
368 let mut candidates = Vec::new();
370 for entry in &self.inner {
371 let module = entry.key().clone();
372 for inst in entry.value() {
373 if let Some(ep) = inst.grpc_services.get(service_name) {
374 let state = inst.state();
375 if matches!(state, InstanceState::Healthy | InstanceState::Ready) {
376 candidates.push((module.clone(), inst.clone(), ep.clone()));
377 }
378 }
379 }
380 }
381
382 if candidates.is_empty() {
383 return None;
384 }
385
386 let len = candidates.len();
388 let service_key = service_name.to_owned();
389 let mut counter = self.rr_counters.entry(service_key).or_insert(0);
390 let idx = *counter % len;
391 *counter = (*counter + 1) % len;
392
393 candidates.get(idx).cloned()
394 }
395}
396
397impl Default for ModuleManager {
398 fn default() -> Self {
399 Self::new()
400 }
401}
402
403#[cfg(test)]
404#[cfg_attr(coverage_nightly, coverage(off))]
405mod tests {
406 use super::*;
407 use std::thread::sleep;
408 use std::time::Duration;
409
410 #[test]
411 fn test_register_and_retrieve_instances() {
412 let dir = ModuleManager::new();
413 let instance_id = Uuid::new_v4();
414 let instance = Arc::new(
415 ModuleInstance::new("test_module", instance_id)
416 .with_control(Endpoint::http("localhost", 8080))
417 .with_version("1.0.0"),
418 );
419
420 dir.register_instance(instance);
421
422 let instances = dir.instances_of("test_module");
423 assert_eq!(instances.len(), 1);
424 assert_eq!(instances[0].instance_id, instance_id);
425 assert_eq!(instances[0].module, "test_module");
426 assert_eq!(instances[0].version, Some("1.0.0".to_owned()));
427 }
428
429 #[test]
430 fn test_register_multiple_instances() {
431 let dir = ModuleManager::new();
432
433 let id1 = Uuid::new_v4();
434 let id2 = Uuid::new_v4();
435 let instance1 = Arc::new(ModuleInstance::new("test_module", id1));
436 let instance2 = Arc::new(ModuleInstance::new("test_module", id2));
437
438 dir.register_instance(instance1);
439 dir.register_instance(instance2);
440
441 let registered = dir.instances_of("test_module");
442 assert_eq!(registered.len(), 2);
443
444 let ids: Vec<_> = registered.iter().map(|i| i.instance_id).collect();
445 assert!(ids.contains(&id1));
446 assert!(ids.contains(&id2));
447 }
448
449 #[test]
450 fn test_update_existing_instance() {
451 let dir = ModuleManager::new();
452 let instance_id = Uuid::new_v4();
453
454 let initial_instance =
455 Arc::new(ModuleInstance::new("test_module", instance_id).with_version("1.0.0"));
456 dir.register_instance(initial_instance);
457
458 let updated_instance =
459 Arc::new(ModuleInstance::new("test_module", instance_id).with_version("2.0.0"));
460 dir.register_instance(updated_instance);
461
462 let registered = dir.instances_of("test_module");
463 assert_eq!(registered.len(), 1, "Should not duplicate instance");
464 assert_eq!(registered[0].version, Some("2.0.0".to_owned()));
465 }
466
467 #[test]
468 fn test_mark_ready() {
469 let dir = ModuleManager::new();
470 let instance_id = Uuid::new_v4();
471 let instance = Arc::new(ModuleInstance::new("test_module", instance_id));
472
473 dir.register_instance(instance);
474
475 dir.mark_ready("test_module", instance_id);
476
477 let instances = dir.instances_of("test_module");
478 assert_eq!(instances.len(), 1);
479 assert!(matches!(instances[0].state(), InstanceState::Ready));
480 }
481
482 #[test]
483 fn test_update_heartbeat() {
484 let dir = ModuleManager::new();
485 let instance_id = Uuid::new_v4();
486 let instance = Arc::new(ModuleInstance::new("test_module", instance_id));
487 let initial_heartbeat = instance.last_heartbeat();
488
489 dir.register_instance(instance);
490
491 sleep(Duration::from_millis(10));
493
494 let new_heartbeat = Instant::now();
495 dir.update_heartbeat("test_module", instance_id, new_heartbeat);
496
497 let instances = dir.instances_of("test_module");
498 assert!(instances[0].last_heartbeat() > initial_heartbeat);
499 assert!(matches!(instances[0].state(), InstanceState::Healthy));
500 }
501
502 #[test]
503 fn test_all_instances() {
504 let dir = ModuleManager::new();
505
506 let instance1 = Arc::new(ModuleInstance::new("module_a", Uuid::new_v4()));
507 let instance2 = Arc::new(ModuleInstance::new("module_b", Uuid::new_v4()));
508 let instance3 = Arc::new(ModuleInstance::new("module_a", Uuid::new_v4()));
509
510 dir.register_instance(instance1);
511 dir.register_instance(instance2);
512 dir.register_instance(instance3);
513
514 let all = dir.all_instances();
515 assert_eq!(all.len(), 3);
516
517 let modules: Vec<_> = all.iter().map(|i| i.module.as_str()).collect();
518 assert_eq!(modules.iter().filter(|&m| *m == "module_a").count(), 2);
519 assert_eq!(modules.iter().filter(|&m| *m == "module_b").count(), 1);
520 }
521
522 #[test]
523 fn test_pick_instance_round_robin() {
524 let dir = ModuleManager::new();
525
526 let id1 = Uuid::new_v4();
527 let id2 = Uuid::new_v4();
528 let instance1 = Arc::new(ModuleInstance::new("test_module", id1));
529 let instance2 = Arc::new(ModuleInstance::new("test_module", id2));
530
531 dir.register_instance(instance1);
532 dir.register_instance(instance2);
533
534 let picked1 = dir.pick_instance_round_robin("test_module").unwrap();
536 let picked2 = dir.pick_instance_round_robin("test_module").unwrap();
537 let picked3 = dir.pick_instance_round_robin("test_module").unwrap();
538
539 let ids = [
540 picked1.instance_id,
541 picked2.instance_id,
542 picked3.instance_id,
543 ];
544
545 assert!(ids.contains(&id1));
548 assert!(ids.contains(&id2));
549 assert_eq!(picked1.instance_id, picked3.instance_id);
551 assert_ne!(picked1.instance_id, picked2.instance_id);
553 }
554
555 #[test]
556 fn test_pick_instance_none_available() {
557 let dir = ModuleManager::new();
558 let picked = dir.pick_instance_round_robin("nonexistent_module");
559 assert!(picked.is_none());
560 }
561
562 #[test]
563 fn test_endpoint_creation() {
564 let plain_ep = Endpoint::http("localhost", 8080);
565 assert_eq!(plain_ep.uri, "http://localhost:8080");
566
567 let secure_ep = Endpoint::https("localhost", 8443);
568 assert_eq!(secure_ep.uri, "https://localhost:8443");
569
570 let uds_ep = Endpoint::uds("/tmp/socket.sock");
571 assert!(uds_ep.uri.starts_with("unix://"));
572 assert!(uds_ep.uri.contains("socket.sock"));
573
574 let custom_ep = Endpoint::from_uri("http://example.com");
575 assert_eq!(custom_ep.uri, "http://example.com");
576 }
577
578 #[test]
579 fn test_endpoint_kind() {
580 let plain_ep = Endpoint::http("127.0.0.1", 8080);
581 match plain_ep.kind() {
582 EndpointKind::Tcp(addr) => {
583 assert_eq!(addr.ip().to_string(), "127.0.0.1");
584 assert_eq!(addr.port(), 8080);
585 }
586 _ => panic!("Expected TCP endpoint for http"),
587 }
588
589 let secure_ep = Endpoint::https("127.0.0.1", 8443);
590 match secure_ep.kind() {
591 EndpointKind::Tcp(addr) => {
592 assert_eq!(addr.ip().to_string(), "127.0.0.1");
593 assert_eq!(addr.port(), 8443);
594 }
595 _ => panic!("Expected TCP endpoint for https"),
596 }
597
598 let uds_ep = Endpoint::uds("/tmp/test.sock");
599 match uds_ep.kind() {
600 EndpointKind::Uds(path) => {
601 assert!(path.to_string_lossy().contains("test.sock"));
602 }
603 _ => panic!("Expected UDS endpoint"),
604 }
605
606 let other_ep = Endpoint::from_uri("grpc://example.com");
607 match other_ep.kind() {
608 EndpointKind::Other(uri) => {
609 assert_eq!(uri, "grpc://example.com");
610 }
611 _ => panic!("Expected Other endpoint"),
612 }
613 }
614
615 #[test]
616 fn test_module_instance_builder() {
617 let instance_id = Uuid::new_v4();
618 let instance = ModuleInstance::new("test_module", instance_id)
619 .with_control(Endpoint::http("localhost", 8080))
620 .with_version("1.2.3")
621 .with_grpc_service("service1", Endpoint::http("localhost", 8082))
622 .with_grpc_service("service2", Endpoint::http("localhost", 8083));
623
624 assert_eq!(instance.module, "test_module");
625 assert_eq!(instance.instance_id, instance_id);
626 assert!(instance.control.is_some());
627 assert_eq!(instance.version, Some("1.2.3".to_owned()));
628 assert_eq!(instance.grpc_services.len(), 2);
629 assert!(instance.grpc_services.contains_key("service1"));
630 assert!(instance.grpc_services.contains_key("service2"));
631 assert!(matches!(instance.state(), InstanceState::Registered));
632 }
633
634 #[test]
635 fn test_quarantine_and_evict() {
636 let ttl = Duration::from_millis(50);
637 let grace = Duration::from_millis(50);
638 let dir = ModuleManager::new().with_heartbeat_policy(ttl, grace);
639
640 let now = Instant::now();
641 let instance = ModuleInstance::new("test_module", Uuid::new_v4());
642 instance.inner.write().last_heartbeat = now
644 .checked_sub(ttl)
645 .and_then(|t| t.checked_sub(Duration::from_millis(10)))
646 .expect("test duration subtraction should not underflow");
647
648 dir.register_instance(Arc::new(instance));
649
650 dir.evict_stale(now);
651 let instances = dir.instances_of("test_module");
652 assert_eq!(instances.len(), 1);
653 assert!(matches!(instances[0].state(), InstanceState::Quarantined));
654
655 let later = now + grace + Duration::from_millis(10);
656 dir.evict_stale(later);
657
658 let instances_after = dir.instances_of("test_module");
659 assert!(instances_after.is_empty());
660 }
661
662 #[test]
663 fn test_instances_of_empty() {
664 let dir = ModuleManager::new();
665 let instances = dir.instances_of("nonexistent");
666 assert!(instances.is_empty());
667 }
668
669 #[test]
670 fn test_rr_prefers_healthy() {
671 let dir = ModuleManager::new();
672
673 let healthy_id = Uuid::new_v4();
675 let healthy = Arc::new(ModuleInstance::new("test_module", healthy_id));
676 dir.register_instance(healthy);
677 dir.update_heartbeat("test_module", healthy_id, Instant::now());
678
679 let quarantined_id = Uuid::new_v4();
680 let quarantined = Arc::new(ModuleInstance::new("test_module", quarantined_id));
681 dir.register_instance(quarantined);
682 dir.mark_quarantined("test_module", quarantined_id);
683
684 for _ in 0..5 {
686 let picked = dir.pick_instance_round_robin("test_module").unwrap();
687 assert_eq!(picked.instance_id, healthy_id);
688 }
689 }
690
691 #[test]
692 fn test_pick_service_round_robin() {
693 let dir = ModuleManager::new();
694
695 let id1 = Uuid::new_v4();
696 let id2 = Uuid::new_v4();
697 let inst1 = Arc::new(
699 ModuleInstance::new("test_module", id1)
700 .with_grpc_service("test.Service", Endpoint::http("127.0.0.1", 8001)),
701 );
702 let inst2 = Arc::new(
703 ModuleInstance::new("test_module", id2)
704 .with_grpc_service("test.Service", Endpoint::http("127.0.0.1", 8002)),
705 );
706
707 dir.register_instance(inst1);
708 dir.register_instance(inst2);
709
710 dir.update_heartbeat("test_module", id1, Instant::now());
712 dir.update_heartbeat("test_module", id2, Instant::now());
713
714 let pick1 = dir.pick_service_round_robin("test.Service");
716 let pick2 = dir.pick_service_round_robin("test.Service");
717 let pick3 = dir.pick_service_round_robin("test.Service");
718
719 assert!(pick1.is_some());
720 assert!(pick2.is_some());
721 assert!(pick3.is_some());
722
723 let (_, inst1, ep1) = pick1.unwrap();
724 let (_, inst2, ep2) = pick2.unwrap();
725 let (_, inst3, _) = pick3.unwrap();
726
727 assert_eq!(inst1.instance_id, inst3.instance_id);
729 assert_ne!(inst1.instance_id, inst2.instance_id);
731 assert_ne!(ep1, ep2);
733 }
734}