sonya_meta/
config.rs

1use serde::de::{Error, MapAccess, SeqAccess, Visitor};
2use serde::{de, Deserialize, Deserializer, Serialize};
3use std::env::VarError;
4use std::fmt;
5use std::fmt::{Display, Formatter};
6use std::fs::File;
7use std::io::BufReader;
8use std::marker::PhantomData;
9use std::net::SocketAddr;
10use std::path::PathBuf;
11use std::str::FromStr;
12
13/// Extracts config from yaml, json or environment
14/// You can manage it with env var `CONFIG`
15///
16/// Example:
17/// ```sh
18/// CONFIG=ENV ADDR=0.0.0.0:8080 ./bin
19/// CONFIG=./config.yaml ./bin
20/// CONFIG=./config.json ./bin
21/// ```
22///
23/// Available envs when `CONFIG=ENV` was set:
24/// ```env
25/// ADDR=addr:port // service address which will be listened
26/// // Tls options
27/// TLS_PRIVATE_KEY=key.pem
28/// TLS_CERT=key.pem
29/// SECURE_SERVICE_TOKEN=xxx // Service token
30/// SECURE_JWT_EXPIRATION_TIME=60 // Jwt expiration time
31/// QUEUE_DEFAULT=test1;test // Default queues splits by ;, queue server only
32/// QUEUE_DB_PATH=/tmp/sonya // DB data path, queue server only
33/// QUEUE_MAX_KEY_UPDATES=10 // Maximum key version to store
34/// SERVICE_DISCOVERY_TYPE=API // Possible service discovery types is API, ETCD
35/// SERVICE_DISCOVERY_HOSTS=http://etcd_host:port;http://etcd_host2:port // Hosts splits by ;, required by ETCD type
36/// SERVICE_DISCOVERY_DEFAULT_SHARDS=http://queue:port;http://queue2:port // Hosts splits by ;, required by ETCD type
37/// SERVICE_DISCOVERY_PREFIX=sonya // Prefix for service discovery key
38/// SERVICE_DISCOVERY_INSTANCE_ADDR=http://queue:port // instance addr which will be registered in service discovery, required by server
39/// SERVICE_DISCOVERY_INSTANCE_id=123 // instance id which will be registered in service discovery
40/// WEBSOCKET_KEY=SGVsbG8sIHdvcmxkIQ== // Sec Web Socket header, proxy only
41/// WEBSOCKET_VERSION=13 // Web Socket version, proxy only
42/// GARBAGE_COLLECTOR_INTERVAL=60 // Time in seconds when proxy storage will be cleared, proxy only
43/// ```
44pub fn get_config() -> Config {
45    env_logger::init();
46    let config_path = std::env::var("CONFIG").unwrap_or_else(|e| match e {
47        VarError::NotPresent => String::from("ENV"),
48        e => panic!("{}", e),
49    });
50
51    match ConfigParsingStrategy::from_str(&config_path).unwrap() {
52        ConfigParsingStrategy::Env => from_env().unwrap(),
53        ConfigParsingStrategy::Yaml(r) => from_yaml(&r).unwrap(),
54        ConfigParsingStrategy::Json(r) => from_json(&r).unwrap(),
55    }
56}
57
58fn from_env() -> Result<Config, std::env::VarError> {
59    Ok(Config {
60        addr: from_env_optional("ADDR")?.map(|a| SocketAddr::from_str(&a).expect("invalid addr")),
61        tls: tls_from_env()?,
62        secure: secure_from_env()?,
63        queue: queue_from_env()?,
64        service_discovery: service_discovery_from_env()?,
65        websocket: websocket_from_env()?,
66        garbage_collector: garbage_collector_from_env()?,
67    })
68}
69
70fn tls_from_env() -> Result<Option<Tls>, std::env::VarError> {
71    let private_key = from_env_optional("TLS_PRIVATE_KEY")?;
72    let cert = from_env_optional("TLS_CERT")?;
73
74    Ok(private_key
75        .and_then(|p| Some((p, cert?)))
76        .map(|(p, c)| Tls {
77            private_key: p,
78            cert: c,
79        }))
80}
81
82fn secure_from_env() -> Result<Option<Secure>, std::env::VarError> {
83    let jwt_token_expiration = from_env_optional("SECURE_JWT_EXPIRATION_TIME")?
84        .map(|e| e.parse().expect("invalid jwt expiration time"))
85        .unwrap_or_else(default_jwt_token_expiration);
86    let service_token = from_env_optional("SECURE_SERVICE_TOKEN")?.map(|st| Secure {
87        service_token: st,
88        jwt_token_expiration,
89    });
90    Ok(service_token)
91}
92
93fn garbage_collector_from_env() -> Result<GarbageCollector, std::env::VarError> {
94    let gb = from_env_optional("GARBAGE_COLLECTOR_INTERVAL")?
95        .map(|interval| GarbageCollector {
96            interval: interval.parse().expect("invalid garbage interval"),
97        })
98        .unwrap_or_default();
99    Ok(gb)
100}
101
102fn websocket_from_env() -> Result<WebSocket, std::env::VarError> {
103    let mut websocket = WebSocket::default();
104    if let Some(key) = from_env_optional("WEBSOCKET_KEY")? {
105        websocket.key = key;
106    }
107    if let Some(version) = from_env_optional("WEBSOCKET_VERSION")? {
108        websocket.version = version;
109    }
110    Ok(websocket)
111}
112
113fn queue_from_env() -> Result<Queue, std::env::VarError> {
114    let default: DefaultQueues = from_env_optional("QUEUE_DEFAULT")?
115        .map(|d| {
116            d.split(';')
117                .filter(|s| !s.is_empty())
118                .map(String::from)
119                .collect()
120        })
121        .unwrap_or_default();
122
123    let db_path = from_env_optional("QUEUE_DB_PATH")?.map(PathBuf::from);
124    let max_key_updates = from_env_optional("QUEUE_MAX_KEY_UPDATES")?
125        .map(|mku| mku.parse().expect("invalid max keys updates value"));
126    Ok(Queue {
127        default,
128        db_path,
129        max_key_updates,
130    })
131}
132
133fn service_discovery_from_env() -> Result<Option<ServiceDiscovery>, std::env::VarError> {
134    let service_discovery_type =
135        from_env_optional("SERVICE_DISCOVERY_TYPE")?.unwrap_or_else(|| String::from("API"));
136    let default_shards: Option<Shards> = from_env_optional("SERVICE_DISCOVERY_DEFAULT_SHARDS")?
137        .map(|d| {
138            d.split(';')
139                .filter(|s| !s.is_empty())
140                .map(String::from)
141                .collect()
142        });
143
144    let service_discovery = match service_discovery_type.as_str() {
145        "API" => ServiceDiscovery::Api {
146            default: default_shards,
147        },
148        "ETCD" => ServiceDiscovery::Etcd {
149            default: default_shards,
150            hosts: std::env::var("SERVICE_DISCOVERY_HOSTS")
151                .expect("empty service discovery hosts")
152                .split(';')
153                .filter(|s| !s.is_empty())
154                .map(String::from)
155                .collect(),
156            prefix: from_env_optional("SERVICE_DISCOVERY_PREFIX")?
157                .unwrap_or_else(default_sd_prefix),
158            instance_opts: instance_opts_from_env()?,
159        },
160        _ => panic!("Invalid service discovery type"),
161    };
162
163    Ok(Some(service_discovery))
164}
165
166fn instance_opts_from_env() -> Result<Option<ServiceDiscoveryInstanceOptions>, std::env::VarError> {
167    let instance_addr = from_env_optional("SERVICE_DISCOVERY_INSTANCE_ADDR")?;
168    let instance_id = from_env_optional("SERVICE_DISCOVERY_INSTANCE_id")?;
169
170    Ok(instance_addr.map(|ia| ServiceDiscoveryInstanceOptions {
171        instance_addr: ia,
172        instance_id,
173    }))
174}
175
176fn from_env_optional(env_var: &str) -> Result<Option<String>, std::env::VarError> {
177    std::env::var(env_var).map(Some).or_else(|e| match e {
178        VarError::NotPresent => Ok(None),
179        e => Err(e),
180    })
181}
182
183fn from_yaml(path: &str) -> serde_yaml::Result<Config> {
184    let reader = match File::open(path) {
185        Ok(r) => BufReader::new(r),
186        Err(e) => return Err(serde_yaml::Error::custom(e)),
187    };
188    serde_yaml::from_reader(reader)
189}
190
191fn from_json(path: &str) -> serde_json::Result<Config> {
192    let reader = match File::open(path) {
193        Ok(r) => BufReader::new(r),
194        Err(e) => return Err(serde_json::Error::custom(e)),
195    };
196    serde_json::from_reader(reader)
197}
198
199enum ConfigParsingStrategy<T> {
200    Env,
201    Yaml(T),
202    Json(T),
203}
204
205impl FromStr for ConfigParsingStrategy<String> {
206    type Err = &'static str;
207
208    fn from_str(s: &str) -> Result<Self, Self::Err> {
209        match s {
210            "ENV" => Ok(Self::Env),
211            s if s.ends_with(".yaml") => Ok(Self::Yaml(String::from(s))),
212            s if s.ends_with(".json") => Ok(Self::Json(String::from(s))),
213            _ => Err("invalid config type"),
214        }
215    }
216}
217
218#[derive(Serialize, Deserialize, Clone, Debug)]
219pub struct Config {
220    pub addr: Option<SocketAddr>,
221    pub tls: Option<Tls>,
222    pub secure: Option<Secure>,
223    pub queue: Queue,
224    pub service_discovery: Option<ServiceDiscovery>,
225    #[serde(default)]
226    pub websocket: WebSocket,
227    #[serde(default)]
228    pub garbage_collector: GarbageCollector,
229}
230
231#[derive(Serialize, Deserialize, Clone, Debug)]
232pub struct Tls {
233    pub private_key: String,
234    pub cert: String,
235}
236
237#[derive(Serialize, Clone, Debug)]
238pub struct Secure {
239    pub service_token: SecureToken,
240    #[serde(default = "default_jwt_token_expiration")]
241    pub jwt_token_expiration: u64,
242}
243
244#[derive(Deserialize)]
245#[serde(remote = "Secure")]
246struct SecureDef {
247    pub service_token: SecureToken,
248    #[serde(default = "default_jwt_token_expiration")]
249    pub jwt_token_expiration: u64,
250}
251
252pub fn default_jwt_token_expiration() -> u64 {
253    60
254}
255
256pub type SecureToken = String;
257
258impl From<SecureToken> for Secure {
259    fn from(service_token: SecureToken) -> Self {
260        Self {
261            service_token,
262            jwt_token_expiration: default_jwt_token_expiration(),
263        }
264    }
265}
266
267#[derive(Serialize, Deserialize, Clone, Debug)]
268pub struct WebSocket {
269    pub key: String,
270    #[serde(default = "default_websocket_v")]
271    pub version: String,
272}
273
274fn default_websocket_v() -> String {
275    "13".into()
276}
277
278impl Default for WebSocket {
279    fn default() -> Self {
280        Self {
281            key: "SGVsbG8sIHdvcmxkIQ==".into(),
282            version: default_websocket_v(),
283        }
284    }
285}
286
287#[derive(Serialize, Deserialize, Clone, Debug)]
288pub struct Queue {
289    #[serde(default)]
290    pub default: DefaultQueues,
291    pub db_path: Option<PathBuf>,
292    pub max_key_updates: Option<usize>,
293}
294
295pub type DefaultQueues = Vec<String>;
296
297#[derive(Serialize, Clone, Debug)]
298pub struct GarbageCollector {
299    pub interval: u64,
300}
301
302#[derive(Deserialize)]
303#[serde(remote = "GarbageCollector")]
304struct GarbageCollectorDef {
305    pub interval: u64,
306}
307
308impl From<u64> for GarbageCollector {
309    fn from(interval: u64) -> Self {
310        Self { interval }
311    }
312}
313
314impl Default for GarbageCollector {
315    fn default() -> Self {
316        Self::from(60)
317    }
318}
319
320pub type Shards = Vec<String>;
321
322#[derive(Serialize, Clone, Debug)]
323#[serde(tag = "type", rename_all = "lowercase")]
324pub enum ServiceDiscovery {
325    Api {
326        default: Option<Shards>,
327    },
328    Etcd {
329        default: Option<Shards>,
330        hosts: ServiceDiscoveryHosts,
331        #[serde(default = "default_sd_prefix")]
332        prefix: String,
333        instance_opts: Option<ServiceDiscoveryInstanceOptions>,
334    },
335}
336
337fn default_sd_prefix() -> String {
338    "sonya".into()
339}
340
341#[derive(Deserialize)]
342#[serde(tag = "type", rename_all = "lowercase")]
343#[serde(remote = "ServiceDiscovery")]
344enum ServiceDiscoveryDef {
345    Api {
346        default: Option<Shards>,
347    },
348    Etcd {
349        default: Option<Shards>,
350        hosts: ServiceDiscoveryHosts,
351        #[serde(default = "default_sd_prefix")]
352        prefix: String,
353        instance_opts: Option<ServiceDiscoveryInstanceOptions>,
354    },
355}
356
357#[derive(Serialize, Deserialize, Clone, Debug)]
358pub struct ServiceDiscoveryInstanceOptions {
359    pub instance_id: Option<String>,
360    pub instance_addr: String,
361}
362
363impl Display for ServiceDiscovery {
364    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
365        write!(
366            f,
367            "{}",
368            match self {
369                ServiceDiscovery::Api { .. } => "api",
370                ServiceDiscovery::Etcd { .. } => "etcd",
371            }
372        )
373    }
374}
375
376impl From<Shards> for ServiceDiscovery {
377    fn from(default: Shards) -> Self {
378        Self::Api {
379            default: Some(default),
380        }
381    }
382}
383
384pub type ServiceDiscoveryHosts = Vec<String>;
385
386struct StringOrStruct<T>(PhantomData<T>);
387struct VecOrStruct<T>(PhantomData<T>);
388struct U64OrStruct<T>(PhantomData<T>);
389
390#[macro_export]
391macro_rules! string_or_struct_impl {
392    ($struct_name: ident, $struct_name_remote: ident) => {
393        impl<'de> Visitor<'de> for StringOrStruct<$struct_name> {
394            type Value = $struct_name;
395
396            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
397                write!(
398                    formatter,
399                    "string or struct {} expected",
400                    std::any::type_name::<Self::Value>()
401                )
402            }
403
404            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
405            where
406                E: de::Error,
407            {
408                Ok(Self::Value::from(value.to_owned()))
409            }
410
411            fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
412            where
413                M: MapAccess<'de>,
414            {
415                $struct_name_remote::deserialize(de::value::MapAccessDeserializer::new(map))
416            }
417        }
418
419        impl<'de> Deserialize<'de> for $struct_name {
420            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
421            where
422                D: Deserializer<'de>,
423            {
424                deserializer.deserialize_any(StringOrStruct::<Self>(PhantomData))
425            }
426        }
427    };
428}
429
430#[macro_export]
431macro_rules! vec_or_struct_impl {
432    ($struct_name: ident, $struct_name_remote: ident) => {
433        impl<'de> Visitor<'de> for VecOrStruct<$struct_name> {
434            type Value = $struct_name;
435
436            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
437                write!(
438                    formatter,
439                    "list of strings or struct {} expected",
440                    std::any::type_name::<Self::Value>()
441                )
442            }
443
444            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
445            where
446                E: de::Error,
447            {
448                Ok(Self::Value::from(vec![value.to_owned()]))
449            }
450
451            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
452            where
453                A: SeqAccess<'de>,
454            {
455                let mut vec = Vec::new();
456
457                while let Some(elem) = seq.next_element::<String>()? {
458                    vec.push(elem);
459                }
460
461                Ok(Self::Value::from(vec))
462            }
463
464            fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
465            where
466                M: MapAccess<'de>,
467            {
468                $struct_name_remote::deserialize(de::value::MapAccessDeserializer::new(map))
469            }
470        }
471
472        impl<'de> Deserialize<'de> for $struct_name {
473            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
474            where
475                D: Deserializer<'de>,
476            {
477                deserializer.deserialize_any(VecOrStruct::<Self>(PhantomData))
478            }
479        }
480    };
481}
482
483#[macro_export]
484macro_rules! u64_or_struct {
485    ($struct_name: ident, $struct_name_remote: ident) => {
486        impl<'de> Visitor<'de> for U64OrStruct<$struct_name> {
487            type Value = $struct_name;
488
489            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
490                write!(
491                    formatter,
492                    "list of strings or struct {} expected",
493                    std::any::type_name::<Self::Value>()
494                )
495            }
496
497            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
498            where
499                E: de::Error,
500            {
501                Ok(Self::Value::from(value))
502            }
503
504            fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
505            where
506                M: MapAccess<'de>,
507            {
508                $struct_name_remote::deserialize(de::value::MapAccessDeserializer::new(map))
509            }
510        }
511
512        impl<'de> Deserialize<'de> for $struct_name {
513            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
514            where
515                D: Deserializer<'de>,
516            {
517                deserializer.deserialize_any(U64OrStruct::<Self>(PhantomData))
518            }
519        }
520    };
521}
522
523string_or_struct_impl!(Secure, SecureDef);
524vec_or_struct_impl!(ServiceDiscovery, ServiceDiscoveryDef);
525u64_or_struct!(GarbageCollector, GarbageCollectorDef);