use std::net::{IpAddr, Ipv4Addr};
use figment::{Figment, Profile, Provider, Metadata, error::Result};
use figment::providers::{Serialized, Env, Toml, Format};
use figment::value::{Map, Dict, magic::RelativePathBuf};
use serde::{Deserialize, Serialize};
use yansi::{Paint, Style, Color::Primary};
use crate::log::PaintExt;
use crate::config::{LogLevel, Shutdown, Ident};
use crate::request::{self, Request, FromRequest};
use crate::http::uncased::Uncased;
use crate::data::Limits;
#[cfg(feature = "tls")]
use crate::config::TlsConfig;
#[cfg(feature = "secrets")]
use crate::config::SecretKey;
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct Config {
#[serde(skip)]
pub profile: Profile,
pub address: IpAddr,
pub port: u16,
pub workers: usize,
pub max_blocking: usize,
pub ident: Ident,
#[serde(deserialize_with = "crate::config::ip_header::deserialize")]
pub ip_header: Option<Uncased<'static>>,
pub limits: Limits,
#[serde(serialize_with = "RelativePathBuf::serialize_relative")]
pub temp_dir: RelativePathBuf,
pub keep_alive: u32,
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub tls: Option<TlsConfig>,
#[cfg(feature = "secrets")]
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
#[serde(serialize_with = "SecretKey::serialize_zero")]
pub secret_key: SecretKey,
pub shutdown: Shutdown,
pub log_level: LogLevel,
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub cli_colors: bool,
#[doc(hidden)]
#[serde(skip)]
pub __non_exhaustive: (),
}
impl Default for Config {
fn default() -> Config {
#[cfg(debug_assertions)] { Config::debug_default() }
#[cfg(not(debug_assertions))] { Config::release_default() }
}
}
impl Config {
const DEPRECATED_KEYS: &'static [(&'static str, Option<&'static str>)] = &[
("env", Some(Self::PROFILE)), ("log", Some(Self::LOG_LEVEL)),
("read_timeout", None), ("write_timeout", None),
];
const DEPRECATED_PROFILES: &'static [(&'static str, Option<&'static str>)] = &[
("dev", Some("debug")), ("prod", Some("release")), ("stag", None)
];
pub fn debug_default() -> Config {
Config {
profile: Self::DEBUG_PROFILE,
address: Ipv4Addr::new(127, 0, 0, 1).into(),
port: 8000,
workers: num_cpus::get(),
max_blocking: 512,
ident: Ident::default(),
ip_header: Some(Uncased::from_borrowed("X-Real-IP")),
limits: Limits::default(),
temp_dir: std::env::temp_dir().into(),
keep_alive: 5,
#[cfg(feature = "tls")]
tls: None,
#[cfg(feature = "secrets")]
secret_key: SecretKey::zero(),
shutdown: Shutdown::default(),
log_level: LogLevel::Normal,
cli_colors: true,
__non_exhaustive: (),
}
}
pub fn release_default() -> Config {
Config {
profile: Self::RELEASE_PROFILE,
log_level: LogLevel::Critical,
..Config::debug_default()
}
}
pub fn figment() -> Figment {
Figment::from(Config::default())
.merge(Toml::file(Env::var_or("ROCKET_CONFIG", "Rocket.toml")).nested())
.merge(Env::prefixed("ROCKET_").ignore(&["PROFILE"]).global())
.select(Profile::from_env_or("ROCKET_PROFILE", Self::DEFAULT_PROFILE))
}
pub fn try_from<T: Provider>(provider: T) -> Result<Self> {
let figment = Figment::from(provider);
let mut config = figment.extract::<Self>()?;
config.profile = figment.profile().clone();
Ok(config)
}
pub fn from<T: Provider>(provider: T) -> Self {
Self::try_from(provider).unwrap_or_else(bail_with_config_error)
}
pub fn tls_enabled(&self) -> bool {
#[cfg(feature = "tls")] {
self.tls.as_ref().map_or(false, |tls| !tls.ciphers.is_empty())
}
#[cfg(not(feature = "tls"))] { false }
}
pub fn mtls_enabled(&self) -> bool {
if !self.tls_enabled() {
return false;
}
#[cfg(feature = "mtls")] {
self.tls.as_ref().map_or(false, |tls| tls.mutual.is_some())
}
#[cfg(not(feature = "mtls"))] { false }
}
#[cfg(feature = "secrets")]
pub(crate) fn known_secret_key_used(&self) -> bool {
const KNOWN_SECRET_KEYS: &'static [&'static str] = &[
"hPRYyVRiMyxpw5sBB1XeCMN1kFsDCqKvBi2QJxBVHQk="
];
KNOWN_SECRET_KEYS.iter().any(|&key_str| {
let value = figment::value::Value::from(key_str);
self.secret_key == value.deserialize().expect("known key is valid")
})
}
#[inline]
pub(crate) fn trace_print(&self, figment: &Figment) {
if self.log_level != LogLevel::Debug {
return;
}
trace!("-- configuration trace information --");
for param in Self::PARAMETERS {
if let Some(meta) = figment.find_metadata(param) {
let (param, name) = (param.blue(), meta.name.primary());
if let Some(ref source) = meta.source {
trace_!("{:?} parameter source: {} ({})", param, name, source);
} else {
trace_!("{:?} parameter source: {}", param, name);
}
}
}
}
pub(crate) fn pretty_print(&self, figment: &Figment) {
static VAL: Style = Primary.bold();
self.trace_print(figment);
launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline());
launch_meta_!("address: {}", self.address.paint(VAL));
launch_meta_!("port: {}", self.port.paint(VAL));
launch_meta_!("workers: {}", self.workers.paint(VAL));
launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL));
launch_meta_!("ident: {}", self.ident.paint(VAL));
match self.ip_header {
Some(ref name) => launch_meta_!("IP header: {}", name.paint(VAL)),
None => launch_meta_!("IP header: {}", "disabled".paint(VAL))
}
launch_meta_!("limits: {}", (&self.limits).paint(VAL));
launch_meta_!("temp dir: {}", self.temp_dir.relative().display().paint(VAL));
launch_meta_!("http/2: {}", (cfg!(feature = "http2").paint(VAL)));
match self.keep_alive {
0 => launch_meta_!("keep-alive: {}", "disabled".paint(VAL)),
ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)),
}
match (self.tls_enabled(), self.mtls_enabled()) {
(true, true) => launch_meta_!("tls: {}", "enabled w/mtls".paint(VAL)),
(true, false) => launch_meta_!("tls: {} w/o mtls", "enabled".paint(VAL)),
(false, _) => launch_meta_!("tls: {}", "disabled".paint(VAL)),
}
launch_meta_!("shutdown: {}", self.shutdown.paint(VAL));
launch_meta_!("log level: {}", self.log_level.paint(VAL));
launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL));
for (key, replacement) in Self::DEPRECATED_KEYS {
if let Some(md) = figment.find_metadata(key) {
warn!("found value for deprecated config key `{}`", key.paint(VAL));
if let Some(ref source) = md.source {
launch_meta_!("in {} {}", source.paint(VAL), md.name);
}
if let Some(new_key) = replacement {
launch_meta_!("key has been by replaced by `{}`", new_key.paint(VAL));
} else {
launch_meta_!("key has no special meaning");
}
}
}
for (prefix, replacement) in Self::DEPRECATED_PROFILES {
if let Some(profile) = figment.profiles().find(|p| p.starts_with(prefix)) {
warn!("found set deprecated profile `{}`", profile.paint(VAL));
if let Some(new_profile) = replacement {
launch_meta_!("profile was replaced by `{}`", new_profile.paint(VAL));
} else {
launch_meta_!("profile `{}` has no special meaning", profile);
}
}
}
#[cfg(feature = "secrets")] {
launch_meta_!("secret key: {}", self.secret_key.paint(VAL));
if !self.secret_key.is_provided() {
warn!("secrets enabled without a stable `secret_key`");
launch_meta_!("disable `secrets` feature or configure a `secret_key`");
launch_meta_!("this becomes an {} in non-debug profiles", "error".red());
}
}
}
}
impl Config {
pub const DEBUG_PROFILE: Profile = Profile::const_new("debug");
pub const RELEASE_PROFILE: Profile = Profile::const_new("release");
#[cfg(debug_assertions)]
pub const DEFAULT_PROFILE: Profile = Self::DEBUG_PROFILE;
#[cfg(not(debug_assertions))]
pub const DEFAULT_PROFILE: Profile = Self::RELEASE_PROFILE;
}
impl Config {
const PROFILE: &'static str = "profile";
pub const ADDRESS: &'static str = "address";
pub const PORT: &'static str = "port";
pub const WORKERS: &'static str = "workers";
pub const MAX_BLOCKING: &'static str = "max_blocking";
pub const KEEP_ALIVE: &'static str = "keep_alive";
pub const IDENT: &'static str = "ident";
pub const IP_HEADER: &'static str = "ip_header";
pub const LIMITS: &'static str = "limits";
pub const TLS: &'static str = "tls";
pub const SECRET_KEY: &'static str = "secret_key";
pub const TEMP_DIR: &'static str = "temp_dir";
pub const LOG_LEVEL: &'static str = "log_level";
pub const SHUTDOWN: &'static str = "shutdown";
pub const CLI_COLORS: &'static str = "cli_colors";
pub const PARAMETERS: &'static [&'static str] = &[
Self::ADDRESS, Self::PORT, Self::WORKERS, Self::MAX_BLOCKING,
Self::KEEP_ALIVE, Self::IDENT, Self::IP_HEADER, Self::LIMITS, Self::TLS,
Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN,
Self::CLI_COLORS,
];
}
impl Provider for Config {
#[track_caller]
fn metadata(&self) -> Metadata {
if self == &Config::default() {
Metadata::named("rocket::Config::default()")
} else {
Metadata::named("rocket::Config")
}
}
#[track_caller]
fn data(&self) -> Result<Map<Profile, Dict>> {
#[allow(unused_mut)]
let mut map: Map<Profile, Dict> = Serialized::defaults(self).data()?;
#[cfg(feature = "secrets")]
if !self.secret_key.is_zero() {
if let Some(map) = map.get_mut(&Profile::Default) {
map.insert("secret_key".into(), self.secret_key.key.master().into());
}
}
Ok(map)
}
fn profile(&self) -> Option<Profile> {
Some(self.profile.clone())
}
}
#[crate::async_trait]
impl<'r> FromRequest<'r> for &'r Config {
type Error = std::convert::Infallible;
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
request::Outcome::Success(req.rocket().config())
}
}
#[doc(hidden)]
pub fn bail_with_config_error<T>(error: figment::Error) -> T {
pretty_print_error(error);
panic!("aborting due to configuration error(s)")
}
#[doc(hidden)]
pub fn pretty_print_error(error: figment::Error) {
use figment::error::{Kind, OneOf};
crate::log::init_default();
error!("Failed to extract valid configuration.");
for e in error {
fn w<T>(v: T) -> yansi::Painted<T> { Paint::new(v).primary() }
match e.kind {
Kind::Message(msg) => error_!("{}", msg),
Kind::InvalidType(v, exp) => {
error_!("invalid type: found {}, expected {}", w(v), w(exp));
}
Kind::InvalidValue(v, exp) => {
error_!("invalid value {}, expected {}", w(v), w(exp));
},
Kind::InvalidLength(v, exp) => {
error_!("invalid length {}, expected {}", w(v), w(exp))
},
Kind::UnknownVariant(v, exp) => {
error_!("unknown variant: found `{}`, expected `{}`", w(v), w(OneOf(exp)))
}
Kind::UnknownField(v, exp) => {
error_!("unknown field: found `{}`, expected `{}`", w(v), w(OneOf(exp)))
}
Kind::MissingField(v) => {
error_!("missing field `{}`", w(v))
}
Kind::DuplicateField(v) => {
error_!("duplicate field `{}`", w(v))
}
Kind::ISizeOutOfRange(v) => {
error_!("signed integer `{}` is out of range", w(v))
}
Kind::USizeOutOfRange(v) => {
error_!("unsigned integer `{}` is out of range", w(v))
}
Kind::Unsupported(v) => {
error_!("unsupported type `{}`", w(v))
}
Kind::UnsupportedKey(a, e) => {
error_!("unsupported type `{}` for key: must be `{}`", w(a), w(e))
}
}
if let (Some(ref profile), Some(ref md)) = (&e.profile, &e.metadata) {
if !e.path.is_empty() {
let key = md.interpolate(profile, &e.path);
info_!("for key {}", w(key));
}
}
if let Some(md) = e.metadata {
if let Some(source) = md.source {
info_!("in {} {}", w(source), md.name);
} else {
info_!("in {}", w(md.name));
}
}
}
}