1use crate::MapletResult;
6use std::collections::BTreeMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tokio::time::{Duration, interval};
10use serde::{Serialize, Deserialize};
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TTLConfig {
16 pub cleanup_interval_secs: u64,
18 pub max_cleanup_batch_size: usize,
20 pub enable_background_cleanup: bool,
22}
23
24impl Default for TTLConfig {
25 fn default() -> Self {
26 Self {
27 cleanup_interval_secs: 60, max_cleanup_batch_size: 1000,
29 enable_background_cleanup: true,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct TTLEntry {
37 pub key: String,
39 pub expires_at: u64,
41 pub db_id: u8,
43}
44
45impl TTLEntry {
46 #[must_use]
48 pub fn new(key: String, db_id: u8, ttl_seconds: u64) -> Self {
49 let expires_at = SystemTime::now()
50 .duration_since(UNIX_EPOCH)
51 .unwrap_or_default()
52 .as_secs() + ttl_seconds;
53
54 Self {
55 key,
56 expires_at,
57 db_id,
58 }
59 }
60
61 #[must_use]
63 pub fn is_expired(&self) -> bool {
64 let now = SystemTime::now()
65 .duration_since(UNIX_EPOCH)
66 .unwrap_or_default()
67 .as_secs();
68 now >= self.expires_at
69 }
70
71 #[must_use]
73 pub fn remaining_ttl(&self) -> i64 {
74 let now = SystemTime::now()
75 .duration_since(UNIX_EPOCH)
76 .unwrap_or_default()
77 .as_secs();
78 #[allow(clippy::cast_possible_wrap)]
79 { self.expires_at as i64 - now as i64 }
80 }
81}
82
83pub struct TTLManager {
85 config: TTLConfig,
87 expiration_map: Arc<RwLock<BTreeMap<u64, Vec<TTLEntry>>>>,
89 key_to_expiration: Arc<RwLock<std::collections::HashMap<String, u64>>>,
91 cleanup_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
93 shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
95}
96
97impl TTLManager {
98 #[must_use]
100 pub fn new(config: TTLConfig) -> Self {
101 Self {
102 config,
103 expiration_map: Arc::new(RwLock::new(BTreeMap::new())),
104 key_to_expiration: Arc::new(RwLock::new(std::collections::HashMap::new())),
105 cleanup_handle: Arc::new(RwLock::new(None)),
106 shutdown_tx: Arc::new(RwLock::new(None)),
107 }
108 }
109
110 pub async fn set_ttl(&self, key: String, db_id: u8, ttl_seconds: u64) -> MapletResult<()> {
112 let entry = TTLEntry::new(key.clone(), db_id, ttl_seconds);
113 let expires_at = entry.expires_at;
114
115 self.remove_ttl(&key).await?;
117
118 {
120 let mut expiration_map = self.expiration_map.write().await;
121 expiration_map.entry(expires_at)
122 .or_insert_with(Vec::new)
123 .push(entry);
124 }
125
126 {
128 let mut key_map = self.key_to_expiration.write().await;
129 key_map.insert(key, expires_at);
130 }
131
132 Ok(())
133 }
134
135 pub async fn get_ttl(&self, key: &str) -> MapletResult<Option<i64>> {
137 let key_map = self.key_to_expiration.read().await;
138 if let Some(&expires_at) = key_map.get(key) {
139 let now = SystemTime::now()
140 .duration_since(UNIX_EPOCH)
141 .unwrap_or_default()
142 .as_secs();
143 #[allow(clippy::cast_possible_wrap)]
144 let remaining = expires_at as i64 - now as i64;
145 Ok(Some(remaining.max(0)))
146 } else {
147 Ok(None)
148 }
149 }
150
151 pub async fn remove_ttl(&self, key: &str) -> MapletResult<()> {
153 let mut key_map = self.key_to_expiration.write().await;
154 if let Some(expires_at) = key_map.remove(key) {
155 drop(key_map);
156
157 let mut expiration_map = self.expiration_map.write().await;
158 if let Some(entries) = expiration_map.get_mut(&expires_at) {
159 entries.retain(|entry| entry.key != key);
160 if entries.is_empty() {
161 expiration_map.remove(&expires_at);
162 }
163 }
164 }
165
166 Ok(())
167 }
168
169 pub async fn is_expired(&self, key: &str) -> MapletResult<bool> {
171 let key_map = self.key_to_expiration.read().await;
172 if let Some(&expires_at) = key_map.get(key) {
173 let now = SystemTime::now()
174 .duration_since(UNIX_EPOCH)
175 .unwrap_or_default()
176 .as_secs();
177 Ok(now >= expires_at)
178 } else {
179 Ok(false)
180 }
181 }
182
183 pub async fn get_expired_keys(&self) -> MapletResult<Vec<TTLEntry>> {
185 let now = SystemTime::now()
186 .duration_since(UNIX_EPOCH)
187 .unwrap_or_default()
188 .as_secs();
189
190 let mut expired_entries = Vec::new();
191 let mut expiration_map = self.expiration_map.write().await;
192
193 let expired_times: Vec<u64> = expiration_map
195 .range(..=now)
196 .map(|(&time, _)| time)
197 .collect();
198
199 for time in expired_times {
200 if let Some(entries) = expiration_map.remove(&time) {
201 expired_entries.extend(entries);
202 }
203 }
204
205 let mut key_map = self.key_to_expiration.write().await;
207 for entry in &expired_entries {
208 key_map.remove(&entry.key);
209 }
210
211 Ok(expired_entries)
212 }
213
214 pub async fn start_cleanup<F>(&self, mut cleanup_callback: F) -> MapletResult<()>
216 where
217 F: FnMut(Vec<TTLEntry>) -> MapletResult<()> + Send + Sync + 'static,
218 {
219 if !self.config.enable_background_cleanup {
220 return Ok(());
221 }
222
223 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
224 let expiration_map = Arc::clone(&self.expiration_map);
225 let key_to_expiration = Arc::clone(&self.key_to_expiration);
226 let config = self.config.clone();
227
228 let handle = tokio::spawn(async move {
229 let mut interval = interval(Duration::from_secs(config.cleanup_interval_secs));
230
231 loop {
232 tokio::select! {
233 _ = interval.tick() => {
234 let now = SystemTime::now()
236 .duration_since(UNIX_EPOCH)
237 .unwrap_or_default()
238 .as_secs();
239
240 let mut expired_entries = Vec::new();
241 {
242 let mut expiration_map = expiration_map.write().await;
243 let expired_times: Vec<u64> = expiration_map
244 .range(..=now)
245 .take(config.max_cleanup_batch_size)
246 .map(|(&time, _)| time)
247 .collect();
248
249 for time in expired_times {
250 if let Some(entries) = expiration_map.remove(&time) {
251 expired_entries.extend(entries);
252 }
253 }
254 }
255
256 if !expired_entries.is_empty() {
257 {
259 let mut key_map = key_to_expiration.write().await;
260 for entry in &expired_entries {
261 key_map.remove(&entry.key);
262 }
263 }
264
265 if let Err(e) = cleanup_callback(expired_entries) {
267 eprintln!("TTL cleanup callback error: {e}");
268 }
269 }
270 }
271 _ = &mut shutdown_rx => {
272 break;
273 }
274 }
275 }
276 });
277
278 {
280 let mut cleanup_handle = self.cleanup_handle.write().await;
281 *cleanup_handle = Some(handle);
282 }
283 {
284 let mut shutdown_tx_guard = self.shutdown_tx.write().await;
285 *shutdown_tx_guard = Some(shutdown_tx);
286 }
287
288 Ok(())
289 }
290
291 pub async fn stop_cleanup(&self) -> MapletResult<()> {
293 {
295 let mut shutdown_tx = self.shutdown_tx.write().await;
296 if let Some(tx) = shutdown_tx.take() {
297 let _ = tx.send(());
298 }
299 }
300
301 {
303 let mut cleanup_handle = self.cleanup_handle.write().await;
304 if let Some(handle) = cleanup_handle.take() {
305 let _ = handle.await;
306 }
307 }
308
309 Ok(())
310 }
311
312 pub async fn get_stats(&self) -> MapletResult<TTLStats> {
314 let expiration_map = self.expiration_map.read().await;
315 let key_map = self.key_to_expiration.read().await;
316
317 let total_keys = key_map.len();
318 let now = SystemTime::now()
319 .duration_since(UNIX_EPOCH)
320 .unwrap_or_default()
321 .as_secs();
322
323 let expired_count: usize = expiration_map
324 .range(..=now)
325 .map(|(_, entries)| entries.len())
326 .sum();
327
328 Ok(TTLStats {
329 total_keys_with_ttl: total_keys as u64,
330 expired_keys: expired_count as u64,
331 next_expiration: expiration_map
332 .range(now..)
333 .next()
334 .map(|(&time, _)| time),
335 })
336 }
337
338 pub async fn clear_all(&self) -> MapletResult<()> {
340 {
341 let mut expiration_map = self.expiration_map.write().await;
342 expiration_map.clear();
343 }
344 {
345 let mut key_map = self.key_to_expiration.write().await;
346 key_map.clear();
347 }
348 Ok(())
349 }
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize, Default)]
354pub struct TTLStats {
355 pub total_keys_with_ttl: u64,
357 pub expired_keys: u64,
359 pub next_expiration: Option<u64>,
361}
362
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[tokio::test]
369 async fn test_ttl_manager_basic_operations() {
370 let config = TTLConfig::default();
371 let manager = TTLManager::new(config);
372
373 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
375
376 let ttl = manager.get_ttl("key1").await.unwrap();
378 assert!(ttl.is_some());
379 assert!(ttl.unwrap() <= 60);
380
381 assert!(!manager.is_expired("key1").await.unwrap());
383
384 manager.remove_ttl("key1").await.unwrap();
386 assert!(manager.get_ttl("key1").await.unwrap().is_none());
387 }
388
389 #[tokio::test]
390 async fn test_ttl_expiration() {
391 let config = TTLConfig::default();
392 let manager = TTLManager::new(config);
393
394 manager.set_ttl("key1".to_string(), 0, 1).await.unwrap();
396
397 tokio::time::sleep(Duration::from_millis(1100)).await;
399
400 assert!(manager.is_expired("key1").await.unwrap());
402
403 let expired = manager.get_expired_keys().await.unwrap();
405 assert!(!expired.is_empty());
406 assert_eq!(expired[0].key, "key1");
407 }
408
409 #[tokio::test]
410 async fn test_ttl_stats() {
411 let config = TTLConfig::default();
412 let manager = TTLManager::new(config);
413
414 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
416 manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
417
418 let stats = manager.get_stats().await.unwrap();
419 assert_eq!(stats.total_keys_with_ttl, 2);
420 assert_eq!(stats.expired_keys, 0);
421 assert!(stats.next_expiration.is_some());
422 }
423
424 #[tokio::test]
425 async fn test_ttl_clear_all() {
426 let config = TTLConfig::default();
427 let manager = TTLManager::new(config);
428
429 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
431 manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
432
433 manager.clear_all().await.unwrap();
435
436 let stats = manager.get_stats().await.unwrap();
437 assert_eq!(stats.total_keys_with_ttl, 0);
438 }
439}