1use crate::MapletResult;
6use std::collections::BTreeMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tokio::time::{Duration, interval};
10use serde::{Serialize, Deserialize};
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TTLConfig {
16 pub cleanup_interval_secs: u64,
18 pub max_cleanup_batch_size: usize,
20 pub enable_background_cleanup: bool,
22}
23
24impl Default for TTLConfig {
25 fn default() -> Self {
26 Self {
27 cleanup_interval_secs: 60, max_cleanup_batch_size: 1000,
29 enable_background_cleanup: true,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct TTLEntry {
37 pub key: String,
39 pub expires_at: u64,
41 pub db_id: u8,
43}
44
45impl TTLEntry {
46 pub fn new(key: String, db_id: u8, ttl_seconds: u64) -> Self {
48 let expires_at = SystemTime::now()
49 .duration_since(UNIX_EPOCH)
50 .unwrap_or_default()
51 .as_secs() + ttl_seconds;
52
53 Self {
54 key,
55 db_id,
56 expires_at,
57 }
58 }
59
60 pub fn is_expired(&self) -> bool {
62 let now = SystemTime::now()
63 .duration_since(UNIX_EPOCH)
64 .unwrap_or_default()
65 .as_secs();
66 now >= self.expires_at
67 }
68
69 pub fn remaining_ttl(&self) -> i64 {
71 let now = SystemTime::now()
72 .duration_since(UNIX_EPOCH)
73 .unwrap_or_default()
74 .as_secs();
75 self.expires_at as i64 - now as i64
76 }
77}
78
79pub struct TTLManager {
81 config: TTLConfig,
83 expiration_map: Arc<RwLock<BTreeMap<u64, Vec<TTLEntry>>>>,
85 key_to_expiration: Arc<RwLock<std::collections::HashMap<String, u64>>>,
87 cleanup_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
89 shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
91}
92
93impl TTLManager {
94 pub fn new(config: TTLConfig) -> Self {
96 Self {
97 config,
98 expiration_map: Arc::new(RwLock::new(BTreeMap::new())),
99 key_to_expiration: Arc::new(RwLock::new(std::collections::HashMap::new())),
100 cleanup_handle: Arc::new(RwLock::new(None)),
101 shutdown_tx: Arc::new(RwLock::new(None)),
102 }
103 }
104
105 pub async fn set_ttl(&self, key: String, db_id: u8, ttl_seconds: u64) -> MapletResult<()> {
107 let entry = TTLEntry::new(key.clone(), db_id, ttl_seconds);
108 let expires_at = entry.expires_at;
109
110 self.remove_ttl(&key).await?;
112
113 {
115 let mut expiration_map = self.expiration_map.write().await;
116 expiration_map.entry(expires_at)
117 .or_insert_with(Vec::new)
118 .push(entry);
119 }
120
121 {
123 let mut key_map = self.key_to_expiration.write().await;
124 key_map.insert(key, expires_at);
125 }
126
127 Ok(())
128 }
129
130 pub async fn get_ttl(&self, key: &str) -> MapletResult<Option<i64>> {
132 let key_map = self.key_to_expiration.read().await;
133 if let Some(&expires_at) = key_map.get(key) {
134 let now = SystemTime::now()
135 .duration_since(UNIX_EPOCH)
136 .unwrap_or_default()
137 .as_secs();
138 let remaining = expires_at as i64 - now as i64;
139 Ok(Some(remaining.max(0)))
140 } else {
141 Ok(None)
142 }
143 }
144
145 pub async fn remove_ttl(&self, key: &str) -> MapletResult<()> {
147 let mut key_map = self.key_to_expiration.write().await;
148 if let Some(expires_at) = key_map.remove(key) {
149 drop(key_map);
150
151 let mut expiration_map = self.expiration_map.write().await;
152 if let Some(entries) = expiration_map.get_mut(&expires_at) {
153 entries.retain(|entry| entry.key != key);
154 if entries.is_empty() {
155 expiration_map.remove(&expires_at);
156 }
157 }
158 }
159
160 Ok(())
161 }
162
163 pub async fn is_expired(&self, key: &str) -> MapletResult<bool> {
165 let key_map = self.key_to_expiration.read().await;
166 if let Some(&expires_at) = key_map.get(key) {
167 let now = SystemTime::now()
168 .duration_since(UNIX_EPOCH)
169 .unwrap_or_default()
170 .as_secs();
171 Ok(now >= expires_at)
172 } else {
173 Ok(false)
174 }
175 }
176
177 pub async fn get_expired_keys(&self) -> MapletResult<Vec<TTLEntry>> {
179 let now = SystemTime::now()
180 .duration_since(UNIX_EPOCH)
181 .unwrap_or_default()
182 .as_secs();
183
184 let mut expired_entries = Vec::new();
185 let mut expiration_map = self.expiration_map.write().await;
186
187 let expired_times: Vec<u64> = expiration_map
189 .range(..=now)
190 .map(|(&time, _)| time)
191 .collect();
192
193 for time in expired_times {
194 if let Some(entries) = expiration_map.remove(&time) {
195 expired_entries.extend(entries);
196 }
197 }
198
199 let mut key_map = self.key_to_expiration.write().await;
201 for entry in &expired_entries {
202 key_map.remove(&entry.key);
203 }
204
205 Ok(expired_entries)
206 }
207
208 pub async fn start_cleanup<F>(&self, mut cleanup_callback: F) -> MapletResult<()>
210 where
211 F: FnMut(Vec<TTLEntry>) -> MapletResult<()> + Send + Sync + 'static,
212 {
213 if !self.config.enable_background_cleanup {
214 return Ok(());
215 }
216
217 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
218 let expiration_map = Arc::clone(&self.expiration_map);
219 let key_to_expiration = Arc::clone(&self.key_to_expiration);
220 let config = self.config.clone();
221
222 let handle = tokio::spawn(async move {
223 let mut interval = interval(Duration::from_secs(config.cleanup_interval_secs));
224
225 loop {
226 tokio::select! {
227 _ = interval.tick() => {
228 let now = SystemTime::now()
230 .duration_since(UNIX_EPOCH)
231 .unwrap_or_default()
232 .as_secs();
233
234 let mut expired_entries = Vec::new();
235 {
236 let mut expiration_map = expiration_map.write().await;
237 let expired_times: Vec<u64> = expiration_map
238 .range(..=now)
239 .take(config.max_cleanup_batch_size)
240 .map(|(&time, _)| time)
241 .collect();
242
243 for time in expired_times {
244 if let Some(entries) = expiration_map.remove(&time) {
245 expired_entries.extend(entries);
246 }
247 }
248 }
249
250 if !expired_entries.is_empty() {
251 {
253 let mut key_map = key_to_expiration.write().await;
254 for entry in &expired_entries {
255 key_map.remove(&entry.key);
256 }
257 }
258
259 if let Err(e) = cleanup_callback(expired_entries) {
261 eprintln!("TTL cleanup callback error: {}", e);
262 }
263 }
264 }
265 _ = &mut shutdown_rx => {
266 break;
267 }
268 }
269 }
270 });
271
272 {
274 let mut cleanup_handle = self.cleanup_handle.write().await;
275 *cleanup_handle = Some(handle);
276 }
277 {
278 let mut shutdown_tx_guard = self.shutdown_tx.write().await;
279 *shutdown_tx_guard = Some(shutdown_tx);
280 }
281
282 Ok(())
283 }
284
285 pub async fn stop_cleanup(&self) -> MapletResult<()> {
287 {
289 let mut shutdown_tx = self.shutdown_tx.write().await;
290 if let Some(tx) = shutdown_tx.take() {
291 let _ = tx.send(());
292 }
293 }
294
295 {
297 let mut cleanup_handle = self.cleanup_handle.write().await;
298 if let Some(handle) = cleanup_handle.take() {
299 let _ = handle.await;
300 }
301 }
302
303 Ok(())
304 }
305
306 pub async fn get_stats(&self) -> MapletResult<TTLStats> {
308 let expiration_map = self.expiration_map.read().await;
309 let key_map = self.key_to_expiration.read().await;
310
311 let total_keys = key_map.len();
312 let now = SystemTime::now()
313 .duration_since(UNIX_EPOCH)
314 .unwrap_or_default()
315 .as_secs();
316
317 let expired_count: usize = expiration_map
318 .range(..=now)
319 .map(|(_, entries)| entries.len())
320 .sum();
321
322 Ok(TTLStats {
323 total_keys_with_ttl: total_keys as u64,
324 expired_keys: expired_count as u64,
325 next_expiration: expiration_map
326 .range(now..)
327 .next()
328 .map(|(&time, _)| time),
329 })
330 }
331
332 pub async fn clear_all(&self) -> MapletResult<()> {
334 {
335 let mut expiration_map = self.expiration_map.write().await;
336 expiration_map.clear();
337 }
338 {
339 let mut key_map = self.key_to_expiration.write().await;
340 key_map.clear();
341 }
342 Ok(())
343 }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct TTLStats {
349 pub total_keys_with_ttl: u64,
351 pub expired_keys: u64,
353 pub next_expiration: Option<u64>,
355}
356
357impl Default for TTLStats {
358 fn default() -> Self {
359 Self {
360 total_keys_with_ttl: 0,
361 expired_keys: 0,
362 next_expiration: None,
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[tokio::test]
372 async fn test_ttl_manager_basic_operations() {
373 let config = TTLConfig::default();
374 let manager = TTLManager::new(config);
375
376 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
378
379 let ttl = manager.get_ttl("key1").await.unwrap();
381 assert!(ttl.is_some());
382 assert!(ttl.unwrap() <= 60);
383
384 assert!(!manager.is_expired("key1").await.unwrap());
386
387 manager.remove_ttl("key1").await.unwrap();
389 assert!(manager.get_ttl("key1").await.unwrap().is_none());
390 }
391
392 #[tokio::test]
393 async fn test_ttl_expiration() {
394 let config = TTLConfig::default();
395 let manager = TTLManager::new(config);
396
397 manager.set_ttl("key1".to_string(), 0, 1).await.unwrap();
399
400 tokio::time::sleep(Duration::from_millis(1100)).await;
402
403 assert!(manager.is_expired("key1").await.unwrap());
405
406 let expired = manager.get_expired_keys().await.unwrap();
408 assert!(!expired.is_empty());
409 assert_eq!(expired[0].key, "key1");
410 }
411
412 #[tokio::test]
413 async fn test_ttl_stats() {
414 let config = TTLConfig::default();
415 let manager = TTLManager::new(config);
416
417 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
419 manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
420
421 let stats = manager.get_stats().await.unwrap();
422 assert_eq!(stats.total_keys_with_ttl, 2);
423 assert_eq!(stats.expired_keys, 0);
424 assert!(stats.next_expiration.is_some());
425 }
426
427 #[tokio::test]
428 async fn test_ttl_clear_all() {
429 let config = TTLConfig::default();
430 let manager = TTLManager::new(config);
431
432 manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
434 manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
435
436 manager.clear_all().await.unwrap();
438
439 let stats = manager.get_stats().await.unwrap();
440 assert_eq!(stats.total_keys_with_ttl, 0);
441 }
442}