use std::collections::HashSet;
use std::time::Duration;
use anyhow::Result;
use dynamo_runtime::{
component::Component,
config::environment_names::nats as env_nats,
discovery::{DiscoveryEvent, DiscoveryQuery},
prelude::*,
transports::nats::{NatsQueue, Slug},
};
use futures::StreamExt;
use rand::Rng;
use tokio_util::sync::CancellationToken;
use crate::kv_router::{
Indexer, KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
protocols::RouterEvent, router_discovery_query,
};
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
const SNAPSHOT_STABILITY_DELAY: Duration = Duration::from_millis(100);
const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100;
async fn get_instance_discovery_stream(
component: &Component,
cancellation_token: &CancellationToken,
) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Result<DiscoveryEvent>> + Send>>> {
let discovery_client = component.drt().discovery();
let generate_discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut stream = discovery_client
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?
.peekable();
tracing::info!("KV subscriber waiting for at least one worker instance...");
std::pin::Pin::new(&mut stream).peek().await;
Ok(Box::pin(stream))
}
async fn download_stable_snapshot(
nats_client: &dynamo_runtime::transports::nats::Client,
bucket_name: &str,
indexer: &Indexer,
) -> Result<()> {
let url = url::Url::parse(&format!(
"nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
nats_client.addr()
))?;
let Ok(mut prev_events) = nats_client
.object_store_download_data::<Vec<RouterEvent>>(&url)
.await
else {
tracing::debug!(
"Failed to download snapshots. This is normal for freshly started Router replicas."
);
return Ok(());
};
for attempt in 1..=MAX_SNAPSHOT_STABILITY_ATTEMPTS {
tokio::time::sleep(SNAPSHOT_STABILITY_DELAY).await;
let curr_events = match nats_client
.object_store_download_data::<Vec<RouterEvent>>(&url)
.await
{
Ok(events) => events,
Err(e) => {
tracing::warn!(
"Snapshot read failed on attempt {attempt}, using previous snapshot with {} events: {e:?}",
prev_events.len()
);
break;
}
};
if prev_events == curr_events {
tracing::info!(
"Successfully downloaded stable snapshot with {} events from object store (stable after {attempt} attempts)",
curr_events.len()
);
prev_events = curr_events;
break;
}
tracing::debug!(
"Snapshot changed between reads on attempt {attempt} ({} -> {} events), retrying",
prev_events.len(),
curr_events.len()
);
prev_events = curr_events;
if attempt == MAX_SNAPSHOT_STABILITY_ATTEMPTS {
tracing::warn!(
"Max stability attempts reached, using latest snapshot with {} events",
prev_events.len()
);
}
}
for event in prev_events {
indexer.apply_event(event).await;
}
tracing::info!("Successfully sent all initial events to indexer");
Ok(())
}
#[derive(Clone)]
struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String,
instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>,
indexer: Indexer,
}
impl SnapshotResources {
async fn purge_then_snapshot(&self, nats_queue: &mut NatsQueue) -> anyhow::Result<()> {
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
let current_instances = self.instances_rx.borrow().clone();
let current_worker_ids: std::collections::HashSet<u64> = current_instances
.iter()
.map(|instance| instance.instance_id)
.collect();
let indexer_worker_ids = self.indexer.get_workers().await;
for worker_id in indexer_worker_ids {
if !current_worker_ids.contains(&worker_id) {
tracing::info!("Removing stale worker {worker_id} from indexer during snapshot");
self.indexer.remove_worker(worker_id).await;
}
}
nats_queue.purge_acknowledged().await?;
let events = self
.indexer
.dump_events()
.await
.map_err(|e| anyhow::anyhow!("Failed to dump events for snapshot: {e:?}"))?;
let nats_client = self.nats_client.clone();
let bucket_name = self.bucket_name.clone();
let event_count = events.len();
tokio::spawn(async move {
let Ok(url) = url::Url::parse(&format!(
"nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
nats_client.addr(),
)) else {
tracing::warn!("Failed to parse snapshot URL");
return;
};
if let Err(e) = nats_client.object_store_upload_data(&events, &url).await {
tracing::warn!("Failed to upload snapshot: {e:?}");
return;
}
tracing::info!(
"Successfully uploaded snapshot with {event_count} events to bucket {bucket_name} in {}ms",
start_time.elapsed().as_millis()
);
});
Ok(())
}
}
pub(crate) async fn start_kv_router_background(
component: Component,
consumer_id: String,
indexer: Indexer,
kv_router_config: &KvRouterConfig,
) -> Result<()> {
let cancellation_token = component.drt().primary_token();
let router_snapshot_threshold = kv_router_config.router_snapshot_threshold;
let router_reset_states = kv_router_config.router_reset_states;
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
let mut nats_queue = NatsQueue::new_with_consumer(
stream_name.clone(),
nats_server.clone(),
std::time::Duration::from_secs(60), consumer_id.clone(),
);
nats_queue.connect_with_reset(router_reset_states).await?;
let client_options = dynamo_runtime::transports::nats::Client::builder()
.server(&nats_server)
.build()?;
let nats_client = client_options.connect().await?;
let event_plane_subject = format!(
"namespace.{}.component.{}",
component.namespace().name(),
component.name()
);
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", event_plane_subject))
.to_string()
.replace("_", "-");
if !router_reset_states {
download_stable_snapshot(&nats_client, &bucket_name, &indexer).await?;
} else {
tracing::info!("Resetting router state, deleting bucket: {bucket_name}");
if let Err(e) = nats_client.object_store_delete_bucket(&bucket_name).await {
tracing::warn!("Failed to delete bucket (may not exist): {e:?}");
}
}
cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await;
let mut instance_event_stream =
get_instance_discovery_stream(&component, &cancellation_token).await?;
let generate_endpoint = component.endpoint("generate");
let discovery_client = component.drt().discovery();
let router_discovery_key =
router_discovery_query(component.namespace().name(), component.name().to_string());
let mut router_event_stream = discovery_client
.list_and_watch(router_discovery_key, Some(cancellation_token.clone()))
.await?;
let client = generate_endpoint.client().await?;
let instances_rx = client.instance_source.as_ref().clone();
let snapshot_resources = router_snapshot_threshold.map(|_| SnapshotResources {
nats_client,
bucket_name,
instances_rx,
indexer: indexer.clone(),
});
tokio::spawn(async move {
let jitter_ms =
rand::rng().random_range(-CHECK_INTERVAL_JITTER_MS..=CHECK_INTERVAL_JITTER_MS);
let interval_duration = Duration::from_millis(
(CHECK_INTERVAL_BASE.as_millis() as i64 + jitter_ms).max(1) as u64,
);
let mut check_interval = tokio::time::interval(interval_duration);
check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("KV Router background task received cancellation signal");
break;
}
Some(discovery_event_result) = instance_event_stream.next() => {
let Ok(discovery_event) = discovery_event_result else {
continue;
};
let DiscoveryEvent::Removed(id) = discovery_event else {
continue;
};
let worker_id = id.instance_id();
tracing::warn!(
"DISCOVERY: Generate endpoint instance removed, removing worker {worker_id}"
);
indexer.remove_worker(worker_id).await;
}
result = nats_queue.dequeue_task(None) => {
match result {
Ok(Some(bytes)) => {
let event: RouterEvent = match serde_json::from_slice(&bytes) {
Ok(event) => event,
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {e:?}");
continue;
}
};
indexer.apply_event(event).await;
},
Ok(None) => {
tracing::trace!("Dequeue timeout, continuing");
},
Err(e) => {
tracing::error!("Failed to dequeue task: {e:?}");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
_ = check_interval.tick() => {
let Some(resources) = snapshot_resources.as_ref() else {
continue;
};
let Ok(message_count) = nats_queue.get_stream_messages().await else {
tracing::warn!("Failed to get stream message count");
continue;
};
let threshold = router_snapshot_threshold.unwrap_or(u32::MAX) as u64;
if message_count <= threshold {
continue;
}
tracing::info!("Stream has {message_count} messages (threshold: {threshold}), performing purge and snapshot");
match resources.purge_then_snapshot(
&mut nats_queue,
).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::debug!("Could not perform purge and snapshot: {e:?}"),
}
}
Some(router_event_result) = router_event_stream.next() => {
let Ok(router_event) = router_event_result else {
continue;
};
let DiscoveryEvent::Removed(id) = router_event else {
continue;
};
let router_instance_id = id.instance_id();
let consumer_to_delete = router_instance_id.to_string();
tracing::info!(
"DISCOVERY: Router instance {router_instance_id} removed, attempting to delete orphaned consumer: {consumer_to_delete}"
);
if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await {
tracing::warn!("Failed to delete consumer {consumer_to_delete}: {e}");
} else {
tracing::info!("Successfully deleted orphaned consumer: {consumer_to_delete}");
}
}
}
}
if let Err(e) = nats_queue.shutdown(None).await {
tracing::warn!("Failed to shutdown NatsQueue: {e}");
}
});
Ok(())
}
async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue,
component: &Component,
consumer_id: &str,
) {
let Ok(consumers) = nats_queue.list_consumers().await else {
return;
};
let discovery = component.drt().discovery();
let Ok(router_instances) = discovery
.list(router_discovery_query(
component.namespace().name(),
component.name().to_string(),
))
.await
else {
tracing::debug!("Failed to list router instances from discovery, skipping cleanup");
return;
};
let active_instance_ids: HashSet<String> = router_instances
.iter()
.map(|instance| instance.instance_id().to_string())
.collect();
for consumer in consumers {
if consumer == consumer_id {
continue;
}
if !active_instance_ids.contains(&consumer) {
tracing::info!("Cleaning up orphaned consumer: {consumer}");
let _ = nats_queue.shutdown(Some(consumer)).await;
}
}
}