1use std::collections::BTreeSet;
2use std::fs;
3use std::net::{IpAddr, Ipv4Addr};
4#[cfg(target_family = "unix")]
5use std::path::PathBuf;
6
7use anyhow::{Context, Result};
8use directories::ProjectDirs;
9use serde::de::{Error, Unexpected};
10use serde::{Deserialize, Deserializer, Serialize};
11
12mod default {
13 use super::*;
14
15 pub fn instance_timeout() -> Option<u32> {
16 Some(5 * 60)
18 }
19
20 pub fn gc_interval() -> u32 {
21 10
23 }
24
25 pub fn listen() -> Address {
26 Address::Tcp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631)
28 }
29
30 pub fn connect() -> Address {
31 listen()
32 }
33
34 pub fn log_filters() -> String {
35 "info".to_owned()
36 }
37
38 pub fn pass_environment() -> BTreeSet<String> {
39 BTreeSet::new()
40 }
41}
42
43mod de {
44 use super::*;
45
46 pub fn instance_timeout<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
48 where
49 D: Deserializer<'de>,
50 {
51 #[derive(Deserialize)]
52 #[serde(untagged)]
53 enum OneOf {
54 Bool(bool),
55 U32(u32),
56 }
57
58 match OneOf::deserialize(deserializer) {
59 Ok(OneOf::U32(value)) => Ok(Some(value)),
60 Ok(OneOf::Bool(false)) => Ok(None),
61 Ok(OneOf::Bool(true)) => Err(Error::invalid_value(
62 Unexpected::Bool(true),
63 &"a non-negative integer or false",
64 )),
65 Err(_) => Err(Error::custom(
66 "invalid type: expected a non-negative integer or false",
67 )),
68 }
69 }
70
71 pub fn gc_interval<'de, D>(deserializer: D) -> Result<u32, D::Error>
73 where
74 D: Deserializer<'de>,
75 {
76 match u32::deserialize(deserializer)? {
77 0 => Err(Error::invalid_value(
78 Unexpected::Unsigned(0),
79 &"an integer 1 or greater",
80 )),
81 value => Ok(value),
82 }
83 }
84}
85
86#[derive(Serialize, Deserialize, Debug)]
87#[serde(untagged)]
88pub enum Address {
89 Tcp(IpAddr, u16),
90 #[cfg(target_family = "unix")]
91 Unix(PathBuf),
92}
93
94#[derive(Serialize, Deserialize, Debug)]
95#[serde(deny_unknown_fields)]
96pub struct Config {
97 #[serde(default = "default::instance_timeout")]
98 #[serde(deserialize_with = "de::instance_timeout")]
99 pub instance_timeout: Option<u32>,
100
101 #[serde(default = "default::gc_interval")]
102 #[serde(deserialize_with = "de::gc_interval")]
103 pub gc_interval: u32,
104
105 #[serde(default = "default::listen")]
106 pub listen: Address,
107
108 #[serde(default = "default::connect")]
109 pub connect: Address,
110
111 #[serde(default = "default::log_filters")]
112 pub log_filters: String,
113
114 #[serde(default = "default::pass_environment")]
115 pub pass_environment: BTreeSet<String>,
116}
117
118#[cfg(test)]
119#[test]
120fn generate_default_and_check_it_matches_commited_defaults() {
121 use std::fs;
122 use std::path::Path;
123
124 let generated_defaults = Config::default();
125 let generated_defaults = toml::to_string(&generated_defaults).expect("failed serialize");
126
127 let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("defaults.toml");
128 let saved_defaults = fs::read_to_string(path).expect("failed reading defaults.toml file");
129
130 assert_eq!(generated_defaults, saved_defaults);
131}
132
133impl Default for Config {
134 fn default() -> Self {
135 Config {
136 instance_timeout: default::instance_timeout(),
137 gc_interval: default::gc_interval(),
138 listen: default::listen(),
139 connect: default::connect(),
140 log_filters: default::log_filters(),
141 pass_environment: default::pass_environment(),
142 }
143 }
144}
145
146impl Config {
147 pub fn try_load() -> Result<Self> {
149 let pkg_name = env!("CARGO_PKG_NAME");
150 let config_path = ProjectDirs::from("", "", pkg_name)
151 .context("project config directory not found")?
152 .config_dir()
153 .join("config.toml");
154 let path = config_path.display();
155 let config_data =
156 fs::read(&config_path).with_context(|| format!("cannot read config file `{path}`"))?;
157 toml::from_slice(&config_data).with_context(|| format!("cannot parse config file `{path}`"))
158 }
159
160 pub fn init_logger(&self) {
165 use tracing_subscriber::prelude::*;
166 use tracing_subscriber::EnvFilter;
167
168 let format = tracing_subscriber::fmt::layer()
169 .without_time()
170 .with_target(false)
171 .with_writer(std::io::stderr);
172
173 let filter = EnvFilter::try_from_default_env()
174 .or_else(|_| EnvFilter::try_new(&self.log_filters))
175 .unwrap_or_else(|_| EnvFilter::new("info"));
176
177 tracing_subscriber::registry()
178 .with(filter)
179 .with(format)
180 .init();
181 }
182}