revolt_database/util/
permissions.rs

1use std::borrow::Cow;
2
3use revolt_permissions::{
4    calculate_user_permissions, ChannelType, Override, PermissionQuery, PermissionValue,
5    RelationshipStatus, DEFAULT_PERMISSION_DIRECT_MESSAGE,
6};
7
8use crate::{Channel, Database, Member, Server, User};
9
10/// Permissions calculator
11#[derive(Clone)]
12pub struct DatabasePermissionQuery<'a> {
13    #[allow(dead_code)]
14    database: &'a Database,
15
16    perspective: &'a User,
17    user: Option<Cow<'a, User>>,
18    channel: Option<Cow<'a, Channel>>,
19    server: Option<Cow<'a, Server>>,
20    member: Option<Cow<'a, Member>>,
21
22    // flag_known_relationship: Option<&'a RelationshipStatus>,
23    cached_user_permission: Option<PermissionValue>,
24    cached_mutual_connection: Option<bool>,
25    cached_permission: Option<u64>,
26}
27
28#[async_trait]
29impl PermissionQuery for DatabasePermissionQuery<'_> {
30    // * For calculating user permission
31
32    /// Is our perspective user privileged?
33    async fn are_we_privileged(&mut self) -> bool {
34        self.perspective.privileged
35    }
36
37    /// Is our perspective user a bot?
38    async fn are_we_a_bot(&mut self) -> bool {
39        self.perspective.bot.is_some()
40    }
41
42    /// Is our perspective user and the currently selected user the same?
43    async fn are_the_users_same(&mut self) -> bool {
44        if let Some(other_user) = &self.user {
45            self.perspective.id == other_user.id
46        } else {
47            false
48        }
49    }
50
51    /// Get the relationship with have with the currently selected user
52    async fn user_relationship(&mut self) -> RelationshipStatus {
53        if let Some(other_user) = &self.user {
54            if self.perspective.id == other_user.id {
55                return RelationshipStatus::User;
56            } else if let Some(bot) = &other_user.bot {
57                // For the purposes of permissions checks,
58                // assume owner is the same as bot
59                if self.perspective.id == bot.owner {
60                    return RelationshipStatus::User;
61                }
62            }
63
64            if let Some(relations) = &self.perspective.relations {
65                for entry in relations {
66                    if entry.id == other_user.id {
67                        return match entry.status {
68                            crate::RelationshipStatus::None => RelationshipStatus::None,
69                            crate::RelationshipStatus::User => RelationshipStatus::User,
70                            crate::RelationshipStatus::Friend => RelationshipStatus::Friend,
71                            crate::RelationshipStatus::Outgoing => RelationshipStatus::Outgoing,
72                            crate::RelationshipStatus::Incoming => RelationshipStatus::Incoming,
73                            crate::RelationshipStatus::Blocked => RelationshipStatus::Blocked,
74                            crate::RelationshipStatus::BlockedOther => {
75                                RelationshipStatus::BlockedOther
76                            }
77                        };
78                    }
79                }
80            }
81        }
82
83        RelationshipStatus::None
84    }
85
86    /// Whether the currently selected user is a bot
87    async fn user_is_bot(&mut self) -> bool {
88        if let Some(other_user) = &self.user {
89            other_user.bot.is_some()
90        } else {
91            false
92        }
93    }
94
95    /// Do we have a mutual connection with the currently selected user?
96    async fn have_mutual_connection(&mut self) -> bool {
97        if let Some(value) = self.cached_mutual_connection {
98            value
99        } else if let Some(user) = &self.user {
100            let value = self
101                .perspective
102                .has_mutual_connection(self.database, &user.id)
103                .await
104                .unwrap_or_default();
105
106            self.cached_mutual_connection = Some(value);
107            value
108        } else {
109            false
110        }
111    }
112
113    // * For calculating server permission
114
115    /// Is our perspective user the server's owner?
116    async fn are_we_server_owner(&mut self) -> bool {
117        if let Some(server) = &self.server {
118            server.owner == self.perspective.id
119        } else {
120            false
121        }
122    }
123
124    /// Is our perspective user a member of the server?
125    async fn are_we_a_member(&mut self) -> bool {
126        if let Some(server) = &self.server {
127            if self.member.is_some() {
128                true
129            } else if let Ok(member) = self
130                .database
131                .fetch_member(&server.id, &self.perspective.id)
132                .await
133            {
134                self.member = Some(Cow::Owned(member));
135                true
136            } else {
137                false
138            }
139        } else {
140            false
141        }
142    }
143
144    /// Get default server permission
145    async fn get_default_server_permissions(&mut self) -> u64 {
146        if let Some(server) = &self.server {
147            server.default_permissions as u64
148        } else {
149            0
150        }
151    }
152
153    /// Get the ordered role overrides (from lowest to highest) for this member in this server
154    async fn get_our_server_role_overrides(&mut self) -> Vec<Override> {
155        if let Some(server) = &self.server {
156            let member_roles = self
157                .member
158                .as_ref()
159                .map(|member| member.roles.clone())
160                .unwrap_or_default();
161
162            let mut roles = server
163                .roles
164                .iter()
165                .filter(|(id, _)| member_roles.contains(id))
166                .map(|(_, role)| {
167                    let v: Override = role.permissions.into();
168                    (role.rank, v)
169                })
170                .collect::<Vec<(i64, Override)>>();
171
172            roles.sort_by(|a, b| b.0.cmp(&a.0));
173            roles.into_iter().map(|(_, v)| v).collect()
174        } else {
175            vec![]
176        }
177    }
178
179    /// Is our perspective user timed out on this server?
180    async fn are_we_timed_out(&mut self) -> bool {
181        if let Some(member) = &self.member {
182            member.in_timeout()
183        } else {
184            false
185        }
186    }
187
188    // * For calculating channel permission
189
190    /// Get the type of the channel
191    async fn get_channel_type(&mut self) -> ChannelType {
192        if let Some(channel) = &self.channel {
193            match channel {
194                Cow::Borrowed(Channel::DirectMessage { .. })
195                | Cow::Owned(Channel::DirectMessage { .. }) => ChannelType::DirectMessage,
196                Cow::Borrowed(Channel::Group { .. }) | Cow::Owned(Channel::Group { .. }) => {
197                    ChannelType::Group
198                }
199                Cow::Borrowed(Channel::SavedMessages { .. })
200                | Cow::Owned(Channel::SavedMessages { .. }) => ChannelType::SavedMessages,
201                Cow::Borrowed(Channel::TextChannel { .. })
202                | Cow::Owned(Channel::TextChannel { .. })
203                | Cow::Borrowed(Channel::VoiceChannel { .. })
204                | Cow::Owned(Channel::VoiceChannel { .. }) => ChannelType::ServerChannel,
205            }
206        } else {
207            ChannelType::Unknown
208        }
209    }
210
211    /// Get the default channel permissions
212    /// Group channel defaults should be mapped to an allow-only override
213    async fn get_default_channel_permissions(&mut self) -> Override {
214        if let Some(channel) = &self.channel {
215            match channel {
216                Cow::Borrowed(Channel::Group { permissions, .. })
217                | Cow::Owned(Channel::Group { permissions, .. }) => Override {
218                    allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64,
219                    deny: 0,
220                },
221                Cow::Borrowed(Channel::TextChannel {
222                    default_permissions,
223                    ..
224                })
225                | Cow::Owned(Channel::TextChannel {
226                    default_permissions,
227                    ..
228                })
229                | Cow::Borrowed(Channel::VoiceChannel {
230                    default_permissions,
231                    ..
232                })
233                | Cow::Owned(Channel::VoiceChannel {
234                    default_permissions,
235                    ..
236                }) => default_permissions.unwrap_or_default().into(),
237                _ => Default::default(),
238            }
239        } else {
240            Default::default()
241        }
242    }
243
244    /// Get the ordered role overrides (from lowest to highest) for this member in this channel
245    async fn get_our_channel_role_overrides(&mut self) -> Vec<Override> {
246        if let Some(channel) = &self.channel {
247            match channel {
248                Cow::Borrowed(Channel::TextChannel {
249                    role_permissions, ..
250                })
251                | Cow::Owned(Channel::TextChannel {
252                    role_permissions, ..
253                })
254                | Cow::Borrowed(Channel::VoiceChannel {
255                    role_permissions, ..
256                })
257                | Cow::Owned(Channel::VoiceChannel {
258                    role_permissions, ..
259                }) => {
260                    if let Some(server) = &self.server {
261                        let member_roles = self
262                            .member
263                            .as_ref()
264                            .map(|member| member.roles.clone())
265                            .unwrap_or_default();
266
267                        let mut roles = role_permissions
268                            .iter()
269                            .filter(|(id, _)| member_roles.contains(id))
270                            .filter_map(|(id, permission)| {
271                                server.roles.get(id).map(|role| {
272                                    let v: Override = (*permission).into();
273                                    (role.rank, v)
274                                })
275                            })
276                            .collect::<Vec<(i64, Override)>>();
277
278                        roles.sort_by(|a, b| b.0.cmp(&a.0));
279                        roles.into_iter().map(|(_, v)| v).collect()
280                    } else {
281                        vec![]
282                    }
283                }
284                _ => vec![],
285            }
286        } else {
287            vec![]
288        }
289    }
290
291    /// Do we own this group or saved messages channel if it is one of those?
292    async fn do_we_own_the_channel(&mut self) -> bool {
293        if let Some(channel) = &self.channel {
294            match channel {
295                Cow::Borrowed(Channel::Group { owner, .. })
296                | Cow::Owned(Channel::Group { owner, .. }) => owner == &self.perspective.id,
297                Cow::Borrowed(Channel::SavedMessages { user, .. })
298                | Cow::Owned(Channel::SavedMessages { user, .. }) => user == &self.perspective.id,
299                _ => false,
300            }
301        } else {
302            false
303        }
304    }
305
306    /// Are we a recipient of this channel?
307    async fn are_we_part_of_the_channel(&mut self) -> bool {
308        if let Some(
309            Cow::Borrowed(Channel::DirectMessage { recipients, .. })
310            | Cow::Owned(Channel::DirectMessage { recipients, .. })
311            | Cow::Borrowed(Channel::Group { recipients, .. })
312            | Cow::Owned(Channel::Group { recipients, .. }),
313        ) = &self.channel
314        {
315            recipients.contains(&self.perspective.id)
316        } else {
317            false
318        }
319    }
320
321    /// Set the current user as the recipient of this channel
322    /// (this will only ever be called for DirectMessage channels, use unimplemented!() for other code paths)
323    async fn set_recipient_as_user(&mut self) {
324        if let Some(channel) = &self.channel {
325            match channel {
326                Cow::Borrowed(Channel::DirectMessage { recipients, .. })
327                | Cow::Owned(Channel::DirectMessage { recipients, .. }) => {
328                    let recipient_id = recipients
329                        .iter()
330                        .find(|recipient| recipient != &&self.perspective.id)
331                        .expect("Missing recipient for DM");
332
333                    if let Ok(user) = self.database.fetch_user(recipient_id).await {
334                        self.user.replace(Cow::Owned(user));
335                    }
336                }
337                _ => unimplemented!(),
338            }
339        }
340    }
341
342    /// Set the current server as the server owning this channel
343    /// (this will only ever be called for server channels, use unimplemented!() for other code paths)
344    async fn set_server_from_channel(&mut self) {
345        if let Some(channel) = &self.channel {
346            match channel {
347                Cow::Borrowed(Channel::TextChannel { server, .. })
348                | Cow::Owned(Channel::TextChannel { server, .. })
349                | Cow::Borrowed(Channel::VoiceChannel { server, .. })
350                | Cow::Owned(Channel::VoiceChannel { server, .. }) => {
351                    if let Some(known_server) =
352                        // I'm not sure why I can't just pattern match both at once here?
353                        // It throws some weird error and the provided fix doesn't work :/
354                        if let Some(Cow::Borrowed(known_server)) = self.server {
355                                Some(known_server)
356                            } else if let Some(Cow::Owned(ref known_server)) = self.server {
357                                Some(known_server)
358                            } else {
359                                None
360                            }
361                    {
362                        if server == &known_server.id {
363                            // Already cached, return early.
364                            return;
365                        }
366                    }
367
368                    if let Ok(server) = self.database.fetch_server(server).await {
369                        self.server.replace(Cow::Owned(server));
370                    }
371                }
372                _ => unimplemented!(),
373            }
374        }
375    }
376}
377
378impl<'a> DatabasePermissionQuery<'a> {
379    /// Create a new permission calculator
380    pub fn new(database: &'a Database, perspective: &'a User) -> DatabasePermissionQuery<'a> {
381        DatabasePermissionQuery {
382            database,
383            perspective,
384            user: None,
385            channel: None,
386            server: None,
387            member: None,
388
389            cached_mutual_connection: None,
390            cached_user_permission: None,
391            cached_permission: None,
392        }
393    }
394
395    /// Calculate the user permission value
396    pub async fn calc_user(mut self) -> DatabasePermissionQuery<'a> {
397        if self.cached_user_permission.is_some() {
398            return self;
399        }
400
401        if self.user.is_none() {
402            panic!("Expected `PermissionCalculator.user to exist.");
403        }
404
405        DatabasePermissionQuery {
406            cached_user_permission: Some(calculate_user_permissions(&mut self).await),
407            ..self
408        }
409    }
410
411    /// Calculate the permission value
412    pub async fn calc(self) -> DatabasePermissionQuery<'a> {
413        if self.cached_permission.is_some() {
414            return self;
415        }
416
417        self
418    }
419
420    /// Use user
421    pub fn user(self, user: &'a User) -> DatabasePermissionQuery<'a> {
422        DatabasePermissionQuery {
423            user: Some(Cow::Borrowed(user)),
424            ..self
425        }
426    }
427
428    /// Use channel
429    pub fn channel(self, channel: &'a Channel) -> DatabasePermissionQuery<'a> {
430        DatabasePermissionQuery {
431            channel: Some(Cow::Borrowed(channel)),
432            ..self
433        }
434    }
435
436    /// Use server
437    pub fn server(self, server: &'a Server) -> DatabasePermissionQuery<'a> {
438        DatabasePermissionQuery {
439            server: Some(Cow::Borrowed(server)),
440            ..self
441        }
442    }
443
444    /// Use member
445    pub fn member(self, member: &'a Member) -> DatabasePermissionQuery<'a> {
446        DatabasePermissionQuery {
447            member: Some(Cow::Borrowed(member)),
448            ..self
449        }
450    }
451
452    /// Access the underlying user
453    pub fn user_ref(&self) -> &Option<Cow<User>> {
454        &self.user
455    }
456
457    /// Access the underlying server
458    pub fn channel_ref(&self) -> &Option<Cow<Channel>> {
459        &self.channel
460    }
461
462    /// Access the underlying server
463    pub fn server_ref(&self) -> &Option<Cow<Server>> {
464        &self.server
465    }
466
467    /// Access the underlying member
468    pub fn member_ref(&self) -> &Option<Cow<Member>> {
469        &self.member
470    }
471
472    /// Get the known member's current ranking
473    pub fn get_member_rank(&self) -> Option<i64> {
474        self.member
475            .as_ref()
476            .map(|member| member.get_ranking(self.server.as_ref().unwrap()))
477    }
478}
479
480/// Short-hand for creating a permission calculator
481pub fn perms<'a>(database: &'a Database, perspective: &'a User) -> DatabasePermissionQuery<'a> {
482    DatabasePermissionQuery::new(database, perspective)
483}