use crate::helpers::strip_bom;
use super::cacher::{BoxedCacher, Cacher, MemoryCache};
use super::configs::{Config, PartialConfig};
use super::error::ConfigError;
#[cfg(feature = "extends")]
use super::extender::ExtendsFrom;
use super::layer::Layer;
use super::source::{Source, SourceFormat};
use serde::Serialize;
use std::borrow::Cow;
use std::fs;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use tracing::{instrument, trace};
#[derive(Serialize)]
pub struct ConfigLoadResult<T: Config> {
pub config: T,
pub layers: Vec<Layer<T>>,
}
pub struct ConfigLoader<T: Config> {
_config: PhantomData<T>,
cacher: Mutex<BoxedCacher>,
formats: Vec<Arc<dyn SourceFormat<T::Partial>>>,
help: Option<String>,
name: String,
sources: Vec<Source>,
root: Option<PathBuf>,
}
impl<T: Config> Default for ConfigLoader<T> {
fn default() -> Self {
ConfigLoader {
_config: PhantomData,
cacher: Mutex::new(Box::<MemoryCache>::default()),
formats: vec![],
help: None,
name: T::schema_name().unwrap_or_else(|| "<unknown>".into()),
sources: vec![],
root: None,
}
}
}
impl<T: Config> ConfigLoader<T> {
pub fn new() -> Self {
let mut loader = ConfigLoader::default();
#[cfg(feature = "json")]
loader.add_format(super::formats::JsonFormat::default());
#[cfg(feature = "pkl")]
loader.add_format(super::formats::PklFormat::default());
#[cfg(feature = "ron")]
loader.add_format(super::formats::RonFormat::default());
#[cfg(feature = "toml")]
loader.add_format(super::formats::TomlFormat::default());
#[cfg(feature = "yaml")]
loader.add_format(super::formats::YamlFormat::default());
loader
}
pub fn add_format(&mut self, format: impl SourceFormat<T::Partial> + 'static) -> &mut Self {
self.formats.push(Arc::new(format));
self
}
pub fn code<S: TryInto<String>, P: TryInto<PathBuf>>(
&mut self,
code: S,
path: P,
) -> Result<&mut Self, ConfigError> {
self.source(Source::code(code, path)?)
}
pub fn file<P: TryInto<PathBuf>>(&mut self, path: P) -> Result<&mut Self, ConfigError> {
self.source(Source::file(path, true)?)
}
pub fn file_optional<S: TryInto<PathBuf>>(
&mut self,
path: S,
) -> Result<&mut Self, ConfigError> {
self.source(Source::file(path, false)?)
}
pub fn source(&mut self, source: Source) -> Result<&mut Self, ConfigError> {
self.sources.push(source);
Ok(self)
}
#[cfg(feature = "url")]
pub fn url<S: TryInto<String>>(&mut self, url: S) -> Result<&mut Self, ConfigError> {
self.source(Source::url(url)?)
}
pub fn load(&self) -> Result<ConfigLoadResult<T>, ConfigError> {
let context = <T::Partial as PartialConfig>::Context::default();
self.load_with_context(&context)
}
#[instrument(name = "load_config", skip_all)]
pub fn load_with_context(
&self,
context: &<T::Partial as PartialConfig>::Context,
) -> Result<ConfigLoadResult<T>, ConfigError> {
trace!(config = &self.name, "Loading configuration");
let layers = self.parse_into_layers(&self.sources, context)?;
let partial = self.merge_layers(&layers, context)?.finalize(context)?;
#[cfg(feature = "validate")]
{
partial.validate(context, true).map_err(|error| {
self.map_validator_error(error, layers.last().map(|layer| &layer.source))
})?;
}
Ok(ConfigLoadResult {
config: T::from_partial(partial),
layers,
})
}
#[instrument(name = "load_partial_config", skip_all)]
pub fn load_partial(
&self,
context: &<T::Partial as PartialConfig>::Context,
) -> Result<T::Partial, ConfigError> {
trace!(config = &self.name, "Loading partial configuration");
let layers = self.parse_into_layers(&self.sources, context)?;
let partial = self.merge_layers(&layers, context)?;
Ok(partial)
}
pub fn set_cacher(&mut self, cacher: impl Cacher + 'static) -> &mut Self {
self.cacher = Mutex::new(Box::new(cacher));
self
}
pub fn set_help<H: AsRef<str>>(&mut self, help: H) -> &mut Self {
self.help = Some(help.as_ref().to_owned());
self
}
pub fn set_root<P: AsRef<Path>>(&mut self, root: P) -> &mut Self {
self.root = Some(root.as_ref().to_path_buf());
self
}
#[cfg(feature = "extends")]
#[instrument(skip_all)]
fn extend_additional_layers(
&self,
context: &<T::Partial as PartialConfig>::Context,
parent_source: &Source,
extends_from: &ExtendsFrom,
) -> Result<Vec<Layer<T>>, ConfigError> {
let mut sources = vec![];
let mut extend_source = |value: &str| {
let source = Source::new(value, Some(parent_source))?;
if matches!(source, Source::Code { .. }) {
return Err(ConfigError::ExtendsFromNoCode);
}
trace!(
config = &self.name,
source = source.as_str(),
"Extending additional source"
);
sources.push(source);
Ok(())
};
match extends_from {
ExtendsFrom::String(value) => {
extend_source(value)?;
}
ExtendsFrom::List(values) => {
for value in values.iter() {
extend_source(value)?;
}
}
};
self.parse_into_layers(&sources, context)
}
fn get_location<'l>(&'l self, source: &'l Source) -> &'l str {
match source {
Source::Code { .. } => &self.name,
Source::File { path, .. } => {
let rel_path = if let Some(root) = &self.root {
path.strip_prefix(root).unwrap_or(path)
} else {
path
};
rel_path.to_str().unwrap_or(&self.name)
}
#[cfg(feature = "url")]
Source::Url { url, .. } => url,
}
}
#[instrument(skip_all)]
fn merge_layers(
&self,
layers: &[Layer<T>],
context: &<T::Partial as PartialConfig>::Context,
) -> Result<T::Partial, ConfigError> {
trace!(
config = &self.name,
"Merging partial layers into a final result"
);
let mut merged = T::Partial::default();
for layer in layers {
merged.merge(context, layer.partial.clone())?;
}
Ok(merged)
}
#[instrument(skip_all)]
fn parse_into_layers(
&self,
sources_to_parse: &[Source],
#[allow(unused_variables)] context: &<T::Partial as PartialConfig>::Context,
) -> Result<Vec<Layer<T>>, ConfigError> {
let mut layers: Vec<Layer<T>> = vec![];
for source in sources_to_parse {
trace!(
config = &self.name,
source = source.as_str(),
"Creating layer from source"
);
let partial: T::Partial = self
.parse_source(source)
.map_err(|error| self.map_parser_error(error, source))?;
#[cfg(feature = "validate")]
{
partial
.validate(context, false)
.map_err(|error| self.map_validator_error(error, Some(source)))?;
}
#[cfg(feature = "extends")]
if let Some(extends_from) = partial.extends_from() {
layers.extend(self.extend_additional_layers(context, source, &extends_from)?);
}
layers.push(Layer {
partial,
source: source.clone(),
});
}
Ok(layers)
}
#[instrument(skip_all)]
fn parse_source(&self, source: &Source) -> Result<T::Partial, ConfigError> {
let (content, cache_path): (Cow<'_, str>, Option<PathBuf>) = match source {
Source::Code { code, .. } => (Cow::Borrowed(strip_bom(code)), None),
Source::File { path, required } => {
let content = if path.exists() {
fs::read_to_string(path).map_err(|error| ConfigError::ReadFileFailed {
path: path.to_path_buf(),
error: Box::new(error),
})?
} else {
if *required {
return Err(ConfigError::MissingFile(path.to_path_buf()));
}
return Ok(T::Partial::default());
};
(Cow::Owned(strip_bom(&content).to_owned()), None)
}
#[cfg(feature = "url")]
Source::Url { url } => {
use crate::helpers::is_secure_url;
if !is_secure_url(url) {
return Err(ConfigError::HttpsOnly(url.to_owned()));
}
let mut cacher = self.cacher.lock().unwrap();
let handle_reqwest_error = |error: reqwest::Error| ConfigError::ReadUrlFailed {
url: url.to_owned(),
error: Box::new(error),
};
let content = if let Some(cache) = cacher.read(url)? {
cache
} else {
let body = reqwest::blocking::get(url)
.map_err(handle_reqwest_error)?
.text()
.map_err(handle_reqwest_error)?;
cacher.write(url, &body)?;
body
};
(
Cow::Owned(strip_bom(&content).to_owned()),
cacher.get_file_path(url)?,
)
}
};
for format in &self.formats {
if format.should_parse(source) {
return format.parse(source, &content, cache_path.as_deref());
}
}
Err(ConfigError::NoMatchingFormat {
src: source.as_str().to_owned(),
ext: source.get_file_ext().unwrap_or("(none)").into(),
})
}
fn map_parser_error(&self, outer: ConfigError, source: &Source) -> ConfigError {
match outer {
ConfigError::Parser { error, .. } => ConfigError::Parser {
location: self.get_location(source).to_owned(),
error,
help: self.help.clone(),
},
_ => outer,
}
}
#[cfg(feature = "validate")]
fn map_validator_error(&self, outer: ConfigError, source: Option<&Source>) -> ConfigError {
match outer {
ConfigError::Validator { error, .. } => ConfigError::Validator {
location: source
.map(|src| self.get_location(src))
.unwrap_or(&self.name)
.to_owned(),
error,
help: self.help.clone(),
},
_ => outer,
}
}
}