ra_multiplex/
config.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::net::{IpAddr, Ipv4Addr};
4#[cfg(target_family = "unix")]
5use std::path::PathBuf;
6
7use anyhow::{Context, Result};
8use directories::ProjectDirs;
9use serde::de::{Error, Unexpected};
10use serde::{Deserialize, Deserializer, Serialize};
11
12mod default {
13    use super::*;
14
15    pub fn instance_timeout() -> Option<u32> {
16        // 5 minutes
17        Some(5 * 60)
18    }
19
20    pub fn gc_interval() -> u32 {
21        // 10 seconds
22        10
23    }
24
25    pub fn listen() -> Address {
26        // localhost & some random unprivileged port
27        Address::Tcp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631)
28    }
29
30    pub fn connect() -> Address {
31        listen()
32    }
33
34    pub fn log_filters() -> String {
35        "info".to_owned()
36    }
37
38    pub fn pass_environment() -> BTreeSet<String> {
39        BTreeSet::new()
40    }
41}
42
43mod de {
44    use super::*;
45
46    /// parse either bool(false) or u32
47    pub fn instance_timeout<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
48    where
49        D: Deserializer<'de>,
50    {
51        #[derive(Deserialize)]
52        #[serde(untagged)]
53        enum OneOf {
54            Bool(bool),
55            U32(u32),
56        }
57
58        match OneOf::deserialize(deserializer) {
59            Ok(OneOf::U32(value)) => Ok(Some(value)),
60            Ok(OneOf::Bool(false)) => Ok(None),
61            Ok(OneOf::Bool(true)) => Err(Error::invalid_value(
62                Unexpected::Bool(true),
63                &"a non-negative integer or false",
64            )),
65            Err(_) => Err(Error::custom(
66                "invalid type: expected a non-negative integer or false",
67            )),
68        }
69    }
70
71    /// make sure the value is greater than 0 to giver users feedback on invalid configuration
72    pub fn gc_interval<'de, D>(deserializer: D) -> Result<u32, D::Error>
73    where
74        D: Deserializer<'de>,
75    {
76        match u32::deserialize(deserializer)? {
77            0 => Err(Error::invalid_value(
78                Unexpected::Unsigned(0),
79                &"an integer 1 or greater",
80            )),
81            value => Ok(value),
82        }
83    }
84}
85
86#[derive(Serialize, Deserialize, Debug)]
87#[serde(untagged)]
88pub enum Address {
89    Tcp(IpAddr, u16),
90    #[cfg(target_family = "unix")]
91    Unix(PathBuf),
92}
93
94#[derive(Serialize, Deserialize, Debug)]
95#[serde(deny_unknown_fields)]
96pub struct Config {
97    #[serde(default = "default::instance_timeout")]
98    #[serde(deserialize_with = "de::instance_timeout")]
99    pub instance_timeout: Option<u32>,
100
101    #[serde(default = "default::gc_interval")]
102    #[serde(deserialize_with = "de::gc_interval")]
103    pub gc_interval: u32,
104
105    #[serde(default = "default::listen")]
106    pub listen: Address,
107
108    #[serde(default = "default::connect")]
109    pub connect: Address,
110
111    #[serde(default = "default::log_filters")]
112    pub log_filters: String,
113
114    #[serde(default = "default::pass_environment")]
115    pub pass_environment: BTreeSet<String>,
116}
117
118#[cfg(test)]
119#[test]
120fn generate_default_and_check_it_matches_commited_defaults() {
121    use std::fs;
122    use std::path::Path;
123
124    let generated_defaults = Config::default();
125    let generated_defaults = toml::to_string(&generated_defaults).expect("failed serialize");
126
127    let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("defaults.toml");
128    let saved_defaults = fs::read_to_string(path).expect("failed reading defaults.toml file");
129
130    assert_eq!(generated_defaults, saved_defaults);
131}
132
133impl Default for Config {
134    fn default() -> Self {
135        Config {
136            instance_timeout: default::instance_timeout(),
137            gc_interval: default::gc_interval(),
138            listen: default::listen(),
139            connect: default::connect(),
140            log_filters: default::log_filters(),
141            pass_environment: default::pass_environment(),
142        }
143    }
144}
145
146impl Config {
147    /// Try loading config file from the system default location
148    pub fn try_load() -> Result<Self> {
149        let pkg_name = env!("CARGO_PKG_NAME");
150        let config_path = ProjectDirs::from("", "", pkg_name)
151            .context("project config directory not found")?
152            .config_dir()
153            .join("config.toml");
154        let path = config_path.display();
155        let config_data =
156            fs::read(&config_path).with_context(|| format!("cannot read config file `{path}`"))?;
157        toml::from_slice(&config_data).with_context(|| format!("cannot parse config file `{path}`"))
158    }
159
160    /// Configure tracing-subscriber with env filter set to `log_filters` (if
161    /// not overriden by RUST_LOG env var)
162    ///
163    /// Panics if called multiple times.
164    pub fn init_logger(&self) {
165        use tracing_subscriber::prelude::*;
166        use tracing_subscriber::EnvFilter;
167
168        let format = tracing_subscriber::fmt::layer()
169            .without_time()
170            .with_target(false)
171            .with_writer(std::io::stderr);
172
173        let filter = EnvFilter::try_from_default_env()
174            .or_else(|_| EnvFilter::try_new(&self.log_filters))
175            .unwrap_or_else(|_| EnvFilter::new("info"));
176
177        tracing_subscriber::registry()
178            .with(filter)
179            .with(format)
180            .init();
181    }
182}