use std::collections::{HashSet, VecDeque};
use std::hash::Hash;
#[cfg(feature = "temp_cache")]
use std::sync::Arc;
#[cfg(feature = "temp_cache")]
use std::time::Duration;
use dashmap::mapref::entry::Entry;
use dashmap::mapref::one::{MappedRef, Ref};
use dashmap::DashMap;
#[cfg(feature = "temp_cache")]
use mini_moka::sync::Cache as MokaCache;
use parking_lot::RwLock;
use tracing::instrument;
pub use self::cache_update::CacheUpdate;
pub use self::settings::Settings;
use crate::model::prelude::*;
mod cache_update;
mod event;
mod settings;
pub(crate) mod wrappers;
#[cfg(feature = "temp_cache")]
pub(crate) use wrappers::MaybeOwnedArc;
use wrappers::{BuildHasher, MaybeMap, ReadOnlyMapRef};
type MessageCache = DashMap<ChannelId, HashMap<MessageId, Message>, BuildHasher>;
struct NotSend;
enum CacheRefInner<'a, K, V, T> {
#[cfg(feature = "temp_cache")]
Arc(Arc<V>),
DashRef(Ref<'a, K, V, BuildHasher>),
DashMappedRef(MappedRef<'a, K, T, V, BuildHasher>),
ReadGuard(parking_lot::RwLockReadGuard<'a, V>),
}
pub struct CacheRef<'a, K, V, T = ()> {
inner: CacheRefInner<'a, K, V, T>,
phantom: std::marker::PhantomData<*const NotSend>,
}
impl<'a, K, V, T> CacheRef<'a, K, V, T> {
fn new(inner: CacheRefInner<'a, K, V, T>) -> Self {
Self {
inner,
phantom: std::marker::PhantomData,
}
}
#[cfg(feature = "temp_cache")]
fn from_arc(inner: MaybeOwnedArc<V>) -> Self {
Self::new(CacheRefInner::Arc(inner.get_inner()))
}
fn from_ref(inner: Ref<'a, K, V, BuildHasher>) -> Self {
Self::new(CacheRefInner::DashRef(inner))
}
fn from_mapped_ref(inner: MappedRef<'a, K, T, V, BuildHasher>) -> Self {
Self::new(CacheRefInner::DashMappedRef(inner))
}
fn from_guard(inner: parking_lot::RwLockReadGuard<'a, V>) -> Self {
Self::new(CacheRefInner::ReadGuard(inner))
}
}
impl<K: Eq + Hash, V, T> std::ops::Deref for CacheRef<'_, K, V, T> {
type Target = V;
fn deref(&self) -> &Self::Target {
match &self.inner {
#[cfg(feature = "temp_cache")]
CacheRefInner::Arc(inner) => inner,
CacheRefInner::DashRef(inner) => inner.value(),
CacheRefInner::DashMappedRef(inner) => inner.value(),
CacheRefInner::ReadGuard(inner) => inner,
}
}
}
type Never = std::convert::Infallible;
type MappedGuildRef<'a, T> = CacheRef<'a, GuildId, T, Guild>;
pub type MemberRef<'a> = MappedGuildRef<'a, Member>;
pub type GuildRoleRef<'a> = MappedGuildRef<'a, Role>;
pub type UserRef<'a> = CacheRef<'a, UserId, User, Never>;
pub type GuildRef<'a> = CacheRef<'a, GuildId, Guild, Never>;
pub type SettingsRef<'a> = CacheRef<'a, Never, Settings, Never>;
pub type GuildChannelRef<'a> = MappedGuildRef<'a, GuildChannel>;
pub type CurrentUserRef<'a> = CacheRef<'a, Never, CurrentUser, Never>;
pub type GuildRolesRef<'a> = MappedGuildRef<'a, HashMap<RoleId, Role>>;
pub type GuildChannelsRef<'a> = MappedGuildRef<'a, HashMap<ChannelId, GuildChannel>>;
pub type MessageRef<'a> = CacheRef<'a, ChannelId, Message, HashMap<MessageId, Message>>;
pub type ChannelMessagesRef<'a> = CacheRef<'a, ChannelId, HashMap<MessageId, Message>, Never>;
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Debug)]
pub(crate) struct CachedShardData {
pub total: u32,
pub connected: HashSet<ShardId>,
pub has_sent_shards_ready: bool,
}
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Debug)]
#[non_exhaustive]
pub struct Cache {
#[cfg(feature = "temp_cache")]
pub(crate) temp_channels: MokaCache<ChannelId, MaybeOwnedArc<GuildChannel>, BuildHasher>,
#[cfg(feature = "temp_cache")]
pub(crate) temp_private_channels: MokaCache<UserId, MaybeOwnedArc<PrivateChannel>, BuildHasher>,
#[cfg(feature = "temp_cache")]
pub(crate) temp_messages: MokaCache<MessageId, MaybeOwnedArc<Message>, BuildHasher>,
#[cfg(feature = "temp_cache")]
pub(crate) temp_users: MokaCache<UserId, MaybeOwnedArc<User>, BuildHasher>,
pub(crate) channels: MaybeMap<ChannelId, GuildId>,
pub(crate) guilds: MaybeMap<GuildId, Guild>,
pub(crate) unavailable_guilds: MaybeMap<GuildId, ()>,
pub(crate) users: MaybeMap<UserId, User>,
pub(crate) messages: MessageCache,
pub(crate) message_queue: DashMap<ChannelId, VecDeque<MessageId>, BuildHasher>,
pub(crate) shard_data: RwLock<CachedShardData>,
pub(crate) user: RwLock<CurrentUser>,
settings: RwLock<Settings>,
}
impl Cache {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[instrument]
pub fn new_with_settings(settings: Settings) -> Self {
#[cfg(feature = "temp_cache")]
fn temp_cache<K, V>(ttl: Duration) -> MokaCache<K, V, BuildHasher>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
MokaCache::builder().time_to_live(ttl).build_with_hasher(BuildHasher::default())
}
Self {
#[cfg(feature = "temp_cache")]
temp_private_channels: temp_cache(settings.time_to_live),
#[cfg(feature = "temp_cache")]
temp_channels: temp_cache(settings.time_to_live),
#[cfg(feature = "temp_cache")]
temp_messages: temp_cache(settings.time_to_live),
#[cfg(feature = "temp_cache")]
temp_users: temp_cache(settings.time_to_live),
channels: MaybeMap(settings.cache_channels.then(DashMap::default)),
guilds: MaybeMap(settings.cache_guilds.then(DashMap::default)),
unavailable_guilds: MaybeMap(settings.cache_guilds.then(DashMap::default)),
users: MaybeMap(settings.cache_users.then(DashMap::default)),
messages: DashMap::default(),
message_queue: DashMap::default(),
shard_data: RwLock::new(CachedShardData {
total: 1,
connected: HashSet::new(),
has_sent_shards_ready: false,
}),
user: RwLock::new(CurrentUser::default()),
settings: RwLock::new(settings),
}
}
pub fn unknown_members(&self) -> u64 {
let mut total = 0;
for guild_entry in self.guilds.iter() {
let guild = guild_entry.value();
let members = guild.members.len() as u64;
if guild.member_count > members {
total += guild.member_count - members;
}
}
total
}
pub fn guilds(&self) -> Vec<GuildId> {
let unavailable_guilds = self.unavailable_guilds();
let unavailable_guild_ids = unavailable_guilds.iter().map(|i| *i.key());
self.guilds.iter().map(|i| *i.key()).chain(unavailable_guild_ids).collect()
}
#[inline]
#[deprecated = "Use Cache::guild and Guild::channels instead"]
pub fn channel<C: Into<ChannelId>>(&self, id: C) -> Option<GuildChannelRef<'_>> {
self.channel_(id.into())
}
fn channel_(&self, id: ChannelId) -> Option<GuildChannelRef<'_>> {
let guild_id = *self.channels.get(&id)?;
let guild_ref = self.guilds.get(&guild_id)?;
let channel = guild_ref.try_map(|g| g.channels.get(&id)).ok();
if let Some(channel) = channel {
return Some(CacheRef::from_mapped_ref(channel));
}
#[cfg(feature = "temp_cache")]
{
if let Some(channel) = self.temp_channels.get(&id) {
return Some(CacheRef::from_arc(channel));
}
}
None
}
pub fn channel_messages(
&self,
channel_id: impl Into<ChannelId>,
) -> Option<ChannelMessagesRef<'_>> {
self.messages.get(&channel_id.into()).map(CacheRef::from_ref)
}
#[inline]
pub fn guild<G: Into<GuildId>>(&self, id: G) -> Option<GuildRef<'_>> {
self.guild_(id.into())
}
fn guild_(&self, id: GuildId) -> Option<GuildRef<'_>> {
self.guilds.get(&id).map(CacheRef::from_ref)
}
pub fn guild_count(&self) -> usize {
self.guilds.len()
}
#[inline]
#[deprecated = "Use Cache::guild and Guild::members instead"]
pub fn member(
&self,
guild_id: impl Into<GuildId>,
user_id: impl Into<UserId>,
) -> Option<MemberRef<'_>> {
self.member_(guild_id.into(), user_id.into())
}
fn member_(&self, guild_id: GuildId, user_id: UserId) -> Option<MemberRef<'_>> {
let member = self.guilds.get(&guild_id)?.try_map(|g| g.members.get(&user_id)).ok()?;
Some(CacheRef::from_mapped_ref(member))
}
#[inline]
#[deprecated = "Use Cache::guild and Guild::roles instead"]
pub fn guild_roles(&self, guild_id: impl Into<GuildId>) -> Option<GuildRolesRef<'_>> {
self.guild_roles_(guild_id.into())
}
fn guild_roles_(&self, guild_id: GuildId) -> Option<GuildRolesRef<'_>> {
let roles = self.guilds.get(&guild_id)?.map(|g| &g.roles);
Some(CacheRef::from_mapped_ref(roles))
}
#[inline]
pub fn unavailable_guilds(&self) -> ReadOnlyMapRef<'_, GuildId, ()> {
self.unavailable_guilds.as_read_only()
}
#[inline]
#[deprecated = "Use Cache::guild and Guild::channels instead"]
pub fn guild_channels(&self, guild_id: impl Into<GuildId>) -> Option<GuildChannelsRef<'_>> {
self.guild_channels_(guild_id.into())
}
fn guild_channels_(&self, guild_id: GuildId) -> Option<GuildChannelsRef<'_>> {
let channels = self.guilds.get(&guild_id)?.map(|g| &g.channels);
Some(CacheRef::from_mapped_ref(channels))
}
pub fn guild_channel_count(&self) -> usize {
self.channels.len()
}
#[inline]
pub fn shard_count(&self) -> u32 {
self.shard_data.read().total
}
#[inline]
pub fn message<C, M>(&self, channel_id: C, message_id: M) -> Option<MessageRef<'_>>
where
C: Into<ChannelId>,
M: Into<MessageId>,
{
self.message_(channel_id.into(), message_id.into())
}
fn message_(&self, channel_id: ChannelId, message_id: MessageId) -> Option<MessageRef<'_>> {
#[cfg(feature = "temp_cache")]
if let Some(message) = self.temp_messages.get(&message_id) {
return Some(CacheRef::from_arc(message));
}
let channel_messages = self.messages.get(&channel_id)?;
let message = channel_messages.try_map(|messages| messages.get(&message_id)).ok()?;
Some(CacheRef::from_mapped_ref(message))
}
#[inline]
#[deprecated = "Use Cache::guild and Guild::roles instead"]
pub fn role<G, R>(&self, guild_id: G, role_id: R) -> Option<GuildRoleRef<'_>>
where
G: Into<GuildId>,
R: Into<RoleId>,
{
self.role_(guild_id.into(), role_id.into())
}
fn role_(&self, guild_id: GuildId, role_id: RoleId) -> Option<GuildRoleRef<'_>> {
let role = self.guilds.get(&guild_id)?.try_map(|g| g.roles.get(&role_id)).ok()?;
Some(CacheRef::from_mapped_ref(role))
}
pub fn settings(&self) -> SettingsRef<'_> {
CacheRef::from_guard(self.settings.read())
}
pub fn set_max_messages(&self, max: usize) {
self.settings.write().max_messages = max;
}
#[inline]
pub fn user<U: Into<UserId>>(&self, user_id: U) -> Option<UserRef<'_>> {
self.user_(user_id.into())
}
#[cfg(feature = "temp_cache")]
fn user_(&self, user_id: UserId) -> Option<UserRef<'_>> {
if let Some(user) = self.users.get(&user_id) {
Some(CacheRef::from_ref(user))
} else {
self.temp_users.get(&user_id).map(CacheRef::from_arc)
}
}
#[cfg(not(feature = "temp_cache"))]
fn user_(&self, user_id: UserId) -> Option<UserRef<'_>> {
self.users.get(&user_id).map(CacheRef::from_ref)
}
#[inline]
pub fn users(&self) -> ReadOnlyMapRef<'_, UserId, User> {
self.users.as_read_only()
}
#[inline]
pub fn user_count(&self) -> usize {
self.users.len()
}
#[inline]
pub fn current_user(&self) -> CurrentUserRef<'_> {
CacheRef::from_guard(self.user.read())
}
#[deprecated = "Use Cache::guild, Guild::channels, and GuildChannel::kind"]
pub fn category(&self, channel_id: ChannelId) -> Option<GuildChannelRef<'_>> {
#[allow(deprecated)]
let channel = self.channel(channel_id)?;
if channel.kind == ChannelType::Category {
Some(channel)
} else {
None
}
}
#[deprecated = "Use Cache::guild, Guild::channels, and GuildChannel::parent_id"]
pub fn channel_category_id(&self, channel_id: ChannelId) -> Option<ChannelId> {
#[allow(deprecated)]
self.channel(channel_id)?.parent_id
}
pub fn guild_categories(&self, guild_id: GuildId) -> Option<HashMap<ChannelId, GuildChannel>> {
let guild = self.guilds.get(&guild_id)?;
Some(
guild
.channels
.iter()
.filter(|(_id, channel)| channel.kind == ChannelType::Category)
.map(|(id, channel)| (*id, channel.clone()))
.collect(),
)
}
#[instrument(skip(self, e))]
pub fn update<E: CacheUpdate>(&self, e: &mut E) -> Option<E::Output> {
e.update(self)
}
pub(crate) fn update_user_entry(&self, user: &User) {
if let Some(users) = &self.users.0 {
match users.entry(user.id) {
Entry::Vacant(e) => {
e.insert(user.clone());
},
Entry::Occupied(mut e) => {
e.get_mut().clone_from(user);
},
}
}
}
}
impl Default for Cache {
fn default() -> Self {
Self::new_with_settings(Settings::default())
}
}
#[cfg(test)]
mod test {
use crate::cache::{Cache, CacheUpdate, Settings};
use crate::model::prelude::*;
#[test]
fn test_cache_messages() {
let settings = Settings {
max_messages: 2,
..Default::default()
};
let cache = Cache::new_with_settings(settings);
let mut event = MessageCreateEvent {
message: Message {
id: MessageId::new(3),
guild_id: Some(GuildId::new(1)),
..Default::default()
},
};
assert!(!cache.messages.contains_key(&event.message.channel_id));
assert!(event.update(&cache).is_none());
assert!(event.update(&cache).is_none());
assert_eq!(cache.messages.get(&event.message.channel_id).unwrap().len(), 1);
event.message.id = MessageId::new(4);
assert!(event.update(&cache).is_none());
assert_eq!(cache.messages.get(&event.message.channel_id).unwrap().len(), 2);
event.message.id = MessageId::new(5);
assert!(event.update(&cache).is_some());
{
let channel = cache.messages.get(&event.message.channel_id).unwrap();
assert_eq!(channel.len(), 2);
assert!(!channel.contains_key(&MessageId::new(3)));
}
let channel = GuildChannel {
id: event.message.channel_id,
guild_id: event.message.guild_id.unwrap(),
..Default::default()
};
let mut delete = ChannelDeleteEvent {
channel: channel.clone(),
};
assert!(cache.update(&mut delete).is_some());
assert!(!cache.messages.contains_key(&delete.channel.id));
let mut guild_create = GuildCreateEvent {
guild: Guild {
id: GuildId::new(1),
channels: HashMap::from([(ChannelId::new(2), channel)]),
..Default::default()
},
};
assert!(cache.update(&mut guild_create).is_none());
assert!(cache.update(&mut event).is_none());
let mut guild_delete = GuildDeleteEvent {
guild: UnavailableGuild {
id: GuildId::new(1),
unavailable: false,
},
};
assert!(cache.update(&mut guild_delete).is_some());
assert!(!cache.messages.contains_key(&ChannelId::new(2)));
}
}