1use std::{
4 collections::{hash_map::Entry, HashMap, VecDeque},
5 iter::FromIterator,
6 sync::Arc,
7};
8
9use async_trait::async_trait;
10use tokio::sync::RwLock;
11
12use robespierre_models::{
13 channels::{Channel, ChannelField, Message, PartialChannel, PartialMessage},
14 events::ServerToClientEvent,
15 id::{ChannelId, MemberId, MessageId, RoleId, ServerId, UserId},
16 servers::{
17 Member, MemberField, PartialMember, PartialRole, PartialServer, RoleField, Server,
18 ServerField,
19 },
20 users::{User, UserField, UserPatch},
21};
22
23#[derive(Debug, Clone, Default)]
24pub struct CacheConfig {
25 pub messages: usize,
27}
28
29impl CacheConfig {
30 pub fn messages(self, messages: usize) -> CacheConfig {
31 Self { messages, ..self }
32 }
33}
34
35pub struct Cache {
36 config: CacheConfig,
37
38 users: RwLock<HashMap<UserId, User>>,
39 servers: RwLock<HashMap<ServerId, Server>>,
40 roles: RwLock<HashMap<RoleId, ServerId>>,
41 members: RwLock<HashMap<MemberId, Member>>,
42 channels: RwLock<HashMap<ChannelId, Channel>>,
43 messages: RwLock<HashMap<ChannelId, HashMap<MessageId, Message>>>,
44 message_queue: RwLock<HashMap<ChannelId, VecDeque<MessageId>>>,
45}
46
47impl Cache {
48 pub fn new(config: CacheConfig) -> Arc<Self> {
49 Arc::new(Self {
50 config,
51
52 users: RwLock::new(HashMap::new()),
53 servers: RwLock::new(HashMap::new()),
54 roles: RwLock::new(HashMap::new()),
55 members: RwLock::new(HashMap::new()),
56 channels: RwLock::new(HashMap::new()),
57 messages: RwLock::new(HashMap::new()),
58 message_queue: RwLock::new(HashMap::new()),
59 })
60 }
61}
62
63macro_rules! cache_field {
64 ($id_ty:ty, $full_ty:ty, $cloner:ident, $get_data:ident, $field:ident) => {
65 impl Cache {
66 pub async fn $cloner(&self, id: $id_ty) -> Option<$full_ty> {
67 self.$get_data(id, Clone::clone).await
68 }
69
70 pub async fn $get_data<F, T>(&self, id: $id_ty, f: F) -> Option<T>
71 where
72 F: FnOnce(&$full_ty) -> T,
73 {
74 self.$field.read().await.get(&id).map(f)
75 }
76 }
77 };
78
79 ($id_ty:ty, $full_ty:ty, $cloner:ident, $get_data:ident, $field:ident, $commit_function:ident, $key_field:ident) => {
80 impl Cache {
81 pub async fn $cloner(&self, id: $id_ty) -> Option<$full_ty> {
82 self.$get_data(id, Clone::clone).await
83 }
84
85 pub async fn $get_data<F, T>(&self, id: $id_ty, f: F) -> Option<T>
86 where
87 F: FnOnce(&$full_ty) -> T,
88 {
89 self.$field.read().await.get(&id).map(f)
90 }
91
92 pub async fn $commit_function(&self, v: &$full_ty) {
93 self.$field.write().await.insert(v.$key_field, v.clone());
94 }
95 }
96 };
97}
98
99cache_field! {UserId, User, get_user, get_user_data, users, commit_user, id}
100
101impl Cache {
102 pub async fn patch_user(
103 &self,
104 user_id: UserId,
105 patch: impl FnOnce() -> UserPatch,
106 remove: Option<UserField>,
107 ) {
108 let mut lock = self.users.write().await;
109 if let Some(user) = lock.get_mut(&user_id) {
110 let patch = patch();
111
112 patch.patch(user);
113 if let Some(remove) = remove {
114 remove.remove_patch(user);
115 }
116 }
117 }
118
119 pub async fn get_users_aggregate<T, F>(&self, f: F) -> T
120 where
121 F: FnOnce(UserIter) -> T,
122 {
123 f(UserIter(self.users.read().await.values()))
124 }
125}
126
127pub struct UserIter<'a>(std::collections::hash_map::Values<'a, UserId, User>);
128
129impl<'a> Iterator for UserIter<'a> {
130 type Item = &'a User;
131
132 fn next(&mut self) -> Option<Self::Item> {
133 self.0.next()
134 }
135
136 fn size_hint(&self) -> (usize, Option<usize>) {
137 self.0.size_hint()
138 }
139}
140
141impl Cache {
142 pub async fn get_server(&self, id: ServerId) -> Option<Server> {
143 self.get_server_data(id, Clone::clone).await
144 }
145 pub async fn get_server_data<F, T>(&self, id: ServerId, f: F) -> Option<T>
146 where
147 F: FnOnce(&Server) -> T,
148 {
149 self.servers.read().await.get(&id).map(f)
150 }
151 pub async fn commit_server(&self, v: &Server) {
152 self.servers.write().await.insert(v.id, v.clone());
153
154 if let Some(ref roles) = v.roles {
155 let mut roles_write_lock = self.roles.write().await;
156
157 for (role_id, _role) in roles.iter() {
158 roles_write_lock.insert(*role_id, v.id);
159 }
160 }
161 }
162}
163
164impl Cache {
165 pub async fn patch_server(
166 &self,
167 server_id: ServerId,
168 patch: impl FnOnce() -> PartialServer,
169 remove: Option<ServerField>,
170 ) {
171 let mut lock = self.servers.write().await;
172 if let Some(server) = lock.get_mut(&server_id) {
173 let patch = patch();
174
175 patch.patch(server);
176 if let Some(remove) = remove {
177 remove.remove_patch(server);
178 }
179 }
180 }
181
182 pub async fn delete_server(&self, server_id: ServerId) {
183 self.servers.write().await.remove(&server_id);
184 }
185
186 pub async fn get_servers_aggregate<T, F>(&self, f: F) -> T
187 where
188 F: FnOnce(ServerIter) -> T,
189 {
190 f(ServerIter(self.servers.read().await.values()))
191 }
192}
193
194pub struct ServerIter<'a>(std::collections::hash_map::Values<'a, ServerId, Server>);
195
196impl<'a> Iterator for ServerIter<'a> {
197 type Item = &'a Server;
198
199 fn next(&mut self) -> Option<Self::Item> {
200 self.0.next()
201 }
202
203 fn size_hint(&self) -> (usize, Option<usize>) {
204 self.0.size_hint()
205 }
206}
207
208impl Cache {
209 pub async fn get_server_of_role(&self, id: RoleId) -> Option<ServerId> {
210 self.roles.read().await.get(&id).copied()
211 }
212
213 pub async fn patch_role(
214 &self,
215 server_id: ServerId,
216 role_id: RoleId,
217 patch: impl FnOnce() -> PartialRole,
218 remove: Option<RoleField>,
219 ) {
220 let mut lock = self.servers.write().await;
221 if let Some(server) = lock.get_mut(&server_id) {
222 if let Some(ref mut roles_obj) = server.roles {
223 let patch = patch();
224
225 roles_obj.patch_role(&role_id, patch, remove);
226 }
227 }
228 }
229
230 pub async fn delete_role(&self, id: ServerId, role: RoleId) {
231 let mut lock = self.servers.write().await;
232 if let Some(server) = lock.get_mut(&id) {
233 if let Some(ref mut roles_obj) = server.roles {
234 roles_obj.remove(&role);
235 }
236 }
237 }
238}
239
240cache_field! {MemberId, Member, get_member, get_member_data, members, commit_member, id}
241
242impl Cache {
243 pub async fn patch_member(
244 &self,
245 member_id: MemberId,
246 patch: impl FnOnce() -> PartialMember,
247 remove: Option<MemberField>,
248 ) {
249 let mut lock = self.members.write().await;
250 if let Some(member) = lock.get_mut(&member_id) {
251 let patch = patch();
252
253 patch.patch(member);
254 if let Some(remove) = remove {
255 remove.remove_patch(member);
256 }
257 }
258 }
259
260 pub async fn get_members_aggregate<T, F>(&self, f: F) -> T
261 where
262 F: FnOnce(MemberIter) -> T,
263 {
264 f(MemberIter(self.members.read().await.values()))
265 }
266}
267
268pub struct MemberIter<'a>(std::collections::hash_map::Values<'a, MemberId, Member>);
269
270impl<'a> Iterator for MemberIter<'a> {
271 type Item = &'a Member;
272
273 fn next(&mut self) -> Option<Self::Item> {
274 self.0.next()
275 }
276
277 fn size_hint(&self) -> (usize, Option<usize>) {
278 self.0.size_hint()
279 }
280}
281
282cache_field! {ChannelId, Channel, get_channel, get_channel_data, channels}
283
284impl Cache {
285 pub async fn commit_channel(&self, channel: &Channel) {
286 self.channels
287 .write()
288 .await
289 .insert(channel.id(), channel.clone());
290 }
291
292 pub async fn patch_channel(
293 &self,
294 channel_id: ChannelId,
295 patch: impl FnOnce() -> PartialChannel,
296 remove: Option<ChannelField>,
297 ) {
298 let mut lock = self.channels.write().await;
299 if let Some(channel) = lock.get_mut(&channel_id) {
300 let patch = patch();
301
302 patch.patch(channel);
303 if let Some(remove) = remove {
304 remove.remove_patch(channel);
305 }
306 }
307 }
308
309 pub async fn delete_channel(&self, channel_id: ChannelId) {
310 self.channels.write().await.remove(&channel_id);
311 }
312
313 pub async fn get_channels_aggregate<T, F>(&self, f: F) -> T
314 where
315 F: FnOnce(ChannelIter) -> T,
316 {
317 f(ChannelIter(self.channels.read().await.values()))
318 }
319}
320
321pub struct ChannelIter<'a>(std::collections::hash_map::Values<'a, ChannelId, Channel>);
322
323impl<'a> Iterator for ChannelIter<'a> {
324 type Item = &'a Channel;
325
326 fn next(&mut self) -> Option<Self::Item> {
327 self.0.next()
328 }
329
330 fn size_hint(&self) -> (usize, Option<usize>) {
331 self.0.size_hint()
332 }
333}
334
335impl Cache {
336 pub async fn get_message(&self, channel: ChannelId, message: MessageId) -> Option<Message> {
337 self.get_message_data(channel, message, Clone::clone).await
338 }
339
340 pub async fn get_message_data<F, T>(
341 &self,
342 channel: ChannelId,
343 message: MessageId,
344 f: F,
345 ) -> Option<T>
346 where
347 F: FnOnce(&Message) -> T,
348 {
349 self.messages
350 .read()
351 .await
352 .get(&channel)?
353 .get(&message)
354 .map(f)
355 }
356
357 pub async fn commit_message(&self, message: &Message) {
358 if self.config.messages == 0 {
359 return;
360 }
361
362 let mut queue_lock = self.message_queue.write().await;
363 let deque = queue_lock
364 .entry(message.channel)
365 .or_insert_with(VecDeque::new);
366
367 match self.messages.write().await.entry(message.channel) {
368 Entry::Occupied(mut m) => {
369 m.get_mut().insert(message.id, message.clone());
370
371 deque.push_back(message.id);
372
373 if deque.len() > self.config.messages {
374 if let Some(oldest) = deque.pop_front() {
375 m.get_mut().remove(&oldest);
376 }
377 }
378 }
379 Entry::Vacant(v) => {
380 deque.push_back(message.id);
381 v.insert(HashMap::from_iter([(message.id, message.clone())]));
382 }
383 }
384 }
385
386 pub async fn patch_message(
387 &self,
388 channel_id: ChannelId,
389 message_id: MessageId,
390 patch: impl FnOnce() -> PartialMessage,
391 ) {
392 let mut lock = self.messages.write().await;
393 if let Some(ch) = lock.get_mut(&channel_id) {
394 if let Some(message) = ch.get_mut(&message_id) {
395 let patch = patch();
396
397 patch.patch(message);
398 }
399 }
400 }
401
402 pub async fn get_messages_aggregate<T, F>(&self, channel_id: ChannelId, f: F) -> Option<T>
403 where
404 F: FnOnce(MessageIter) -> T,
405 {
406 let lock = self.messages.read().await;
407 let iter = lock.get(&channel_id)?.values();
408 Some(f(MessageIter(iter)))
409 }
410}
411
412pub struct MessageIter<'a>(std::collections::hash_map::Values<'a, MessageId, Message>);
413
414impl<'a> Iterator for MessageIter<'a> {
415 type Item = &'a Message;
416
417 fn next(&mut self) -> Option<Self::Item> {
418 self.0.next()
419 }
420
421 fn size_hint(&self) -> (usize, Option<usize>) {
422 self.0.size_hint()
423 }
424}
425
426pub trait HasCache: Send + Sync {
427 fn get_cache(&self) -> Option<&Cache>;
428}
429
430impl HasCache for Cache {
431 fn get_cache(&self) -> Option<&Cache> {
432 Some(self)
433 }
434}
435
436impl HasCache for Arc<Cache> {
437 fn get_cache(&self) -> Option<&Cache> {
438 Some(self)
439 }
440}
441
442#[async_trait]
443pub trait CommitToCache: Send + Sync {
444 async fn commit_to_cache<C: HasCache>(self, c: &C) -> Self
445 where
446 Self: Sized,
447 {
448 self.commit_to_cache_ref(c).await;
449
450 self
451 }
452
453 async fn commit_to_cache_ref<C: HasCache>(&self, c: &C) {
454 if let Some(c) = c.get_cache() {
455 Self::__commit_to_cache(self, c).await;
456 }
457 }
458
459 async fn __commit_to_cache(&self, cache: &Cache);
460}
461
462#[async_trait]
463impl CommitToCache for User {
464 async fn __commit_to_cache(&self, cache: &Cache) {
465 cache.commit_user(self).await;
466 }
467}
468
469#[async_trait]
470impl CommitToCache for Channel {
471 async fn __commit_to_cache(&self, cache: &Cache) {
472 cache.commit_channel(self).await;
473 }
474}
475
476#[async_trait]
477impl CommitToCache for Server {
478 async fn __commit_to_cache(&self, cache: &Cache) {
479 cache.commit_server(self).await;
480 }
481}
482
483#[async_trait]
484impl CommitToCache for Member {
485 async fn __commit_to_cache(&self, cache: &Cache) {
486 cache.commit_member(self).await;
487 }
488}
489
490#[async_trait]
491impl CommitToCache for Message {
492 async fn __commit_to_cache(&self, cache: &Cache) {
493 cache.commit_message(self).await;
494 }
495}
496
497#[async_trait]
498impl CommitToCache for ServerToClientEvent {
499 async fn __commit_to_cache(&self, cache: &Cache) {
500 #[allow(unused_variables)]
501 match self {
502 ServerToClientEvent::Error { .. } => {}
503 ServerToClientEvent::Authenticated => {}
504 ServerToClientEvent::Pong { .. } => {}
505 ServerToClientEvent::Ready { event } => {
506 for user in event.users.iter() {
507 user.commit_to_cache_ref(cache).await;
508 }
509 for channel in event.channels.iter() {
510 channel.commit_to_cache_ref(cache).await;
511 }
512 for server in event.servers.iter() {
513 server.commit_to_cache_ref(cache).await;
514 }
515 for member in event.members.iter() {
516 member.commit_to_cache_ref(cache).await;
517 }
518 }
519 ServerToClientEvent::Message { message } => {
520 message.commit_to_cache_ref(cache).await;
521 }
522 ServerToClientEvent::MessageUpdate { id, channel, data } => {
523 cache.patch_message(*channel, *id, || data.clone()).await;
524 }
525 ServerToClientEvent::MessageDelete { id, channel } => {}
526 ServerToClientEvent::ChannelCreate { channel } => {
527 cache.commit_channel(channel).await;
528 }
529 ServerToClientEvent::ChannelUpdate { id, data, clear } => {
530 cache.patch_channel(*id, || data.clone(), *clear).await;
531 }
532 ServerToClientEvent::ChannelDelete { id } => {
533 cache.delete_channel(*id).await;
534 }
535 ServerToClientEvent::ChannelGroupJoin { id, user } => {}
536 ServerToClientEvent::ChannelGroupLeave { id, user } => {}
537 ServerToClientEvent::ChannelStartTyping { id, user } => {}
538 ServerToClientEvent::ChannelStopTyping { id, user } => {}
539 ServerToClientEvent::ChannelAck {
540 id,
541 user,
542 message_id,
543 } => {}
544 ServerToClientEvent::ServerUpdate { id, data, clear } => {
545 cache.patch_server(*id, || data.clone(), *clear).await;
546 }
547 ServerToClientEvent::ServerDelete { id } => {
548 cache.delete_server(*id).await;
549 }
550 ServerToClientEvent::ServerMemberUpdate { id, data, clear } => {
551 cache.patch_member(*id, || data.clone(), *clear).await;
552 }
553 ServerToClientEvent::ServerMemberJoin { id, user } => {}
554 ServerToClientEvent::ServerMemberLeave { id, user } => {}
555 ServerToClientEvent::ServerRoleUpdate {
556 id,
557 role_id,
558 data,
559 clear,
560 } => {
561 cache
562 .patch_role(*id, *role_id, || data.clone(), *clear)
563 .await;
564 }
565 ServerToClientEvent::ServerRoleDelete { id, role_id } => {
566 cache.delete_role(*id, *role_id).await;
567 }
568 ServerToClientEvent::UserUpdate { id, data, clear } => {
569 cache.patch_user(*id, || data.clone(), *clear).await;
570 }
571 ServerToClientEvent::UserRelationship { id, user, status } => {}
572 }
573 }
574}