dynamo_runtime/component/
client.rs1use std::sync::atomic::{AtomicU64, Ordering};
5use std::{
6 collections::{HashMap, HashSet},
7 sync::{Arc, Mutex as StdMutex},
8 time::Duration,
9};
10
11use anyhow::Result;
12use arc_swap::ArcSwap;
13use dashmap::DashMap;
14use futures::StreamExt;
15
16use crate::component::{Endpoint, Instance};
17use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
18use crate::traits::DistributedRuntimeProvider;
19
20#[derive(Debug, Default)]
22pub(crate) struct RoutingOccupancyState {
23 counts: DashMap<u64, AtomicU64>,
24 exact_selection_lock: tokio::sync::Mutex<()>,
25}
26
27impl RoutingOccupancyState {
28 pub(crate) fn increment(&self, instance_id: u64) {
29 self.counts
30 .entry(instance_id)
31 .or_insert_with(|| AtomicU64::new(0))
32 .fetch_add(1, Ordering::Relaxed);
33 }
34
35 pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
36 let _guard = self.exact_selection_lock.lock().await;
37 let id = *instance_ids.iter().min_by_key(|&&id| self.load(id))?;
38 self.increment(id);
39 Some(id)
40 }
41
42 pub(crate) fn decrement(&self, instance_id: u64) {
43 if let Some(count) = self.counts.get(&instance_id) {
44 let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
45 Some(current.saturating_sub(1))
46 });
47 }
48 }
49
50 pub(crate) fn load(&self, instance_id: u64) -> u64 {
51 self.counts
52 .get(&instance_id)
53 .map(|c| c.load(Ordering::Relaxed))
54 .unwrap_or(0)
55 }
56
57 pub(crate) fn retain(&self, instance_ids: &[u64]) {
58 let live: HashSet<u64> = instance_ids.iter().copied().collect();
59 self.counts.retain(|id, _| live.contains(id));
60 }
61}
62
63pub(crate) async fn get_or_create_routing_occupancy_state(
65 endpoint: &Endpoint,
66) -> Arc<RoutingOccupancyState> {
67 let drt = endpoint.drt();
68 let registry = drt.routing_occupancy_states();
69 let mut registry = registry.lock().await;
70
71 if let Some(weak) = registry.get(endpoint) {
72 if let Some(state) = weak.upgrade() {
73 return state;
74 } else {
75 registry.remove(endpoint);
76 }
77 }
78
79 let state = Arc::new(RoutingOccupancyState::default());
80 registry.insert(endpoint.clone(), Arc::downgrade(&state));
81 state
82}
83
84const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
86
87#[derive(Debug)]
95pub(crate) struct EndpointDiscoverySource {
96 instance_source: tokio::sync::watch::Receiver<Vec<Instance>>,
97 event_subscribers: StdMutex<Vec<tokio::sync::mpsc::UnboundedSender<DiscoveryEvent>>>,
98}
99
100impl EndpointDiscoverySource {
101 fn new(instance_source: tokio::sync::watch::Receiver<Vec<Instance>>) -> Self {
102 Self {
103 instance_source,
104 event_subscribers: StdMutex::new(Vec::new()),
105 }
106 }
107
108 fn instance_receiver(&self) -> tokio::sync::watch::Receiver<Vec<Instance>> {
109 self.instance_source.clone()
110 }
111
112 fn subscribe_events(&self) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
113 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
114 self.event_subscribers.lock().unwrap().push(tx);
115 rx
116 }
117
118 fn broadcast_event(&self, event: &DiscoveryEvent) {
119 let subscribers = &mut *self.event_subscribers.lock().unwrap();
120 subscribers.retain(|tx| tx.send(event.clone()).is_ok());
121 }
122}
123
124#[derive(Clone, Debug)]
125pub struct Client {
126 pub endpoint: Endpoint,
128 endpoint_discovery_source: Arc<EndpointDiscoverySource>,
130 pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
132 instance_avail: Arc<ArcSwap<Vec<u64>>>,
134 instance_free: Arc<ArcSwap<Vec<u64>>>,
136 instance_avail_tx: Arc<tokio::sync::watch::Sender<Vec<u64>>>,
138 instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
140 reconcile_interval: Duration,
143}
144
145impl Client {
146 pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
148 Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
149 }
150
151 pub(crate) async fn with_reconcile_interval(
155 endpoint: Endpoint,
156 reconcile_interval: Duration,
157 ) -> Result<Self> {
158 tracing::trace!(
159 "Client::new_dynamic: Creating dynamic client for endpoint: {}",
160 endpoint.id()
161 );
162 let endpoint_discovery_source =
163 Self::get_or_create_dynamic_discovery_source(&endpoint).await?;
164 let instance_source = Arc::new(endpoint_discovery_source.instance_receiver());
165
166 let initial_ids: Vec<u64> = instance_source
171 .borrow()
172 .iter()
173 .map(|instance| instance.id())
174 .collect();
175 let (avail_tx, avail_rx) = tokio::sync::watch::channel(initial_ids.clone());
176 let client = Client {
177 endpoint: endpoint.clone(),
178 endpoint_discovery_source,
179 instance_source: instance_source.clone(),
180 instance_avail: Arc::new(ArcSwap::from(Arc::new(initial_ids.clone()))),
181 instance_free: Arc::new(ArcSwap::from(Arc::new(initial_ids))),
182 instance_avail_tx: Arc::new(avail_tx),
183 instance_avail_rx: avail_rx,
184 reconcile_interval,
185 };
186 client.monitor_instance_source();
187 Ok(client)
188 }
189
190 pub fn instances(&self) -> Vec<Instance> {
192 self.instance_source.borrow().clone()
193 }
194
195 pub fn instance_ids(&self) -> Vec<u64> {
196 self.instances().into_iter().map(|ep| ep.id()).collect()
197 }
198
199 pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
200 self.instance_avail.load()
201 }
202
203 pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
204 self.instance_free.load()
205 }
206
207 pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
209 self.instance_avail_rx.clone()
210 }
211
212 pub(crate) fn subscribe_discovery_events(
217 &self,
218 ) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
219 self.endpoint_discovery_source.subscribe_events()
220 }
221
222 pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
224 tracing::trace!(
225 "wait_for_instances: Starting wait for endpoint: {}",
226 self.endpoint.id()
227 );
228 let mut rx = self.instance_source.as_ref().clone();
229 let mut instances: Vec<Instance>;
231 loop {
232 instances = rx.borrow_and_update().to_vec();
233 if instances.is_empty() {
234 rx.changed().await?;
235 } else {
236 tracing::info!(
237 "wait_for_instances: Found {} instance(s) for endpoint: {}",
238 instances.len(),
239 self.endpoint.id()
240 );
241 break;
242 }
243 }
244 Ok(instances)
245 }
246
247 pub fn report_instance_down(&self, instance_id: u64) {
249 let filtered = self
250 .instance_ids_avail()
251 .iter()
252 .filter_map(|&id| if id == instance_id { None } else { Some(id) })
253 .collect::<Vec<_>>();
254 self.instance_avail.store(Arc::new(filtered.clone()));
255
256 let _ = self.instance_avail_tx.send(filtered);
258
259 tracing::debug!("inhibiting instance {instance_id}");
260 }
261
262 pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
264 let all_instance_ids = self.instance_ids();
265 let free_ids: Vec<u64> = all_instance_ids
266 .into_iter()
267 .filter(|id| !busy_instance_ids.contains(id))
268 .collect();
269 self.instance_free.store(Arc::new(free_ids));
270 }
271
272 fn monitor_instance_source(&self) {
279 let reconcile_interval = self.reconcile_interval;
280 let cancel_token = self.endpoint.drt().primary_token();
281 let client = self.clone();
282 let endpoint_id = self.endpoint.id();
283 tokio::task::spawn(async move {
284 let mut rx = client.instance_source.as_ref().clone();
285 while !cancel_token.is_cancelled() {
286 let instance_ids: Vec<u64> = rx
287 .borrow_and_update()
288 .iter()
289 .map(|instance| instance.id())
290 .collect();
291
292 client.instance_avail.store(Arc::new(instance_ids.clone()));
294 client.instance_free.store(Arc::new(instance_ids.clone()));
295
296 let registry = client.endpoint.drt().routing_occupancy_states();
298 if let Ok(registry) = registry.try_lock()
299 && let Some(weak) = registry.get(&client.endpoint)
300 && let Some(state) = weak.upgrade()
301 {
302 state.retain(&instance_ids);
303 }
304
305 let _ = client.instance_avail_tx.send(instance_ids);
307
308 tokio::select! {
309 result = rx.changed() => {
310 if let Err(err) = result {
311 tracing::error!(
312 "monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
313 );
314 cancel_token.cancel();
315 }
316 }
317 _ = tokio::time::sleep(reconcile_interval) => {
318 tracing::trace!(
319 "monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
320 );
321 }
322 }
323 }
324 });
325 }
326
327 #[cfg(test)]
330 pub(crate) fn override_instance_avail(&self, ids: Vec<u64>) {
331 self.instance_avail.store(Arc::new(ids));
332 }
333
334 async fn get_or_create_dynamic_discovery_source(
335 endpoint: &Endpoint,
336 ) -> Result<Arc<EndpointDiscoverySource>> {
337 let drt = endpoint.drt();
338 let sources = drt.endpoint_discovery_sources();
339 let mut sources = sources.lock().await;
340
341 if let Some(source) = sources.get(endpoint) {
342 if let Some(source) = source.upgrade() {
343 return Ok(source);
344 } else {
345 sources.remove(endpoint);
346 }
347 }
348
349 let discovery = drt.discovery();
350 let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
351 namespace: endpoint.component.namespace.name.clone(),
352 component: endpoint.component.name.clone(),
353 endpoint: endpoint.name.clone(),
354 };
355
356 let mut discovery_stream = discovery
357 .list_and_watch(discovery_query.clone(), None)
358 .await?;
359 let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
360 let discovery_source = Arc::new(EndpointDiscoverySource::new(watch_rx));
361
362 let secondary = endpoint.component.drt.runtime().secondary().clone();
363 let discovery_source_task = discovery_source.clone();
364
365 secondary.spawn(async move {
366 tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
367 let mut map: HashMap<u64, Instance> = HashMap::new();
368
369 loop {
370 let discovery_event = tokio::select! {
371 _ = watch_tx.closed() => {
372 break;
373 }
374 discovery_event = discovery_stream.next() => {
375 match discovery_event {
376 Some(Ok(event)) => {
377 event
378 },
379 Some(Err(e)) => {
380 tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
381 break;
382 }
383 None => {
384 break;
385 }
386 }
387 }
388 };
389
390 discovery_source_task.broadcast_event(&discovery_event);
391
392 match discovery_event {
393 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(instance)) => {
394 map.insert(instance.instance_id, instance);
395 }
396 DiscoveryEvent::Added(_) => {}
397 DiscoveryEvent::Removed(id) => {
398 if let DiscoveryInstanceId::Endpoint(endpoint_id) = id {
399 map.remove(&endpoint_id.instance_id);
400 }
401 }
402 }
403
404 let instances: Vec<Instance> = map.values().cloned().collect();
405 if watch_tx.send(instances).is_err() {
406 break;
407 }
408 }
409 let _ = watch_tx.send(vec![]);
410 });
411
412 sources.insert(endpoint.clone(), Arc::downgrade(&discovery_source));
413 Ok(discovery_source)
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
421
422 #[tokio::test]
425 async fn test_instance_reconciliation() {
426 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
427
428 let rt = Runtime::from_current().unwrap();
429 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
431 .await
432 .unwrap();
433 let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
434 let component = ns.component("test_component".to_string()).unwrap();
435 let endpoint = component.endpoint("test_endpoint".to_string());
436
437 let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
439 .await
440 .unwrap();
441
442 assert!(client.instance_ids_avail().is_empty());
444
445 client.instance_avail.store(Arc::new(vec![1, 2, 3]));
448
449 assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
450
451 client.report_instance_down(2);
453 assert_eq!(**client.instance_ids_avail(), vec![1u64, 3]);
454
455 tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
459
460 assert!(
462 client.instance_ids_avail().is_empty(),
463 "After reconciliation, instance_avail should match instance_source"
464 );
465
466 rt.shutdown();
467 }
468
469 #[tokio::test]
471 async fn test_report_instance_down() {
472 let rt = Runtime::from_current().unwrap();
473 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
475 .await
476 .unwrap();
477 let ns = drt.namespace("test_report_down".to_string()).unwrap();
478 let component = ns.component("test_component".to_string()).unwrap();
479 let endpoint = component.endpoint("test_endpoint".to_string());
480
481 let client = endpoint.client().await.unwrap();
482
483 client.instance_avail.store(Arc::new(vec![1, 2, 3]));
485 assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
486
487 client.report_instance_down(2);
489
490 let avail = client.instance_ids_avail();
492 assert!(avail.contains(&1), "Instance 1 should still be available");
493 assert!(
494 !avail.contains(&2),
495 "Instance 2 should be removed after report_instance_down"
496 );
497 assert!(avail.contains(&3), "Instance 3 should still be available");
498
499 rt.shutdown();
500 }
501
502 #[tokio::test]
504 async fn test_instance_avail_watcher() {
505 let rt = Runtime::from_current().unwrap();
506 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
508 .await
509 .unwrap();
510 let ns = drt.namespace("test_watcher".to_string()).unwrap();
511 let component = ns.component("test_component".to_string()).unwrap();
512 let endpoint = component.endpoint("test_endpoint".to_string());
513
514 let client = endpoint.client().await.unwrap();
515 let watcher = client.instance_avail_watcher();
516
517 client.instance_avail.store(Arc::new(vec![1, 2, 3]));
519
520 client.report_instance_down(2);
522
523 let current = watcher.borrow().clone();
526 assert_eq!(current, vec![1, 3]);
527
528 rt.shutdown();
529 }
530
531 #[tokio::test]
533 async fn test_concurrent_select_and_increment() {
534 let state = Arc::new(RoutingOccupancyState::default());
535 let instance_ids: Vec<u64> = vec![100, 200, 300];
536 let num_requests = 90;
537
538 let mut handles = Vec::new();
539 for _ in 0..num_requests {
540 let state = state.clone();
541 let ids = instance_ids.clone();
542 handles.push(tokio::spawn(async move {
543 state.select_exact_min_and_increment(&ids).await
544 }));
545 }
546
547 for handle in handles {
548 handle.await.unwrap();
549 }
550
551 assert_eq!(state.load(100), 30);
552 assert_eq!(state.load(200), 30);
553 assert_eq!(state.load(300), 30);
554 }
555
556 #[tokio::test]
557 async fn test_connection_counts() {
558 let rt = Runtime::from_current().unwrap();
559 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
560 .await
561 .unwrap();
562 let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
563 let component = ns.component("test_component".to_string()).unwrap();
564 let endpoint = component.endpoint("test_endpoint".to_string());
565
566 let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
567 let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
568
569 let picked1 = state1
570 .select_exact_min_and_increment(&[10, 20, 30])
571 .await
572 .unwrap();
573 assert_eq!(state1.load(picked1), 1);
574
575 let picked2 = state1
576 .select_exact_min_and_increment(&[10, 20, 30])
577 .await
578 .unwrap();
579 assert_ne!(picked1, picked2);
580
581 assert_eq!(state2.load(10), state1.load(10));
583 assert_eq!(state2.load(20), state1.load(20));
584 assert_eq!(state2.load(30), state1.load(30));
585
586 state2.decrement(picked1);
587 assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
588
589 rt.shutdown();
590 }
591
592 #[tokio::test]
593 async fn test_least_loaded_state_retain() {
594 let state = RoutingOccupancyState::default();
595
596 state.select_exact_min_and_increment(&[1, 2, 3]).await;
598 state.select_exact_min_and_increment(&[1, 2, 3]).await;
599 state.select_exact_min_and_increment(&[1, 2, 3]).await;
600 assert_eq!(state.load(1), 1);
602 assert_eq!(state.load(2), 1);
603 assert_eq!(state.load(3), 1);
604
605 state.retain(&[1, 3]);
607
608 assert_eq!(state.load(1), 1);
609 assert_eq!(state.load(2), 0);
610 assert_eq!(state.load(3), 1);
611 }
612
613 #[tokio::test]
614 async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
615 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
616
617 let rt = Runtime::from_current().unwrap();
618 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
619 .await
620 .unwrap();
621 let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
622 let component = ns.component("test_component".to_string()).unwrap();
623 let endpoint = component.endpoint("test_endpoint".to_string());
624
625 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
626 .await
627 .unwrap();
628 endpoint.register_endpoint_instance().await.unwrap();
629 client.wait_for_instances().await.unwrap();
630
631 let worker_id = client.instance_ids_avail()[0];
632 let state = get_or_create_routing_occupancy_state(&endpoint).await;
633 state.increment(worker_id);
634 assert_eq!(state.load(worker_id), 1);
635
636 endpoint.unregister_endpoint_instance().await.unwrap();
637
638 for _ in 0..10 {
639 if state.load(worker_id) == 0 {
640 break;
641 }
642 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
643 }
644
645 assert_eq!(state.load(worker_id), 0);
646
647 rt.shutdown();
648 }
649}