robespierre_cache/
lib.rs

1// TODO: documentation
2
3use std::{
4    collections::{hash_map::Entry, HashMap, VecDeque},
5    iter::FromIterator,
6    sync::Arc,
7};
8
9use async_trait::async_trait;
10use tokio::sync::RwLock;
11
12use robespierre_models::{
13    channels::{Channel, ChannelField, Message, PartialChannel, PartialMessage},
14    events::ServerToClientEvent,
15    id::{ChannelId, MemberId, MessageId, RoleId, ServerId, UserId},
16    servers::{
17        Member, MemberField, PartialMember, PartialRole, PartialServer, RoleField, Server,
18        ServerField,
19    },
20    users::{User, UserField, UserPatch},
21};
22
23#[derive(Debug, Clone, Default)]
24pub struct CacheConfig {
25    /// number of messages to cache / channel, 0 for no caching
26    pub messages: usize,
27}
28
29impl CacheConfig {
30    pub fn messages(self, messages: usize) -> CacheConfig {
31        Self { messages, ..self }
32    }
33}
34
35pub struct Cache {
36    config: CacheConfig,
37
38    users: RwLock<HashMap<UserId, User>>,
39    servers: RwLock<HashMap<ServerId, Server>>,
40    roles: RwLock<HashMap<RoleId, ServerId>>,
41    members: RwLock<HashMap<MemberId, Member>>,
42    channels: RwLock<HashMap<ChannelId, Channel>>,
43    messages: RwLock<HashMap<ChannelId, HashMap<MessageId, Message>>>,
44    message_queue: RwLock<HashMap<ChannelId, VecDeque<MessageId>>>,
45}
46
47impl Cache {
48    pub fn new(config: CacheConfig) -> Arc<Self> {
49        Arc::new(Self {
50            config,
51
52            users: RwLock::new(HashMap::new()),
53            servers: RwLock::new(HashMap::new()),
54            roles: RwLock::new(HashMap::new()),
55            members: RwLock::new(HashMap::new()),
56            channels: RwLock::new(HashMap::new()),
57            messages: RwLock::new(HashMap::new()),
58            message_queue: RwLock::new(HashMap::new()),
59        })
60    }
61}
62
63macro_rules! cache_field {
64    ($id_ty:ty, $full_ty:ty, $cloner:ident, $get_data:ident, $field:ident) => {
65        impl Cache {
66            pub async fn $cloner(&self, id: $id_ty) -> Option<$full_ty> {
67                self.$get_data(id, Clone::clone).await
68            }
69
70            pub async fn $get_data<F, T>(&self, id: $id_ty, f: F) -> Option<T>
71            where
72                F: FnOnce(&$full_ty) -> T,
73            {
74                self.$field.read().await.get(&id).map(f)
75            }
76        }
77    };
78
79    ($id_ty:ty, $full_ty:ty, $cloner:ident, $get_data:ident, $field:ident, $commit_function:ident, $key_field:ident) => {
80        impl Cache {
81            pub async fn $cloner(&self, id: $id_ty) -> Option<$full_ty> {
82                self.$get_data(id, Clone::clone).await
83            }
84
85            pub async fn $get_data<F, T>(&self, id: $id_ty, f: F) -> Option<T>
86            where
87                F: FnOnce(&$full_ty) -> T,
88            {
89                self.$field.read().await.get(&id).map(f)
90            }
91
92            pub async fn $commit_function(&self, v: &$full_ty) {
93                self.$field.write().await.insert(v.$key_field, v.clone());
94            }
95        }
96    };
97}
98
99cache_field! {UserId, User, get_user, get_user_data, users, commit_user, id}
100
101impl Cache {
102    pub async fn patch_user(
103        &self,
104        user_id: UserId,
105        patch: impl FnOnce() -> UserPatch,
106        remove: Option<UserField>,
107    ) {
108        let mut lock = self.users.write().await;
109        if let Some(user) = lock.get_mut(&user_id) {
110            let patch = patch();
111
112            patch.patch(user);
113            if let Some(remove) = remove {
114                remove.remove_patch(user);
115            }
116        }
117    }
118
119    pub async fn get_users_aggregate<T, F>(&self, f: F) -> T
120    where
121        F: FnOnce(UserIter) -> T,
122    {
123        f(UserIter(self.users.read().await.values()))
124    }
125}
126
127pub struct UserIter<'a>(std::collections::hash_map::Values<'a, UserId, User>);
128
129impl<'a> Iterator for UserIter<'a> {
130    type Item = &'a User;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        self.0.next()
134    }
135
136    fn size_hint(&self) -> (usize, Option<usize>) {
137        self.0.size_hint()
138    }
139}
140
141impl Cache {
142    pub async fn get_server(&self, id: ServerId) -> Option<Server> {
143        self.get_server_data(id, Clone::clone).await
144    }
145    pub async fn get_server_data<F, T>(&self, id: ServerId, f: F) -> Option<T>
146    where
147        F: FnOnce(&Server) -> T,
148    {
149        self.servers.read().await.get(&id).map(f)
150    }
151    pub async fn commit_server(&self, v: &Server) {
152        self.servers.write().await.insert(v.id, v.clone());
153
154        if let Some(ref roles) = v.roles {
155            let mut roles_write_lock = self.roles.write().await;
156
157            for (role_id, _role) in roles.iter() {
158                roles_write_lock.insert(*role_id, v.id);
159            }
160        }
161    }
162}
163
164impl Cache {
165    pub async fn patch_server(
166        &self,
167        server_id: ServerId,
168        patch: impl FnOnce() -> PartialServer,
169        remove: Option<ServerField>,
170    ) {
171        let mut lock = self.servers.write().await;
172        if let Some(server) = lock.get_mut(&server_id) {
173            let patch = patch();
174
175            patch.patch(server);
176            if let Some(remove) = remove {
177                remove.remove_patch(server);
178            }
179        }
180    }
181
182    pub async fn delete_server(&self, server_id: ServerId) {
183        self.servers.write().await.remove(&server_id);
184    }
185
186    pub async fn get_servers_aggregate<T, F>(&self, f: F) -> T
187    where
188        F: FnOnce(ServerIter) -> T,
189    {
190        f(ServerIter(self.servers.read().await.values()))
191    }
192}
193
194pub struct ServerIter<'a>(std::collections::hash_map::Values<'a, ServerId, Server>);
195
196impl<'a> Iterator for ServerIter<'a> {
197    type Item = &'a Server;
198
199    fn next(&mut self) -> Option<Self::Item> {
200        self.0.next()
201    }
202
203    fn size_hint(&self) -> (usize, Option<usize>) {
204        self.0.size_hint()
205    }
206}
207
208impl Cache {
209    pub async fn get_server_of_role(&self, id: RoleId) -> Option<ServerId> {
210        self.roles.read().await.get(&id).copied()
211    }
212
213    pub async fn patch_role(
214        &self,
215        server_id: ServerId,
216        role_id: RoleId,
217        patch: impl FnOnce() -> PartialRole,
218        remove: Option<RoleField>,
219    ) {
220        let mut lock = self.servers.write().await;
221        if let Some(server) = lock.get_mut(&server_id) {
222            if let Some(ref mut roles_obj) = server.roles {
223                let patch = patch();
224
225                roles_obj.patch_role(&role_id, patch, remove);
226            }
227        }
228    }
229
230    pub async fn delete_role(&self, id: ServerId, role: RoleId) {
231        let mut lock = self.servers.write().await;
232        if let Some(server) = lock.get_mut(&id) {
233            if let Some(ref mut roles_obj) = server.roles {
234                roles_obj.remove(&role);
235            }
236        }
237    }
238}
239
240cache_field! {MemberId, Member, get_member, get_member_data, members, commit_member, id}
241
242impl Cache {
243    pub async fn patch_member(
244        &self,
245        member_id: MemberId,
246        patch: impl FnOnce() -> PartialMember,
247        remove: Option<MemberField>,
248    ) {
249        let mut lock = self.members.write().await;
250        if let Some(member) = lock.get_mut(&member_id) {
251            let patch = patch();
252
253            patch.patch(member);
254            if let Some(remove) = remove {
255                remove.remove_patch(member);
256            }
257        }
258    }
259
260    pub async fn get_members_aggregate<T, F>(&self, f: F) -> T
261    where
262        F: FnOnce(MemberIter) -> T,
263    {
264        f(MemberIter(self.members.read().await.values()))
265    }
266}
267
268pub struct MemberIter<'a>(std::collections::hash_map::Values<'a, MemberId, Member>);
269
270impl<'a> Iterator for MemberIter<'a> {
271    type Item = &'a Member;
272
273    fn next(&mut self) -> Option<Self::Item> {
274        self.0.next()
275    }
276
277    fn size_hint(&self) -> (usize, Option<usize>) {
278        self.0.size_hint()
279    }
280}
281
282cache_field! {ChannelId, Channel, get_channel, get_channel_data, channels}
283
284impl Cache {
285    pub async fn commit_channel(&self, channel: &Channel) {
286        self.channels
287            .write()
288            .await
289            .insert(channel.id(), channel.clone());
290    }
291
292    pub async fn patch_channel(
293        &self,
294        channel_id: ChannelId,
295        patch: impl FnOnce() -> PartialChannel,
296        remove: Option<ChannelField>,
297    ) {
298        let mut lock = self.channels.write().await;
299        if let Some(channel) = lock.get_mut(&channel_id) {
300            let patch = patch();
301
302            patch.patch(channel);
303            if let Some(remove) = remove {
304                remove.remove_patch(channel);
305            }
306        }
307    }
308
309    pub async fn delete_channel(&self, channel_id: ChannelId) {
310        self.channels.write().await.remove(&channel_id);
311    }
312
313    pub async fn get_channels_aggregate<T, F>(&self, f: F) -> T
314    where
315        F: FnOnce(ChannelIter) -> T,
316    {
317        f(ChannelIter(self.channels.read().await.values()))
318    }
319}
320
321pub struct ChannelIter<'a>(std::collections::hash_map::Values<'a, ChannelId, Channel>);
322
323impl<'a> Iterator for ChannelIter<'a> {
324    type Item = &'a Channel;
325
326    fn next(&mut self) -> Option<Self::Item> {
327        self.0.next()
328    }
329
330    fn size_hint(&self) -> (usize, Option<usize>) {
331        self.0.size_hint()
332    }
333}
334
335impl Cache {
336    pub async fn get_message(&self, channel: ChannelId, message: MessageId) -> Option<Message> {
337        self.get_message_data(channel, message, Clone::clone).await
338    }
339
340    pub async fn get_message_data<F, T>(
341        &self,
342        channel: ChannelId,
343        message: MessageId,
344        f: F,
345    ) -> Option<T>
346    where
347        F: FnOnce(&Message) -> T,
348    {
349        self.messages
350            .read()
351            .await
352            .get(&channel)?
353            .get(&message)
354            .map(f)
355    }
356
357    pub async fn commit_message(&self, message: &Message) {
358        if self.config.messages == 0 {
359            return;
360        }
361
362        let mut queue_lock = self.message_queue.write().await;
363        let deque = queue_lock
364            .entry(message.channel)
365            .or_insert_with(VecDeque::new);
366
367        match self.messages.write().await.entry(message.channel) {
368            Entry::Occupied(mut m) => {
369                m.get_mut().insert(message.id, message.clone());
370
371                deque.push_back(message.id);
372
373                if deque.len() > self.config.messages {
374                    if let Some(oldest) = deque.pop_front() {
375                        m.get_mut().remove(&oldest);
376                    }
377                }
378            }
379            Entry::Vacant(v) => {
380                deque.push_back(message.id);
381                v.insert(HashMap::from_iter([(message.id, message.clone())]));
382            }
383        }
384    }
385
386    pub async fn patch_message(
387        &self,
388        channel_id: ChannelId,
389        message_id: MessageId,
390        patch: impl FnOnce() -> PartialMessage,
391    ) {
392        let mut lock = self.messages.write().await;
393        if let Some(ch) = lock.get_mut(&channel_id) {
394            if let Some(message) = ch.get_mut(&message_id) {
395                let patch = patch();
396
397                patch.patch(message);
398            }
399        }
400    }
401
402    pub async fn get_messages_aggregate<T, F>(&self, channel_id: ChannelId, f: F) -> Option<T>
403    where
404        F: FnOnce(MessageIter) -> T,
405    {
406        let lock = self.messages.read().await;
407        let iter = lock.get(&channel_id)?.values();
408        Some(f(MessageIter(iter)))
409    }
410}
411
412pub struct MessageIter<'a>(std::collections::hash_map::Values<'a, MessageId, Message>);
413
414impl<'a> Iterator for MessageIter<'a> {
415    type Item = &'a Message;
416
417    fn next(&mut self) -> Option<Self::Item> {
418        self.0.next()
419    }
420
421    fn size_hint(&self) -> (usize, Option<usize>) {
422        self.0.size_hint()
423    }
424}
425
426pub trait HasCache: Send + Sync {
427    fn get_cache(&self) -> Option<&Cache>;
428}
429
430impl HasCache for Cache {
431    fn get_cache(&self) -> Option<&Cache> {
432        Some(self)
433    }
434}
435
436impl HasCache for Arc<Cache> {
437    fn get_cache(&self) -> Option<&Cache> {
438        Some(self)
439    }
440}
441
442#[async_trait]
443pub trait CommitToCache: Send + Sync {
444    async fn commit_to_cache<C: HasCache>(self, c: &C) -> Self
445    where
446        Self: Sized,
447    {
448        self.commit_to_cache_ref(c).await;
449
450        self
451    }
452
453    async fn commit_to_cache_ref<C: HasCache>(&self, c: &C) {
454        if let Some(c) = c.get_cache() {
455            Self::__commit_to_cache(self, c).await;
456        }
457    }
458
459    async fn __commit_to_cache(&self, cache: &Cache);
460}
461
462#[async_trait]
463impl CommitToCache for User {
464    async fn __commit_to_cache(&self, cache: &Cache) {
465        cache.commit_user(self).await;
466    }
467}
468
469#[async_trait]
470impl CommitToCache for Channel {
471    async fn __commit_to_cache(&self, cache: &Cache) {
472        cache.commit_channel(self).await;
473    }
474}
475
476#[async_trait]
477impl CommitToCache for Server {
478    async fn __commit_to_cache(&self, cache: &Cache) {
479        cache.commit_server(self).await;
480    }
481}
482
483#[async_trait]
484impl CommitToCache for Member {
485    async fn __commit_to_cache(&self, cache: &Cache) {
486        cache.commit_member(self).await;
487    }
488}
489
490#[async_trait]
491impl CommitToCache for Message {
492    async fn __commit_to_cache(&self, cache: &Cache) {
493        cache.commit_message(self).await;
494    }
495}
496
497#[async_trait]
498impl CommitToCache for ServerToClientEvent {
499    async fn __commit_to_cache(&self, cache: &Cache) {
500        #[allow(unused_variables)]
501        match self {
502            ServerToClientEvent::Error { .. } => {}
503            ServerToClientEvent::Authenticated => {}
504            ServerToClientEvent::Pong { .. } => {}
505            ServerToClientEvent::Ready { event } => {
506                for user in event.users.iter() {
507                    user.commit_to_cache_ref(cache).await;
508                }
509                for channel in event.channels.iter() {
510                    channel.commit_to_cache_ref(cache).await;
511                }
512                for server in event.servers.iter() {
513                    server.commit_to_cache_ref(cache).await;
514                }
515                for member in event.members.iter() {
516                    member.commit_to_cache_ref(cache).await;
517                }
518            }
519            ServerToClientEvent::Message { message } => {
520                message.commit_to_cache_ref(cache).await;
521            }
522            ServerToClientEvent::MessageUpdate { id, channel, data } => {
523                cache.patch_message(*channel, *id, || data.clone()).await;
524            }
525            ServerToClientEvent::MessageDelete { id, channel } => {}
526            ServerToClientEvent::ChannelCreate { channel } => {
527                cache.commit_channel(channel).await;
528            }
529            ServerToClientEvent::ChannelUpdate { id, data, clear } => {
530                cache.patch_channel(*id, || data.clone(), *clear).await;
531            }
532            ServerToClientEvent::ChannelDelete { id } => {
533                cache.delete_channel(*id).await;
534            }
535            ServerToClientEvent::ChannelGroupJoin { id, user } => {}
536            ServerToClientEvent::ChannelGroupLeave { id, user } => {}
537            ServerToClientEvent::ChannelStartTyping { id, user } => {}
538            ServerToClientEvent::ChannelStopTyping { id, user } => {}
539            ServerToClientEvent::ChannelAck {
540                id,
541                user,
542                message_id,
543            } => {}
544            ServerToClientEvent::ServerUpdate { id, data, clear } => {
545                cache.patch_server(*id, || data.clone(), *clear).await;
546            }
547            ServerToClientEvent::ServerDelete { id } => {
548                cache.delete_server(*id).await;
549            }
550            ServerToClientEvent::ServerMemberUpdate { id, data, clear } => {
551                cache.patch_member(*id, || data.clone(), *clear).await;
552            }
553            ServerToClientEvent::ServerMemberJoin { id, user } => {}
554            ServerToClientEvent::ServerMemberLeave { id, user } => {}
555            ServerToClientEvent::ServerRoleUpdate {
556                id,
557                role_id,
558                data,
559                clear,
560            } => {
561                cache
562                    .patch_role(*id, *role_id, || data.clone(), *clear)
563                    .await;
564            }
565            ServerToClientEvent::ServerRoleDelete { id, role_id } => {
566                cache.delete_role(*id, *role_id).await;
567            }
568            ServerToClientEvent::UserUpdate { id, data, clear } => {
569                cache.patch_user(*id, || data.clone(), *clear).await;
570            }
571            ServerToClientEvent::UserRelationship { id, user, status } => {}
572        }
573    }
574}