1use std::{collections::HashMap, hash::RandomState};
2
3use revolt_permissions::{
4 ChannelPermission, ChannelType, Override, OverrideField, PermissionValue, ALLOW_IN_TIMEOUT,
5 DEFAULT_PERMISSION_DIRECT_MESSAGE,
6};
7
8use crate::{Channel, Database, Member, Server, User};
9
10#[derive(Clone)]
11pub struct BulkDatabasePermissionQuery<'a> {
12 #[allow(dead_code)]
13 database: &'a Database,
14
15 server: Server,
16 channel: Option<Channel>,
17 users: Option<Vec<User>>,
18 members: Option<Vec<Member>>,
19
20 pub(crate) cached_users: Option<Vec<User>>,
22 pub(crate) cached_members: Option<Vec<Member>>,
23
24 cached_member_perms: Option<HashMap<String, PermissionValue>>,
25}
26
27impl<'z, 'x> BulkDatabasePermissionQuery<'x> {
28 pub async fn members_can_see_channel(&'z mut self) -> HashMap<String, bool>
29 where
30 'z: 'x,
31 {
32 let member_perms = if self.cached_member_perms.is_some() {
33 let perms = self.cached_member_perms.as_ref().unwrap();
35 perms
36 .iter()
37 .map(|(m, p)| {
38 (
39 m.clone(),
40 p.has_channel_permission(ChannelPermission::ViewChannel),
41 )
42 })
43 .collect()
44 } else {
45 calculate_members_permissions(self)
46 .await
47 .iter()
48 .map(|(m, p)| {
49 (
50 m.clone(),
51 p.has_channel_permission(ChannelPermission::ViewChannel),
52 )
53 })
54 .collect()
55 };
56 member_perms
57 }
58}
59
60impl<'z> BulkDatabasePermissionQuery<'z> {
61 pub fn new(database: &Database, server: Server) -> BulkDatabasePermissionQuery<'_> {
62 BulkDatabasePermissionQuery {
63 database,
64 server,
65 channel: None,
66 users: None,
67 members: None,
68 cached_members: None,
69 cached_users: None,
70 cached_member_perms: None,
71 }
72 }
73
74 pub async fn from_server_id<'a>(
75 db: &'a Database,
76 server: &str,
77 ) -> BulkDatabasePermissionQuery<'a> {
78 BulkDatabasePermissionQuery {
79 database: db,
80 server: db.fetch_server(server).await.unwrap(),
81 channel: None,
82 users: None,
83 members: None,
84 cached_members: None,
85 cached_users: None,
86 cached_member_perms: None,
87 }
88 }
89
90 pub fn channel(self, channel: &'z Channel) -> BulkDatabasePermissionQuery<'z> {
91 BulkDatabasePermissionQuery {
92 channel: Some(channel.clone()),
93 ..self
94 }
95 }
96
97 pub async fn from_channel_id(self, channel_id: String) -> BulkDatabasePermissionQuery<'z> {
98 let channel = self
99 .database
100 .fetch_channel(channel_id.as_str())
101 .await
102 .expect("Valid channel id");
103
104 drop(channel_id);
105
106 BulkDatabasePermissionQuery {
107 channel: Some(channel),
108 ..self
109 }
110 }
111
112 pub fn members(self, members: &'z [Member]) -> BulkDatabasePermissionQuery<'z> {
113 BulkDatabasePermissionQuery {
114 members: Some(members.to_owned()),
115 cached_member_perms: None,
116 users: None,
117 cached_members: None,
118 cached_users: None,
119 ..self
120 }
121 }
122
123 pub fn users(self, users: &'z [User]) -> BulkDatabasePermissionQuery<'z> {
124 BulkDatabasePermissionQuery {
125 users: Some(users.to_owned()),
126 cached_member_perms: None,
127 members: None,
128 cached_members: None,
129 cached_users: None,
130 ..self
131 }
132 }
133
134 #[allow(dead_code)]
137 async fn get_default_channel_permissions(&mut self) -> Override {
138 if let Some(channel) = &self.channel {
139 match channel {
140 Channel::Group { permissions, .. } => Override {
141 allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64,
142 deny: 0,
143 },
144 Channel::TextChannel {
145 default_permissions,
146 ..
147 } => default_permissions.unwrap_or_default().into(),
148 _ => Default::default(),
149 }
150 } else {
151 Default::default()
152 }
153 }
154
155 #[allow(dead_code, deprecated)]
156 fn get_channel_type(&mut self) -> ChannelType {
157 if let Some(channel) = &self.channel {
158 match channel {
159 Channel::DirectMessage { .. } => ChannelType::DirectMessage,
160 Channel::Group { .. } => ChannelType::Group,
161 Channel::SavedMessages { .. } => ChannelType::SavedMessages,
162 Channel::TextChannel { .. } => ChannelType::ServerChannel,
163 }
164 } else {
165 ChannelType::Unknown
166 }
167 }
168
169 #[allow(dead_code)]
171 async fn get_channel_role_overrides(&mut self) -> &HashMap<String, OverrideField> {
172 if let Some(channel) = &self.channel {
173 match channel {
174 Channel::TextChannel {
175 role_permissions, ..
176 } => role_permissions,
177 _ => panic!("Not supported for non-server channels"),
178 }
179 } else {
180 panic!("No channel added to query")
181 }
182 }
183}
184
185async fn calculate_members_permissions<'a>(
187 query: &'a mut BulkDatabasePermissionQuery<'a>,
188) -> HashMap<String, PermissionValue> {
189 let mut resp = HashMap::new();
190
191 let (_, channel_role_permissions, channel_default_permissions) = match query
192 .channel
193 .as_ref()
194 .expect("A channel must be assigned to calculate channel permissions")
195 .clone()
196 {
197 Channel::TextChannel {
198 id,
199 role_permissions,
200 default_permissions,
201 ..
202 } => (id, role_permissions, default_permissions),
203 _ => panic!("Calculation of member permissions must be done on a server channel"),
204 };
205
206 if query.users.is_none() {
207 let ids: Vec<String> = query
208 .members
209 .as_ref()
210 .expect("No users or members added to the query")
211 .iter()
212 .map(|m| m.id.user.clone())
213 .collect();
214
215 query.cached_users = Some(
216 query
217 .database
218 .fetch_users(&ids[..])
219 .await
220 .expect("Failed to get data from the db"),
221 );
222
223 query.users = Some(query.cached_users.as_ref().unwrap().to_vec())
224 }
225
226 let users = query.users.as_ref().unwrap();
227
228 if query.members.is_none() {
229 let ids: Vec<String> = query
230 .users
231 .as_ref()
232 .expect("No users or members added to the query")
233 .iter()
234 .map(|m| m.id.clone())
235 .collect();
236
237 query.cached_members = Some(
238 query
239 .database
240 .fetch_members(&query.server.id, &ids[..])
241 .await
242 .expect("Failed to get data from the db"),
243 );
244 query.members = Some(query.cached_members.as_ref().unwrap().to_vec())
245 }
246
247 let members: HashMap<&String, &Member, RandomState> = HashMap::from_iter(
248 query
249 .members
250 .as_ref()
251 .unwrap()
252 .iter()
253 .map(|m| (&m.id.user, m)),
254 );
255
256 for user in users {
257 let member = members.get(&user.id);
258
259 if member.is_none() {
261 resp.insert(user.id.clone(), 0_u64.into());
262 continue;
263 }
264
265 let member = *member.unwrap();
266
267 if user.privileged {
268 resp.insert(
269 user.id.clone(),
270 PermissionValue::from(ChannelPermission::GrantAllSafe),
271 );
272 continue;
273 }
274
275 if user.id == query.server.owner {
276 resp.insert(
277 user.id.clone(),
278 PermissionValue::from(ChannelPermission::GrantAllSafe),
279 );
280 continue;
281 }
282
283 let mut permission = calculate_server_permissions(&query.server, user, member);
285
286 if let Some(defaults) = channel_default_permissions {
287 permission.apply(defaults.into());
288 }
289
290 let mut roles = channel_role_permissions
292 .iter()
293 .filter(|(id, _)| member.roles.contains(id))
294 .filter_map(|(id, permission)| {
295 query.server.roles.get(id).map(|role| {
296 let v: Override = (*permission).into();
297 (role.rank, v)
298 })
299 })
300 .collect::<Vec<(i64, Override)>>();
301
302 roles.sort_by(|a, b| b.0.cmp(&a.0));
303 let overrides = roles.into_iter().map(|(_, v)| v);
304
305 for role_override in overrides {
306 permission.apply(role_override)
307 }
308
309 resp.insert(user.id.clone(), permission);
310 }
311
312 resp
313}
314
315fn calculate_server_permissions(server: &Server, user: &User, member: &Member) -> PermissionValue {
317 if user.privileged || server.owner == user.id {
318 return ChannelPermission::GrantAllSafe.into();
319 }
320
321 let mut permissions: PermissionValue = server.default_permissions.into();
322
323 let mut roles = server
324 .roles
325 .iter()
326 .filter(|(id, _)| member.roles.contains(id))
327 .map(|(_, role)| {
328 let v: Override = role.permissions.into();
329 (role.rank, v)
330 })
331 .collect::<Vec<(i64, Override)>>();
332
333 roles.sort_by(|a, b| b.0.cmp(&a.0));
334 let role_overrides: Vec<Override> = roles.into_iter().map(|(_, v)| v).collect();
335
336 for role in role_overrides {
337 permissions.apply(role);
338 }
339
340 if member.in_timeout() {
341 permissions.restrict(*ALLOW_IN_TIMEOUT);
342 }
343
344 permissions
345}