1use crate::config::LeaseConfig;
2use crate::error::Result;
3use crate::lease::Lease;
4use crate::storage::LeaseStorage;
5use rand::Rng;
6use std::sync::Arc;
7use tokio::time::{sleep, Duration};
8use tracing::{error, warn};
9
10pub struct LeaseManager<S: LeaseStorage> {
11 storage: Arc<S>,
12 worker_id: String,
13 config: LeaseConfig,
14}
15
16impl<S: LeaseStorage + 'static> LeaseManager<S> {
17 pub fn new(storage: Arc<S>, worker_id: String, config: LeaseConfig) -> Self {
18 Self {
19 storage,
20 worker_id,
21 config,
22 }
23 }
24
25 pub fn worker_id(&self) -> &str {
26 &self.worker_id
27 }
28
29 pub async fn ensure_lease(&self, key: &str) -> Result<()> {
32 let leases = self.storage.list_leases().await?;
33 if !leases.iter().any(|l| l.lease_key == key) {
34 match self.storage.create_lease(key).await {
35 Ok(_) => {}
36 Err(crate::error::LeaseError::Conflict) => {}
38 Err(e) => return Err(e),
39 }
40 }
41 Ok(())
42 }
43
44 pub async fn get_my_lease_keys(&self) -> Result<Vec<String>> {
46 let leases = self.storage.list_leases().await?;
47 Ok(self
48 .my_leases(&leases)
49 .iter()
50 .map(|l| l.lease_key.clone())
51 .collect())
52 }
53
54 pub async fn rebalance(&self) -> Result<Vec<String>> {
57 let all_leases = self.storage.list_leases().await?;
58
59 let active_workers = self.count_active_workers(&all_leases);
60 let total = all_leases.len();
61
62 let fair_share = (total + active_workers - 1) / active_workers;
64 let target = match self.config.max_leases_per_worker {
65 Some(max) => fair_share.min(max),
66 None => fair_share,
67 };
68
69 let my_count = self.my_leases(&all_leases).len();
70 let mut remaining_deficit = target.saturating_sub(my_count);
71
72 if remaining_deficit > 0 {
73 let unowned: Vec<&Lease> = all_leases
75 .iter()
76 .filter(|l| l.owner.is_none())
77 .take(remaining_deficit)
78 .collect();
79
80 for lease in &unowned {
81 if self
82 .storage
83 .acquire_lease(lease, &self.worker_id)
84 .await
85 .unwrap_or(false)
86 {
87 remaining_deficit = remaining_deficit.saturating_sub(1);
88 }
89 }
90
91 if remaining_deficit > 0 {
93 let expired: Vec<&Lease> = all_leases
94 .iter()
95 .filter(|l| l.is_expired() && !l.is_owned_by(&self.worker_id))
96 .take(remaining_deficit)
97 .collect();
98
99 for lease in &expired {
100 let _ = self.storage.acquire_lease(lease, &self.worker_id).await;
102 }
103 }
104 }
105
106 let updated = self.storage.list_leases().await?;
108 Ok(self
109 .my_leases(&updated)
110 .iter()
111 .map(|l| l.lease_key.clone())
112 .collect())
113 }
114
115 pub async fn renew_my_leases(&self) -> Result<()> {
118 let leases = self.storage.list_leases().await?;
119 let my_leases = self.my_leases(&leases);
120
121 let mut handles = tokio::task::JoinSet::new();
122
123 for lease in my_leases {
124 let storage = self.storage.clone();
125 let l = lease.clone();
126 handles.spawn(async move {
127 if let Err(e) = storage.renew_lease(&l).await {
128 warn!("Failed to renew lease {}: {}", l.lease_key, e);
129 }
130 });
131 }
132
133 while let Some(res) = handles.join_next().await {
134 if let Err(e) = res {
135 error!("Task join error during renewal: {}", e);
136 }
137 }
138
139 Ok(())
140 }
141
142 pub async fn get_checkpoint(&self, lease_key: &str) -> Result<Option<String>> {
143 self.storage.get_checkpoint(lease_key).await
144 }
145
146 pub async fn checkpoint(&self, lease_key: &str, checkpoint: &str) -> Result<()> {
147 self.storage.update_checkpoint(lease_key, checkpoint).await?;
148 Ok(())
149 }
150
151 fn count_active_workers(&self, leases: &[Lease]) -> usize {
153 let mut owners = std::collections::HashSet::new();
154 owners.insert(&self.worker_id);
156 for lease in leases {
157 if !lease.is_expired() {
158 if let Some(ref owner) = lease.owner {
159 owners.insert(owner);
160 }
161 }
162 }
163 owners.len().max(1)
164 }
165
166 fn my_leases<'a>(&self, leases: &'a [Lease]) -> Vec<&'a Lease> {
167 leases
168 .iter()
169 .filter(|l| l.is_owned_by(&self.worker_id))
170 .collect()
171 }
172
173 pub fn start_background_tasks(self: Arc<Self>) {
176 let manager_renew = self.clone();
177
178 tokio::spawn(async move {
180 loop {
181 sleep(Duration::from_millis(
182 manager_renew.config.renewal_interval_ms as u64,
183 ))
184 .await;
185
186 if let Err(e) = manager_renew.renew_my_leases().await {
187 error!("Renewal failed: {}", e);
188 }
189 }
190 });
191
192 let manager_rebalance = self.clone();
194 tokio::spawn(async move {
195 loop {
196 let jitter = {
199 let mut rng = rand::thread_rng();
200 rng.gen_range(0..1000u64)
201 };
202 sleep(Duration::from_millis(
203 (manager_rebalance.config.rebalance_interval_ms as u64) + jitter,
204 ))
205 .await;
206
207 match manager_rebalance.rebalance().await {
208 Ok(leases) => tracing::debug!("Holding {} leases: {:?}", leases.len(), leases),
209 Err(e) => error!("Rebalance failed: {}", e),
210 }
211 }
212 });
213 }
214}