revolt_database/util/
bulk_permissions.rs

1use std::{collections::HashMap, hash::RandomState};
2
3use revolt_permissions::{
4    ChannelPermission, ChannelType, Override, OverrideField, PermissionValue, ALLOW_IN_TIMEOUT,
5    DEFAULT_PERMISSION_DIRECT_MESSAGE,
6};
7
8use crate::{Channel, Database, Member, Server, User};
9
10#[derive(Clone)]
11pub struct BulkDatabasePermissionQuery<'a> {
12    #[allow(dead_code)]
13    database: &'a Database,
14
15    server: Server,
16    channel: Option<Channel>,
17    users: Option<Vec<User>>,
18    members: Option<Vec<Member>>,
19
20    // In case the users or members are fetched as part of the permissions checking operation
21    pub(crate) cached_users: Option<Vec<User>>,
22    pub(crate) cached_members: Option<Vec<Member>>,
23
24    cached_member_perms: Option<HashMap<String, PermissionValue>>,
25}
26
27impl<'z, 'x> BulkDatabasePermissionQuery<'x> {
28    pub async fn members_can_see_channel(&'z mut self) -> HashMap<String, bool>
29    where
30        'z: 'x,
31    {
32        let member_perms = if self.cached_member_perms.is_some() {
33            // This isn't done as an if let to prevent borrow checker errors with the mut self call when the perms aren't cached.
34            let perms = self.cached_member_perms.as_ref().unwrap();
35            perms
36                .iter()
37                .map(|(m, p)| {
38                    (
39                        m.clone(),
40                        p.has_channel_permission(ChannelPermission::ViewChannel),
41                    )
42                })
43                .collect()
44        } else {
45            calculate_members_permissions(self)
46                .await
47                .iter()
48                .map(|(m, p)| {
49                    (
50                        m.clone(),
51                        p.has_channel_permission(ChannelPermission::ViewChannel),
52                    )
53                })
54                .collect()
55        };
56        member_perms
57    }
58}
59
60impl<'z> BulkDatabasePermissionQuery<'z> {
61    pub fn new(database: &Database, server: Server) -> BulkDatabasePermissionQuery<'_> {
62        BulkDatabasePermissionQuery {
63            database,
64            server,
65            channel: None,
66            users: None,
67            members: None,
68            cached_members: None,
69            cached_users: None,
70            cached_member_perms: None,
71        }
72    }
73
74    pub async fn from_server_id<'a>(
75        db: &'a Database,
76        server: &str,
77    ) -> BulkDatabasePermissionQuery<'a> {
78        BulkDatabasePermissionQuery {
79            database: db,
80            server: db.fetch_server(server).await.unwrap(),
81            channel: None,
82            users: None,
83            members: None,
84            cached_members: None,
85            cached_users: None,
86            cached_member_perms: None,
87        }
88    }
89
90    pub fn channel(self, channel: &'z Channel) -> BulkDatabasePermissionQuery<'z> {
91        BulkDatabasePermissionQuery {
92            channel: Some(channel.clone()),
93            ..self
94        }
95    }
96
97    pub async fn from_channel_id(self, channel_id: String) -> BulkDatabasePermissionQuery<'z> {
98        let channel = self
99            .database
100            .fetch_channel(channel_id.as_str())
101            .await
102            .expect("Valid channel id");
103
104        drop(channel_id);
105
106        BulkDatabasePermissionQuery {
107            channel: Some(channel),
108            ..self
109        }
110    }
111
112    pub fn members(self, members: &'z [Member]) -> BulkDatabasePermissionQuery<'z> {
113        BulkDatabasePermissionQuery {
114            members: Some(members.to_owned()),
115            cached_member_perms: None,
116            users: None,
117            cached_members: None,
118            cached_users: None,
119            ..self
120        }
121    }
122
123    pub fn users(self, users: &'z [User]) -> BulkDatabasePermissionQuery<'z> {
124        BulkDatabasePermissionQuery {
125            users: Some(users.to_owned()),
126            cached_member_perms: None,
127            members: None,
128            cached_members: None,
129            cached_users: None,
130            ..self
131        }
132    }
133
134    /// Get the default channel permissions
135    /// Group channel defaults should be mapped to an allow-only override
136    #[allow(dead_code)]
137    async fn get_default_channel_permissions(&mut self) -> Override {
138        if let Some(channel) = &self.channel {
139            match channel {
140                Channel::Group { permissions, .. } => Override {
141                    allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64,
142                    deny: 0,
143                },
144                Channel::TextChannel {
145                    default_permissions,
146                    ..
147                } => default_permissions.unwrap_or_default().into(),
148                _ => Default::default(),
149            }
150        } else {
151            Default::default()
152        }
153    }
154
155    #[allow(dead_code, deprecated)]
156    fn get_channel_type(&mut self) -> ChannelType {
157        if let Some(channel) = &self.channel {
158            match channel {
159                Channel::DirectMessage { .. } => ChannelType::DirectMessage,
160                Channel::Group { .. } => ChannelType::Group,
161                Channel::SavedMessages { .. } => ChannelType::SavedMessages,
162                Channel::TextChannel { .. } => ChannelType::ServerChannel,
163            }
164        } else {
165            ChannelType::Unknown
166        }
167    }
168
169    /// Get the ordered role overrides (from lowest to highest) for this member in this channel
170    #[allow(dead_code)]
171    async fn get_channel_role_overrides(&mut self) -> &HashMap<String, OverrideField> {
172        if let Some(channel) = &self.channel {
173            match channel {
174                Channel::TextChannel {
175                    role_permissions, ..
176                } => role_permissions,
177                _ => panic!("Not supported for non-server channels"),
178            }
179        } else {
180            panic!("No channel added to query")
181        }
182    }
183}
184
185/// Calculate members permissions in a server channel.
186async fn calculate_members_permissions<'a>(
187    query: &'a mut BulkDatabasePermissionQuery<'a>,
188) -> HashMap<String, PermissionValue> {
189    let mut resp = HashMap::new();
190
191    let (_, channel_role_permissions, channel_default_permissions) = match query
192        .channel
193        .as_ref()
194        .expect("A channel must be assigned to calculate channel permissions")
195        .clone()
196    {
197        Channel::TextChannel {
198            id,
199            role_permissions,
200            default_permissions,
201            ..
202        } => (id, role_permissions, default_permissions),
203        _ => panic!("Calculation of member permissions must be done on a server channel"),
204    };
205
206    if query.users.is_none() {
207        let ids: Vec<String> = query
208            .members
209            .as_ref()
210            .expect("No users or members added to the query")
211            .iter()
212            .map(|m| m.id.user.clone())
213            .collect();
214
215        query.cached_users = Some(
216            query
217                .database
218                .fetch_users(&ids[..])
219                .await
220                .expect("Failed to get data from the db"),
221        );
222
223        query.users = Some(query.cached_users.as_ref().unwrap().to_vec())
224    }
225
226    let users = query.users.as_ref().unwrap();
227
228    if query.members.is_none() {
229        let ids: Vec<String> = query
230            .users
231            .as_ref()
232            .expect("No users or members added to the query")
233            .iter()
234            .map(|m| m.id.clone())
235            .collect();
236
237        query.cached_members = Some(
238            query
239                .database
240                .fetch_members(&query.server.id, &ids[..])
241                .await
242                .expect("Failed to get data from the db"),
243        );
244        query.members = Some(query.cached_members.as_ref().unwrap().to_vec())
245    }
246
247    let members: HashMap<&String, &Member, RandomState> = HashMap::from_iter(
248        query
249            .members
250            .as_ref()
251            .unwrap()
252            .iter()
253            .map(|m| (&m.id.user, m)),
254    );
255
256    for user in users {
257        let member = members.get(&user.id);
258
259        // User isn't a part of the server
260        if member.is_none() {
261            resp.insert(user.id.clone(), 0_u64.into());
262            continue;
263        }
264
265        let member = *member.unwrap();
266
267        if user.privileged {
268            resp.insert(
269                user.id.clone(),
270                PermissionValue::from(ChannelPermission::GrantAllSafe),
271            );
272            continue;
273        }
274
275        if user.id == query.server.owner {
276            resp.insert(
277                user.id.clone(),
278                PermissionValue::from(ChannelPermission::GrantAllSafe),
279            );
280            continue;
281        }
282
283        // Get the user's server permissions
284        let mut permission = calculate_server_permissions(&query.server, user, member);
285
286        if let Some(defaults) = channel_default_permissions {
287            permission.apply(defaults.into());
288        }
289
290        // Get the applicable role overrides
291        let mut roles = channel_role_permissions
292            .iter()
293            .filter(|(id, _)| member.roles.contains(id))
294            .filter_map(|(id, permission)| {
295                query.server.roles.get(id).map(|role| {
296                    let v: Override = (*permission).into();
297                    (role.rank, v)
298                })
299            })
300            .collect::<Vec<(i64, Override)>>();
301
302        roles.sort_by(|a, b| b.0.cmp(&a.0));
303        let overrides = roles.into_iter().map(|(_, v)| v);
304
305        for role_override in overrides {
306            permission.apply(role_override)
307        }
308
309        resp.insert(user.id.clone(), permission);
310    }
311
312    resp
313}
314
315/// Calculates a member's server permissions
316fn calculate_server_permissions(server: &Server, user: &User, member: &Member) -> PermissionValue {
317    if user.privileged || server.owner == user.id {
318        return ChannelPermission::GrantAllSafe.into();
319    }
320
321    let mut permissions: PermissionValue = server.default_permissions.into();
322
323    let mut roles = server
324        .roles
325        .iter()
326        .filter(|(id, _)| member.roles.contains(id))
327        .map(|(_, role)| {
328            let v: Override = role.permissions.into();
329            (role.rank, v)
330        })
331        .collect::<Vec<(i64, Override)>>();
332
333    roles.sort_by(|a, b| b.0.cmp(&a.0));
334    let role_overrides: Vec<Override> = roles.into_iter().map(|(_, v)| v).collect();
335
336    for role in role_overrides {
337        permissions.apply(role);
338    }
339
340    if member.in_timeout() {
341        permissions.restrict(*ALLOW_IN_TIMEOUT);
342    }
343
344    permissions
345}