1use crate::Error;
2use crate::{duration::NonZeroDuration, hash::NoOpHasherDefault, Result};
3use config::{Config, Environment, File, FileFormat};
4use notify::{event::ModifyKind, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
5use parking_lot::RwLock;
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::{
10 any::{Any, TypeId},
11 collections::HashMap,
12 fs,
13 ops::Deref,
14 path::{Path, PathBuf},
15 sync::Arc,
16 time::Duration,
17};
18use tracing::{error, info};
19
20pub const CARGO_PKG_VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION");
21
22fn default_version() -> String {
23 CARGO_PKG_VERSION.map(ToOwned::to_owned).unwrap_or_default()
24}
25
26fn default_nips() -> Vec<u32> {
27 vec![1, 2, 4, 9, 11, 12, 15, 16, 20, 22, 25, 26, 28, 33, 40, 70]
28}
29
30#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
31#[serde(default)]
32pub struct Information {
33 pub name: String,
34 pub description: String,
35 pub pubkey: Option<String>,
36 pub contact: Option<String>,
37 pub software: String,
38 #[serde(skip_deserializing)]
39 pub version: String,
40 #[serde(skip_deserializing)]
41 pub supported_nips: Vec<u32>,
42}
43
44impl Default for Information {
45 fn default() -> Self {
46 Self {
47 name: Default::default(),
48 description: Default::default(),
49 pubkey: Default::default(),
50 contact: Default::default(),
51 software: Default::default(),
52 version: default_version(),
53 supported_nips: default_nips(),
54 }
55 }
56}
57
58#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
59#[serde(default)]
60pub struct Data {
61 pub path: PathBuf,
62
63 pub db_query_timeout: Option<NonZeroDuration>,
65}
66
67impl Default for Data {
68 fn default() -> Self {
69 Self {
70 path: PathBuf::from("./data"),
71 db_query_timeout: None,
72 }
73 }
74}
75
76#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
78#[serde(default)]
79pub struct Thread {
80 pub http: usize,
82 pub reader: usize,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
88#[serde(default)]
89pub struct Network {
90 pub host: String,
92 pub port: u16,
94 pub heartbeat_timeout: NonZeroDuration,
97
98 pub heartbeat_interval: NonZeroDuration,
101
102 pub real_ip_header: Option<String>,
103
104 pub index_redirect_to: Option<String>,
106}
107
108impl Default for Network {
109 fn default() -> Self {
110 Self {
111 host: "127.0.0.1".to_string(),
112 port: 8080,
113 heartbeat_interval: Duration::from_secs(60).try_into().unwrap(),
114 heartbeat_timeout: Duration::from_secs(120).try_into().unwrap(),
115 real_ip_header: None,
116 index_redirect_to: None,
117 }
118 }
119}
120
121#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
122#[serde(default)]
123pub struct Limitation {
124 pub max_message_length: usize,
126 pub max_subscriptions: usize,
128 pub max_filters: usize,
130 pub max_limit: u64,
132 pub max_subid_length: usize,
134 pub min_prefix: usize,
136 pub max_event_tags: usize,
138 pub max_event_time_older_than_now: u64,
140 pub max_event_time_newer_than_now: u64,
142}
143
144impl Default for Limitation {
145 fn default() -> Self {
146 Self {
147 max_message_length: 524288,
148 max_subscriptions: 20,
149 max_filters: 10,
150 max_limit: 300,
151 max_subid_length: 100,
152 min_prefix: 10,
153 max_event_tags: 5000,
154 max_event_time_older_than_now: 94608000,
155 max_event_time_newer_than_now: 900,
156 }
157 }
158}
159
160#[derive(Debug, Serialize, Deserialize, Default)]
161#[serde(default)]
162pub struct Setting {
163 pub information: Information,
164 pub data: Data,
165 pub thread: Thread,
166 pub network: Network,
167 pub limitation: Limitation,
168
169 #[serde(flatten)]
171 pub extra: HashMap<String, Value>,
172
173 #[serde(skip)]
175 extensions: HashMap<TypeId, Box<dyn Any + Send + Sync>, NoOpHasherDefault>,
176
177 #[serde(skip)]
179 ext_information: HashMap<String, Value>,
180
181 #[serde(skip)]
183 ext_limitation: HashMap<String, Value>,
184}
185
186impl PartialEq for Setting {
187 fn eq(&self, other: &Self) -> bool {
188 self.information == other.information
189 && self.data == other.data
190 && self.thread == other.thread
191 && self.network == other.network
192 && self.limitation == other.limitation
193 && self.extra == other.extra
194 }
195}
196
197#[derive(Debug, Clone)]
198pub struct SettingWrapper {
199 inner: Arc<RwLock<Setting>>,
200 watcher: Option<Arc<RecommendedWatcher>>,
201}
202
203impl Deref for SettingWrapper {
204 type Target = Arc<RwLock<Setting>>;
205 fn deref(&self) -> &Self::Target {
206 &self.inner
207 }
208}
209
210impl From<Setting> for SettingWrapper {
211 fn from(setting: Setting) -> Self {
212 Self {
213 inner: Arc::new(RwLock::new(setting)),
214 watcher: None,
215 }
216 }
217}
218
219impl SettingWrapper {
220 pub fn reload<P: AsRef<Path>>(&self, file: P, env_prefix: Option<String>) -> Result<()> {
222 let setting = Setting::read(&file, env_prefix)?;
223 {
224 let mut w = self.write();
225 *w = setting;
226 }
227 Ok(())
228 }
229
230 pub fn watch<P: AsRef<Path>, F: Fn(&SettingWrapper) + Send + 'static>(
232 file: P,
233 env_prefix: Option<String>,
234 f: F,
235 ) -> Result<Self> {
236 let mut setting: SettingWrapper = Setting::read(&file, env_prefix.clone())?.into();
237 let c_setting = setting.clone();
238
239 let file = fs::canonicalize(file.as_ref())?;
242 let c_file = file.clone();
243
244 let dir = file
249 .parent()
250 .ok_or_else(|| Error::Message("failed to get config dir".to_owned()))?;
251
252 let mut watcher = RecommendedWatcher::new(
253 move |result: Result<Event, notify::Error>| match result {
254 Ok(event) => {
255 #[cfg(target_os = "windows")]
256 let is_modify = matches!(event.kind, EventKind::Modify(ModifyKind::Any));
258 #[cfg(not(target_os = "windows"))]
259 let is_modify = matches!(event.kind, EventKind::Modify(ModifyKind::Data(_)));
260 if is_modify && event.paths.contains(&c_file) {
261 match c_setting.reload(&c_file, env_prefix.clone()) {
262 Ok(_) => {
263 info!("Reload config success {:?}", c_file);
264 info!("{:?}", c_setting.read());
265 f(&c_setting);
266 }
267 Err(e) => {
268 error!(
269 error = e.to_string(),
270 "failed to reload config {:?}", c_file
271 );
272 }
273 }
274 }
275 }
276 Err(e) => {
277 error!(error = e.to_string(), "failed to watch file {:?}", c_file);
278 }
279 },
280 notify::Config::default(),
281 )?;
282
283 watcher.watch(dir, RecursiveMode::NonRecursive)?;
284 setting.watcher = Some(Arc::new(watcher));
286
287 Ok(setting)
288 }
289}
290
291impl Setting {
292 pub fn add_nip(&mut self, nip: u32) {
294 if !self.information.supported_nips.contains(&nip) {
295 self.information.supported_nips.push(nip);
296 self.information.supported_nips.sort();
297 }
298 }
299
300 pub fn add_information(&mut self, key: String, value: Value) {
302 self.ext_information.insert(key, value);
303 }
304
305 pub fn add_limitation(&mut self, key: String, value: Value) {
307 self.ext_limitation.insert(key, value);
308 }
309
310 pub fn parse_extension<T: DeserializeOwned + Default>(&self, key: &str) -> T {
312 self.extra
313 .get(key)
314 .and_then(|v| {
315 let r = serde_json::from_value::<T>(v.clone());
316 if let Err(err) = &r {
317 error!(error = err.to_string(), "failed to parse {:?} setting", key);
318 }
319 r.ok()
320 })
321 .unwrap_or_default()
322 }
323
324 pub fn set_extension<T: Send + Sync + 'static>(&mut self, val: T) {
326 self.extensions.insert(TypeId::of::<T>(), Box::new(val));
327 }
328
329 pub fn get_extension<T: 'static>(&self) -> Option<&T> {
331 self.extensions
332 .get(&TypeId::of::<T>())
333 .and_then(|boxed| boxed.downcast_ref())
334 }
335
336 pub fn render_information(&self) -> Result<String> {
338 let info = &self.information;
339 let mut val = json!({
340 "name": info.name,
341 "description": info.description,
342 "pubkey": info.pubkey,
343 "contact": info.contact,
344 "software": info.software,
345 "version": info.version,
346 "supported_nips": info.supported_nips,
347 "limitation": &self.limitation,
348 });
349 self.ext_limitation.iter().for_each(|(k, v)| {
350 val["limitation"][k] = v.clone();
351 });
352 self.ext_information.iter().for_each(|(k, v)| {
353 val[k] = v.clone();
354 });
355 Ok(serde_json::to_string_pretty(&val)?)
356 }
357
358 pub fn read<P: AsRef<Path>>(file: P, env_prefix: Option<String>) -> Result<Self> {
360 let builder = Config::builder();
361 let mut config = builder
362 .add_source(File::with_name(file.as_ref().to_str().unwrap()));
367 if let Some(prefix) = env_prefix {
368 config = config.add_source(Self::env_source(&prefix));
369 }
370
371 let config = config.build()?;
372 let mut setting: Setting = config.try_deserialize()?;
373 setting.correct();
374 Ok(setting)
375 }
376
377 fn env_source(prefix: &str) -> Environment {
378 Environment::with_prefix(prefix)
379 .try_parsing(true)
380 .prefix_separator("_")
381 .separator("__")
382 }
385
386 pub fn from_env(env_prefix: String) -> Result<Self> {
388 let mut config = Config::builder();
389 config = config.add_source(Self::env_source(&env_prefix));
390 let config = config.build()?;
391 let mut setting: Setting = config.try_deserialize()?;
392 setting.correct();
393 Ok(setting)
394 }
395
396 pub fn from_str(s: &str, format: FileFormat) -> Result<Self> {
398 let builder = Config::builder();
399 let config = builder.add_source(File::from_str(s, format)).build()?;
400 let mut setting: Setting = config.try_deserialize()?;
401 setting.correct();
402 Ok(setting)
403 }
404
405 fn correct(&mut self) {
406 if self.network.heartbeat_timeout <= self.network.heartbeat_interval {
407 error!("network heartbeat_timeout must bigger than heartbeat_interval, use defaults");
408 self.network.heartbeat_interval = Duration::from_secs(60).try_into().unwrap();
409 self.network.heartbeat_timeout = Duration::from_secs(120).try_into().unwrap();
410 }
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use anyhow::Result;
418 use config::FileFormat;
419 use std::{fs, thread::sleep, time::Duration};
420 use tempfile::Builder;
421
422 #[test]
423 fn der() -> Result<()> {
424 let json = r#"{
425 "network": {"port": 1},
426 "information": {"name": "test"},
427 "data": {},
428 "thread": {"http": 1},
429 "limitation": {}
430 }"#;
431
432 let mut def = Setting::default();
433 def.network.port = 1;
434 def.information.name = "test".to_owned();
435 def.thread.http = 1;
436
437 let s2 = serde_json::from_str::<Setting>(json)?;
438 let s1: Setting = Setting::from_str(json, FileFormat::Json)?;
439
440 assert_eq!(def, s1);
441 assert_eq!(def, s2);
442
443 Ok(())
444 }
445
446 #[test]
447 fn render() -> Result<()> {
448 let mut def = Setting::default();
449 def.add_nip(1234567);
450 def.add_limitation("payment_required".to_owned(), json!(true));
451 def.add_information("payments_url".to_owned(), json!("https://payments"));
452 let info = def.render_information()?;
453 let val: Value = serde_json::from_str(&info)?;
454 assert!(val["supported_nips"]
456 .as_array()
457 .unwrap()
458 .contains(&Value::Number(serde_json::Number::from(1234567))));
459 assert_eq!(val["payments_url"], json!("https://payments"));
460 assert_eq!(val["limitation"]["payment_required"], json!(true));
461 Ok(())
462 }
463
464 #[test]
465 fn read() -> Result<()> {
466 let setting = Setting::default();
467 assert_eq!(setting.information.name, "");
468 assert!(setting.information.supported_nips.contains(&1));
469
470 let file = Builder::new()
471 .prefix("nostr-relay-config-test-read")
472 .suffix(".toml")
473 .rand_bytes(0)
474 .tempfile()?;
475
476 let setting = Setting::read(&file, None)?;
477 assert_eq!(setting.information.name, "");
478 assert!(setting.information.supported_nips.contains(&1));
479 fs::write(
480 &file,
481 r#"
482 [information]
483 name = "nostr"
484 [network]
485 host = "127.0.0.1"
486 "#,
487 )?;
488
489 temp_env::with_vars(
490 [
491 ("NOSTR_information.description", Some("test")),
492 ("NOSTR_information__contact", Some("test")),
493 ("NOSTR_INFORMATION__PUBKEY", Some("test")),
494 ("NOSTR_NETWORK__PORT", Some("1")),
495 ],
496 || {
497 let setting = Setting::read(&file, Some("NOSTR".to_owned())).unwrap();
498 assert_eq!(setting.information.name, "nostr".to_string());
499 assert_eq!(setting.information.description, "test".to_string());
500 assert_eq!(setting.information.contact, Some("test".to_string()));
501 assert_eq!(setting.information.pubkey, Some("test".to_string()));
502 assert_eq!(setting.network.port, 1);
503 },
504 );
505 Ok(())
506 }
507
508 #[test]
509 fn watch() -> Result<()> {
510 let file = Builder::new()
511 .prefix("nostr-relay-config-test-watch")
512 .suffix(".toml")
513 .tempfile()?;
514
515 let setting = SettingWrapper::watch(&file, None, |_s| {})?;
516 {
517 let r = setting.read();
518 assert_eq!(r.information.name, "");
519 assert!(r.information.supported_nips.contains(&1));
520 }
521
522 fs::write(
523 &file,
524 r#"[information]
525 name = "nostr"
526 "#,
527 )?;
528 sleep(Duration::from_secs(1));
529 {
531 let r = setting.read();
532 assert_eq!(r.information.name, "nostr");
533 assert!(r.information.supported_nips.contains(&1));
534 }
535 Ok(())
536 }
537}