dynamo_runtime/component/
client.rs1use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::{
7 collections::{HashMap, HashSet},
8 time::Duration,
9};
10
11use anyhow::Result;
12use arc_swap::ArcSwap;
13use dashmap::DashMap;
14use futures::StreamExt;
15use tokio::net::unix::pipe::Receiver;
16
17use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
18use crate::{
19 component::{Endpoint, Instance},
20 pipeline::async_trait,
21 pipeline::{
22 AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
23 SingleIn,
24 },
25 traits::DistributedRuntimeProvider,
26 transports::etcd::Client as EtcdClient,
27};
28
29#[derive(Debug, Default)]
31pub(crate) struct RoutingOccupancyState {
32 counts: DashMap<u64, AtomicU64>,
33 exact_selection_lock: tokio::sync::Mutex<()>,
34}
35
36impl RoutingOccupancyState {
37 pub(crate) fn increment(&self, instance_id: u64) {
38 self.counts
39 .entry(instance_id)
40 .or_insert_with(|| AtomicU64::new(0))
41 .fetch_add(1, Ordering::Relaxed);
42 }
43
44 pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
45 let _guard = self.exact_selection_lock.lock().await;
46 let id = *instance_ids.iter().min_by_key(|&&id| self.load(id))?;
47 self.increment(id);
48 Some(id)
49 }
50
51 pub(crate) fn decrement(&self, instance_id: u64) {
52 if let Some(count) = self.counts.get(&instance_id) {
53 let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
54 Some(current.saturating_sub(1))
55 });
56 }
57 }
58
59 pub(crate) fn load(&self, instance_id: u64) -> u64 {
60 self.counts
61 .get(&instance_id)
62 .map(|c| c.load(Ordering::Relaxed))
63 .unwrap_or(0)
64 }
65
66 pub(crate) fn retain(&self, instance_ids: &[u64]) {
67 let live: HashSet<u64> = instance_ids.iter().copied().collect();
68 self.counts.retain(|id, _| live.contains(id));
69 }
70}
71
72pub(crate) async fn get_or_create_routing_occupancy_state(
74 endpoint: &Endpoint,
75) -> Arc<RoutingOccupancyState> {
76 let drt = endpoint.drt();
77 let registry = drt.routing_occupancy_states();
78 let mut registry = registry.lock().await;
79
80 if let Some(weak) = registry.get(endpoint) {
81 if let Some(state) = weak.upgrade() {
82 return state;
83 } else {
84 registry.remove(endpoint);
85 }
86 }
87
88 let state = Arc::new(RoutingOccupancyState::default());
89 registry.insert(endpoint.clone(), Arc::downgrade(&state));
90 state
91}
92
93const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
95
96#[derive(Clone, Debug)]
97pub struct Client {
98 pub endpoint: Endpoint,
100 pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
102 instance_avail: Arc<ArcSwap<Vec<u64>>>,
104 instance_free: Arc<ArcSwap<Vec<u64>>>,
106 instance_avail_tx: Arc<tokio::sync::watch::Sender<Vec<u64>>>,
108 instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
110 reconcile_interval: Duration,
113}
114
115impl Client {
116 pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
118 Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
119 }
120
121 pub(crate) async fn with_reconcile_interval(
125 endpoint: Endpoint,
126 reconcile_interval: Duration,
127 ) -> Result<Self> {
128 tracing::trace!(
129 "Client::new_dynamic: Creating dynamic client for endpoint: {}",
130 endpoint.id()
131 );
132 let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
133
134 let initial_ids: Vec<u64> = instance_source
139 .borrow()
140 .iter()
141 .map(|instance| instance.id())
142 .collect();
143 let (avail_tx, avail_rx) = tokio::sync::watch::channel(initial_ids.clone());
144 let client = Client {
145 endpoint: endpoint.clone(),
146 instance_source: instance_source.clone(),
147 instance_avail: Arc::new(ArcSwap::from(Arc::new(initial_ids.clone()))),
148 instance_free: Arc::new(ArcSwap::from(Arc::new(initial_ids))),
149 instance_avail_tx: Arc::new(avail_tx),
150 instance_avail_rx: avail_rx,
151 reconcile_interval,
152 };
153 client.monitor_instance_source();
154 Ok(client)
155 }
156
157 pub fn instances(&self) -> Vec<Instance> {
159 self.instance_source.borrow().clone()
160 }
161
162 pub fn instance_ids(&self) -> Vec<u64> {
163 self.instances().into_iter().map(|ep| ep.id()).collect()
164 }
165
166 pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
167 self.instance_avail.load()
168 }
169
170 pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
171 self.instance_free.load()
172 }
173
174 pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
176 self.instance_avail_rx.clone()
177 }
178
179 pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
181 tracing::trace!(
182 "wait_for_instances: Starting wait for endpoint: {}",
183 self.endpoint.id()
184 );
185 let mut rx = self.instance_source.as_ref().clone();
186 let mut instances: Vec<Instance>;
188 loop {
189 instances = rx.borrow_and_update().to_vec();
190 if instances.is_empty() {
191 rx.changed().await?;
192 } else {
193 tracing::info!(
194 "wait_for_instances: Found {} instance(s) for endpoint: {}",
195 instances.len(),
196 self.endpoint.id()
197 );
198 break;
199 }
200 }
201 Ok(instances)
202 }
203
204 pub fn report_instance_down(&self, instance_id: u64) {
206 let filtered = self
207 .instance_ids_avail()
208 .iter()
209 .filter_map(|&id| if id == instance_id { None } else { Some(id) })
210 .collect::<Vec<_>>();
211 self.instance_avail.store(Arc::new(filtered.clone()));
212
213 let _ = self.instance_avail_tx.send(filtered);
215
216 tracing::debug!("inhibiting instance {instance_id}");
217 }
218
219 pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
221 let all_instance_ids = self.instance_ids();
222 let free_ids: Vec<u64> = all_instance_ids
223 .into_iter()
224 .filter(|id| !busy_instance_ids.contains(id))
225 .collect();
226 self.instance_free.store(Arc::new(free_ids));
227 }
228
229 fn monitor_instance_source(&self) {
236 let reconcile_interval = self.reconcile_interval;
237 let cancel_token = self.endpoint.drt().primary_token();
238 let client = self.clone();
239 let endpoint_id = self.endpoint.id();
240 tokio::task::spawn(async move {
241 let mut rx = client.instance_source.as_ref().clone();
242 while !cancel_token.is_cancelled() {
243 let instance_ids: Vec<u64> = rx
244 .borrow_and_update()
245 .iter()
246 .map(|instance| instance.id())
247 .collect();
248
249 client.instance_avail.store(Arc::new(instance_ids.clone()));
251 client.instance_free.store(Arc::new(instance_ids.clone()));
252
253 let registry = client.endpoint.drt().routing_occupancy_states();
255 if let Ok(registry) = registry.try_lock()
256 && let Some(weak) = registry.get(&client.endpoint)
257 && let Some(state) = weak.upgrade()
258 {
259 state.retain(&instance_ids);
260 }
261
262 let _ = client.instance_avail_tx.send(instance_ids);
264
265 tokio::select! {
266 result = rx.changed() => {
267 if let Err(err) = result {
268 tracing::error!(
269 "monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
270 );
271 cancel_token.cancel();
272 }
273 }
274 _ = tokio::time::sleep(reconcile_interval) => {
275 tracing::trace!(
276 "monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
277 );
278 }
279 }
280 }
281 });
282 }
283
284 #[cfg(test)]
287 pub(crate) fn override_instance_avail(&self, ids: Vec<u64>) {
288 self.instance_avail.store(Arc::new(ids));
289 }
290
291 async fn get_or_create_dynamic_instance_source(
292 endpoint: &Endpoint,
293 ) -> Result<Arc<tokio::sync::watch::Receiver<Vec<Instance>>>> {
294 let drt = endpoint.drt();
295 let instance_sources = drt.instance_sources();
296 let mut instance_sources = instance_sources.lock().await;
297
298 if let Some(instance_source) = instance_sources.get(endpoint) {
299 if let Some(instance_source) = instance_source.upgrade() {
300 return Ok(instance_source);
301 } else {
302 instance_sources.remove(endpoint);
303 }
304 }
305
306 let discovery = drt.discovery();
307 let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
308 namespace: endpoint.component.namespace.name.clone(),
309 component: endpoint.component.name.clone(),
310 endpoint: endpoint.name.clone(),
311 };
312
313 let mut discovery_stream = discovery
314 .list_and_watch(discovery_query.clone(), None)
315 .await?;
316 let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
317
318 let secondary = endpoint.component.drt.runtime().secondary().clone();
319
320 secondary.spawn(async move {
321 tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
322 let mut map: HashMap<u64, Instance> = HashMap::new();
323
324 loop {
325 let discovery_event = tokio::select! {
326 _ = watch_tx.closed() => {
327 break;
328 }
329 discovery_event = discovery_stream.next() => {
330 match discovery_event {
331 Some(Ok(event)) => {
332 event
333 },
334 Some(Err(e)) => {
335 tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
336 break;
337 }
338 None => {
339 break;
340 }
341 }
342 }
343 };
344
345 match discovery_event {
346 DiscoveryEvent::Added(discovery_instance) => {
347 if let DiscoveryInstance::Endpoint(instance) = discovery_instance {
348
349 map.insert(instance.instance_id, instance);
350 }
351 }
352 DiscoveryEvent::Removed(id) => {
353 map.remove(&id.instance_id());
354 }
355 }
356
357 let instances: Vec<Instance> = map.values().cloned().collect();
358 if watch_tx.send(instances).is_err() {
359 break;
360 }
361 }
362 let _ = watch_tx.send(vec![]);
363 });
364
365 let instance_source = Arc::new(watch_rx);
366 instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
367 Ok(instance_source)
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
375
376 #[tokio::test]
379 async fn test_instance_reconciliation() {
380 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
381
382 let rt = Runtime::from_current().unwrap();
383 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
385 .await
386 .unwrap();
387 let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
388 let component = ns.component("test_component".to_string()).unwrap();
389 let endpoint = component.endpoint("test_endpoint".to_string());
390
391 let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
393 .await
394 .unwrap();
395
396 assert!(client.instance_ids_avail().is_empty());
398
399 client.instance_avail.store(Arc::new(vec![1, 2, 3]));
402
403 assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
404
405 client.report_instance_down(2);
407 assert_eq!(**client.instance_ids_avail(), vec![1u64, 3]);
408
409 tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
413
414 assert!(
416 client.instance_ids_avail().is_empty(),
417 "After reconciliation, instance_avail should match instance_source"
418 );
419
420 rt.shutdown();
421 }
422
423 #[tokio::test]
425 async fn test_report_instance_down() {
426 let rt = Runtime::from_current().unwrap();
427 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
429 .await
430 .unwrap();
431 let ns = drt.namespace("test_report_down".to_string()).unwrap();
432 let component = ns.component("test_component".to_string()).unwrap();
433 let endpoint = component.endpoint("test_endpoint".to_string());
434
435 let client = endpoint.client().await.unwrap();
436
437 client.instance_avail.store(Arc::new(vec![1, 2, 3]));
439 assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
440
441 client.report_instance_down(2);
443
444 let avail = client.instance_ids_avail();
446 assert!(avail.contains(&1), "Instance 1 should still be available");
447 assert!(
448 !avail.contains(&2),
449 "Instance 2 should be removed after report_instance_down"
450 );
451 assert!(avail.contains(&3), "Instance 3 should still be available");
452
453 rt.shutdown();
454 }
455
456 #[tokio::test]
458 async fn test_instance_avail_watcher() {
459 let rt = Runtime::from_current().unwrap();
460 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
462 .await
463 .unwrap();
464 let ns = drt.namespace("test_watcher".to_string()).unwrap();
465 let component = ns.component("test_component".to_string()).unwrap();
466 let endpoint = component.endpoint("test_endpoint".to_string());
467
468 let client = endpoint.client().await.unwrap();
469 let watcher = client.instance_avail_watcher();
470
471 client.instance_avail.store(Arc::new(vec![1, 2, 3]));
473
474 client.report_instance_down(2);
476
477 let current = watcher.borrow().clone();
480 assert_eq!(current, vec![1, 3]);
481
482 rt.shutdown();
483 }
484
485 #[tokio::test]
487 async fn test_concurrent_select_and_increment() {
488 let state = Arc::new(RoutingOccupancyState::default());
489 let instance_ids: Vec<u64> = vec![100, 200, 300];
490 let num_requests = 90;
491
492 let mut handles = Vec::new();
493 for _ in 0..num_requests {
494 let state = state.clone();
495 let ids = instance_ids.clone();
496 handles.push(tokio::spawn(async move {
497 state.select_exact_min_and_increment(&ids).await
498 }));
499 }
500
501 for handle in handles {
502 handle.await.unwrap();
503 }
504
505 assert_eq!(state.load(100), 30);
506 assert_eq!(state.load(200), 30);
507 assert_eq!(state.load(300), 30);
508 }
509
510 #[tokio::test]
511 async fn test_connection_counts() {
512 let rt = Runtime::from_current().unwrap();
513 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
514 .await
515 .unwrap();
516 let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
517 let component = ns.component("test_component".to_string()).unwrap();
518 let endpoint = component.endpoint("test_endpoint".to_string());
519
520 let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
521 let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
522
523 let picked1 = state1
524 .select_exact_min_and_increment(&[10, 20, 30])
525 .await
526 .unwrap();
527 assert_eq!(state1.load(picked1), 1);
528
529 let picked2 = state1
530 .select_exact_min_and_increment(&[10, 20, 30])
531 .await
532 .unwrap();
533 assert_ne!(picked1, picked2);
534
535 assert_eq!(state2.load(10), state1.load(10));
537 assert_eq!(state2.load(20), state1.load(20));
538 assert_eq!(state2.load(30), state1.load(30));
539
540 state2.decrement(picked1);
541 assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
542
543 rt.shutdown();
544 }
545
546 #[tokio::test]
547 async fn test_least_loaded_state_retain() {
548 let state = RoutingOccupancyState::default();
549
550 state.select_exact_min_and_increment(&[1, 2, 3]).await;
552 state.select_exact_min_and_increment(&[1, 2, 3]).await;
553 state.select_exact_min_and_increment(&[1, 2, 3]).await;
554 assert_eq!(state.load(1), 1);
556 assert_eq!(state.load(2), 1);
557 assert_eq!(state.load(3), 1);
558
559 state.retain(&[1, 3]);
561
562 assert_eq!(state.load(1), 1);
563 assert_eq!(state.load(2), 0);
564 assert_eq!(state.load(3), 1);
565 }
566
567 #[tokio::test]
568 async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
569 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
570
571 let rt = Runtime::from_current().unwrap();
572 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
573 .await
574 .unwrap();
575 let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
576 let component = ns.component("test_component".to_string()).unwrap();
577 let endpoint = component.endpoint("test_endpoint".to_string());
578
579 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
580 .await
581 .unwrap();
582 endpoint.register_endpoint_instance().await.unwrap();
583 client.wait_for_instances().await.unwrap();
584
585 let worker_id = client.instance_ids_avail()[0];
586 let state = get_or_create_routing_occupancy_state(&endpoint).await;
587 state.increment(worker_id);
588 assert_eq!(state.load(worker_id), 1);
589
590 endpoint.unregister_endpoint_instance().await.unwrap();
591
592 for _ in 0..10 {
593 if state.load(worker_id) == 0 {
594 break;
595 }
596 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
597 }
598
599 assert_eq!(state.load(worker_id), 0);
600
601 rt.shutdown();
602 }
603}