Skip to main content

leasy/
manager.rs

1use crate::config::LeaseConfig;
2use crate::error::Result;
3use crate::lease::Lease;
4use crate::storage::LeaseStorage;
5use rand::Rng;
6use std::sync::Arc;
7use tokio::time::{sleep, Duration};
8use tracing::{error, warn};
9
10pub struct LeaseManager<S: LeaseStorage> {
11    storage: Arc<S>,
12    worker_id: String,
13    config: LeaseConfig,
14}
15
16impl<S: LeaseStorage + 'static> LeaseManager<S> {
17    pub fn new(storage: Arc<S>, worker_id: String, config: LeaseConfig) -> Self {
18        Self {
19            storage,
20            worker_id,
21            config,
22        }
23    }
24
25    pub fn worker_id(&self) -> &str {
26        &self.worker_id
27    }
28
29    /// Ensures a lease row exists for this key. Idempotent — safe to call
30    /// from multiple workers simultaneously.
31    pub async fn ensure_lease(&self, key: &str) -> Result<()> {
32        let leases = self.storage.list_leases().await?;
33        if !leases.iter().any(|l| l.lease_key == key) {
34            match self.storage.create_lease(key).await {
35                Ok(_) => {}
36                // Another worker raced and created it — that's fine
37                Err(crate::error::LeaseError::Conflict) => {}
38                Err(e) => return Err(e),
39            }
40        }
41        Ok(())
42    }
43
44    /// Returns lease keys currently owned by this worker.
45    pub async fn get_my_lease_keys(&self) -> Result<Vec<String>> {
46        let leases = self.storage.list_leases().await?;
47        Ok(self
48            .my_leases(&leases)
49            .iter()
50            .map(|l| l.lease_key.clone())
51            .collect())
52    }
53
54    /// Core rebalance algorithm. Tries to reach a fair share of leases for
55    /// this worker by first claiming unowned leases, then stealing expired ones.
56    pub async fn rebalance(&self) -> Result<Vec<String>> {
57        let all_leases = self.storage.list_leases().await?;
58
59        let active_workers = self.count_active_workers(&all_leases);
60        let total = all_leases.len();
61
62        // ceil(total / active_workers), capped by max_leases_per_worker
63        let fair_share = (total + active_workers - 1) / active_workers;
64        let target = match self.config.max_leases_per_worker {
65            Some(max) => fair_share.min(max),
66            None => fair_share,
67        };
68
69        let my_count = self.my_leases(&all_leases).len();
70        let mut remaining_deficit = target.saturating_sub(my_count);
71
72        if remaining_deficit > 0 {
73            // Priority 1: take unowned leases
74            let unowned: Vec<&Lease> = all_leases
75                .iter()
76                .filter(|l| l.owner.is_none())
77                .take(remaining_deficit)
78                .collect();
79
80            for lease in &unowned {
81                if self
82                    .storage
83                    .acquire_lease(lease, &self.worker_id)
84                    .await
85                    .unwrap_or(false)
86                {
87                    remaining_deficit = remaining_deficit.saturating_sub(1);
88                }
89            }
90
91            // Priority 2: steal expired leases (only if still short)
92            if remaining_deficit > 0 {
93                let expired: Vec<&Lease> = all_leases
94                    .iter()
95                    .filter(|l| l.is_expired() && !l.is_owned_by(&self.worker_id))
96                    .take(remaining_deficit)
97                    .collect();
98
99                for lease in &expired {
100                    // acquire_lease is conditional — if another worker stole it first, we get false
101                    let _ = self.storage.acquire_lease(lease, &self.worker_id).await;
102                }
103            }
104        }
105
106        // Return final snapshot of our leases
107        let updated = self.storage.list_leases().await?;
108        Ok(self
109            .my_leases(&updated)
110            .iter()
111            .map(|l| l.lease_key.clone())
112            .collect())
113    }
114
115    /// Renew all leases owned by this worker. Uses concurrent requests to
116    /// minimize total latency when holding many leases.
117    pub async fn renew_my_leases(&self) -> Result<()> {
118        let leases = self.storage.list_leases().await?;
119        let my_leases = self.my_leases(&leases);
120
121        let mut handles = tokio::task::JoinSet::new();
122
123        for lease in my_leases {
124            let storage = self.storage.clone();
125            let l = lease.clone();
126            handles.spawn(async move {
127                if let Err(e) = storage.renew_lease(&l).await {
128                    warn!("Failed to renew lease {}: {}", l.lease_key, e);
129                }
130            });
131        }
132
133        while let Some(res) = handles.join_next().await {
134            if let Err(e) = res {
135                error!("Task join error during renewal: {}", e);
136            }
137        }
138
139        Ok(())
140    }
141
142    pub async fn get_checkpoint(&self, lease_key: &str) -> Result<Option<String>> {
143        self.storage.get_checkpoint(lease_key).await
144    }
145
146    pub async fn checkpoint(&self, lease_key: &str, checkpoint: &str) -> Result<()> {
147        self.storage.update_checkpoint(lease_key, checkpoint).await?;
148        Ok(())
149    }
150
151    /// Count distinct non-expired owners. Returns at least 1 (this worker).
152    fn count_active_workers(&self, leases: &[Lease]) -> usize {
153        let mut owners = std::collections::HashSet::new();
154        // Always count ourselves even if we hold zero leases yet
155        owners.insert(&self.worker_id);
156        for lease in leases {
157            if !lease.is_expired() {
158                if let Some(ref owner) = lease.owner {
159                    owners.insert(owner);
160                }
161            }
162        }
163        owners.len().max(1)
164    }
165
166    fn my_leases<'a>(&self, leases: &'a [Lease]) -> Vec<&'a Lease> {
167        leases
168            .iter()
169            .filter(|l| l.is_owned_by(&self.worker_id))
170            .collect()
171    }
172
173    /// Spawns background tokio tasks for periodic renewal and rebalancing.
174    /// Call this once after creating the manager.
175    pub fn start_background_tasks(self: Arc<Self>) {
176        let manager_renew = self.clone();
177
178        // Renewal loop
179        tokio::spawn(async move {
180            loop {
181                sleep(Duration::from_millis(
182                    manager_renew.config.renewal_interval_ms as u64,
183                ))
184                .await;
185
186                if let Err(e) = manager_renew.renew_my_leases().await {
187                    error!("Renewal failed: {}", e);
188                }
189            }
190        });
191
192        // Rebalance loop with jitter to prevent thundering herd
193        let manager_rebalance = self.clone();
194        tokio::spawn(async move {
195            loop {
196                // Generate jitter inline — thread_rng() is !Send so we
197                // must not hold it across the .await boundary
198                let jitter = {
199                    let mut rng = rand::thread_rng();
200                    rng.gen_range(0..1000u64)
201                };
202                sleep(Duration::from_millis(
203                    (manager_rebalance.config.rebalance_interval_ms as u64) + jitter,
204                ))
205                .await;
206
207                match manager_rebalance.rebalance().await {
208                    Ok(leases) => tracing::debug!("Holding {} leases: {:?}", leases.len(), leases),
209                    Err(e) => error!("Rebalance failed: {}", e),
210                }
211            }
212        });
213    }
214}