use std::collections::HashMap;
use std::str::FromStr;
use std::net::SocketAddr;
use std::fmt;
use std::result::Result as StdResult;
use serde::de::{Error as DeError, Visitor, Deserialize, Deserializer};
use common::prelude::*;
use utils;
macro_rules! default {
($struct:ident {$( $key:ident: $value:expr, )*}) => {
impl Default for $struct {
fn default() -> Self {
$struct {
$( $key: $value ),*
}
}
}
}
}
macro_rules! default_fn {
($name:ident: $type:ty = $val:expr) => {
fn $name() -> $type {
$val
}
}
}
#[derive(Debug, Default, PartialEq, Eq, Deserialize)]
pub struct Config {
#[serde(default)]
pub http: HttpConfig,
#[serde(default)]
pub scripts: ScriptsConfig,
#[serde(default)]
pub jobs: JobsConfig,
#[serde(default)]
pub env: HashMap<String, String>,
}
#[derive(Debug, PartialEq, Eq, Deserialize)]
pub struct HttpConfig {
#[serde(rename="behind-proxies", default="default_behind_proxies")]
pub behind_proxies: u8,
#[serde(default="default_bind")]
pub bind: SocketAddr,
#[serde(rename="rate-limit", default)]
pub rate_limit: RateLimitConfig,
#[serde(rename="health-endpoint", default="default_health_endpoint")]
pub health_endpoint: bool,
}
default_fn!(default_behind_proxies: u8 = 0);
default_fn!(default_bind: SocketAddr = "127.0.0.1:8000".parse().unwrap());
default_fn!(default_health_endpoint: bool = true);
default!(HttpConfig {
behind_proxies: default_behind_proxies(),
bind: default_bind(),
rate_limit: RateLimitConfig::default(),
health_endpoint: default_health_endpoint(),
});
#[derive(Debug, PartialEq, Eq)]
pub struct RateLimitConfig {
pub allowed: u64,
pub interval: utils::TimeString,
}
default!(RateLimitConfig {
allowed: 10,
interval: 60.into(),
});
impl RateLimitConfig {
fn from_str_internal(s: &str) -> Result<RateLimitConfig> {
let slash_pos = s.char_indices()
.filter(|ci| ci.1 == '/')
.map(|ci| ci.0)
.collect::<Vec<_>>();
match slash_pos.len() {
0 => Ok(RateLimitConfig {
allowed: s.parse()?,
interval: 60.into(),
}),
1 => {
let (requests, interval) = s.split_at(slash_pos[0]);
Ok(RateLimitConfig {
allowed: requests.parse()?,
interval: (&interval[1..]).parse()?,
})
},
_ => Err(ErrorKind::RateLimitConfigTooManySlashes.into()),
}
}
}
impl FromStr for RateLimitConfig {
type Err = Error;
fn from_str(s: &str) -> Result<RateLimitConfig> {
Self::from_str_internal(s)
.chain_err(|| ErrorKind::RateLimitConfigError(s.into()))
}
}
struct RateLimitConfigVisitor;
impl<'de> Visitor<'de> for RateLimitConfigVisitor {
type Value = RateLimitConfig;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a number of seconds, a time string or a map")
}
fn visit_str<E: DeError>(self, s: &str) -> StdResult<RateLimitConfig, E> {
match s.parse() {
Ok(parsed) => Ok(parsed),
Err(e) => Err(E::custom(e.to_string())),
}
}
fn visit_i64<E>(self, num: i64) -> StdResult<RateLimitConfig, E> {
Ok(RateLimitConfig {
allowed: num as u64,
interval: 60.into(),
})
}
}
impl<'de> Deserialize<'de> for RateLimitConfig {
fn deserialize<D: Deserializer<'de>>(
deserializer: D,
) -> StdResult<RateLimitConfig, D::Error> {
deserializer.deserialize_any(RateLimitConfigVisitor)
}
}
#[derive(Debug, PartialEq, Eq, Deserialize)]
pub struct JobsConfig {
#[serde(default = "default_threads")]
pub threads: u16,
}
default_fn!(default_threads: u16 = 1);
default!(JobsConfig {
threads: default_threads(),
});
#[derive(Debug, PartialEq, Eq, Deserialize)]
pub struct ScriptsConfig {
#[serde(default = "default_path")]
pub path: String,
#[serde(default = "default_recursive")]
pub recursive: bool,
}
default_fn!(default_path: String = ".".into());
default_fn!(default_recursive: bool = false);
default!(ScriptsConfig {
path: default_path(),
recursive: default_recursive(),
});