use std::fmt::Debug;
use std::io;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use clap::{Args as ClapArgs, Command, FromArgMatches, Parser};
use http::header::HeaderName;
use http::Method;
use serde::{Deserialize, Serialize};
use serde_with::with_prefix;
use tracing::instrument;
use tracing::subscriber::set_global_default;
use tracing_subscriber::fmt::{format, layer};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::{EnvFilter, Registry};
use crate::config::cors::{AllowType, CorsConfig, HeaderValue, TaggedAllowTypes};
use crate::config::parser::from_path;
use crate::config::FormattingStyle::{Compact, Full, Json, Pretty};
use crate::error::Error::{ArgParseError, TracingError};
use crate::error::Result;
use crate::resolver::Resolver;
use crate::tls::TlsServerConfig;
pub mod cors;
pub mod parser;
pub const USAGE: &str = "To configure htsget-rs use a config file or environment variables. \
See the documentation of the htsget-config crate for more information.";
pub(crate) fn default_localstorage_addr() -> &'static str {
"127.0.0.1:8081"
}
fn default_addr() -> &'static str {
"127.0.0.1:8080"
}
fn default_server_origin() -> &'static str {
"http://localhost:8080"
}
pub(crate) fn default_path() -> &'static str {
"./"
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = USAGE)]
struct Args {
#[arg(
short,
long,
env = "HTSGET_CONFIG",
help = "Set the location of the config file"
)]
config: Option<PathBuf>,
#[arg(short, long, exclusive = true, help = "Print a default config file")]
print_default_config: bool,
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, Default)]
pub enum FormattingStyle {
#[default]
Full,
Compact,
Pretty,
Json,
}
with_prefix!(ticket_server_prefix "ticket_server_");
with_prefix!(data_server_prefix "data_server_");
with_prefix!(cors_prefix "cors_");
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(default)]
pub struct Config {
formatting_style: FormattingStyle,
#[serde(flatten, with = "ticket_server_prefix")]
ticket_server: TicketServerConfig,
#[serde(flatten, with = "data_server_prefix")]
data_server: DataServerConfig,
#[serde(flatten)]
service_info: ServiceInfo,
resolvers: Vec<Resolver>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(default)]
pub struct TicketServerConfig {
addr: SocketAddr,
#[serde(skip_serializing)]
tls: Option<TlsServerConfig>,
#[serde(flatten, with = "cors_prefix")]
cors: CorsConfig,
}
impl TicketServerConfig {
pub fn new(addr: SocketAddr, tls: Option<TlsServerConfig>, cors: CorsConfig) -> Self {
Self { addr, tls, cors }
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn tls(&self) -> Option<&TlsServerConfig> {
self.tls.as_ref()
}
pub fn into_tls(self) -> Option<TlsServerConfig> {
self.tls
}
pub fn cors(&self) -> &CorsConfig {
&self.cors
}
pub fn allow_credentials(&self) -> bool {
self.cors.allow_credentials()
}
pub fn allow_origins(&self) -> &AllowType<HeaderValue, TaggedAllowTypes> {
self.cors.allow_origins()
}
pub fn allow_headers(&self) -> &AllowType<HeaderName> {
self.cors.allow_headers()
}
pub fn allow_methods(&self) -> &AllowType<Method> {
self.cors.allow_methods()
}
pub fn max_age(&self) -> usize {
self.cors.max_age()
}
pub fn expose_headers(&self) -> &AllowType<HeaderName> {
self.cors.expose_headers()
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(default)]
pub struct DataServerConfig {
enabled: bool,
addr: SocketAddr,
local_path: PathBuf,
serve_at: String,
#[serde(skip_serializing)]
tls: Option<TlsServerConfig>,
#[serde(flatten, with = "cors_prefix")]
cors: CorsConfig,
}
impl DataServerConfig {
pub fn new(
enabled: bool,
addr: SocketAddr,
local_path: PathBuf,
serve_at: String,
tls: Option<TlsServerConfig>,
cors: CorsConfig,
) -> Self {
Self {
enabled,
addr,
local_path,
serve_at,
tls,
cors,
}
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn local_path(&self) -> &Path {
&self.local_path
}
pub fn serve_at(&self) -> &str {
&self.serve_at
}
pub fn tls(&self) -> Option<&TlsServerConfig> {
self.tls.as_ref()
}
pub fn into_tls(self) -> Option<TlsServerConfig> {
self.tls
}
pub fn cors(&self) -> &CorsConfig {
&self.cors
}
pub fn allow_credentials(&self) -> bool {
self.cors.allow_credentials()
}
pub fn allow_origins(&self) -> &AllowType<HeaderValue, TaggedAllowTypes> {
self.cors.allow_origins()
}
pub fn allow_headers(&self) -> &AllowType<HeaderName> {
self.cors.allow_headers()
}
pub fn allow_methods(&self) -> &AllowType<Method> {
self.cors.allow_methods()
}
pub fn max_age(&self) -> usize {
self.cors.max_age()
}
pub fn expose_headers(&self) -> &AllowType<HeaderName> {
self.cors.expose_headers()
}
pub fn enabled(&self) -> bool {
self.enabled
}
}
impl Default for DataServerConfig {
fn default() -> Self {
Self {
enabled: true,
addr: default_localstorage_addr()
.parse()
.expect("expected valid address"),
local_path: default_path().into(),
serve_at: Default::default(),
tls: None,
cors: CorsConfig::default(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct ServiceInfo {
id: Option<String>,
name: Option<String>,
version: Option<String>,
organization_name: Option<String>,
organization_url: Option<String>,
contact_url: Option<String>,
documentation_url: Option<String>,
created_at: Option<String>,
updated_at: Option<String>,
environment: Option<String>,
}
impl ServiceInfo {
pub fn id(&self) -> Option<&str> {
self.id.as_deref()
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn version(&self) -> Option<&str> {
self.version.as_deref()
}
pub fn organization_name(&self) -> Option<&str> {
self.organization_name.as_deref()
}
pub fn organization_url(&self) -> Option<&str> {
self.organization_url.as_deref()
}
pub fn contact_url(&self) -> Option<&str> {
self.contact_url.as_deref()
}
pub fn documentation_url(&self) -> Option<&str> {
self.documentation_url.as_deref()
}
pub fn created_at(&self) -> Option<&str> {
self.created_at.as_deref()
}
pub fn updated_at(&self) -> Option<&str> {
self.updated_at.as_deref()
}
pub fn environment(&self) -> Option<&str> {
self.environment.as_deref()
}
}
impl Default for TicketServerConfig {
fn default() -> Self {
Self {
addr: default_addr().parse().expect("expected valid address"),
tls: None,
cors: CorsConfig::default(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
formatting_style: Full,
ticket_server: TicketServerConfig::default(),
data_server: DataServerConfig::default(),
service_info: ServiceInfo::default(),
resolvers: vec![Resolver::default()],
}
}
}
impl Config {
pub fn new(
formatting: FormattingStyle,
ticket_server: TicketServerConfig,
data_server: DataServerConfig,
service_info: ServiceInfo,
resolvers: Vec<Resolver>,
) -> Self {
Self {
formatting_style: formatting,
ticket_server,
data_server,
service_info,
resolvers,
}
}
pub fn parse_args_with_command(augment_args: Command) -> Result<Option<PathBuf>> {
Ok(Self::parse_with_args(
Args::from_arg_matches(&Args::augment_args(augment_args).get_matches())
.map_err(|err| ArgParseError(err.to_string()))?,
))
}
pub fn parse_args() -> Option<PathBuf> {
Self::parse_with_args(Args::parse())
}
fn parse_with_args(args: Args) -> Option<PathBuf> {
if args.print_default_config {
println!(
"{}",
toml::ser::to_string_pretty(&Config::default()).unwrap()
);
None
} else {
Some(args.config.unwrap_or_else(|| "".into()))
}
}
#[instrument]
pub fn from_path(path: &Path) -> io::Result<Self> {
let config: Self = from_path(path)?;
Ok(config.resolvers_from_data_server_config())
}
pub fn setup_tracing(&self) -> Result<()> {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let subscriber = Registry::default().with(env_filter);
match self.formatting_style() {
Full => set_global_default(subscriber.with(layer())),
Compact => set_global_default(subscriber.with(layer().event_format(format().compact()))),
Pretty => set_global_default(subscriber.with(layer().event_format(format().pretty()))),
Json => set_global_default(subscriber.with(layer().event_format(format().json()))),
}
.map_err(|err| TracingError(err.to_string()))?;
Ok(())
}
pub fn formatting_style(&self) -> FormattingStyle {
self.formatting_style
}
pub fn ticket_server(&self) -> &TicketServerConfig {
&self.ticket_server
}
pub fn data_server(&self) -> &DataServerConfig {
&self.data_server
}
pub fn into_data_server(self) -> DataServerConfig {
self.data_server
}
pub fn service_info(&self) -> &ServiceInfo {
&self.service_info
}
pub fn resolvers(&self) -> &[Resolver] {
&self.resolvers
}
pub fn owned_resolvers(self) -> Vec<Resolver> {
self.resolvers
}
pub fn resolvers_from_data_server_config(self) -> Self {
let Config {
formatting_style: formatting,
ticket_server,
data_server,
service_info,
mut resolvers,
} = self;
resolvers
.iter_mut()
.for_each(|resolver| resolver.resolvers_from_data_server_config(&data_server));
Self::new(
formatting,
ticket_server,
data_server,
service_info,
resolvers,
)
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::fmt::Display;
use crate::config::parser::from_str;
use figment::Jail;
use http::uri::Authority;
use crate::storage::Storage;
use crate::tls::tests::with_test_certificates;
use crate::types::Scheme::Http;
use super::*;
fn test_config<K, V, F>(contents: Option<&str>, env_variables: Vec<(K, V)>, test_fn: F)
where
K: AsRef<str>,
V: Display,
F: Fn(Config),
{
Jail::expect_with(|jail| {
let file = "test.toml";
if let Some(contents) = contents {
jail.create_file(file, contents)?;
}
for (key, value) in env_variables {
jail.set_env(key, value);
}
let path = Path::new(file);
test_fn(Config::from_path(path).map_err(|err| err.to_string())?);
test_fn(
from_path::<Config>(path)
.map_err(|err| err.to_string())?
.resolvers_from_data_server_config(),
);
test_fn(
from_str::<Config>(contents.unwrap_or(""))
.map_err(|err| err.to_string())?
.resolvers_from_data_server_config(),
);
Ok(())
});
}
pub(crate) fn test_config_from_env<K, V, F>(env_variables: Vec<(K, V)>, test_fn: F)
where
K: AsRef<str>,
V: Display,
F: Fn(Config),
{
test_config(None, env_variables, test_fn);
}
pub(crate) fn test_config_from_file<F>(contents: &str, test_fn: F)
where
F: Fn(Config),
{
test_config(Some(contents), Vec::<(&str, &str)>::new(), test_fn);
}
#[test]
fn config_ticket_server_addr_env() {
test_config_from_env(
vec![("HTSGET_TICKET_SERVER_ADDR", "127.0.0.1:8082")],
|config| {
assert_eq!(
config.ticket_server().addr(),
"127.0.0.1:8082".parse().unwrap()
);
},
);
}
#[test]
fn config_ticket_server_cors_allow_origin_env() {
test_config_from_env(
vec![("HTSGET_TICKET_SERVER_CORS_ALLOW_CREDENTIALS", true)],
|config| {
assert!(config.ticket_server().allow_credentials());
},
);
}
#[test]
fn config_service_info_id_env() {
test_config_from_env(vec![("HTSGET_ID", "id")], |config| {
assert_eq!(config.service_info().id(), Some("id"));
});
}
#[test]
fn config_data_server_addr_env() {
test_config_from_env(
vec![("HTSGET_DATA_SERVER_ADDR", "127.0.0.1:8082")],
|config| {
assert_eq!(
config.data_server().addr(),
"127.0.0.1:8082".parse().unwrap()
);
},
);
}
#[test]
fn config_no_data_server_env() {
test_config_from_env(vec![("HTSGET_DATA_SERVER_ENABLED", "true")], |config| {
assert!(config.data_server().enabled());
});
}
#[test]
fn config_ticket_server_addr_file() {
test_config_from_file(r#"ticket_server_addr = "127.0.0.1:8082""#, |config| {
assert_eq!(
config.ticket_server().addr(),
"127.0.0.1:8082".parse().unwrap()
);
});
}
#[test]
fn config_ticket_server_cors_allow_origin_file() {
test_config_from_file(r#"ticket_server_cors_allow_credentials = true"#, |config| {
assert!(config.ticket_server().allow_credentials());
});
}
#[test]
fn config_service_info_id_file() {
test_config_from_file(r#"id = "id""#, |config| {
assert_eq!(config.service_info().id(), Some("id"));
});
}
#[test]
fn config_data_server_addr_file() {
test_config_from_file(r#"data_server_addr = "127.0.0.1:8082""#, |config| {
assert_eq!(
config.data_server().addr(),
"127.0.0.1:8082".parse().unwrap()
);
});
}
#[test]
#[should_panic]
fn config_data_server_tls_no_cert() {
with_test_certificates(|path, _, _| {
let key_path = path.join("key.pem");
test_config_from_file(
&format!(
r#"
data_server_tls.key = "{}"
"#,
key_path.to_string_lossy().escape_default()
),
|config| {
assert!(config.data_server().tls().is_none());
},
);
});
}
#[test]
fn config_data_server_tls() {
with_test_certificates(|path, _, _| {
let key_path = path.join("key.pem");
let cert_path = path.join("cert.pem");
test_config_from_file(
&format!(
r#"
data_server_tls.key = "{}"
data_server_tls.cert = "{}"
"#,
key_path.to_string_lossy().escape_default(),
cert_path.to_string_lossy().escape_default()
),
|config| {
println!("{:?}", config.data_server().tls());
assert!(config.data_server().tls().is_some());
},
);
});
}
#[test]
fn config_data_server_tls_env() {
with_test_certificates(|path, _, _| {
let key_path = path.join("key.pem");
let cert_path = path.join("cert.pem");
test_config_from_env(
vec![
("HTSGET_DATA_SERVER_TLS_KEY", key_path.to_string_lossy()),
("HTSGET_DATA_SERVER_TLS_CERT", cert_path.to_string_lossy()),
],
|config| {
assert!(config.data_server().tls().is_some());
},
);
});
}
#[test]
#[should_panic]
fn config_ticket_server_tls_no_cert() {
with_test_certificates(|path, _, _| {
let key_path = path.join("key.pem");
test_config_from_file(
&format!(
r#"
ticket_server_tls.key = "{}"
"#,
key_path.to_string_lossy().escape_default()
),
|config| {
assert!(config.ticket_server().tls().is_none());
},
);
});
}
#[test]
fn config_ticket_server_tls() {
with_test_certificates(|path, _, _| {
let key_path = path.join("key.pem");
let cert_path = path.join("cert.pem");
test_config_from_file(
&format!(
r#"
ticket_server_tls.key = "{}"
ticket_server_tls.cert = "{}"
"#,
key_path.to_string_lossy().escape_default(),
cert_path.to_string_lossy().escape_default()
),
|config| {
assert!(config.ticket_server().tls().is_some());
},
);
});
}
#[test]
fn config_ticket_server_tls_env() {
with_test_certificates(|path, _, _| {
let key_path = path.join("key.pem");
let cert_path = path.join("cert.pem");
test_config_from_env(
vec![
("HTSGET_TICKET_SERVER_TLS_KEY", key_path.to_string_lossy()),
("HTSGET_TICKET_SERVER_TLS_CERT", cert_path.to_string_lossy()),
],
|config| {
assert!(config.ticket_server().tls().is_some());
},
);
});
}
#[test]
fn config_no_data_server_file() {
test_config_from_file(r#"data_server_enabled = true"#, |config| {
assert!(config.data_server().enabled());
});
}
#[test]
fn resolvers_from_data_server_config() {
test_config_from_file(
r#"
data_server_addr = "127.0.0.1:8080"
data_server_local_path = "path"
data_server_serve_at = "/path"
[[resolvers]]
storage = "Local"
"#,
|config| {
assert_eq!(config.resolvers.len(), 1);
assert!(matches!(config.resolvers.first().unwrap().storage(),
Storage::Local { local_storage } if local_storage.local_path() == "path" && local_storage.scheme() == Http && local_storage.authority() == &Authority::from_static("127.0.0.1:8080") && local_storage.path_prefix() == "/path"));
},
);
}
}