revolt_config/
lib.rs

1use std::collections::HashMap;
2
3use cached::proc_macro::cached;
4use config::{Config, File, FileFormat};
5use futures_locks::RwLock;
6use once_cell::sync::Lazy;
7use serde::Deserialize;
8
9#[cfg(feature = "sentry")]
10pub use sentry::{capture_error, capture_message, Level};
11#[cfg(feature = "anyhow")]
12pub use sentry_anyhow::capture_anyhow;
13
14#[cfg(all(feature = "report-macros", feature = "sentry"))]
15#[macro_export]
16macro_rules! report_error {
17    ( $expr: expr, $error: ident $( $tt:tt )? ) => {
18        $expr
19            .inspect_err(|err| {
20                $crate::capture_message(
21                    &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
22                    $crate::Level::Error,
23                );
24            })
25            .map_err(|_| ::revolt_result::create_error!($error))
26    };
27}
28
29#[cfg(all(feature = "report-macros", feature = "sentry"))]
30#[macro_export]
31macro_rules! capture_internal_error {
32    ( $expr: expr ) => {
33        $crate::capture_message(
34            &format!("{:?} ({}:{}:{})", $expr, file!(), line!(), column!()),
35            $crate::Level::Error,
36        );
37    };
38}
39
40#[cfg(all(feature = "report-macros", feature = "sentry"))]
41#[macro_export]
42macro_rules! report_internal_error {
43    ( $expr: expr ) => {
44        $expr
45            .inspect_err(|err| {
46                $crate::capture_message(
47                    &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
48                    $crate::Level::Error,
49                );
50            })
51            .map_err(|_| ::revolt_result::create_error!(InternalError))
52    };
53}
54
55/// Paths to search for configuration
56static CONFIG_SEARCH_PATHS: [&str; 3] = [
57    // current working directory
58    "Revolt.toml",
59    // current working directory - overrides file
60    "Revolt.overrides.toml",
61    // root directory, for Docker containers
62    "/Revolt.toml",
63];
64
65/// Path to search for test overrides
66static TEST_OVERRIDE_PATH: &str = "Revolt.test-overrides.toml";
67
68/// Configuration builder
69static CONFIG_BUILDER: Lazy<RwLock<Config>> = Lazy::new(|| {
70    RwLock::new({
71        let mut builder = Config::builder().add_source(File::from_str(
72            include_str!("../Revolt.toml"),
73            FileFormat::Toml,
74        ));
75
76        if std::env::var("TEST_DB").is_ok() {
77            builder = builder.add_source(File::from_str(
78                include_str!("../Revolt.test.toml"),
79                FileFormat::Toml,
80            ));
81
82            // recursively search upwards for an overrides file (if there is one)
83            if let Ok(cwd) = std::env::current_dir() {
84                let mut path = Some(cwd.as_path());
85                while let Some(current_path) = path {
86                    let target_path = current_path.join(TEST_OVERRIDE_PATH);
87                    if target_path.exists() {
88                        builder = builder
89                            .add_source(File::new(target_path.to_str().unwrap(), FileFormat::Toml));
90                    }
91
92                    path = current_path.parent();
93                }
94            }
95        }
96
97        for path in CONFIG_SEARCH_PATHS {
98            if std::path::Path::new(path).exists() {
99                builder = builder.add_source(File::new(path, FileFormat::Toml));
100            }
101        }
102
103        builder.build().unwrap()
104    })
105});
106
107#[derive(Deserialize, Debug, Clone)]
108pub struct Database {
109    pub mongodb: String,
110    pub redis: String,
111}
112
113#[derive(Deserialize, Debug, Clone)]
114pub struct Rabbit {
115    pub host: String,
116    pub port: u16,
117    pub username: String,
118    pub password: String,
119}
120
121#[derive(Deserialize, Debug, Clone)]
122pub struct Hosts {
123    pub app: String,
124    pub api: String,
125    pub events: String,
126    pub autumn: String,
127    pub january: String,
128    pub livekit: HashMap<String, String>,
129}
130
131#[derive(Deserialize, Debug, Clone)]
132pub struct ApiRegistration {
133    pub invite_only: bool,
134}
135
136#[derive(Deserialize, Debug, Clone)]
137pub struct ApiSmtp {
138    pub host: String,
139    pub username: String,
140    pub password: String,
141    pub from_address: String,
142    pub reply_to: Option<String>,
143    pub port: Option<i32>,
144    pub use_tls: Option<bool>,
145    pub use_starttls: Option<bool>,
146}
147
148#[derive(Deserialize, Debug, Clone)]
149pub struct PushVapid {
150    pub queue: String,
151    pub private_key: String,
152    pub public_key: String,
153}
154
155#[derive(Deserialize, Debug, Clone)]
156pub struct PushFcm {
157    pub queue: String,
158    pub key_type: String,
159    pub project_id: String,
160    pub private_key_id: String,
161    pub private_key: String,
162    pub client_email: String,
163    pub client_id: String,
164    pub auth_uri: String,
165    pub token_uri: String,
166    pub auth_provider_x509_cert_url: String,
167    pub client_x509_cert_url: String,
168}
169
170#[derive(Deserialize, Debug, Clone)]
171pub struct PushApn {
172    pub queue: String,
173    pub sandbox: bool,
174    pub pkcs8: String,
175    pub key_id: String,
176    pub team_id: String,
177}
178
179#[derive(Deserialize, Debug, Clone)]
180pub struct ApiSecurityCaptcha {
181    pub hcaptcha_key: String,
182    pub hcaptcha_sitekey: String,
183}
184
185#[derive(Deserialize, Debug, Clone)]
186pub struct ApiSecurity {
187    pub authifier_shield_key: String,
188    pub voso_legacy_token: String,
189    pub captcha: ApiSecurityCaptcha,
190    pub trust_cloudflare: bool,
191    pub easypwned: String,
192    pub tenor_key: String,
193}
194
195#[derive(Deserialize, Debug, Clone)]
196pub struct ApiWorkers {
197    pub max_concurrent_connections: usize,
198}
199
200#[derive(Deserialize, Debug, Clone)]
201pub struct ApiLiveKit {
202    pub call_ring_duration: usize,
203    pub nodes: HashMap<String, LiveKitNode>,
204}
205
206#[derive(Deserialize, Debug, Clone)]
207pub struct LiveKitNode {
208    pub url: String,
209    pub lat: f64,
210    pub lon: f64,
211    pub key: String,
212    pub secret: String,
213
214    // whether to hide the node in the nodes list
215    #[serde(default)]
216    pub private: bool,
217}
218
219#[derive(Deserialize, Debug, Clone)]
220pub struct ApiUsers {
221    pub early_adopter_cutoff: Option<u64>,
222}
223
224#[derive(Deserialize, Debug, Clone)]
225pub struct Api {
226    pub registration: ApiRegistration,
227    pub smtp: ApiSmtp,
228    pub security: ApiSecurity,
229    pub workers: ApiWorkers,
230    pub livekit: ApiLiveKit,
231    pub users: ApiUsers,
232}
233
234#[derive(Deserialize, Debug, Clone)]
235pub struct Pushd {
236    pub production: bool,
237    pub exchange: String,
238    pub mass_mention_chunk_size: usize,
239
240    // Queues
241    pub message_queue: String,
242    pub mass_mention_queue: String,
243    pub dm_call_queue: String,
244    pub fr_accepted_queue: String,
245    pub fr_received_queue: String,
246    pub generic_queue: String,
247    pub ack_queue: String,
248
249    pub vapid: PushVapid,
250    pub fcm: PushFcm,
251    pub apn: PushApn,
252}
253
254impl Pushd {
255    fn get_routing_key(&self, key: String) -> String {
256        match self.production {
257            true => key + "-prd",
258            false => key + "-tst",
259        }
260    }
261
262    pub fn get_ack_routing_key(&self) -> String {
263        self.get_routing_key(self.ack_queue.clone())
264    }
265
266    pub fn get_message_routing_key(&self) -> String {
267        self.get_routing_key(self.message_queue.clone())
268    }
269
270    pub fn get_mass_mention_routing_key(&self) -> String {
271        self.get_routing_key(self.mass_mention_queue.clone())
272    }
273
274    pub fn get_dm_call_routing_key(&self) -> String {
275        self.get_routing_key(self.dm_call_queue.clone())
276    }
277
278    pub fn get_fr_accepted_routing_key(&self) -> String {
279        self.get_routing_key(self.fr_accepted_queue.clone())
280    }
281
282    pub fn get_fr_received_routing_key(&self) -> String {
283        self.get_routing_key(self.fr_received_queue.clone())
284    }
285
286    pub fn get_generic_routing_key(&self) -> String {
287        self.get_routing_key(self.generic_queue.clone())
288    }
289}
290
291#[derive(Deserialize, Debug, Clone)]
292pub struct FilesLimit {
293    pub min_file_size: usize,
294    pub min_resolution: [usize; 2],
295    pub max_mega_pixels: usize,
296    pub max_pixel_side: usize,
297}
298
299#[derive(Deserialize, Debug, Clone)]
300pub struct FilesS3 {
301    pub endpoint: String,
302    pub path_style_buckets: bool,
303    pub region: String,
304    pub access_key_id: String,
305    pub secret_access_key: String,
306    pub default_bucket: String,
307}
308
309#[derive(Deserialize, Debug, Clone)]
310pub struct Files {
311    pub encryption_key: String,
312    pub webp_quality: f32,
313    pub blocked_mime_types: Vec<String>,
314    pub clamd_host: String,
315    pub scan_mime_types: Vec<String>,
316
317    pub limit: FilesLimit,
318    pub preview: HashMap<String, [usize; 2]>,
319    pub s3: FilesS3,
320}
321
322#[derive(Deserialize, Debug, Clone)]
323pub struct GlobalLimits {
324    pub group_size: usize,
325    pub message_embeds: usize,
326    pub message_replies: usize,
327    pub message_reactions: usize,
328    pub server_emoji: usize,
329    pub server_roles: usize,
330    pub server_channels: usize,
331
332    pub new_user_hours: usize,
333
334    pub body_limit_size: usize,
335}
336
337#[derive(Deserialize, Debug, Clone)]
338pub struct FeaturesLimits {
339    pub outgoing_friend_requests: usize,
340
341    pub bots: usize,
342    pub message_length: usize,
343    pub message_attachments: usize,
344    pub servers: usize,
345    pub voice_quality: u32,
346    pub video: bool,
347    pub video_resolution: [u32; 2],
348    pub video_aspect_ratio: [f32; 2],
349
350    pub file_upload_size_limit: HashMap<String, usize>,
351}
352
353#[derive(Deserialize, Debug, Clone)]
354pub struct FeaturesLimitsCollection {
355    pub global: GlobalLimits,
356
357    pub new_user: FeaturesLimits,
358    pub default: FeaturesLimits,
359
360    #[serde(flatten)]
361    pub roles: HashMap<String, FeaturesLimits>,
362}
363
364#[derive(Deserialize, Debug, Clone)]
365pub struct FeaturesAdvanced {
366    #[serde(default)]
367    pub process_message_delay_limit: u16,
368}
369
370impl Default for FeaturesAdvanced {
371    fn default() -> Self {
372        Self {
373            process_message_delay_limit: 5,
374        }
375    }
376}
377
378#[derive(Deserialize, Debug, Clone)]
379pub struct Features {
380    pub limits: FeaturesLimitsCollection,
381    pub webhooks_enabled: bool,
382    pub mass_mentions_send_notifications: bool,
383    pub mass_mentions_enabled: bool,
384
385    #[serde(default)]
386    pub advanced: FeaturesAdvanced,
387}
388
389#[derive(Deserialize, Debug, Clone)]
390pub struct Sentry {
391    pub api: String,
392    pub events: String,
393    pub voice_ingress: String,
394    pub files: String,
395    pub proxy: String,
396    pub pushd: String,
397    pub crond: String,
398    pub gifbox: String,
399}
400
401#[derive(Deserialize, Debug, Clone)]
402pub struct Settings {
403    pub database: Database,
404    pub rabbit: Rabbit,
405    pub hosts: Hosts,
406    pub api: Api,
407    pub pushd: Pushd,
408    pub files: Files,
409    pub features: Features,
410    pub sentry: Sentry,
411    pub production: bool,
412}
413
414impl Settings {
415    pub fn preflight_checks(&self) {
416        if self.api.smtp.host.is_empty() {
417            log::warn!("No SMTP settings specified! Remember to configure email.");
418        }
419
420        if self.api.security.captcha.hcaptcha_key.is_empty() {
421            log::warn!("No Captcha key specified! Remember to add hCaptcha key.");
422        }
423    }
424}
425
426pub async fn init() {
427    println!(
428        ":: Revolt Configuration ::\n\x1b[32m{:?}\x1b[0m",
429        config().await
430    );
431}
432
433pub async fn read() -> Config {
434    CONFIG_BUILDER.read().await.clone()
435}
436
437#[cached(time = 30)]
438pub async fn config() -> Settings {
439    let mut config = read().await.try_deserialize::<Settings>().unwrap();
440
441    // inject REDIS_URI for redis-kiss library
442    if std::env::var("REDIS_URL").is_err() {
443        std::env::set_var("REDIS_URI", config.database.redis.clone());
444    }
445
446    // auto-detect production nodes
447    if config.hosts.api.contains("https") && config.hosts.api.contains("revolt.chat") {
448        config.production = true;
449    }
450
451    config
452}
453
454/// Configure logging and common Rust variables
455#[cfg(feature = "sentry")]
456pub async fn setup_logging(release: &'static str, dsn: String) -> Option<sentry::ClientInitGuard> {
457    if std::env::var("RUST_LOG").is_err() {
458        std::env::set_var("RUST_LOG", "info");
459    }
460
461    if std::env::var("ROCKET_ADDRESS").is_err() {
462        std::env::set_var("ROCKET_ADDRESS", "0.0.0.0");
463    }
464
465    pretty_env_logger::init();
466    log::info!("Starting {release}");
467
468    if dsn.is_empty() {
469        None
470    } else {
471        Some(sentry::init((
472            dsn,
473            sentry::ClientOptions {
474                release: Some(release.into()),
475                ..Default::default()
476            },
477        )))
478    }
479}
480
481#[cfg(feature = "sentry")]
482#[macro_export]
483macro_rules! configure {
484    ($application: ident) => {
485        let config = $crate::config().await;
486        let _sentry = $crate::setup_logging(
487            concat!(env!("CARGO_PKG_NAME"), "@", env!("CARGO_PKG_VERSION")),
488            config.sentry.$application,
489        )
490        .await;
491    };
492}
493
494#[cfg(feature = "test")]
495#[cfg(test)]
496mod tests {
497    use crate::init;
498
499    #[async_std::test]
500    async fn it_works() {
501        init().await;
502    }
503}