use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
use anyhow::Result;
use arc_swap::ArcSwap;
use futures::StreamExt;
use tokio::net::unix::pipe::Receiver;
use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
use crate::{
component::{Endpoint, Instance},
pipeline::async_trait,
pipeline::{
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn,
},
traits::DistributedRuntimeProvider,
transports::etcd::Client as EtcdClient,
};
const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
#[derive(Clone, Debug)]
pub struct Client {
pub endpoint: Endpoint,
pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
instance_avail: Arc<ArcSwap<Vec<u64>>>,
instance_free: Arc<ArcSwap<Vec<u64>>>,
instance_avail_tx: Arc<tokio::sync::watch::Sender<Vec<u64>>>,
instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
reconcile_interval: Duration,
}
impl Client {
pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
}
pub(crate) async fn with_reconcile_interval(
endpoint: Endpoint,
reconcile_interval: Duration,
) -> Result<Self> {
tracing::trace!(
"Client::new_dynamic: Creating dynamic client for endpoint: {}",
endpoint.id()
);
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
let initial_ids: Vec<u64> = instance_source
.borrow()
.iter()
.map(|instance| instance.id())
.collect();
let (avail_tx, avail_rx) = tokio::sync::watch::channel(initial_ids.clone());
let client = Client {
endpoint: endpoint.clone(),
instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(initial_ids.clone()))),
instance_free: Arc::new(ArcSwap::from(Arc::new(initial_ids))),
instance_avail_tx: Arc::new(avail_tx),
instance_avail_rx: avail_rx,
reconcile_interval,
};
client.monitor_instance_source();
Ok(client)
}
pub fn instances(&self) -> Vec<Instance> {
self.instance_source.borrow().clone()
}
pub fn instance_ids(&self) -> Vec<u64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
self.instance_avail.load()
}
pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
self.instance_free.load()
}
pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
self.instance_avail_rx.clone()
}
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
tracing::trace!(
"wait_for_instances: Starting wait for endpoint: {}",
self.endpoint.id()
);
let mut rx = self.instance_source.as_ref().clone();
let mut instances: Vec<Instance>;
loop {
instances = rx.borrow_and_update().to_vec();
if instances.is_empty() {
rx.changed().await?;
} else {
tracing::info!(
"wait_for_instances: Found {} instance(s) for endpoint: {}",
instances.len(),
self.endpoint.id()
);
break;
}
}
Ok(instances)
}
pub fn report_instance_down(&self, instance_id: u64) {
let filtered = self
.instance_ids_avail()
.iter()
.filter_map(|&id| if id == instance_id { None } else { Some(id) })
.collect::<Vec<_>>();
self.instance_avail.store(Arc::new(filtered.clone()));
let _ = self.instance_avail_tx.send(filtered);
tracing::debug!("inhibiting instance {instance_id}");
}
pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
let all_instance_ids = self.instance_ids();
let free_ids: Vec<u64> = all_instance_ids
.into_iter()
.filter(|id| !busy_instance_ids.contains(id))
.collect();
self.instance_free.store(Arc::new(free_ids));
}
fn monitor_instance_source(&self) {
let reconcile_interval = self.reconcile_interval;
let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
let endpoint_id = self.endpoint.id();
tokio::task::spawn(async move {
let mut rx = client.instance_source.as_ref().clone();
while !cancel_token.is_cancelled() {
let instance_ids: Vec<u64> = rx
.borrow_and_update()
.iter()
.map(|instance| instance.id())
.collect();
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids.clone()));
let _ = client.instance_avail_tx.send(instance_ids);
tokio::select! {
result = rx.changed() => {
if let Err(err) = result {
tracing::error!(
"monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
);
cancel_token.cancel();
}
}
_ = tokio::time::sleep(reconcile_interval) => {
tracing::trace!(
"monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
);
}
}
}
});
}
async fn get_or_create_dynamic_instance_source(
endpoint: &Endpoint,
) -> Result<Arc<tokio::sync::watch::Receiver<Vec<Instance>>>> {
let drt = endpoint.drt();
let instance_sources = drt.instance_sources();
let mut instance_sources = instance_sources.lock().await;
if let Some(instance_source) = instance_sources.get(endpoint) {
if let Some(instance_source) = instance_source.upgrade() {
return Ok(instance_source);
} else {
instance_sources.remove(endpoint);
}
}
let discovery = drt.discovery();
let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
namespace: endpoint.component.namespace.name.clone(),
component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(),
};
let mut discovery_stream = discovery
.list_and_watch(discovery_query.clone(), None)
.await?;
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime().secondary().clone();
secondary.spawn(async move {
tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
let mut map: HashMap<u64, Instance> = HashMap::new();
loop {
let discovery_event = tokio::select! {
_ = watch_tx.closed() => {
break;
}
discovery_event = discovery_stream.next() => {
match discovery_event {
Some(Ok(event)) => {
event
},
Some(Err(e)) => {
tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
break;
}
None => {
break;
}
}
}
};
match discovery_event {
DiscoveryEvent::Added(discovery_instance) => {
if let DiscoveryInstance::Endpoint(instance) = discovery_instance {
map.insert(instance.instance_id, instance);
}
}
DiscoveryEvent::Removed(id) => {
map.remove(&id.instance_id());
}
}
let instances: Vec<Instance> = map.values().cloned().collect();
if watch_tx.send(instances).is_err() {
break;
}
}
let _ = watch_tx.send(vec![]);
});
let instance_source = Arc::new(watch_rx);
instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
Ok(instance_source)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
#[tokio::test]
async fn test_instance_reconciliation() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
.await
.unwrap();
assert!(client.instance_ids_avail().is_empty());
client.instance_avail.store(Arc::new(vec![1, 2, 3]));
assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
client.report_instance_down(2);
assert_eq!(**client.instance_ids_avail(), vec![1u64, 3]);
tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
assert!(
client.instance_ids_avail().is_empty(),
"After reconciliation, instance_avail should match instance_source"
);
rt.shutdown();
}
#[tokio::test]
async fn test_report_instance_down() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_report_down".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
client.instance_avail.store(Arc::new(vec![1, 2, 3]));
assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
client.report_instance_down(2);
let avail = client.instance_ids_avail();
assert!(avail.contains(&1), "Instance 1 should still be available");
assert!(
!avail.contains(&2),
"Instance 2 should be removed after report_instance_down"
);
assert!(avail.contains(&3), "Instance 3 should still be available");
rt.shutdown();
}
#[tokio::test]
async fn test_instance_avail_watcher() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_watcher".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
let watcher = client.instance_avail_watcher();
client.instance_avail.store(Arc::new(vec![1, 2, 3]));
client.report_instance_down(2);
let current = watcher.borrow().clone();
assert_eq!(current, vec![1, 3]);
rt.shutdown();
}
}