use std::{
collections::VecDeque,
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap;
use crate::types::{Channel, Guild, Message, Role, User};
pub const DEFAULT_MAX_MESSAGES: usize = 100;
#[derive(Debug, Clone)]
pub struct CacheSettings {
pub cache_guilds: bool,
pub cache_users: bool,
pub cache_messages: bool,
pub max_messages: usize,
pub time_to_live: Option<Duration>,
}
impl Default for CacheSettings {
fn default() -> Self {
Self {
cache_guilds: true,
cache_users: true,
cache_messages: true,
max_messages: DEFAULT_MAX_MESSAGES,
time_to_live: None,
}
}
}
struct MessageRing {
messages: VecDeque<Message>,
max: usize,
}
impl MessageRing {
fn new(max: usize) -> Self {
Self { messages: VecDeque::with_capacity(max.min(128)), max }
}
fn push(&mut self, msg: Message) {
if let Some(pos) = self.messages.iter().position(|m| m.id == msg.id) {
self.messages[pos] = msg;
return;
}
if self.max == 0 {
return;
}
if self.messages.len() >= self.max {
self.messages.pop_front(); }
self.messages.push_back(msg);
}
fn remove(&mut self, message_id: &str) -> Option<Message> {
if let Some(pos) = self.messages.iter().position(|m| m.id == message_id) {
return self.messages.remove(pos);
}
None
}
fn get(&self, message_id: &str) -> Option<&Message> {
self.messages.iter().find(|m| m.id == message_id)
}
fn all(&self) -> Vec<Message> {
self.messages.iter().cloned().collect()
}
}
#[derive(Clone)]
pub struct Cache {
settings: CacheSettings,
guilds: Arc<DashMap<String, Guild>>,
guild_timestamps: Arc<DashMap<String, Instant>>,
users: Arc<DashMap<String, User>>,
user_timestamps: Arc<DashMap<String, Instant>>,
messages: Arc<DashMap<String, MessageRing>>,
channels: Arc<DashMap<String, Channel>>,
roles: Arc<DashMap<String, Role>>,
}
impl Default for Cache {
fn default() -> Self {
Self::with_settings(CacheSettings::default())
}
}
impl Cache {
pub fn new() -> Self {
Self::default()
}
pub fn with_settings(settings: CacheSettings) -> Self {
Self {
settings,
guilds: Arc::new(DashMap::new()),
guild_timestamps: Arc::new(DashMap::new()),
users: Arc::new(DashMap::new()),
user_timestamps: Arc::new(DashMap::new()),
messages: Arc::new(DashMap::new()),
channels: Arc::new(DashMap::new()),
roles: Arc::new(DashMap::new()),
}
}
pub fn with_max_messages(max_messages: usize) -> Self {
Self::with_settings(CacheSettings { max_messages, ..CacheSettings::default() })
}
pub fn settings(&self) -> &CacheSettings {
&self.settings
}
pub fn guild(&self, guild_id: &str) -> Option<Guild> {
if let Some(ttl) = self.settings.time_to_live {
if let Some(ts) = self.guild_timestamps.get(guild_id) {
if ts.elapsed() > ttl {
drop(ts);
self.guilds.remove(guild_id);
self.guild_timestamps.remove(guild_id);
return None;
}
}
}
self.guilds.get(guild_id).map(|g| g.clone())
}
pub fn guilds(&self) -> Vec<Guild> {
if let Some(ttl) = self.settings.time_to_live {
let expired: Vec<String> = self.guild_timestamps.iter().filter(|r| r.value().elapsed() > ttl).map(|r| r.key().clone()).collect();
for id in expired {
self.guilds.remove(&id);
self.guild_timestamps.remove(&id);
}
}
self.guilds.iter().map(|r| r.value().clone()).collect()
}
pub fn guild_count(&self) -> usize {
self.guilds.len()
}
pub fn user(&self, user_id: &str) -> Option<User> {
if let Some(ttl) = self.settings.time_to_live {
if let Some(ts) = self.user_timestamps.get(user_id) {
if ts.elapsed() > ttl {
drop(ts);
self.users.remove(user_id);
self.user_timestamps.remove(user_id);
return None;
}
}
}
self.users.get(user_id).map(|u| u.clone())
}
pub fn users(&self) -> Vec<User> {
if let Some(ttl) = self.settings.time_to_live {
let expired: Vec<String> = self.user_timestamps.iter().filter(|r| r.value().elapsed() > ttl).map(|r| r.key().clone()).collect();
for id in expired {
self.users.remove(&id);
self.user_timestamps.remove(&id);
}
}
self.users.iter().map(|r| r.value().clone()).collect()
}
pub fn user_count(&self) -> usize {
self.users.len()
}
pub fn message(&self, channel_id: &str, message_id: &str) -> Option<Message> {
self.messages.get(channel_id)?.get(message_id).cloned()
}
pub fn channel_messages(&self, channel_id: &str) -> Vec<Message> {
self.messages.get(channel_id).map(|r| r.all()).unwrap_or_default()
}
pub fn message_count(&self) -> usize {
self.messages.iter().map(|r| r.messages.len()).sum()
}
pub fn channel(&self, channel_id: &str) -> Option<Channel> {
self.channels.get(channel_id).map(|c| c.clone())
}
pub fn channels(&self) -> Vec<Channel> {
self.channels.iter().map(|r| r.value().clone()).collect()
}
pub fn channel_count(&self) -> usize {
self.channels.len()
}
pub fn role(&self, role_id: &str) -> Option<Role> {
self.roles.get(role_id).map(|r| r.clone())
}
pub fn roles(&self) -> Vec<Role> {
self.roles.iter().map(|r| r.value().clone()).collect()
}
pub fn role_count(&self) -> usize {
self.roles.len()
}
pub(crate) fn upsert_guild(&self, guild: Guild) {
if self.settings.cache_guilds {
self.guild_timestamps.insert(guild.id.clone(), Instant::now());
self.guilds.insert(guild.id.clone(), guild);
}
}
pub(crate) fn remove_guild(&self, guild_id: &str) -> Option<Guild> {
self.guild_timestamps.remove(guild_id);
self.guilds.remove(guild_id).map(|(_, g)| g)
}
pub(crate) fn upsert_user(&self, user: User) {
if self.settings.cache_users {
self.user_timestamps.insert(user.id.clone(), Instant::now());
self.users.insert(user.id.clone(), user);
}
}
pub(crate) fn upsert_message(&self, msg: Message) {
if !self.settings.cache_messages || self.settings.max_messages == 0 {
return;
}
let max = self.settings.max_messages;
self.messages.entry(msg.channel_id.clone()).or_insert_with(|| MessageRing::new(max)).push(msg);
}
pub(crate) fn remove_message(&self, channel_id: &str, message_id: &str) -> Option<Message> {
self.messages.get_mut(channel_id)?.remove(message_id)
}
pub(crate) fn upsert_channel(&self, channel: Channel) {
self.channels.insert(channel.id.clone(), channel);
}
pub(crate) fn remove_channel(&self, channel_id: &str) -> Option<Channel> {
self.channels.remove(channel_id).map(|(_, c)| c)
}
pub(crate) fn upsert_role(&self, role: Role) {
self.roles.insert(role.id.clone(), role);
}
pub(crate) fn remove_role(&self, role_id: &str) -> Option<Role> {
self.roles.remove(role_id).map(|(_, r)| r)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Guild;
fn make_guild(id: &str, name: &str) -> Guild {
Guild {
id: id.to_string(),
name: Some(name.to_string()),
..Default::default()
}
}
#[test]
fn upsert_and_get_guild() {
let cache = Cache::new();
cache.upsert_guild(make_guild("111", "Test Guild"));
let g = cache.guild("111").expect("guild should be cached");
assert_eq!(g.name.as_deref(), Some("Test Guild"));
}
#[test]
fn update_guild_replaces_entry() {
let cache = Cache::new();
cache.upsert_guild(make_guild("222", "Old Name"));
cache.upsert_guild(make_guild("222", "New Name"));
assert_eq!(cache.guild("222").unwrap().name.as_deref(), Some("New Name"));
assert_eq!(cache.guild_count(), 1);
}
#[test]
fn remove_guild() {
let cache = Cache::new();
cache.upsert_guild(make_guild("333", "To Remove"));
assert!(cache.remove_guild("333").is_some());
assert!(cache.guild("333").is_none());
}
#[test]
fn guilds_list_returns_all() {
let cache = Cache::new();
cache.upsert_guild(make_guild("1", "A"));
cache.upsert_guild(make_guild("2", "B"));
assert_eq!(cache.guild_count(), 2);
assert_eq!(cache.guilds().len(), 2);
}
fn make_user(id: &str, name: &str) -> User {
serde_json::from_value(serde_json::json!({
"id": id,
"username": name,
"discriminator": "0000",
"avatar": null
}))
.unwrap()
}
#[test]
fn upsert_and_get_user() {
let cache = Cache::new();
cache.upsert_user(make_user("u1", "Alice"));
let u = cache.user("u1").expect("user should be cached");
assert_eq!(u.username, "Alice");
}
#[test]
fn update_user_replaces_entry() {
let cache = Cache::new();
cache.upsert_user(make_user("u2", "Bob"));
cache.upsert_user(make_user("u2", "Bobby"));
assert_eq!(cache.user("u2").unwrap().username, "Bobby");
assert_eq!(cache.user_count(), 1);
}
fn make_msg(channel_id: &str, message_id: &str, content: &str) -> Message {
serde_json::from_value(serde_json::json!({
"id": message_id,
"channel_id": channel_id,
"author": { "id": "1", "username": "u", "discriminator": "0", "avatar": null },
"content": content,
"timestamp": "2024-01-01T00:00:00Z",
"tts": false,
"mention_everyone": false,
"mentions": [],
"mention_roles": [],
"attachments": [],
"embeds": [],
"pinned": false,
"type": 0
}))
.unwrap()
}
#[test]
fn message_cache_stores_and_retrieves() {
let cache = Cache::new();
cache.upsert_message(make_msg("ch1", "m1", "hello"));
let msg = cache.message("ch1", "m1").expect("message should be cached");
assert_eq!(msg.content, "hello");
}
#[test]
fn message_cache_lru_eviction() {
let cache = Cache::with_max_messages(3);
cache.upsert_message(make_msg("ch2", "1", "first"));
cache.upsert_message(make_msg("ch2", "2", "second"));
cache.upsert_message(make_msg("ch2", "3", "third"));
cache.upsert_message(make_msg("ch2", "4", "fourth"));
assert!(cache.message("ch2", "1").is_none(), "oldest should be evicted");
assert!(cache.message("ch2", "4").is_some());
assert_eq!(cache.channel_messages("ch2").len(), 3);
}
#[test]
fn message_cache_delete() {
let cache = Cache::new();
cache.upsert_message(make_msg("ch3", "m10", "to delete"));
assert!(cache.remove_message("ch3", "m10").is_some());
assert!(cache.message("ch3", "m10").is_none());
}
#[test]
fn message_cache_disabled_when_max_zero() {
let cache = Cache::with_max_messages(0);
cache.upsert_message(make_msg("ch4", "m1", "ignored"));
assert_eq!(cache.message_count(), 0);
}
#[test]
fn settings_cache_guilds_false_skips_upsert() {
let cache = Cache::with_settings(CacheSettings { cache_guilds: false, ..CacheSettings::default() });
cache.upsert_guild(make_guild("g1", "Ignored"));
assert!(cache.guild("g1").is_none());
}
#[test]
fn settings_cache_users_false_skips_upsert() {
let cache = Cache::with_settings(CacheSettings { cache_users: false, ..CacheSettings::default() });
cache.upsert_user(make_user("u99", "Ghost"));
assert!(cache.user("u99").is_none());
}
#[test]
fn settings_cache_messages_false_skips_upsert() {
let cache = Cache::with_settings(CacheSettings { cache_messages: false, ..CacheSettings::default() });
cache.upsert_message(make_msg("ch5", "m1", "ignored"));
assert_eq!(cache.message_count(), 0);
}
#[test]
fn settings_accessor_returns_config() {
let settings = CacheSettings { max_messages: 42, cache_guilds: false, ..CacheSettings::default() };
let cache = Cache::with_settings(settings.clone());
assert_eq!(cache.settings().max_messages, 42);
assert!(!cache.settings().cache_guilds);
}
#[test]
fn ttl_expired_guild_returns_none_on_access() {
let cache = Cache::with_settings(CacheSettings { time_to_live: Some(Duration::from_nanos(0)), ..CacheSettings::default() });
cache.upsert_guild(make_guild("ttl1", "Expiring"));
std::thread::sleep(Duration::from_millis(1));
assert!(cache.guild("ttl1").is_none(), "entry should have expired");
}
#[test]
fn ttl_expired_user_returns_none_on_access() {
let cache = Cache::with_settings(CacheSettings { time_to_live: Some(Duration::from_nanos(0)), ..CacheSettings::default() });
cache.upsert_user(make_user("uttl1", "Expiring"));
std::thread::sleep(Duration::from_millis(1));
assert!(cache.user("uttl1").is_none(), "user entry should have expired");
}
#[test]
fn no_ttl_entries_stay_indefinitely() {
let cache = Cache::with_settings(CacheSettings { time_to_live: None, ..CacheSettings::default() });
cache.upsert_guild(make_guild("perm1", "Permanent"));
std::thread::sleep(Duration::from_millis(1));
assert!(cache.guild("perm1").is_some(), "entry without TTL should persist");
}
}