revolt_database/models/channels/ops/
reference.rs

1use std::collections::hash_map::Entry;
2
3use super::AbstractChannels;
4use crate::ReferenceDb;
5use crate::{Channel, FieldsChannel, PartialChannel};
6use revolt_permissions::OverrideField;
7use revolt_result::Result;
8
9#[async_trait]
10impl AbstractChannels for ReferenceDb {
11    /// Insert a new channel in the database
12    async fn insert_channel(&self, channel: &Channel) -> Result<()> {
13        let mut channels = self.channels.lock().await;
14        if let Entry::Vacant(entry) = channels.entry(channel.id().to_string()) {
15            entry.insert(channel.clone());
16            Ok(())
17        } else {
18            Err(create_database_error!("insert", "channel"))
19        }
20    }
21
22    /// Fetch a channel from the database
23    async fn fetch_channel(&self, channel_id: &str) -> Result<Channel> {
24        let channels = self.channels.lock().await;
25        channels
26            .get(channel_id)
27            .cloned()
28            .ok_or_else(|| create_error!(NotFound))
29    }
30
31    /// Fetch all channels from the database
32    async fn fetch_channels<'a>(&self, ids: &'a [String]) -> Result<Vec<Channel>> {
33        let channels = self.channels.lock().await;
34        ids.iter()
35            .map(|id| {
36                channels
37                    .get(id)
38                    .cloned()
39                    .ok_or_else(|| create_error!(NotFound))
40            })
41            .collect()
42    }
43
44    /// Fetch all direct messages for a user
45    async fn find_direct_messages(&self, user_id: &str) -> Result<Vec<Channel>> {
46        let channels = self.channels.lock().await;
47        Ok(channels
48            .values()
49            .filter(|channel| channel.contains_user(user_id))
50            .cloned()
51            .collect())
52    }
53
54    // Fetch saved messages channel
55    async fn find_saved_messages_channel(&self, user_id: &str) -> Result<Channel> {
56        let channels = self.channels.lock().await;
57        channels
58            .get(user_id)
59            .cloned()
60            .ok_or_else(|| create_database_error!("fetch", "channel"))
61    }
62
63    // Fetch direct message channel (DM or Saved Messages)
64    async fn find_direct_message_channel(&self, user_a: &str, user_b: &str) -> Result<Channel> {
65        let channels = self.channels.lock().await;
66        for (_, data) in channels.iter() {
67            if data.contains_user(user_a) && data.contains_user(user_b) {
68                return Ok(data.to_owned());
69            }
70        }
71        Err(create_error!(NotFound))
72    }
73    /// Insert a user to a group
74    async fn add_user_to_group(&self, channel_id: &str, user_id: &str) -> Result<()> {
75        let mut channels = self.channels.lock().await;
76
77        if let Some(Channel::Group { recipients, .. }) = channels.get_mut(channel_id) {
78            recipients.push(String::from(user_id));
79            Ok(())
80        } else {
81            Err(create_error!(InvalidOperation))
82        }
83    }
84    /// Insert channel role permissions
85    async fn set_channel_role_permission(
86        &self,
87        channel_id: &str,
88        role_id: &str,
89        permissions: OverrideField,
90    ) -> Result<()> {
91        let mut channels = self.channels.lock().await;
92
93        if let Some(mut channel) = channels.get_mut(channel_id) {
94            match &mut channel {
95                Channel::TextChannel {
96                    role_permissions, ..
97                } => {
98                    if role_permissions.get(role_id).is_some() {
99                        role_permissions.remove(role_id);
100                        role_permissions.insert(String::from(role_id), permissions);
101
102                        Ok(())
103                    } else {
104                        Err(create_error!(NotFound))
105                    }
106                }
107                _ => Err(create_error!(NotFound)),
108            }
109        } else {
110            Err(create_error!(NotFound))
111        }
112    }
113
114    // Update channel
115    async fn update_channel(
116        &self,
117        id: &str,
118        channel: &PartialChannel,
119        remove: Vec<FieldsChannel>,
120    ) -> Result<()> {
121        let mut channels = self.channels.lock().await;
122        if let Some(channel_data) = channels.get_mut(id) {
123            channel_data.apply_options(channel.to_owned());
124            channel_data.remove_fields(remove);
125            Ok(())
126        } else {
127            Err(create_error!(NotFound))
128        }
129    }
130
131    // Remove a user from a group
132    async fn remove_user_from_group(&self, channel: &str, user: &str) -> Result<()> {
133        let mut channels = self.channels.lock().await;
134        if let Some(channel_data) = channels.get_mut(channel) {
135            if channel_data.users()?.contains(&String::from(user)) {
136                channel_data.users()?.retain(|x| x != user);
137                return Ok(());
138            } else {
139                return Err(create_error!(NotFound));
140            }
141        }
142        Err(create_error!(NotFound))
143    }
144
145    // Delete a channel
146    async fn delete_channel(&self, channel: &Channel) -> Result<()> {
147        let mut channels = self.channels.lock().await;
148        if channels.remove(channel.id()).is_some() {
149            Ok(())
150        } else {
151            Err(create_error!(NotFound))
152        }
153    }
154}