use std::time::Duration;
use chrono::Utc;
use tokio_util::sync::CancellationToken;
use super::cron::CRON_LEASE_TTL;
use crate::storage::Storage;
pub const REBALANCE_TICK: Duration = Duration::from_secs(5);
fn fair_shares(total: usize, pods: usize) -> Vec<usize> {
if pods == 0 {
return Vec::new();
}
let base = total / pods;
let extra = total % pods;
(0..pods)
.map(|i| if i < extra { base + 1 } else { base })
.collect()
}
pub(super) async fn pod_heartbeat_loop(
storage: Storage,
host_id: String,
worker_name: Option<String>,
queues: Vec<String>,
shutdown: CancellationToken,
) {
let mut tick = tokio::time::interval(super::SUPERVISOR_TICK);
tick.tick().await;
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => return,
_ = tick.tick() => {
if let Err(e) = storage
.procs
.pod_heartbeat(&host_id, worker_name.as_deref(), &queues)
.await
{
tracing::warn!(?e, %host_id, "rebalance: pod heartbeat failed");
}
}
}
}
}
pub(super) async fn rebalance_loop(storage: Storage, host_id: String, shutdown: CancellationToken) {
let mut tick = tokio::time::interval(REBALANCE_TICK);
tick.tick().await;
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => return,
_ = tick.tick() => {
match storage.cron.try_cron_lease(&host_id, CRON_LEASE_TTL).await {
Ok(true) => {}
Ok(false) => continue,
Err(e) => {
tracing::warn!(?e, %host_id, "rebalance: lease check failed");
continue;
}
}
if let Err(e) = rebalance_once(&storage).await {
tracing::warn!(?e, "rebalance: tick failed");
}
}
}
}
}
pub async fn rebalance_once(storage: &Storage) -> crate::storage::error::Result<()> {
let stale_before = Utc::now() - super::STALE_THRESHOLD;
let pods = storage.procs.list_live_pods(stale_before).await?;
if pods.is_empty() {
return Ok(());
}
let queues = storage.config.list_queues().await?;
let assignments_snapshot = storage.procs.list_slot_assignments().await;
let positive: std::collections::HashSet<(&str, &str)> = match assignments_snapshot {
Ok(ref a) => a
.iter()
.filter(|s| s.slots > 0)
.map(|s| (s.queue_name.as_str(), s.host_id.as_str()))
.collect(),
Err(ref e) => {
tracing::warn!(
?e,
"rebalance: slot-assignment read failed; skipping zero-out this tick"
);
std::collections::HashSet::new()
}
};
for q in queues {
let eligible: Vec<&str> = pods
.iter()
.filter(|p| p.handles(&q.name))
.map(|p| p.host_id.as_str())
.collect();
let total = usize::try_from(q.max_workers).unwrap_or(0);
if eligible.is_empty() && total > 0 {
tracing::warn!(
queue = %q.name,
"rebalance: no live worker declares this queue; its jobs will not run \
until a worker lists it in FORGE_QUEUES / with_queues",
);
}
let shares = fair_shares(total, eligible.len());
for (host, slots) in eligible.iter().zip(shares) {
let slots = i32::try_from(slots).unwrap_or(0);
if let Err(e) = storage.procs.set_slots(&q.name, host, slots).await {
tracing::warn!(?e, queue = %q.name, %host, "rebalance: set_slots failed");
}
}
for p in &pods {
if eligible.iter().any(|h| *h == p.host_id) {
continue;
}
if !positive.contains(&(q.name.as_str(), p.host_id.as_str())) {
continue;
}
if let Err(e) = storage.procs.set_slots(&q.name, &p.host_id, 0).await {
tracing::warn!(?e, queue = %q.name, host = %p.host_id, "rebalance: zero set_slots failed");
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::fair_shares;
#[test]
fn fair_shares_distributes_remainder_to_leaders() {
assert_eq!(fair_shares(10, 3), vec![4, 3, 3]);
assert_eq!(fair_shares(9, 3), vec![3, 3, 3]);
assert_eq!(fair_shares(2, 3), vec![1, 1, 0]);
assert_eq!(fair_shares(0, 3), vec![0, 0, 0]);
assert_eq!(fair_shares(7, 1), vec![7]);
}
#[test]
fn fair_shares_zero_pods_is_empty() {
assert!(fair_shares(10, 0).is_empty());
}
#[test]
fn fair_shares_conserves_the_total() {
for total in 0..50 {
for pods in 1..8 {
let sum: usize = fair_shares(total, pods).iter().sum();
assert_eq!(sum, total, "total={total} pods={pods}");
}
}
}
}