1use std::{collections::HashMap, hash::RandomState};
2
3use revolt_permissions::{
4 ChannelPermission, ChannelType, Override, OverrideField, PermissionValue, ALLOW_IN_TIMEOUT,
5 DEFAULT_PERMISSION_DIRECT_MESSAGE,
6};
7
8use crate::{Channel, Database, Member, Server, User};
9
10#[derive(Clone)]
11pub struct BulkDatabasePermissionQuery<'a> {
12 #[allow(dead_code)]
13 database: &'a Database,
14
15 server: Server,
16 channel: Option<Channel>,
17 users: Option<Vec<User>>,
18 members: Option<Vec<Member>>,
19
20 pub(crate) cached_users: Option<Vec<User>>,
22 pub(crate) cached_members: Option<Vec<Member>>,
23
24 cached_member_perms: Option<HashMap<String, PermissionValue>>,
25}
26
27impl<'z, 'x> BulkDatabasePermissionQuery<'x> {
28 pub async fn members_can_see_channel(&'z mut self) -> HashMap<String, bool>
29 where
30 'z: 'x,
31 {
32 let member_perms = if self.cached_member_perms.is_some() {
33 let perms = self.cached_member_perms.as_ref().unwrap();
35 perms
36 .iter()
37 .map(|(m, p)| {
38 (
39 m.clone(),
40 p.has_channel_permission(ChannelPermission::ViewChannel),
41 )
42 })
43 .collect()
44 } else {
45 calculate_members_permissions(self)
46 .await
47 .iter()
48 .map(|(m, p)| {
49 (
50 m.clone(),
51 p.has_channel_permission(ChannelPermission::ViewChannel),
52 )
53 })
54 .collect()
55 };
56 member_perms
57 }
58}
59
60impl<'z> BulkDatabasePermissionQuery<'z> {
61 pub fn new(database: &Database, server: Server) -> BulkDatabasePermissionQuery<'_> {
62 BulkDatabasePermissionQuery {
63 database,
64 server,
65 channel: None,
66 users: None,
67 members: None,
68 cached_members: None,
69 cached_users: None,
70 cached_member_perms: None,
71 }
72 }
73
74 pub async fn from_server_id<'a>(
75 db: &'a Database,
76 server: &str,
77 ) -> BulkDatabasePermissionQuery<'a> {
78 BulkDatabasePermissionQuery {
79 database: db,
80 server: db.fetch_server(server).await.unwrap(),
81 channel: None,
82 users: None,
83 members: None,
84 cached_members: None,
85 cached_users: None,
86 cached_member_perms: None,
87 }
88 }
89
90 pub fn channel(self, channel: &'z Channel) -> BulkDatabasePermissionQuery<'z> {
91 BulkDatabasePermissionQuery {
92 channel: Some(channel.clone()),
93 ..self
94 }
95 }
96
97 pub async fn from_channel_id(self, channel_id: String) -> BulkDatabasePermissionQuery<'z> {
98 let channel = self
99 .database
100 .fetch_channel(channel_id.as_str())
101 .await
102 .expect("Valid channel id");
103
104 drop(channel_id);
105
106 BulkDatabasePermissionQuery {
107 channel: Some(channel),
108 ..self
109 }
110 }
111
112 pub fn members(self, members: &'z [Member]) -> BulkDatabasePermissionQuery<'z> {
113 BulkDatabasePermissionQuery {
114 members: Some(members.to_owned()),
115 cached_member_perms: None,
116 users: None,
117 cached_members: None,
118 cached_users: None,
119 ..self
120 }
121 }
122
123 pub fn users(self, users: &'z [User]) -> BulkDatabasePermissionQuery<'z> {
124 BulkDatabasePermissionQuery {
125 users: Some(users.to_owned()),
126 cached_member_perms: None,
127 members: None,
128 cached_members: None,
129 cached_users: None,
130 ..self
131 }
132 }
133
134 #[allow(dead_code)]
137 async fn get_default_channel_permissions(&mut self) -> Override {
138 if let Some(channel) = &self.channel {
139 match channel {
140 Channel::Group { permissions, .. } => Override {
141 allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64,
142 deny: 0,
143 },
144 Channel::TextChannel {
145 default_permissions,
146 ..
147 }
148 | Channel::VoiceChannel {
149 default_permissions,
150 ..
151 } => default_permissions.unwrap_or_default().into(),
152 _ => Default::default(),
153 }
154 } else {
155 Default::default()
156 }
157 }
158
159 #[allow(dead_code)]
160 fn get_channel_type(&mut self) -> ChannelType {
161 if let Some(channel) = &self.channel {
162 match channel {
163 Channel::DirectMessage { .. } => ChannelType::DirectMessage,
164 Channel::Group { .. } => ChannelType::Group,
165 Channel::SavedMessages { .. } => ChannelType::SavedMessages,
166 Channel::TextChannel { .. } | Channel::VoiceChannel { .. } => {
167 ChannelType::ServerChannel
168 }
169 }
170 } else {
171 ChannelType::Unknown
172 }
173 }
174
175 #[allow(dead_code)]
177 async fn get_channel_role_overrides(&mut self) -> &HashMap<String, OverrideField> {
178 if let Some(channel) = &self.channel {
179 match channel {
180 Channel::TextChannel {
181 role_permissions, ..
182 }
183 | Channel::VoiceChannel {
184 role_permissions, ..
185 } => role_permissions,
186 _ => panic!("Not supported for non-server channels"),
187 }
188 } else {
189 panic!("No channel added to query")
190 }
191 }
192}
193
194async fn calculate_members_permissions<'a>(
196 query: &'a mut BulkDatabasePermissionQuery<'a>,
197) -> HashMap<String, PermissionValue> {
198 let mut resp = HashMap::new();
199
200 let (_, channel_role_permissions, channel_default_permissions) = match query
201 .channel
202 .as_ref()
203 .expect("A channel must be assigned to calculate channel permissions")
204 .clone()
205 {
206 Channel::TextChannel {
207 id,
208 role_permissions,
209 default_permissions,
210 ..
211 }
212 | Channel::VoiceChannel {
213 id,
214 role_permissions,
215 default_permissions,
216 ..
217 } => (id, role_permissions, default_permissions),
218 _ => panic!("Calculation of member permissions must be done on a server channel"),
219 };
220
221 if query.users.is_none() {
222 let ids: Vec<String> = query
223 .members
224 .as_ref()
225 .expect("No users or members added to the query")
226 .iter()
227 .map(|m| m.id.user.clone())
228 .collect();
229
230 query.cached_users = Some(
231 query
232 .database
233 .fetch_users(&ids[..])
234 .await
235 .expect("Failed to get data from the db"),
236 );
237
238 query.users = Some(query.cached_users.as_ref().unwrap().to_vec())
239 }
240
241 let users = query.users.as_ref().unwrap();
242
243 if query.members.is_none() {
244 let ids: Vec<String> = query
245 .users
246 .as_ref()
247 .expect("No users or members added to the query")
248 .iter()
249 .map(|m| m.id.clone())
250 .collect();
251
252 query.cached_members = Some(
253 query
254 .database
255 .fetch_members(&query.server.id, &ids[..])
256 .await
257 .expect("Failed to get data from the db"),
258 );
259 query.members = Some(query.cached_members.as_ref().unwrap().to_vec())
260 }
261
262 let members: HashMap<&String, &Member, RandomState> = HashMap::from_iter(
263 query
264 .members
265 .as_ref()
266 .unwrap()
267 .iter()
268 .map(|m| (&m.id.user, m)),
269 );
270
271 for user in users {
272 let member = members.get(&user.id);
273
274 if member.is_none() {
276 resp.insert(user.id.clone(), 0_u64.into());
277 continue;
278 }
279
280 let member = *member.unwrap();
281
282 if user.privileged {
283 resp.insert(
284 user.id.clone(),
285 PermissionValue::from(ChannelPermission::GrantAllSafe),
286 );
287 continue;
288 }
289
290 if user.id == query.server.owner {
291 resp.insert(
292 user.id.clone(),
293 PermissionValue::from(ChannelPermission::GrantAllSafe),
294 );
295 continue;
296 }
297
298 let mut permission = calculate_server_permissions(&query.server, user, member);
300
301 if let Some(defaults) = channel_default_permissions {
302 permission.apply(defaults.into());
303 }
304
305 let mut roles = channel_role_permissions
307 .iter()
308 .filter(|(id, _)| member.roles.contains(id))
309 .filter_map(|(id, permission)| {
310 query.server.roles.get(id).map(|role| {
311 let v: Override = (*permission).into();
312 (role.rank, v)
313 })
314 })
315 .collect::<Vec<(i64, Override)>>();
316
317 roles.sort_by(|a, b| b.0.cmp(&a.0));
318 let overrides = roles.into_iter().map(|(_, v)| v);
319
320 for role_override in overrides {
321 permission.apply(role_override)
322 }
323
324 resp.insert(user.id.clone(), permission);
325 }
326
327 resp
328}
329
330fn calculate_server_permissions(server: &Server, user: &User, member: &Member) -> PermissionValue {
332 if user.privileged || server.owner == user.id {
333 return ChannelPermission::GrantAllSafe.into();
334 }
335
336 let mut permissions: PermissionValue = server.default_permissions.into();
337
338 let mut roles = server
339 .roles
340 .iter()
341 .filter(|(id, _)| member.roles.contains(id))
342 .map(|(_, role)| {
343 let v: Override = role.permissions.into();
344 (role.rank, v)
345 })
346 .collect::<Vec<(i64, Override)>>();
347
348 roles.sort_by(|a, b| b.0.cmp(&a.0));
349 let role_overrides: Vec<Override> = roles.into_iter().map(|(_, v)| v).collect();
350
351 for role in role_overrides {
352 permissions.apply(role);
353 }
354
355 if member.in_timeout() {
356 permissions.restrict(*ALLOW_IN_TIMEOUT);
357 }
358
359 permissions
360}