1use crate::MapletResult;
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::sync::Arc;
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::sync::RwLock;
11use tokio::time::{Duration, interval};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TTLConfig {
16 pub cleanup_interval_secs: u64,
18 pub max_cleanup_batch_size: usize,
20 pub enable_background_cleanup: bool,
22}
23
24impl Default for TTLConfig {
25 fn default() -> Self {
26 Self {
27 cleanup_interval_secs: 60, max_cleanup_batch_size: 1000,
29 enable_background_cleanup: true,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct TTLEntry {
37 pub key: String,
39 pub expires_at: u64,
41 pub db_id: u8,
43}
44
45impl TTLEntry {
46 #[must_use]
48 pub fn new(key: String, db_id: u8, ttl_seconds: u64) -> Self {
49 let expires_at = SystemTime::now()
50 .duration_since(UNIX_EPOCH)
51 .unwrap_or_default()
52 .as_secs()
53 + ttl_seconds;
54
55 Self {
56 key,
57 expires_at,
58 db_id,
59 }
60 }
61
62 #[must_use]
64 pub fn is_expired(&self) -> bool {
65 let now = SystemTime::now()
66 .duration_since(UNIX_EPOCH)
67 .unwrap_or_default()
68 .as_secs();
69 now >= self.expires_at
70 }
71
72 #[must_use]
74 pub fn remaining_ttl(&self) -> i64 {
75 let now = SystemTime::now()
76 .duration_since(UNIX_EPOCH)
77 .unwrap_or_default()
78 .as_secs();
79 #[allow(clippy::cast_possible_wrap)]
80 {
81 self.expires_at as i64 - now as i64
82 }
83 }
84}
85
86pub struct TTLManager {
88 config: TTLConfig,
90 expiration_map: Arc<RwLock<BTreeMap<u64, Vec<TTLEntry>>>>,
92 key_to_expiration: Arc<RwLock<std::collections::HashMap<String, u64>>>,
94 cleanup_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
96 shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
98}
99
100impl TTLManager {
101 #[must_use]
103 pub fn new(config: TTLConfig) -> Self {
104 Self {
105 config,
106 expiration_map: Arc::new(RwLock::new(BTreeMap::new())),
107 key_to_expiration: Arc::new(RwLock::new(std::collections::HashMap::new())),
108 cleanup_handle: Arc::new(RwLock::new(None)),
109 shutdown_tx: Arc::new(RwLock::new(None)),
110 }
111 }
112
113 pub async fn set_ttl(&self, key: String, db_id: u8, ttl_seconds: u64) -> MapletResult<()> {
115 let entry = TTLEntry::new(key.clone(), db_id, ttl_seconds);
116 let expires_at = entry.expires_at;
117
118 self.remove_ttl(&key).await?;
120
121 {
123 let mut expiration_map = self.expiration_map.write().await;
124 expiration_map
125 .entry(expires_at)
126 .or_insert_with(Vec::new)
127 .push(entry);
128 }
129
130 {
132 let mut key_map = self.key_to_expiration.write().await;
133 key_map.insert(key, expires_at);
134 }
135
136 Ok(())
137 }
138
139 pub async fn get_ttl(&self, key: &str) -> MapletResult<Option<i64>> {
141 let key_map = self.key_to_expiration.read().await;
142 if let Some(&expires_at) = key_map.get(key) {
143 let now = SystemTime::now()
144 .duration_since(UNIX_EPOCH)
145 .unwrap_or_default()
146 .as_secs();
147 #[allow(clippy::cast_possible_wrap)]
148 let remaining = expires_at as i64 - now as i64;
149 Ok(Some(remaining.max(0)))
150 } else {
151 Ok(None)
152 }
153 }
154
155 pub async fn remove_ttl(&self, key: &str) -> MapletResult<()> {
157 let mut key_map = self.key_to_expiration.write().await;
158 if let Some(expires_at) = key_map.remove(key) {
159 drop(key_map);
160
161 let mut expiration_map = self.expiration_map.write().await;
162 if let Some(entries) = expiration_map.get_mut(&expires_at) {
163 entries.retain(|entry| entry.key != key);
164 if entries.is_empty() {
165 expiration_map.remove(&expires_at);
166 }
167 }
168 }
169
170 Ok(())
171 }
172
173 pub async fn is_expired(&self, key: &str) -> MapletResult<bool> {
175 let key_map = self.key_to_expiration.read().await;
176 if let Some(&expires_at) = key_map.get(key) {
177 let now = SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .unwrap_or_default()
180 .as_secs();
181 Ok(now >= expires_at)
182 } else {
183 Ok(false)
184 }
185 }
186
187 pub async fn get_expired_keys(&self) -> MapletResult<Vec<TTLEntry>> {
189 let now = SystemTime::now()
190 .duration_since(UNIX_EPOCH)
191 .unwrap_or_default()
192 .as_secs();
193
194 let mut expired_entries = Vec::new();
195 let mut expiration_map = self.expiration_map.write().await;
196
197 let expired_times: Vec<u64> = expiration_map
199 .range(..=now)
200 .map(|(&time, _)| time)
201 .collect();
202
203 for time in expired_times {
204 if let Some(entries) = expiration_map.remove(&time) {
205 expired_entries.extend(entries);
206 }
207 }
208
209 let mut key_map = self.key_to_expiration.write().await;
211 for entry in &expired_entries {
212 key_map.remove(&entry.key);
213 }
214
215 Ok(expired_entries)
216 }
217
218 pub async fn start_cleanup<F>(&self, mut cleanup_callback: F) -> MapletResult<()>
220 where
221 F: FnMut(Vec<TTLEntry>) -> MapletResult<()> + Send + Sync + 'static,
222 {
223 if !self.config.enable_background_cleanup {
224 return Ok(());
225 }
226
227 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
228 let expiration_map = Arc::clone(&self.expiration_map);
229 let key_to_expiration = Arc::clone(&self.key_to_expiration);
230 let config = self.config.clone();
231
232 let handle = tokio::spawn(async move {
233 let mut interval = interval(Duration::from_secs(config.cleanup_interval_secs));
234
235 loop {
236 tokio::select! {
237 _ = interval.tick() => {
238 let now = SystemTime::now()
240 .duration_since(UNIX_EPOCH)
241 .unwrap_or_default()
242 .as_secs();
243
244 let mut expired_entries = Vec::new();
245 {
246 let mut expiration_map = expiration_map.write().await;
247 let expired_times: Vec<u64> = expiration_map
248 .range(..=now)
249 .take(config.max_cleanup_batch_size)
250 .map(|(&time, _)| time)
251 .collect();
252
253 for time in expired_times {
254 if let Some(entries) = expiration_map.remove(&time) {
255 expired_entries.extend(entries);
256 }
257 }
258 }
259
260 if !expired_entries.is_empty() {
261 {
263 let mut key_map = key_to_expiration.write().await;
264 for entry in &expired_entries {
265 key_map.remove(&entry.key);
266 }
267 }
268
269 if let Err(e) = cleanup_callback(expired_entries) {
271 eprintln!("TTL cleanup callback error: {e}");
272 }
273 }
274 }
275 _ = &mut shutdown_rx => {
276 break;
277 }
278 }
279 }
280 });
281
282 {
284 let mut cleanup_handle = self.cleanup_handle.write().await;
285 *cleanup_handle = Some(handle);
286 }
287 {
288 let mut shutdown_tx_guard = self.shutdown_tx.write().await;
289 *shutdown_tx_guard = Some(shutdown_tx);
290 }
291
292 Ok(())
293 }
294
295 pub async fn stop_cleanup(&self) -> MapletResult<()> {
297 {
299 let mut shutdown_tx = self.shutdown_tx.write().await;
300 if let Some(tx) = shutdown_tx.take() {
301 let _ = tx.send(());
302 }
303 }
304
305 {
307 let mut cleanup_handle = self.cleanup_handle.write().await;
308 if let Some(handle) = cleanup_handle.take() {
309 let _ = handle.await;
310 }
311 }
312
313 Ok(())
314 }
315
316 pub async fn get_stats(&self) -> MapletResult<TTLStats> {
318 #[allow(clippy::significant_drop_in_scrutinee)] let expiration_map = self.expiration_map.read().await;
320 #[allow(clippy::significant_drop_in_scrutinee)] let key_map = self.key_to_expiration.read().await;
322
323 let total_keys = key_map.len();
324 let now = SystemTime::now()
325 .duration_since(UNIX_EPOCH)
326 .unwrap_or_default()
327 .as_secs();
328
329 let expired_count: usize = expiration_map
330 .range(..=now)
331 .map(|(_, entries)| entries.len())
332 .sum();
333
334 Ok(TTLStats {
335 total_keys_with_ttl: total_keys as u64,
336 expired_keys: expired_count as u64,
337 next_expiration: expiration_map.range(now..).next().map(|(&time, _)| time),
338 })
339 }
340
341 pub async fn clear_all(&self) -> MapletResult<()> {
343 {
344 let mut expiration_map = self.expiration_map.write().await;
345 expiration_map.clear();
346 }
347 {
348 let mut key_map = self.key_to_expiration.write().await;
349 key_map.clear();
350 }
351 Ok(())
352 }
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, Default)]
357pub struct TTLStats {
358 pub total_keys_with_ttl: u64,
360 pub expired_keys: u64,
362 pub next_expiration: Option<u64>,
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[tokio::test]
371 async fn test_ttl_manager_basic_operations() {
372 let config = TTLConfig::default();
373 let manager = TTLManager::new(config);
374
375 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
377
378 let ttl = manager.get_ttl("key1").await.unwrap();
380 assert!(ttl.is_some());
381 assert!(ttl.unwrap() <= 60);
382
383 assert!(!manager.is_expired("key1").await.unwrap());
385
386 manager.remove_ttl("key1").await.unwrap();
388 assert!(manager.get_ttl("key1").await.unwrap().is_none());
389 }
390
391 #[tokio::test]
392 async fn test_ttl_expiration() {
393 let config = TTLConfig::default();
394 let manager = TTLManager::new(config);
395
396 manager.set_ttl("key1".to_string(), 0, 1).await.unwrap();
398
399 tokio::time::sleep(Duration::from_millis(1100)).await;
401
402 assert!(manager.is_expired("key1").await.unwrap());
404
405 let expired = manager.get_expired_keys().await.unwrap();
407 assert!(!expired.is_empty());
408 assert_eq!(expired[0].key, "key1");
409 }
410
411 #[tokio::test]
412 async fn test_ttl_stats() {
413 let config = TTLConfig::default();
414 let manager = TTLManager::new(config);
415
416 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
418 manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
419
420 let stats = manager.get_stats().await.unwrap();
421 assert_eq!(stats.total_keys_with_ttl, 2);
422 assert_eq!(stats.expired_keys, 0);
423 assert!(stats.next_expiration.is_some());
424 }
425
426 #[tokio::test]
427 async fn test_ttl_clear_all() {
428 let config = TTLConfig::default();
429 let manager = TTLManager::new(config);
430
431 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
433 manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
434
435 manager.clear_all().await.unwrap();
437
438 let stats = manager.get_stats().await.unwrap();
439 assert_eq!(stats.total_keys_with_ttl, 0);
440 }
441}