1use std::collections::{BTreeMap, HashMap};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::interval;
10#[cfg(test)]
11use tokio::time::sleep;
12use chrono::{DateTime, Utc};
13use tracing::debug;
14
15use crate::{KVError, KVResult, Key, Entry, TTL};
16
17pub struct TTLManager {
19 expiration_map: Arc<RwLock<BTreeMap<DateTime<Utc>, Vec<Key>>>>,
21 key_expirations: Arc<RwLock<HashMap<Key, DateTime<Utc>>>>,
23 check_interval: Duration,
25 cleanup_handle: Option<tokio::task::JoinHandle<()>>,
27}
28
29impl TTLManager {
30 #[must_use]
32 pub fn new(check_interval: Duration) -> Self {
33 Self {
34 expiration_map: Arc::new(RwLock::new(BTreeMap::new())),
35 key_expirations: Arc::new(RwLock::new(HashMap::new())),
36 check_interval,
37 cleanup_handle: None,
38 }
39 }
40
41 pub fn start_cleanup(&mut self, cleanup_callback: impl Fn(Vec<Key>) + Send + Sync + 'static) {
43 let expiration_map = Arc::clone(&self.expiration_map);
44 let key_expirations = Arc::clone(&self.key_expirations);
45 let check_interval = self.check_interval;
46
47 let handle = tokio::spawn(async move {
48 let mut interval = interval(check_interval);
49
50 loop {
51 interval.tick().await;
52
53 let now = Utc::now();
54 let expired_keys = {
55 let mut exp_map = expiration_map.write().await;
56 let mut key_exps = key_expirations.write().await;
57
58 let mut expired = Vec::new();
59
60 let expired_times: Vec<DateTime<Utc>> = exp_map
62 .range(..now)
63 .map(|(time, _)| *time)
64 .collect();
65
66 for time in expired_times {
67 if let Some(keys) = exp_map.remove(&time) {
68 for key in keys {
69 key_exps.remove(&key);
70 expired.push(key);
71 }
72 }
73 }
74
75 expired
76 };
77
78 if !expired_keys.is_empty() {
79 debug!("Found {} expired keys", expired_keys.len());
80 cleanup_callback(expired_keys);
81 }
82 }
83 });
84
85 self.cleanup_handle = Some(handle);
86 }
87
88 pub fn stop_cleanup(&mut self) {
90 if let Some(handle) = self.cleanup_handle.take() {
91 handle.abort();
92 }
93 }
94
95 #[allow(clippy::significant_drop_tightening)]
100 pub async fn set_ttl(&self, key: Key, ttl: TTL) -> KVResult<()> {
101 #[allow(clippy::cast_possible_wrap)]
102 let expiration_time = Utc::now() + chrono::Duration::seconds(ttl as i64);
103
104 let mut exp_map = self.expiration_map.write().await;
105 let mut key_exps = self.key_expirations.write().await;
106
107 if let Some(old_time) = key_exps.get(&key)
109 && let Some(keys) = exp_map.get_mut(old_time) {
110 keys.retain(|k| k != &key);
111 if keys.is_empty() {
112 exp_map.remove(old_time);
113 }
114 }
115
116 exp_map.entry(expiration_time).or_insert_with(Vec::new).push(key.clone());
118 key_exps.insert(key, expiration_time);
119
120 Ok(())
121 }
122
123 #[allow(clippy::significant_drop_tightening, clippy::option_if_let_else)]
128 pub async fn remove_ttl(&self, key: &Key) -> KVResult<bool> {
129 let mut exp_map = self.expiration_map.write().await;
130 let mut key_exps = self.key_expirations.write().await;
131
132 if let Some(expiration_time) = key_exps.remove(key) {
133 if let Some(keys) = exp_map.get_mut(&expiration_time) {
134 keys.retain(|k| k != key);
135 if keys.is_empty() {
136 exp_map.remove(&expiration_time);
137 }
138 }
139 Ok(true)
140 } else {
141 Ok(false)
142 }
143 }
144
145 #[allow(clippy::significant_drop_tightening, clippy::option_if_let_else)]
150 pub async fn get_ttl(&self, key: &Key) -> KVResult<Option<TTL>> {
151 let key_exps = self.key_expirations.read().await;
152
153 if let Some(expiration_time) = key_exps.get(key) {
154 let now = Utc::now();
155 if now < *expiration_time {
156 #[allow(clippy::cast_sign_loss)]
157 let remaining = (*expiration_time - now).num_seconds() as u64;
158 Ok(Some(remaining))
159 } else {
160 Ok(Some(0)) }
162 } else {
163 Ok(None) }
165 }
166
167 #[allow(clippy::significant_drop_tightening, clippy::option_if_let_else)]
172 pub async fn is_expired(&self, key: &Key) -> KVResult<bool> {
173 let key_exps = self.key_expirations.read().await;
174
175 if let Some(expiration_time) = key_exps.get(key) {
176 Ok(Utc::now() > *expiration_time)
177 } else {
178 Ok(false) }
180 }
181
182 #[allow(clippy::significant_drop_tightening)]
187 pub async fn get_expiring_keys(&self, within: Duration) -> KVResult<Vec<Key>> {
188 let now = Utc::now();
189 let future_time = now + chrono::Duration::from_std(within)
190 .map_err(|e| KVError::TTL(format!("Invalid duration: {e}")))?;
191
192 let exp_map = self.expiration_map.read().await;
193 let mut expiring_keys = Vec::new();
194
195 for (_expiration_time, keys) in exp_map.range(now..=future_time) {
196 expiring_keys.extend(keys.clone());
197 }
198
199 Ok(expiring_keys)
200 }
201
202 #[allow(clippy::significant_drop_tightening)]
207 pub async fn get_stats(&self) -> KVResult<TTLStats> {
208 let exp_map = self.expiration_map.read().await;
209 let key_exps = self.key_expirations.read().await;
210
211 let now = Utc::now();
212 let mut expired_count = 0;
213 let mut active_count = 0;
214
215 for expiration_time in key_exps.values() {
216 if *expiration_time <= now {
217 expired_count += 1;
218 } else {
219 active_count += 1;
220 }
221 }
222
223 Ok(TTLStats {
224 total_keys_with_ttl: key_exps.len(),
225 active_keys: active_count,
226 expired_keys: expired_count,
227 next_expiration: exp_map.keys().next().copied(),
228 })
229 }
230
231 #[allow(clippy::significant_drop_tightening)]
233 pub async fn clear_all(&self) {
234 let mut exp_map = self.expiration_map.write().await;
235 let mut key_exps = self.key_expirations.write().await;
236
237 exp_map.clear();
238 key_exps.clear();
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct TTLStats {
245 pub total_keys_with_ttl: usize,
246 pub active_keys: usize,
247 pub expired_keys: usize,
248 pub next_expiration: Option<DateTime<Utc>>,
249}
250
251pub trait TTLSupport {
253 fn is_expired(&self) -> bool;
254 fn remaining_ttl(&self) -> Option<TTL>;
255 fn set_ttl(&mut self, ttl: TTL);
256 fn remove_ttl(&mut self);
257}
258
259impl TTLSupport for Entry {
260 fn is_expired(&self) -> bool {
261 self.is_expired()
262 }
263
264 fn remaining_ttl(&self) -> Option<TTL> {
265 self.remaining_ttl()
266 }
267
268 fn set_ttl(&mut self, ttl: TTL) {
269 #[allow(clippy::cast_possible_wrap)]
270 let expiration_time = Utc::now() + chrono::Duration::seconds(ttl as i64);
271 self.expires_at = Some(expiration_time);
272 }
273
274 fn remove_ttl(&mut self) {
275 self.expires_at = None;
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::sync::Arc;
283 use std::sync::atomic::{AtomicUsize, Ordering};
284
285 #[tokio::test]
286 async fn test_ttl_set_and_get() {
287 let ttl_manager = TTLManager::new(Duration::from_secs(1));
288
289 ttl_manager.set_ttl("test_key".to_string(), 60).await.unwrap();
291
292 let ttl = ttl_manager.get_ttl(&"test_key".to_string()).await.unwrap();
294 assert!(ttl.is_some());
295 assert!(ttl.unwrap() <= 60);
296
297 let expired = ttl_manager.is_expired(&"test_key".to_string()).await.unwrap();
299 assert!(!expired);
300 }
301
302 #[tokio::test]
303 async fn test_ttl_removal() {
304 let ttl_manager = TTLManager::new(Duration::from_secs(1));
305
306 ttl_manager.set_ttl("test_key".to_string(), 60).await.unwrap();
308
309 let removed = ttl_manager.remove_ttl(&"test_key".to_string()).await.unwrap();
311 assert!(removed);
312
313 let ttl = ttl_manager.get_ttl(&"test_key".to_string()).await.unwrap();
315 assert!(ttl.is_none());
316 }
317
318 #[tokio::test]
319 async fn test_expiration_cleanup() {
320 let mut ttl_manager = TTLManager::new(Duration::from_millis(100));
321 let cleanup_count = Arc::new(AtomicUsize::new(0));
322 let cleanup_count_clone = Arc::clone(&cleanup_count);
323
324 ttl_manager.start_cleanup(move |keys| {
326 cleanup_count_clone.fetch_add(keys.len(), Ordering::Relaxed);
327 });
328
329 ttl_manager.set_ttl("short_ttl_key".to_string(), 1).await.unwrap();
331
332 sleep(Duration::from_millis(1500)).await;
334
335 assert!(cleanup_count.load(Ordering::Relaxed) > 0);
337
338 ttl_manager.stop_cleanup();
340 }
341
342 #[tokio::test]
343 async fn test_ttl_stats() {
344 let ttl_manager = TTLManager::new(Duration::from_secs(1));
345
346 ttl_manager.set_ttl("key1".to_string(), 60).await.unwrap();
348 ttl_manager.set_ttl("key2".to_string(), 120).await.unwrap();
349
350 let stats = ttl_manager.get_stats().await.unwrap();
351 assert_eq!(stats.total_keys_with_ttl, 2);
352 assert_eq!(stats.active_keys, 2);
353 assert_eq!(stats.expired_keys, 0);
354 assert!(stats.next_expiration.is_some());
355 }
356}