use crate::config::LeaseConfig;
use crate::error::Result;
use crate::lease::Lease;
use crate::storage::LeaseStorage;
use rand::Rng;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
use tracing::{error, warn};
pub struct LeaseManager<S: LeaseStorage> {
storage: Arc<S>,
worker_id: String,
config: LeaseConfig,
}
impl<S: LeaseStorage + 'static> LeaseManager<S> {
pub fn new(storage: Arc<S>, worker_id: String, config: LeaseConfig) -> Self {
Self {
storage,
worker_id,
config,
}
}
pub fn worker_id(&self) -> &str {
&self.worker_id
}
pub async fn ensure_lease(&self, key: &str) -> Result<()> {
let leases = self.storage.list_leases().await?;
if !leases.iter().any(|l| l.lease_key == key) {
match self.storage.create_lease(key).await {
Ok(_) => {}
Err(crate::error::LeaseError::Conflict) => {}
Err(e) => return Err(e),
}
}
Ok(())
}
pub async fn get_my_lease_keys(&self) -> Result<Vec<String>> {
let leases = self.storage.list_leases().await?;
Ok(self
.my_leases(&leases)
.iter()
.map(|l| l.lease_key.clone())
.collect())
}
pub async fn rebalance(&self) -> Result<Vec<String>> {
let all_leases = self.storage.list_leases().await?;
let active_workers = self.count_active_workers(&all_leases);
let total = all_leases.len();
let fair_share = (total + active_workers - 1) / active_workers;
let target = match self.config.max_leases_per_worker {
Some(max) => fair_share.min(max),
None => fair_share,
};
let my_count = self.my_leases(&all_leases).len();
let mut remaining_deficit = target.saturating_sub(my_count);
if remaining_deficit > 0 {
let unowned: Vec<&Lease> = all_leases
.iter()
.filter(|l| l.owner.is_none())
.take(remaining_deficit)
.collect();
for lease in &unowned {
if self
.storage
.acquire_lease(lease, &self.worker_id)
.await
.unwrap_or(false)
{
remaining_deficit = remaining_deficit.saturating_sub(1);
}
}
if remaining_deficit > 0 {
let expired: Vec<&Lease> = all_leases
.iter()
.filter(|l| l.is_expired() && !l.is_owned_by(&self.worker_id))
.take(remaining_deficit)
.collect();
for lease in &expired {
let _ = self.storage.acquire_lease(lease, &self.worker_id).await;
}
}
}
let updated = self.storage.list_leases().await?;
Ok(self
.my_leases(&updated)
.iter()
.map(|l| l.lease_key.clone())
.collect())
}
pub async fn renew_my_leases(&self) -> Result<()> {
let leases = self.storage.list_leases().await?;
let my_leases = self.my_leases(&leases);
let mut handles = tokio::task::JoinSet::new();
for lease in my_leases {
let storage = self.storage.clone();
let l = lease.clone();
handles.spawn(async move {
if let Err(e) = storage.renew_lease(&l).await {
warn!("Failed to renew lease {}: {}", l.lease_key, e);
}
});
}
while let Some(res) = handles.join_next().await {
if let Err(e) = res {
error!("Task join error during renewal: {}", e);
}
}
Ok(())
}
pub async fn get_checkpoint(&self, lease_key: &str) -> Result<Option<String>> {
self.storage.get_checkpoint(lease_key).await
}
pub async fn checkpoint(&self, lease_key: &str, checkpoint: &str) -> Result<()> {
self.storage.update_checkpoint(lease_key, checkpoint).await?;
Ok(())
}
fn count_active_workers(&self, leases: &[Lease]) -> usize {
let mut owners = std::collections::HashSet::new();
owners.insert(&self.worker_id);
for lease in leases {
if !lease.is_expired() {
if let Some(ref owner) = lease.owner {
owners.insert(owner);
}
}
}
owners.len().max(1)
}
fn my_leases<'a>(&self, leases: &'a [Lease]) -> Vec<&'a Lease> {
leases
.iter()
.filter(|l| l.is_owned_by(&self.worker_id))
.collect()
}
pub fn start_background_tasks(self: Arc<Self>) {
let manager_renew = self.clone();
tokio::spawn(async move {
loop {
sleep(Duration::from_millis(
manager_renew.config.renewal_interval_ms as u64,
))
.await;
if let Err(e) = manager_renew.renew_my_leases().await {
error!("Renewal failed: {}", e);
}
}
});
let manager_rebalance = self.clone();
tokio::spawn(async move {
loop {
let jitter = {
let mut rng = rand::thread_rng();
rng.gen_range(0..1000u64)
};
sleep(Duration::from_millis(
(manager_rebalance.config.rebalance_interval_ms as u64) + jitter,
))
.await;
match manager_rebalance.rebalance().await {
Ok(leases) => tracing::debug!("Holding {} leases: {:?}", leases.len(), leases),
Err(e) => error!("Rebalance failed: {}", e),
}
}
});
}
}