revolt_config/
lib.rs

1use std::collections::HashMap;
2
3use cached::proc_macro::cached;
4use config::{Config, File, FileFormat};
5use futures_locks::RwLock;
6use once_cell::sync::Lazy;
7use serde::Deserialize;
8
9pub use sentry::{capture_error, capture_message, Level};
10pub use sentry_anyhow::capture_anyhow;
11
12#[cfg(feature = "report-macros")]
13#[macro_export]
14macro_rules! report_error {
15    ( $expr: expr, $error: ident $( $tt:tt )? ) => {
16        $expr
17            .inspect_err(|err| {
18                $crate::capture_message(
19                    &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
20                    $crate::Level::Error,
21                );
22            })
23            .map_err(|_| ::revolt_result::create_error!($error))
24    };
25}
26
27#[cfg(feature = "report-macros")]
28#[macro_export]
29macro_rules! capture_internal_error {
30    ( $expr: expr ) => {
31        $crate::capture_message(
32            &format!("{:?} ({}:{}:{})", $expr, file!(), line!(), column!()),
33            $crate::Level::Error,
34        );
35    };
36}
37
38#[cfg(feature = "report-macros")]
39#[macro_export]
40macro_rules! report_internal_error {
41    ( $expr: expr ) => {
42        $expr
43            .inspect_err(|err| {
44                $crate::capture_message(
45                    &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
46                    $crate::Level::Error,
47                );
48            })
49            .map_err(|_| ::revolt_result::create_error!(InternalError))
50    };
51}
52
53/// Paths to search for configuration
54static CONFIG_SEARCH_PATHS: [&str; 3] = [
55    // current working directory
56    "Revolt.toml",
57    // current working directory - overrides file
58    "Revolt.overrides.toml",
59    // root directory, for Docker containers
60    "/Revolt.toml",
61];
62
63/// Path to search for test overrides
64static TEST_OVERRIDE_PATH: &str = "Revolt.test-overrides.toml";
65
66/// Configuration builder
67static CONFIG_BUILDER: Lazy<RwLock<Config>> = Lazy::new(|| {
68    RwLock::new({
69        let mut builder = Config::builder().add_source(File::from_str(
70            include_str!("../Revolt.toml"),
71            FileFormat::Toml,
72        ));
73
74        if std::env::var("TEST_DB").is_ok() {
75            builder = builder.add_source(File::from_str(
76                include_str!("../Revolt.test.toml"),
77                FileFormat::Toml,
78            ));
79
80            // recursively search upwards for an overrides file (if there is one)
81            if let Ok(cwd) = std::env::current_dir() {
82                let mut path = Some(cwd.as_path());
83                while let Some(current_path) = path {
84                    let target_path = current_path.join(TEST_OVERRIDE_PATH);
85                    if target_path.exists() {
86                        builder = builder
87                            .add_source(File::new(target_path.to_str().unwrap(), FileFormat::Toml));
88                    }
89
90                    path = current_path.parent();
91                }
92            }
93        }
94
95        for path in CONFIG_SEARCH_PATHS {
96            if std::path::Path::new(path).exists() {
97                builder = builder.add_source(File::new(path, FileFormat::Toml));
98            }
99        }
100
101        builder.build().unwrap()
102    })
103});
104
105#[derive(Deserialize, Debug, Clone)]
106pub struct Database {
107    pub mongodb: String,
108    pub redis: String,
109}
110
111#[derive(Deserialize, Debug, Clone)]
112pub struct Rabbit {
113    pub host: String,
114    pub port: u16,
115    pub username: String,
116    pub password: String,
117}
118
119#[derive(Deserialize, Debug, Clone)]
120pub struct Hosts {
121    pub app: String,
122    pub api: String,
123    pub events: String,
124    pub autumn: String,
125    pub january: String,
126    pub voso_legacy: String,
127    pub voso_legacy_ws: String,
128}
129
130#[derive(Deserialize, Debug, Clone)]
131pub struct ApiRegistration {
132    pub invite_only: bool,
133}
134
135#[derive(Deserialize, Debug, Clone)]
136pub struct ApiSmtp {
137    pub host: String,
138    pub username: String,
139    pub password: String,
140    pub from_address: String,
141    pub reply_to: Option<String>,
142    pub port: Option<i32>,
143    pub use_tls: Option<bool>,
144    pub use_starttls: Option<bool>,
145}
146
147#[derive(Deserialize, Debug, Clone)]
148pub struct PushVapid {
149    pub queue: String,
150    pub private_key: String,
151    pub public_key: String,
152}
153
154#[derive(Deserialize, Debug, Clone)]
155pub struct PushFcm {
156    pub queue: String,
157    pub key_type: String,
158    pub project_id: String,
159    pub private_key_id: String,
160    pub private_key: String,
161    pub client_email: String,
162    pub client_id: String,
163    pub auth_uri: String,
164    pub token_uri: String,
165    pub auth_provider_x509_cert_url: String,
166    pub client_x509_cert_url: String,
167}
168
169#[derive(Deserialize, Debug, Clone)]
170pub struct PushApn {
171    pub queue: String,
172    pub sandbox: bool,
173    pub pkcs8: String,
174    pub key_id: String,
175    pub team_id: String,
176}
177
178#[derive(Deserialize, Debug, Clone)]
179pub struct ApiSecurityCaptcha {
180    pub hcaptcha_key: String,
181    pub hcaptcha_sitekey: String,
182}
183
184#[derive(Deserialize, Debug, Clone)]
185pub struct ApiSecurity {
186    pub authifier_shield_key: String,
187    pub voso_legacy_token: String,
188    pub captcha: ApiSecurityCaptcha,
189    pub trust_cloudflare: bool,
190    pub easypwned: String,
191}
192
193#[derive(Deserialize, Debug, Clone)]
194pub struct ApiWorkers {
195    pub max_concurrent_connections: usize,
196}
197
198#[derive(Deserialize, Debug, Clone)]
199pub struct ApiUsers {
200    pub early_adopter_cutoff: Option<u64>,
201}
202
203#[derive(Deserialize, Debug, Clone)]
204pub struct Api {
205    pub registration: ApiRegistration,
206    pub smtp: ApiSmtp,
207    pub security: ApiSecurity,
208    pub workers: ApiWorkers,
209    pub users: ApiUsers,
210}
211
212#[derive(Deserialize, Debug, Clone)]
213pub struct Pushd {
214    pub production: bool,
215    pub exchange: String,
216    pub mass_mention_chunk_size: usize,
217
218    // Queues
219    pub message_queue: String,
220    pub mass_mention_queue: String,
221    pub fr_accepted_queue: String,
222    pub fr_received_queue: String,
223    pub generic_queue: String,
224    pub ack_queue: String,
225
226    pub vapid: PushVapid,
227    pub fcm: PushFcm,
228    pub apn: PushApn,
229}
230
231impl Pushd {
232    fn get_routing_key(&self, key: String) -> String {
233        match self.production {
234            true => key + "-prd",
235            false => key + "-tst",
236        }
237    }
238
239    pub fn get_ack_routing_key(&self) -> String {
240        self.get_routing_key(self.ack_queue.clone())
241    }
242
243    pub fn get_message_routing_key(&self) -> String {
244        self.get_routing_key(self.message_queue.clone())
245    }
246
247    pub fn get_mass_mention_routing_key(&self) -> String {
248        self.get_routing_key(self.mass_mention_queue.clone())
249    }
250
251    pub fn get_fr_accepted_routing_key(&self) -> String {
252        self.get_routing_key(self.fr_accepted_queue.clone())
253    }
254
255    pub fn get_fr_received_routing_key(&self) -> String {
256        self.get_routing_key(self.fr_received_queue.clone())
257    }
258
259    pub fn get_generic_routing_key(&self) -> String {
260        self.get_routing_key(self.generic_queue.clone())
261    }
262}
263
264#[derive(Deserialize, Debug, Clone)]
265pub struct FilesLimit {
266    pub min_file_size: usize,
267    pub min_resolution: [usize; 2],
268    pub max_mega_pixels: usize,
269    pub max_pixel_side: usize,
270}
271
272#[derive(Deserialize, Debug, Clone)]
273pub struct FilesS3 {
274    pub endpoint: String,
275    pub path_style_buckets: bool,
276    pub region: String,
277    pub access_key_id: String,
278    pub secret_access_key: String,
279    pub default_bucket: String,
280}
281
282#[derive(Deserialize, Debug, Clone)]
283pub struct Files {
284    pub encryption_key: String,
285    pub webp_quality: f32,
286    pub blocked_mime_types: Vec<String>,
287    pub clamd_host: String,
288    pub scan_mime_types: Vec<String>,
289
290    pub limit: FilesLimit,
291    pub preview: HashMap<String, [usize; 2]>,
292    pub s3: FilesS3,
293}
294
295#[derive(Deserialize, Debug, Clone)]
296pub struct GlobalLimits {
297    pub group_size: usize,
298    pub message_embeds: usize,
299    pub message_replies: usize,
300    pub message_reactions: usize,
301    pub server_emoji: usize,
302    pub server_roles: usize,
303    pub server_channels: usize,
304
305    pub new_user_hours: usize,
306
307    pub body_limit_size: usize,
308}
309
310#[derive(Deserialize, Debug, Clone)]
311pub struct FeaturesLimits {
312    pub outgoing_friend_requests: usize,
313
314    pub bots: usize,
315    pub message_length: usize,
316    pub message_attachments: usize,
317    pub servers: usize,
318
319    pub file_upload_size_limit: HashMap<String, usize>,
320}
321
322#[derive(Deserialize, Debug, Clone)]
323pub struct FeaturesLimitsCollection {
324    pub global: GlobalLimits,
325
326    pub new_user: FeaturesLimits,
327    pub default: FeaturesLimits,
328
329    #[serde(flatten)]
330    pub roles: HashMap<String, FeaturesLimits>,
331}
332
333#[derive(Deserialize, Debug, Clone)]
334pub struct FeaturesAdvanced {
335    #[serde(default)]
336    pub process_message_delay_limit: u16,
337}
338
339impl Default for FeaturesAdvanced {
340    fn default() -> Self {
341        Self {
342            process_message_delay_limit: 5,
343        }
344    }
345}
346
347#[derive(Deserialize, Debug, Clone)]
348pub struct Features {
349    pub limits: FeaturesLimitsCollection,
350    pub webhooks_enabled: bool,
351    pub mass_mentions_send_notifications: bool,
352    pub mass_mentions_enabled: bool,
353
354    #[serde(default)]
355    pub advanced: FeaturesAdvanced,
356}
357
358#[derive(Deserialize, Debug, Clone)]
359pub struct Sentry {
360    pub api: String,
361    pub events: String,
362    pub files: String,
363    pub proxy: String,
364    pub pushd: String,
365    pub crond: String,
366}
367
368#[derive(Deserialize, Debug, Clone)]
369pub struct Settings {
370    pub database: Database,
371    pub rabbit: Rabbit,
372    pub hosts: Hosts,
373    pub api: Api,
374    pub pushd: Pushd,
375    pub files: Files,
376    pub features: Features,
377    pub sentry: Sentry,
378    pub production: bool,
379}
380
381impl Settings {
382    pub fn preflight_checks(&self) {
383        if self.api.smtp.host.is_empty() {
384            log::warn!("No SMTP settings specified! Remember to configure email.");
385        }
386
387        if self.api.security.captcha.hcaptcha_key.is_empty() {
388            log::warn!("No Captcha key specified! Remember to add hCaptcha key.");
389        }
390    }
391}
392
393pub async fn init() {
394    println!(
395        ":: Revolt Configuration ::\n\x1b[32m{:?}\x1b[0m",
396        config().await
397    );
398}
399
400pub async fn read() -> Config {
401    CONFIG_BUILDER.read().await.clone()
402}
403
404#[cached(time = 30)]
405pub async fn config() -> Settings {
406    let mut config = read().await.try_deserialize::<Settings>().unwrap();
407
408    // inject REDIS_URI for redis-kiss library
409    if std::env::var("REDIS_URL").is_err() {
410        std::env::set_var("REDIS_URI", config.database.redis.clone());
411    }
412
413    // auto-detect production nodes
414    if config.hosts.api.contains("https") && config.hosts.api.contains("revolt.chat") {
415        config.production = true;
416    }
417
418    config
419}
420
421/// Configure logging and common Rust variables
422pub async fn setup_logging(release: &'static str, dsn: String) -> Option<sentry::ClientInitGuard> {
423    if std::env::var("RUST_LOG").is_err() {
424        std::env::set_var("RUST_LOG", "info");
425    }
426
427    if std::env::var("ROCKET_ADDRESS").is_err() {
428        std::env::set_var("ROCKET_ADDRESS", "0.0.0.0");
429    }
430
431    pretty_env_logger::init();
432    log::info!("Starting {release}");
433
434    if dsn.is_empty() {
435        None
436    } else {
437        Some(sentry::init((
438            dsn,
439            sentry::ClientOptions {
440                release: Some(release.into()),
441                ..Default::default()
442            },
443        )))
444    }
445}
446
447#[macro_export]
448macro_rules! configure {
449    ($application: ident) => {
450        let config = $crate::config().await;
451        let _sentry = $crate::setup_logging(
452            concat!(env!("CARGO_PKG_NAME"), "@", env!("CARGO_PKG_VERSION")),
453            config.sentry.$application,
454        )
455        .await;
456    };
457}
458
459#[cfg(feature = "test")]
460#[cfg(test)]
461mod tests {
462    use crate::init;
463
464    #[async_std::test]
465    async fn it_works() {
466        init().await;
467    }
468}