revolt_database/util/
bulk_permissions.rs

1use std::{collections::HashMap, hash::RandomState};
2
3use revolt_permissions::{
4    ChannelPermission, ChannelType, Override, OverrideField, PermissionValue, ALLOW_IN_TIMEOUT,
5    DEFAULT_PERMISSION_DIRECT_MESSAGE,
6};
7
8use crate::{Channel, Database, Member, Server, User};
9
10#[derive(Clone)]
11pub struct BulkDatabasePermissionQuery<'a> {
12    #[allow(dead_code)]
13    database: &'a Database,
14
15    server: Server,
16    channel: Option<Channel>,
17    users: Option<Vec<User>>,
18    members: Option<Vec<Member>>,
19
20    // In case the users or members are fetched as part of the permissions checking operation
21    pub(crate) cached_users: Option<Vec<User>>,
22    pub(crate) cached_members: Option<Vec<Member>>,
23
24    cached_member_perms: Option<HashMap<String, PermissionValue>>,
25}
26
27impl<'z, 'x> BulkDatabasePermissionQuery<'x> {
28    pub async fn members_can_see_channel(&'z mut self) -> HashMap<String, bool>
29    where
30        'z: 'x,
31    {
32        let member_perms = if self.cached_member_perms.is_some() {
33            // This isn't done as an if let to prevent borrow checker errors with the mut self call when the perms aren't cached.
34            let perms = self.cached_member_perms.as_ref().unwrap();
35            perms
36                .iter()
37                .map(|(m, p)| {
38                    (
39                        m.clone(),
40                        p.has_channel_permission(ChannelPermission::ViewChannel),
41                    )
42                })
43                .collect()
44        } else {
45            calculate_members_permissions(self)
46                .await
47                .iter()
48                .map(|(m, p)| {
49                    (
50                        m.clone(),
51                        p.has_channel_permission(ChannelPermission::ViewChannel),
52                    )
53                })
54                .collect()
55        };
56        member_perms
57    }
58}
59
60impl<'z> BulkDatabasePermissionQuery<'z> {
61    pub fn new(database: &Database, server: Server) -> BulkDatabasePermissionQuery<'_> {
62        BulkDatabasePermissionQuery {
63            database,
64            server,
65            channel: None,
66            users: None,
67            members: None,
68            cached_members: None,
69            cached_users: None,
70            cached_member_perms: None,
71        }
72    }
73
74    pub async fn from_server_id<'a>(
75        db: &'a Database,
76        server: &str,
77    ) -> BulkDatabasePermissionQuery<'a> {
78        BulkDatabasePermissionQuery {
79            database: db,
80            server: db.fetch_server(server).await.unwrap(),
81            channel: None,
82            users: None,
83            members: None,
84            cached_members: None,
85            cached_users: None,
86            cached_member_perms: None,
87        }
88    }
89
90    pub fn channel(self, channel: &'z Channel) -> BulkDatabasePermissionQuery<'z> {
91        BulkDatabasePermissionQuery {
92            channel: Some(channel.clone()),
93            ..self
94        }
95    }
96
97    pub async fn from_channel_id(self, channel_id: String) -> BulkDatabasePermissionQuery<'z> {
98        let channel = self
99            .database
100            .fetch_channel(channel_id.as_str())
101            .await
102            .expect("Valid channel id");
103
104        drop(channel_id);
105
106        BulkDatabasePermissionQuery {
107            channel: Some(channel),
108            ..self
109        }
110    }
111
112    pub fn members(self, members: &'z [Member]) -> BulkDatabasePermissionQuery<'z> {
113        BulkDatabasePermissionQuery {
114            members: Some(members.to_owned()),
115            cached_member_perms: None,
116            users: None,
117            cached_members: None,
118            cached_users: None,
119            ..self
120        }
121    }
122
123    pub fn users(self, users: &'z [User]) -> BulkDatabasePermissionQuery<'z> {
124        BulkDatabasePermissionQuery {
125            users: Some(users.to_owned()),
126            cached_member_perms: None,
127            members: None,
128            cached_members: None,
129            cached_users: None,
130            ..self
131        }
132    }
133
134    /// Get the default channel permissions
135    /// Group channel defaults should be mapped to an allow-only override
136    #[allow(dead_code)]
137    async fn get_default_channel_permissions(&mut self) -> Override {
138        if let Some(channel) = &self.channel {
139            match channel {
140                Channel::Group { permissions, .. } => Override {
141                    allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64,
142                    deny: 0,
143                },
144                Channel::TextChannel {
145                    default_permissions,
146                    ..
147                }
148                | Channel::VoiceChannel {
149                    default_permissions,
150                    ..
151                } => default_permissions.unwrap_or_default().into(),
152                _ => Default::default(),
153            }
154        } else {
155            Default::default()
156        }
157    }
158
159    #[allow(dead_code)]
160    fn get_channel_type(&mut self) -> ChannelType {
161        if let Some(channel) = &self.channel {
162            match channel {
163                Channel::DirectMessage { .. } => ChannelType::DirectMessage,
164                Channel::Group { .. } => ChannelType::Group,
165                Channel::SavedMessages { .. } => ChannelType::SavedMessages,
166                Channel::TextChannel { .. } | Channel::VoiceChannel { .. } => {
167                    ChannelType::ServerChannel
168                }
169            }
170        } else {
171            ChannelType::Unknown
172        }
173    }
174
175    /// Get the ordered role overrides (from lowest to highest) for this member in this channel
176    #[allow(dead_code)]
177    async fn get_channel_role_overrides(&mut self) -> &HashMap<String, OverrideField> {
178        if let Some(channel) = &self.channel {
179            match channel {
180                Channel::TextChannel {
181                    role_permissions, ..
182                }
183                | Channel::VoiceChannel {
184                    role_permissions, ..
185                } => role_permissions,
186                _ => panic!("Not supported for non-server channels"),
187            }
188        } else {
189            panic!("No channel added to query")
190        }
191    }
192}
193
194/// Calculate members permissions in a server channel.
195async fn calculate_members_permissions<'a>(
196    query: &'a mut BulkDatabasePermissionQuery<'a>,
197) -> HashMap<String, PermissionValue> {
198    let mut resp = HashMap::new();
199
200    let (_, channel_role_permissions, channel_default_permissions) = match query
201        .channel
202        .as_ref()
203        .expect("A channel must be assigned to calculate channel permissions")
204        .clone()
205    {
206        Channel::TextChannel {
207            id,
208            role_permissions,
209            default_permissions,
210            ..
211        }
212        | Channel::VoiceChannel {
213            id,
214            role_permissions,
215            default_permissions,
216            ..
217        } => (id, role_permissions, default_permissions),
218        _ => panic!("Calculation of member permissions must be done on a server channel"),
219    };
220
221    if query.users.is_none() {
222        let ids: Vec<String> = query
223            .members
224            .as_ref()
225            .expect("No users or members added to the query")
226            .iter()
227            .map(|m| m.id.user.clone())
228            .collect();
229
230        query.cached_users = Some(
231            query
232                .database
233                .fetch_users(&ids[..])
234                .await
235                .expect("Failed to get data from the db"),
236        );
237
238        query.users = Some(query.cached_users.as_ref().unwrap().to_vec())
239    }
240
241    let users = query.users.as_ref().unwrap();
242
243    if query.members.is_none() {
244        let ids: Vec<String> = query
245            .users
246            .as_ref()
247            .expect("No users or members added to the query")
248            .iter()
249            .map(|m| m.id.clone())
250            .collect();
251
252        query.cached_members = Some(
253            query
254                .database
255                .fetch_members(&query.server.id, &ids[..])
256                .await
257                .expect("Failed to get data from the db"),
258        );
259        query.members = Some(query.cached_members.as_ref().unwrap().to_vec())
260    }
261
262    let members: HashMap<&String, &Member, RandomState> = HashMap::from_iter(
263        query
264            .members
265            .as_ref()
266            .unwrap()
267            .iter()
268            .map(|m| (&m.id.user, m)),
269    );
270
271    for user in users {
272        let member = members.get(&user.id);
273
274        // User isn't a part of the server
275        if member.is_none() {
276            resp.insert(user.id.clone(), 0_u64.into());
277            continue;
278        }
279
280        let member = *member.unwrap();
281
282        if user.privileged {
283            resp.insert(
284                user.id.clone(),
285                PermissionValue::from(ChannelPermission::GrantAllSafe),
286            );
287            continue;
288        }
289
290        if user.id == query.server.owner {
291            resp.insert(
292                user.id.clone(),
293                PermissionValue::from(ChannelPermission::GrantAllSafe),
294            );
295            continue;
296        }
297
298        // Get the user's server permissions
299        let mut permission = calculate_server_permissions(&query.server, user, member);
300
301        if let Some(defaults) = channel_default_permissions {
302            permission.apply(defaults.into());
303        }
304
305        // Get the applicable role overrides
306        let mut roles = channel_role_permissions
307            .iter()
308            .filter(|(id, _)| member.roles.contains(id))
309            .filter_map(|(id, permission)| {
310                query.server.roles.get(id).map(|role| {
311                    let v: Override = (*permission).into();
312                    (role.rank, v)
313                })
314            })
315            .collect::<Vec<(i64, Override)>>();
316
317        roles.sort_by(|a, b| b.0.cmp(&a.0));
318        let overrides = roles.into_iter().map(|(_, v)| v);
319
320        for role_override in overrides {
321            permission.apply(role_override)
322        }
323
324        resp.insert(user.id.clone(), permission);
325    }
326
327    resp
328}
329
330/// Calculates a member's server permissions
331fn calculate_server_permissions(server: &Server, user: &User, member: &Member) -> PermissionValue {
332    if user.privileged || server.owner == user.id {
333        return ChannelPermission::GrantAllSafe.into();
334    }
335
336    let mut permissions: PermissionValue = server.default_permissions.into();
337
338    let mut roles = server
339        .roles
340        .iter()
341        .filter(|(id, _)| member.roles.contains(id))
342        .map(|(_, role)| {
343            let v: Override = role.permissions.into();
344            (role.rank, v)
345        })
346        .collect::<Vec<(i64, Override)>>();
347
348    roles.sort_by(|a, b| b.0.cmp(&a.0));
349    let role_overrides: Vec<Override> = roles.into_iter().map(|(_, v)| v).collect();
350
351    for role in role_overrides {
352        permissions.apply(role);
353    }
354
355    if member.in_timeout() {
356        permissions.restrict(*ALLOW_IN_TIMEOUT);
357    }
358
359    permissions
360}