Skip to main content

revolt_config/
lib.rs

1use std::{collections::HashMap, path::Path};
2
3use cached::proc_macro::cached;
4use config::{Config, Environment, File, FileFormat};
5use futures_locks::RwLock;
6use once_cell::sync::Lazy;
7use serde::Deserialize;
8
9#[cfg(feature = "sentry")]
10pub use sentry::{capture_error, capture_message, Level};
11#[cfg(feature = "anyhow")]
12pub use sentry_anyhow::capture_anyhow;
13
14#[cfg(all(feature = "report-macros", feature = "sentry"))]
15#[macro_export]
16macro_rules! report_error {
17    ( $expr: expr, $error: ident $( $tt:tt )? ) => {
18        $expr
19            .inspect_err(|err| {
20                $crate::capture_message(
21                    &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
22                    $crate::Level::Error,
23                );
24            })
25            .map_err(|_| ::revolt_result::create_error!($error))
26    };
27}
28
29#[cfg(all(feature = "report-macros", feature = "sentry"))]
30#[macro_export]
31macro_rules! capture_internal_error {
32    ( $expr: expr ) => {
33        $crate::capture_message(
34            &format!("{:?} ({}:{}:{})", $expr, file!(), line!(), column!()),
35            $crate::Level::Error,
36        );
37    };
38}
39
40#[cfg(all(feature = "report-macros", feature = "sentry"))]
41#[macro_export]
42macro_rules! report_internal_error {
43    ( $expr: expr ) => {
44        $expr
45            .inspect_err(|err| {
46                $crate::capture_message(
47                    &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
48                    $crate::Level::Error,
49                );
50            })
51            .map_err(|_| ::revolt_result::create_error!(InternalError))
52    };
53}
54
55/// Paths to search for configuration
56static CONFIG_SEARCH_PATHS: [&str; 3] = [
57    // current working directory
58    "Revolt.toml",
59    // current working directory - overrides file
60    "Revolt.overrides.toml",
61    // root directory, for Docker containers
62    "/Revolt.toml",
63];
64
65/// Path to search for test overrides
66static TEST_OVERRIDE_PATH: &str = "Revolt.test-overrides.toml";
67
68/// Configuration builder
69static CONFIG_BUILDER: Lazy<RwLock<Config>> = Lazy::new(|| {
70    RwLock::new({
71        let mut builder = Config::builder().add_source(File::from_str(
72            include_str!("../Revolt.toml"),
73            FileFormat::Toml,
74        ));
75
76        if std::env::var("TEST_DB").is_ok() {
77            builder = builder.add_source(File::from_str(
78                include_str!("../Revolt.test.toml"),
79                FileFormat::Toml,
80            ));
81
82            // recursively search upwards for an overrides file (if there is one)
83            if let Ok(cwd) = std::env::current_dir() {
84                let mut path = Some(cwd.as_path());
85                while let Some(current_path) = path {
86                    let target_path = current_path.join(TEST_OVERRIDE_PATH);
87                    if target_path.exists() {
88                        builder = builder
89                            .add_source(File::new(target_path.to_str().unwrap(), FileFormat::Toml));
90                    }
91
92                    path = current_path.parent();
93                }
94            }
95        }
96
97        let cwd = std::env::current_dir().unwrap();
98        let mut cwd: Option<&Path> = Some(&cwd);
99
100        while let Some(path) = cwd {
101            for config_path in CONFIG_SEARCH_PATHS {
102                let config_path = path.join(config_path);
103                if config_path.exists() {
104                    builder = builder
105                        .add_source(File::new(config_path.to_str().unwrap(), FileFormat::Toml));
106                }
107            }
108
109            cwd = path.parent();
110        }
111
112        builder = builder.add_source(Environment::with_prefix("REVOLT").separator("__"));
113
114        builder.build().unwrap()
115    })
116});
117
118#[derive(Deserialize, Debug, Clone)]
119pub struct Database {
120    pub mongodb: String,
121    pub redis: String,
122    pub redis_pubsub: Option<String>,
123}
124
125#[derive(Deserialize, Debug, Clone)]
126pub struct RabbitQueues {
127    pub acks: String,
128}
129
130#[derive(Deserialize, Debug, Clone)]
131pub struct Rabbit {
132    pub host: String,
133    pub port: u16,
134    pub username: String,
135    pub password: String,
136    pub default_exchange: String,
137    pub queues: RabbitQueues,
138}
139
140#[derive(Deserialize, Debug, Clone)]
141pub struct Hosts {
142    pub app: String,
143    pub api: String,
144    pub events: String,
145    pub autumn: String,
146    pub january: String,
147    pub livekit: HashMap<String, String>,
148}
149
150#[derive(Deserialize, Debug, Clone)]
151pub struct ApiRegistration {
152    pub invite_only: bool,
153}
154
155#[derive(Deserialize, Debug, Clone)]
156pub struct ApiSmtp {
157    pub host: String,
158    pub username: String,
159    pub password: String,
160    pub from_address: String,
161    pub reply_to: Option<String>,
162    pub port: Option<i32>,
163    pub use_tls: Option<bool>,
164    pub use_starttls: Option<bool>,
165}
166
167#[derive(Deserialize, Debug, Clone)]
168pub struct PushVapid {
169    pub queue: String,
170    pub private_key: String,
171    pub public_key: String,
172}
173
174#[derive(Deserialize, Debug, Clone)]
175pub struct PushFcm {
176    pub queue: String,
177    pub key_type: String,
178    pub project_id: String,
179    pub private_key_id: String,
180    pub private_key: String,
181    pub client_email: String,
182    pub client_id: String,
183    pub auth_uri: String,
184    pub token_uri: String,
185    pub auth_provider_x509_cert_url: String,
186    pub client_x509_cert_url: String,
187}
188
189#[derive(Deserialize, Debug, Clone)]
190pub struct PushApn {
191    pub queue: String,
192    pub sandbox: bool,
193    pub pkcs8: String,
194    pub key_id: String,
195    pub team_id: String,
196}
197
198#[derive(Deserialize, Debug, Clone)]
199pub struct ApiSecurityCaptcha {
200    pub hcaptcha_key: String,
201    pub hcaptcha_sitekey: String,
202}
203
204#[derive(Deserialize, Debug, Clone)]
205pub struct ApiSecurity {
206    pub authifier_shield_key: String,
207    pub voso_legacy_token: String,
208    pub captcha: ApiSecurityCaptcha,
209    pub trust_cloudflare: bool,
210    pub easypwned: String,
211    pub tenor_key: String,
212}
213
214#[derive(Deserialize, Debug, Clone)]
215pub struct ApiWorkers {
216    pub max_concurrent_connections: usize,
217}
218
219#[derive(Deserialize, Debug, Clone)]
220pub struct ApiLiveKit {
221    pub call_ring_duration: usize,
222    pub nodes: HashMap<String, LiveKitNode>,
223}
224
225#[derive(Deserialize, Debug, Clone)]
226pub struct LiveKitNode {
227    pub url: String,
228    pub lat: f64,
229    pub lon: f64,
230    pub key: String,
231    pub secret: String,
232
233    // whether to hide the node in the nodes list
234    #[serde(default)]
235    pub private: bool,
236}
237
238#[derive(Deserialize, Debug, Clone)]
239pub struct ApiUsers {
240    pub early_adopter_cutoff: Option<u64>,
241    pub min_username_length: usize,
242}
243
244#[derive(Deserialize, Debug, Clone)]
245pub struct Api {
246    pub registration: ApiRegistration,
247    pub smtp: ApiSmtp,
248    pub security: ApiSecurity,
249    pub workers: ApiWorkers,
250    pub livekit: ApiLiveKit,
251    pub users: ApiUsers,
252}
253
254#[derive(Deserialize, Debug, Clone)]
255pub struct Pushd {
256    pub production: bool,
257    pub exchange: String,
258    pub mass_mention_chunk_size: usize,
259    pub render_cache_time: usize,
260
261    // Queues
262    pub message_queue: String,
263    pub mass_mention_queue: String,
264    pub dm_call_queue: String,
265    pub fr_accepted_queue: String,
266    pub fr_received_queue: String,
267    pub generic_queue: String,
268    pub ack_queue: String,
269
270    pub vapid: PushVapid,
271    pub fcm: PushFcm,
272    pub apn: PushApn,
273}
274
275impl Pushd {
276    fn get_routing_key(&self, key: String) -> String {
277        match self.production {
278            true => key + "-prd",
279            false => key + "-tst",
280        }
281    }
282
283    pub fn get_ack_routing_key(&self) -> String {
284        self.get_routing_key(self.ack_queue.clone())
285    }
286
287    pub fn get_message_routing_key(&self) -> String {
288        self.get_routing_key(self.message_queue.clone())
289    }
290
291    pub fn get_mass_mention_routing_key(&self) -> String {
292        self.get_routing_key(self.mass_mention_queue.clone())
293    }
294
295    pub fn get_dm_call_routing_key(&self) -> String {
296        self.get_routing_key(self.dm_call_queue.clone())
297    }
298
299    pub fn get_fr_accepted_routing_key(&self) -> String {
300        self.get_routing_key(self.fr_accepted_queue.clone())
301    }
302
303    pub fn get_fr_received_routing_key(&self) -> String {
304        self.get_routing_key(self.fr_received_queue.clone())
305    }
306
307    pub fn get_generic_routing_key(&self) -> String {
308        self.get_routing_key(self.generic_queue.clone())
309    }
310}
311
312#[derive(Deserialize, Debug, Clone)]
313pub struct January {
314    pub blocked_domains: Vec<String>,
315}
316
317#[derive(Deserialize, Debug, Clone)]
318pub struct FilesLimit {
319    pub min_file_size: usize,
320    pub min_resolution: [usize; 2],
321    pub max_mega_pixels: usize,
322    pub max_pixel_side: usize,
323}
324
325#[derive(Deserialize, Debug, Clone)]
326pub struct FilesS3 {
327    pub endpoint: String,
328    pub path_style_buckets: bool,
329    pub region: String,
330    pub access_key_id: String,
331    pub secret_access_key: String,
332    pub default_bucket: String,
333}
334
335#[derive(Deserialize, Debug, Clone)]
336pub struct Files {
337    pub encryption_key: String,
338    pub webp_quality: f32,
339    pub blocked_mime_types: Vec<String>,
340    pub clamd_host: String,
341    pub scan_mime_types: Vec<String>,
342
343    pub limit: FilesLimit,
344    pub preview: HashMap<String, [usize; 2]>,
345    pub s3: FilesS3,
346}
347
348#[derive(Deserialize, Debug, Clone)]
349pub struct GlobalLimits {
350    pub group_size: usize,
351    pub message_embeds: usize,
352    pub message_replies: usize,
353    pub message_reactions: usize,
354    pub server_emoji: usize,
355    pub server_roles: usize,
356    pub server_channels: usize,
357
358    pub new_user_hours: usize,
359
360    pub body_limit_size: usize,
361
362    pub restrict_server_creation: Vec<String>,
363}
364
365#[derive(Deserialize, Debug, Clone)]
366pub struct FeaturesLimits {
367    pub outgoing_friend_requests: usize,
368
369    pub bots: usize,
370    pub message_length: usize,
371    pub message_attachments: usize,
372    pub servers: usize,
373    pub voice_quality: u32,
374    pub video: bool,
375    pub video_resolution: [u32; 2],
376    pub video_aspect_ratio: [f32; 2],
377
378    pub file_upload_size_limit: HashMap<String, usize>,
379}
380
381#[derive(Deserialize, Debug, Clone)]
382pub struct FeaturesLimitsCollection {
383    pub global: GlobalLimits,
384
385    pub new_user: FeaturesLimits,
386    pub default: FeaturesLimits,
387
388    #[serde(flatten)]
389    pub roles: HashMap<String, FeaturesLimits>,
390}
391
392#[derive(Deserialize, Debug, Clone)]
393pub struct LegalLinks {
394    /// Terms of Service URL
395    pub terms_of_service: String,
396    /// Privacy Policy URL
397    pub privacy_policy: String,
398    /// Guidelines URL
399    pub guidelines: String,
400}
401
402#[derive(Deserialize, Debug, Clone)]
403pub struct FeaturesAdvanced {
404    #[serde(default)]
405    pub process_message_delay_limit: u16,
406}
407
408impl Default for FeaturesAdvanced {
409    fn default() -> Self {
410        Self {
411            process_message_delay_limit: 5,
412        }
413    }
414}
415
416#[derive(Deserialize, Debug, Clone)]
417pub struct Features {
418    pub limits: FeaturesLimitsCollection,
419    pub legal_links: LegalLinks,
420    pub webhooks_enabled: bool,
421    pub mass_mentions_send_notifications: bool,
422    pub mass_mentions_enabled: bool,
423
424    #[serde(default)]
425    pub advanced: FeaturesAdvanced,
426}
427
428#[derive(Deserialize, Debug, Clone)]
429pub struct Sentry {
430    pub api: String,
431    pub events: String,
432    pub voice_ingress: String,
433    pub files: String,
434    pub proxy: String,
435    pub pushd: String,
436    pub crond: String,
437    pub gifbox: String,
438}
439
440#[derive(Deserialize, Debug, Clone)]
441pub struct Settings {
442    pub database: Database,
443    pub rabbit: Rabbit,
444    pub hosts: Hosts,
445    pub api: Api,
446    pub pushd: Pushd,
447    pub january: January,
448    pub files: Files,
449    pub features: Features,
450    pub sentry: Sentry,
451    pub production: bool,
452    pub disable_events_dont_use: bool,
453}
454
455impl Settings {
456    pub fn preflight_checks(&self) {
457        if self.api.smtp.host.is_empty() {
458            log::warn!("No SMTP settings specified! Remember to configure email.");
459        }
460
461        if self.api.security.captcha.hcaptcha_key.is_empty() {
462            log::warn!("No Captcha key specified! Remember to add hCaptcha key.");
463        }
464    }
465}
466
467pub async fn init() {
468    println!(
469        ":: Revolt Configuration ::\n\x1b[32m{:?}\x1b[0m",
470        config().await
471    );
472}
473
474pub async fn read() -> Config {
475    CONFIG_BUILDER.read().await.clone()
476}
477
478#[cached(time = 30)]
479pub async fn config() -> Settings {
480    let mut config = read().await.try_deserialize::<Settings>().unwrap();
481
482    // inject REDIS_URI for redis-kiss library
483    if std::env::var("REDIS_URI").is_err() {
484        std::env::set_var("REDIS_URI", config.database.redis.clone());
485    }
486
487    // auto-detect production nodes
488    if config.hosts.api.contains("https")
489        && (config.hosts.api.contains("revolt.chat") || config.hosts.api.contains("stoat.chat"))
490    {
491        config.production = true;
492    }
493
494    config
495}
496
497/// Configure logging and common Rust variables
498#[cfg(feature = "sentry")]
499pub async fn setup_logging(release: &'static str, dsn: String) -> Option<sentry::ClientInitGuard> {
500    if std::env::var("RUST_LOG").is_err() {
501        std::env::set_var("RUST_LOG", "info");
502    }
503
504    if std::env::var("ROCKET_ADDRESS").is_err() {
505        std::env::set_var("ROCKET_ADDRESS", "0.0.0.0");
506    }
507
508    pretty_env_logger::init();
509    log::info!("Starting {release}");
510
511    if dsn.is_empty() {
512        None
513    } else {
514        Some(sentry::init((
515            dsn,
516            sentry::ClientOptions {
517                release: Some(release.into()),
518                ..Default::default()
519            },
520        )))
521    }
522}
523
524#[cfg(feature = "sentry")]
525#[macro_export]
526macro_rules! configure {
527    ($application: ident) => {
528        let config = $crate::config().await;
529        let _sentry = $crate::setup_logging(
530            concat!(env!("CARGO_PKG_NAME"), "@", env!("CARGO_PKG_VERSION")),
531            config.sentry.$application,
532        )
533        .await;
534    };
535}
536
537#[cfg(feature = "test")]
538#[cfg(test)]
539mod tests {
540    use crate::init;
541
542    #[async_std::test]
543    async fn it_works() {
544        init().await;
545    }
546}