use crate::MapletResult;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio::time::{Duration, interval};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TTLConfig {
pub cleanup_interval_secs: u64,
pub max_cleanup_batch_size: usize,
pub enable_background_cleanup: bool,
}
impl Default for TTLConfig {
fn default() -> Self {
Self {
cleanup_interval_secs: 60, max_cleanup_batch_size: 1000,
enable_background_cleanup: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TTLEntry {
pub key: String,
pub expires_at: u64,
pub db_id: u8,
}
impl TTLEntry {
#[must_use]
pub fn new(key: String, db_id: u8, ttl_seconds: u64) -> Self {
let expires_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ ttl_seconds;
Self {
key,
expires_at,
db_id,
}
}
#[must_use]
pub fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
now >= self.expires_at
}
#[must_use]
pub fn remaining_ttl(&self) -> i64 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
#[allow(clippy::cast_possible_wrap)]
{
self.expires_at as i64 - now as i64
}
}
}
pub struct TTLManager {
config: TTLConfig,
expiration_map: Arc<RwLock<BTreeMap<u64, Vec<TTLEntry>>>>,
key_to_expiration: Arc<RwLock<std::collections::HashMap<String, u64>>>,
cleanup_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
}
impl TTLManager {
#[must_use]
pub fn new(config: TTLConfig) -> Self {
Self {
config,
expiration_map: Arc::new(RwLock::new(BTreeMap::new())),
key_to_expiration: Arc::new(RwLock::new(std::collections::HashMap::new())),
cleanup_handle: Arc::new(RwLock::new(None)),
shutdown_tx: Arc::new(RwLock::new(None)),
}
}
pub async fn set_ttl(&self, key: String, db_id: u8, ttl_seconds: u64) -> MapletResult<()> {
let entry = TTLEntry::new(key.clone(), db_id, ttl_seconds);
let expires_at = entry.expires_at;
self.remove_ttl(&key).await?;
{
let mut expiration_map = self.expiration_map.write().await;
expiration_map
.entry(expires_at)
.or_insert_with(Vec::new)
.push(entry);
}
{
let mut key_map = self.key_to_expiration.write().await;
key_map.insert(key, expires_at);
}
Ok(())
}
pub async fn get_ttl(&self, key: &str) -> MapletResult<Option<i64>> {
let key_map = self.key_to_expiration.read().await;
if let Some(&expires_at) = key_map.get(key) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
#[allow(clippy::cast_possible_wrap)]
let remaining = expires_at as i64 - now as i64;
Ok(Some(remaining.max(0)))
} else {
Ok(None)
}
}
pub async fn remove_ttl(&self, key: &str) -> MapletResult<()> {
let mut key_map = self.key_to_expiration.write().await;
if let Some(expires_at) = key_map.remove(key) {
drop(key_map);
let mut expiration_map = self.expiration_map.write().await;
if let Some(entries) = expiration_map.get_mut(&expires_at) {
entries.retain(|entry| entry.key != key);
if entries.is_empty() {
expiration_map.remove(&expires_at);
}
}
}
Ok(())
}
pub async fn is_expired(&self, key: &str) -> MapletResult<bool> {
let key_map = self.key_to_expiration.read().await;
if let Some(&expires_at) = key_map.get(key) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Ok(now >= expires_at)
} else {
Ok(false)
}
}
pub async fn get_expired_keys(&self) -> MapletResult<Vec<TTLEntry>> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut expired_entries = Vec::new();
let mut expiration_map = self.expiration_map.write().await;
let expired_times: Vec<u64> = expiration_map
.range(..=now)
.map(|(&time, _)| time)
.collect();
for time in expired_times {
if let Some(entries) = expiration_map.remove(&time) {
expired_entries.extend(entries);
}
}
let mut key_map = self.key_to_expiration.write().await;
for entry in &expired_entries {
key_map.remove(&entry.key);
}
Ok(expired_entries)
}
pub async fn start_cleanup<F>(&self, mut cleanup_callback: F) -> MapletResult<()>
where
F: FnMut(Vec<TTLEntry>) -> MapletResult<()> + Send + Sync + 'static,
{
if !self.config.enable_background_cleanup {
return Ok(());
}
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
let expiration_map = Arc::clone(&self.expiration_map);
let key_to_expiration = Arc::clone(&self.key_to_expiration);
let config = self.config.clone();
let handle = tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(config.cleanup_interval_secs));
loop {
tokio::select! {
_ = interval.tick() => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut expired_entries = Vec::new();
{
let mut expiration_map = expiration_map.write().await;
let expired_times: Vec<u64> = expiration_map
.range(..=now)
.take(config.max_cleanup_batch_size)
.map(|(&time, _)| time)
.collect();
for time in expired_times {
if let Some(entries) = expiration_map.remove(&time) {
expired_entries.extend(entries);
}
}
}
if !expired_entries.is_empty() {
{
let mut key_map = key_to_expiration.write().await;
for entry in &expired_entries {
key_map.remove(&entry.key);
}
}
if let Err(e) = cleanup_callback(expired_entries) {
eprintln!("TTL cleanup callback error: {e}");
}
}
}
_ = &mut shutdown_rx => {
break;
}
}
}
});
{
let mut cleanup_handle = self.cleanup_handle.write().await;
*cleanup_handle = Some(handle);
}
{
let mut shutdown_tx_guard = self.shutdown_tx.write().await;
*shutdown_tx_guard = Some(shutdown_tx);
}
Ok(())
}
pub async fn stop_cleanup(&self) -> MapletResult<()> {
{
let mut shutdown_tx = self.shutdown_tx.write().await;
if let Some(tx) = shutdown_tx.take() {
let _ = tx.send(());
}
}
{
let mut cleanup_handle = self.cleanup_handle.write().await;
if let Some(handle) = cleanup_handle.take() {
let _ = handle.await;
}
}
Ok(())
}
pub async fn get_stats(&self) -> MapletResult<TTLStats> {
#[allow(clippy::significant_drop_in_scrutinee)] let expiration_map = self.expiration_map.read().await;
#[allow(clippy::significant_drop_in_scrutinee)] let key_map = self.key_to_expiration.read().await;
let total_keys = key_map.len();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let expired_count: usize = expiration_map
.range(..=now)
.map(|(_, entries)| entries.len())
.sum();
Ok(TTLStats {
total_keys_with_ttl: total_keys as u64,
expired_keys: expired_count as u64,
next_expiration: expiration_map.range(now..).next().map(|(&time, _)| time),
})
}
pub async fn clear_all(&self) -> MapletResult<()> {
{
let mut expiration_map = self.expiration_map.write().await;
expiration_map.clear();
}
{
let mut key_map = self.key_to_expiration.write().await;
key_map.clear();
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TTLStats {
pub total_keys_with_ttl: u64,
pub expired_keys: u64,
pub next_expiration: Option<u64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_ttl_manager_basic_operations() {
let config = TTLConfig::default();
let manager = TTLManager::new(config);
manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
let ttl = manager.get_ttl("key1").await.unwrap();
assert!(ttl.is_some());
assert!(ttl.unwrap() <= 60);
assert!(!manager.is_expired("key1").await.unwrap());
manager.remove_ttl("key1").await.unwrap();
assert!(manager.get_ttl("key1").await.unwrap().is_none());
}
#[tokio::test]
async fn test_ttl_expiration() {
let config = TTLConfig::default();
let manager = TTLManager::new(config);
manager.set_ttl("key1".to_string(), 0, 1).await.unwrap();
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(manager.is_expired("key1").await.unwrap());
let expired = manager.get_expired_keys().await.unwrap();
assert!(!expired.is_empty());
assert_eq!(expired[0].key, "key1");
}
#[tokio::test]
async fn test_ttl_stats() {
let config = TTLConfig::default();
let manager = TTLManager::new(config);
manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
let stats = manager.get_stats().await.unwrap();
assert_eq!(stats.total_keys_with_ttl, 2);
assert_eq!(stats.expired_keys, 0);
assert!(stats.next_expiration.is_some());
}
#[tokio::test]
async fn test_ttl_clear_all() {
let config = TTLConfig::default();
let manager = TTLManager::new(config);
manager.set_ttl("key1".to_string(), 0, 60).await.unwrap();
manager.set_ttl("key2".to_string(), 0, 120).await.unwrap();
manager.clear_all().await.unwrap();
let stats = manager.get_stats().await.unwrap();
assert_eq!(stats.total_keys_with_ttl, 0);
}
}