1use std::collections::HashMap;
2
3use cached::proc_macro::cached;
4use config::{Config, File, FileFormat};
5use futures_locks::RwLock;
6use once_cell::sync::Lazy;
7use serde::Deserialize;
8
9pub use sentry::{capture_error, capture_message, Level};
10pub use sentry_anyhow::capture_anyhow;
11
12#[cfg(feature = "report-macros")]
13#[macro_export]
14macro_rules! report_error {
15 ( $expr: expr, $error: ident $( $tt:tt )? ) => {
16 $expr
17 .inspect_err(|err| {
18 $crate::capture_message(
19 &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
20 $crate::Level::Error,
21 );
22 })
23 .map_err(|_| ::revolt_result::create_error!($error))
24 };
25}
26
27#[cfg(feature = "report-macros")]
28#[macro_export]
29macro_rules! capture_internal_error {
30 ( $expr: expr ) => {
31 $crate::capture_message(
32 &format!("{:?} ({}:{}:{})", $expr, file!(), line!(), column!()),
33 $crate::Level::Error,
34 );
35 };
36}
37
38#[cfg(feature = "report-macros")]
39#[macro_export]
40macro_rules! report_internal_error {
41 ( $expr: expr ) => {
42 $expr
43 .inspect_err(|err| {
44 $crate::capture_message(
45 &format!("{err:?} ({}:{}:{})", file!(), line!(), column!()),
46 $crate::Level::Error,
47 );
48 })
49 .map_err(|_| ::revolt_result::create_error!(InternalError))
50 };
51}
52
53static CONFIG_SEARCH_PATHS: [&str; 3] = [
55 "Revolt.toml",
57 "Revolt.overrides.toml",
59 "/Revolt.toml",
61];
62
63static TEST_OVERRIDE_PATH: &str = "Revolt.test-overrides.toml";
65
66static CONFIG_BUILDER: Lazy<RwLock<Config>> = Lazy::new(|| {
68 RwLock::new({
69 let mut builder = Config::builder().add_source(File::from_str(
70 include_str!("../Revolt.toml"),
71 FileFormat::Toml,
72 ));
73
74 if std::env::var("TEST_DB").is_ok() {
75 builder = builder.add_source(File::from_str(
76 include_str!("../Revolt.test.toml"),
77 FileFormat::Toml,
78 ));
79
80 if let Ok(cwd) = std::env::current_dir() {
82 let mut path = Some(cwd.as_path());
83 while let Some(current_path) = path {
84 let target_path = current_path.join(TEST_OVERRIDE_PATH);
85 if target_path.exists() {
86 builder = builder
87 .add_source(File::new(target_path.to_str().unwrap(), FileFormat::Toml));
88 }
89
90 path = current_path.parent();
91 }
92 }
93 }
94
95 for path in CONFIG_SEARCH_PATHS {
96 if std::path::Path::new(path).exists() {
97 builder = builder.add_source(File::new(path, FileFormat::Toml));
98 }
99 }
100
101 builder.build().unwrap()
102 })
103});
104
105#[derive(Deserialize, Debug, Clone)]
106pub struct Database {
107 pub mongodb: String,
108 pub redis: String,
109}
110
111#[derive(Deserialize, Debug, Clone)]
112pub struct Rabbit {
113 pub host: String,
114 pub port: u16,
115 pub username: String,
116 pub password: String,
117}
118
119#[derive(Deserialize, Debug, Clone)]
120pub struct Hosts {
121 pub app: String,
122 pub api: String,
123 pub events: String,
124 pub autumn: String,
125 pub january: String,
126 pub voso_legacy: String,
127 pub voso_legacy_ws: String,
128}
129
130#[derive(Deserialize, Debug, Clone)]
131pub struct ApiRegistration {
132 pub invite_only: bool,
133}
134
135#[derive(Deserialize, Debug, Clone)]
136pub struct ApiSmtp {
137 pub host: String,
138 pub username: String,
139 pub password: String,
140 pub from_address: String,
141 pub reply_to: Option<String>,
142 pub port: Option<i32>,
143 pub use_tls: Option<bool>,
144 pub use_starttls: Option<bool>,
145}
146
147#[derive(Deserialize, Debug, Clone)]
148pub struct PushVapid {
149 pub queue: String,
150 pub private_key: String,
151 pub public_key: String,
152}
153
154#[derive(Deserialize, Debug, Clone)]
155pub struct PushFcm {
156 pub queue: String,
157 pub key_type: String,
158 pub project_id: String,
159 pub private_key_id: String,
160 pub private_key: String,
161 pub client_email: String,
162 pub client_id: String,
163 pub auth_uri: String,
164 pub token_uri: String,
165 pub auth_provider_x509_cert_url: String,
166 pub client_x509_cert_url: String,
167}
168
169#[derive(Deserialize, Debug, Clone)]
170pub struct PushApn {
171 pub queue: String,
172 pub sandbox: bool,
173 pub pkcs8: String,
174 pub key_id: String,
175 pub team_id: String,
176}
177
178#[derive(Deserialize, Debug, Clone)]
179pub struct ApiSecurityCaptcha {
180 pub hcaptcha_key: String,
181 pub hcaptcha_sitekey: String,
182}
183
184#[derive(Deserialize, Debug, Clone)]
185pub struct ApiSecurity {
186 pub authifier_shield_key: String,
187 pub voso_legacy_token: String,
188 pub captcha: ApiSecurityCaptcha,
189 pub trust_cloudflare: bool,
190 pub easypwned: String,
191}
192
193#[derive(Deserialize, Debug, Clone)]
194pub struct ApiWorkers {
195 pub max_concurrent_connections: usize,
196}
197
198#[derive(Deserialize, Debug, Clone)]
199pub struct ApiUsers {
200 pub early_adopter_cutoff: Option<u64>,
201}
202
203#[derive(Deserialize, Debug, Clone)]
204pub struct Api {
205 pub registration: ApiRegistration,
206 pub smtp: ApiSmtp,
207 pub security: ApiSecurity,
208 pub workers: ApiWorkers,
209 pub users: ApiUsers,
210}
211
212#[derive(Deserialize, Debug, Clone)]
213pub struct Pushd {
214 pub production: bool,
215 pub exchange: String,
216 pub mass_mention_chunk_size: usize,
217
218 pub message_queue: String,
220 pub mass_mention_queue: String,
221 pub fr_accepted_queue: String,
222 pub fr_received_queue: String,
223 pub generic_queue: String,
224 pub ack_queue: String,
225
226 pub vapid: PushVapid,
227 pub fcm: PushFcm,
228 pub apn: PushApn,
229}
230
231impl Pushd {
232 fn get_routing_key(&self, key: String) -> String {
233 match self.production {
234 true => key + "-prd",
235 false => key + "-tst",
236 }
237 }
238
239 pub fn get_ack_routing_key(&self) -> String {
240 self.get_routing_key(self.ack_queue.clone())
241 }
242
243 pub fn get_message_routing_key(&self) -> String {
244 self.get_routing_key(self.message_queue.clone())
245 }
246
247 pub fn get_mass_mention_routing_key(&self) -> String {
248 self.get_routing_key(self.mass_mention_queue.clone())
249 }
250
251 pub fn get_fr_accepted_routing_key(&self) -> String {
252 self.get_routing_key(self.fr_accepted_queue.clone())
253 }
254
255 pub fn get_fr_received_routing_key(&self) -> String {
256 self.get_routing_key(self.fr_received_queue.clone())
257 }
258
259 pub fn get_generic_routing_key(&self) -> String {
260 self.get_routing_key(self.generic_queue.clone())
261 }
262}
263
264#[derive(Deserialize, Debug, Clone)]
265pub struct FilesLimit {
266 pub min_file_size: usize,
267 pub min_resolution: [usize; 2],
268 pub max_mega_pixels: usize,
269 pub max_pixel_side: usize,
270}
271
272#[derive(Deserialize, Debug, Clone)]
273pub struct FilesS3 {
274 pub endpoint: String,
275 pub path_style_buckets: bool,
276 pub region: String,
277 pub access_key_id: String,
278 pub secret_access_key: String,
279 pub default_bucket: String,
280}
281
282#[derive(Deserialize, Debug, Clone)]
283pub struct Files {
284 pub encryption_key: String,
285 pub webp_quality: f32,
286 pub blocked_mime_types: Vec<String>,
287 pub clamd_host: String,
288 pub scan_mime_types: Vec<String>,
289
290 pub limit: FilesLimit,
291 pub preview: HashMap<String, [usize; 2]>,
292 pub s3: FilesS3,
293}
294
295#[derive(Deserialize, Debug, Clone)]
296pub struct GlobalLimits {
297 pub group_size: usize,
298 pub message_embeds: usize,
299 pub message_replies: usize,
300 pub message_reactions: usize,
301 pub server_emoji: usize,
302 pub server_roles: usize,
303 pub server_channels: usize,
304
305 pub new_user_hours: usize,
306
307 pub body_limit_size: usize,
308}
309
310#[derive(Deserialize, Debug, Clone)]
311pub struct FeaturesLimits {
312 pub outgoing_friend_requests: usize,
313
314 pub bots: usize,
315 pub message_length: usize,
316 pub message_attachments: usize,
317 pub servers: usize,
318
319 pub file_upload_size_limit: HashMap<String, usize>,
320}
321
322#[derive(Deserialize, Debug, Clone)]
323pub struct FeaturesLimitsCollection {
324 pub global: GlobalLimits,
325
326 pub new_user: FeaturesLimits,
327 pub default: FeaturesLimits,
328
329 #[serde(flatten)]
330 pub roles: HashMap<String, FeaturesLimits>,
331}
332
333#[derive(Deserialize, Debug, Clone)]
334pub struct FeaturesAdvanced {
335 #[serde(default)]
336 pub process_message_delay_limit: u16,
337}
338
339impl Default for FeaturesAdvanced {
340 fn default() -> Self {
341 Self {
342 process_message_delay_limit: 5,
343 }
344 }
345}
346
347#[derive(Deserialize, Debug, Clone)]
348pub struct Features {
349 pub limits: FeaturesLimitsCollection,
350 pub webhooks_enabled: bool,
351 pub mass_mentions_send_notifications: bool,
352 pub mass_mentions_enabled: bool,
353
354 #[serde(default)]
355 pub advanced: FeaturesAdvanced,
356}
357
358#[derive(Deserialize, Debug, Clone)]
359pub struct Sentry {
360 pub api: String,
361 pub events: String,
362 pub files: String,
363 pub proxy: String,
364 pub pushd: String,
365 pub crond: String,
366}
367
368#[derive(Deserialize, Debug, Clone)]
369pub struct Settings {
370 pub database: Database,
371 pub rabbit: Rabbit,
372 pub hosts: Hosts,
373 pub api: Api,
374 pub pushd: Pushd,
375 pub files: Files,
376 pub features: Features,
377 pub sentry: Sentry,
378 pub production: bool,
379}
380
381impl Settings {
382 pub fn preflight_checks(&self) {
383 if self.api.smtp.host.is_empty() {
384 log::warn!("No SMTP settings specified! Remember to configure email.");
385 }
386
387 if self.api.security.captcha.hcaptcha_key.is_empty() {
388 log::warn!("No Captcha key specified! Remember to add hCaptcha key.");
389 }
390 }
391}
392
393pub async fn init() {
394 println!(
395 ":: Revolt Configuration ::\n\x1b[32m{:?}\x1b[0m",
396 config().await
397 );
398}
399
400pub async fn read() -> Config {
401 CONFIG_BUILDER.read().await.clone()
402}
403
404#[cached(time = 30)]
405pub async fn config() -> Settings {
406 let mut config = read().await.try_deserialize::<Settings>().unwrap();
407
408 if std::env::var("REDIS_URL").is_err() {
410 std::env::set_var("REDIS_URI", config.database.redis.clone());
411 }
412
413 if config.hosts.api.contains("https") && config.hosts.api.contains("revolt.chat") {
415 config.production = true;
416 }
417
418 config
419}
420
421pub async fn setup_logging(release: &'static str, dsn: String) -> Option<sentry::ClientInitGuard> {
423 if std::env::var("RUST_LOG").is_err() {
424 std::env::set_var("RUST_LOG", "info");
425 }
426
427 if std::env::var("ROCKET_ADDRESS").is_err() {
428 std::env::set_var("ROCKET_ADDRESS", "0.0.0.0");
429 }
430
431 pretty_env_logger::init();
432 log::info!("Starting {release}");
433
434 if dsn.is_empty() {
435 None
436 } else {
437 Some(sentry::init((
438 dsn,
439 sentry::ClientOptions {
440 release: Some(release.into()),
441 ..Default::default()
442 },
443 )))
444 }
445}
446
447#[macro_export]
448macro_rules! configure {
449 ($application: ident) => {
450 let config = $crate::config().await;
451 let _sentry = $crate::setup_logging(
452 concat!(env!("CARGO_PKG_NAME"), "@", env!("CARGO_PKG_VERSION")),
453 config.sentry.$application,
454 )
455 .await;
456 };
457}
458
459#[cfg(feature = "test")]
460#[cfg(test)]
461mod tests {
462 use crate::init;
463
464 #[async_std::test]
465 async fn it_works() {
466 init().await;
467 }
468}