kv_core/
ttl.rs

1//! TTL (Time To Live) management for the KV service
2//! 
3//! Handles expiration of keys and background cleanup of expired entries.
4
5use std::collections::{BTreeMap, HashMap};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::interval;
10#[cfg(test)]
11use tokio::time::sleep;
12use chrono::{DateTime, Utc};
13use tracing::debug;
14
15use crate::{KVError, KVResult, Key, Entry, TTL};
16
17/// TTL manager for handling key expiration
18pub struct TTLManager {
19    /// Map of expiration time -> set of keys that expire at that time
20    expiration_map: Arc<RwLock<BTreeMap<DateTime<Utc>, Vec<Key>>>>,
21    /// Map of key -> expiration time (for quick lookup)
22    key_expirations: Arc<RwLock<HashMap<Key, DateTime<Utc>>>>,
23    /// Check interval for expired keys
24    check_interval: Duration,
25    /// Background task handle
26    cleanup_handle: Option<tokio::task::JoinHandle<()>>,
27}
28
29impl TTLManager {
30    /// Create a new TTL manager
31    #[must_use] 
32    pub fn new(check_interval: Duration) -> Self {
33        Self {
34            expiration_map: Arc::new(RwLock::new(BTreeMap::new())),
35            key_expirations: Arc::new(RwLock::new(HashMap::new())),
36            check_interval,
37            cleanup_handle: None,
38        }
39    }
40
41    /// Start the background cleanup task
42    pub fn start_cleanup(&mut self, cleanup_callback: impl Fn(Vec<Key>) + Send + Sync + 'static) {
43        let expiration_map = Arc::clone(&self.expiration_map);
44        let key_expirations = Arc::clone(&self.key_expirations);
45        let check_interval = self.check_interval;
46
47        let handle = tokio::spawn(async move {
48            let mut interval = interval(check_interval);
49            
50            loop {
51                interval.tick().await;
52                
53                let now = Utc::now();
54                let expired_keys = {
55                    let mut exp_map = expiration_map.write().await;
56                    let mut key_exps = key_expirations.write().await;
57                    
58                    let mut expired = Vec::new();
59                    
60                    // Find all keys that have expired
61                    let expired_times: Vec<DateTime<Utc>> = exp_map
62                        .range(..now)
63                        .map(|(time, _)| *time)
64                        .collect();
65                    
66                    for time in expired_times {
67                        if let Some(keys) = exp_map.remove(&time) {
68                            for key in keys {
69                                key_exps.remove(&key);
70                                expired.push(key);
71                            }
72                        }
73                    }
74                    
75                    expired
76                };
77                
78                if !expired_keys.is_empty() {
79                    debug!("Found {} expired keys", expired_keys.len());
80                    cleanup_callback(expired_keys);
81                }
82            }
83        });
84
85        self.cleanup_handle = Some(handle);
86    }
87
88    /// Stop the background cleanup task
89    pub fn stop_cleanup(&mut self) {
90        if let Some(handle) = self.cleanup_handle.take() {
91            handle.abort();
92        }
93    }
94
95    /// Set TTL for a key
96    /// 
97    /// # Errors
98    /// Returns error if TTL operation fails
99    #[allow(clippy::significant_drop_tightening)]
100    pub async fn set_ttl(&self, key: Key, ttl: TTL) -> KVResult<()> {
101        #[allow(clippy::cast_possible_wrap)]
102        let expiration_time = Utc::now() + chrono::Duration::seconds(ttl as i64);
103        
104        let mut exp_map = self.expiration_map.write().await;
105        let mut key_exps = self.key_expirations.write().await;
106        
107        // Remove from old expiration time if exists
108        if let Some(old_time) = key_exps.get(&key)
109            && let Some(keys) = exp_map.get_mut(old_time) {
110                keys.retain(|k| k != &key);
111                if keys.is_empty() {
112                    exp_map.remove(old_time);
113                }
114            }
115        
116        // Add to new expiration time
117        exp_map.entry(expiration_time).or_insert_with(Vec::new).push(key.clone());
118        key_exps.insert(key, expiration_time);
119        
120        Ok(())
121    }
122
123    /// Remove TTL for a key (make it persistent)
124    /// 
125    /// # Errors
126    /// Returns error if TTL removal fails
127    #[allow(clippy::significant_drop_tightening, clippy::option_if_let_else)]
128    pub async fn remove_ttl(&self, key: &Key) -> KVResult<bool> {
129        let mut exp_map = self.expiration_map.write().await;
130        let mut key_exps = self.key_expirations.write().await;
131        
132        if let Some(expiration_time) = key_exps.remove(key) {
133            if let Some(keys) = exp_map.get_mut(&expiration_time) {
134                keys.retain(|k| k != key);
135                if keys.is_empty() {
136                    exp_map.remove(&expiration_time);
137                }
138            }
139            Ok(true)
140        } else {
141            Ok(false)
142        }
143    }
144
145    /// Get remaining TTL for a key
146    /// 
147    /// # Errors
148    /// Returns error if TTL retrieval fails
149    #[allow(clippy::significant_drop_tightening, clippy::option_if_let_else)]
150    pub async fn get_ttl(&self, key: &Key) -> KVResult<Option<TTL>> {
151        let key_exps = self.key_expirations.read().await;
152        
153        if let Some(expiration_time) = key_exps.get(key) {
154            let now = Utc::now();
155            if now < *expiration_time {
156                #[allow(clippy::cast_sign_loss)]
157                let remaining = (*expiration_time - now).num_seconds() as u64;
158                Ok(Some(remaining))
159            } else {
160                Ok(Some(0)) // Expired
161            }
162        } else {
163            Ok(None) // No TTL set
164        }
165    }
166
167    /// Check if a key has expired
168    /// 
169    /// # Errors
170    /// Returns error if expiration check fails
171    #[allow(clippy::significant_drop_tightening, clippy::option_if_let_else)]
172    pub async fn is_expired(&self, key: &Key) -> KVResult<bool> {
173        let key_exps = self.key_expirations.read().await;
174        
175        if let Some(expiration_time) = key_exps.get(key) {
176            Ok(Utc::now() > *expiration_time)
177        } else {
178            Ok(false) // No TTL means not expired
179        }
180    }
181
182    /// Get all keys that will expire within the given duration
183    /// 
184    /// # Errors
185    /// Returns error if duration conversion fails
186    #[allow(clippy::significant_drop_tightening)]
187    pub async fn get_expiring_keys(&self, within: Duration) -> KVResult<Vec<Key>> {
188        let now = Utc::now();
189        let future_time = now + chrono::Duration::from_std(within)
190            .map_err(|e| KVError::TTL(format!("Invalid duration: {e}")))?;
191        
192        let exp_map = self.expiration_map.read().await;
193        let mut expiring_keys = Vec::new();
194        
195        for (_expiration_time, keys) in exp_map.range(now..=future_time) {
196            expiring_keys.extend(keys.clone());
197        }
198        
199        Ok(expiring_keys)
200    }
201
202    /// Get statistics about TTL usage
203    /// 
204    /// # Errors
205    /// Returns error if stats calculation fails
206    #[allow(clippy::significant_drop_tightening)]
207    pub async fn get_stats(&self) -> KVResult<TTLStats> {
208        let exp_map = self.expiration_map.read().await;
209        let key_exps = self.key_expirations.read().await;
210        
211        let now = Utc::now();
212        let mut expired_count = 0;
213        let mut active_count = 0;
214        
215        for expiration_time in key_exps.values() {
216            if *expiration_time <= now {
217                expired_count += 1;
218            } else {
219                active_count += 1;
220            }
221        }
222        
223        Ok(TTLStats {
224            total_keys_with_ttl: key_exps.len(),
225            active_keys: active_count,
226            expired_keys: expired_count,
227            next_expiration: exp_map.keys().next().copied(),
228        })
229    }
230
231    /// Clear all TTL information (for testing or reset)
232    #[allow(clippy::significant_drop_tightening)]
233    pub async fn clear_all(&self) {
234        let mut exp_map = self.expiration_map.write().await;
235        let mut key_exps = self.key_expirations.write().await;
236        
237        exp_map.clear();
238        key_exps.clear();
239    }
240}
241
242/// Statistics for TTL usage
243#[derive(Debug, Clone)]
244pub struct TTLStats {
245    pub total_keys_with_ttl: usize,
246    pub active_keys: usize,
247    pub expired_keys: usize,
248    pub next_expiration: Option<DateTime<Utc>>,
249}
250
251/// Helper trait for entries with TTL support
252pub trait TTLSupport {
253    fn is_expired(&self) -> bool;
254    fn remaining_ttl(&self) -> Option<TTL>;
255    fn set_ttl(&mut self, ttl: TTL);
256    fn remove_ttl(&mut self);
257}
258
259impl TTLSupport for Entry {
260    fn is_expired(&self) -> bool {
261        self.is_expired()
262    }
263
264    fn remaining_ttl(&self) -> Option<TTL> {
265        self.remaining_ttl()
266    }
267
268    fn set_ttl(&mut self, ttl: TTL) {
269        #[allow(clippy::cast_possible_wrap)]
270        let expiration_time = Utc::now() + chrono::Duration::seconds(ttl as i64);
271        self.expires_at = Some(expiration_time);
272    }
273
274    fn remove_ttl(&mut self) {
275        self.expires_at = None;
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use std::sync::Arc;
283    use std::sync::atomic::{AtomicUsize, Ordering};
284
285    #[tokio::test]
286    async fn test_ttl_set_and_get() {
287        let ttl_manager = TTLManager::new(Duration::from_secs(1));
288        
289        // Set TTL for a key
290        ttl_manager.set_ttl("test_key".to_string(), 60).await.unwrap();
291        
292        // Check TTL
293        let ttl = ttl_manager.get_ttl(&"test_key".to_string()).await.unwrap();
294        assert!(ttl.is_some());
295        assert!(ttl.unwrap() <= 60);
296        
297        // Check not expired
298        let expired = ttl_manager.is_expired(&"test_key".to_string()).await.unwrap();
299        assert!(!expired);
300    }
301
302    #[tokio::test]
303    async fn test_ttl_removal() {
304        let ttl_manager = TTLManager::new(Duration::from_secs(1));
305        
306        // Set TTL
307        ttl_manager.set_ttl("test_key".to_string(), 60).await.unwrap();
308        
309        // Remove TTL
310        let removed = ttl_manager.remove_ttl(&"test_key".to_string()).await.unwrap();
311        assert!(removed);
312        
313        // Check TTL is gone
314        let ttl = ttl_manager.get_ttl(&"test_key".to_string()).await.unwrap();
315        assert!(ttl.is_none());
316    }
317
318    #[tokio::test]
319    async fn test_expiration_cleanup() {
320        let mut ttl_manager = TTLManager::new(Duration::from_millis(100));
321        let cleanup_count = Arc::new(AtomicUsize::new(0));
322        let cleanup_count_clone = Arc::clone(&cleanup_count);
323        
324        // Start cleanup with callback
325        ttl_manager.start_cleanup(move |keys| {
326            cleanup_count_clone.fetch_add(keys.len(), Ordering::Relaxed);
327        });
328        
329        // Set a very short TTL
330        ttl_manager.set_ttl("short_ttl_key".to_string(), 1).await.unwrap();
331        
332        // Wait for expiration
333        sleep(Duration::from_millis(1500)).await;
334        
335        // Check that cleanup was called
336        assert!(cleanup_count.load(Ordering::Relaxed) > 0);
337        
338        // Stop cleanup
339        ttl_manager.stop_cleanup();
340    }
341
342    #[tokio::test]
343    async fn test_ttl_stats() {
344        let ttl_manager = TTLManager::new(Duration::from_secs(1));
345        
346        // Set some TTLs
347        ttl_manager.set_ttl("key1".to_string(), 60).await.unwrap();
348        ttl_manager.set_ttl("key2".to_string(), 120).await.unwrap();
349        
350        let stats = ttl_manager.get_stats().await.unwrap();
351        assert_eq!(stats.total_keys_with_ttl, 2);
352        assert_eq!(stats.active_keys, 2);
353        assert_eq!(stats.expired_keys, 0);
354        assert!(stats.next_expiration.is_some());
355    }
356}