Skip to main content

revolt_database/models/servers/
model.rs

1use std::collections::{HashMap, HashSet};
2
3use revolt_models::v0::{self, DataCreateServerChannel};
4use revolt_permissions::{OverrideField, DEFAULT_PERMISSION_SERVER};
5use revolt_result::Result;
6use ulid::Ulid;
7
8use crate::{events::client::EventV1, Channel, Database, File, User};
9
10auto_derived_partial!(
11    /// Server
12    pub struct Server {
13        /// Unique Id
14        #[serde(rename = "_id")]
15        pub id: String,
16        /// User id of the owner
17        pub owner: String,
18
19        /// Name of the server
20        pub name: String,
21        /// Description for the server
22        #[serde(skip_serializing_if = "Option::is_none")]
23        pub description: Option<String>,
24
25        /// Channels within this server
26        // TODO: investigate if this is redundant and can be removed
27        pub channels: Vec<String>,
28        /// Categories for this server
29        #[serde(skip_serializing_if = "Option::is_none")]
30        pub categories: Option<Vec<Category>>,
31        /// Configuration for sending system event messages
32        #[serde(skip_serializing_if = "Option::is_none")]
33        pub system_messages: Option<SystemMessageChannels>,
34
35        /// Roles for this server
36        #[serde(
37            default = "HashMap::<String, Role>::new",
38            skip_serializing_if = "HashMap::<String, Role>::is_empty"
39        )]
40        pub roles: HashMap<String, Role>,
41        /// Default set of server and channel permissions
42        pub default_permissions: i64,
43
44        /// Icon attachment
45        #[serde(skip_serializing_if = "Option::is_none")]
46        pub icon: Option<File>,
47        /// Banner attachment
48        #[serde(skip_serializing_if = "Option::is_none")]
49        pub banner: Option<File>,
50
51        /// Bitfield of server flags
52        #[serde(skip_serializing_if = "Option::is_none")]
53        pub flags: Option<i32>,
54
55        /// Whether this server is flagged as not safe for work
56        #[serde(skip_serializing_if = "crate::if_false", default)]
57        pub nsfw: bool,
58        /// Whether to enable analytics
59        #[serde(skip_serializing_if = "crate::if_false", default)]
60        pub analytics: bool,
61        /// Whether this server should be publicly discoverable
62        #[serde(skip_serializing_if = "crate::if_false", default)]
63        pub discoverable: bool,
64    },
65    "PartialServer"
66);
67
68auto_derived_partial!(
69    /// Role
70    pub struct Role {
71        /// Unique Id
72        #[serde(rename = "_id")]
73        pub id: String,
74        /// Role name
75        pub name: String,
76        /// Permissions available to this role
77        pub permissions: OverrideField,
78        /// Colour used for this role
79        ///
80        /// This can be any valid CSS colour
81        #[serde(skip_serializing_if = "Option::is_none")]
82        pub colour: Option<String>,
83        /// Whether this role should be shown separately on the member sidebar
84        #[serde(skip_serializing_if = "crate::if_false", default)]
85        pub hoist: bool,
86        /// Ranking of this role
87        #[serde(default)]
88        pub rank: i64,
89        /// Custom icon attachment
90        #[serde(skip_serializing_if = "Option::is_none")]
91        pub icon: Option<File>,
92    },
93    "PartialRole"
94);
95
96auto_derived!(
97    /// Channel category
98    pub struct Category {
99        /// Unique ID for this category
100        pub id: String,
101        /// Title for this category
102        pub title: String,
103        /// Channels in this category
104        pub channels: Vec<String>,
105    }
106
107    /// System message channel assignments
108    pub struct SystemMessageChannels {
109        /// ID of channel to send user join messages in
110        #[serde(skip_serializing_if = "Option::is_none")]
111        pub user_joined: Option<String>,
112        /// ID of channel to send user left messages in
113        #[serde(skip_serializing_if = "Option::is_none")]
114        pub user_left: Option<String>,
115        /// ID of channel to send user kicked messages in
116        #[serde(skip_serializing_if = "Option::is_none")]
117        pub user_kicked: Option<String>,
118        /// ID of channel to send user banned messages in
119        #[serde(skip_serializing_if = "Option::is_none")]
120        pub user_banned: Option<String>,
121    }
122
123    /// Optional fields on server object
124    pub enum FieldsServer {
125        Description,
126        Categories,
127        SystemMessages,
128        Icon,
129        Banner,
130    }
131
132    /// Optional fields on server object
133    pub enum FieldsRole {
134        Colour,
135        Icon,
136    }
137);
138
139#[allow(clippy::disallowed_methods)]
140impl Server {
141    /// Create a server
142    pub async fn create(
143        db: &Database,
144        data: v0::DataCreateServer,
145        owner: &User,
146        create_default_channels: bool,
147    ) -> Result<(Server, Vec<Channel>)> {
148        let mut server = Server {
149            id: ulid::Ulid::new().to_string(),
150            owner: owner.id.to_string(),
151            name: data.name,
152            description: data.description,
153            channels: vec![],
154            nsfw: data.nsfw.unwrap_or(false),
155            default_permissions: *DEFAULT_PERMISSION_SERVER as i64,
156
157            analytics: false,
158            banner: None,
159            categories: None,
160            discoverable: false,
161            flags: None,
162            icon: None,
163            roles: HashMap::new(),
164            system_messages: None,
165        };
166
167        let channels: Vec<Channel> = if create_default_channels {
168            vec![
169                Channel::create_server_channel(
170                    db,
171                    &mut server,
172                    DataCreateServerChannel {
173                        channel_type: v0::LegacyServerChannelType::Text,
174                        name: "General".to_string(),
175                        ..Default::default()
176                    },
177                    false,
178                )
179                .await?,
180            ]
181        } else {
182            vec![]
183        };
184
185        server.channels = channels.iter().map(|c| c.id().to_string()).collect();
186        db.insert_server(&server).await?;
187        Ok((server, channels))
188    }
189
190    /// Update server data
191    pub async fn update(
192        &mut self,
193        db: &Database,
194        partial: PartialServer,
195        remove: Vec<FieldsServer>,
196    ) -> Result<()> {
197        for field in &remove {
198            self.remove_field(field);
199        }
200
201        self.apply_options(partial.clone());
202
203        db.update_server(&self.id, &partial, remove.clone()).await?;
204
205        EventV1::ServerUpdate {
206            id: self.id.clone(),
207            data: partial.into(),
208            clear: remove.into_iter().map(|v| v.into()).collect(),
209        }
210        .p(self.id.clone())
211        .await;
212
213        Ok(())
214    }
215
216    /// Delete a server
217    pub async fn delete(self, db: &Database) -> Result<()> {
218        EventV1::ServerDelete {
219            id: self.id.clone(),
220        }
221        .p(self.id.clone())
222        .await;
223
224        db.delete_server(&self.id).await
225    }
226
227    /// Remove a field from Server
228    pub fn remove_field(&mut self, field: &FieldsServer) {
229        match field {
230            FieldsServer::Description => self.description = None,
231            FieldsServer::Categories => self.categories = None,
232            FieldsServer::SystemMessages => self.system_messages = None,
233            FieldsServer::Icon => self.icon = None,
234            FieldsServer::Banner => self.banner = None,
235        }
236    }
237
238    /// Ordered roles list
239    pub fn ordered_roles(&self) -> Vec<(String, Role)> {
240        let mut ordered_roles = self.roles.clone().into_iter().collect::<Vec<_>>();
241        ordered_roles.sort_by(|(_, role_a), (_, role_b)| role_a.rank.cmp(&role_b.rank));
242        ordered_roles
243    }
244
245    /// Set role permission on a server
246    pub async fn set_role_permission(
247        &mut self,
248        db: &Database,
249        role_id: &str,
250        permissions: OverrideField,
251    ) -> Result<()> {
252        if let Some(role) = self.roles.get_mut(role_id) {
253            role.update(
254                db,
255                &self.id,
256                PartialRole {
257                    permissions: Some(permissions),
258                    ..Default::default()
259                },
260                vec![],
261            )
262            .await?;
263
264            Ok(())
265        } else {
266            Err(create_error!(NotFound))
267        }
268    }
269
270    /// Reorders the server's roles rankings
271    pub async fn set_role_ordering(&mut self, db: &Database, new_order: Vec<String>) -> Result<()> {
272        // New order must always contain every role
273        debug_assert_eq!(self.roles.len(), new_order.len());
274
275        // Set the role's ranks to the positions in the vec
276        for (rank, id) in new_order.iter().enumerate() {
277            self.roles.get_mut(id).unwrap().rank = rank as i64;
278        }
279
280        db.update_server(
281            &self.id,
282            &PartialServer {
283                roles: Some(self.roles.clone()),
284                ..Default::default()
285            },
286            Vec::new(),
287        )
288        .await?;
289
290        // Publish bulk update event
291        EventV1::ServerRoleRanksUpdate {
292            id: self.id.clone(),
293            ranks: new_order,
294        }
295        .p(self.id.clone())
296        .await;
297
298        Ok(())
299    }
300}
301
302impl Role {
303    /// Into optional struct
304    pub fn into_optional(self) -> PartialRole {
305        PartialRole {
306            id: Some(self.id),
307            name: Some(self.name),
308            permissions: Some(self.permissions),
309            colour: self.colour,
310            hoist: Some(self.hoist),
311            rank: Some(self.rank),
312            icon: self.icon,
313        }
314    }
315
316    /// Create a role
317    pub async fn create(db: &Database, server: &Server, name: String) -> Result<Self> {
318        let role = Role {
319            id: Ulid::new().to_string(),
320            name,
321            // Rank of the new role should be below the lowest role
322            rank: server.roles.len() as i64,
323            colour: None,
324            hoist: false,
325            permissions: Default::default(),
326            icon: None,
327        };
328
329        db.insert_role(&server.id, &role).await?;
330
331        EventV1::ServerRoleUpdate {
332            id: server.id.clone(),
333            role_id: role.id.clone(),
334            data: role.clone().into_optional().into(),
335            clear: vec![],
336        }
337        .p(server.id.clone())
338        .await;
339
340        Ok(role)
341    }
342
343    /// Update server data
344    pub async fn update(
345        &mut self,
346        db: &Database,
347        server_id: &str,
348        partial: PartialRole,
349        remove: Vec<FieldsRole>,
350    ) -> Result<()> {
351        for field in &remove {
352            self.remove_field(field);
353        }
354
355        self.apply_options(partial.clone());
356
357        db.update_role(server_id, &self.id, &partial, remove.clone())
358            .await?;
359
360        EventV1::ServerRoleUpdate {
361            id: server_id.to_string(),
362            role_id: self.id.clone(),
363            data: partial.into(),
364            clear: remove.into_iter().map(Into::into).collect(),
365        }
366        .p(server_id.to_string())
367        .await;
368
369        Ok(())
370    }
371
372    /// Remove field from Role
373    pub fn remove_field(&mut self, field: &FieldsRole) {
374        match field {
375            FieldsRole::Colour => self.colour = None,
376            FieldsRole::Icon => self.icon = None,
377        }
378    }
379
380    /// Delete a role
381    pub async fn delete(self, db: &Database, server_id: &str) -> Result<()> {
382        EventV1::ServerRoleDelete {
383            id: server_id.to_string(),
384            role_id: self.id.clone(),
385        }
386        .p(server_id.to_string())
387        .await;
388
389        db.delete_role(server_id, &self.id).await
390    }
391}
392
393impl SystemMessageChannels {
394    pub fn into_channel_ids(self) -> HashSet<String> {
395        let mut ids = HashSet::new();
396
397        if let Some(id) = self.user_joined {
398            ids.insert(id);
399        }
400
401        if let Some(id) = self.user_left {
402            ids.insert(id);
403        }
404
405        if let Some(id) = self.user_kicked {
406            ids.insert(id);
407        }
408
409        if let Some(id) = self.user_banned {
410            ids.insert(id);
411        }
412
413        ids
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use revolt_permissions::{calculate_server_permissions, ChannelPermission};
420
421    use crate::{fixture, util::permissions::DatabasePermissionQuery};
422
423    #[async_std::test]
424    async fn permissions() {
425        database_test!(|db| async move {
426            fixture!(db, "server_with_roles",
427                owner user 0
428                moderator user 1
429                user user 2
430                server server 4);
431
432            let mut query = DatabasePermissionQuery::new(&db, &owner).server(&server);
433            assert!(calculate_server_permissions(&mut query)
434                .await
435                .has_channel_permission(ChannelPermission::GrantAllSafe));
436
437            let mut query = DatabasePermissionQuery::new(&db, &moderator).server(&server);
438            assert!(calculate_server_permissions(&mut query)
439                .await
440                .has_channel_permission(ChannelPermission::BanMembers));
441
442            let mut query = DatabasePermissionQuery::new(&db, &user).server(&server);
443            assert!(!calculate_server_permissions(&mut query)
444                .await
445                .has_channel_permission(ChannelPermission::BanMembers));
446        });
447    }
448}