use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use nodedb_cluster::DescriptorId;
use tokio::time::timeout;
use crate::control::state::SharedState;
pub const DEFAULT_SHUTDOWN_RELEASE_TIMEOUT: Duration = Duration::from_secs(2);
pub async fn release_all_local_leases(shared: Arc<SharedState>, deadline: Duration) {
let descriptor_ids = collect_local_descriptor_ids(&shared);
if descriptor_ids.is_empty() {
tracing::debug!("shutdown release: no local leases to release");
return;
}
let count = descriptor_ids.len();
let release_shared = Arc::clone(&shared);
let release_task = tokio::task::spawn_blocking(move || {
release_shared.release_descriptor_leases(descriptor_ids)
});
match timeout(deadline, release_task).await {
Ok(Ok(Ok(()))) => {
tracing::info!(count, "shutdown release: released {count} local leases");
}
Ok(Ok(Err(e))) => {
tracing::warn!(
error = %e,
count,
"shutdown release: propose failed, leases will drain via TTL"
);
}
Ok(Err(join_err)) => {
tracing::warn!(
error = %join_err,
"shutdown release: spawn_blocking task panicked"
);
}
Err(_) => {
tracing::warn!(
count,
deadline = ?deadline,
"shutdown release: deadline exceeded, leases will drain via TTL"
);
}
}
}
fn collect_local_descriptor_ids(shared: &SharedState) -> Vec<DescriptorId> {
let cache = shared
.metadata_cache
.read()
.unwrap_or_else(|p| p.into_inner());
let mut seen = HashSet::new();
let mut out = Vec::new();
for ((id, node_id), _) in cache.leases.iter() {
if *node_id != shared.node_id {
continue;
}
if seen.insert(id.clone()) {
out.push(id.clone());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use nodedb_cluster::{DescriptorKind, DescriptorLease};
use nodedb_types::Hlc;
fn make_lease(node_id: u64, name: &str) -> DescriptorLease {
DescriptorLease {
descriptor_id: DescriptorId::new(1, DescriptorKind::Collection, name.to_string()),
version: 1,
node_id,
expires_at: Hlc::new(u64::MAX, 0),
}
}
#[test]
fn collect_filters_by_node_id_and_dedupes() {
let mut map: std::collections::HashMap<(DescriptorId, u64), DescriptorLease> =
std::collections::HashMap::new();
let a = DescriptorId::new(1, DescriptorKind::Collection, "a".to_string());
let b = DescriptorId::new(1, DescriptorKind::Collection, "b".to_string());
map.insert((a.clone(), 1), make_lease(1, "a"));
map.insert((b.clone(), 1), make_lease(1, "b"));
map.insert((a.clone(), 2), make_lease(2, "a"));
let self_node_id = 1u64;
let mut seen = HashSet::new();
let mut out = Vec::new();
for ((id, node_id), _) in map.iter() {
if *node_id != self_node_id {
continue;
}
if seen.insert(id.clone()) {
out.push(id.clone());
}
}
assert_eq!(out.len(), 2);
let names: std::collections::HashSet<_> = out.iter().map(|i| i.name.clone()).collect();
assert!(names.contains("a"));
assert!(names.contains("b"));
}
}